use half::f16;
use crate::codebook::get_codebook;
use crate::error::{check_values_match, require, Result, TurboQuantError};
use crate::packed::{PackedBlock, TurboQuantConfig};
use crate::qjl::{
estimate_inner_product_with_codebook, precompute_query_projections, quantize_with_qjl,
EstimationContext, QjlBlock,
};
use crate::quantize::{dequantize_into_with_codebook, DequantScratch};
use crate::rotation::generate_sign_pattern;
const BYTES_PER_F16: usize = 2;
const KV_PAIR_COUNT: usize = 2;
const RESIDUAL_NORM_BYTES: usize = 2;
pub struct QuantizedKVCache {
config: TurboQuantConfig,
qjl_seed: u64,
keys: Vec<Vec<QjlBlock>>,
values: Vec<Vec<QjlBlock>>,
}
fn validate_layer(index: usize, num_layers: usize) -> bool {
index < num_layers
}
fn accumulate_weighted(result: &mut [f32], values: &[f32], weight: f32) {
for (r, &v) in result.iter_mut().zip(values.iter()) {
*r += v * weight;
}
}
fn qjl_block_size_bytes(block: &QjlBlock) -> usize {
block.polar_block.size_bytes() + block.qjl_signs.len() + RESIDUAL_NORM_BYTES
}
fn layer_len(layer_keys: &[QjlBlock]) -> usize {
layer_keys.len()
}
fn check_layer(index: usize, num_layers: usize) -> Result<()> {
require(
validate_layer(index, num_layers),
TurboQuantError::LayerOutOfRange { index, num_layers },
)
}
fn validate_batch_dims(keys: &[&[f32]], values: &[&[f32]], dim: usize) -> Result<()> {
for (i, key) in keys.iter().enumerate() {
check_values_match(key.len(), dim).map_err(|_| TurboQuantError::DimensionMismatch {
expected: dim,
actual: key.len(),
})?;
check_values_match(values[i].len(), dim).map_err(|_| {
TurboQuantError::DimensionMismatch {
expected: dim,
actual: values[i].len(),
}
})?;
}
Ok(())
}
fn validate_range(start: usize, end: usize, entry_count: usize) -> bool {
start <= end && end <= entry_count
}
fn check_range(start: usize, end: usize, entry_count: usize) -> Result<()> {
require(
validate_range(start, end, entry_count),
TurboQuantError::RangeOutOfBounds {
start,
end,
entry_count,
},
)
}
pub struct PackedImport<'a> {
pub layer: usize,
pub polar_bits: u8,
pub packed_bytes: &'a [u8],
pub scales: &'a [u16],
pub qjl_signs_flat: &'a [u8],
pub residual_norms: &'a [u16],
pub bytes_per_block: usize,
pub signs_per_block: usize,
pub is_keys: bool,
}
fn collect_packed_data(blocks: &[QjlBlock]) -> (Vec<u8>, Vec<u16>) {
let packed_bytes: Vec<u8> = blocks
.iter()
.flat_map(|b| &b.polar_block.packed_indices)
.copied()
.collect();
let scales: Vec<u16> = blocks
.iter()
.map(|b| b.polar_block.scale.to_bits())
.collect();
(packed_bytes, scales)
}
fn reconstruct_block(import: &PackedImport<'_>, index: usize) -> QjlBlock {
let pb_start = index * import.bytes_per_block;
let pb_end = pb_start + import.bytes_per_block;
let polar_block = PackedBlock::from_raw(
import.polar_bits,
f16::from_bits(import.scales[index]),
import.packed_bytes[pb_start..pb_end].to_vec(),
);
let qs_start = index * import.signs_per_block;
let qs_end = qs_start + import.signs_per_block;
let qjl_signs = import.qjl_signs_flat[qs_start..qs_end].to_vec();
let residual_norm = f16::from_bits(import.residual_norms[index]);
QjlBlock::from_parts(polar_block, qjl_signs, residual_norm)
}
impl QuantizedKVCache {
pub fn new(config: TurboQuantConfig, num_layers: usize, qjl_seed: u64) -> Self {
let keys = (0..num_layers).map(|_| Vec::new()).collect();
let values = (0..num_layers).map(|_| Vec::new()).collect();
Self {
config,
qjl_seed,
keys,
values,
}
}
pub fn push(&mut self, layer: usize, key: &[f32], value: &[f32]) -> Result<()> {
check_layer(layer, self.num_layers())?;
check_values_match(key.len(), self.config.dim)?;
check_values_match(value.len(), self.config.dim)?;
let key_block = quantize_with_qjl(&self.config, key, self.qjl_seed)?;
let value_block = quantize_with_qjl(&self.config, value, self.qjl_seed)?;
self.keys[layer].push(key_block);
self.values[layer].push(value_block);
Ok(())
}
pub fn push_batch(&mut self, layer: usize, keys: &[&[f32]], values: &[&[f32]]) -> Result<()> {
check_layer(layer, self.num_layers())?;
check_values_match(keys.len(), values.len())?;
validate_batch_dims(keys, values, self.config.dim)?;
self.quantize_and_store(layer, keys, values)
}
fn quantize_and_store(
&mut self,
layer: usize,
keys: &[&[f32]],
values: &[&[f32]],
) -> Result<()> {
use crate::qjl::{quantize_with_qjl_resources, QjlBatchResources};
let mut res = QjlBatchResources::new(&self.config)?;
self.keys[layer].reserve(keys.len());
self.values[layer].reserve(values.len());
for (key, value) in keys.iter().zip(values.iter()) {
let key_block = quantize_with_qjl_resources(key, self.qjl_seed, &mut res)?;
let value_block = quantize_with_qjl_resources(value, self.qjl_seed, &mut res)?;
self.keys[layer].push(key_block);
self.values[layer].push(value_block);
}
Ok(())
}
pub fn attention_scores(&self, layer: usize, query: &[f32]) -> Result<Vec<f32>> {
check_layer(layer, self.num_layers())?;
let r_query = precompute_query_projections(query, self.config.dim, self.qjl_seed);
let fetch_and_score = |keys: &[QjlBlock]| -> Result<Vec<f32>> {
if keys.is_empty() {
return Ok(Vec::new());
}
let polar_bits = keys[0].polar_block.bits;
let polar_config = TurboQuantConfig::new(polar_bits, self.config.dim)?
.with_seed(self.config.rotation_seed);
let codebook = get_codebook(polar_bits, self.config.dim)?;
let sign_pattern = generate_sign_pattern(self.config.dim, self.config.rotation_seed);
let mut ctx = EstimationContext {
polar_config: &polar_config,
codebook: &codebook,
sign_pattern: &sign_pattern,
dim: self.config.dim,
scratch: DequantScratch::new(self.config.dim),
};
let mut scores = Vec::with_capacity(keys.len());
for key_block in keys {
let score =
estimate_inner_product_with_codebook(query, &r_query, key_block, &mut ctx)?;
scores.push(score);
}
Ok(scores)
};
fetch_and_score(&self.keys[layer])
}
pub fn weighted_values(&self, layer: usize, weights: &[f32]) -> Result<Vec<f32>> {
check_layer(layer, self.num_layers())?;
check_values_match(weights.len(), layer_len(&self.keys[layer]))?;
let dim = self.config.dim;
let mut result = vec![0.0_f32; dim];
let dequantize_and_accumulate = |result: &mut Vec<f32>| -> Result<()> {
let values = &self.values[layer];
if values.is_empty() {
return Ok(());
}
let polar_bits = values[0].polar_block.bits;
let polar_config =
TurboQuantConfig::new(polar_bits, dim)?.with_seed(self.config.rotation_seed);
let codebook = get_codebook(polar_bits, dim)?;
let sign_pattern = generate_sign_pattern(dim, self.config.rotation_seed);
let mut scratch = DequantScratch::new(dim);
for (block, &w) in values.iter().zip(weights.iter()) {
dequantize_into_with_codebook(
&polar_config,
&block.polar_block,
&codebook,
&sign_pattern,
&mut scratch,
)?;
accumulate_weighted(result, &scratch.values, w);
}
Ok(())
};
dequantize_and_accumulate(&mut result)?;
Ok(result)
}
pub fn memory_usage(&self) -> usize {
let key_bytes: usize = self
.keys
.iter()
.flat_map(|layer| layer.iter())
.map(qjl_block_size_bytes)
.sum();
let value_bytes: usize = self
.values
.iter()
.flat_map(|layer| layer.iter())
.map(qjl_block_size_bytes)
.sum();
key_bytes + value_bytes
}
pub fn fp16_equivalent_memory(&self) -> usize {
let total_entries: usize = self.keys.iter().map(|layer| layer.len()).sum();
total_entries * self.config.dim * BYTES_PER_F16 * KV_PAIR_COUNT
}
pub fn entry_count(&self, layer: usize) -> usize {
layer_len(&self.keys[layer])
}
pub fn key_block(&self, layer: usize, index: usize) -> Option<&QjlBlock> {
self.keys.get(layer).and_then(|blocks| blocks.get(index))
}
pub fn num_layers(&self) -> usize {
self.keys.len()
}
pub fn dequantize_all_keys(&self, layer: usize) -> Result<Vec<Vec<f32>>> {
self.dequantize_all_blocks(layer, &self.keys)
}
pub fn dequantize_all_values(&self, layer: usize) -> Result<Vec<Vec<f32>>> {
self.dequantize_all_blocks(layer, &self.values)
}
fn dequantize_all_blocks(
&self,
layer: usize,
blocks_per_layer: &[Vec<QjlBlock>],
) -> Result<Vec<Vec<f32>>> {
check_layer(layer, self.num_layers())?;
self.dequantize_block_slice(&blocks_per_layer[layer])
}
fn dequantize_block_slice(&self, blocks: &[QjlBlock]) -> Result<Vec<Vec<f32>>> {
if blocks.is_empty() {
return Ok(Vec::new());
}
let dim = self.config.dim;
let polar_bits = blocks[0].polar_block.bits;
let polar_config =
TurboQuantConfig::new(polar_bits, dim)?.with_seed(self.config.rotation_seed);
let codebook = get_codebook(polar_bits, dim)?;
let sign_pattern = generate_sign_pattern(dim, self.config.rotation_seed);
let mut scratch = DequantScratch::new(dim);
let mut result = Vec::with_capacity(blocks.len());
for block in blocks {
dequantize_into_with_codebook(
&polar_config,
&block.polar_block,
&codebook,
&sign_pattern,
&mut scratch,
)?;
result.push(scratch.values.clone());
}
Ok(result)
}
fn dequantize_blocks_range(
&self,
layer: usize,
start: usize,
end: usize,
blocks_per_layer: &[Vec<QjlBlock>],
) -> Result<Vec<Vec<f32>>> {
check_layer(layer, self.num_layers())?;
let blocks = &blocks_per_layer[layer];
check_range(start, end, blocks.len())?;
self.dequantize_block_slice(&blocks[start..end])
}
pub fn dequantize_keys_range(
&self,
layer: usize,
start: usize,
end: usize,
) -> Result<Vec<Vec<f32>>> {
self.dequantize_blocks_range(layer, start, end, &self.keys)
}
pub fn dequantize_values_range(
&self,
layer: usize,
start: usize,
end: usize,
) -> Result<Vec<Vec<f32>>> {
self.dequantize_blocks_range(layer, start, end, &self.values)
}
fn select_blocks(&self, layer: usize, is_keys: bool) -> &[QjlBlock] {
if is_keys {
&self.keys[layer]
} else {
&self.values[layer]
}
}
fn select_blocks_mut(&mut self, layer: usize, is_keys: bool) -> &mut Vec<QjlBlock> {
if is_keys {
&mut self.keys[layer]
} else {
&mut self.values[layer]
}
}
pub fn config(&self) -> &TurboQuantConfig {
&self.config
}
pub fn qjl_seed(&self) -> u64 {
self.qjl_seed
}
pub fn export_packed_range(
&self,
layer: usize,
start: usize,
end: usize,
is_keys: bool,
) -> Result<(Vec<u8>, Vec<u16>)> {
check_layer(layer, self.num_layers())?;
let blocks = self.select_blocks(layer, is_keys);
check_range(start, end, blocks.len())?;
Ok(collect_packed_data(&blocks[start..end]))
}
pub fn import_packed_range(&mut self, import: &PackedImport<'_>) -> Result<()> {
check_layer(import.layer, self.num_layers())?;
let count = import.scales.len();
let target = self.select_blocks_mut(import.layer, import.is_keys);
target.reserve(count);
for i in 0..count {
target.push(reconstruct_block(import, i));
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::qjl::dot_product;
use crate::quantize::l2_norm;
use crate::test_utils::{pseudo_random_vec, LCG_MULTIPLIER};
const TEST_DIM: usize = 64;
const BITS_3: u8 = 3;
const TEST_ROTATION_SEED: u64 = 42;
const TEST_QJL_SEED: u64 = 12345;
const TEST_NUM_LAYERS: usize = 4;
const TEST_KEY_SEED: u64 = 11111;
const TEST_VALUE_SEED: u64 = 22222;
const TEST_QUERY_SEED: u64 = 33333;
const TEST_KEY_SEED_2: u64 = 44444;
const TEST_VALUE_SEED_2: u64 = 55555;
const TEST_KEY_SEED_3: u64 = 66666;
const TEST_VALUE_SEED_3: u64 = 77777;
const SCORE_RELATIVE_TOLERANCE: f32 = 0.5;
const WEIGHTED_RELATIVE_TOLERANCE: f32 = 1.0;
const MEMORY_TEST_ENTRIES: usize = 10;
const MIN_COMPRESSION_RATIO: f32 = 2.0;
const TEST_LAYER: usize = 0;
const TEST_LAYER_2: usize = 1;
const INVALID_LAYER: usize = 999;
const FLOAT_EPSILON: f32 = 1e-6;
const TEST_WEIGHT_A: f32 = 0.6;
const TEST_WEIGHT_B: f32 = 0.4;
const FP16_TEST_ENTRIES: usize = 5;
const SEED_OFFSET_BASE: u64 = 100;
fn test_config() -> TurboQuantConfig {
TurboQuantConfig::new(BITS_3, TEST_DIM)
.unwrap()
.with_seed(TEST_ROTATION_SEED)
}
#[test]
fn push_increases_len() {
let config = test_config();
let mut cache = QuantizedKVCache::new(config, TEST_NUM_LAYERS, TEST_QJL_SEED);
assert_eq!(cache.entry_count(TEST_LAYER), 0);
let key = pseudo_random_vec(TEST_DIM, TEST_KEY_SEED);
let value = pseudo_random_vec(TEST_DIM, TEST_VALUE_SEED);
cache.push(TEST_LAYER, &key, &value).unwrap();
assert_eq!(cache.entry_count(TEST_LAYER), 1);
let key2 = pseudo_random_vec(TEST_DIM, TEST_KEY_SEED_2);
let value2 = pseudo_random_vec(TEST_DIM, TEST_VALUE_SEED_2);
cache.push(TEST_LAYER, &key2, &value2).unwrap();
assert_eq!(cache.entry_count(TEST_LAYER), 2);
}
#[test]
fn attention_scores_approximate_dot_product() {
let config = test_config();
let mut cache = QuantizedKVCache::new(config, TEST_NUM_LAYERS, TEST_QJL_SEED);
let key = pseudo_random_vec(TEST_DIM, TEST_KEY_SEED);
let value = pseudo_random_vec(TEST_DIM, TEST_VALUE_SEED);
cache.push(TEST_LAYER, &key, &value).unwrap();
let query = pseudo_random_vec(TEST_DIM, TEST_QUERY_SEED);
let scores = cache.attention_scores(TEST_LAYER, &query).unwrap();
assert_eq!(scores.len(), 1);
let true_ip = dot_product(&query, &key);
let relative_error = (scores[0] - true_ip).abs() / true_ip.abs().max(1.0);
assert!(
relative_error < SCORE_RELATIVE_TOLERANCE,
"score relative error {relative_error} exceeds tolerance {SCORE_RELATIVE_TOLERANCE} \
(estimated={}, true={true_ip})",
scores[0]
);
}
#[test]
fn weighted_values_approximate_naive_sum() {
let config = test_config();
let mut cache = QuantizedKVCache::new(config, TEST_NUM_LAYERS, TEST_QJL_SEED);
let key1 = pseudo_random_vec(TEST_DIM, TEST_KEY_SEED);
let val1 = pseudo_random_vec(TEST_DIM, TEST_VALUE_SEED);
cache.push(TEST_LAYER, &key1, &val1).unwrap();
let key2 = pseudo_random_vec(TEST_DIM, TEST_KEY_SEED_2);
let val2 = pseudo_random_vec(TEST_DIM, TEST_VALUE_SEED_2);
cache.push(TEST_LAYER, &key2, &val2).unwrap();
let weights = vec![TEST_WEIGHT_A, TEST_WEIGHT_B];
let result = cache.weighted_values(TEST_LAYER, &weights).unwrap();
let naive: Vec<f32> = (0..TEST_DIM)
.map(|i| val1[i] * TEST_WEIGHT_A + val2[i] * TEST_WEIGHT_B)
.collect();
let error_vec: Vec<f32> = result
.iter()
.zip(naive.iter())
.map(|(&r, &n)| r - n)
.collect();
let error_norm = l2_norm(&error_vec);
let naive_norm = l2_norm(&naive);
let relative_error = error_norm / naive_norm;
assert!(
relative_error < WEIGHTED_RELATIVE_TOLERANCE,
"weighted values relative error {relative_error} exceeds tolerance {WEIGHTED_RELATIVE_TOLERANCE}"
);
}
#[test]
fn memory_usage_shows_compression() {
let config = test_config();
let mut cache = QuantizedKVCache::new(config, TEST_NUM_LAYERS, TEST_QJL_SEED);
for i in 0..MEMORY_TEST_ENTRIES {
let key = pseudo_random_vec(
TEST_DIM,
TEST_KEY_SEED.wrapping_add(i as u64 * SEED_OFFSET_BASE),
);
let val = pseudo_random_vec(
TEST_DIM,
TEST_VALUE_SEED.wrapping_add(i as u64 * SEED_OFFSET_BASE),
);
cache.push(TEST_LAYER, &key, &val).unwrap();
}
let quantized_bytes = cache.memory_usage();
let fp16_bytes = cache.fp16_equivalent_memory();
assert!(quantized_bytes > 0);
assert!(fp16_bytes > 0);
let ratio = fp16_bytes as f32 / quantized_bytes as f32;
assert!(
ratio > MIN_COMPRESSION_RATIO,
"compression ratio {ratio} is below minimum {MIN_COMPRESSION_RATIO}"
);
}
#[test]
fn fp16_equivalent_memory_correct_calculation() {
let config = test_config();
let mut cache = QuantizedKVCache::new(config, TEST_NUM_LAYERS, TEST_QJL_SEED);
for i in 0..FP16_TEST_ENTRIES {
let key = pseudo_random_vec(
TEST_DIM,
TEST_KEY_SEED_3.wrapping_add(i as u64 * SEED_OFFSET_BASE),
);
let val = pseudo_random_vec(
TEST_DIM,
TEST_VALUE_SEED_3.wrapping_add(i as u64 * SEED_OFFSET_BASE),
);
cache.push(TEST_LAYER, &key, &val).unwrap();
}
let expected = FP16_TEST_ENTRIES * TEST_DIM * BYTES_PER_F16 * KV_PAIR_COUNT;
assert_eq!(cache.fp16_equivalent_memory(), expected);
}
#[test]
fn multi_layer_cache_independent() {
let config = test_config();
let mut cache = QuantizedKVCache::new(config, TEST_NUM_LAYERS, TEST_QJL_SEED);
let key1 = pseudo_random_vec(TEST_DIM, TEST_KEY_SEED);
let val1 = pseudo_random_vec(TEST_DIM, TEST_VALUE_SEED);
cache.push(TEST_LAYER, &key1, &val1).unwrap();
let key2 = pseudo_random_vec(TEST_DIM, TEST_KEY_SEED_2);
let val2 = pseudo_random_vec(TEST_DIM, TEST_VALUE_SEED_2);
cache.push(TEST_LAYER_2, &key2, &val2).unwrap();
cache.push(TEST_LAYER_2, &key2, &val2).unwrap();
assert_eq!(cache.entry_count(TEST_LAYER), 1);
assert_eq!(cache.entry_count(TEST_LAYER_2), 2);
assert_eq!(cache.num_layers(), TEST_NUM_LAYERS);
}
#[test]
fn empty_cache_returns_empty_scores() {
let config = test_config();
let cache = QuantizedKVCache::new(config, TEST_NUM_LAYERS, TEST_QJL_SEED);
let query = pseudo_random_vec(TEST_DIM, TEST_QUERY_SEED);
let scores = cache.attention_scores(TEST_LAYER, &query).unwrap();
assert!(scores.is_empty());
}
#[test]
fn push_rejects_invalid_layer() {
let config = test_config();
let mut cache = QuantizedKVCache::new(config, TEST_NUM_LAYERS, TEST_QJL_SEED);
let key = pseudo_random_vec(TEST_DIM, TEST_KEY_SEED);
let val = pseudo_random_vec(TEST_DIM, TEST_VALUE_SEED);
let result = cache.push(INVALID_LAYER, &key, &val);
assert!(result.is_err());
}
#[test]
fn attention_scores_rejects_invalid_layer() {
let config = test_config();
let cache = QuantizedKVCache::new(config, TEST_NUM_LAYERS, TEST_QJL_SEED);
let query = pseudo_random_vec(TEST_DIM, TEST_QUERY_SEED);
let result = cache.attention_scores(INVALID_LAYER, &query);
assert!(result.is_err());
}
#[test]
fn weighted_values_rejects_invalid_layer() {
let config = test_config();
let cache = QuantizedKVCache::new(config, TEST_NUM_LAYERS, TEST_QJL_SEED);
let weights = vec![1.0_f32];
let result = cache.weighted_values(INVALID_LAYER, &weights);
assert!(result.is_err());
}
#[test]
fn push_rejects_wrong_key_dimension() {
let config = test_config();
let mut cache = QuantizedKVCache::new(config, TEST_NUM_LAYERS, TEST_QJL_SEED);
let wrong_key = vec![1.0_f32; TEST_DIM + 1];
let val = pseudo_random_vec(TEST_DIM, TEST_VALUE_SEED);
let result = cache.push(TEST_LAYER, &wrong_key, &val);
assert!(result.is_err());
}
#[test]
fn push_rejects_wrong_value_dimension() {
let config = test_config();
let mut cache = QuantizedKVCache::new(config, TEST_NUM_LAYERS, TEST_QJL_SEED);
let key = pseudo_random_vec(TEST_DIM, TEST_KEY_SEED);
let wrong_val = vec![1.0_f32; TEST_DIM + 1];
let result = cache.push(TEST_LAYER, &key, &wrong_val);
assert!(result.is_err());
}
#[test]
fn validate_layer_in_bounds() {
assert!(validate_layer(0, TEST_NUM_LAYERS));
assert!(validate_layer(TEST_NUM_LAYERS - 1, TEST_NUM_LAYERS));
}
#[test]
fn validate_layer_out_of_bounds() {
assert!(!validate_layer(TEST_NUM_LAYERS, TEST_NUM_LAYERS));
assert!(!validate_layer(INVALID_LAYER, TEST_NUM_LAYERS));
}
#[test]
fn accumulate_weighted_basic() {
let mut result = vec![0.0_f32; TEST_DIM];
let values: Vec<f32> = (0..TEST_DIM).map(|i| i as f32).collect();
let weight = TEST_WEIGHT_A;
accumulate_weighted(&mut result, &values, weight);
for (i, &r) in result.iter().enumerate() {
let expected = i as f32 * weight;
assert!(
(r - expected).abs() < FLOAT_EPSILON,
"mismatch at index {i}: expected {expected}, got {r}"
);
}
}
const WV_ROUNDTRIP_DIM: usize = 128;
const WV_ROUNDTRIP_ENTRIES: usize = 50;
const WV_KEY_SEED_OFFSET: u64 = 90000;
const WV_VALUE_SEED_OFFSET: u64 = 91000;
const WV_UNIFORM_WEIGHT: f32 = 1.0 / WV_ROUNDTRIP_ENTRIES as f32;
const WV_MAX_RELATIVE_ERROR: f32 = 1.0;
const WV_ROTATION_SEED: u64 = 42;
const WV_QJL_SEED: u64 = 31415;
#[test]
fn weighted_values_uniform_roundtrip_quality() {
let config = TurboQuantConfig::new(BITS_3, WV_ROUNDTRIP_DIM)
.unwrap()
.with_seed(WV_ROTATION_SEED);
let mut cache = QuantizedKVCache::new(config, 1, WV_QJL_SEED);
let mut original_values: Vec<Vec<f32>> = Vec::with_capacity(WV_ROUNDTRIP_ENTRIES);
for i in 0..WV_ROUNDTRIP_ENTRIES {
let key_seed = (i as u64)
.wrapping_mul(LCG_MULTIPLIER)
.wrapping_add(WV_KEY_SEED_OFFSET);
let val_seed = (i as u64)
.wrapping_mul(LCG_MULTIPLIER)
.wrapping_add(WV_VALUE_SEED_OFFSET);
let key = pseudo_random_vec(WV_ROUNDTRIP_DIM, key_seed);
let val = pseudo_random_vec(WV_ROUNDTRIP_DIM, val_seed);
original_values.push(val.clone());
cache.push(0, &key, &val).unwrap();
}
let weights = vec![WV_UNIFORM_WEIGHT; WV_ROUNDTRIP_ENTRIES];
let result = cache.weighted_values(0, &weights).unwrap();
assert_eq!(result.len(), WV_ROUNDTRIP_DIM);
let mut naive_avg = vec![0.0_f32; WV_ROUNDTRIP_DIM];
for val in &original_values {
for (j, &v) in val.iter().enumerate() {
naive_avg[j] += v * WV_UNIFORM_WEIGHT;
}
}
let error_vec: Vec<f32> = result
.iter()
.zip(naive_avg.iter())
.map(|(&r, &n)| r - n)
.collect();
let error_norm = l2_norm(&error_vec);
let naive_norm = l2_norm(&naive_avg);
const NORM_FLOOR: f32 = 1e-10;
let relative_error = error_norm / naive_norm.max(NORM_FLOOR);
eprintln!(
"Weighted values uniform roundtrip: relative_error={relative_error:.4}, \
error_norm={error_norm:.6}, naive_norm={naive_norm:.6}"
);
assert!(
relative_error < WV_MAX_RELATIVE_ERROR,
"Weighted values uniform roundtrip relative error {relative_error:.4} \
exceeds tolerance {WV_MAX_RELATIVE_ERROR}"
);
}
const RATIO_TEST_DIM: usize = 128;
const RATIO_TEST_ENTRIES: usize = 100;
const BITS_4: u8 = 4;
const TQ3_MIN_COMPRESSION_RATIO: f32 = 4.0;
const TQ4_MIN_COMPRESSION_RATIO: f32 = 3.0;
const RATIO_SEED_OFFSET: u64 = 7000;
#[test]
fn tq3_compression_ratio_meets_minimum() {
let config = TurboQuantConfig::new(BITS_3, RATIO_TEST_DIM)
.unwrap()
.with_seed(TEST_ROTATION_SEED);
let mut cache = QuantizedKVCache::new(config, 1, TEST_QJL_SEED);
for i in 0..RATIO_TEST_ENTRIES {
let key = pseudo_random_vec(
RATIO_TEST_DIM,
TEST_KEY_SEED.wrapping_add(i as u64 * RATIO_SEED_OFFSET),
);
let val = pseudo_random_vec(
RATIO_TEST_DIM,
TEST_VALUE_SEED.wrapping_add(i as u64 * RATIO_SEED_OFFSET),
);
cache.push(0, &key, &val).unwrap();
}
let quantized_bytes = cache.memory_usage();
let fp16_bytes = cache.fp16_equivalent_memory();
let ratio = fp16_bytes as f32 / quantized_bytes as f32;
assert!(
ratio >= TQ3_MIN_COMPRESSION_RATIO,
"TQ3 compression ratio {ratio:.2}x is below minimum {TQ3_MIN_COMPRESSION_RATIO}x \
(quantized={quantized_bytes} bytes, fp16={fp16_bytes} bytes)"
);
}
const RANGE_TEST_ENTRIES: usize = 5;
const RANGE_SEED_OFFSET: u64 = 500;
fn make_range_test_cache() -> QuantizedKVCache {
let config = test_config();
let mut cache = QuantizedKVCache::new(config, TEST_NUM_LAYERS, TEST_QJL_SEED);
for i in 0..RANGE_TEST_ENTRIES {
let key = pseudo_random_vec(
TEST_DIM,
TEST_KEY_SEED.wrapping_add(i as u64 * RANGE_SEED_OFFSET),
);
let val = pseudo_random_vec(
TEST_DIM,
TEST_VALUE_SEED.wrapping_add(i as u64 * RANGE_SEED_OFFSET),
);
cache.push(TEST_LAYER, &key, &val).unwrap();
}
cache
}
#[test]
fn dequantize_keys_range_full_matches_all() {
let cache = make_range_test_cache();
let all = cache.dequantize_all_keys(TEST_LAYER).unwrap();
let range_all = cache
.dequantize_keys_range(TEST_LAYER, 0, RANGE_TEST_ENTRIES)
.unwrap();
assert_eq!(all.len(), range_all.len());
for (a, r) in all.iter().zip(range_all.iter()) {
assert_eq!(a, r);
}
}
#[test]
fn dequantize_values_range_full_matches_all() {
let cache = make_range_test_cache();
let all = cache.dequantize_all_values(TEST_LAYER).unwrap();
let range_all = cache
.dequantize_values_range(TEST_LAYER, 0, RANGE_TEST_ENTRIES)
.unwrap();
assert_eq!(all.len(), range_all.len());
for (a, r) in all.iter().zip(range_all.iter()) {
assert_eq!(a, r);
}
}
#[test]
fn dequantize_keys_range_subset_matches_slice() {
let cache = make_range_test_cache();
let all = cache.dequantize_all_keys(TEST_LAYER).unwrap();
let start = 1;
let end = 3;
let range_subset = cache.dequantize_keys_range(TEST_LAYER, start, end).unwrap();
assert_eq!(range_subset.len(), end - start);
for (i, vec) in range_subset.iter().enumerate() {
assert_eq!(vec, &all[start + i]);
}
}
#[test]
fn dequantize_values_range_subset_matches_slice() {
let cache = make_range_test_cache();
let all = cache.dequantize_all_values(TEST_LAYER).unwrap();
let start = 2;
let end = 4;
let range_subset = cache
.dequantize_values_range(TEST_LAYER, start, end)
.unwrap();
assert_eq!(range_subset.len(), end - start);
for (i, vec) in range_subset.iter().enumerate() {
assert_eq!(vec, &all[start + i]);
}
}
#[test]
fn dequantize_keys_range_empty_returns_empty() {
let cache = make_range_test_cache();
let empty = cache.dequantize_keys_range(TEST_LAYER, 2, 2).unwrap();
assert!(empty.is_empty());
}
#[test]
fn dequantize_values_range_empty_returns_empty() {
let cache = make_range_test_cache();
let empty = cache.dequantize_values_range(TEST_LAYER, 0, 0).unwrap();
assert!(empty.is_empty());
}
#[test]
fn dequantize_keys_range_single_entry() {
let cache = make_range_test_cache();
let all = cache.dequantize_all_keys(TEST_LAYER).unwrap();
let single = cache.dequantize_keys_range(TEST_LAYER, 3, 4).unwrap();
assert_eq!(single.len(), 1);
assert_eq!(single[0], all[3]);
}
#[test]
fn dequantize_keys_range_rejects_invalid_layer() {
let cache = make_range_test_cache();
let result = cache.dequantize_keys_range(INVALID_LAYER, 0, 1);
assert!(result.is_err());
}
#[test]
fn dequantize_values_range_rejects_invalid_layer() {
let cache = make_range_test_cache();
let result = cache.dequantize_values_range(INVALID_LAYER, 0, 1);
assert!(result.is_err());
}
#[test]
fn dequantize_keys_range_rejects_start_greater_than_end() {
let cache = make_range_test_cache();
let result = cache.dequantize_keys_range(TEST_LAYER, 3, 1);
assert!(result.is_err());
}
#[test]
fn dequantize_values_range_rejects_end_beyond_entry_count() {
let cache = make_range_test_cache();
let result = cache.dequantize_values_range(TEST_LAYER, 0, RANGE_TEST_ENTRIES + 1);
assert!(result.is_err());
}
#[test]
fn validate_range_pure_operation() {
assert!(validate_range(0, 0, 0));
assert!(validate_range(0, 5, 5));
assert!(validate_range(2, 3, 5));
assert!(!validate_range(3, 2, 5));
assert!(!validate_range(0, 6, 5));
}
#[test]
fn tq4_compression_ratio_meets_minimum() {
let config = TurboQuantConfig::new(BITS_4, RATIO_TEST_DIM)
.unwrap()
.with_seed(TEST_ROTATION_SEED);
let mut cache = QuantizedKVCache::new(config, 1, TEST_QJL_SEED);
for i in 0..RATIO_TEST_ENTRIES {
let key = pseudo_random_vec(
RATIO_TEST_DIM,
TEST_KEY_SEED.wrapping_add(i as u64 * RATIO_SEED_OFFSET),
);
let val = pseudo_random_vec(
RATIO_TEST_DIM,
TEST_VALUE_SEED.wrapping_add(i as u64 * RATIO_SEED_OFFSET),
);
cache.push(0, &key, &val).unwrap();
}
let quantized_bytes = cache.memory_usage();
let fp16_bytes = cache.fp16_equivalent_memory();
let ratio = fp16_bytes as f32 / quantized_bytes as f32;
assert!(
ratio >= TQ4_MIN_COMPRESSION_RATIO,
"TQ4 compression ratio {ratio:.2}x is below minimum {TQ4_MIN_COMPRESSION_RATIO}x \
(quantized={quantized_bytes} bytes, fp16={fp16_bytes} bytes)"
);
}
const BATCH_TEST_COUNT: usize = 8;
const BATCH_SEED_OFFSET: u64 = 3000;
#[test]
fn push_batch_produces_same_results_as_individual_pushes() {
let config = test_config();
let mut key_vecs: Vec<Vec<f32>> = Vec::new();
let mut val_vecs: Vec<Vec<f32>> = Vec::new();
for i in 0..BATCH_TEST_COUNT {
key_vecs.push(pseudo_random_vec(
TEST_DIM,
TEST_KEY_SEED.wrapping_add(i as u64 * BATCH_SEED_OFFSET),
));
val_vecs.push(pseudo_random_vec(
TEST_DIM,
TEST_VALUE_SEED.wrapping_add(i as u64 * BATCH_SEED_OFFSET),
));
}
let mut cache_individual = QuantizedKVCache::new(config, TEST_NUM_LAYERS, TEST_QJL_SEED);
for i in 0..BATCH_TEST_COUNT {
cache_individual
.push(TEST_LAYER, &key_vecs[i], &val_vecs[i])
.unwrap();
}
let mut cache_batch = QuantizedKVCache::new(config, TEST_NUM_LAYERS, TEST_QJL_SEED);
let key_refs: Vec<&[f32]> = key_vecs.iter().map(|v| v.as_slice()).collect();
let val_refs: Vec<&[f32]> = val_vecs.iter().map(|v| v.as_slice()).collect();
cache_batch
.push_batch(TEST_LAYER, &key_refs, &val_refs)
.unwrap();
assert_eq!(
cache_individual.entry_count(TEST_LAYER),
cache_batch.entry_count(TEST_LAYER),
);
let keys_ind = cache_individual.dequantize_all_keys(TEST_LAYER).unwrap();
let keys_bat = cache_batch.dequantize_all_keys(TEST_LAYER).unwrap();
assert_eq!(keys_ind.len(), keys_bat.len());
for (a, b) in keys_ind.iter().zip(keys_bat.iter()) {
assert_eq!(a, b, "batch and individual key dequantizations differ");
}
let vals_ind = cache_individual.dequantize_all_values(TEST_LAYER).unwrap();
let vals_bat = cache_batch.dequantize_all_values(TEST_LAYER).unwrap();
assert_eq!(vals_ind.len(), vals_bat.len());
for (a, b) in vals_ind.iter().zip(vals_bat.iter()) {
assert_eq!(a, b, "batch and individual value dequantizations differ");
}
}
#[test]
fn push_batch_empty_is_noop() {
let config = test_config();
let mut cache = QuantizedKVCache::new(config, TEST_NUM_LAYERS, TEST_QJL_SEED);
let empty_keys: Vec<&[f32]> = Vec::new();
let empty_vals: Vec<&[f32]> = Vec::new();
cache
.push_batch(TEST_LAYER, &empty_keys, &empty_vals)
.unwrap();
assert_eq!(cache.entry_count(TEST_LAYER), 0);
}
#[test]
fn push_batch_rejects_invalid_layer() {
let config = test_config();
let mut cache = QuantizedKVCache::new(config, TEST_NUM_LAYERS, TEST_QJL_SEED);
let key = pseudo_random_vec(TEST_DIM, TEST_KEY_SEED);
let val = pseudo_random_vec(TEST_DIM, TEST_VALUE_SEED);
let keys: Vec<&[f32]> = vec![&key];
let vals: Vec<&[f32]> = vec![&val];
let result = cache.push_batch(INVALID_LAYER, &keys, &vals);
assert!(result.is_err());
}
#[test]
fn push_batch_rejects_wrong_dimension() {
let config = test_config();
let mut cache = QuantizedKVCache::new(config, TEST_NUM_LAYERS, TEST_QJL_SEED);
let wrong_key = vec![1.0_f32; TEST_DIM + 1];
let val = pseudo_random_vec(TEST_DIM, TEST_VALUE_SEED);
let keys: Vec<&[f32]> = vec![wrong_key.as_slice()];
let vals: Vec<&[f32]> = vec![val.as_slice()];
let result = cache.push_batch(TEST_LAYER, &keys, &vals);
assert!(result.is_err());
}
#[test]
fn push_batch_rejects_mismatched_lengths() {
let config = test_config();
let mut cache = QuantizedKVCache::new(config, TEST_NUM_LAYERS, TEST_QJL_SEED);
let key = pseudo_random_vec(TEST_DIM, TEST_KEY_SEED);
let val1 = pseudo_random_vec(TEST_DIM, TEST_VALUE_SEED);
let val2 = pseudo_random_vec(TEST_DIM, TEST_VALUE_SEED_2);
let keys: Vec<&[f32]> = vec![key.as_slice()];
let vals: Vec<&[f32]> = vec![val1.as_slice(), val2.as_slice()];
let result = cache.push_batch(TEST_LAYER, &keys, &vals);
assert!(result.is_err());
}
#[test]
fn push_batch_single_element_matches_push() {
let config = test_config();
let key = pseudo_random_vec(TEST_DIM, TEST_KEY_SEED);
let val = pseudo_random_vec(TEST_DIM, TEST_VALUE_SEED);
let mut cache_push = QuantizedKVCache::new(config, TEST_NUM_LAYERS, TEST_QJL_SEED);
cache_push.push(TEST_LAYER, &key, &val).unwrap();
let mut cache_batch = QuantizedKVCache::new(config, TEST_NUM_LAYERS, TEST_QJL_SEED);
let keys: Vec<&[f32]> = vec![key.as_slice()];
let vals: Vec<&[f32]> = vec![val.as_slice()];
cache_batch.push_batch(TEST_LAYER, &keys, &vals).unwrap();
assert_eq!(
cache_push.entry_count(TEST_LAYER),
cache_batch.entry_count(TEST_LAYER)
);
let k_push = cache_push.dequantize_all_keys(TEST_LAYER).unwrap();
let k_batch = cache_batch.dequantize_all_keys(TEST_LAYER).unwrap();
assert_eq!(k_push, k_batch);
let v_push = cache_push.dequantize_all_values(TEST_LAYER).unwrap();
let v_batch = cache_batch.dequantize_all_values(TEST_LAYER).unwrap();
assert_eq!(v_push, v_batch);
}
#[test]
fn push_batch_then_individual_push_consistent() {
let config = test_config();
let mut cache = QuantizedKVCache::new(config, TEST_NUM_LAYERS, TEST_QJL_SEED);
let mut key_vecs: Vec<Vec<f32>> = Vec::new();
let mut val_vecs: Vec<Vec<f32>> = Vec::new();
for i in 0..3 {
key_vecs.push(pseudo_random_vec(
TEST_DIM,
TEST_KEY_SEED.wrapping_add(i as u64 * BATCH_SEED_OFFSET),
));
val_vecs.push(pseudo_random_vec(
TEST_DIM,
TEST_VALUE_SEED.wrapping_add(i as u64 * BATCH_SEED_OFFSET),
));
}
let key_refs: Vec<&[f32]> = key_vecs.iter().map(|v| v.as_slice()).collect();
let val_refs: Vec<&[f32]> = val_vecs.iter().map(|v| v.as_slice()).collect();
cache.push_batch(TEST_LAYER, &key_refs, &val_refs).unwrap();
assert_eq!(cache.entry_count(TEST_LAYER), 3);
let extra_key = pseudo_random_vec(TEST_DIM, TEST_KEY_SEED_3);
let extra_val = pseudo_random_vec(TEST_DIM, TEST_VALUE_SEED_3);
cache.push(TEST_LAYER, &extra_key, &extra_val).unwrap();
assert_eq!(cache.entry_count(TEST_LAYER), 4);
let all_keys = cache.dequantize_all_keys(TEST_LAYER).unwrap();
assert_eq!(all_keys.len(), 4);
}
#[test]
fn collect_packed_data_empty() {
let (bytes, scales) = collect_packed_data(&[]);
assert!(bytes.is_empty());
assert!(scales.is_empty());
}
#[test]
fn collect_packed_data_roundtrip() {
let config = test_config();
let mut cache = QuantizedKVCache::new(config, TEST_NUM_LAYERS, TEST_QJL_SEED);
let key = pseudo_random_vec(TEST_DIM, TEST_KEY_SEED);
let val = pseudo_random_vec(TEST_DIM, TEST_VALUE_SEED);
cache.push(TEST_LAYER, &key, &val).unwrap();
let (packed, scales) = cache.export_packed_range(TEST_LAYER, 0, 1, true).unwrap();
assert!(!packed.is_empty());
assert_eq!(scales.len(), 1);
let key2 = pseudo_random_vec(TEST_DIM, TEST_KEY_SEED_2);
let val2 = pseudo_random_vec(TEST_DIM, TEST_VALUE_SEED_2);
cache.push(TEST_LAYER, &key2, &val2).unwrap();
let (packed2, scales2) = cache.export_packed_range(TEST_LAYER, 0, 2, true).unwrap();
assert_eq!(scales2.len(), 2);
let bytes_per_block = packed.len();
assert_eq!(&packed2[..bytes_per_block], &packed[..]);
}
#[test]
fn reconstruct_block_preserves_data() {
let config = test_config();
let mut cache = QuantizedKVCache::new(config, TEST_NUM_LAYERS, TEST_QJL_SEED);
let key = pseudo_random_vec(TEST_DIM, TEST_KEY_SEED);
let val = pseudo_random_vec(TEST_DIM, TEST_VALUE_SEED);
cache.push(TEST_LAYER, &key, &val).unwrap();
let (packed, scales) = cache.export_packed_range(TEST_LAYER, 0, 1, true).unwrap();
let bytes_per_block = packed.len();
const BITS_PER_BYTE: usize = 8;
let signs_per_block = TEST_DIM.div_ceil(BITS_PER_BYTE);
let qjl_signs = vec![0u8; signs_per_block];
let residual_norms = vec![0u16; 1];
let import = PackedImport {
layer: TEST_LAYER,
polar_bits: BITS_3 - 1,
packed_bytes: &packed,
scales: &scales,
qjl_signs_flat: &qjl_signs,
residual_norms: &residual_norms,
bytes_per_block,
signs_per_block,
is_keys: false,
};
let block = reconstruct_block(&import, 0);
assert_eq!(block.polar_block.packed_indices, &packed[..]);
assert_eq!(block.polar_block.scale.to_bits(), scales[0]);
}
#[test]
fn select_blocks_returns_correct_side() {
let config = test_config();
let mut cache = QuantizedKVCache::new(config, TEST_NUM_LAYERS, TEST_QJL_SEED);
let key = pseudo_random_vec(TEST_DIM, TEST_KEY_SEED);
let val = pseudo_random_vec(TEST_DIM, TEST_VALUE_SEED);
cache.push(TEST_LAYER, &key, &val).unwrap();
let keys = cache.select_blocks(TEST_LAYER, true);
let vals = cache.select_blocks(TEST_LAYER, false);
assert_eq!(keys.len(), 1);
assert_eq!(vals.len(), 1);
assert_ne!(
keys[0].polar_block.packed_indices,
vals[0].polar_block.packed_indices
);
}
#[test]
fn select_blocks_mut_allows_push() {
let config = test_config();
let mut cache = QuantizedKVCache::new(config, TEST_NUM_LAYERS, TEST_QJL_SEED);
assert_eq!(cache.select_blocks(TEST_LAYER, true).len(), 0);
let key = pseudo_random_vec(TEST_DIM, TEST_KEY_SEED);
let val = pseudo_random_vec(TEST_DIM, TEST_VALUE_SEED);
cache.push(TEST_LAYER, &key, &val).unwrap();
let keys_mut = cache.select_blocks_mut(TEST_LAYER, true);
assert_eq!(keys_mut.len(), 1);
}
}