Skip to main content

midi_toolkit/io/
readers.rs

1use crossbeam_channel::{bounded, unbounded, Sender};
2use std::{
3    io::{self, Read, Seek, SeekFrom},
4    sync::Arc,
5    thread::{self, JoinHandle},
6};
7
8use crate::DelayedReceiver;
9
10use super::errors::{MIDILoadError, MIDIParseError};
11
12use std::fmt::Debug;
13#[derive(Debug)]
14pub struct DiskReader {
15    reader: Arc<BufferReadProvider>,
16    length: u64,
17}
18
19#[derive(Debug)]
20pub struct RAMReader {
21    bytes: Arc<Vec<u8>>,
22    pos: usize,
23}
24
25pub struct ReadCommand {
26    destination: Sender<Result<Vec<u8>, io::Error>>,
27    buffer: Vec<u8>,
28    start: u64,
29    length: usize,
30}
31
32#[derive(Debug)]
33pub struct BufferReadProvider {
34    _thread: JoinHandle<()>,
35    send: Sender<ReadCommand>,
36}
37
38impl BufferReadProvider {
39    pub fn new<T: 'static + Read + Seek + Send>(mut reader: T) -> BufferReadProvider {
40        let (snd, rcv) = unbounded::<ReadCommand>();
41
42        let handle = thread::spawn(move || {
43            let mut read = move |mut buffer: Vec<u8>,
44                                 start: u64,
45                                 length: usize|
46                  -> Result<Vec<u8>, io::Error> {
47                reader.seek(SeekFrom::Start(start))?;
48                if length < buffer.len() {
49                    buffer.truncate(length)
50                }
51                reader.read_exact(&mut buffer)?;
52                Ok(buffer)
53            };
54
55            loop {
56                match rcv.recv() {
57                    Err(_) => return,
58                    Ok(cmd) => match read(cmd.buffer, cmd.start, cmd.length) {
59                        Ok(buf) => {
60                            cmd.destination.send(Ok(buf)).ok();
61                        }
62                        Err(e) => {
63                            cmd.destination.send(Err(e)).ok();
64                        }
65                    },
66                }
67            }
68        });
69
70        BufferReadProvider {
71            send: snd,
72            _thread: handle,
73        }
74    }
75
76    pub fn send_read_command(
77        &self,
78        destination: Sender<Result<Vec<u8>, io::Error>>,
79        buffer: Vec<u8>,
80        start: u64,
81        length: usize,
82    ) {
83        let cmd = ReadCommand {
84            destination,
85            buffer,
86            start,
87            length,
88        };
89
90        self.send.send(cmd).unwrap();
91    }
92
93    pub fn read_sync(&self, buf: Vec<u8>, start: u64) -> Result<Vec<u8>, io::Error> {
94        let (send, receive) = bounded::<Result<Vec<u8>, io::Error>>(1);
95
96        let len = buf.len();
97        self.send_read_command(send, buf, start, len);
98
99        receive.recv().unwrap()
100    }
101}
102
103fn get_reader_len<T: Seek>(reader: &mut T) -> Result<u64, MIDILoadError> {
104    let pos = reader.seek(SeekFrom::End(0))?;
105    reader.seek(SeekFrom::Start(0))?;
106    Ok(pos)
107}
108
109impl DiskReader {
110    pub fn new<T: 'static + Read + Seek + Send>(
111        mut reader: T,
112    ) -> Result<DiskReader, MIDILoadError> {
113        let len = get_reader_len(&mut reader);
114        let reader = BufferReadProvider::new(reader);
115
116        match len {
117            Err(e) => Err(e),
118            Ok(length) => Ok(DiskReader {
119                reader: Arc::new(reader),
120                length,
121            }),
122        }
123    }
124}
125
126impl RAMReader {
127    pub fn new<T: Read + Seek>(mut reader: T) -> Result<RAMReader, MIDILoadError> {
128        let len = get_reader_len(&mut reader);
129
130        match len {
131            Err(e) => Err(e),
132            Ok(length) => {
133                let max_supported: u64 = 2147483648;
134                if length > max_supported {
135                    return Err(MIDILoadError::FileTooBig);
136                }
137
138                let mut bytes = vec![0; length as usize];
139                reader.read_exact(&mut bytes)?;
140                Ok(RAMReader {
141                    bytes: Arc::new(bytes),
142                    pos: 0,
143                })
144            }
145        }
146    }
147
148    pub fn read_byte(&mut self) -> Result<u8, MIDILoadError> {
149        let b = self.bytes.get(self.pos);
150        self.pos += 1;
151        match b {
152            Some(v) => Ok(*v),
153            None => Err(MIDILoadError::CorruptChunks),
154        }
155    }
156}
157
158pub trait MIDIReader: Debug {
159    type ByteReader: TrackReader;
160
161    fn read_bytes_to(&self, pos: u64, bytes: Vec<u8>) -> Result<Vec<u8>, MIDILoadError>;
162    fn read_bytes(&self, pos: u64, count: usize) -> Result<Vec<u8>, MIDILoadError> {
163        let bytes = vec![0u8; count];
164
165        self.read_bytes_to(pos, bytes)
166    }
167
168    fn len(&self) -> u64;
169    fn is_empty(&self) -> bool {
170        self.len() == 0
171    }
172
173    fn open_reader(&self, track_number: Option<u32>, start: u64, len: u64) -> Self::ByteReader;
174}
175
176impl MIDIReader for DiskReader {
177    type ByteReader = DiskTrackReader;
178
179    fn open_reader(&self, track_number: Option<u32>, start: u64, len: u64) -> DiskTrackReader {
180        DiskTrackReader::new(track_number, self.reader.clone(), start, len)
181    }
182
183    fn read_bytes_to(&self, pos: u64, bytes: Vec<u8>) -> Result<Vec<u8>, MIDILoadError> {
184        Ok(self.reader.read_sync(bytes, pos)?)
185    }
186
187    fn len(&self) -> u64 {
188        self.length
189    }
190}
191
192impl MIDIReader for RAMReader {
193    type ByteReader = FullRamTrackReader;
194
195    fn open_reader<'a>(
196        &self,
197        track_number: Option<u32>,
198        start: u64,
199        len: u64,
200    ) -> FullRamTrackReader {
201        FullRamTrackReader {
202            track_number,
203            start: start as usize,
204            pos: start as usize,
205            end: (start + len) as usize,
206            bytes: self.bytes.clone(),
207        }
208    }
209
210    fn read_bytes_to(&self, pos: u64, mut bytes: Vec<u8>) -> Result<Vec<u8>, MIDILoadError> {
211        let count = bytes.len();
212        if pos + count as u64 > self.len() {
213            return Err(MIDILoadError::CorruptChunks);
214        }
215
216        bytes[..].clone_from_slice(&self.bytes[pos as usize..pos as usize + count]);
217
218        Ok(bytes)
219    }
220
221    fn len(&self) -> u64 {
222        self.bytes.len() as u64
223    }
224}
225
226pub trait TrackReader: Send + Sync {
227    /// The stored track number for diagnostic purposes
228    fn track_number(&self) -> Option<u32>;
229
230    fn read(&mut self) -> Result<u8, MIDIParseError>;
231    fn pos(&self) -> u64;
232    fn is_at_end(&self) -> bool;
233}
234
235pub struct DiskTrackReader {
236    /// The track number used only for error logging purposes
237    track_number: Option<u32>,
238
239    reader: Arc<BufferReadProvider>,
240    start: u64,                  // Relative to midi
241    len: u64,                    //
242    buffer: Option<Vec<u8>>,     //
243    buffer_start: u64,           // Relative to start
244    buffer_pos: usize,           // Relative buffer start
245    unrequested_data_start: u64, // Relative to start
246
247    receiver: DelayedReceiver<Result<Vec<u8>, io::Error>>,
248    receiver_sender: Option<Sender<Result<Vec<u8>, io::Error>>>, // Becomes None when there's nothing left to read
249}
250
251pub struct FullRamTrackReader {
252    /// The track number and start are only for error logging purposes
253    track_number: Option<u32>,
254    start: usize,
255
256    bytes: Arc<Vec<u8>>,
257    pos: usize,
258    end: usize,
259}
260
261impl FullRamTrackReader {
262    pub fn new(
263        track_number: Option<u32>,
264        bytes: Arc<Vec<u8>>,
265        start: usize,
266        end: usize,
267    ) -> FullRamTrackReader {
268        FullRamTrackReader {
269            track_number,
270            bytes,
271            start,
272            pos: start,
273            end,
274        }
275    }
276
277    pub fn new_from_vec(track_number: Option<u32>, bytes: Vec<u8>) -> FullRamTrackReader {
278        let len = bytes.len();
279        FullRamTrackReader {
280            track_number,
281            bytes: Arc::new(bytes),
282            pos: 0,
283            start: 0,
284            end: len,
285        }
286    }
287}
288
289impl TrackReader for FullRamTrackReader {
290    #[inline(always)]
291    fn read(&mut self) -> Result<u8, MIDIParseError> {
292        if self.pos == self.end {
293            return Err(MIDIParseError::UnexpectedTrackEnd {
294                track_number: self.track_number,
295                track_start: self.start as u64,
296                expected_track_end: self.end as u64,
297                found_track_end: self.pos as u64,
298            });
299        }
300        let b = self.bytes[self.pos];
301        self.pos += 1;
302        Ok(b)
303    }
304
305    #[inline(always)]
306    fn pos(&self) -> u64 {
307        self.pos as u64
308    }
309
310    fn is_at_end(&self) -> bool {
311        self.pos == self.end
312    }
313
314    fn track_number(&self) -> Option<u32> {
315        self.track_number
316    }
317}
318
319impl DiskTrackReader {
320    fn finished_sending_reads(&self) -> bool {
321        self.unrequested_data_start == self.len
322    }
323
324    fn next_buffer_req_length(&self) -> usize {
325        (self.len - self.unrequested_data_start).min(1 << 19) as usize
326    }
327
328    fn send_next_read(&mut self, buffer: Option<Vec<u8>>) {
329        if self.finished_sending_reads() {
330            self.receiver_sender.take();
331            return;
332        }
333
334        let mut next_len = self.next_buffer_req_length();
335
336        let buffer = match buffer {
337            None => vec![0u8; next_len],
338            Some(b) => b,
339        };
340
341        next_len = next_len.min(buffer.len());
342
343        self.reader.send_read_command(
344            self.receiver_sender.clone().unwrap(),
345            buffer,
346            self.unrequested_data_start + self.start,
347            next_len,
348        );
349
350        self.unrequested_data_start += next_len as u64;
351    }
352
353    fn receive_next_buffer(&mut self) -> Option<Result<Vec<u8>, MIDIParseError>> {
354        match self.receiver.recv() {
355            Ok(v) => match v {
356                Ok(v) => Some(Ok(v)),
357                Err(e) => Some(Err(e.into())),
358            },
359            Err(_) => None,
360        }
361    }
362
363    pub fn new(
364        track_number: Option<u32>,
365        reader: Arc<BufferReadProvider>,
366        start: u64,
367        len: u64,
368    ) -> DiskTrackReader {
369        let buffer_count = 3;
370
371        let (send, receive) = unbounded();
372
373        let mut reader = DiskTrackReader {
374            track_number,
375            reader,
376            start,
377            len,
378            buffer: None,
379            buffer_start: 0,
380            buffer_pos: 0,
381            unrequested_data_start: 0,
382            receiver: DelayedReceiver::new(receive),
383            receiver_sender: Some(send),
384        };
385
386        for _ in 0..buffer_count {
387            reader.send_next_read(None);
388        }
389
390        reader.receiver.wait_first();
391
392        reader
393    }
394}
395
396impl TrackReader for DiskTrackReader {
397    fn read(&mut self) -> Result<u8, MIDIParseError> {
398        if self.buffer.is_none() {
399            if let Some(next) = self.receive_next_buffer() {
400                self.buffer = Some(next?);
401            } else {
402                return Err(MIDIParseError::UnexpectedTrackEnd {
403                    track_number: self.track_number,
404                    track_start: self.start,
405                    expected_track_end: self.start + self.len,
406                    found_track_end: self.pos(),
407                });
408            }
409        }
410
411        let buffer = self.buffer.as_ref().unwrap();
412        let byte = buffer[self.buffer_pos];
413
414        self.buffer_pos += 1;
415        if self.buffer_pos == buffer.len() {
416            let buffer = self.buffer.take().unwrap();
417            self.buffer_start += buffer.len() as u64;
418            self.buffer_pos = 0;
419            self.send_next_read(Some(buffer));
420        }
421
422        Ok(byte)
423    }
424
425    #[inline(always)]
426    fn pos(&self) -> u64 {
427        self.start + self.buffer_start + self.buffer_pos as u64
428    }
429
430    fn is_at_end(&self) -> bool {
431        self.buffer_start + self.buffer_pos as u64 >= self.len
432    }
433
434    fn track_number(&self) -> Option<u32> {
435        self.track_number
436    }
437}
438
439#[cfg(test)]
440mod tests {
441    use super::RAMReader;
442    use crate::io::errors::MIDILoadError;
443    use std::io::{Read, Seek, SeekFrom};
444
445    struct OversizedReader {
446        pos: u64,
447        len: u64,
448    }
449
450    impl Read for OversizedReader {
451        fn read(&mut self, _buf: &mut [u8]) -> std::io::Result<usize> {
452            panic!("RAMReader::new should reject oversized files before reading")
453        }
454    }
455
456    impl Seek for OversizedReader {
457        fn seek(&mut self, pos: SeekFrom) -> std::io::Result<u64> {
458            self.pos = match pos {
459                SeekFrom::Start(pos) => pos,
460                SeekFrom::End(0) => self.len,
461                SeekFrom::Current(0) => self.pos,
462                _ => panic!("unexpected seek request in oversized reader test"),
463            };
464            Ok(self.pos)
465        }
466    }
467
468    #[test]
469    fn ram_reader_returns_file_too_big_error() {
470        let err = RAMReader::new(OversizedReader {
471            pos: 0,
472            len: 2_147_483_649,
473        })
474        .expect_err("oversized RAM MIDI should error");
475
476        assert!(matches!(err, MIDILoadError::FileTooBig));
477    }
478}