midi-codec 0.4.0

Tools for encoding and decoding a stream of MIDI messages.
Documentation
use core::{fmt::Debug, time::Duration};

use crate::{
    error::DecodeError,
    io::ReadSimple,
    message::{ChannelVoiceMessage, Message, RealTimeMessage, SystemCommonMessage},
    status::{ChannelStatus, ChannelVoiceStatus, RequiredDataBytes, Status, SystemCommonStatus},
};

use self::builder::DecoderBuilder;

pub mod builder;

/// A decoder state machine.
pub struct Decoder<R, B> {
    decoder_buf: DecoderBuf<R, B>,

    running_status: Option<ChannelVoiceStatus>,
    partial_message: Option<PartialMessage>,
}

struct DecoderBuf<R, B> {
    reader: R,
    buf: B,

    buf_fill: usize,
    buf_used: usize,
}

impl Decoder<(), ()> {
    pub fn builder() -> DecoderBuilder {
        DecoderBuilder::default()
    }
}

impl<R, B> DecoderBuf<R, B>
where
    R: ReadSimple,
    B: AsMut<[u8]>,
{
    fn refill_buf(&mut self) -> Result<(), R::Error> {
        // tracing::trace!("refill");
        let bytes_read = self.reader.read_simple(self.buf.as_mut())?;
        self.buf_fill = bytes_read;
        self.buf_used = 0;
        // tracing::trace!("refilled {bytes_read} bytes");
        Ok(())
    }

    fn peek_byte(&mut self) -> Result<u8, R::Error> {
        let mut should_wait = false;
        while self.buf_used >= self.buf_fill {
            #[cfg(feature = "std")]
            if should_wait {
                std::thread::sleep(Duration::from_millis(1));
            }

            self.refill_buf()?;
            should_wait = true;
        }

        tracing::trace!("peeked {}", self.buf_used);

        Ok(self.buf.as_mut()[self.buf_used])
    }

    fn skip_byte(&mut self) {
        tracing::trace!("skipped {}", self.buf_used);
        self.buf_used += 1;
    }

    fn read_byte(&mut self) -> Result<u8, R::Error> {
        let byte = self.peek_byte()?;
        self.skip_byte();

        Ok(byte)
    }

    fn read_data_byte(&mut self) -> Result<Result<u8, RealTimeMessage>, DecodeError<R::Error>> {
        let byte = self.read_byte()?;
        match parse_byte(byte) {
            ByteType::Data(d) => Ok(Ok(d)),
            ByteType::Status(_status) => return Err(DecodeError::UnexpectedStatus(byte)),
            ByteType::RealTimeStatus(real_time) => return Ok(Err(real_time)),
        }
    }

    fn flush(&mut self) -> Result<(), R::Error> {
        self.buf_fill = 0;
        self.buf_used = 0;

        loop {
            let size = self.reader.read_simple(self.buf.as_mut())?;

            if size == 0 {
                break;
            }
        }

        Ok(())
    }
}

impl<R, B> Decoder<R, B>
where
    R: ReadSimple,
    B: AsMut<[u8]>,
{
    fn new(reader: R, buf: B) -> Self {
        Self {
            decoder_buf: DecoderBuf {
                reader,
                buf,

                buf_fill: 0,
                buf_used: 0,
            },

            running_status: None,
            partial_message: None,
        }
    }

    pub fn next_message(&mut self) -> Result<Message, DecodeError<R::Error>> {
        tracing::trace!("read message");
        let partial_message = if let Some(ref mut partial_message) = self.partial_message {
            partial_message
        } else {
            let mut data0 = None;

            let status = match parse_byte(self.decoder_buf.read_byte()?) {
                ByteType::Data(data) => {
                    if let Some(running_status) = self.running_status {
                        data0 = Some(data);
                        Status::ChannelVoice(running_status)
                    } else {
                        return Err(DecodeError::ExtraneousData(data));
                    }
                }
                ByteType::Status(status) => status,
                ByteType::RealTimeStatus(real_time) => {
                    if real_time == RealTimeMessage::SystemReset {
                        self.running_status = None;
                    }
                    return Ok(Message::RealTime(real_time));
                }
            };

            let required_data_bytes = status.required_data_bytes();

            if required_data_bytes == RequiredDataBytes::D0 {
                return Ok(construct_message(status, [0, 0]));
            }

            self.partial_message
                .insert(PartialMessage { status, data0 })
        };
        let required_data_bytes = partial_message.status.required_data_bytes();

        let next_data = match self.decoder_buf.read_data_byte()? {
            Ok(d) => d,
            Err(real_time) => {
                if real_time == RealTimeMessage::SystemReset {
                    self.running_status = None;
                }
                return Ok(Message::RealTime(real_time));
            }
        };

        if required_data_bytes == RequiredDataBytes::D1 {
            let status = partial_message.status;
            self.partial_message = None;

            return Ok(construct_message(status, [next_data, 0]));
        } else {
            let (data0, data1) = match partial_message.data0 {
                Some(data0) => (data0, next_data),
                None => {
                    let data0 = next_data;
                    partial_message.data0 = Some(data0); // In case we return a realtime message below

                    let data1 = match self.decoder_buf.read_data_byte()? {
                        Ok(d) => d,
                        Err(real_time) => {
                            if real_time == RealTimeMessage::SystemReset {
                                self.running_status = None;
                            }
                            return Ok(Message::RealTime(real_time));
                        }
                    };

                    (data0, data1)
                }
            };

            let status = partial_message.status;
            self.partial_message = None;

            return Ok(construct_message(status, [data0, data1]));
        }
    }

    pub fn read_sysex(&mut self) -> SysexReader<'_, R, B> {
        tracing::trace!("read sysex");
        SysexReader { decoder: self }
    }

    pub fn flush(&mut self) -> Result<(), R::Error> {
        tracing::trace!("flush");
        self.decoder_buf.flush()
    }
}

