midi-toolkit-rs 0.3.1

A library for ultra high performance MIDI operations, designed for black MIDI. The library isn't perfect
Documentation
use std::{
    collections::{BTreeMap, BTreeSet, HashSet},
    io::{self, Write},
};

use crate::events::SerializeEventWithDelta;

use super::MIDIWriteError;

pub fn encode_u16(val: u16) -> [u8; 2] {
    let mut bytes = [0; 2];
    bytes[0] = ((val >> 8) & 0xff) as u8;
    bytes[1] = (val & 0xff) as u8;
    bytes
}

pub fn encode_u32(val: u32) -> [u8; 4] {
    let mut bytes = [0; 4];
    bytes[0] = ((val >> 24) & 0xff) as u8;
    bytes[1] = ((val >> 16) & 0xff) as u8;
    bytes[2] = ((val >> 8) & 0xff) as u8;
    bytes[3] = (val & 0xff) as u8;
    bytes
}

pub fn write_midi_header(
    output: &mut dyn Write,
    format: u16,
    track_count: u16,
    ppq: u16,
) -> io::Result<()> {
    output.write_all(b"MThd")?;
    output.write_all(&encode_u32(6))?;
    output.write_all(&encode_u16(format))?;
    output.write_all(&encode_u16(track_count))?;
    output.write_all(&encode_u16(ppq))?;
    Ok(())
}

pub fn write_track_header(output: &mut dyn Write, length: u32) -> io::Result<()> {
    output.write_all(b"MTrk")?;
    output.write_all(&encode_u32(length))?;
    Ok(())
}

fn sorted_track_ids(tracks: &HashSet<i32>) -> Vec<i32> {
    let mut ids = tracks.iter().copied().collect::<Vec<_>>();
    ids.sort_unstable();
    ids
}

#[derive(Debug)]
pub struct OrderedTrackRegistry<T> {
    opened_tracks: HashSet<i32>,
    completed_tracks: BTreeSet<i32>,
    pending_tracks: BTreeMap<i32, T>,
    next_init_track: i32,
    next_flush_track: i32,
}

impl<T> OrderedTrackRegistry<T> {
    pub fn new() -> Self {
        Self {
            opened_tracks: HashSet::new(),
            completed_tracks: BTreeSet::new(),
            pending_tracks: BTreeMap::new(),
            next_init_track: 0,
            next_flush_track: 0,
        }
    }

    pub fn open_next_track(&mut self) -> Result<i32, MIDIWriteError> {
        let track_id = self.next_init_track;
        self.open_track(track_id)?;
        self.next_init_track += 1;
        Ok(track_id)
    }

    pub fn open_track(&mut self, track_id: i32) -> Result<(), MIDIWriteError> {
        if track_id < 0 {
            return Err(MIDIWriteError::InvalidTrackId { track_id });
        }

        if self.completed_tracks.contains(&track_id) || self.opened_tracks.contains(&track_id) {
            return Err(MIDIWriteError::TrackAlreadyOpened { track_id });
        }

        self.opened_tracks.insert(track_id);
        Ok(())
    }

    pub fn finish_track(&mut self, track_id: i32, payload: T) -> Result<(), MIDIWriteError> {
        if !self.opened_tracks.remove(&track_id) {
            return Err(MIDIWriteError::TrackAlreadyEnded { track_id });
        }

        self.completed_tracks.insert(track_id);
        self.pending_tracks.insert(track_id, payload);
        Ok(())
    }

    pub fn drain_ready_tracks(&mut self) -> Vec<(i32, T)> {
        let mut ready = Vec::new();
        while let Some(track) = self.pending_tracks.remove(&self.next_flush_track) {
            ready.push((self.next_flush_track, track));
            self.next_flush_track += 1;
        }
        ready
    }

    pub fn drain_all_tracks(&mut self) -> Vec<(i32, T)> {
        std::mem::take(&mut self.pending_tracks)
            .into_iter()
            .collect()
    }

    pub fn finalize_track_count(&self) -> Result<u16, MIDIWriteError> {
        if !self.opened_tracks.is_empty() {
            return Err(MIDIWriteError::OpenTracksRemaining {
                track_ids: sorted_track_ids(&self.opened_tracks),
            });
        }

        if let Some(&max_track) = self.completed_tracks.iter().next_back() {
            let mut missing = (0..max_track)
                .filter(|track_id| !self.completed_tracks.contains(track_id))
                .collect::<Vec<_>>();
            missing.sort_unstable();
            if !missing.is_empty() {
                return Err(MIDIWriteError::TrackGapsRemaining { track_ids: missing });
            }
        }

        let track_count = self.completed_tracks.len();
        if track_count > u16::MAX as usize {
            return Err(MIDIWriteError::TrackCountOverflow { track_count });
        }

        Ok(track_count as u16)
    }
}

