1use std::{
2 collections::{HashMap, HashSet},
3 fs::File,
4 io::{self, copy, Cursor, Read, Seek, SeekFrom, Write},
5 sync::Mutex,
6};
7
8use crate::events::SerializeEventWithDelta;
9
10use super::errors::MIDIWriteError;
11
12pub trait WriteSeek: Write + Seek {}
13impl WriteSeek for File {}
14impl WriteSeek for Cursor<Vec<u8>> {}
15
16pub struct QueuedOutput {
17 write: Box<dyn Read>,
18 length: u32,
19}
20
21struct TrackStatus {
22 opened_tracks: HashSet<i32>,
23 written_tracks: HashSet<i32>,
24 next_init_track: i32,
25 next_write_track: i32,
26 queued_writes: HashMap<i32, QueuedOutput>,
27}
28
29pub struct MIDIWriter {
30 output: Option<Mutex<Box<dyn WriteSeek>>>,
31 tracks: Mutex<TrackStatus>,
32}
33
34pub struct TrackWriter<'a> {
35 midi_writer: &'a MIDIWriter,
36 track_id: i32,
37 writer: Option<Cursor<Vec<u8>>>,
38}
39
40fn sorted_track_ids(tracks: &HashSet<i32>) -> Vec<i32> {
41 let mut ids = tracks.iter().copied().collect::<Vec<_>>();
42 ids.sort_unstable();
43 ids
44}
45
46fn encode_u16(val: u16) -> [u8; 2] {
47 let mut bytes = [0; 2];
48 bytes[0] = ((val >> 8) & 0xff) as u8;
49 bytes[1] = (val & 0xff) as u8;
50 bytes
51}
52
53fn encode_u32(val: u32) -> [u8; 4] {
54 let mut bytes = [0; 4];
55 bytes[0] = ((val >> 24) & 0xff) as u8;
56 bytes[1] = ((val >> 16) & 0xff) as u8;
57 bytes[2] = ((val >> 8) & 0xff) as u8;
58 bytes[3] = (val & 0xff) as u8;
59 bytes
60}
61
62fn flush_track(writer: &mut dyn WriteSeek, mut output: QueuedOutput) -> Result<(), io::Error> {
63 writer.write_all("MTrk".as_bytes())?;
64 writer.write_all(&encode_u32(output.length))?;
65 copy(&mut output.write, writer)?;
66 Ok(())
67}
68
69impl MIDIWriter {
70 pub fn new(filename: &str, ppq: u16) -> Result<MIDIWriter, MIDIWriteError> {
71 let reader = File::create(filename)?;
72 MIDIWriter::new_from_stream(Box::new(reader), ppq)
73 }
74
75 pub fn new_from_stream(
76 mut output: Box<dyn WriteSeek>,
77 ppq: u16,
78 ) -> Result<MIDIWriter, MIDIWriteError> {
79 output.seek(SeekFrom::Start(0))?;
80 output.write_all("MThd".as_bytes())?;
81 output.write_all(&encode_u32(6))?;
82 output.write_all(&encode_u16(1))?;
83 output.write_all(&encode_u16(0))?;
84 output.write_all(&encode_u16(ppq))?;
85
86 Ok(MIDIWriter {
87 output: Some(Mutex::new(output)),
88 tracks: Mutex::new(TrackStatus {
89 opened_tracks: HashSet::new(),
90 next_init_track: 0,
91 next_write_track: 0,
92 queued_writes: HashMap::new(),
93 written_tracks: HashSet::new(),
94 }),
95 })
96 }
97
98 #[deprecated(note = "use new_from_stream")]
99 pub fn new_from_stram(
100 output: Box<dyn WriteSeek>,
101 ppq: u16,
102 ) -> Result<MIDIWriter, MIDIWriteError> {
103 Self::new_from_stream(output, ppq)
104 }
105
106 fn with_writer<R>(
107 &self,
108 f: impl FnOnce(&mut dyn WriteSeek) -> Result<R, io::Error>,
109 ) -> Result<R, MIDIWriteError> {
110 let output = self.output.as_ref().ok_or(MIDIWriteError::WriterEnded)?;
111 let mut output = output.lock().unwrap();
112 Ok(f(output.as_mut())?)
113 }
114
115 fn write_u16_at(&self, at: u64, val: u16) -> Result<(), MIDIWriteError> {
116 self.with_writer(|output| {
117 let pos = output.stream_position()?;
118 output.seek(SeekFrom::Start(at))?;
119 output.write_all(&encode_u16(val))?;
120 output.seek(SeekFrom::Start(pos))?;
121 Ok(())
122 })
123 }
124
125 pub fn write_ppq(&self, ppq: u16) -> Result<(), MIDIWriteError> {
126 self.write_u16_at(12, ppq)
127 }
128
129 pub fn write_format(&self, format: u16) -> Result<(), MIDIWriteError> {
130 self.write_u16_at(8, format)
131 }
132
133 fn write_ntrks(&self, track_count: u16) -> Result<(), MIDIWriteError> {
134 self.write_u16_at(10, track_count)
135 }
136
137 pub fn try_open_next_track(&self) -> Result<TrackWriter<'_>, MIDIWriteError> {
138 let mut tracks = self.tracks.lock().unwrap();
139 if self.output.is_none() {
140 return Err(MIDIWriteError::WriterEnded);
141 }
142
143 let track_id = tracks.next_init_track;
144 if tracks.written_tracks.contains(&track_id) || tracks.opened_tracks.contains(&track_id) {
145 return Err(MIDIWriteError::TrackAlreadyOpened { track_id });
146 }
147
148 tracks.next_init_track += 1;
149 tracks.opened_tracks.insert(track_id);
150
151 Ok(TrackWriter {
152 midi_writer: self,
153 track_id,
154 writer: Some(Cursor::new(Vec::new())),
155 })
156 }
157
158 #[deprecated(note = "use try_open_next_track")]
159 pub fn open_next_track(&self) -> TrackWriter<'_> {
160 self.try_open_next_track()
161 .expect("failed to open next track")
162 }
163
164 pub fn try_open_track(&self, track_id: i32) -> Result<TrackWriter<'_>, MIDIWriteError> {
165 if self.output.is_none() {
166 return Err(MIDIWriteError::WriterEnded);
167 }
168
169 let mut tracks = self.tracks.lock().unwrap();
170 if tracks.written_tracks.contains(&track_id) || tracks.opened_tracks.contains(&track_id) {
171 return Err(MIDIWriteError::TrackAlreadyOpened { track_id });
172 }
173
174 tracks.opened_tracks.insert(track_id);
175
176 Ok(TrackWriter {
177 midi_writer: self,
178 track_id,
179 writer: Some(Cursor::new(Vec::new())),
180 })
181 }
182
183 #[deprecated(note = "use try_open_track")]
184 pub fn open_track(&self, track_id: i32) -> TrackWriter<'_> {
185 self.try_open_track(track_id).expect("failed to open track")
186 }
187
188 pub fn try_end(&mut self) -> Result<(), MIDIWriteError> {
189 if self.is_ended() {
190 return Err(MIDIWriteError::WriterEnded);
191 }
192
193 let (open_tracks, missing_tracks, track_count) = {
194 let tracks = self.tracks.lock().unwrap();
195 if !tracks.opened_tracks.is_empty() {
196 (
197 Some(sorted_track_ids(&tracks.opened_tracks)),
198 None,
199 tracks.written_tracks.len(),
200 )
201 } else if !tracks.queued_writes.is_empty() {
202 let max_track = *tracks
203 .queued_writes
204 .keys()
205 .max()
206 .expect("queued_writes checked to be non-empty");
207 let mut missing = (0..max_track)
208 .filter(|track_id| !tracks.written_tracks.contains(track_id))
209 .collect::<Vec<_>>();
210 missing.sort_unstable();
211 (None, Some(missing), tracks.written_tracks.len())
212 } else {
213 (None, None, tracks.written_tracks.len())
214 }
215 };
216
217 if let Some(track_ids) = open_tracks {
218 return Err(MIDIWriteError::OpenTracksRemaining { track_ids });
219 }
220
221 if let Some(track_ids) = missing_tracks {
222 return Err(MIDIWriteError::TrackGapsRemaining { track_ids });
223 }
224
225 if track_count > u16::MAX as usize {
226 return Err(MIDIWriteError::TrackCountOverflow { track_count });
227 }
228
229 self.write_ntrks(track_count as u16)?;
230 self.output.take();
231
232 Ok(())
233 }
234
235 pub fn end(&mut self) -> Result<(), MIDIWriteError> {
236 self.try_end()
237 }
238
239 pub fn is_ended(&self) -> bool {
240 self.output.is_none()
241 }
242}
243
244impl<'a> TrackWriter<'a> {
245 fn writer_mut(&mut self) -> Result<&mut Cursor<Vec<u8>>, MIDIWriteError> {
246 let track_id = self.track_id;
247 self.writer
248 .as_mut()
249 .ok_or(MIDIWriteError::TrackAlreadyEnded { track_id })
250 }
251
252 pub fn end(&mut self) -> Result<(), MIDIWriteError> {
253 if self.is_ended() {
254 return Err(MIDIWriteError::TrackAlreadyEnded {
255 track_id: self.track_id,
256 });
257 }
258
259 self.write_bytes(&[0x00, 0xFF, 0x2F, 0x00])?;
260
261 let length = self
262 .writer
263 .as_ref()
264 .expect("writer presence was checked above")
265 .position() as u32;
266
267 let mut status = self.midi_writer.tracks.lock().unwrap();
268 if !status.opened_tracks.remove(&self.track_id) {
269 return Err(MIDIWriteError::TrackAlreadyEnded {
270 track_id: self.track_id,
271 });
272 }
273 status.written_tracks.insert(self.track_id);
274
275 let mut writer = self
276 .writer
277 .take()
278 .ok_or(MIDIWriteError::TrackAlreadyEnded {
279 track_id: self.track_id,
280 })?;
281 writer.seek(SeekFrom::Start(0))?;
282
283 status.queued_writes.insert(
284 self.track_id,
285 QueuedOutput {
286 write: Box::new(writer),
287 length,
288 },
289 );
290
291 if self.track_id == status.next_write_track {
292 let output = self
293 .midi_writer
294 .output
295 .as_ref()
296 .ok_or(MIDIWriteError::WriterEnded)?;
297 let mut writer = output.lock().unwrap();
298 loop {
299 let next_write_track = status.next_write_track;
300 match status.queued_writes.remove_entry(&next_write_track) {
301 None => break,
302 Some(output) => {
303 flush_track(writer.as_mut(), output.1)?;
304 status.next_write_track += 1;
305 }
306 }
307 }
308 }
309
310 Ok(())
311 }
312
313 pub fn is_ended(&self) -> bool {
314 self.writer.is_none()
315 }
316
317 pub fn get_writer_mut(&mut self) -> &mut impl Write {
318 self.writer_mut()
319 .expect("Tried to write to TrackWriter after .end() was called")
320 }
321
322 pub fn write_event<T: SerializeEventWithDelta>(
323 &mut self,
324 event: T,
325 ) -> Result<usize, MIDIWriteError> {
326 let writer = self.writer_mut()?;
327 event.serialize_event_with_delta(writer)
328 }
329
330 pub fn write_events_iter<T: SerializeEventWithDelta>(
331 &mut self,
332 events: impl Iterator<Item = T>,
333 ) -> Result<usize, MIDIWriteError> {
334 let mut count = 0;
335 for event in events {
336 count += self.write_event(event)?;
337 }
338 Ok(count)
339 }
340
341 pub fn write_bytes(&mut self, bytes: &[u8]) -> Result<usize, MIDIWriteError> {
342 let writer = self.writer_mut()?;
343 writer.write_all(bytes)?;
344 Ok(bytes.len())
345 }
346}
347
348impl<'a> std::fmt::Debug for TrackWriter<'a> {
349 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
350 f.debug_struct("TrackWriter")
351 .field("track_id", &self.track_id)
352 .field("is_ended", &self.is_ended())
353 .finish()
354 }
355}
356
357impl<'a> Drop for TrackWriter<'a> {
358 fn drop(&mut self) {
359 if !self.is_ended() {
360 let _ = self.end();
361 }
362 }
363}
364
365impl Drop for MIDIWriter {
366 fn drop(&mut self) {
367 if !self.is_ended() {
368 let _ = self.end();
369 }
370 }
371}
372
373#[cfg(test)]
374mod tests {
375 use super::{MIDIWriter, WriteSeek};
376 use crate::io::MIDIWriteError;
377 use std::{
378 io::{Cursor, Seek, SeekFrom, Write},
379 sync::{Arc, Mutex},
380 };
381
382 #[derive(Clone, Default)]
383 struct SharedCursor {
384 inner: Arc<Mutex<Cursor<Vec<u8>>>>,
385 }
386
387 impl SharedCursor {
388 fn bytes(&self) -> Vec<u8> {
389 self.inner.lock().unwrap().get_ref().clone()
390 }
391 }
392
393 impl Write for SharedCursor {
394 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
395 self.inner.lock().unwrap().write(buf)
396 }
397
398 fn flush(&mut self) -> std::io::Result<()> {
399 self.inner.lock().unwrap().flush()
400 }
401 }
402
403 impl Seek for SharedCursor {
404 fn seek(&mut self, pos: SeekFrom) -> std::io::Result<u64> {
405 self.inner.lock().unwrap().seek(pos)
406 }
407 }
408
409 impl WriteSeek for SharedCursor {}
410
411 #[test]
412 fn is_ended_matches_writer_lifecycle() {
413 let shared = SharedCursor::default();
414 let mut writer = MIDIWriter::new_from_stream(Box::new(shared), 480).unwrap();
415 assert!(!writer.is_ended());
416
417 let mut track = writer.try_open_next_track().unwrap();
418 assert!(!track.is_ended());
419
420 track.end().unwrap();
421 assert!(track.is_ended());
422 drop(track);
423 assert!(!writer.is_ended());
424
425 writer.end().unwrap();
426 assert!(writer.is_ended());
427 }
428
429 #[test]
430 fn drop_still_finalizes_open_writers() {
431 let shared = SharedCursor::default();
432 {
433 let writer = MIDIWriter::new_from_stream(Box::new(shared.clone()), 480).unwrap();
434 let mut track = writer.try_open_next_track().unwrap();
435 track.write_bytes(&[0x00, 0x90, 0x3C, 0x40]).unwrap();
436 }
437
438 let bytes = shared.bytes();
439 assert_eq!(&bytes[0..4], b"MThd");
440 assert_eq!(&bytes[10..12], &[0x00, 0x01]);
441 assert_eq!(&bytes[14..18], b"MTrk");
442 assert_eq!(&bytes[18..22], &[0x00, 0x00, 0x00, 0x08]);
443 assert_eq!(
444 &bytes[22..30],
445 &[0x00, 0x90, 0x3C, 0x40, 0x00, 0xFF, 0x2F, 0x00]
446 );
447 }
448
449 #[test]
450 fn try_open_track_reports_duplicates_and_writer_end() {
451 let shared = SharedCursor::default();
452 let mut writer = MIDIWriter::new_from_stream(Box::new(shared), 480).unwrap();
453
454 let mut track = writer.try_open_next_track().unwrap();
455 let err = writer
456 .try_open_track(0)
457 .expect_err("duplicate track should fail");
458 assert!(matches!(
459 err,
460 MIDIWriteError::TrackAlreadyOpened { track_id: 0 }
461 ));
462
463 track.end().unwrap();
464 drop(track);
465 writer.end().unwrap();
466 let err = writer
467 .try_open_track(3)
468 .expect_err("ended writer should fail");
469 assert!(matches!(err, MIDIWriteError::WriterEnded));
470 }
471
472 #[test]
473 fn try_end_reports_missing_track_gaps() {
474 let shared = SharedCursor::default();
475 let mut writer = MIDIWriter::new_from_stream(Box::new(shared), 480).unwrap();
476 let mut track = writer.try_open_track(1).unwrap();
477 track.end().unwrap();
478 drop(track);
479 let gap_err = writer
480 .try_end()
481 .expect_err("missing track 0 should prevent end");
482 assert!(matches!(
483 gap_err,
484 MIDIWriteError::TrackGapsRemaining { ref track_ids } if track_ids == &[0]
485 ));
486 }
487
488 #[test]
489 fn write_bytes_writes_all_bytes() {
490 let shared = SharedCursor::default();
491 {
492 let writer = MIDIWriter::new_from_stream(Box::new(shared.clone()), 480).unwrap();
493 let mut track = writer.try_open_next_track().unwrap();
494 track.write_bytes(&[1, 2, 3, 4]).unwrap();
495 track.end().unwrap();
496 }
497
498 let bytes = shared.bytes();
499 assert!(bytes.windows(4).any(|window| window == [1, 2, 3, 4]));
500 }
501}