midi_toolkit/io/
midi_writer.rs

1use std::{
2    collections::{HashMap, HashSet},
3    fs::File,
4    io::{self, copy, Cursor, Read, Seek, SeekFrom, Write},
5    sync::Mutex,
6};
7
8use crate::events::SerializeEventWithDelta;
9
10use super::errors::MIDIWriteError;
11
12pub trait WriteSeek: Write + Seek {}
13impl WriteSeek for File {}
14impl WriteSeek for Cursor<Vec<u8>> {}
15
16pub struct QueuedOutput {
17    write: Box<dyn Read>,
18    length: u32,
19}
20
21struct TrackStatus {
22    opened_tracks: HashSet<i32>,
23    written_tracks: HashSet<i32>,
24    next_init_track: i32,
25    next_write_track: i32,
26    queued_writes: HashMap<i32, QueuedOutput>,
27}
28
29pub struct MIDIWriter {
30    output: Option<Mutex<Box<dyn WriteSeek>>>,
31    tracks: Mutex<TrackStatus>,
32}
33
34pub struct TrackWriter<'a> {
35    midi_writer: &'a MIDIWriter,
36    track_id: i32,
37    writer: Option<Cursor<Vec<u8>>>,
38}
39
40fn encode_u16(val: u16) -> [u8; 2] {
41    let mut bytes = [0; 2];
42    bytes[0] = ((val >> 8) & 0xff) as u8;
43    bytes[1] = (val & 0xff) as u8;
44    bytes
45}
46
47fn encode_u32(val: u32) -> [u8; 4] {
48    let mut bytes = [0; 4];
49    bytes[0] = ((val >> 24) & 0xff) as u8;
50    bytes[1] = ((val >> 16) & 0xff) as u8;
51    bytes[2] = ((val >> 8) & 0xff) as u8;
52    bytes[3] = (val & 0xff) as u8;
53    bytes
54}
55
56fn flush_track(writer: &mut Box<dyn WriteSeek>, mut output: QueuedOutput) -> Result<(), io::Error> {
57    writer.write_all("MTrk".as_bytes())?;
58    writer.write_all(&encode_u32(output.length))?;
59    copy(&mut output.write, writer)?;
60    Ok(())
61}
62
63impl MIDIWriter {
64    pub fn new(filename: &str, ppq: u16) -> Result<MIDIWriter, MIDIWriteError> {
65        let reader = File::create(filename)?;
66        MIDIWriter::new_from_stram(Box::new(reader), ppq)
67    }
68
69    pub fn new_from_stram(
70        mut output: Box<dyn WriteSeek>,
71        ppq: u16,
72    ) -> Result<MIDIWriter, MIDIWriteError> {
73        output.seek(SeekFrom::Start(0))?;
74        output.write_all("MThd".as_bytes())?;
75        output.write_all(&encode_u32(6))?;
76        output.write_all(&encode_u16(1))?;
77        output.write_all(&encode_u16(0))?;
78        output.write_all(&encode_u16(ppq))?;
79
80        Ok(MIDIWriter {
81            output: Some(Mutex::new(output)),
82            tracks: Mutex::new(TrackStatus {
83                opened_tracks: HashSet::new(),
84                next_init_track: 0,
85                next_write_track: 0,
86                queued_writes: HashMap::new(),
87                written_tracks: HashSet::new(),
88            }),
89        })
90    }
91
92    fn get_writer(&self) -> &Mutex<Box<dyn WriteSeek>> {
93        self.output
94            .as_ref()
95            .expect("Can't get the writer of an ended MIDIWriter")
96    }
97
98    fn write_u16_at(&self, at: u64, val: u16) -> Result<(), io::Error> {
99        let mut output = self.get_writer().lock().unwrap();
100        let pos = output.stream_position()?;
101        output.seek(SeekFrom::Start(at))?;
102        output.write_all(&encode_u16(val))?;
103        output.seek(SeekFrom::Start(pos))?;
104        Ok(())
105    }
106
107    pub fn write_ppq(&self, ppq: u16) -> Result<(), MIDIWriteError> {
108        Ok(self.write_u16_at(12, ppq)?)
109    }
110
111    pub fn write_format(&self, ppq: u16) -> Result<(), MIDIWriteError> {
112        Ok(self.write_u16_at(8, ppq)?)
113    }
114
115    fn write_ntrks(&self, ppq: u16) -> Result<(), MIDIWriteError> {
116        Ok(self.write_u16_at(10, ppq)?)
117    }
118
119    pub fn open_next_track(&self) -> TrackWriter {
120        let track_id = {
121            let mut tracks = self.tracks.lock().unwrap();
122            let track_id = tracks.next_init_track;
123            tracks.next_init_track += 1;
124            track_id
125        };
126        self.open_track(track_id)
127    }
128
129    pub fn open_track(&self, track_id: i32) -> TrackWriter {
130        self.add_opened_track(track_id);
131        TrackWriter {
132            midi_writer: self,
133            track_id,
134            writer: Some(Cursor::new(Vec::new())),
135        }
136    }
137
138    fn add_opened_track(&self, track_id: i32) {
139        let mut tracks = self.tracks.lock().unwrap();
140        if tracks.written_tracks.contains(&track_id) || !tracks.opened_tracks.insert(track_id) {
141            panic!("Track with id {} has aready been opened before", track_id);
142        }
143    }
144
145    pub fn end(&mut self) -> Result<(), MIDIWriteError> {
146        let tracks = self.tracks.lock().unwrap();
147        if !tracks.opened_tracks.is_empty() {
148            let unwritten: Vec<&i32> = tracks.queued_writes.keys().collect();
149            panic!("Not all tracks have been ended! Make sure you drop or call .end() on each track before ending the MIDIWriter\nMissing tracks {:?}", unwritten);
150        }
151        if !tracks.queued_writes.is_empty() {
152            let max_track = tracks.queued_writes.keys().max().unwrap();
153            let unwritten: Vec<i32> = (0..*max_track)
154                .filter(|track_id| !tracks.written_tracks.contains(track_id))
155                .collect();
156            panic!(
157                "Not all tracks have been opened! Missing tracks {:?}",
158                unwritten
159            );
160        }
161
162        let track_count = tracks.written_tracks.len();
163        self.write_ntrks(track_count.min(u16::MAX as usize) as u16)?;
164
165        self.output.take();
166
167        Ok(())
168    }
169
170    pub fn is_ended(&self) -> bool {
171        self.output.is_some()
172    }
173}
174
175impl<'a> TrackWriter<'a> {
176    pub fn end(&mut self) -> Result<(), MIDIWriteError> {
177        self.write_bytes(&[0x00, 0xFF, 0x2F, 0x00])?;
178
179        let mut status = self.midi_writer.tracks.lock().unwrap();
180        if !status.written_tracks.insert(self.track_id)
181            || !status.opened_tracks.remove(&self.track_id)
182        {
183            panic!("Invalid MIDIWriter state, unknown error");
184        }
185
186        let mut writer = match self.writer.take() {
187            Some(cursor) => cursor,
188            None => panic!(".end() was called more than once on TrackWriter"),
189        };
190
191        let length = writer.stream_position()? as u32;
192        writer.seek(SeekFrom::Start(0))?;
193
194        status.queued_writes.insert(
195            self.track_id,
196            QueuedOutput {
197                write: Box::new(writer),
198                length,
199            },
200        );
201
202        if self.track_id == status.next_write_track {
203            let mut writer = self.midi_writer.get_writer().lock().unwrap();
204            loop {
205                let next_write_track = status.next_write_track;
206                match status.queued_writes.remove_entry(&next_write_track) {
207                    None => break,
208                    Some(output) => {
209                        flush_track(&mut writer, output.1)?;
210                        status.next_write_track += 1;
211                    }
212                }
213            }
214        }
215
216        Ok(())
217    }
218
219    pub fn is_ended(&self) -> bool {
220        self.writer.is_some()
221    }
222
223    pub fn get_writer_mut(&mut self) -> &mut impl Write {
224        self.writer
225            .as_mut()
226            .expect("Tried to write to TrackWriter after .end() was called")
227    }
228
229    pub fn write_event<T: SerializeEventWithDelta>(
230        &mut self,
231        event: T,
232    ) -> Result<usize, MIDIWriteError> {
233        let writer = self.get_writer_mut();
234        event.serialize_event_with_delta(writer)
235    }
236
237    pub fn write_events_iter<T: SerializeEventWithDelta>(
238        &mut self,
239        events: impl Iterator<Item = T>,
240    ) -> Result<usize, MIDIWriteError> {
241        let mut count = 0;
242        for event in events {
243            count += self.write_event(event)?;
244        }
245        Ok(count)
246    }
247
248    pub fn write_bytes(&mut self, bytes: &[u8]) -> Result<usize, MIDIWriteError> {
249        let writer = self.get_writer_mut();
250        Ok(writer.write(bytes)?)
251    }
252}
253
254impl<'a> Drop for TrackWriter<'a> {
255    fn drop(&mut self) {
256        if self.is_ended() {
257            match self.end() {
258                Ok(()) => {}
259                Err(e) => {
260                    panic!("TrackWriter errored when being dropped with: {:?}\n\nIf you want to handle these errors in the future, manually call .end() before dropping", e);
261                }
262            }
263        }
264    }
265}
266
267impl Drop for MIDIWriter {
268    fn drop(&mut self) {
269        if self.is_ended() {
270            match self.end() {
271                Ok(()) => {}
272                Err(e) => {
273                    panic!("TrackWriter errored when being dropped with: {:?}\n\nIf you want to handle these errors in the future, manually call .end() before dropping", e);
274                }
275            }
276        }
277    }
278}