1use crossbeam_channel::{bounded, unbounded, Sender};
2use std::{
3 io::{self, Read, Seek, SeekFrom},
4 sync::Arc,
5 thread::{self, JoinHandle},
6};
7
8use crate::DelayedReceiver;
9
10use super::errors::{MIDILoadError, MIDIParseError};
11
12use std::fmt::Debug;
13#[derive(Debug)]
14pub struct DiskReader {
15 reader: Arc<BufferReadProvider>,
16 length: u64,
17}
18
19#[derive(Debug)]
20pub struct RAMReader {
21 bytes: Arc<Vec<u8>>,
22 pos: usize,
23}
24
25pub struct ReadCommand {
26 destination: Sender<Result<Vec<u8>, io::Error>>,
27 buffer: Vec<u8>,
28 start: u64,
29 length: usize,
30}
31
32#[derive(Debug)]
33pub struct BufferReadProvider {
34 _thread: JoinHandle<()>,
35 send: Sender<ReadCommand>,
36}
37
38impl BufferReadProvider {
39 pub fn new<T: 'static + Read + Seek + Send>(mut reader: T) -> BufferReadProvider {
40 let (snd, rcv) = unbounded::<ReadCommand>();
41
42 let handle = thread::spawn(move || {
43 let mut read = move |mut buffer: Vec<u8>,
44 start: u64,
45 length: usize|
46 -> Result<Vec<u8>, io::Error> {
47 reader.seek(SeekFrom::Start(start))?;
48 if length < buffer.len() {
49 buffer.truncate(length)
50 }
51 reader.read_exact(&mut buffer)?;
52 Ok(buffer)
53 };
54
55 loop {
56 match rcv.recv() {
57 Err(_) => return,
58 Ok(cmd) => match read(cmd.buffer, cmd.start, cmd.length) {
59 Ok(buf) => {
60 cmd.destination.send(Ok(buf)).ok();
61 }
62 Err(e) => {
63 cmd.destination.send(Err(e)).ok();
64 }
65 },
66 }
67 }
68 });
69
70 BufferReadProvider {
71 send: snd,
72 _thread: handle,
73 }
74 }
75
76 pub fn send_read_command(
77 &self,
78 destination: Sender<Result<Vec<u8>, io::Error>>,
79 buffer: Vec<u8>,
80 start: u64,
81 length: usize,
82 ) {
83 let cmd = ReadCommand {
84 destination,
85 buffer,
86 start,
87 length,
88 };
89
90 self.send.send(cmd).unwrap();
91 }
92
93 pub fn read_sync(&self, buf: Vec<u8>, start: u64) -> Result<Vec<u8>, io::Error> {
94 let (send, receive) = bounded::<Result<Vec<u8>, io::Error>>(1);
95
96 let len = buf.len();
97 self.send_read_command(send, buf, start, len);
98
99 receive.recv().unwrap()
100 }
101}
102
103fn get_reader_len<T: Seek>(reader: &mut T) -> Result<u64, MIDILoadError> {
104 let pos = reader.seek(SeekFrom::End(0))?;
105 reader.seek(SeekFrom::Start(0))?;
106 Ok(pos)
107}
108
109impl DiskReader {
110 pub fn new<T: 'static + Read + Seek + Send>(
111 mut reader: T,
112 ) -> Result<DiskReader, MIDILoadError> {
113 let len = get_reader_len(&mut reader);
114 let reader = BufferReadProvider::new(reader);
115
116 match len {
117 Err(e) => Err(e),
118 Ok(length) => Ok(DiskReader {
119 reader: Arc::new(reader),
120 length,
121 }),
122 }
123 }
124}
125
126impl RAMReader {
127 pub fn new<T: Read + Seek>(mut reader: T) -> Result<RAMReader, MIDILoadError> {
128 let len = get_reader_len(&mut reader);
129
130 match len {
131 Err(e) => Err(e),
132 Ok(length) => {
133 let max_supported: u64 = 2147483648;
134 if length > max_supported {
135 panic!(
136 "The maximum length allowed for a memory loaded MIDI file is {}",
137 max_supported
138 );
139 }
140
141 let mut bytes = vec![0; length as usize];
142 reader.read_exact(&mut bytes)?;
143 Ok(RAMReader {
144 bytes: Arc::new(bytes),
145 pos: 0,
146 })
147 }
148 }
149 }
150
151 pub fn read_byte(&mut self) -> Result<u8, MIDILoadError> {
152 let b = self.bytes.get(self.pos);
153 self.pos += 1;
154 match b {
155 Some(v) => Ok(*v),
156 None => Err(MIDILoadError::CorruptChunks),
157 }
158 }
159}
160
161pub trait MIDIReader: Debug {
162 type ByteReader: TrackReader;
163
164 fn read_bytes_to(&self, pos: u64, bytes: Vec<u8>) -> Result<Vec<u8>, MIDILoadError>;
165 fn read_bytes(&self, pos: u64, count: usize) -> Result<Vec<u8>, MIDILoadError> {
166 let bytes = vec![0u8; count];
167
168 self.read_bytes_to(pos, bytes)
169 }
170
171 fn len(&self) -> u64;
172 fn is_empty(&self) -> bool {
173 self.len() == 0
174 }
175
176 fn open_reader(&self, track_number: Option<u32>, start: u64, len: u64) -> Self::ByteReader;
177}
178
179impl MIDIReader for DiskReader {
180 type ByteReader = DiskTrackReader;
181
182 fn open_reader(&self, track_number: Option<u32>, start: u64, len: u64) -> DiskTrackReader {
183 DiskTrackReader::new(track_number, self.reader.clone(), start, len)
184 }
185
186 fn read_bytes_to(&self, pos: u64, bytes: Vec<u8>) -> Result<Vec<u8>, MIDILoadError> {
187 Ok(self.reader.read_sync(bytes, pos)?)
188 }
189
190 fn len(&self) -> u64 {
191 self.length
192 }
193}
194
195impl MIDIReader for RAMReader {
196 type ByteReader = FullRamTrackReader;
197
198 fn open_reader<'a>(
199 &self,
200 track_number: Option<u32>,
201 start: u64,
202 len: u64,
203 ) -> FullRamTrackReader {
204 FullRamTrackReader {
205 track_number,
206 start: start as usize,
207 pos: start as usize,
208 end: (start + len) as usize,
209 bytes: self.bytes.clone(),
210 }
211 }
212
213 fn read_bytes_to(&self, pos: u64, mut bytes: Vec<u8>) -> Result<Vec<u8>, MIDILoadError> {
214 let count = bytes.len();
215 if pos + count as u64 > self.len() {
216 return Err(MIDILoadError::CorruptChunks);
217 }
218
219 bytes[..].clone_from_slice(&self.bytes[pos as usize..pos as usize + count]);
220
221 Ok(bytes)
222 }
223
224 fn len(&self) -> u64 {
225 self.bytes.len() as u64
226 }
227}
228
229pub trait TrackReader: Send + Sync {
230 fn track_number(&self) -> Option<u32>;
232
233 fn read(&mut self) -> Result<u8, MIDIParseError>;
234 fn pos(&self) -> u64;
235 fn is_at_end(&self) -> bool;
236}
237
238pub struct DiskTrackReader {
239 track_number: Option<u32>,
241
242 reader: Arc<BufferReadProvider>,
243 start: u64, len: u64, buffer: Option<Vec<u8>>, buffer_start: u64, buffer_pos: usize, unrequested_data_start: u64, receiver: DelayedReceiver<Result<Vec<u8>, io::Error>>,
251 receiver_sender: Option<Sender<Result<Vec<u8>, io::Error>>>, }
253
254pub struct FullRamTrackReader {
255 track_number: Option<u32>,
257 start: usize,
258
259 bytes: Arc<Vec<u8>>,
260 pos: usize,
261 end: usize,
262}
263
264impl FullRamTrackReader {
265 pub fn new(
266 track_number: Option<u32>,
267 bytes: Arc<Vec<u8>>,
268 start: usize,
269 end: usize,
270 ) -> FullRamTrackReader {
271 FullRamTrackReader {
272 track_number,
273 bytes,
274 start,
275 pos: start,
276 end,
277 }
278 }
279
280 pub fn new_from_vec(track_number: Option<u32>, bytes: Vec<u8>) -> FullRamTrackReader {
281 let len = bytes.len();
282 FullRamTrackReader {
283 track_number,
284 bytes: Arc::new(bytes),
285 pos: 0,
286 start: 0,
287 end: len,
288 }
289 }
290}
291
292impl TrackReader for FullRamTrackReader {
293 #[inline(always)]
294 fn read(&mut self) -> Result<u8, MIDIParseError> {
295 if self.pos == self.end {
296 return Err(MIDIParseError::UnexpectedTrackEnd {
297 track_number: self.track_number,
298 track_start: self.start as u64,
299 expected_track_end: self.end as u64,
300 found_track_end: self.pos as u64,
301 });
302 }
303 let b = self.bytes[self.pos];
304 self.pos += 1;
305 Ok(b)
306 }
307
308 #[inline(always)]
309 fn pos(&self) -> u64 {
310 self.pos as u64
311 }
312
313 fn is_at_end(&self) -> bool {
314 self.pos == self.end
315 }
316
317 fn track_number(&self) -> Option<u32> {
318 self.track_number
319 }
320}
321
322impl DiskTrackReader {
323 fn finished_sending_reads(&self) -> bool {
324 self.unrequested_data_start == self.len
325 }
326
327 fn next_buffer_req_length(&self) -> usize {
328 (self.len - self.unrequested_data_start).min(1 << 19) as usize
329 }
330
331 fn send_next_read(&mut self, buffer: Option<Vec<u8>>) {
332 if self.finished_sending_reads() {
333 self.receiver_sender.take();
334 return;
335 }
336
337 let mut next_len = self.next_buffer_req_length();
338
339 let buffer = match buffer {
340 None => vec![0u8; next_len],
341 Some(b) => b,
342 };
343
344 next_len = next_len.min(buffer.len());
345
346 self.reader.send_read_command(
347 self.receiver_sender.clone().unwrap(),
348 buffer,
349 self.unrequested_data_start + self.start,
350 next_len,
351 );
352
353 self.unrequested_data_start += next_len as u64;
354 }
355
356 fn receive_next_buffer(&mut self) -> Option<Result<Vec<u8>, MIDIParseError>> {
357 match self.receiver.recv() {
358 Ok(v) => match v {
359 Ok(v) => Some(Ok(v)),
360 Err(e) => Some(Err(e.into())),
361 },
362 Err(_) => None,
363 }
364 }
365
366 pub fn new(
367 track_number: Option<u32>,
368 reader: Arc<BufferReadProvider>,
369 start: u64,
370 len: u64,
371 ) -> DiskTrackReader {
372 let buffer_count = 3;
373
374 let (send, receive) = unbounded();
375
376 let mut reader = DiskTrackReader {
377 track_number,
378 reader,
379 start,
380 len,
381 buffer: None,
382 buffer_start: 0,
383 buffer_pos: 0,
384 unrequested_data_start: 0,
385 receiver: DelayedReceiver::new(receive),
386 receiver_sender: Some(send),
387 };
388
389 for _ in 0..buffer_count {
390 reader.send_next_read(None);
391 }
392
393 reader.receiver.wait_first();
394
395 reader
396 }
397}
398
399impl TrackReader for DiskTrackReader {
400 fn read(&mut self) -> Result<u8, MIDIParseError> {
401 match self.buffer {
402 None => {
403 if let Some(next) = self.receive_next_buffer() {
404 self.buffer = Some(next?);
405 } else {
406 return Err(MIDIParseError::UnexpectedTrackEnd {
407 track_number: self.track_number,
408 track_start: self.start,
409 expected_track_end: self.start + self.len,
410 found_track_end: self.pos(),
411 });
412 }
413 }
414 Some(_) => {}
415 }
416
417 let buffer = self.buffer.as_ref().unwrap();
418 let byte = buffer[self.buffer_pos];
419
420 self.buffer_pos += 1;
421 if self.buffer_pos == buffer.len() {
422 let buffer = self.buffer.take().unwrap();
423 self.buffer_start += buffer.len() as u64;
424 self.buffer_pos = 0;
425 self.send_next_read(Some(buffer));
426 }
427
428 Ok(byte)
429 }
430
431 #[inline(always)]
432 fn pos(&self) -> u64 {
433 self.start + self.buffer_start + self.buffer_pos as u64
434 }
435
436 fn is_at_end(&self) -> bool {
437 self.buffer_start + self.buffer_pos as u64 >= self.len
438 }
439
440 fn track_number(&self) -> Option<u32> {
441 self.track_number
442 }
443}