vyre-primitives 0.6.3

Compositional primitives for vyre - marker types (always on) + Tier 2.5 LEGO substrate (feature-gated per domain).
Documentation
use super::*;

#[test]
fn batched_matmul_cpu_matches_dequantized_reference() {
    let weights = vec![
        vec![1, 2, 3, 4, -1, -2, -3, -4, 5],
        vec![4, 3, 2, 1, -4, -3, -2, -1, -5],
        vec![-8, -7, -6, -5, -4, -3, -2, -1, 0],
    ];
    let activations = vec![
        vec![1, -1, 2, -2, 3, -3, 4, -4, 5],
        vec![-5, 4, -4, 3, -3, 2, -2, 1, -1],
    ];
    let row_scales = [0.5_f32, 0.25, 0.125];
    let batch_scales = [0.25_f32, 0.5];
    let weights_packed = pack_i4_matrix_rows(&weights);
    let activations_packed = pack_i4_matrix_rows(&activations);
    let actual = i4x8_batched_matmul_f32_scaled_cpu(
        &weights_packed,
        &activations_packed,
        &row_scales,
        &batch_scales,
        2,
        3,
        9,
    );
    let expected = activations
        .iter()
        .zip(batch_scales)
        .flat_map(|(activation, batch_scale)| {
            weights.iter().zip(row_scales).map(move |(row, row_scale)| {
                row.iter()
                    .zip(activation)
                    .fold(0.0_f32, |acc, (&w, &x)| acc + w as f32 * x as f32)
                    * row_scale
                    * batch_scale
            })
        })
        .collect::<Vec<_>>();

    assert_eq!(
        actual.iter().map(|v| v.to_bits()).collect::<Vec<_>>(),
        expected.iter().map(|v| v.to_bits()).collect::<Vec<_>>()
    );
}

#[test]
fn batched_matmul_top1_cpu_matches_full_matmul_argmax() {
    let weights = vec![
        vec![1, 2, 3, 4, -1, -2, -3, -4, 5],
        vec![4, 3, 2, 1, -4, -3, -2, -1, -5],
        vec![-8, -7, -6, -5, -4, -3, -2, -1, 0],
    ];
    let activations = vec![
        vec![1, -1, 2, -2, 3, -3, 4, -4, 5],
        vec![-5, 4, -4, 3, -3, 2, -2, 1, -1],
    ];
    let row_scales = [0.5_f32, 0.25, 0.125];
    let batch_scales = [0.75_f32, 0.375];
    let weights_packed = pack_i4_matrix_rows(&weights);
    let activations_packed = pack_i4_matrix_rows(&activations);
    let logits = i4x8_batched_matmul_f32_scaled_cpu(
        &weights_packed,
        &activations_packed,
        &row_scales,
        &batch_scales,
        activations.len() as u32,
        weights.len() as u32,
        weights[0].len() as u32,
    );
    let (scores, indices) = i4x8_batched_matmul_top1_f32_scaled_cpu(
        &weights_packed,
        &activations_packed,
        &row_scales,
        &batch_scales,
        activations.len() as u32,
        weights.len() as u32,
        weights[0].len() as u32,
    );

    for batch_index in 0..activations.len() {
        let row_start = batch_index * weights.len();
        let (expected_index, expected_score) = (0..weights.len())
            .map(|row| (row as u32, logits[row_start + row]))
            .max_by(|(_, lhs), (_, rhs)| lhs.total_cmp(rhs))
            .expect("Fix: top1 test requires at least one row.");
        assert_eq!(indices[batch_index], expected_index);
        assert_eq!(scores[batch_index].to_bits(), expected_score.to_bits());
    }
}

#[test]