impl<T> Default for OrderedTrackRegistry<T> {
    fn default() -> Self {
        Self::new()
    }
}

struct LengthTrackingWriter<'a, W> {
    inner: &'a mut W,
    written: usize,
}

impl<'a, W> LengthTrackingWriter<'a, W> {
    fn new(inner: &'a mut W) -> Self {
        Self { inner, written: 0 }
    }
}

impl<W: Write> Write for LengthTrackingWriter<'_, W> {
    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
        let written = self.inner.write(buf)?;
        self.written += written;
        Ok(written)
    }

    fn flush(&mut self) -> io::Result<()> {
        self.inner.flush()
    }
}

#[derive(Debug)]
pub struct TrackByteSink<W> {
    track_id: i32,
    writer: Option<W>,
    length: u64,
}

impl<W> TrackByteSink<W> {
    pub fn new(track_id: i32, writer: W) -> Self {
        Self {
            track_id,
            writer: Some(writer),
            length: 0,
        }
    }

    pub fn track_id(&self) -> i32 {
        self.track_id
    }

    pub fn is_ended(&self) -> bool {
        self.writer.is_none()
    }

    pub fn writer_mut(&mut self) -> Result<&mut W, MIDIWriteError> {
        self.writer
            .as_mut()
            .ok_or(MIDIWriteError::TrackAlreadyEnded {
                track_id: self.track_id,
            })
    }

    fn note_write(&mut self, written: usize) -> Result<(), MIDIWriteError> {
        let new_length =
            self.length
                .checked_add(written as u64)
                .ok_or(MIDIWriteError::TrackLengthOverflow {
                    track_id: self.track_id,
                    length: u64::MAX,
                })?;

        if new_length > u32::MAX as u64 {
            return Err(MIDIWriteError::TrackLengthOverflow {
                track_id: self.track_id,
                length: new_length,
            });
        }

        self.length = new_length;
        Ok(())
    }
}

impl<W: Write> TrackByteSink<W> {
    pub fn write_event<T: SerializeEventWithDelta>(
        &mut self,
        event: T,
    ) -> Result<usize, MIDIWriteError> {
        let written = {
            let writer = self.writer_mut()?;
            let mut writer = LengthTrackingWriter::new(writer);
            event.serialize_event_with_delta(&mut writer)?;
            writer.written
        };
        self.note_write(written)?;
        Ok(written)
    }

    pub fn write_events_iter<T: SerializeEventWithDelta>(
        &mut self,
        events: impl Iterator<Item = T>,
    ) -> Result<usize, MIDIWriteError> {
        let mut count = 0;
        for event in events {
            count += self.write_event(event)?;
        }
        Ok(count)
    }

    pub fn write_bytes(&mut self, bytes: &[u8]) -> Result<usize, MIDIWriteError> {
        self.note_write(bytes.len())?;
        self.writer_mut()?.write_all(bytes)?;
        Ok(bytes.len())
    }

    pub fn finish(&mut self) -> Result<(W, u32), MIDIWriteError> {
        if self.is_ended() {
            return Err(MIDIWriteError::TrackAlreadyEnded {
                track_id: self.track_id,
            });
        }

        self.write_bytes(&[0x00, 0xFF, 0x2F, 0x00])?;
        let length = self.length as u32;
        let writer = self
            .writer
            .take()
            .ok_or(MIDIWriteError::TrackAlreadyEnded {
                track_id: self.track_id,
            })?;
        Ok((writer, length))
    }
}

impl<W: Write> Write for TrackByteSink<W> {
    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
        let writer = self
            .writer
            .as_mut()
            .ok_or_else(|| io::Error::new(io::ErrorKind::BrokenPipe, "track writer ended"))?;
        let written = writer.write(buf)?;
        self.note_write(written)
            .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?;
        Ok(written)
    }

    fn flush(&mut self) -> io::Result<()> {
        self.writer
            .as_mut()
            .ok_or_else(|| io::Error::new(io::ErrorKind::BrokenPipe, "track writer ended"))?
            .flush()
    }
}