trueno 0.17.4

High-performance SIMD compute library with GPU support for matrix operations
Documentation
//! Coverage tests for `matmul_q6k_f32_parallel` — the threaded dispatch path
//! triggered when `out_dim * in_dim >= 8_000_000`.
//!
//! These tests exercise:
//! - The parallel path entry via `matmul_q6k_f32_dispatch`
//! - Multi-thread chunk splitting with even and uneven out_dim
//! - Parity between parallel and scalar results
//! - Edge cases: single output row, prime out_dim, just-at-threshold

use super::super::*;

/// Helper: build deterministic Q6K row-major test data.
///
/// Each super-block has (210 bytes):
///   ql (128 bytes) | qh (64 bytes) | scales (16 bytes) | d (2 bytes, f16)
fn build_q6k_test_data(out_dim: usize, in_dim: usize) -> Vec<u8> {
    let num_blocks_per_row = (in_dim + SUPER_BLOCK_SIZE - 1) / SUPER_BLOCK_SIZE;
    let row_bytes = num_blocks_per_row * SUPER_BLOCK_BYTES;
    let total_bytes = out_dim * row_bytes;
    let mut data = vec![0u8; total_bytes];

    for row in 0..out_dim {
        for sb in 0..num_blocks_per_row {
            let offset = row * row_bytes + sb * SUPER_BLOCK_BYTES;
            // ql: 128 bytes at offset+0
            for i in 0..128 {
                data[offset + i] = ((row + sb + i + 1) % 256) as u8;
            }
            // qh: 64 bytes at offset+128
            for i in 0..64 {
                data[offset + 128 + i] = ((row ^ sb ^ i) % 4) as u8;
            }
            // scales: 16 bytes at offset+192
            for i in 0..16 {
                data[offset + 192 + i] = ((row + sb + i) % 64 + 1) as u8;
            }
            // d: 2 bytes at offset+208 (0.5 as f16 = 0x3800)
            data[offset + 208] = 0x00;
            data[offset + 209] = 0x38;
        }
    }
    data
}

/// Helper: build a deterministic input vector.
fn build_input(in_dim: usize) -> Vec<f32> {
    (0..in_dim).map(|i| (i as f32 * 0.00137).sin()).collect()
}

// ============================================================================
// Parallel dispatch path tests
// ============================================================================

/// Core test: parallel dispatch produces results matching scalar.
///
/// Uses out_dim=4096, in_dim=2048 => total_work = 8,388,608 (>= 8M threshold).
/// This directly exercises the `matmul_q6k_f32_parallel` function.
#[test]
fn test_q6k_parallel_dispatch_matches_scalar() {
    let out_dim = 4096;
    let in_dim = 2048; // 8 super-blocks per row
    let total_work = out_dim * in_dim;
    assert!(total_work >= 8_000_000, "Must trigger parallel path");
    assert_eq!(in_dim % SUPER_BLOCK_SIZE, 0);

    let q6k_data = build_q6k_test_data(out_dim, in_dim);
    let input = build_input(in_dim);

    let scalar = matmul_q6k_f32_scalar(&q6k_data, &input, out_dim, in_dim);
    let dispatch = matmul_q6k_f32_dispatch(&q6k_data, &input, out_dim, in_dim);

    assert_eq!(scalar.len(), out_dim);
    assert_eq!(dispatch.len(), out_dim);

    // Compare every 64th row for speed, plus first and last
    let check_indices: Vec<usize> =
        (0..out_dim).step_by(64).chain(std::iter::once(out_dim - 1)).collect();

    for &i in &check_indices {
        let diff = (scalar[i] - dispatch[i]).abs();
        let tol = scalar[i].abs() * 1e-4 + 1e-4;
        assert!(
            diff < tol,
            "Row {}: scalar={}, dispatch={}, diff={}",
            i,
            scalar[i],
            dispatch[i],
            diff
        );
    }
}

