Skip to main content

midi_toolkit/io/
midi_writer.rs

1use std::{
2    collections::{HashMap, HashSet},
3    fs::File,
4    io::{self, copy, Cursor, Read, Seek, SeekFrom, Write},
5    sync::Mutex,
6};
7
8use crate::events::SerializeEventWithDelta;
9
10use super::errors::MIDIWriteError;
11
12pub trait WriteSeek: Write + Seek {}
13impl WriteSeek for File {}
14impl WriteSeek for Cursor<Vec<u8>> {}
15
16pub struct QueuedOutput {
17    write: Box<dyn Read>,
18    length: u32,
19}
20
21struct TrackStatus {
22    opened_tracks: HashSet<i32>,
23    written_tracks: HashSet<i32>,
24    next_init_track: i32,
25    next_write_track: i32,
26    queued_writes: HashMap<i32, QueuedOutput>,
27}
28
29pub struct MIDIWriter {
30    output: Option<Mutex<Box<dyn WriteSeek>>>,
31    tracks: Mutex<TrackStatus>,
32}
33
34pub struct TrackWriter<'a> {
35    midi_writer: &'a MIDIWriter,
36    track_id: i32,
37    writer: Option<Cursor<Vec<u8>>>,
38}
39
40fn sorted_track_ids(tracks: &HashSet<i32>) -> Vec<i32> {
41    let mut ids = tracks.iter().copied().collect::<Vec<_>>();
42    ids.sort_unstable();
43    ids
44}
45
46fn encode_u16(val: u16) -> [u8; 2] {
47    let mut bytes = [0; 2];
48    bytes[0] = ((val >> 8) & 0xff) as u8;
49    bytes[1] = (val & 0xff) as u8;
50    bytes
51}
52
53fn encode_u32(val: u32) -> [u8; 4] {
54    let mut bytes = [0; 4];
55    bytes[0] = ((val >> 24) & 0xff) as u8;
56    bytes[1] = ((val >> 16) & 0xff) as u8;
57    bytes[2] = ((val >> 8) & 0xff) as u8;
58    bytes[3] = (val & 0xff) as u8;
59    bytes
60}
61
62fn flush_track(writer: &mut dyn WriteSeek, mut output: QueuedOutput) -> Result<(), io::Error> {
63    writer.write_all("MTrk".as_bytes())?;
64    writer.write_all(&encode_u32(output.length))?;
65    copy(&mut output.write, writer)?;
66    Ok(())
67}
68
69impl MIDIWriter {
70    pub fn new(filename: &str, ppq: u16) -> Result<MIDIWriter, MIDIWriteError> {
71        let reader = File::create(filename)?;
72        MIDIWriter::new_from_stream(Box::new(reader), ppq)
73    }
74
75    pub fn new_from_stream(
76        mut output: Box<dyn WriteSeek>,
77        ppq: u16,
78    ) -> Result<MIDIWriter, MIDIWriteError> {
79        output.seek(SeekFrom::Start(0))?;
80        output.write_all("MThd".as_bytes())?;
81        output.write_all(&encode_u32(6))?;
82        output.write_all(&encode_u16(1))?;
83        output.write_all(&encode_u16(0))?;
84        output.write_all(&encode_u16(ppq))?;
85
86        Ok(MIDIWriter {
87            output: Some(Mutex::new(output)),
88            tracks: Mutex::new(TrackStatus {
89                opened_tracks: HashSet::new(),
90                next_init_track: 0,
91                next_write_track: 0,
92                queued_writes: HashMap::new(),
93                written_tracks: HashSet::new(),
94            }),
95        })
96    }
97
98    #[deprecated(note = "use new_from_stream")]
99    pub fn new_from_stram(
100        output: Box<dyn WriteSeek>,
101        ppq: u16,
102    ) -> Result<MIDIWriter, MIDIWriteError> {
103        Self::new_from_stream(output, ppq)
104    }
105
106    fn with_writer<R>(
107        &self,
108        f: impl FnOnce(&mut dyn WriteSeek) -> Result<R, io::Error>,
109    ) -> Result<R, MIDIWriteError> {
110        let output = self.output.as_ref().ok_or(MIDIWriteError::WriterEnded)?;
111        let mut output = output.lock().unwrap();
112        Ok(f(output.as_mut())?)
113    }
114
115    fn write_u16_at(&self, at: u64, val: u16) -> Result<(), MIDIWriteError> {
116        self.with_writer(|output| {
117            let pos = output.stream_position()?;
118            output.seek(SeekFrom::Start(at))?;
119            output.write_all(&encode_u16(val))?;
120            output.seek(SeekFrom::Start(pos))?;
121            Ok(())
122        })
123    }
124
125    pub fn write_ppq(&self, ppq: u16) -> Result<(), MIDIWriteError> {
126        self.write_u16_at(12, ppq)
127    }
128
129    pub fn write_format(&self, format: u16) -> Result<(), MIDIWriteError> {
130        self.write_u16_at(8, format)
131    }
132
133    fn write_ntrks(&self, track_count: u16) -> Result<(), MIDIWriteError> {
134        self.write_u16_at(10, track_count)
135    }
136
137    pub fn try_open_next_track(&self) -> Result<TrackWriter<'_>, MIDIWriteError> {
138        let mut tracks = self.tracks.lock().unwrap();
139        if self.output.is_none() {
140            return Err(MIDIWriteError::WriterEnded);
141        }
142
143        let track_id = tracks.next_init_track;
144        if tracks.written_tracks.contains(&track_id) || tracks.opened_tracks.contains(&track_id) {
145            return Err(MIDIWriteError::TrackAlreadyOpened { track_id });
146        }
147
148        tracks.next_init_track += 1;
149        tracks.opened_tracks.insert(track_id);
150
151        Ok(TrackWriter {
152            midi_writer: self,
153            track_id,
154            writer: Some(Cursor::new(Vec::new())),
155        })
156    }
157
158    #[deprecated(note = "use try_open_next_track")]
159    pub fn open_next_track(&self) -> TrackWriter<'_> {
160        self.try_open_next_track()
161            .expect("failed to open next track")
162    }
163
164    pub fn try_open_track(&self, track_id: i32) -> Result<TrackWriter<'_>, MIDIWriteError> {
165        if self.output.is_none() {
166            return Err(MIDIWriteError::WriterEnded);
167        }
168
169        let mut tracks = self.tracks.lock().unwrap();
170        if tracks.written_tracks.contains(&track_id) || tracks.opened_tracks.contains(&track_id) {
171            return Err(MIDIWriteError::TrackAlreadyOpened { track_id });
172        }
173
174        tracks.opened_tracks.insert(track_id);
175
176        Ok(TrackWriter {
177            midi_writer: self,
178            track_id,
179            writer: Some(Cursor::new(Vec::new())),
180        })
181    }
182
183    #[deprecated(note = "use try_open_track")]
184    pub fn open_track(&self, track_id: i32) -> TrackWriter<'_> {
185        self.try_open_track(track_id).expect("failed to open track")
186    }
187
188    pub fn try_end(&mut self) -> Result<(), MIDIWriteError> {
189        if self.is_ended() {
190            return Err(MIDIWriteError::WriterEnded);
191        }
192
193        let (open_tracks, missing_tracks, track_count) = {
194            let tracks = self.tracks.lock().unwrap();
195            if !tracks.opened_tracks.is_empty() {
196                (
197                    Some(sorted_track_ids(&tracks.opened_tracks)),
198                    None,
199                    tracks.written_tracks.len(),
200                )
201            } else if !tracks.queued_writes.is_empty() {
202                let max_track = *tracks
203                    .queued_writes
204                    .keys()
205                    .max()
206                    .expect("queued_writes checked to be non-empty");
207                let mut missing = (0..max_track)
208                    .filter(|track_id| !tracks.written_tracks.contains(track_id))
209                    .collect::<Vec<_>>();
210                missing.sort_unstable();
211                (None, Some(missing), tracks.written_tracks.len())
212            } else {
213                (None, None, tracks.written_tracks.len())
214            }
215        };
216
217        if let Some(track_ids) = open_tracks {
218            return Err(MIDIWriteError::OpenTracksRemaining { track_ids });
219        }
220
221        if let Some(track_ids) = missing_tracks {
222            return Err(MIDIWriteError::TrackGapsRemaining { track_ids });
223        }
224
225        if track_count > u16::MAX as usize {
226            return Err(MIDIWriteError::TrackCountOverflow { track_count });
227        }
228
229        self.write_ntrks(track_count as u16)?;
230        self.output.take();
231
232        Ok(())
233    }
234
235    pub fn end(&mut self) -> Result<(), MIDIWriteError> {
236        self.try_end()
237    }
238
239    pub fn is_ended(&self) -> bool {
240        self.output.is_none()
241    }
242}
243
244impl<'a> TrackWriter<'a> {
245    fn writer_mut(&mut self) -> Result<&mut Cursor<Vec<u8>>, MIDIWriteError> {
246        let track_id = self.track_id;
247        self.writer
248            .as_mut()
249            .ok_or(MIDIWriteError::TrackAlreadyEnded { track_id })
250    }
251
252    pub fn end(&mut self) -> Result<(), MIDIWriteError> {
253        if self.is_ended() {
254            return Err(MIDIWriteError::TrackAlreadyEnded {
255                track_id: self.track_id,
256            });
257        }
258
259        self.write_bytes(&[0x00, 0xFF, 0x2F, 0x00])?;
260
261        let length = self
262            .writer
263            .as_ref()
264            .expect("writer presence was checked above")
265            .position() as u32;
266
267        let mut status = self.midi_writer.tracks.lock().unwrap();
268        if !status.opened_tracks.remove(&self.track_id) {
269            return Err(MIDIWriteError::TrackAlreadyEnded {
270                track_id: self.track_id,
271            });
272        }
273        status.written_tracks.insert(self.track_id);
274
275        let mut writer = self
276            .writer
277            .take()
278            .ok_or(MIDIWriteError::TrackAlreadyEnded {
279                track_id: self.track_id,
280            })?;
281        writer.seek(SeekFrom::Start(0))?;
282
283        status.queued_writes.insert(
284            self.track_id,
285            QueuedOutput {
286                write: Box::new(writer),
287                length,
288            },
289        );
290
291        if self.track_id == status.next_write_track {
292            let output = self
293                .midi_writer
294                .output
295                .as_ref()
296                .ok_or(MIDIWriteError::WriterEnded)?;
297            let mut writer = output.lock().unwrap();
298            loop {
299                let next_write_track = status.next_write_track;
300                match status.queued_writes.remove_entry(&next_write_track) {
301                    None => break,
302                    Some(output) => {
303                        flush_track(writer.as_mut(), output.1)?;
304                        status.next_write_track += 1;
305                    }
306                }
307            }
308        }
309
310        Ok(())
311    }
312
313    pub fn is_ended(&self) -> bool {
314        self.writer.is_none()
315    }
316
317    pub fn get_writer_mut(&mut self) -> &mut impl Write {
318        self.writer_mut()
319            .expect("Tried to write to TrackWriter after .end() was called")
320    }
321
322    pub fn write_event<T: SerializeEventWithDelta>(
323        &mut self,
324        event: T,
325    ) -> Result<usize, MIDIWriteError> {
326        let writer = self.writer_mut()?;
327        event.serialize_event_with_delta(writer)
328    }
329
330    pub fn write_events_iter<T: SerializeEventWithDelta>(
331        &mut self,
332        events: impl Iterator<Item = T>,
333    ) -> Result<usize, MIDIWriteError> {
334        let mut count = 0;
335        for event in events {
336            count += self.write_event(event)?;
337        }
338        Ok(count)
339    }
340
341    pub fn write_bytes(&mut self, bytes: &[u8]) -> Result<usize, MIDIWriteError> {
342        let writer = self.writer_mut()?;
343        writer.write_all(bytes)?;
344        Ok(bytes.len())
345    }
346}
347
348impl<'a> std::fmt::Debug for TrackWriter<'a> {
349    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
350        f.debug_struct("TrackWriter")
351            .field("track_id", &self.track_id)
352            .field("is_ended", &self.is_ended())
353            .finish()
354    }
355}
356
357impl<'a> Drop for TrackWriter<'a> {
358    fn drop(&mut self) {
359        if !self.is_ended() {
360            let _ = self.end();
361        }
362    }
363}
364
365impl Drop for MIDIWriter {
366    fn drop(&mut self) {
367        if !self.is_ended() {
368            let _ = self.end();
369        }
370    }
371}
372
373#[cfg(test)]
374mod tests {
375    use super::{MIDIWriter, WriteSeek};
376    use crate::io::MIDIWriteError;
377    use std::{
378        io::{Cursor, Seek, SeekFrom, Write},
379        sync::{Arc, Mutex},
380    };
381
382    #[derive(Clone, Default)]
383    struct SharedCursor {
384        inner: Arc<Mutex<Cursor<Vec<u8>>>>,
385    }
386
387    impl SharedCursor {
388        fn bytes(&self) -> Vec<u8> {
389            self.inner.lock().unwrap().get_ref().clone()
390        }
391    }
392
393    impl Write for SharedCursor {
394        fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
395            self.inner.lock().unwrap().write(buf)
396        }
397
398        fn flush(&mut self) -> std::io::Result<()> {
399            self.inner.lock().unwrap().flush()
400        }
401    }
402
403    impl Seek for SharedCursor {
404        fn seek(&mut self, pos: SeekFrom) -> std::io::Result<u64> {
405            self.inner.lock().unwrap().seek(pos)
406        }
407    }
408
409    impl WriteSeek for SharedCursor {}
410
411    #[test]
412    fn is_ended_matches_writer_lifecycle() {
413        let shared = SharedCursor::default();
414        let mut writer = MIDIWriter::new_from_stream(Box::new(shared), 480).unwrap();
415        assert!(!writer.is_ended());
416
417        let mut track = writer.try_open_next_track().unwrap();
418        assert!(!track.is_ended());
419
420        track.end().unwrap();
421        assert!(track.is_ended());
422        drop(track);
423        assert!(!writer.is_ended());
424
425        writer.end().unwrap();
426        assert!(writer.is_ended());
427    }
428
429    #[test]
430    fn drop_still_finalizes_open_writers() {
431        let shared = SharedCursor::default();
432        {
433            let writer = MIDIWriter::new_from_stream(Box::new(shared.clone()), 480).unwrap();
434            let mut track = writer.try_open_next_track().unwrap();
435            track.write_bytes(&[0x00, 0x90, 0x3C, 0x40]).unwrap();
436        }
437
438        let bytes = shared.bytes();
439        assert_eq!(&bytes[0..4], b"MThd");
440        assert_eq!(&bytes[10..12], &[0x00, 0x01]);
441        assert_eq!(&bytes[14..18], b"MTrk");
442        assert_eq!(&bytes[18..22], &[0x00, 0x00, 0x00, 0x08]);
443        assert_eq!(
444            &bytes[22..30],
445            &[0x00, 0x90, 0x3C, 0x40, 0x00, 0xFF, 0x2F, 0x00]
446        );
447    }
448
449    #[test]
450    fn try_open_track_reports_duplicates_and_writer_end() {
451        let shared = SharedCursor::default();
452        let mut writer = MIDIWriter::new_from_stream(Box::new(shared), 480).unwrap();
453
454        let mut track = writer.try_open_next_track().unwrap();
455        let err = writer
456            .try_open_track(0)
457            .expect_err("duplicate track should fail");
458        assert!(matches!(
459            err,
460            MIDIWriteError::TrackAlreadyOpened { track_id: 0 }
461        ));
462
463        track.end().unwrap();
464        drop(track);
465        writer.end().unwrap();
466        let err = writer
467            .try_open_track(3)
468            .expect_err("ended writer should fail");
469        assert!(matches!(err, MIDIWriteError::WriterEnded));
470    }
471
472    #[test]
473    fn try_end_reports_missing_track_gaps() {
474        let shared = SharedCursor::default();
475        let mut writer = MIDIWriter::new_from_stream(Box::new(shared), 480).unwrap();
476        let mut track = writer.try_open_track(1).unwrap();
477        track.end().unwrap();
478        drop(track);
479        let gap_err = writer
480            .try_end()
481            .expect_err("missing track 0 should prevent end");
482        assert!(matches!(
483            gap_err,
484            MIDIWriteError::TrackGapsRemaining { ref track_ids } if track_ids == &[0]
485        ));
486    }
487
488    #[test]
489    fn write_bytes_writes_all_bytes() {
490        let shared = SharedCursor::default();
491        {
492            let writer = MIDIWriter::new_from_stream(Box::new(shared.clone()), 480).unwrap();
493            let mut track = writer.try_open_next_track().unwrap();
494            track.write_bytes(&[1, 2, 3, 4]).unwrap();
495            track.end().unwrap();
496        }
497
498        let bytes = shared.bytes();
499        assert!(bytes.windows(4).any(|window| window == [1, 2, 3, 4]));
500    }
501}