midi-toolkit-rs 0.3.1

A library for ultra high performance MIDI operations, designed for black MIDI. The library isn't perfect
Documentation
use std::{
    fs::File,
    io::{Read, Seek},
    path::Path,
};

use crate::{
    events::Event,
    sequence::{
        channels_into_threadpool,
        event::{
            convert_events_into_batches, flatten_batches_to_events,
            flatten_track_batches_to_events, into_track_events, merge_events_array, Delta,
            EventBatch, Track,
        },
    },
};
use std::fmt::Debug;

use super::{
    errors::{MIDILoadError, MIDIParseError},
    readers::{DiskReader, MIDIReader, RAMReader},
    track_parser::TrackParser,
};

#[derive(Debug)]
struct TrackPos {
    pos: u64,
    len: u32,
}

#[derive(Debug)]
pub struct MIDIFile<T: MIDIReader> {
    reader: T,
    track_positions: Vec<TrackPos>,

    format: u16,
    ppq: u16,
}

impl<T: 'static + MIDIReader> MIDIFile<T> {
    fn new_from_disk_reader(
        reader: T,
        mut read_progress: Option<&mut dyn FnMut(u32)>,
    ) -> Result<Self, MIDILoadError> {
        fn bytes_to_val(bytes: &[u8]) -> u32 {
            assert!(bytes.len() <= 4);
            let mut num: u32 = 0;
            for b in bytes {
                num = (num << 8) + *b as u32;
            }

            num
        }

        fn read_header<T: MIDIReader>(
            reader: &T,
            pos: u64,
            text: &str,
        ) -> Result<u32, MIDILoadError> {
            assert!(text.len() == 4);

            let bytes = reader.read_bytes(pos, 8)?;

            let (header, len) = bytes.split_at(4);

            let chars = text.as_bytes();

            for i in 0..chars.len() {
                if chars[i] != header[i] {
                    return Err(MIDILoadError::CorruptChunks);
                }
            }

            Ok(bytes_to_val(len))
        }

        let mut pos = 0u64;

        let header_len = read_header(&reader, pos, "MThd")?;
        pos += 8;
        if header_len != 6 {
            return Err(MIDILoadError::CorruptChunks);
        }

        let (format, ppq) = {
            let header_data = reader.read_bytes(pos, 6)?;
            pos += 6;
            let (format_bytes, rest) = header_data.split_at(2);
            let (_, ppq_bytes) = rest.split_at(2);
            (
                bytes_to_val(format_bytes) as u16,
                bytes_to_val(ppq_bytes) as u16,
            )
        };

        let mut track_count = 0;
        let mut track_positions = Vec::<TrackPos>::new();
        while pos != reader.len() {
            let len = read_header(&reader, pos, "MTrk")?;
            pos += 8;
            track_count += 1;
            track_positions.push(TrackPos { len, pos });
            pos += len as u64;

            if let Some(progress) = read_progress.as_mut() {
                progress(track_count);
            }
        }

        track_positions.shrink_to_fit();
        Ok(MIDIFile {
            reader,
            ppq,
            format,
            track_positions,
        })
    }

    pub fn open_track_reader(&self, track: u32) -> Option<T::ByteReader> {
        let pos = self.track_positions.get(track as usize)?;
        Some(
            self.reader
                .open_reader(Some(track), pos.pos, pos.len as u64),
        )
    }

    pub fn iter_all_tracks(
        &self,
    ) -> impl Iterator<Item = impl Iterator<Item = Result<Delta<u64, Event>, MIDIParseError>>> {
        let mut tracks = Vec::new();
        for i in 0..self.track_count() {
            tracks.push(
                self.iter_track(i as u32)
                    .expect("track iteration should exist for a known track index"),
            );
        }
        tracks.into_iter()
    }

    pub fn iter_all_events_merged(
        &self,
    ) -> impl Iterator<Item = Result<Delta<u64, Event>, MIDIParseError>> {
        let merged_batches = self.iter_all_events_merged_batches();
        flatten_batches_to_events(merged_batches)
    }

    pub fn iter_all_track_events_merged(
        &self,
    ) -> impl Iterator<Item = Result<Delta<u64, Track<Event>>, MIDIParseError>> {
        let merged_batches = self.iter_all_track_events_merged_batches();
        flatten_track_batches_to_events(merged_batches)
    }

    pub fn iter_all_events_merged_batches(
        &self,
    ) -> impl Iterator<Item = Result<Delta<u64, EventBatch<Event>>, MIDIParseError>> {
        let batched_tracks = self
            .iter_all_tracks()
            .map(convert_events_into_batches)
            .collect();
        let batched_tracks_threaded = channels_into_threadpool(batched_tracks, 10);
        merge_events_array(batched_tracks_threaded)
    }

    pub fn iter_all_track_events_merged_batches(
        &self,
    ) -> impl Iterator<Item = Result<Delta<u64, Track<EventBatch<Event>>>, MIDIParseError>> {
        let batched_tracks = self
            .iter_all_tracks()
            .map(convert_events_into_batches)
            .enumerate()
            .map(|(i, track)| into_track_events(track, i as u32))
            .collect();
        let batched_tracks_threaded = channels_into_threadpool(batched_tracks, 10);
        merge_events_array(batched_tracks_threaded)
    }

    pub fn iter_track(
        &self,
        track: u32,
    ) -> Option<impl Iterator<Item = Result<Delta<u64, Event>, MIDIParseError>>> {
        let reader = self.open_track_reader(track)?;
        Some(TrackParser::new(reader))
    }

    pub fn ppq(&self) -> u16 {
        self.ppq
    }

    pub fn format(&self) -> u16 {
        self.format
    }

    pub fn track_count(&self) -> usize {
        self.track_positions.len()
    }
}

