1use std::{
2 collections::{HashMap, HashSet},
3 fs::File,
4 io::{self, copy, Cursor, Read, Seek, SeekFrom, Write},
5 sync::Mutex,
6};
7
8use crate::events::SerializeEventWithDelta;
9
10use super::errors::MIDIWriteError;
11
12pub trait WriteSeek: Write + Seek {}
13impl WriteSeek for File {}
14impl WriteSeek for Cursor<Vec<u8>> {}
15
16pub struct QueuedOutput {
17 write: Box<dyn Read>,
18 length: u32,
19}
20
21struct TrackStatus {
22 opened_tracks: HashSet<i32>,
23 written_tracks: HashSet<i32>,
24 next_init_track: i32,
25 next_write_track: i32,
26 queued_writes: HashMap<i32, QueuedOutput>,
27}
28
29pub struct MIDIWriter {
30 output: Option<Mutex<Box<dyn WriteSeek>>>,
31 tracks: Mutex<TrackStatus>,
32}
33
34pub struct TrackWriter<'a> {
35 midi_writer: &'a MIDIWriter,
36 track_id: i32,
37 writer: Option<Cursor<Vec<u8>>>,
38}
39
40fn encode_u16(val: u16) -> [u8; 2] {
41 let mut bytes = [0; 2];
42 bytes[0] = ((val >> 8) & 0xff) as u8;
43 bytes[1] = (val & 0xff) as u8;
44 bytes
45}
46
47fn encode_u32(val: u32) -> [u8; 4] {
48 let mut bytes = [0; 4];
49 bytes[0] = ((val >> 24) & 0xff) as u8;
50 bytes[1] = ((val >> 16) & 0xff) as u8;
51 bytes[2] = ((val >> 8) & 0xff) as u8;
52 bytes[3] = (val & 0xff) as u8;
53 bytes
54}
55
56fn flush_track(writer: &mut Box<dyn WriteSeek>, mut output: QueuedOutput) -> Result<(), io::Error> {
57 writer.write_all("MTrk".as_bytes())?;
58 writer.write_all(&encode_u32(output.length))?;
59 copy(&mut output.write, writer)?;
60 Ok(())
61}
62
63impl MIDIWriter {
64 pub fn new(filename: &str, ppq: u16) -> Result<MIDIWriter, MIDIWriteError> {
65 let reader = File::create(filename)?;
66 MIDIWriter::new_from_stram(Box::new(reader), ppq)
67 }
68
69 pub fn new_from_stram(
70 mut output: Box<dyn WriteSeek>,
71 ppq: u16,
72 ) -> Result<MIDIWriter, MIDIWriteError> {
73 output.seek(SeekFrom::Start(0))?;
74 output.write_all("MThd".as_bytes())?;
75 output.write_all(&encode_u32(6))?;
76 output.write_all(&encode_u16(1))?;
77 output.write_all(&encode_u16(0))?;
78 output.write_all(&encode_u16(ppq))?;
79
80 Ok(MIDIWriter {
81 output: Some(Mutex::new(output)),
82 tracks: Mutex::new(TrackStatus {
83 opened_tracks: HashSet::new(),
84 next_init_track: 0,
85 next_write_track: 0,
86 queued_writes: HashMap::new(),
87 written_tracks: HashSet::new(),
88 }),
89 })
90 }
91
92 fn get_writer(&self) -> &Mutex<Box<dyn WriteSeek>> {
93 self.output
94 .as_ref()
95 .expect("Can't get the writer of an ended MIDIWriter")
96 }
97
98 fn write_u16_at(&self, at: u64, val: u16) -> Result<(), io::Error> {
99 let mut output = self.get_writer().lock().unwrap();
100 let pos = output.stream_position()?;
101 output.seek(SeekFrom::Start(at))?;
102 output.write_all(&encode_u16(val))?;
103 output.seek(SeekFrom::Start(pos))?;
104 Ok(())
105 }
106
107 pub fn write_ppq(&self, ppq: u16) -> Result<(), MIDIWriteError> {
108 Ok(self.write_u16_at(12, ppq)?)
109 }
110
111 pub fn write_format(&self, ppq: u16) -> Result<(), MIDIWriteError> {
112 Ok(self.write_u16_at(8, ppq)?)
113 }
114
115 fn write_ntrks(&self, ppq: u16) -> Result<(), MIDIWriteError> {
116 Ok(self.write_u16_at(10, ppq)?)
117 }
118
119 pub fn open_next_track(&self) -> TrackWriter {
120 let track_id = {
121 let mut tracks = self.tracks.lock().unwrap();
122 let track_id = tracks.next_init_track;
123 tracks.next_init_track += 1;
124 track_id
125 };
126 self.open_track(track_id)
127 }
128
129 pub fn open_track(&self, track_id: i32) -> TrackWriter {
130 self.add_opened_track(track_id);
131 TrackWriter {
132 midi_writer: self,
133 track_id,
134 writer: Some(Cursor::new(Vec::new())),
135 }
136 }
137
138 fn add_opened_track(&self, track_id: i32) {
139 let mut tracks = self.tracks.lock().unwrap();
140 if tracks.written_tracks.contains(&track_id) || !tracks.opened_tracks.insert(track_id) {
141 panic!("Track with id {} has aready been opened before", track_id);
142 }
143 }
144
145 pub fn end(&mut self) -> Result<(), MIDIWriteError> {
146 let tracks = self.tracks.lock().unwrap();
147 if !tracks.opened_tracks.is_empty() {
148 let unwritten: Vec<&i32> = tracks.queued_writes.keys().collect();
149 panic!("Not all tracks have been ended! Make sure you drop or call .end() on each track before ending the MIDIWriter\nMissing tracks {:?}", unwritten);
150 }
151 if !tracks.queued_writes.is_empty() {
152 let max_track = tracks.queued_writes.keys().max().unwrap();
153 let unwritten: Vec<i32> = (0..*max_track)
154 .filter(|track_id| !tracks.written_tracks.contains(track_id))
155 .collect();
156 panic!(
157 "Not all tracks have been opened! Missing tracks {:?}",
158 unwritten
159 );
160 }
161
162 let track_count = tracks.written_tracks.len();
163 self.write_ntrks(track_count.min(u16::MAX as usize) as u16)?;
164
165 self.output.take();
166
167 Ok(())
168 }
169
170 pub fn is_ended(&self) -> bool {
171 self.output.is_some()
172 }
173}
174
175impl<'a> TrackWriter<'a> {
176 pub fn end(&mut self) -> Result<(), MIDIWriteError> {
177 self.write_bytes(&[0x00, 0xFF, 0x2F, 0x00])?;
178
179 let mut status = self.midi_writer.tracks.lock().unwrap();
180 if !status.written_tracks.insert(self.track_id)
181 || !status.opened_tracks.remove(&self.track_id)
182 {
183 panic!("Invalid MIDIWriter state, unknown error");
184 }
185
186 let mut writer = match self.writer.take() {
187 Some(cursor) => cursor,
188 None => panic!(".end() was called more than once on TrackWriter"),
189 };
190
191 let length = writer.stream_position()? as u32;
192 writer.seek(SeekFrom::Start(0))?;
193
194 status.queued_writes.insert(
195 self.track_id,
196 QueuedOutput {
197 write: Box::new(writer),
198 length,
199 },
200 );
201
202 if self.track_id == status.next_write_track {
203 let mut writer = self.midi_writer.get_writer().lock().unwrap();
204 loop {
205 let next_write_track = status.next_write_track;
206 match status.queued_writes.remove_entry(&next_write_track) {
207 None => break,
208 Some(output) => {
209 flush_track(&mut writer, output.1)?;
210 status.next_write_track += 1;
211 }
212 }
213 }
214 }
215
216 Ok(())
217 }
218
219 pub fn is_ended(&self) -> bool {
220 self.writer.is_some()
221 }
222
223 pub fn get_writer_mut(&mut self) -> &mut impl Write {
224 self.writer
225 .as_mut()
226 .expect("Tried to write to TrackWriter after .end() was called")
227 }
228
229 pub fn write_event<T: SerializeEventWithDelta>(
230 &mut self,
231 event: T,
232 ) -> Result<usize, MIDIWriteError> {
233 let writer = self.get_writer_mut();
234 event.serialize_event_with_delta(writer)
235 }
236
237 pub fn write_events_iter<T: SerializeEventWithDelta>(
238 &mut self,
239 events: impl Iterator<Item = T>,
240 ) -> Result<usize, MIDIWriteError> {
241 let mut count = 0;
242 for event in events {
243 count += self.write_event(event)?;
244 }
245 Ok(count)
246 }
247
248 pub fn write_bytes(&mut self, bytes: &[u8]) -> Result<usize, MIDIWriteError> {
249 let writer = self.get_writer_mut();
250 Ok(writer.write(bytes)?)
251 }
252}
253
254impl<'a> Drop for TrackWriter<'a> {
255 fn drop(&mut self) {
256 if self.is_ended() {
257 match self.end() {
258 Ok(()) => {}
259 Err(e) => {
260 panic!("TrackWriter errored when being dropped with: {:?}\n\nIf you want to handle these errors in the future, manually call .end() before dropping", e);
261 }
262 }
263 }
264 }
265}
266
267impl Drop for MIDIWriter {
268 fn drop(&mut self) {
269 if self.is_ended() {
270 match self.end() {
271 Ok(()) => {}
272 Err(e) => {
273 panic!("TrackWriter errored when being dropped with: {:?}\n\nIf you want to handle these errors in the future, manually call .end() before dropping", e);
274 }
275 }
276 }
277 }
278}