Skip to main content

midi_toolkit/io/
midi_writer.rs

1use std::{
2    fs::File,
3    io::{self, copy, Cursor, Read, Seek, SeekFrom, Write},
4    sync::Mutex,
5};
6
7use crate::events::SerializeEventWithDelta;
8
9use super::{
10    errors::MIDIWriteError,
11    writer_common::{
12        encode_u16, write_midi_header, write_track_header, OrderedTrackRegistry, TrackByteSink,
13    },
14};
15
16pub trait WriteSeek: Write + Seek {}
17impl WriteSeek for File {}
18impl WriteSeek for Cursor<Vec<u8>> {}
19
20pub struct QueuedOutput {
21    write: Box<dyn Read>,
22    length: u32,
23}
24
25pub struct MIDIWriter {
26    output: Option<Mutex<Box<dyn WriteSeek>>>,
27    tracks: Mutex<OrderedTrackRegistry<QueuedOutput>>,
28}
29
30pub struct TrackWriter<'a> {
31    midi_writer: &'a MIDIWriter,
32    track: TrackByteSink<Cursor<Vec<u8>>>,
33}
34
35fn flush_track(writer: &mut dyn WriteSeek, mut output: QueuedOutput) -> Result<(), io::Error> {
36    write_track_header(writer, output.length)?;
37    copy(&mut output.write, writer)?;
38    Ok(())
39}
40
41impl MIDIWriter {
42    pub fn new(filename: &str, ppq: u16) -> Result<MIDIWriter, MIDIWriteError> {
43        let reader = File::create(filename)?;
44        MIDIWriter::new_from_stream(Box::new(reader), ppq)
45    }
46
47    pub fn new_from_stream(
48        mut output: Box<dyn WriteSeek>,
49        ppq: u16,
50    ) -> Result<MIDIWriter, MIDIWriteError> {
51        output.seek(SeekFrom::Start(0))?;
52        write_midi_header(output.as_mut(), 1, 0, ppq)?;
53
54        Ok(MIDIWriter {
55            output: Some(Mutex::new(output)),
56            tracks: Mutex::new(OrderedTrackRegistry::new()),
57        })
58    }
59
60    #[deprecated(note = "use new_from_stream")]
61    pub fn new_from_stram(
62        output: Box<dyn WriteSeek>,
63        ppq: u16,
64    ) -> Result<MIDIWriter, MIDIWriteError> {
65        Self::new_from_stream(output, ppq)
66    }
67
68    fn with_writer<R>(
69        &self,
70        f: impl FnOnce(&mut dyn WriteSeek) -> Result<R, io::Error>,
71    ) -> Result<R, MIDIWriteError> {
72        let output = self.output.as_ref().ok_or(MIDIWriteError::WriterEnded)?;
73        let mut output = output.lock().unwrap();
74        Ok(f(output.as_mut())?)
75    }
76
77    fn write_u16_at(&self, at: u64, val: u16) -> Result<(), MIDIWriteError> {
78        self.with_writer(|output| {
79            let pos = output.stream_position()?;
80            output.seek(SeekFrom::Start(at))?;
81            output.write_all(&encode_u16(val))?;
82            output.seek(SeekFrom::Start(pos))?;
83            Ok(())
84        })
85    }
86
87    pub fn write_ppq(&self, ppq: u16) -> Result<(), MIDIWriteError> {
88        self.write_u16_at(12, ppq)
89    }
90
91    pub fn write_format(&self, format: u16) -> Result<(), MIDIWriteError> {
92        self.write_u16_at(8, format)
93    }
94
95    fn write_ntrks(&self, track_count: u16) -> Result<(), MIDIWriteError> {
96        self.write_u16_at(10, track_count)
97    }
98
99    pub fn try_open_next_track(&self) -> Result<TrackWriter<'_>, MIDIWriteError> {
100        let mut tracks = self.tracks.lock().unwrap();
101        if self.output.is_none() {
102            return Err(MIDIWriteError::WriterEnded);
103        }
104
105        let track_id = tracks.open_next_track()?;
106
107        Ok(TrackWriter {
108            midi_writer: self,
109            track: TrackByteSink::new(track_id, Cursor::new(Vec::new())),
110        })
111    }
112
113    #[deprecated(note = "use try_open_next_track")]
114    pub fn open_next_track(&self) -> TrackWriter<'_> {
115        self.try_open_next_track()
116            .expect("failed to open next track")
117    }
118
119    pub fn try_open_track(&self, track_id: i32) -> Result<TrackWriter<'_>, MIDIWriteError> {
120        if self.output.is_none() {
121            return Err(MIDIWriteError::WriterEnded);
122        }
123
124        let mut tracks = self.tracks.lock().unwrap();
125        tracks.open_track(track_id)?;
126
127        Ok(TrackWriter {
128            midi_writer: self,
129            track: TrackByteSink::new(track_id, Cursor::new(Vec::new())),
130        })
131    }
132
133    #[deprecated(note = "use try_open_track")]
134    pub fn open_track(&self, track_id: i32) -> TrackWriter<'_> {
135        self.try_open_track(track_id).expect("failed to open track")
136    }
137
138    pub fn try_end(&mut self) -> Result<(), MIDIWriteError> {
139        if self.is_ended() {
140            return Err(MIDIWriteError::WriterEnded);
141        }
142
143        let track_count = self.tracks.lock().unwrap().finalize_track_count()?;
144        self.write_ntrks(track_count)?;
145        self.output.take();
146
147        Ok(())
148    }
149
150    pub fn end(&mut self) -> Result<(), MIDIWriteError> {
151        self.try_end()
152    }
153
154    pub fn is_ended(&self) -> bool {
155        self.output.is_none()
156    }
157}
158
159impl<'a> TrackWriter<'a> {
160    fn writer_mut(&mut self) -> Result<&mut Cursor<Vec<u8>>, MIDIWriteError> {
161        self.track.writer_mut()
162    }
163
164    pub fn end(&mut self) -> Result<(), MIDIWriteError> {
165        let track_id = self.track.track_id();
166        let (mut writer, length) = self.track.finish()?;
167        writer.seek(SeekFrom::Start(0))?;
168
169        let queued_outputs = {
170            let mut status = self.midi_writer.tracks.lock().unwrap();
171            status.finish_track(
172                track_id,
173                QueuedOutput {
174                    write: Box::new(writer),
175                    length,
176                },
177            )?;
178            status.drain_ready_tracks()
179        };
180
181        if !queued_outputs.is_empty() {
182            let output = self
183                .midi_writer
184                .output
185                .as_ref()
186                .ok_or(MIDIWriteError::WriterEnded)?;
187            let mut writer = output.lock().unwrap();
188            for (_, queued_output) in queued_outputs {
189                flush_track(writer.as_mut(), queued_output)?;
190            }
191        }
192
193        Ok(())
194    }
195
196    pub fn is_ended(&self) -> bool {
197        self.track.is_ended()
198    }
199
200    pub fn get_writer_mut(&mut self) -> &mut impl Write {
201        self.writer_mut()
202            .expect("Tried to write to TrackWriter after .end() was called")
203    }
204
205    pub fn write_event<T: SerializeEventWithDelta>(
206        &mut self,
207        event: T,
208    ) -> Result<usize, MIDIWriteError> {
209        self.track.write_event(event)
210    }
211
212    pub fn write_events_iter<T: SerializeEventWithDelta>(
213        &mut self,
214        events: impl Iterator<Item = T>,
215    ) -> Result<usize, MIDIWriteError> {
216        self.track.write_events_iter(events)
217    }
218
219    pub fn write_bytes(&mut self, bytes: &[u8]) -> Result<usize, MIDIWriteError> {
220        self.track.write_bytes(bytes)
221    }
222}
223
224impl<'a> std::fmt::Debug for TrackWriter<'a> {
225    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
226        f.debug_struct("TrackWriter")
227            .field("track_id", &self.track.track_id())
228            .field("is_ended", &self.is_ended())
229            .finish()
230    }
231}
232
233impl<'a> Drop for TrackWriter<'a> {
234    fn drop(&mut self) {
235        if !self.is_ended() {
236            let _ = self.end();
237        }
238    }
239}
240
241impl Drop for MIDIWriter {
242    fn drop(&mut self) {
243        if !self.is_ended() {
244            let _ = self.end();
245        }
246    }
247}
248
249#[cfg(test)]
250mod tests {
251    use super::{MIDIWriter, WriteSeek};
252    use crate::io::MIDIWriteError;
253    use std::{
254        io::{Cursor, Seek, SeekFrom, Write},
255        sync::{Arc, Mutex},
256    };
257
258    #[derive(Clone, Default)]
259    struct SharedCursor {
260        inner: Arc<Mutex<Cursor<Vec<u8>>>>,
261    }
262
263    impl SharedCursor {
264        fn bytes(&self) -> Vec<u8> {
265            self.inner.lock().unwrap().get_ref().clone()
266        }
267    }
268
269    impl Write for SharedCursor {
270        fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
271            self.inner.lock().unwrap().write(buf)
272        }
273
274        fn flush(&mut self) -> std::io::Result<()> {
275            self.inner.lock().unwrap().flush()
276        }
277    }
278
279    impl Seek for SharedCursor {
280        fn seek(&mut self, pos: SeekFrom) -> std::io::Result<u64> {
281            self.inner.lock().unwrap().seek(pos)
282        }
283    }
284
285    impl WriteSeek for SharedCursor {}
286
287    #[test]
288    fn is_ended_matches_writer_lifecycle() {
289        let shared = SharedCursor::default();
290        let mut writer = MIDIWriter::new_from_stream(Box::new(shared), 480).unwrap();
291        assert!(!writer.is_ended());
292
293        let mut track = writer.try_open_next_track().unwrap();
294        assert!(!track.is_ended());
295
296        track.end().unwrap();
297        assert!(track.is_ended());
298        drop(track);
299        assert!(!writer.is_ended());
300
301        writer.end().unwrap();
302        assert!(writer.is_ended());
303    }
304
305    #[test]
306    fn drop_still_finalizes_open_writers() {
307        let shared = SharedCursor::default();
308        {
309            let writer = MIDIWriter::new_from_stream(Box::new(shared.clone()), 480).unwrap();
310            let mut track = writer.try_open_next_track().unwrap();
311            track.write_bytes(&[0x00, 0x90, 0x3C, 0x40]).unwrap();
312        }
313
314        let bytes = shared.bytes();
315        assert_eq!(&bytes[0..4], b"MThd");
316        assert_eq!(&bytes[10..12], &[0x00, 0x01]);
317        assert_eq!(&bytes[14..18], b"MTrk");
318        assert_eq!(&bytes[18..22], &[0x00, 0x00, 0x00, 0x08]);
319        assert_eq!(
320            &bytes[22..30],
321            &[0x00, 0x90, 0x3C, 0x40, 0x00, 0xFF, 0x2F, 0x00]
322        );
323    }
324
325    #[test]
326    fn try_open_track_reports_duplicates_and_writer_end() {
327        let shared = SharedCursor::default();
328        let mut writer = MIDIWriter::new_from_stream(Box::new(shared), 480).unwrap();
329
330        let mut track = writer.try_open_next_track().unwrap();
331        let err = writer
332            .try_open_track(0)
333            .expect_err("duplicate track should fail");
334        assert!(matches!(
335            err,
336            MIDIWriteError::TrackAlreadyOpened { track_id: 0 }
337        ));
338
339        track.end().unwrap();
340        drop(track);
341        writer.end().unwrap();
342        let err = writer
343            .try_open_track(3)
344            .expect_err("ended writer should fail");
345        assert!(matches!(err, MIDIWriteError::WriterEnded));
346    }
347
348    #[test]
349    fn try_end_reports_missing_track_gaps() {
350        let shared = SharedCursor::default();
351        let mut writer = MIDIWriter::new_from_stream(Box::new(shared), 480).unwrap();
352        let mut track = writer.try_open_track(1).unwrap();
353        track.end().unwrap();
354        drop(track);
355        let gap_err = writer
356            .try_end()
357            .expect_err("missing track 0 should prevent end");
358        assert!(matches!(
359            gap_err,
360            MIDIWriteError::TrackGapsRemaining { ref track_ids } if track_ids == &[0]
361        ));
362    }
363
364    #[test]
365    fn write_bytes_writes_all_bytes() {
366        let shared = SharedCursor::default();
367        {
368            let writer = MIDIWriter::new_from_stream(Box::new(shared.clone()), 480).unwrap();
369            let mut track = writer.try_open_next_track().unwrap();
370            track.write_bytes(&[1, 2, 3, 4]).unwrap();
371            track.end().unwrap();
372        }
373
374        let bytes = shared.bytes();
375        assert!(bytes.windows(4).any(|window| window == [1, 2, 3, 4]));
376    }
377}