impl MIDIFile<DiskReader> {
    pub fn open(
        filename: impl AsRef<Path>,
        read_progress: Option<&mut dyn FnMut(u32)>,
    ) -> Result<Self, MIDILoadError> {
        let reader = File::open(filename)?;
        let reader = DiskReader::new(reader)?;

        MIDIFile::new_from_disk_reader(reader, read_progress)
    }

    pub fn open_from_stream<T: 'static + Read + Seek + Send>(
        stream: T,
        read_progress: Option<&mut dyn FnMut(u32)>,
    ) -> Result<Self, MIDILoadError> {
        let reader = DiskReader::new(stream)?;

        MIDIFile::new_from_disk_reader(reader, read_progress)
    }
}

impl MIDIFile<RAMReader> {
    pub fn open_in_ram(
        filename: impl AsRef<Path>,
        read_progress: Option<&mut dyn FnMut(u32)>,
    ) -> Result<Self, MIDILoadError> {
        let reader = File::open(filename)?;
        let reader = RAMReader::new(reader)?;

        MIDIFile::new_from_disk_reader(reader, read_progress)
    }

    pub fn open_from_stream_in_ram<T: 'static + Read + Seek + Send>(
        stream: T,
        read_progress: Option<&mut dyn FnMut(u32)>,
    ) -> Result<Self, MIDILoadError> {
        let reader = RAMReader::new(stream)?;

        MIDIFile::new_from_disk_reader(reader, read_progress)
    }
}

#[cfg(test)]
mod tests {
    use super::MIDIFile;
    use crate::io::RAMReader;
    use std::io::Cursor;

    fn single_track_midi_bytes() -> Vec<u8> {
        vec![
            b'M', b'T', b'h', b'd', 0x00, 0x00, 0x00, 0x06, 0x00, 0x01, 0x00, 0x01, 0x01, 0xE0,
            b'M', b'T', b'r', b'k', 0x00, 0x00, 0x00, 0x04, 0x00, 0xFF, 0x2F, 0x00,
        ]
    }

    #[test]
    fn invalid_track_access_returns_none() {
        let midi: MIDIFile<RAMReader> =
            MIDIFile::open_from_stream_in_ram(Cursor::new(single_track_midi_bytes()), None)
                .expect("valid midi fixture");

        assert!(midi.open_track_reader(1).is_none());
        assert!(midi.iter_track(1).is_none());
    }
}