vyre-self-substrate 0.6.3

Vyre self-substrate: vyre using its own primitives on its own scheduler problems. The recursion-thesis layer between vyre-primitives and vyre-driver.
Documentation
mod batched_matmul_contracts;
mod batched_matmul_top1_contracts;
mod batched_matvec_contracts;
mod dot_contracts;
mod generated_contracts;
mod matvec_contracts;
mod unpack_contracts;

use super::*;

struct QuantizedDispatcher;

impl OptimizerDispatcher for QuantizedDispatcher {
    fn dispatch(
        &self,
        _program: &Program,
        inputs: &[Vec<u8>],
        grid_override: Option<[u32; 3]>,
    ) -> Result<Vec<Vec<u8>>, DispatchError> {
        assert_eq!(grid_override, Some([1, 1, 1]));
        assert_eq!(inputs.len(), 2);
        let packed = crate::hardware::dispatch_buffers::read_u32s(&inputs[0]);
        let lane_count = inputs[1].len() / std::mem::size_of::<i32>();
        let mut out = Vec::new();
        unpack_i4x8_cpu_into(&packed, lane_count as u32, &mut out);
        Ok(vec![vyre_primitives::wire::pack_i32_slice(&out)])
    }
}

struct QuantizedDotDispatcher;

impl OptimizerDispatcher for QuantizedDotDispatcher {
    fn dispatch(
        &self,
        _program: &Program,
        inputs: &[Vec<u8>],
        grid_override: Option<[u32; 3]>,
    ) -> Result<Vec<Vec<u8>>, DispatchError> {
        assert_eq!(grid_override, Some([1, 1, 1]));
        assert_eq!(inputs.len(), 5);
        let lhs = crate::hardware::dispatch_buffers::read_u32s(&inputs[0]);
        let rhs = crate::hardware::dispatch_buffers::read_u32s(&inputs[1]);
        let lhs_scale = crate::hardware::dispatch_buffers::read_f32s(&inputs[2])[0];
        let rhs_scale = crate::hardware::dispatch_buffers::read_f32s(&inputs[3])[0];
        let lane_count = (inputs[4].len() / std::mem::size_of::<f32>()) as u32;
        assert_eq!(
            lane_count, 1,
            "Fix: dot output slot must reserve exactly one f32 word."
        );
        let logical_lane_count = (lhs.len() as u32 - 1) * 8
            + if lhs.last().copied().unwrap_or(0) == 0 {
                8
            } else {
                8
            };
        let lane_count = logical_lane_count.min((lhs.len() as u32) * 8);
        let out = i4x8_dot_f32_scaled_cpu(&lhs, &rhs, lhs_scale, rhs_scale, lane_count);
        Ok(vec![vyre_primitives::wire::pack_f32_slice(&[out])])
    }
}

struct MalformedDotDispatcher {
    outputs: Vec<Vec<u8>>,
}

impl OptimizerDispatcher for MalformedDotDispatcher {
    fn dispatch(
        &self,
        _program: &Program,
        _inputs: &[Vec<u8>],
        _grid_override: Option<[u32; 3]>,
    ) -> Result<Vec<Vec<u8>>, DispatchError> {
        Ok(self.outputs.clone())
    }
}

struct QuantizedMatvecDispatcher;

impl OptimizerDispatcher for QuantizedMatvecDispatcher {
    fn dispatch(
        &self,
        _program: &Program,
        inputs: &[Vec<u8>],
        grid_override: Option<[u32; 3]>,
    ) -> Result<Vec<Vec<u8>>, DispatchError> {
        assert_eq!(inputs.len(), 4);
        let weights = crate::hardware::dispatch_buffers::read_u32s(&inputs[0]);
        let x = crate::hardware::dispatch_buffers::read_f32s(&inputs[1]);
        let row_scales = crate::hardware::dispatch_buffers::read_f32s(&inputs[2]);
        let rows = row_scales.len() as u32;
        let cols = x.len() as u32;
        assert_eq!(grid_override, Some([rows, 1, 1]));
        assert_eq!(
            inputs[3].len(),
            row_scales.len() * std::mem::size_of::<f32>(),
            "Fix: matvec output slot must reserve exactly one f32 per row."
        );
        let out = i4x8_matvec_f32_scaled_cpu(&weights, &x, &row_scales, rows, cols);
        Ok(vec![vyre_primitives::wire::pack_f32_slice(&out)])
    }
}

struct QuantizedBatchedMatvecDispatcher;

