fib-quant 0.1.0-alpha.1

Experimental Rust implementation of the FibQuant radial-angular vector quantization core
Documentation
#![cfg(feature = "kv")]

use fib_quant::kv::{
    decode_kv_pages, encode_kv_tensor, KvAttentionKind, KvAxisPolicyV1, KvCacheLayoutV1,
    KvCompressionProfileV1, KvDType, KvPageGeometryV1, KvRole, KvRopeState, KvTensorShapeV1,
};
use fib_quant::{FibQuantProfileV1, FibQuantizer};

fn fixture(
    axis: KvAxisPolicyV1,
) -> (
    KvTensorShapeV1,
    KvCacheLayoutV1,
    KvCompressionProfileV1,
    Vec<f32>,
) {
    let shape = KvTensorShapeV1::new(
        KvRole::Value,
        KvAttentionKind::Mha,
        1,
        1,
        1,
        1,
        4,
        8,
        KvDType::F32,
        KvRopeState::NotApplicable,
    );
    let layout = KvCacheLayoutV1::canonical(&shape).unwrap();
    let mut fib_profile = FibQuantProfileV1::paper_default(8, 2, 8, 13).unwrap();
    fib_profile.training_samples = 64;
    fib_profile.lloyd_restarts = 1;
    fib_profile.lloyd_iterations = 1;
    let quantizer = FibQuantizer::new(fib_profile.clone()).unwrap();
    let profile = KvCompressionProfileV1::from_parts(
        "encode-decode",
        &shape,
        fib_profile,
        quantizer.codebook().codebook_digest.clone(),
        axis,
        KvPageGeometryV1::new(2, 8, 64),
    )
    .unwrap();
    let values = (0..shape.element_count().unwrap())
        .map(|idx| ((idx as f32 + 1.0) * 0.07).sin() + 0.25)
        .collect();
    (shape, layout, profile, values)
}

#[test]
fn per_token_encode_decode_preserves_shape_and_receipts() {
    let (shape, layout, profile, values) = fixture(KvAxisPolicyV1::PerToken);
    let encoded = encode_kv_tensor(shape.clone(), layout, profile, &values).unwrap();
    assert_eq!(encoded.shape, shape);
    assert_eq!(encoded.receipt.encoded_pages, 2);
    assert!(encoded.receipt.compressed_blocks > 0);
    assert_eq!(encoded.receipt.page_digests.len(), encoded.pages.len());

    let decoded = decode_kv_pages(&encoded).unwrap();
    assert_eq!(decoded.values.len(), values.len());
    assert_eq!(decoded.receipt.decoded_pages, encoded.pages.len() as u32);
}

#[test]
fn unsupported_axis_uses_raw_fallback() {
    let (shape, layout, profile, values) = fixture(KvAxisPolicyV1::PerChannel);
    let encoded = encode_kv_tensor(shape, layout, profile, &values).unwrap();
    assert_eq!(encoded.receipt.compressed_blocks, 0);
    assert_eq!(encoded.receipt.raw_fallback_blocks, 4);
    let decoded = decode_kv_pages(&encoded).unwrap();
    assert_eq!(decoded.values, values);
}