1use crossbeam_channel::{bounded, unbounded, Sender};
2use std::{
3 io::{self, Read, Seek, SeekFrom},
4 sync::Arc,
5 thread::{self, JoinHandle},
6};
7
8use crate::DelayedReceiver;
9
10use super::errors::{MIDILoadError, MIDIParseError};
11
12use std::fmt::Debug;
13#[derive(Debug)]
14pub struct DiskReader {
15 reader: Arc<BufferReadProvider>,
16 length: u64,
17}
18
19#[derive(Debug)]
20pub struct RAMReader {
21 bytes: Arc<Vec<u8>>,
22 pos: usize,
23}
24
25pub struct ReadCommand {
26 destination: Sender<Result<Vec<u8>, io::Error>>,
27 buffer: Vec<u8>,
28 start: u64,
29 length: usize,
30}
31
32#[derive(Debug)]
33pub struct BufferReadProvider {
34 _thread: JoinHandle<()>,
35 send: Sender<ReadCommand>,
36}
37
38impl BufferReadProvider {
39 pub fn new<T: 'static + Read + Seek + Send>(mut reader: T) -> BufferReadProvider {
40 let (snd, rcv) = unbounded::<ReadCommand>();
41
42 let handle = thread::spawn(move || {
43 let mut read = move |mut buffer: Vec<u8>,
44 start: u64,
45 length: usize|
46 -> Result<Vec<u8>, io::Error> {
47 reader.seek(SeekFrom::Start(start))?;
48 if length < buffer.len() {
49 buffer.truncate(length)
50 }
51 reader.read_exact(&mut buffer)?;
52 Ok(buffer)
53 };
54
55 loop {
56 match rcv.recv() {
57 Err(_) => return,
58 Ok(cmd) => match read(cmd.buffer, cmd.start, cmd.length) {
59 Ok(buf) => {
60 cmd.destination.send(Ok(buf)).ok();
61 }
62 Err(e) => {
63 cmd.destination.send(Err(e)).ok();
64 }
65 },
66 }
67 }
68 });
69
70 BufferReadProvider {
71 send: snd,
72 _thread: handle,
73 }
74 }
75
76 pub fn send_read_command(
77 &self,
78 destination: Sender<Result<Vec<u8>, io::Error>>,
79 buffer: Vec<u8>,
80 start: u64,
81 length: usize,
82 ) {
83 let cmd = ReadCommand {
84 destination,
85 buffer,
86 start,
87 length,
88 };
89
90 self.send.send(cmd).unwrap();
91 }
92
93 pub fn read_sync(&self, buf: Vec<u8>, start: u64) -> Result<Vec<u8>, io::Error> {
94 let (send, receive) = bounded::<Result<Vec<u8>, io::Error>>(1);
95
96 let len = buf.len();
97 self.send_read_command(send, buf, start, len);
98
99 receive.recv().unwrap()
100 }
101}
102
103fn get_reader_len<T: Seek>(reader: &mut T) -> Result<u64, MIDILoadError> {
104 let pos = reader.seek(SeekFrom::End(0))?;
105 reader.seek(SeekFrom::Start(0))?;
106 Ok(pos)
107}
108
109impl DiskReader {
110 pub fn new<T: 'static + Read + Seek + Send>(
111 mut reader: T,
112 ) -> Result<DiskReader, MIDILoadError> {
113 let len = get_reader_len(&mut reader);
114 let reader = BufferReadProvider::new(reader);
115
116 match len {
117 Err(e) => Err(e),
118 Ok(length) => Ok(DiskReader {
119 reader: Arc::new(reader),
120 length,
121 }),
122 }
123 }
124}
125
126impl RAMReader {
127 pub fn new<T: Read + Seek>(mut reader: T) -> Result<RAMReader, MIDILoadError> {
128 let len = get_reader_len(&mut reader);
129
130 match len {
131 Err(e) => Err(e),
132 Ok(length) => {
133 let max_supported: u64 = 2147483648;
134 if length > max_supported {
135 return Err(MIDILoadError::FileTooBig);
136 }
137
138 let mut bytes = vec![0; length as usize];
139 reader.read_exact(&mut bytes)?;
140 Ok(RAMReader {
141 bytes: Arc::new(bytes),
142 pos: 0,
143 })
144 }
145 }
146 }
147
148 pub fn read_byte(&mut self) -> Result<u8, MIDILoadError> {
149 let b = self.bytes.get(self.pos);
150 self.pos += 1;
151 match b {
152 Some(v) => Ok(*v),
153 None => Err(MIDILoadError::CorruptChunks),
154 }
155 }
156}
157
158pub trait MIDIReader: Debug {
159 type ByteReader: TrackReader;
160
161 fn read_bytes_to(&self, pos: u64, bytes: Vec<u8>) -> Result<Vec<u8>, MIDILoadError>;
162 fn read_bytes(&self, pos: u64, count: usize) -> Result<Vec<u8>, MIDILoadError> {
163 let bytes = vec![0u8; count];
164
165 self.read_bytes_to(pos, bytes)
166 }
167
168 fn len(&self) -> u64;
169 fn is_empty(&self) -> bool {
170 self.len() == 0
171 }
172
173 fn open_reader(&self, track_number: Option<u32>, start: u64, len: u64) -> Self::ByteReader;
174}
175
176impl MIDIReader for DiskReader {
177 type ByteReader = DiskTrackReader;
178
179 fn open_reader(&self, track_number: Option<u32>, start: u64, len: u64) -> DiskTrackReader {
180 DiskTrackReader::new(track_number, self.reader.clone(), start, len)
181 }
182
183 fn read_bytes_to(&self, pos: u64, bytes: Vec<u8>) -> Result<Vec<u8>, MIDILoadError> {
184 Ok(self.reader.read_sync(bytes, pos)?)
185 }
186
187 fn len(&self) -> u64 {
188 self.length
189 }
190}
191
192impl MIDIReader for RAMReader {
193 type ByteReader = FullRamTrackReader;
194
195 fn open_reader<'a>(
196 &self,
197 track_number: Option<u32>,
198 start: u64,
199 len: u64,
200 ) -> FullRamTrackReader {
201 FullRamTrackReader {
202 track_number,
203 start: start as usize,
204 pos: start as usize,
205 end: (start + len) as usize,
206 bytes: self.bytes.clone(),
207 }
208 }
209
210 fn read_bytes_to(&self, pos: u64, mut bytes: Vec<u8>) -> Result<Vec<u8>, MIDILoadError> {
211 let count = bytes.len();
212 if pos + count as u64 > self.len() {
213 return Err(MIDILoadError::CorruptChunks);
214 }
215
216 bytes[..].clone_from_slice(&self.bytes[pos as usize..pos as usize + count]);
217
218 Ok(bytes)
219 }
220
221 fn len(&self) -> u64 {
222 self.bytes.len() as u64
223 }
224}
225
226pub trait TrackReader: Send + Sync {
227 fn track_number(&self) -> Option<u32>;
229
230 fn read(&mut self) -> Result<u8, MIDIParseError>;
231 fn pos(&self) -> u64;
232 fn is_at_end(&self) -> bool;
233}
234
235pub struct DiskTrackReader {
236 track_number: Option<u32>,
238
239 reader: Arc<BufferReadProvider>,
240 start: u64, len: u64, buffer: Option<Vec<u8>>, buffer_start: u64, buffer_pos: usize, unrequested_data_start: u64, receiver: DelayedReceiver<Result<Vec<u8>, io::Error>>,
248 receiver_sender: Option<Sender<Result<Vec<u8>, io::Error>>>, }
250
251pub struct FullRamTrackReader {
252 track_number: Option<u32>,
254 start: usize,
255
256 bytes: Arc<Vec<u8>>,
257 pos: usize,
258 end: usize,
259}
260
261impl FullRamTrackReader {
262 pub fn new(
263 track_number: Option<u32>,
264 bytes: Arc<Vec<u8>>,
265 start: usize,
266 end: usize,
267 ) -> FullRamTrackReader {
268 FullRamTrackReader {
269 track_number,
270 bytes,
271 start,
272 pos: start,
273 end,
274 }
275 }
276
277 pub fn new_from_vec(track_number: Option<u32>, bytes: Vec<u8>) -> FullRamTrackReader {
278 let len = bytes.len();
279 FullRamTrackReader {
280 track_number,
281 bytes: Arc::new(bytes),
282 pos: 0,
283 start: 0,
284 end: len,
285 }
286 }
287}
288
289impl TrackReader for FullRamTrackReader {
290 #[inline(always)]
291 fn read(&mut self) -> Result<u8, MIDIParseError> {
292 if self.pos == self.end {
293 return Err(MIDIParseError::UnexpectedTrackEnd {
294 track_number: self.track_number,
295 track_start: self.start as u64,
296 expected_track_end: self.end as u64,
297 found_track_end: self.pos as u64,
298 });
299 }
300 let b = self.bytes[self.pos];
301 self.pos += 1;
302 Ok(b)
303 }
304
305 #[inline(always)]
306 fn pos(&self) -> u64 {
307 self.pos as u64
308 }
309
310 fn is_at_end(&self) -> bool {
311 self.pos == self.end
312 }
313
314 fn track_number(&self) -> Option<u32> {
315 self.track_number
316 }
317}
318
319impl DiskTrackReader {
320 fn finished_sending_reads(&self) -> bool {
321 self.unrequested_data_start == self.len
322 }
323
324 fn next_buffer_req_length(&self) -> usize {
325 (self.len - self.unrequested_data_start).min(1 << 19) as usize
326 }
327
328 fn send_next_read(&mut self, buffer: Option<Vec<u8>>) {
329 if self.finished_sending_reads() {
330 self.receiver_sender.take();
331 return;
332 }
333
334 let mut next_len = self.next_buffer_req_length();
335
336 let buffer = match buffer {
337 None => vec![0u8; next_len],
338 Some(b) => b,
339 };
340
341 next_len = next_len.min(buffer.len());
342
343 self.reader.send_read_command(
344 self.receiver_sender.clone().unwrap(),
345 buffer,
346 self.unrequested_data_start + self.start,
347 next_len,
348 );
349
350 self.unrequested_data_start += next_len as u64;
351 }
352
353 fn receive_next_buffer(&mut self) -> Option<Result<Vec<u8>, MIDIParseError>> {
354 match self.receiver.recv() {
355 Ok(v) => match v {
356 Ok(v) => Some(Ok(v)),
357 Err(e) => Some(Err(e.into())),
358 },
359 Err(_) => None,
360 }
361 }
362
363 pub fn new(
364 track_number: Option<u32>,
365 reader: Arc<BufferReadProvider>,
366 start: u64,
367 len: u64,
368 ) -> DiskTrackReader {
369 let buffer_count = 3;
370
371 let (send, receive) = unbounded();
372
373 let mut reader = DiskTrackReader {
374 track_number,
375 reader,
376 start,
377 len,
378 buffer: None,
379 buffer_start: 0,
380 buffer_pos: 0,
381 unrequested_data_start: 0,
382 receiver: DelayedReceiver::new(receive),
383 receiver_sender: Some(send),
384 };
385
386 for _ in 0..buffer_count {
387 reader.send_next_read(None);
388 }
389
390 reader.receiver.wait_first();
391
392 reader
393 }
394}
395
396impl TrackReader for DiskTrackReader {
397 fn read(&mut self) -> Result<u8, MIDIParseError> {
398 if self.buffer.is_none() {
399 if let Some(next) = self.receive_next_buffer() {
400 self.buffer = Some(next?);
401 } else {
402 return Err(MIDIParseError::UnexpectedTrackEnd {
403 track_number: self.track_number,
404 track_start: self.start,
405 expected_track_end: self.start + self.len,
406 found_track_end: self.pos(),
407 });
408 }
409 }
410
411 let buffer = self.buffer.as_ref().unwrap();
412 let byte = buffer[self.buffer_pos];
413
414 self.buffer_pos += 1;
415 if self.buffer_pos == buffer.len() {
416 let buffer = self.buffer.take().unwrap();
417 self.buffer_start += buffer.len() as u64;
418 self.buffer_pos = 0;
419 self.send_next_read(Some(buffer));
420 }
421
422 Ok(byte)
423 }
424
425 #[inline(always)]
426 fn pos(&self) -> u64 {
427 self.start + self.buffer_start + self.buffer_pos as u64
428 }
429
430 fn is_at_end(&self) -> bool {
431 self.buffer_start + self.buffer_pos as u64 >= self.len
432 }
433
434 fn track_number(&self) -> Option<u32> {
435 self.track_number
436 }
437}
438
439#[cfg(test)]
440mod tests {
441 use super::RAMReader;
442 use crate::io::errors::MIDILoadError;
443 use std::io::{Read, Seek, SeekFrom};
444
445 struct OversizedReader {
446 pos: u64,
447 len: u64,
448 }
449
450 impl Read for OversizedReader {
451 fn read(&mut self, _buf: &mut [u8]) -> std::io::Result<usize> {
452 panic!("RAMReader::new should reject oversized files before reading")
453 }
454 }
455
456 impl Seek for OversizedReader {
457 fn seek(&mut self, pos: SeekFrom) -> std::io::Result<u64> {
458 self.pos = match pos {
459 SeekFrom::Start(pos) => pos,
460 SeekFrom::End(0) => self.len,
461 SeekFrom::Current(0) => self.pos,
462 _ => panic!("unexpected seek request in oversized reader test"),
463 };
464 Ok(self.pos)
465 }
466 }
467
468 #[test]
469 fn ram_reader_returns_file_too_big_error() {
470 let err = RAMReader::new(OversizedReader {
471 pos: 0,
472 len: 2_147_483_649,
473 })
474 .expect_err("oversized RAM MIDI should error");
475
476 assert!(matches!(err, MIDILoadError::FileTooBig));
477 }
478}