midi_toolkit/io/
readers.rs

1use crossbeam_channel::{bounded, unbounded, Sender};
2use std::{
3    io::{self, Read, Seek, SeekFrom},
4    sync::Arc,
5    thread::{self, JoinHandle},
6};
7
8use crate::DelayedReceiver;
9
10use super::errors::{MIDILoadError, MIDIParseError};
11
12use std::fmt::Debug;
13#[derive(Debug)]
14pub struct DiskReader {
15    reader: Arc<BufferReadProvider>,
16    length: u64,
17}
18
19#[derive(Debug)]
20pub struct RAMReader {
21    bytes: Arc<Vec<u8>>,
22    pos: usize,
23}
24
25pub struct ReadCommand {
26    destination: Sender<Result<Vec<u8>, io::Error>>,
27    buffer: Vec<u8>,
28    start: u64,
29    length: usize,
30}
31
32#[derive(Debug)]
33pub struct BufferReadProvider {
34    _thread: JoinHandle<()>,
35    send: Sender<ReadCommand>,
36}
37
38impl BufferReadProvider {
39    pub fn new<T: 'static + Read + Seek + Send>(mut reader: T) -> BufferReadProvider {
40        let (snd, rcv) = unbounded::<ReadCommand>();
41
42        let handle = thread::spawn(move || {
43            let mut read = move |mut buffer: Vec<u8>,
44                                 start: u64,
45                                 length: usize|
46                  -> Result<Vec<u8>, io::Error> {
47                reader.seek(SeekFrom::Start(start))?;
48                if length < buffer.len() {
49                    buffer.truncate(length)
50                }
51                reader.read_exact(&mut buffer)?;
52                Ok(buffer)
53            };
54
55            loop {
56                match rcv.recv() {
57                    Err(_) => return,
58                    Ok(cmd) => match read(cmd.buffer, cmd.start, cmd.length) {
59                        Ok(buf) => {
60                            cmd.destination.send(Ok(buf)).ok();
61                        }
62                        Err(e) => {
63                            cmd.destination.send(Err(e)).ok();
64                        }
65                    },
66                }
67            }
68        });
69
70        BufferReadProvider {
71            send: snd,
72            _thread: handle,
73        }
74    }
75
76    pub fn send_read_command(
77        &self,
78        destination: Sender<Result<Vec<u8>, io::Error>>,
79        buffer: Vec<u8>,
80        start: u64,
81        length: usize,
82    ) {
83        let cmd = ReadCommand {
84            destination,
85            buffer,
86            start,
87            length,
88        };
89
90        self.send.send(cmd).unwrap();
91    }
92
93    pub fn read_sync(&self, buf: Vec<u8>, start: u64) -> Result<Vec<u8>, io::Error> {
94        let (send, receive) = bounded::<Result<Vec<u8>, io::Error>>(1);
95
96        let len = buf.len();
97        self.send_read_command(send, buf, start, len);
98
99        receive.recv().unwrap()
100    }
101}
102
103fn get_reader_len<T: Seek>(reader: &mut T) -> Result<u64, MIDILoadError> {
104    let pos = reader.seek(SeekFrom::End(0))?;
105    reader.seek(SeekFrom::Start(0))?;
106    Ok(pos)
107}
108
109impl DiskReader {
110    pub fn new<T: 'static + Read + Seek + Send>(
111        mut reader: T,
112    ) -> Result<DiskReader, MIDILoadError> {
113        let len = get_reader_len(&mut reader);
114        let reader = BufferReadProvider::new(reader);
115
116        match len {
117            Err(e) => Err(e),
118            Ok(length) => Ok(DiskReader {
119                reader: Arc::new(reader),
120                length,
121            }),
122        }
123    }
124}
125
126impl RAMReader {
127    pub fn new<T: Read + Seek>(mut reader: T) -> Result<RAMReader, MIDILoadError> {
128        let len = get_reader_len(&mut reader);
129
130        match len {
131            Err(e) => Err(e),
132            Ok(length) => {
133                let max_supported: u64 = 2147483648;
134                if length > max_supported {
135                    panic!(
136                        "The maximum length allowed for a memory loaded MIDI file is {}",
137                        max_supported
138                    );
139                }
140
141                let mut bytes = vec![0; length as usize];
142                reader.read_exact(&mut bytes)?;
143                Ok(RAMReader {
144                    bytes: Arc::new(bytes),
145                    pos: 0,
146                })
147            }
148        }
149    }
150
151    pub fn read_byte(&mut self) -> Result<u8, MIDILoadError> {
152        let b = self.bytes.get(self.pos);
153        self.pos += 1;
154        match b {
155            Some(v) => Ok(*v),
156            None => Err(MIDILoadError::CorruptChunks),
157        }
158    }
159}
160
161pub trait MIDIReader: Debug {
162    type ByteReader: TrackReader;
163
164    fn read_bytes_to(&self, pos: u64, bytes: Vec<u8>) -> Result<Vec<u8>, MIDILoadError>;
165    fn read_bytes(&self, pos: u64, count: usize) -> Result<Vec<u8>, MIDILoadError> {
166        let bytes = vec![0u8; count];
167
168        self.read_bytes_to(pos, bytes)
169    }
170
171    fn len(&self) -> u64;
172    fn is_empty(&self) -> bool {
173        self.len() == 0
174    }
175
176    fn open_reader(&self, track_number: Option<u32>, start: u64, len: u64) -> Self::ByteReader;
177}
178
179impl MIDIReader for DiskReader {
180    type ByteReader = DiskTrackReader;
181
182    fn open_reader(&self, track_number: Option<u32>, start: u64, len: u64) -> DiskTrackReader {
183        DiskTrackReader::new(track_number, self.reader.clone(), start, len)
184    }
185
186    fn read_bytes_to(&self, pos: u64, bytes: Vec<u8>) -> Result<Vec<u8>, MIDILoadError> {
187        Ok(self.reader.read_sync(bytes, pos)?)
188    }
189
190    fn len(&self) -> u64 {
191        self.length
192    }
193}
194
195impl MIDIReader for RAMReader {
196    type ByteReader = FullRamTrackReader;
197
198    fn open_reader<'a>(
199        &self,
200        track_number: Option<u32>,
201        start: u64,
202        len: u64,
203    ) -> FullRamTrackReader {
204        FullRamTrackReader {
205            track_number,
206            start: start as usize,
207            pos: start as usize,
208            end: (start + len) as usize,
209            bytes: self.bytes.clone(),
210        }
211    }
212
213    fn read_bytes_to(&self, pos: u64, mut bytes: Vec<u8>) -> Result<Vec<u8>, MIDILoadError> {
214        let count = bytes.len();
215        if pos + count as u64 > self.len() {
216            return Err(MIDILoadError::CorruptChunks);
217        }
218
219        bytes[..].clone_from_slice(&self.bytes[pos as usize..pos as usize + count]);
220
221        Ok(bytes)
222    }
223
224    fn len(&self) -> u64 {
225        self.bytes.len() as u64
226    }
227}
228
229pub trait TrackReader: Send + Sync {
230    /// The stored track number for diagnostic purposes
231    fn track_number(&self) -> Option<u32>;
232
233    fn read(&mut self) -> Result<u8, MIDIParseError>;
234    fn pos(&self) -> u64;
235    fn is_at_end(&self) -> bool;
236}
237
238pub struct DiskTrackReader {
239    /// The track number used only for error logging purposes
240    track_number: Option<u32>,
241
242    reader: Arc<BufferReadProvider>,
243    start: u64,                  // Relative to midi
244    len: u64,                    //
245    buffer: Option<Vec<u8>>,     //
246    buffer_start: u64,           // Relative to start
247    buffer_pos: usize,           // Relative buffer start
248    unrequested_data_start: u64, // Relative to start
249
250    receiver: DelayedReceiver<Result<Vec<u8>, io::Error>>,
251    receiver_sender: Option<Sender<Result<Vec<u8>, io::Error>>>, // Becomes None when there's nothing left to read
252}
253
254pub struct FullRamTrackReader {
255    /// The track number and start are only for error logging purposes
256    track_number: Option<u32>,
257    start: usize,
258
259    bytes: Arc<Vec<u8>>,
260    pos: usize,
261    end: usize,
262}
263
264impl FullRamTrackReader {
265    pub fn new(
266        track_number: Option<u32>,
267        bytes: Arc<Vec<u8>>,
268        start: usize,
269        end: usize,
270    ) -> FullRamTrackReader {
271        FullRamTrackReader {
272            track_number,
273            bytes,
274            start,
275            pos: start,
276            end,
277        }
278    }
279
280    pub fn new_from_vec(track_number: Option<u32>, bytes: Vec<u8>) -> FullRamTrackReader {
281        let len = bytes.len();
282        FullRamTrackReader {
283            track_number,
284            bytes: Arc::new(bytes),
285            pos: 0,
286            start: 0,
287            end: len,
288        }
289    }
290}
291
292impl TrackReader for FullRamTrackReader {
293    #[inline(always)]
294    fn read(&mut self) -> Result<u8, MIDIParseError> {
295        if self.pos == self.end {
296            return Err(MIDIParseError::UnexpectedTrackEnd {
297                track_number: self.track_number,
298                track_start: self.start as u64,
299                expected_track_end: self.end as u64,
300                found_track_end: self.pos as u64,
301            });
302        }
303        let b = self.bytes[self.pos];
304        self.pos += 1;
305        Ok(b)
306    }
307
308    #[inline(always)]
309    fn pos(&self) -> u64 {
310        self.pos as u64
311    }
312
313    fn is_at_end(&self) -> bool {
314        self.pos == self.end
315    }
316
317    fn track_number(&self) -> Option<u32> {
318        self.track_number
319    }
320}
321
322impl DiskTrackReader {
323    fn finished_sending_reads(&self) -> bool {
324        self.unrequested_data_start == self.len
325    }
326
327    fn next_buffer_req_length(&self) -> usize {
328        (self.len - self.unrequested_data_start).min(1 << 19) as usize
329    }
330
331    fn send_next_read(&mut self, buffer: Option<Vec<u8>>) {
332        if self.finished_sending_reads() {
333            self.receiver_sender.take();
334            return;
335        }
336
337        let mut next_len = self.next_buffer_req_length();
338
339        let buffer = match buffer {
340            None => vec![0u8; next_len],
341            Some(b) => b,
342        };
343
344        next_len = next_len.min(buffer.len());
345
346        self.reader.send_read_command(
347            self.receiver_sender.clone().unwrap(),
348            buffer,
349            self.unrequested_data_start + self.start,
350            next_len,
351        );
352
353        self.unrequested_data_start += next_len as u64;
354    }
355
356    fn receive_next_buffer(&mut self) -> Option<Result<Vec<u8>, MIDIParseError>> {
357        match self.receiver.recv() {
358            Ok(v) => match v {
359                Ok(v) => Some(Ok(v)),
360                Err(e) => Some(Err(e.into())),
361            },
362            Err(_) => None,
363        }
364    }
365
366    pub fn new(
367        track_number: Option<u32>,
368        reader: Arc<BufferReadProvider>,
369        start: u64,
370        len: u64,
371    ) -> DiskTrackReader {
372        let buffer_count = 3;
373
374        let (send, receive) = unbounded();
375
376        let mut reader = DiskTrackReader {
377            track_number,
378            reader,
379            start,
380            len,
381            buffer: None,
382            buffer_start: 0,
383            buffer_pos: 0,
384            unrequested_data_start: 0,
385            receiver: DelayedReceiver::new(receive),
386            receiver_sender: Some(send),
387        };
388
389        for _ in 0..buffer_count {
390            reader.send_next_read(None);
391        }
392
393        reader.receiver.wait_first();
394
395        reader
396    }
397}
398
399impl TrackReader for DiskTrackReader {
400    fn read(&mut self) -> Result<u8, MIDIParseError> {
401        match self.buffer {
402            None => {
403                if let Some(next) = self.receive_next_buffer() {
404                    self.buffer = Some(next?);
405                } else {
406                    return Err(MIDIParseError::UnexpectedTrackEnd {
407                        track_number: self.track_number,
408                        track_start: self.start,
409                        expected_track_end: self.start + self.len,
410                        found_track_end: self.pos(),
411                    });
412                }
413            }
414            Some(_) => {}
415        }
416
417        let buffer = self.buffer.as_ref().unwrap();
418        let byte = buffer[self.buffer_pos];
419
420        self.buffer_pos += 1;
421        if self.buffer_pos == buffer.len() {
422            let buffer = self.buffer.take().unwrap();
423            self.buffer_start += buffer.len() as u64;
424            self.buffer_pos = 0;
425            self.send_next_read(Some(buffer));
426        }
427
428        Ok(byte)
429    }
430
431    #[inline(always)]
432    fn pos(&self) -> u64 {
433        self.start + self.buffer_start + self.buffer_pos as u64
434    }
435
436    fn is_at_end(&self) -> bool {
437        self.buffer_start + self.buffer_pos as u64 >= self.len
438    }
439
440    fn track_number(&self) -> Option<u32> {
441        self.track_number
442    }
443}