mousiki 0.2.1

Pure Rust Opus codec.
Documentation
use alloc::vec::Vec;

const WEIGHT_BLOCK_SIZE: usize = 64;
const WEIGHT_NAME_LEN: usize = 44;

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum WeightError {
    InvalidBlob,
    MissingArray(&'static str),
    SizeMismatch(&'static str),
    InvalidIndex(&'static str),
}

#[derive(Debug, Clone, Copy)]
pub(crate) struct WeightArray<'a> {
    pub name: &'a str,
    pub size: usize,
    pub data: &'a [u8],
}

#[derive(Debug)]
pub(crate) struct WeightBlob<'a> {
    arrays: Vec<WeightArray<'a>>,
}

impl<'a> WeightBlob<'a> {
    pub(crate) fn parse(mut data: &'a [u8]) -> Result<Self, WeightError> {
        let mut arrays = Vec::new();
        while !data.is_empty() {
            if data.len() < WEIGHT_BLOCK_SIZE {
                return Err(WeightError::InvalidBlob);
            }
            let header = &data[..WEIGHT_BLOCK_SIZE];
            let size = read_i32(header, 12)?;
            let block_size = read_i32(header, 16)?;
            if size < 0 || block_size < size {
                return Err(WeightError::InvalidBlob);
            }
            let size = usize::try_from(size).map_err(|_| WeightError::InvalidBlob)?;
            let block_size = usize::try_from(block_size).map_err(|_| WeightError::InvalidBlob)?;
            if block_size > data.len().saturating_sub(WEIGHT_BLOCK_SIZE) {
                return Err(WeightError::InvalidBlob);
            }

            let name_bytes = &header[20..20 + WEIGHT_NAME_LEN];
            if name_bytes[WEIGHT_NAME_LEN - 1] != 0 {
                return Err(WeightError::InvalidBlob);
            }
            let name_len = name_bytes
                .iter()
                .position(|&byte| byte == 0)
                .unwrap_or(WEIGHT_NAME_LEN);
            let name = core::str::from_utf8(&name_bytes[..name_len])
                .map_err(|_| WeightError::InvalidBlob)?;

            let data_start = WEIGHT_BLOCK_SIZE;
            let data_end = data_start
                .checked_add(size)
                .ok_or(WeightError::InvalidBlob)?;
            if data_end > WEIGHT_BLOCK_SIZE + block_size {
                return Err(WeightError::InvalidBlob);
            }
            let payload = &data[data_start..data_end];
            arrays.push(WeightArray {
                name,
                size,
                data: payload,
            });

            let advance = WEIGHT_BLOCK_SIZE
                .checked_add(block_size)
                .ok_or(WeightError::InvalidBlob)?;
            data = &data[advance..];
        }

        Ok(Self { arrays })
    }

    pub(crate) fn find(&self, name: &str) -> Option<&WeightArray<'a>> {
        self.arrays.iter().find(|array| array.name == name)
    }
}

pub(crate) fn require_bytes<'a>(
    blob: &'a WeightBlob<'a>,
    name: &'static str,
    expected: usize,
) -> Result<&'a [u8], WeightError> {
    let array = blob.find(name).ok_or(WeightError::MissingArray(name))?;
    if array.size != expected {
        return Err(WeightError::SizeMismatch(name));
    }
    Ok(array.data)
}

pub(crate) fn optional_bytes<'a>(
    blob: &'a WeightBlob<'a>,
    name: &'static str,
    expected: usize,
) -> Result<Option<&'a [u8]>, WeightError> {
    let Some(array) = blob.find(name) else {
        return Ok(None);
    };
    if array.size != expected {
        return Err(WeightError::SizeMismatch(name));
    }
    Ok(Some(array.data))
}

fn read_i32(data: &[u8], offset: usize) -> Result<i32, WeightError> {
    let end = offset.checked_add(4).ok_or(WeightError::InvalidBlob)?;
    if end > data.len() {
        return Err(WeightError::InvalidBlob);
    }
    let bytes: [u8; 4] = data[offset..end]
        .try_into()
        .map_err(|_| WeightError::InvalidBlob)?;
    Ok(i32::from_le_bytes(bytes))
}

#[cfg(test)]
mod tests {
    use super::{WEIGHT_BLOCK_SIZE, WEIGHT_NAME_LEN, WeightBlob, WeightError};
    use alloc::vec;
    use alloc::vec::Vec;

    fn build_blob(name: &str, payload: &[u8]) -> Vec<u8> {
        assert!(payload.len() > 0);
        let mut blob = vec![0u8; WEIGHT_BLOCK_SIZE + payload.len()];
        let size = payload.len() as i32;
        blob[12..16].copy_from_slice(&size.to_le_bytes());
        blob[16..20].copy_from_slice(&size.to_le_bytes());
        let name_bytes = name.as_bytes();
        assert!(name_bytes.len() < WEIGHT_NAME_LEN);
        blob[20..20 + name_bytes.len()].copy_from_slice(name_bytes);
        blob[WEIGHT_BLOCK_SIZE..].copy_from_slice(payload);
        blob
    }

    #[test]
    fn parse_rejects_short_header() {
        let err = WeightBlob::parse(&[0u8; 12]).unwrap_err();
        assert_eq!(err, WeightError::InvalidBlob);
    }

    #[test]
    fn parse_accepts_single_entry() {
        let payload = [1u8, 2, 3, 4];
        let blob = build_blob("test", &payload);
        let parsed = WeightBlob::parse(&blob).expect("parse");
        let entry = parsed.find("test").expect("entry");
        assert_eq!(entry.size, payload.len());
        assert_eq!(entry.data, payload);
    }
}