impl OptimizerDispatcher for QuantizedBatchedMatvecDispatcher {
    fn dispatch(
        &self,
        _program: &Program,
        inputs: &[Vec<u8>],
        grid_override: Option<[u32; 3]>,
    ) -> Result<Vec<Vec<u8>>, DispatchError> {
        assert_eq!(inputs.len(), 4);
        let weights = crate::hardware::dispatch_buffers::read_u32s(&inputs[0]);
        let x_batches = crate::hardware::dispatch_buffers::read_f32s(&inputs[1]);
        let row_scales = crate::hardware::dispatch_buffers::read_f32s(&inputs[2]);
        let Some([rows, batch, 1]) = grid_override else {
            panic!("Fix: batched matvec dispatch must launch with [rows, batch, 1].");
        };
        let cols = x_batches
            .len()
            .checked_div(batch as usize)
            .expect("Fix: fake batched matvec dispatcher requires nonzero batch")
            as u32;
        assert_eq!(rows as usize, row_scales.len());
        assert_eq!(
            inputs[3].len(),
            batch as usize * rows as usize * std::mem::size_of::<f32>(),
            "Fix: batched matvec output slot must reserve exactly one f32 per batch row."
        );
        let out = i4x8_batched_matvec_f32_scaled_cpu(
            &weights,
            &x_batches,
            &row_scales,
            batch,
            rows,
            cols,
        );
        Ok(vec![vyre_primitives::wire::pack_f32_slice(&out)])
    }
}

struct QuantizedBatchedMatmulDispatcher;

impl OptimizerDispatcher for QuantizedBatchedMatmulDispatcher {
    fn dispatch(
        &self,
        _program: &Program,
        inputs: &[Vec<u8>],
        grid_override: Option<[u32; 3]>,
    ) -> Result<Vec<Vec<u8>>, DispatchError> {
        assert_eq!(inputs.len(), 5);
        let weights = crate::hardware::dispatch_buffers::read_u32s(&inputs[0]);
        let activations = crate::hardware::dispatch_buffers::read_u32s(&inputs[1]);
        let row_scales = crate::hardware::dispatch_buffers::read_f32s(&inputs[2]);
        let batch_scales = crate::hardware::dispatch_buffers::read_f32s(&inputs[3]);
        let rows = row_scales.len() as u32;
        let batch = batch_scales.len() as u32;
        let Some([grid_x, 1, 1]) = grid_override else {
            panic!(
                "Fix: batched matmul dispatch must launch one-dimensional 64-wide workgroup grid."
            );
        };
        assert_eq!(grid_x, ceil_div_u32(batch * rows, 64));
        assert_eq!(
            inputs[4].len(),
            batch as usize * rows as usize * std::mem::size_of::<f32>(),
            "Fix: batched matmul output slot must reserve exactly one f32 per batch row."
        );
        let words_per_activation = activations.len() / batch as usize;
        let cols = (words_per_activation as u32) * 8;
        let out = i4x8_batched_matmul_f32_scaled_cpu(
            &weights,
            &activations,
            &row_scales,
            &batch_scales,
            batch,
            rows,
            cols,
        );
        Ok(vec![vyre_primitives::wire::pack_f32_slice(&out)])
    }
}

struct QuantizedBatchedMatmulTop1Dispatcher;

impl OptimizerDispatcher for QuantizedBatchedMatmulTop1Dispatcher {
    fn dispatch(
        &self,
        _program: &Program,
        inputs: &[Vec<u8>],
        grid_override: Option<[u32; 3]>,
    ) -> Result<Vec<Vec<u8>>, DispatchError> {
        assert_eq!(inputs.len(), 6);
        let weights = crate::hardware::dispatch_buffers::read_u32s(&inputs[0]);
        let activations = crate::hardware::dispatch_buffers::read_u32s(&inputs[1]);
        let row_scales = crate::hardware::dispatch_buffers::read_f32s(&inputs[2]);
        let batch_scales = crate::hardware::dispatch_buffers::read_f32s(&inputs[3]);
        let rows = row_scales.len() as u32;
        let batch = batch_scales.len() as u32;
        assert_eq!(grid_override, Some([ceil_div_u32(batch, 64), 1, 1]));
        assert_eq!(
            inputs[4].len(),
            batch as usize * std::mem::size_of::<f32>(),
            "Fix: top-1 score output slot must reserve exactly one f32 per batch."
        );
        assert_eq!(
            inputs[5].len(),
            batch as usize * std::mem::size_of::<u32>(),
            "Fix: top-1 index output slot must reserve exactly one u32 per batch."
        );
        let words_per_activation = activations.len() / batch as usize;
        let cols = (words_per_activation as u32) * 8;
        let (scores, indices) = i4x8_batched_matmul_top1_f32_scaled_cpu(
            &weights,
            &activations,
            &row_scales,
            &batch_scales,
            batch,
            rows,
            cols,
        );
        Ok(vec![
            vyre_primitives::wire::pack_f32_slice(&scores),
            vyre_primitives::wire::pack_u32_slice(&indices),
        ])
    }
}

fn pack_i4_rows(rows: &[&[i32]]) -> Vec<u32> {
    let mut packed = Vec::new();
    for row in rows {
        packed.extend(pack_i4x8_cpu(row));
    }
    packed
}