pub struct SysexReader<'a, R, B> {
    decoder: &'a mut Decoder<R, B>,
}

impl<'a, R, B> Iterator for SysexReader<'a, R, B>
where
    R: ReadSimple,
    B: AsMut<[u8]>,
{
    type Item = Result<u8, R::Error>;

    fn next(&mut self) -> Option<Self::Item> {
        let byte = match self.decoder.decoder_buf.peek_byte() {
            Ok(byte) => byte,
            Err(e) => return Some(Err(e)),
        };

        if is_data_byte(byte) {
            tracing::trace!("returning sysex byte");
            self.decoder.decoder_buf.skip_byte();
            Some(Ok(byte))
        } else {
            None
        }
    }
}

fn is_data_byte(byte: u8) -> bool {
    byte & 0b1000_0000 == 0
}

fn parse_byte(byte: u8) -> ByteType {
    if is_data_byte(byte) {
        ByteType::Data(byte)
    } else if byte & 0b1111_1000 == 0b1111_1000 {
        let index = byte & 0b0111;

        let real_time = match index {
            0 => RealTimeMessage::TimingClock,
            1 => RealTimeMessage::Undefined1,
            2 => RealTimeMessage::Start,
            3 => RealTimeMessage::Continue,
            4 => RealTimeMessage::Stop,
            5 => RealTimeMessage::Undefined2,
            6 => RealTimeMessage::ActiveSensing,
            7 => RealTimeMessage::SystemReset,
            _ => unreachable!(),
        };

        ByteType::RealTimeStatus(real_time)
    } else {
        let status_upper = (byte & 0b0111_0000) >> 4;
        let status_lower = byte & 0b1111;

        let status = if status_upper == 7 {
            // System Messages
            // status_lower > 7 is handled above in the real time status block

            if status_lower == 0 {
                Status::SystemExclusive
            } else {
                let common = match status_lower {
                    1 => SystemCommonStatus::MTCQuarterFrame,
                    2 => SystemCommonStatus::SongPositionPointer,
                    3 => SystemCommonStatus::SongSelect,
                    4 => SystemCommonStatus::Undefined1,
                    5 => SystemCommonStatus::Undefined2,
                    6 => SystemCommonStatus::TuneRequest,
                    7 => SystemCommonStatus::EOX,
                    _ => unreachable!(),
                };

                Status::SystemCommon(common)
            }
        } else {
            let channel_voice = match status_upper {
                0 => ChannelStatus::NoteOff,
                1 => ChannelStatus::NoteOn,
                2 => ChannelStatus::PolyphonicPressure,
                3 => ChannelStatus::ControlChange,
                4 => ChannelStatus::ProgramChange,
                5 => ChannelStatus::ChannelPressure,
                6 => ChannelStatus::PitchBend,
                _ => unreachable!(),
            };

            Status::ChannelVoice(ChannelVoiceStatus {
                status: channel_voice,
                channel: status_lower,
            })
        };

        ByteType::Status(status)
    }
}

fn construct_message(status: Status, data: [u8; 2]) -> Message {
    match status {
        Status::ChannelVoice(ChannelVoiceStatus { status, channel }) => {
            let message = match status {
                ChannelStatus::NoteOff => ChannelVoiceMessage::NoteOff {
                    note: data[0],
                    velocity: data[1],
                },
                ChannelStatus::NoteOn => ChannelVoiceMessage::NoteOn {
                    note: data[0],
                    velocity: data[1],
                },
                ChannelStatus::PolyphonicPressure => ChannelVoiceMessage::PolyphonicPressure {
                    note: data[0],
                    pressure: data[1],
                },
                ChannelStatus::ControlChange => ChannelVoiceMessage::ControlChange {
                    control: data[0],
                    value: data[1],
                },
                ChannelStatus::ProgramChange => {
                    ChannelVoiceMessage::ProgramChange { program: data[0] }
                }
                ChannelStatus::ChannelPressure => {
                    ChannelVoiceMessage::ChannelPressure { pressure: data[0] }
                }
                ChannelStatus::PitchBend => {
                    let lsb = data[0];
                    let msb = data[1];

                    let pitch_bend = (lsb as u16) | ((msb as u16) << 7);

                    ChannelVoiceMessage::PitchBend { pitch_bend }
                }
            };

            Message::ChannelVoice { channel, message }
        }
        Status::SystemCommon(system_common) => {
            let message = match system_common {
                SystemCommonStatus::MTCQuarterFrame => {
                    SystemCommonMessage::MTCQuarterFrame { data: data[0] }
                }
                SystemCommonStatus::SongPositionPointer => {
                    SystemCommonMessage::SongPositionPointer {
                        low: data[0],
                        high: data[1],
                    }
                }
                SystemCommonStatus::SongSelect => SystemCommonMessage::SongSelect { song: data[0] },
                SystemCommonStatus::Undefined1 => SystemCommonMessage::Undefined1,
                SystemCommonStatus::Undefined2 => SystemCommonMessage::Undefined2,
                SystemCommonStatus::TuneRequest => SystemCommonMessage::TuneRequest,
                SystemCommonStatus::EOX => SystemCommonMessage::EOX,
            };

            Message::SystemCommon(message)
        }
        Status::SystemExclusive => Message::SystemExclusive,
    }
}

struct PartialMessage {
    status: Status,
    data0: Option<u8>,
}

#[derive(Debug, Clone, Copy)]
enum ByteType {
    Data(u8),
    Status(Status),
    RealTimeStatus(RealTimeMessage),
}