fn generated_batched_matmul_matches_dequantized_reference_across_pack_boundaries() {
    let pattern = [-8, -3, -1, 0, 1, 3, 7, 6, 5, 4, 2, -2, -4, -6, -7, -5];
    for batch in 1..=4_u32 {
        for rows in 1..=5_u32 {
            for cols in [1_u32, 7, 8, 9, 16, 17, 31, 32, 33] {
                let weights = (0..rows as usize)
                    .map(|row| {
                        pattern
                            .iter()
                            .copied()
                            .cycle()
                            .skip(row * 3)
                            .take(cols as usize)
                            .collect::<Vec<_>>()
                    })
                    .collect::<Vec<_>>();
                let activations = (0..batch as usize)
                    .map(|batch_index| {
                        pattern
                            .iter()
                            .copied()
                            .cycle()
                            .skip(batch_index * 5 + 1)
                            .take(cols as usize)
                            .collect::<Vec<_>>()
                    })
                    .collect::<Vec<_>>();
                let row_scales = (0..rows)
                    .map(|row| 0.125_f32 + row as f32 * 0.0625)
                    .collect::<Vec<_>>();
                let batch_scales = (0..batch)
                    .map(|batch_index| 0.25_f32 + batch_index as f32 * 0.03125)
                    .collect::<Vec<_>>();
                let weights_packed = pack_i4_matrix_rows(&weights);
                let activations_packed = pack_i4_matrix_rows(&activations);
                let actual = i4x8_batched_matmul_f32_scaled_cpu(
                    &weights_packed,
                    &activations_packed,
                    &row_scales,
                    &batch_scales,
                    batch,
                    rows,
                    cols,
                );
                let expected = activations
                    .iter()
                    .zip(batch_scales.iter().copied())
                    .flat_map(|(activation, batch_scale)| {
                        weights.iter().zip(row_scales.iter().copied()).map(
                            move |(row, row_scale)| {
                                row.iter()
                                    .zip(activation)
                                    .fold(0.0_f32, |acc, (&w, &x)| acc + w as f32 * x as f32)
                                    * row_scale
                                    * batch_scale
                            },
                        )
                    })
                    .collect::<Vec<_>>();

                assert_eq!(
                    actual.iter().map(|v| v.to_bits()).collect::<Vec<_>>(),
                    expected.iter().map(|v| v.to_bits()).collect::<Vec<_>>(),
                    "batch={batch} rows={rows} cols={cols}"
                );
            }
        }
    }
}

#[test]
fn generated_batched_matmul_top1_matches_full_matmul_across_pack_boundaries() {
    let weight_pattern = [-8, -3, -1, 0, 1, 3, 7, 6, 5, 4, 2, -2, -4, -6, -7, -5];
    let activation_pattern = [7, 5, 3, 1, -1, -3, -5, -7, 6, 4, 2, 0, -2, -4, -6, -8];
    for (batch, rows, cols) in [
        (1_u32, 1_u32, 1_u32),
        (2, 2, 7),
        (3, 3, 8),
        (4, 4, 9),
        (5, 5, 17),
        (6, 6, 33),
        (3, 7, 65),
    ] {
        let weights = (0..rows as usize)
            .map(|row| {
                weight_pattern
                    .iter()
                    .copied()
                    .cycle()
                    .skip(row * 5)
                    .take(cols as usize)
                    .collect::<Vec<_>>()
            })
            .collect::<Vec<_>>();
        let activations = (0..batch as usize)
            .map(|batch_index| {
                activation_pattern
                    .iter()
                    .copied()
                    .cycle()
                    .skip(batch_index * 7)
                    .take(cols as usize)
                    .collect::<Vec<_>>()
            })
            .collect::<Vec<_>>();
        let row_scales = (0..rows)
            .map(|row| 0.125_f32 + row as f32 * 0.0625)
            .collect::<Vec<_>>();
        let batch_scales = (0..batch)
            .map(|batch_index| 0.25_f32 + batch_index as f32 * 0.03125)
            .collect::<Vec<_>>();
        let weights_packed = pack_i4_matrix_rows(&weights);
        let activations_packed = pack_i4_matrix_rows(&activations);
        let logits = i4x8_batched_matmul_f32_scaled_cpu(
            &weights_packed,
            &activations_packed,
            &row_scales,
            &batch_scales,
            batch,
            rows,
            cols,
        );
        let (scores, indices) = i4x8_batched_matmul_top1_f32_scaled_cpu(
            &weights_packed,
            &activations_packed,
            &row_scales,
            &batch_scales,
            batch,
            rows,
            cols,
        );

        for batch_index in 0..batch as usize {
            let row_start = batch_index * rows as usize;
            let (expected_index, expected_score) = (0..rows as usize)
                .map(|row| (row as u32, logits[row_start + row]))
                .max_by(|(_, lhs), (_, rhs)| lhs.total_cmp(rhs))
                .expect("Fix: top1 generated test requires at least one row.");
            assert_eq!(
                indices[batch_index], expected_index,
                "batch={batch} rows={rows} cols={cols} batch_index={batch_index}"
            );
            assert_eq!(
                scores[batch_index].to_bits(),
                expected_score.to_bits(),
                "batch={batch} rows={rows} cols={cols} batch_index={batch_index}"
            );
        }
    }
}