/// Uneven chunk splitting: out_dim not divisible by typical thread counts.
///
/// Uses a prime out_dim (4099) so no thread count evenly divides it.
/// The last thread's chunk will be smaller, exercising remainder handling.
#[test]
fn test_q6k_parallel_dispatch_prime_outdim() {
    let out_dim = 4099; // prime
    let in_dim = 2048;
    let total_work = out_dim * in_dim;
    assert!(total_work >= 8_000_000);
    assert_eq!(in_dim % SUPER_BLOCK_SIZE, 0);

    let q6k_data = build_q6k_test_data(out_dim, in_dim);
    let input = build_input(in_dim);

    let scalar = matmul_q6k_f32_scalar(&q6k_data, &input, out_dim, in_dim);
    let dispatch = matmul_q6k_f32_dispatch(&q6k_data, &input, out_dim, in_dim);

    assert_eq!(dispatch.len(), out_dim);

    // Check first, middle, and last rows
    for &i in &[0, out_dim / 2, out_dim - 1] {
        let diff = (scalar[i] - dispatch[i]).abs();
        let tol = scalar[i].abs() * 1e-4 + 1e-4;
        assert!(
            diff < tol,
            "Row {}: scalar={}, dispatch={}, diff={}",
            i,
            scalar[i],
            dispatch[i],
            diff
        );
    }
}

/// Small out_dim but very large in_dim to trigger parallel path.
///
/// out_dim=2, in_dim=4194304 (16384 super-blocks) => total_work = 8,388,608.
/// With only 2 output rows, each thread gets at most 1 row.
#[test]
fn test_q6k_parallel_dispatch_few_rows_large_indim() {
    let out_dim = 2;
    let in_dim = 4_194_304; // 16384 super-blocks
    let total_work = out_dim * in_dim;
    assert!(total_work >= 8_000_000);
    assert_eq!(in_dim % SUPER_BLOCK_SIZE, 0);

    let q6k_data = build_q6k_test_data(out_dim, in_dim);
    let input = build_input(in_dim);

    let scalar = matmul_q6k_f32_scalar(&q6k_data, &input, out_dim, in_dim);
    let dispatch = matmul_q6k_f32_dispatch(&q6k_data, &input, out_dim, in_dim);

    assert_eq!(dispatch.len(), out_dim);
    for i in 0..out_dim {
        let diff = (scalar[i] - dispatch[i]).abs();
        let tol = scalar[i].abs() * 1e-4 + 1e-4;
        assert!(
            diff < tol,
            "Row {}: scalar={}, dispatch={}, diff={}",
            i,
            scalar[i],
            dispatch[i],
            diff
        );
    }
}

/// Just at threshold: total_work = 8_000_000 exactly.
///
/// out_dim=31250, in_dim=256 => 31250 * 256 = 8,000,000.
#[test]
fn test_q6k_parallel_dispatch_exact_threshold() {
    let out_dim = 31_250;
    let in_dim = 256; // 1 super-block per row
    let total_work = out_dim * in_dim;
    assert_eq!(total_work, 8_000_000);

    let q6k_data = build_q6k_test_data(out_dim, in_dim);
    let input = build_input(in_dim);

    let scalar = matmul_q6k_f32_scalar(&q6k_data, &input, out_dim, in_dim);
    let dispatch = matmul_q6k_f32_dispatch(&q6k_data, &input, out_dim, in_dim);

    assert_eq!(dispatch.len(), out_dim);

    // Spot-check rows
    for &i in &[0, 100, 10_000, 31_249] {
        let diff = (scalar[i] - dispatch[i]).abs();
        let tol = scalar[i].abs() * 1e-4 + 1e-4;
        assert!(
            diff < tol,
            "Row {}: scalar={}, dispatch={}, diff={}",
            i,
            scalar[i],
            dispatch[i],
            diff
        );
    }
}

