midi_toolkit/io/
midi_file.rs

1use std::{
2    fs::File,
3    io::{Read, Seek},
4    path::Path,
5};
6
7use crate::{
8    events::Event,
9    sequence::{
10        channels_into_threadpool,
11        event::{
12            convert_events_into_batches, flatten_batches_to_events,
13            flatten_track_batches_to_events, into_track_events, merge_events_array, Delta,
14            EventBatch, Track,
15        },
16    },
17};
18use std::fmt::Debug;
19
20use super::{
21    errors::{MIDILoadError, MIDIParseError},
22    readers::{DiskReader, MIDIReader, RAMReader},
23    track_parser::TrackParser,
24};
25
26#[derive(Debug)]
27struct TrackPos {
28    pos: u64,
29    len: u32,
30}
31
32#[derive(Debug)]
33pub struct MIDIFile<T: MIDIReader> {
34    reader: T,
35    track_positions: Vec<TrackPos>,
36
37    format: u16,
38    ppq: u16,
39}
40
41impl<T: 'static + MIDIReader> MIDIFile<T> {
42    fn new_from_disk_reader(
43        reader: T,
44        mut read_progress: Option<&mut dyn FnMut(u32)>,
45    ) -> Result<Self, MIDILoadError> {
46        fn bytes_to_val(bytes: &[u8]) -> u32 {
47            assert!(bytes.len() <= 4);
48            let mut num: u32 = 0;
49            for b in bytes {
50                num = (num << 8) + *b as u32;
51            }
52
53            num
54        }
55
56        fn read_header<T: MIDIReader>(
57            reader: &T,
58            pos: u64,
59            text: &str,
60        ) -> Result<u32, MIDILoadError> {
61            assert!(text.len() == 4);
62
63            let bytes = reader.read_bytes(pos, 8)?;
64
65            let (header, len) = bytes.split_at(4);
66
67            let chars = text.as_bytes();
68
69            for i in 0..chars.len() {
70                if chars[i] != header[i] {
71                    return Err(MIDILoadError::CorruptChunks);
72                }
73            }
74
75            Ok(bytes_to_val(len))
76        }
77
78        let mut pos = 0u64;
79
80        let header_len = read_header(&reader, pos, "MThd")?;
81        pos += 8;
82        if header_len != 6 {
83            return Err(MIDILoadError::CorruptChunks);
84        }
85
86        let (format, ppq) = {
87            let header_data = reader.read_bytes(pos, 6)?;
88            pos += 6;
89            let (format_bytes, rest) = header_data.split_at(2);
90            let (_, ppq_bytes) = rest.split_at(2);
91            (
92                bytes_to_val(format_bytes) as u16,
93                bytes_to_val(ppq_bytes) as u16,
94            )
95        };
96
97        let mut track_count = 0;
98        let mut track_positions = Vec::<TrackPos>::new();
99        while pos != reader.len() {
100            let len = read_header(&reader, pos, "MTrk")?;
101            pos += 8;
102            track_count += 1;
103            track_positions.push(TrackPos { len, pos });
104            pos += len as u64;
105
106            if let Some(progress) = read_progress.as_mut().take() {
107                progress(track_count);
108            }
109        }
110
111        track_positions.shrink_to_fit();
112        Ok(MIDIFile {
113            reader,
114            ppq,
115            format,
116            track_positions,
117        })
118    }
119
120    pub fn open_track_reader(&self, track: u32) -> T::ByteReader {
121        let pos = &self.track_positions[track as usize];
122        self.reader
123            .open_reader(Some(track), pos.pos, pos.len as u64)
124    }
125
126    pub fn iter_all_tracks(
127        &self,
128    ) -> impl Iterator<Item = impl Iterator<Item = Result<Delta<u64, Event>, MIDIParseError>>> {
129        let mut tracks = Vec::new();
130        for i in 0..self.track_count() {
131            tracks.push(self.iter_track(i as u32));
132        }
133        tracks.into_iter()
134    }
135
136    pub fn iter_all_events_merged(
137        &self,
138    ) -> impl Iterator<Item = Result<Delta<u64, Event>, MIDIParseError>> {
139        let merged_batches = self.iter_all_events_merged_batches();
140        flatten_batches_to_events(merged_batches)
141    }
142
143    pub fn iter_all_track_events_merged(
144        &self,
145    ) -> impl Iterator<Item = Result<Delta<u64, Track<Event>>, MIDIParseError>> {
146        let merged_batches = self.iter_all_track_events_merged_batches();
147        flatten_track_batches_to_events(merged_batches)
148    }
149
150    pub fn iter_all_events_merged_batches(
151        &self,
152    ) -> impl Iterator<Item = Result<Delta<u64, EventBatch<Event>>, MIDIParseError>> {
153        let batched_tracks = self
154            .iter_all_tracks()
155            .map(convert_events_into_batches)
156            .collect();
157        let batched_tracks_threaded = channels_into_threadpool(batched_tracks, 10);
158        merge_events_array(batched_tracks_threaded)
159    }
160
161    pub fn iter_all_track_events_merged_batches(
162        &self,
163    ) -> impl Iterator<Item = Result<Delta<u64, Track<EventBatch<Event>>>, MIDIParseError>> {
164        let batched_tracks = self
165            .iter_all_tracks()
166            .map(convert_events_into_batches)
167            .enumerate()
168            .map(|(i, track)| into_track_events(track, i as u32))
169            .collect();
170        let batched_tracks_threaded = channels_into_threadpool(batched_tracks, 10);
171        merge_events_array(batched_tracks_threaded)
172    }
173
174    pub fn iter_track(
175        &self,
176        track: u32,
177    ) -> impl Iterator<Item = Result<Delta<u64, Event>, MIDIParseError>> {
178        let reader = self.open_track_reader(track);
179        TrackParser::new(reader)
180    }
181
182    pub fn ppq(&self) -> u16 {
183        self.ppq
184    }
185
186    pub fn format(&self) -> u16 {
187        self.format
188    }
189
190    pub fn track_count(&self) -> usize {
191        self.track_positions.len()
192    }
193}
194
195impl MIDIFile<DiskReader> {
196    pub fn open(
197        filename: impl AsRef<Path>,
198        read_progress: Option<&mut dyn FnMut(u32)>,
199    ) -> Result<Self, MIDILoadError> {
200        let reader = File::open(filename)?;
201        let reader = DiskReader::new(reader)?;
202
203        MIDIFile::new_from_disk_reader(reader, read_progress)
204    }
205
206    pub fn open_from_stream<T: 'static + Read + Seek + Send>(
207        stream: T,
208        read_progress: Option<&mut dyn FnMut(u32)>,
209    ) -> Result<Self, MIDILoadError> {
210        let reader = DiskReader::new(stream)?;
211
212        MIDIFile::new_from_disk_reader(reader, read_progress)
213    }
214}
215
216impl MIDIFile<RAMReader> {
217    pub fn open_in_ram(
218        filename: impl AsRef<Path>,
219        read_progress: Option<&mut dyn FnMut(u32)>,
220    ) -> Result<Self, MIDILoadError> {
221        let reader = File::open(filename)?;
222        let reader = RAMReader::new(reader)?;
223
224        MIDIFile::new_from_disk_reader(reader, read_progress)
225    }
226
227    pub fn open_from_stream_in_ram<T: 'static + Read + Seek + Send>(
228        stream: T,
229        read_progress: Option<&mut dyn FnMut(u32)>,
230    ) -> Result<Self, MIDILoadError> {
231        let reader = RAMReader::new(stream)?;
232
233        MIDIFile::new_from_disk_reader(reader, read_progress)
234    }
235}