vyre-self-substrate 0.6.3

Vyre self-substrate: vyre using its own primitives on its own scheduler problems. The recursion-thesis layer between vyre-primitives and vyre-driver.
Documentation
use crate::optimizer::dispatcher::DispatchError;
use vyre_primitives::math::quantized::i4_packed_words;

#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub(super) struct PackedI4BatchedMatmulShape {
    pub(super) total_outputs_u32: u32,
    pub(super) output_words: usize,
    pub(super) output_bytes: usize,
    pub(super) top1_output_bytes: usize,
}

pub(super) fn validate_batched_packed_matmul_shape(
    context: &str,
    weights_packed: &[u32],
    activation_batches_packed: &[u32],
    row_scales: &[f32],
    batch_scales: &[f32],
    batch: u32,
    rows: u32,
    cols: u32,
) -> Result<PackedI4BatchedMatmulShape, DispatchError> {
    if batch == 0 || rows == 0 || cols == 0 {
        return Err(DispatchError::BadInputs(format!(
            "Fix: {context} requires batch > 0, rows > 0, and cols > 0, got batch={batch} rows={rows} cols={cols}."
        )));
    }
    let words_per_row = i4_packed_words(cols) as usize;
    let expected_weight_words = (rows as usize).checked_mul(words_per_row).ok_or_else(|| {
        DispatchError::BadInputs(format!(
            "Fix: {context} weight word count overflows usize for rows={rows} cols={cols}."
        ))
    })?;
    if weights_packed.len() != expected_weight_words {
        return Err(DispatchError::BadInputs(format!(
            "Fix: {context} requires weights_packed.len() == rows*i4_packed_words(cols), got len={} expected={expected_weight_words} for rows={rows} cols={cols}.",
            weights_packed.len()
        )));
    }
    let expected_activation_words =
        (batch as usize).checked_mul(words_per_row).ok_or_else(|| {
            DispatchError::BadInputs(format!(
                "Fix: {context} activation word count overflows usize for batch={batch} cols={cols}."
            ))
        })?;
    if activation_batches_packed.len() != expected_activation_words {
        return Err(DispatchError::BadInputs(format!(
            "Fix: {context} requires activation_batches_packed.len() == batch*i4_packed_words(cols), got len={} expected={expected_activation_words} for batch={batch} cols={cols}.",
            activation_batches_packed.len()
        )));
    }
    if row_scales.len() != rows as usize {
        return Err(DispatchError::BadInputs(format!(
            "Fix: {context} requires row_scales.len() == rows, got len={} rows={rows}.",
            row_scales.len()
        )));
    }
    if batch_scales.len() != batch as usize {
        return Err(DispatchError::BadInputs(format!(
            "Fix: {context} requires batch_scales.len() == batch, got len={} batch={batch}.",
            batch_scales.len()
        )));
    }
    let total_outputs_u32 = batch.checked_mul(rows).ok_or_else(|| {
        DispatchError::BadInputs(format!(
            "Fix: {context} output grid overflows u32 for batch={batch} rows={rows}."
        ))
    })?;
    let output_words = (batch as usize).checked_mul(rows as usize).ok_or_else(|| {
        DispatchError::BadInputs(format!(
            "Fix: {context} output word count overflows usize for batch={batch} rows={rows}."
        ))
    })?;
    let output_bytes = output_words
        .checked_mul(std::mem::size_of::<f32>())
        .ok_or_else(|| {
            DispatchError::BadInputs(format!(
                "Fix: {context} output byte count overflows usize for batch={batch} rows={rows}."
            ))
        })?;
    let top1_output_bytes = (batch as usize)
        .checked_mul(std::mem::size_of::<f32>())
        .ok_or_else(|| {
            DispatchError::BadInputs(format!(
                "Fix: {context} top-1 output byte count overflows usize for batch={batch}."
            ))
        })?;

    Ok(PackedI4BatchedMatmulShape {
        total_outputs_u32,
        output_words,
        output_bytes,
        top1_output_bytes,
    })
}

pub(super) fn expect_one_output<'a>(
    context: &str,
    outputs: &'a [Vec<u8>],
) -> Result<&'a [u8], DispatchError> {
    if outputs.len() != 1 {
        return Err(DispatchError::BackendError(format!(
            "Fix: {context} expected exactly one output buffer, got {}.",
            outputs.len()
        )));
    }
    Ok(&outputs[0])
}

pub(super) fn expect_two_outputs<'a>(
    context: &str,
    outputs: &'a [Vec<u8>],
) -> Result<(&'a [u8], &'a [u8]), DispatchError> {
    if outputs.len() != 2 {
        return Err(DispatchError::BackendError(format!(
            "Fix: {context} expected exactly two output buffers, got {}.",
            outputs.len()
        )));
    }
    Ok((&outputs[0], &outputs[1]))
}