use half::f16;
use oxibonsai_core::tensor::{BlockQ1_0G128, QK1_0_G128};
use oxibonsai_kernels::dequant::dequant_1bit_g128;
use oxibonsai_kernels::gemm::gemm_1bit_g128;
use oxibonsai_kernels::gemv::gemv_1bit_g128;
use proptest::prelude::*;
fn arb_block_any_scale() -> impl Strategy<Value = BlockQ1_0G128> {
(any::<u16>(), prop::array::uniform16(any::<u8>())).prop_map(|(raw_scale, qs)| BlockQ1_0G128 {
d: f16::from_bits(raw_scale),
qs,
})
}
fn arb_block_finite() -> impl Strategy<Value = BlockQ1_0G128> {
(
prop::num::f32::NORMAL.prop_filter("finite nonzero scale", |v| {
v.is_finite() && v.abs() > 1e-5 && v.abs() < 1e4
}),
prop::array::uniform16(any::<u8>()),
)
.prop_map(|(scale, qs)| BlockQ1_0G128 {
d: f16::from_f32(scale),
qs,
})
}
#[allow(dead_code)]
fn arb_blocks_finite(count: usize) -> impl Strategy<Value = Vec<BlockQ1_0G128>> {
prop::collection::vec(arb_block_finite(), count..=count)
}
fn arb_input(len: usize) -> impl Strategy<Value = Vec<f32>> {
prop::collection::vec(
prop::num::f32::NORMAL.prop_filter("finite input", |v| v.is_finite() && v.abs() < 1e6),
len..=len,
)
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(512))]
#[test]
fn prop_test_dequant_never_panics(block in arb_block_any_scale()) {
let mut output = vec![0.0f32; QK1_0_G128];
let _ = dequant_1bit_g128(&[block], &mut output);
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(128))]
#[test]
fn prop_test_gemv_output_shape(
n_rows in 1usize..=8,
blocks in prop::collection::vec(arb_block_finite(), 1..=8),
) {
let k = QK1_0_G128; let n_rows = n_rows.min(blocks.len());
let weight_blocks = blocks[..n_rows].to_vec();
let input = vec![1.0f32; k];
let mut output = vec![0.0f32; n_rows];
let result = gemv_1bit_g128(&weight_blocks, &input, &mut output, n_rows, k);
if result.is_ok() {
prop_assert_eq!(output.len(), n_rows,
"output length must equal n_rows");
}
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(128))]
#[test]
fn prop_test_gemm_output_shape(
m in 1usize..=4,
n_rows in 1usize..=4,
blocks in prop::collection::vec(arb_block_finite(), 1..=16),
) {
let k = QK1_0_G128;
let blocks_needed = n_rows; if blocks.len() < blocks_needed {
return Ok(());
}
let weight_blocks = blocks[..blocks_needed].to_vec();
let input = vec![1.0f32; m * k];
let mut output = vec![0.0f32; m * n_rows];
let result = gemm_1bit_g128(&weight_blocks, &input, &mut output, m, n_rows, k);
if result.is_ok() {
prop_assert_eq!(output.len(), m * n_rows,
"GEMM output length must be m * n_rows");
}
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(256))]
#[test]
fn prop_test_dequant_values_bounded(block in arb_block_finite()) {
let d = block.d.to_f32();
if !d.is_finite() || d == 0.0 {
return Ok(());
}
let scale = d.abs();
let mut output = vec![0.0f32; QK1_0_G128];
if dequant_1bit_g128(&[block], &mut output).is_err() {
return Ok(());
}
let tol = scale * 0.02 + f32::EPSILON;
for &v in &output {
prop_assert!(
v.is_finite(),
"dequant output must be finite, got {v}"
);
prop_assert!(
(v.abs() - scale).abs() <= tol,
"dequant value {v} not within ±{scale} (tol={tol})"
);
}
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(128))]
#[test]
fn prop_test_gemv_linearity(
block in arb_block_finite(),
alpha in prop::num::f32::NORMAL.prop_filter("finite nonzero alpha", |v| {
v.is_finite() && v.abs() > 0.01 && v.abs() < 100.0
}),
input in arb_input(QK1_0_G128),
) {
let blocks = vec![block];
let scaled_input: Vec<f32> = input.iter().map(|&x| x * alpha).collect();
let mut out_base = vec![0.0f32; 1];
let mut out_scaled = vec![0.0f32; 1];
let r1 = gemv_1bit_g128(&blocks, &input, &mut out_base, 1, QK1_0_G128);
let r2 = gemv_1bit_g128(&blocks, &scaled_input, &mut out_scaled, 1, QK1_0_G128);
if r1.is_ok() && r2.is_ok() {
let expected = out_base[0] * alpha;
let actual = out_scaled[0];
let tol = expected.abs() * 0.05 + 1.0;
prop_assert!(
(expected - actual).abs() < tol,
"linearity failed: α·gemv(x)={expected:.4} vs gemv(α·x)={actual:.4}, α={alpha}"
);
}
}
}
#[test]
fn prop_test_kernel_with_zero_rows() {
let blocks: &[BlockQ1_0G128] = &[];
let input = vec![0.0f32; QK1_0_G128];
let mut output = vec![0.0f32; 0];
let result = gemv_1bit_g128(blocks, &input, &mut output, 0, QK1_0_G128);
if let Ok(()) = result {
assert!(output.is_empty(), "n_rows=0 should yield empty output");
}
}
#[test]
fn prop_test_kernel_with_single_row() {
let block = BlockQ1_0G128 {
d: f16::from_f32(1.0),
qs: [0xFF; 16], };
let input = vec![1.0f32; QK1_0_G128];
let mut output = vec![0.0f32; 1];
gemv_1bit_g128(&[block], &input, &mut output, 1, QK1_0_G128)
.expect("single-row GEMV should succeed");
assert!(
(output[0] - 128.0).abs() < 1.0,
"single-row GEMV with all-one weights: expected 128.0, got {}",
output[0]
);
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(32))]
#[test]
fn prop_test_large_block_count(
block_count in 1usize..=1024,
) {
let blocks: Vec<BlockQ1_0G128> = (0..block_count)
.map(|_| BlockQ1_0G128 {
d: f16::from_f32(1.0),
qs: [0xAAu8; 16],
})
.collect();
let mut output = vec![0.0f32; block_count * QK1_0_G128];
let _ = dequant_1bit_g128(&blocks, &mut output);
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(256))]
#[test]
fn prop_test_gemv_never_panics(block in arb_block_any_scale()) {
let input = vec![1.0f32; QK1_0_G128];
let mut output = vec![0.0f32; 1];
let _ = gemv_1bit_g128(&[block], &input, &mut output, 1, QK1_0_G128);
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(256))]
#[test]
fn prop_test_gemm_never_panics(block in arb_block_any_scale()) {
let input = vec![1.0f32; QK1_0_G128];
let mut output = vec![0.0f32; 1];
let _ = gemm_1bit_g128(&[block], &input, &mut output, 1, 1, QK1_0_G128);
}
}