1use std::{
2 fs::File,
3 io::{self, copy, Cursor, Read, Seek, SeekFrom, Write},
4 sync::Mutex,
5};
6
7use crate::events::SerializeEventWithDelta;
8
9use super::{
10 errors::MIDIWriteError,
11 writer_common::{
12 encode_u16, write_midi_header, write_track_header, OrderedTrackRegistry, TrackByteSink,
13 },
14};
15
16pub trait WriteSeek: Write + Seek {}
17impl WriteSeek for File {}
18impl WriteSeek for Cursor<Vec<u8>> {}
19
20pub struct QueuedOutput {
21 write: Box<dyn Read>,
22 length: u32,
23}
24
25pub struct MIDIWriter {
26 output: Option<Mutex<Box<dyn WriteSeek>>>,
27 tracks: Mutex<OrderedTrackRegistry<QueuedOutput>>,
28}
29
30pub struct TrackWriter<'a> {
31 midi_writer: &'a MIDIWriter,
32 track: TrackByteSink<Cursor<Vec<u8>>>,
33}
34
35fn flush_track(writer: &mut dyn WriteSeek, mut output: QueuedOutput) -> Result<(), io::Error> {
36 write_track_header(writer, output.length)?;
37 copy(&mut output.write, writer)?;
38 Ok(())
39}
40
41impl MIDIWriter {
42 pub fn new(filename: &str, ppq: u16) -> Result<MIDIWriter, MIDIWriteError> {
43 let reader = File::create(filename)?;
44 MIDIWriter::new_from_stream(Box::new(reader), ppq)
45 }
46
47 pub fn new_from_stream(
48 mut output: Box<dyn WriteSeek>,
49 ppq: u16,
50 ) -> Result<MIDIWriter, MIDIWriteError> {
51 output.seek(SeekFrom::Start(0))?;
52 write_midi_header(output.as_mut(), 1, 0, ppq)?;
53
54 Ok(MIDIWriter {
55 output: Some(Mutex::new(output)),
56 tracks: Mutex::new(OrderedTrackRegistry::new()),
57 })
58 }
59
60 #[deprecated(note = "use new_from_stream")]
61 pub fn new_from_stram(
62 output: Box<dyn WriteSeek>,
63 ppq: u16,
64 ) -> Result<MIDIWriter, MIDIWriteError> {
65 Self::new_from_stream(output, ppq)
66 }
67
68 fn with_writer<R>(
69 &self,
70 f: impl FnOnce(&mut dyn WriteSeek) -> Result<R, io::Error>,
71 ) -> Result<R, MIDIWriteError> {
72 let output = self.output.as_ref().ok_or(MIDIWriteError::WriterEnded)?;
73 let mut output = output.lock().unwrap();
74 Ok(f(output.as_mut())?)
75 }
76
77 fn write_u16_at(&self, at: u64, val: u16) -> Result<(), MIDIWriteError> {
78 self.with_writer(|output| {
79 let pos = output.stream_position()?;
80 output.seek(SeekFrom::Start(at))?;
81 output.write_all(&encode_u16(val))?;
82 output.seek(SeekFrom::Start(pos))?;
83 Ok(())
84 })
85 }
86
87 pub fn write_ppq(&self, ppq: u16) -> Result<(), MIDIWriteError> {
88 self.write_u16_at(12, ppq)
89 }
90
91 pub fn write_format(&self, format: u16) -> Result<(), MIDIWriteError> {
92 self.write_u16_at(8, format)
93 }
94
95 fn write_ntrks(&self, track_count: u16) -> Result<(), MIDIWriteError> {
96 self.write_u16_at(10, track_count)
97 }
98
99 pub fn try_open_next_track(&self) -> Result<TrackWriter<'_>, MIDIWriteError> {
100 let mut tracks = self.tracks.lock().unwrap();
101 if self.output.is_none() {
102 return Err(MIDIWriteError::WriterEnded);
103 }
104
105 let track_id = tracks.open_next_track()?;
106
107 Ok(TrackWriter {
108 midi_writer: self,
109 track: TrackByteSink::new(track_id, Cursor::new(Vec::new())),
110 })
111 }
112
113 #[deprecated(note = "use try_open_next_track")]
114 pub fn open_next_track(&self) -> TrackWriter<'_> {
115 self.try_open_next_track()
116 .expect("failed to open next track")
117 }
118
119 pub fn try_open_track(&self, track_id: i32) -> Result<TrackWriter<'_>, MIDIWriteError> {
120 if self.output.is_none() {
121 return Err(MIDIWriteError::WriterEnded);
122 }
123
124 let mut tracks = self.tracks.lock().unwrap();
125 tracks.open_track(track_id)?;
126
127 Ok(TrackWriter {
128 midi_writer: self,
129 track: TrackByteSink::new(track_id, Cursor::new(Vec::new())),
130 })
131 }
132
133 #[deprecated(note = "use try_open_track")]
134 pub fn open_track(&self, track_id: i32) -> TrackWriter<'_> {
135 self.try_open_track(track_id).expect("failed to open track")
136 }
137
138 pub fn try_end(&mut self) -> Result<(), MIDIWriteError> {
139 if self.is_ended() {
140 return Err(MIDIWriteError::WriterEnded);
141 }
142
143 let track_count = self.tracks.lock().unwrap().finalize_track_count()?;
144 self.write_ntrks(track_count)?;
145 self.output.take();
146
147 Ok(())
148 }
149
150 pub fn end(&mut self) -> Result<(), MIDIWriteError> {
151 self.try_end()
152 }
153
154 pub fn is_ended(&self) -> bool {
155 self.output.is_none()
156 }
157}
158
159impl<'a> TrackWriter<'a> {
160 fn writer_mut(&mut self) -> Result<&mut Cursor<Vec<u8>>, MIDIWriteError> {
161 self.track.writer_mut()
162 }
163
164 pub fn end(&mut self) -> Result<(), MIDIWriteError> {
165 let track_id = self.track.track_id();
166 let (mut writer, length) = self.track.finish()?;
167 writer.seek(SeekFrom::Start(0))?;
168
169 let queued_outputs = {
170 let mut status = self.midi_writer.tracks.lock().unwrap();
171 status.finish_track(
172 track_id,
173 QueuedOutput {
174 write: Box::new(writer),
175 length,
176 },
177 )?;
178 status.drain_ready_tracks()
179 };
180
181 if !queued_outputs.is_empty() {
182 let output = self
183 .midi_writer
184 .output
185 .as_ref()
186 .ok_or(MIDIWriteError::WriterEnded)?;
187 let mut writer = output.lock().unwrap();
188 for (_, queued_output) in queued_outputs {
189 flush_track(writer.as_mut(), queued_output)?;
190 }
191 }
192
193 Ok(())
194 }
195
196 pub fn is_ended(&self) -> bool {
197 self.track.is_ended()
198 }
199
200 pub fn get_writer_mut(&mut self) -> &mut impl Write {
201 self.writer_mut()
202 .expect("Tried to write to TrackWriter after .end() was called")
203 }
204
205 pub fn write_event<T: SerializeEventWithDelta>(
206 &mut self,
207 event: T,
208 ) -> Result<usize, MIDIWriteError> {
209 self.track.write_event(event)
210 }
211
212 pub fn write_events_iter<T: SerializeEventWithDelta>(
213 &mut self,
214 events: impl Iterator<Item = T>,
215 ) -> Result<usize, MIDIWriteError> {
216 self.track.write_events_iter(events)
217 }
218
219 pub fn write_bytes(&mut self, bytes: &[u8]) -> Result<usize, MIDIWriteError> {
220 self.track.write_bytes(bytes)
221 }
222}
223
224impl<'a> std::fmt::Debug for TrackWriter<'a> {
225 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
226 f.debug_struct("TrackWriter")
227 .field("track_id", &self.track.track_id())
228 .field("is_ended", &self.is_ended())
229 .finish()
230 }
231}
232
233impl<'a> Drop for TrackWriter<'a> {
234 fn drop(&mut self) {
235 if !self.is_ended() {
236 let _ = self.end();
237 }
238 }
239}
240
241impl Drop for MIDIWriter {
242 fn drop(&mut self) {
243 if !self.is_ended() {
244 let _ = self.end();
245 }
246 }
247}
248
249#[cfg(test)]
250mod tests {
251 use super::{MIDIWriter, WriteSeek};
252 use crate::io::MIDIWriteError;
253 use std::{
254 io::{Cursor, Seek, SeekFrom, Write},
255 sync::{Arc, Mutex},
256 };
257
258 #[derive(Clone, Default)]
259 struct SharedCursor {
260 inner: Arc<Mutex<Cursor<Vec<u8>>>>,
261 }
262
263 impl SharedCursor {
264 fn bytes(&self) -> Vec<u8> {
265 self.inner.lock().unwrap().get_ref().clone()
266 }
267 }
268
269 impl Write for SharedCursor {
270 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
271 self.inner.lock().unwrap().write(buf)
272 }
273
274 fn flush(&mut self) -> std::io::Result<()> {
275 self.inner.lock().unwrap().flush()
276 }
277 }
278
279 impl Seek for SharedCursor {
280 fn seek(&mut self, pos: SeekFrom) -> std::io::Result<u64> {
281 self.inner.lock().unwrap().seek(pos)
282 }
283 }
284
285 impl WriteSeek for SharedCursor {}
286
287 #[test]
288 fn is_ended_matches_writer_lifecycle() {
289 let shared = SharedCursor::default();
290 let mut writer = MIDIWriter::new_from_stream(Box::new(shared), 480).unwrap();
291 assert!(!writer.is_ended());
292
293 let mut track = writer.try_open_next_track().unwrap();
294 assert!(!track.is_ended());
295
296 track.end().unwrap();
297 assert!(track.is_ended());
298 drop(track);
299 assert!(!writer.is_ended());
300
301 writer.end().unwrap();
302 assert!(writer.is_ended());
303 }
304
305 #[test]
306 fn drop_still_finalizes_open_writers() {
307 let shared = SharedCursor::default();
308 {
309 let writer = MIDIWriter::new_from_stream(Box::new(shared.clone()), 480).unwrap();
310 let mut track = writer.try_open_next_track().unwrap();
311 track.write_bytes(&[0x00, 0x90, 0x3C, 0x40]).unwrap();
312 }
313
314 let bytes = shared.bytes();
315 assert_eq!(&bytes[0..4], b"MThd");
316 assert_eq!(&bytes[10..12], &[0x00, 0x01]);
317 assert_eq!(&bytes[14..18], b"MTrk");
318 assert_eq!(&bytes[18..22], &[0x00, 0x00, 0x00, 0x08]);
319 assert_eq!(
320 &bytes[22..30],
321 &[0x00, 0x90, 0x3C, 0x40, 0x00, 0xFF, 0x2F, 0x00]
322 );
323 }
324
325 #[test]
326 fn try_open_track_reports_duplicates_and_writer_end() {
327 let shared = SharedCursor::default();
328 let mut writer = MIDIWriter::new_from_stream(Box::new(shared), 480).unwrap();
329
330 let mut track = writer.try_open_next_track().unwrap();
331 let err = writer
332 .try_open_track(0)
333 .expect_err("duplicate track should fail");
334 assert!(matches!(
335 err,
336 MIDIWriteError::TrackAlreadyOpened { track_id: 0 }
337 ));
338
339 track.end().unwrap();
340 drop(track);
341 writer.end().unwrap();
342 let err = writer
343 .try_open_track(3)
344 .expect_err("ended writer should fail");
345 assert!(matches!(err, MIDIWriteError::WriterEnded));
346 }
347
348 #[test]
349 fn try_end_reports_missing_track_gaps() {
350 let shared = SharedCursor::default();
351 let mut writer = MIDIWriter::new_from_stream(Box::new(shared), 480).unwrap();
352 let mut track = writer.try_open_track(1).unwrap();
353 track.end().unwrap();
354 drop(track);
355 let gap_err = writer
356 .try_end()
357 .expect_err("missing track 0 should prevent end");
358 assert!(matches!(
359 gap_err,
360 MIDIWriteError::TrackGapsRemaining { ref track_ids } if track_ids == &[0]
361 ));
362 }
363
364 #[test]
365 fn write_bytes_writes_all_bytes() {
366 let shared = SharedCursor::default();
367 {
368 let writer = MIDIWriter::new_from_stream(Box::new(shared.clone()), 480).unwrap();
369 let mut track = writer.try_open_next_track().unwrap();
370 track.write_bytes(&[1, 2, 3, 4]).unwrap();
371 track.end().unwrap();
372 }
373
374 let bytes = shared.bytes();
375 assert!(bytes.windows(4).any(|window| window == [1, 2, 3, 4]));
376 }
377}