1use std::{
2 fs::File,
3 io::{Read, Seek},
4 path::Path,
5};
6
7use crate::{
8 events::Event,
9 sequence::{
10 channels_into_threadpool,
11 event::{
12 convert_events_into_batches, flatten_batches_to_events,
13 flatten_track_batches_to_events, into_track_events, merge_events_array, Delta,
14 EventBatch, Track,
15 },
16 },
17};
18use std::fmt::Debug;
19
20use super::{
21 errors::{MIDILoadError, MIDIParseError},
22 readers::{DiskReader, MIDIReader, RAMReader},
23 track_parser::TrackParser,
24};
25
26#[derive(Debug)]
27struct TrackPos {
28 pos: u64,
29 len: u32,
30}
31
32#[derive(Debug)]
33pub struct MIDIFile<T: MIDIReader> {
34 reader: T,
35 track_positions: Vec<TrackPos>,
36
37 format: u16,
38 ppq: u16,
39}
40
41impl<T: 'static + MIDIReader> MIDIFile<T> {
42 fn new_from_disk_reader(
43 reader: T,
44 mut read_progress: Option<&mut dyn FnMut(u32)>,
45 ) -> Result<Self, MIDILoadError> {
46 fn bytes_to_val(bytes: &[u8]) -> u32 {
47 assert!(bytes.len() <= 4);
48 let mut num: u32 = 0;
49 for b in bytes {
50 num = (num << 8) + *b as u32;
51 }
52
53 num
54 }
55
56 fn read_header<T: MIDIReader>(
57 reader: &T,
58 pos: u64,
59 text: &str,
60 ) -> Result<u32, MIDILoadError> {
61 assert!(text.len() == 4);
62
63 let bytes = reader.read_bytes(pos, 8)?;
64
65 let (header, len) = bytes.split_at(4);
66
67 let chars = text.as_bytes();
68
69 for i in 0..chars.len() {
70 if chars[i] != header[i] {
71 return Err(MIDILoadError::CorruptChunks);
72 }
73 }
74
75 Ok(bytes_to_val(len))
76 }
77
78 let mut pos = 0u64;
79
80 let header_len = read_header(&reader, pos, "MThd")?;
81 pos += 8;
82 if header_len != 6 {
83 return Err(MIDILoadError::CorruptChunks);
84 }
85
86 let (format, ppq) = {
87 let header_data = reader.read_bytes(pos, 6)?;
88 pos += 6;
89 let (format_bytes, rest) = header_data.split_at(2);
90 let (_, ppq_bytes) = rest.split_at(2);
91 (
92 bytes_to_val(format_bytes) as u16,
93 bytes_to_val(ppq_bytes) as u16,
94 )
95 };
96
97 let mut track_count = 0;
98 let mut track_positions = Vec::<TrackPos>::new();
99 while pos != reader.len() {
100 let len = read_header(&reader, pos, "MTrk")?;
101 pos += 8;
102 track_count += 1;
103 track_positions.push(TrackPos { len, pos });
104 pos += len as u64;
105
106 if let Some(progress) = read_progress.as_mut().take() {
107 progress(track_count);
108 }
109 }
110
111 track_positions.shrink_to_fit();
112 Ok(MIDIFile {
113 reader,
114 ppq,
115 format,
116 track_positions,
117 })
118 }
119
120 pub fn open_track_reader(&self, track: u32) -> T::ByteReader {
121 let pos = &self.track_positions[track as usize];
122 self.reader
123 .open_reader(Some(track), pos.pos, pos.len as u64)
124 }
125
126 pub fn iter_all_tracks(
127 &self,
128 ) -> impl Iterator<Item = impl Iterator<Item = Result<Delta<u64, Event>, MIDIParseError>>> {
129 let mut tracks = Vec::new();
130 for i in 0..self.track_count() {
131 tracks.push(self.iter_track(i as u32));
132 }
133 tracks.into_iter()
134 }
135
136 pub fn iter_all_events_merged(
137 &self,
138 ) -> impl Iterator<Item = Result<Delta<u64, Event>, MIDIParseError>> {
139 let merged_batches = self.iter_all_events_merged_batches();
140 flatten_batches_to_events(merged_batches)
141 }
142
143 pub fn iter_all_track_events_merged(
144 &self,
145 ) -> impl Iterator<Item = Result<Delta<u64, Track<Event>>, MIDIParseError>> {
146 let merged_batches = self.iter_all_track_events_merged_batches();
147 flatten_track_batches_to_events(merged_batches)
148 }
149
150 pub fn iter_all_events_merged_batches(
151 &self,
152 ) -> impl Iterator<Item = Result<Delta<u64, EventBatch<Event>>, MIDIParseError>> {
153 let batched_tracks = self
154 .iter_all_tracks()
155 .map(convert_events_into_batches)
156 .collect();
157 let batched_tracks_threaded = channels_into_threadpool(batched_tracks, 10);
158 merge_events_array(batched_tracks_threaded)
159 }
160
161 pub fn iter_all_track_events_merged_batches(
162 &self,
163 ) -> impl Iterator<Item = Result<Delta<u64, Track<EventBatch<Event>>>, MIDIParseError>> {
164 let batched_tracks = self
165 .iter_all_tracks()
166 .map(convert_events_into_batches)
167 .enumerate()
168 .map(|(i, track)| into_track_events(track, i as u32))
169 .collect();
170 let batched_tracks_threaded = channels_into_threadpool(batched_tracks, 10);
171 merge_events_array(batched_tracks_threaded)
172 }
173
174 pub fn iter_track(
175 &self,
176 track: u32,
177 ) -> impl Iterator<Item = Result<Delta<u64, Event>, MIDIParseError>> {
178 let reader = self.open_track_reader(track);
179 TrackParser::new(reader)
180 }
181
182 pub fn ppq(&self) -> u16 {
183 self.ppq
184 }
185
186 pub fn format(&self) -> u16 {
187 self.format
188 }
189
190 pub fn track_count(&self) -> usize {
191 self.track_positions.len()
192 }
193}
194
195impl MIDIFile<DiskReader> {
196 pub fn open(
197 filename: impl AsRef<Path>,
198 read_progress: Option<&mut dyn FnMut(u32)>,
199 ) -> Result<Self, MIDILoadError> {
200 let reader = File::open(filename)?;
201 let reader = DiskReader::new(reader)?;
202
203 MIDIFile::new_from_disk_reader(reader, read_progress)
204 }
205
206 pub fn open_from_stream<T: 'static + Read + Seek + Send>(
207 stream: T,
208 read_progress: Option<&mut dyn FnMut(u32)>,
209 ) -> Result<Self, MIDILoadError> {
210 let reader = DiskReader::new(stream)?;
211
212 MIDIFile::new_from_disk_reader(reader, read_progress)
213 }
214}
215
216impl MIDIFile<RAMReader> {
217 pub fn open_in_ram(
218 filename: impl AsRef<Path>,
219 read_progress: Option<&mut dyn FnMut(u32)>,
220 ) -> Result<Self, MIDILoadError> {
221 let reader = File::open(filename)?;
222 let reader = RAMReader::new(reader)?;
223
224 MIDIFile::new_from_disk_reader(reader, read_progress)
225 }
226
227 pub fn open_from_stream_in_ram<T: 'static + Read + Seek + Send>(
228 stream: T,
229 read_progress: Option<&mut dyn FnMut(u32)>,
230 ) -> Result<Self, MIDILoadError> {
231 let reader = RAMReader::new(stream)?;
232
233 MIDIFile::new_from_disk_reader(reader, read_progress)
234 }
235}