/// Just below threshold: total_work < 8M should NOT use parallel.
///
/// Verifies dispatch still produces correct results for the non-parallel path.
#[test]
fn test_q6k_dispatch_just_below_threshold() {
    let out_dim = 31_249;
    let in_dim = 256;
    let total_work = out_dim * in_dim;
    assert!(total_work < 8_000_000);

    let q6k_data = build_q6k_test_data(out_dim, in_dim);
    let input = build_input(in_dim);

    let scalar = matmul_q6k_f32_scalar(&q6k_data, &input, out_dim, in_dim);
    let dispatch = matmul_q6k_f32_dispatch(&q6k_data, &input, out_dim, in_dim);

    assert_eq!(dispatch.len(), out_dim);

    for &i in &[0, 100, 15_000, 31_248] {
        let diff = (scalar[i] - dispatch[i]).abs();
        let tol = scalar[i].abs() * 1e-4 + 1e-4;
        assert!(
            diff < tol,
            "Row {}: scalar={}, dispatch={}, diff={}",
            i,
            scalar[i],
            dispatch[i],
            diff
        );
    }
}

/// Single output row with massive in_dim.
///
/// out_dim=1, in_dim=8M => total_work=8M. Only 1 chunk, assigned to 1 thread.
#[test]
fn test_q6k_parallel_dispatch_single_row() {
    let out_dim = 1;
    let in_dim = 8_388_608; // 32768 super-blocks
    let total_work = out_dim * in_dim;
    assert!(total_work >= 8_000_000);
    assert_eq!(in_dim % SUPER_BLOCK_SIZE, 0);

    let q6k_data = build_q6k_test_data(out_dim, in_dim);
    let input = build_input(in_dim);

    let scalar = matmul_q6k_f32_scalar(&q6k_data, &input, out_dim, in_dim);
    let dispatch = matmul_q6k_f32_dispatch(&q6k_data, &input, out_dim, in_dim);

    assert_eq!(dispatch.len(), 1);
    let diff = (scalar[0] - dispatch[0]).abs();
    let tol = scalar[0].abs() * 1e-4 + 1e-4;
    assert!(diff < tol, "scalar={}, dispatch={}, diff={}", scalar[0], dispatch[0], diff);
}

/// All-zero input: parallel path should produce all-zero output.
#[test]
fn test_q6k_parallel_dispatch_zero_input() {
    let out_dim = 4096;
    let in_dim = 2048;
    assert!(out_dim * in_dim >= 8_000_000);

    let q6k_data = build_q6k_test_data(out_dim, in_dim);
    let input = vec![0.0f32; in_dim];

    let dispatch = matmul_q6k_f32_dispatch(&q6k_data, &input, out_dim, in_dim);

    assert_eq!(dispatch.len(), out_dim);
    // Q6K: sum += d * scale * q6 * input[idx], with input=0 => sum=0
    for (i, &val) in dispatch.iter().enumerate() {
        assert_eq!(val, 0.0, "Row {}: expected 0.0 with zero input, got {}", i, val);
    }
}

/// Verify the public `matmul_q6k_f32` alias also routes through parallel for large inputs.
#[test]
fn test_q6k_public_api_parallel_route() {
    let out_dim = 4096;
    let in_dim = 2048;
    assert!(out_dim * in_dim >= 8_000_000);

    let q6k_data = build_q6k_test_data(out_dim, in_dim);
    let input = build_input(in_dim);

    // matmul_q6k_f32 is the public alias for dispatch
    let result = matmul_q6k_f32(&q6k_data, &input, out_dim, in_dim);
    let scalar = matmul_q6k_f32_scalar(&q6k_data, &input, out_dim, in_dim);

    assert_eq!(result.len(), out_dim);
    for &i in &[0, 1000, 4095] {
        let diff = (scalar[i] - result[i]).abs();
        let tol = scalar[i].abs() * 1e-4 + 1e-4;
        assert!(
            diff < tol,
            "Row {}: scalar={}, public_api={}, diff={}",
            i,
            scalar[i],
            result[i],
            diff
        );
    }
}