use half::f16;
use crate::codebook::Codebook;
use crate::error::{require, Result, TurboQuantError};
use crate::packed::{PackedBlock, TurboQuantConfig};
use crate::quantize::{
dequantize_into_with_codebook, dequantize_vec, l2_norm, quantize_vec, DequantScratch,
};
const SPLITMIX_GAMMA: u64 = 0x9e37_79b9_7f4a_7c15;
const SPLITMIX_MUL_1: u64 = 0xbf58_476d_1ce4_e5b9;
const SPLITMIX_MUL_2: u64 = 0x94d0_49bb_1331_11eb;
const SPLITMIX_SHIFT_1: u32 = 30;
const SPLITMIX_SHIFT_2: u32 = 27;
const SPLITMIX_SHIFT_3: u32 = 31;
const SQRT_PI_OVER_2: f32 = 1.253_314_1;
const BITS_PER_BYTE: usize = 8;
const MIN_POLAR_BITS: u8 = 3;
const MAX_POLAR_BITS: u8 = 4;
const SEED_MIX_MULTIPLIER: u64 = 0x517c_c1b7_2722_0a95;
const SEED_MIX_XOR: u64 = 0x6c62_272e_07bb_0142;
const POSITIVE_SIGN: f32 = 1.0;
const NEGATIVE_SIGN: f32 = -1.0;
const SIGN_BIT_SCALE: f32 = 2.0;
const SIGN_BIT_OFFSET: f32 = -1.0;
const SEED_MIX_SHIFT: u32 = 32;
const SIGN_PACK_BITS: usize = BITS_PER_BYTE;
pub struct QjlBlock {
pub polar_block: PackedBlock,
pub qjl_signs: Vec<u8>,
pub residual_norm: f16,
}
impl QjlBlock {
pub fn from_parts(polar_block: PackedBlock, qjl_signs: Vec<u8>, residual_norm: f16) -> Self {
Self {
polar_block,
qjl_signs,
residual_norm,
}
}
}
fn mix_seed(seed: u64, row_index: usize) -> u64 {
let row = row_index as u64;
let mixed = seed.wrapping_mul(SEED_MIX_MULTIPLIER).wrapping_add(row);
mixed ^ (mixed >> SEED_MIX_SHIFT) ^ SEED_MIX_XOR
}
fn rademacher_sign_from_hash(seed: u64, row: usize, col: usize) -> f32 {
let z = seed
.wrapping_add((row as u64).wrapping_mul(SPLITMIX_GAMMA))
.wrapping_add(col as u64);
let z = (z ^ (z >> SPLITMIX_SHIFT_1)).wrapping_mul(SPLITMIX_MUL_1);
let z = (z ^ (z >> SPLITMIX_SHIFT_2)).wrapping_mul(SPLITMIX_MUL_2);
let z = z ^ (z >> SPLITMIX_SHIFT_3);
if z & 1 == 0 {
POSITIVE_SIGN
} else {
NEGATIVE_SIGN
}
}
pub fn generate_rademacher_row(dim: usize, seed: u64, row_index: usize) -> Vec<f32> {
let row_seed = mix_seed(seed, row_index);
let inv_sqrt_d = 1.0 / (dim as f32).sqrt();
(0..dim)
.map(|col| {
let sign = rademacher_sign_from_hash(row_seed, row_index, col);
sign * inv_sqrt_d
})
.collect()
}
fn rademacher_vector_product(data: &[f32], dim: usize, seed: u64, row_index: usize) -> f32 {
let row_seed = mix_seed(seed, row_index);
let inv_sqrt_d = 1.0 / (dim as f32).sqrt();
data.iter()
.enumerate()
.map(|(col, &val)| {
let sign = rademacher_sign_from_hash(row_seed, row_index, col);
sign * val
})
.sum::<f32>()
* inv_sqrt_d
}
pub fn compute_residual(original: &[f32], reconstructed: &[f32]) -> Vec<f32> {
original
.iter()
.zip(reconstructed.iter())
.map(|(&o, &r)| o - r)
.collect()
}
pub fn dot_product(a: &[f32], b: &[f32]) -> f32 {
a.iter().zip(b.iter()).map(|(&x, &y)| x * y).sum()
}
pub fn sign_bit(signs: &[u8], index: usize) -> f32 {
let byte_index = index / BITS_PER_BYTE;
let bit_offset = index % BITS_PER_BYTE;
let bit = ((signs[byte_index] >> bit_offset) & 1) as f32;
bit * SIGN_BIT_SCALE + SIGN_BIT_OFFSET }
pub fn pack_sign_bits(signs: &[bool]) -> Vec<u8> {
let num_bytes = ceiling_div(signs.len(), BITS_PER_BYTE);
let mut packed = vec![0u8; num_bytes];
for (i, &positive) in signs.iter().enumerate() {
let byte_index = i / SIGN_PACK_BITS;
let bit_offset = i % SIGN_PACK_BITS;
packed[byte_index] |= (positive as u8) << bit_offset;
}
packed
}
pub fn qjl_scaling_constant(residual_norm: f32, dim: usize) -> f32 {
residual_norm * SQRT_PI_OVER_2 / (dim as f32).sqrt()
}
fn ceiling_div(a: usize, b: usize) -> usize {
a.div_ceil(b)
}
fn is_valid_qjl_bits(bits: u8) -> bool {
(MIN_POLAR_BITS..=MAX_POLAR_BITS).contains(&bits)
}
fn polar_bit_width(total_bits: u8) -> u8 {
total_bits - QJL_SIGN_BITS
}
const QJL_SIGN_BITS: u8 = 1;
pub fn compute_qjl_signs(residual: &[f32], dim: usize, seed: u64) -> crate::error::Result<Vec<u8>> {
if residual.len() != dim {
return Err(crate::error::TurboQuantError::DimensionMismatch {
expected: dim,
actual: residual.len(),
});
}
let sign_bools: Vec<bool> = (0..dim)
.map(|j| {
let projection = rademacher_vector_product(residual, dim, seed, j);
projection >= 0.0
})
.collect();
Ok(pack_sign_bits(&sign_bools))
}
fn validate_qjl_config(bits: u8) -> Result<()> {
require(
is_valid_qjl_bits(bits),
TurboQuantError::UnsupportedBits(bits),
)
}
pub fn quantize_with_qjl(
config: &TurboQuantConfig,
data: &[f32],
qjl_seed: u64,
) -> Result<QjlBlock> {
validate_qjl_config(config.bits)?;
let polar_bits = polar_bit_width(config.bits);
let polar_config =
TurboQuantConfig::new(polar_bits, config.dim)?.with_seed(config.rotation_seed);
let polar_block = quantize_vec(&polar_config, data)?;
let reconstructed = dequantize_vec(&polar_config, &polar_block)?;
let residual = compute_residual(data, &reconstructed);
let residual_norm = l2_norm(&residual);
let qjl_signs = compute_qjl_signs(&residual, config.dim, qjl_seed)?;
Ok(QjlBlock {
polar_block,
qjl_signs,
residual_norm: f16::from_f32(residual_norm),
})
}
pub struct QjlBatchResources {
pub polar_config: TurboQuantConfig,
pub polar_codebook: crate::codebook::Codebook,
pub polar_sign_pattern: Vec<f32>,
pub scratch: DequantScratch,
}
impl QjlBatchResources {
pub fn new(config: &TurboQuantConfig) -> Result<Self> {
validate_qjl_config(config.bits)?;
let polar_bits = polar_bit_width(config.bits);
let polar_config =
TurboQuantConfig::new(polar_bits, config.dim)?.with_seed(config.rotation_seed);
let polar_codebook = crate::codebook::get_codebook(polar_bits, config.dim)?;
let polar_sign_pattern =
crate::rotation::generate_sign_pattern(config.dim, config.rotation_seed);
let scratch = DequantScratch::new(config.dim);
Ok(Self {
polar_config,
polar_codebook,
polar_sign_pattern,
scratch,
})
}
}
pub fn quantize_with_qjl_resources(
data: &[f32],
qjl_seed: u64,
res: &mut QjlBatchResources,
) -> Result<QjlBlock> {
use crate::quantize::quantize_vec_with_codebook;
let polar_block = quantize_vec_with_codebook(
&res.polar_config,
data,
&res.polar_codebook,
&res.polar_sign_pattern,
)?;
dequantize_into_with_codebook(
&res.polar_config,
&polar_block,
&res.polar_codebook,
&res.polar_sign_pattern,
&mut res.scratch,
)?;
let residual = compute_residual(data, &res.scratch.values);
let residual_norm = l2_norm(&residual);
let dim = res.polar_config.dim;
let qjl_signs = compute_qjl_signs(&residual, dim, qjl_seed)?;
Ok(QjlBlock {
polar_block,
qjl_signs,
residual_norm: f16::from_f32(residual_norm),
})
}
fn corrected_estimate(base: f32, scaling: f32, correction: f32) -> f32 {
base + scaling * correction
}
pub fn precompute_query_projections(query: &[f32], dim: usize, qjl_seed: u64) -> Vec<f32> {
(0..dim)
.map(|j| rademacher_vector_product(query, dim, qjl_seed, j))
.collect()
}
#[cfg(test)]
fn unpack_signs_to_f32(signs: &[u8], count: usize) -> Vec<f32> {
(0..count).map(|i| sign_bit(signs, i)).collect()
}
fn qjl_correction(r_query: &[f32], signs: &[u8], dim: usize) -> f32 {
r_query
.iter()
.enumerate()
.take(dim)
.map(|(i, &rq)| rq * sign_bit(signs, i))
.sum()
}
pub fn estimate_inner_product(
query: &[f32],
r_query: &[f32],
qjl_block: &QjlBlock,
config: &TurboQuantConfig,
) -> Result<f32> {
let polar_config = TurboQuantConfig::new(qjl_block.polar_block.bits, config.dim)?
.with_seed(config.rotation_seed);
let reconstructed = dequantize_vec(&polar_config, &qjl_block.polar_block)?;
let base = dot_product(query, &reconstructed);
let dim = config.dim;
let residual_norm = qjl_block.residual_norm.to_f32();
let c = qjl_scaling_constant(residual_norm, dim);
let correction = qjl_correction(r_query, &qjl_block.qjl_signs, dim);
Ok(corrected_estimate(base, c, correction))
}
pub struct EstimationContext<'a> {
pub polar_config: &'a TurboQuantConfig,
pub codebook: &'a Codebook,
pub sign_pattern: &'a [f32],
pub dim: usize,
pub scratch: DequantScratch,
}
pub fn estimate_inner_product_with_codebook(
query: &[f32],
r_query: &[f32],
qjl_block: &QjlBlock,
ctx: &mut EstimationContext,
) -> Result<f32> {
dequantize_into_with_codebook(
ctx.polar_config,
&qjl_block.polar_block,
ctx.codebook,
ctx.sign_pattern,
&mut ctx.scratch,
)?;
let base = dot_product(query, &ctx.scratch.values);
let dim = ctx.dim;
let residual_norm = qjl_block.residual_norm.to_f32();
let c = qjl_scaling_constant(residual_norm, dim);
let correction = qjl_correction(r_query, &qjl_block.qjl_signs, dim);
Ok(corrected_estimate(base, c, correction))
}
pub fn estimate_inner_product_single(
query: &[f32],
qjl_block: &QjlBlock,
config: &TurboQuantConfig,
qjl_seed: u64,
) -> Result<f32> {
let r_query = precompute_query_projections(query, config.dim, qjl_seed);
estimate_inner_product(query, &r_query, qjl_block, config)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_utils::{pseudo_random_vec, LCG_MULTIPLIER};
const TEST_DIM: usize = 64;
const TEST_DIM_64: usize = 64;
const TEST_ROTATION_SEED: u64 = 42;
const TEST_QJL_SEED: u64 = 12345;
const TEST_SIGN_INDEX: usize = 7;
const TEST_ALT_SEED: u64 = 999;
const TEST_SEED_A: u64 = 11111;
const TEST_SEED_B: u64 = 22222;
const TEST_SEED_C: u64 = 33333;
const TEST_SEED_D: u64 = 44444;
const TEST_SEED_E: u64 = 55555;
const TEST_SEED_F: u64 = 66666;
const TEST_SEED_G: u64 = 77777;
const TEST_SEED_H: u64 = 88888;
const TEST_EXPECTED_DOT: f32 = 32.0;
const TEST_SCALE_FACTOR: f32 = 2.0;
const TEST_SAMPLE_COUNT: u64 = 100;
const TEST_LARGE_SAMPLE_COUNT: u64 = 200;
const TEST_SIGN_PACK_COUNT: usize = 3;
const FLOAT_EPSILON: f32 = 1e-6;
const BIAS_TOLERANCE: f32 = 0.20;
const STATISTICAL_SAMPLE_COUNT: usize = 500;
const BITS_3: u8 = 3;
const BITS_4: u8 = 4;
const PACK_TEST_LEN: usize = 19;
const SIGN_PACK_PATTERN_STRIDE: usize = 3;
const MAX_RELATIVE_VARIANCE: f32 = 1.0;
const ORTHO_TOLERANCE: f32 = 0.3;
const PARALLEL_RELATIVE_TOLERANCE: f32 = 0.5;
fn unit_vec(dim: usize, axis: usize) -> Vec<f32> {
let mut v = vec![0.0_f32; dim];
v[axis] = 1.0;
v
}
#[test]
fn rademacher_row_has_correct_magnitude() {
let row = generate_rademacher_row(TEST_DIM, TEST_QJL_SEED, 0);
let expected_magnitude = 1.0 / (TEST_DIM as f32).sqrt();
assert_eq!(row.len(), TEST_DIM);
for &val in &row {
assert!(
(val.abs() - expected_magnitude).abs() < FLOAT_EPSILON,
"expected +/-{expected_magnitude}, got {val}"
);
}
}
#[test]
fn rademacher_row_is_deterministic() {
let row_a = generate_rademacher_row(TEST_DIM, TEST_QJL_SEED, TEST_SIGN_INDEX);
let row_b = generate_rademacher_row(TEST_DIM, TEST_QJL_SEED, TEST_SIGN_INDEX);
assert_eq!(row_a, row_b);
}
#[test]
fn rademacher_different_rows_differ() {
let row_a = generate_rademacher_row(TEST_DIM, TEST_QJL_SEED, 0);
let row_b = generate_rademacher_row(TEST_DIM, TEST_QJL_SEED, 1);
assert_ne!(row_a, row_b);
}
#[test]
fn rademacher_vector_product_matches_explicit_row() {
let data = pseudo_random_vec(TEST_DIM, TEST_ALT_SEED);
let row = generate_rademacher_row(TEST_DIM, TEST_QJL_SEED, TEST_SIGN_PACK_COUNT);
let expected = dot_product(&data, &row);
let actual =
rademacher_vector_product(&data, TEST_DIM, TEST_QJL_SEED, TEST_SIGN_PACK_COUNT);
assert!(
(expected - actual).abs() < FLOAT_EPSILON,
"expected {expected}, got {actual}"
);
}
const RESIDUAL_ORIGINAL: [f32; 4] = [1.0, 2.0, 3.0, 4.0];
const RESIDUAL_RECONSTRUCTED: [f32; 4] = [0.9, 2.1, 2.8, 4.2];
const RESIDUAL_EXPECTED: [f32; 4] = [0.1, -0.1, 0.2, -0.2];
#[test]
fn compute_residual_basic() {
let residual = compute_residual(&RESIDUAL_ORIGINAL, &RESIDUAL_RECONSTRUCTED);
for (i, (&r, &e)) in residual.iter().zip(RESIDUAL_EXPECTED.iter()).enumerate() {
assert!(
(r - e).abs() < FLOAT_EPSILON,
"residual[{i}]: expected {e}, got {r}"
);
}
}
#[test]
fn compute_residual_zero_when_identical() {
let v = vec![1.0, 2.0, 3.0];
let residual = compute_residual(&v, &v);
for &r in &residual {
assert!(r.abs() < FLOAT_EPSILON);
}
}
#[test]
fn dot_product_known_value() {
let a = vec![1.0, 2.0, 3.0];
let b = vec![4.0, 5.0, 6.0];
let expected = TEST_EXPECTED_DOT; let actual = dot_product(&a, &b);
assert!((actual - expected).abs() < FLOAT_EPSILON);
}
#[test]
fn dot_product_orthogonal() {
let a = vec![1.0, 0.0];
let b = vec![0.0, 1.0];
assert!(dot_product(&a, &b).abs() < FLOAT_EPSILON);
}
#[test]
fn pack_unpack_sign_bits_roundtrip() {
let bools: Vec<bool> = (0..PACK_TEST_LEN)
.map(|i| i % SIGN_PACK_PATTERN_STRIDE == 0)
.collect();
let packed = pack_sign_bits(&bools);
for (i, &expected_positive) in bools.iter().enumerate() {
let extracted = sign_bit(&packed, i);
let expected = if expected_positive {
POSITIVE_SIGN
} else {
NEGATIVE_SIGN
};
assert!(
(extracted - expected).abs() < FLOAT_EPSILON,
"bit {i}: expected {expected}, got {extracted}"
);
}
}
#[test]
fn pack_sign_bits_all_true() {
let bools = vec![true; BITS_PER_BYTE];
let packed = pack_sign_bits(&bools);
assert_eq!(packed.len(), 1);
assert_eq!(packed[0], 0xFF);
}
#[test]
fn pack_sign_bits_all_false() {
let bools = vec![false; BITS_PER_BYTE];
let packed = pack_sign_bits(&bools);
assert_eq!(packed.len(), 1);
assert_eq!(packed[0], 0x00);
}
#[test]
fn qjl_scaling_constant_correct_formula() {
let residual_norm = TEST_SCALE_FACTOR;
let dim = TEST_DIM_64;
let expected = residual_norm * SQRT_PI_OVER_2 / (dim as f32).sqrt();
let actual = qjl_scaling_constant(residual_norm, dim);
assert!(
(actual - expected).abs() < FLOAT_EPSILON,
"expected {expected}, got {actual}"
);
}
#[test]
fn qjl_scaling_constant_zero_norm() {
let actual = qjl_scaling_constant(0.0, TEST_DIM);
assert!(actual.abs() < FLOAT_EPSILON);
}
#[test]
fn quantize_with_qjl_produces_valid_block() {
let config = TurboQuantConfig::new(BITS_3, TEST_DIM)
.unwrap()
.with_seed(TEST_ROTATION_SEED);
let data = pseudo_random_vec(TEST_DIM, TEST_SEED_A);
let block = quantize_with_qjl(&config, &data, TEST_QJL_SEED).unwrap();
let expected_sign_bytes = ceiling_div(TEST_DIM, BITS_PER_BYTE);
assert_eq!(block.qjl_signs.len(), expected_sign_bytes);
assert!(block.residual_norm.to_f32() >= 0.0);
}
#[test]
fn quantize_with_qjl_4bit() {
let config = TurboQuantConfig::new(BITS_4, TEST_DIM)
.unwrap()
.with_seed(TEST_ROTATION_SEED);
let data = pseudo_random_vec(TEST_DIM, TEST_SEED_B);
let block = quantize_with_qjl(&config, &data, TEST_QJL_SEED);
assert!(block.is_ok());
}
#[test]
fn quantize_with_qjl_rejects_invalid_bits() {
assert!(!is_valid_qjl_bits(2));
assert!(!is_valid_qjl_bits(5));
}
#[test]
fn qjl_quantize_is_deterministic() {
let config = TurboQuantConfig::new(BITS_3, TEST_DIM)
.unwrap()
.with_seed(TEST_ROTATION_SEED);
let data = pseudo_random_vec(TEST_DIM, TEST_SEED_C);
let block_a = quantize_with_qjl(&config, &data, TEST_QJL_SEED).unwrap();
let block_b = quantize_with_qjl(&config, &data, TEST_QJL_SEED).unwrap();
assert_eq!(block_a.qjl_signs, block_b.qjl_signs);
assert_eq!(
block_a.residual_norm.to_f32(),
block_b.residual_norm.to_f32()
);
}
#[test]
fn estimate_inner_product_single_is_deterministic() {
let config = TurboQuantConfig::new(BITS_3, TEST_DIM)
.unwrap()
.with_seed(TEST_ROTATION_SEED);
let key = pseudo_random_vec(TEST_DIM, TEST_SEED_D);
let query = pseudo_random_vec(TEST_DIM, TEST_SEED_E);
let block = quantize_with_qjl(&config, &key, TEST_QJL_SEED).unwrap();
let est_a = estimate_inner_product_single(&query, &block, &config, TEST_QJL_SEED).unwrap();
let est_b = estimate_inner_product_single(&query, &block, &config, TEST_QJL_SEED).unwrap();
assert!(
(est_a - est_b).abs() < FLOAT_EPSILON,
"not deterministic: {est_a} vs {est_b}"
);
}
#[test]
fn qjl_reduces_bias_vs_plain_polarquant() {
let config = TurboQuantConfig::new(BITS_3, TEST_DIM)
.unwrap()
.with_seed(TEST_ROTATION_SEED);
let mut qjl_bias_sum = 0.0_f64;
let mut plain_bias_sum = 0.0_f64;
for i in 0..STATISTICAL_SAMPLE_COUNT {
let key_seed = (i as u64)
.wrapping_mul(LCG_MULTIPLIER)
.wrapping_add(TEST_SAMPLE_COUNT);
let query_seed = (i as u64)
.wrapping_mul(LCG_MULTIPLIER)
.wrapping_add(TEST_LARGE_SAMPLE_COUNT);
let qjl_seed = TEST_QJL_SEED.wrapping_add(i as u64);
let key = pseudo_random_vec(TEST_DIM, key_seed);
let query = pseudo_random_vec(TEST_DIM, query_seed);
let true_ip = dot_product(&key, &query) as f64;
let block = quantize_with_qjl(&config, &key, qjl_seed).unwrap();
let qjl_est =
estimate_inner_product_single(&query, &block, &config, qjl_seed).unwrap() as f64;
let plain_block = quantize_vec(&config, &key).unwrap();
let plain_recon = dequantize_vec(&config, &plain_block).unwrap();
let plain_est = dot_product(&query, &plain_recon) as f64;
qjl_bias_sum += qjl_est - true_ip;
plain_bias_sum += plain_est - true_ip;
}
let qjl_mean_bias = (qjl_bias_sum / STATISTICAL_SAMPLE_COUNT as f64).abs() as f32;
let plain_mean_bias = (plain_bias_sum / STATISTICAL_SAMPLE_COUNT as f64).abs() as f32;
assert!(
qjl_mean_bias < BIAS_TOLERANCE,
"QJL mean bias {qjl_mean_bias} exceeds tolerance {BIAS_TOLERANCE}"
);
let _ = plain_mean_bias; }
#[test]
fn qjl_estimation_variance_is_bounded() {
let config = TurboQuantConfig::new(BITS_3, TEST_DIM)
.unwrap()
.with_seed(TEST_ROTATION_SEED);
let key = pseudo_random_vec(TEST_DIM, TEST_SEED_F);
let query = pseudo_random_vec(TEST_DIM, TEST_SEED_G);
let true_ip = dot_product(&key, &query);
let mut sum_sq_error = 0.0_f64;
for seed_offset in 0..STATISTICAL_SAMPLE_COUNT {
let qjl_seed = TEST_QJL_SEED.wrapping_add(seed_offset as u64);
let block = quantize_with_qjl(&config, &key, qjl_seed).unwrap();
let est = estimate_inner_product_single(&query, &block, &config, qjl_seed).unwrap();
let error = (est - true_ip) as f64;
sum_sq_error += error * error;
}
let variance = sum_sq_error / STATISTICAL_SAMPLE_COUNT as f64;
let true_ip_sq = (true_ip as f64) * (true_ip as f64);
let relative_variance = if true_ip_sq > FLOAT_EPSILON as f64 {
variance / true_ip_sq
} else {
variance
};
assert!(
(relative_variance as f32) < MAX_RELATIVE_VARIANCE,
"relative variance {relative_variance} exceeds bound {MAX_RELATIVE_VARIANCE}"
);
}
#[test]
fn orthogonal_vectors_estimate_near_zero() {
let config = TurboQuantConfig::new(BITS_3, TEST_DIM)
.unwrap()
.with_seed(TEST_ROTATION_SEED);
let key = unit_vec(TEST_DIM, 0);
let query = unit_vec(TEST_DIM, 1);
let block = quantize_with_qjl(&config, &key, TEST_QJL_SEED).unwrap();
let est = estimate_inner_product_single(&query, &block, &config, TEST_QJL_SEED).unwrap();
assert!(
est.abs() < ORTHO_TOLERANCE,
"orthogonal estimate {est} not near zero"
);
}
#[test]
fn parallel_vectors_estimate_near_product() {
let config = TurboQuantConfig::new(BITS_4, TEST_DIM)
.unwrap()
.with_seed(TEST_ROTATION_SEED);
let key = pseudo_random_vec(TEST_DIM, TEST_SEED_H);
let key_norm = l2_norm(&key);
let query: Vec<f32> = key.iter().map(|&v| v * TEST_SCALE_FACTOR).collect();
let true_ip = dot_product(&query, &key);
let block = quantize_with_qjl(&config, &key, TEST_QJL_SEED).unwrap();
let est = estimate_inner_product_single(&query, &block, &config, TEST_QJL_SEED).unwrap();
let relative_error = (est - true_ip).abs() / true_ip.abs();
assert!(
relative_error < PARALLEL_RELATIVE_TOLERANCE,
"parallel relative error {relative_error} exceeds tolerance \
{PARALLEL_RELATIVE_TOLERANCE} (est={est}, true={true_ip}, key_norm={key_norm})"
);
}
#[test]
fn ceiling_div_exact() {
assert_eq!(ceiling_div(16, 8), 2);
assert_eq!(ceiling_div(8, 8), 1);
}
#[test]
fn ceiling_div_with_remainder() {
assert_eq!(ceiling_div(17, 8), 3);
assert_eq!(ceiling_div(1, 8), 1);
}
#[test]
fn is_valid_qjl_bits_accepts_3_and_4() {
assert!(is_valid_qjl_bits(BITS_3));
assert!(is_valid_qjl_bits(BITS_4));
}
#[test]
fn is_valid_qjl_bits_rejects_others() {
assert!(!is_valid_qjl_bits(1));
assert!(!is_valid_qjl_bits(2));
assert!(!is_valid_qjl_bits(5));
}
const BIT_BUDGET_DIM: usize = 128;
const TQ3_EXPECTED_POLAR_BITS: u8 = 2;
const TQ4_EXPECTED_POLAR_BITS: u8 = 3;
const BIT_BUDGET_SEED: u64 = 42;
const BIT_BUDGET_QJL_SEED: u64 = 99999;
#[test]
fn tq3_uses_2bit_polar_quant() {
let config = TurboQuantConfig::new(BITS_3, BIT_BUDGET_DIM)
.unwrap()
.with_seed(BIT_BUDGET_SEED);
let data = pseudo_random_vec(BIT_BUDGET_DIM, TEST_SEED_A);
let block = quantize_with_qjl(&config, &data, BIT_BUDGET_QJL_SEED).unwrap();
assert_eq!(
block.polar_block.bits, TQ3_EXPECTED_POLAR_BITS,
"TQ3 should use {TQ3_EXPECTED_POLAR_BITS}-bit PolarQuant, got {}",
block.polar_block.bits
);
}
#[test]
fn tq4_uses_3bit_polar_quant() {
let config = TurboQuantConfig::new(BITS_4, BIT_BUDGET_DIM)
.unwrap()
.with_seed(BIT_BUDGET_SEED);
let data = pseudo_random_vec(BIT_BUDGET_DIM, TEST_SEED_B);
let block = quantize_with_qjl(&config, &data, BIT_BUDGET_QJL_SEED).unwrap();
assert_eq!(
block.polar_block.bits, TQ4_EXPECTED_POLAR_BITS,
"TQ4 should use {TQ4_EXPECTED_POLAR_BITS}-bit PolarQuant, got {}",
block.polar_block.bits
);
}
const BYTE_BITS: usize = 8;
const RESIDUAL_NORM_STORAGE_BYTES: usize = 2;
const TQ3_D128_EXPECTED_BLOCK_BYTES: usize = 52;
const TQ4_D128_EXPECTED_BLOCK_BYTES: usize = 68;
#[test]
fn tq3_qjl_block_size_matches_expected() {
let config = TurboQuantConfig::new(BITS_3, BIT_BUDGET_DIM)
.unwrap()
.with_seed(BIT_BUDGET_SEED);
let data = pseudo_random_vec(BIT_BUDGET_DIM, TEST_SEED_C);
let block = quantize_with_qjl(&config, &data, BIT_BUDGET_QJL_SEED).unwrap();
let polar_bytes = block.polar_block.size_bytes();
let sign_bytes = block.qjl_signs.len();
let total = polar_bytes + sign_bytes + RESIDUAL_NORM_STORAGE_BYTES;
assert_eq!(
sign_bytes,
BIT_BUDGET_DIM / BYTE_BITS,
"TQ3 sign bytes: expected {}, got {sign_bytes}",
BIT_BUDGET_DIM / BYTE_BITS
);
assert_eq!(
total, TQ3_D128_EXPECTED_BLOCK_BYTES,
"TQ3 QjlBlock total size: expected {TQ3_D128_EXPECTED_BLOCK_BYTES}, got {total} \
(polar={polar_bytes}, signs={sign_bytes}, residual_norm={RESIDUAL_NORM_STORAGE_BYTES})"
);
}
#[test]
fn tq4_qjl_block_size_matches_expected() {
let config = TurboQuantConfig::new(BITS_4, BIT_BUDGET_DIM)
.unwrap()
.with_seed(BIT_BUDGET_SEED);
let data = pseudo_random_vec(BIT_BUDGET_DIM, TEST_SEED_D);
let block = quantize_with_qjl(&config, &data, BIT_BUDGET_QJL_SEED).unwrap();
let polar_bytes = block.polar_block.size_bytes();
let sign_bytes = block.qjl_signs.len();
let total = polar_bytes + sign_bytes + RESIDUAL_NORM_STORAGE_BYTES;
assert_eq!(
sign_bytes,
BIT_BUDGET_DIM / BYTE_BITS,
"TQ4 sign bytes: expected {}, got {sign_bytes}",
BIT_BUDGET_DIM / BYTE_BITS
);
assert_eq!(
total, TQ4_D128_EXPECTED_BLOCK_BYTES,
"TQ4 QjlBlock total size: expected {TQ4_D128_EXPECTED_BLOCK_BYTES}, got {total} \
(polar={polar_bytes}, signs={sign_bytes}, residual_norm={RESIDUAL_NORM_STORAGE_BYTES})"
);
}
const RADEMACHER_BIAS_TOLERANCE: f32 = 0.15;
const RADEMACHER_SAMPLE_COUNT: usize = 1000;
#[test]
fn hash_rademacher_is_deterministic() {
const TEST_ROW: usize = 5;
const TEST_COL: usize = 10;
let sign_a = rademacher_sign_from_hash(TEST_QJL_SEED, TEST_ROW, TEST_COL);
let sign_b = rademacher_sign_from_hash(TEST_QJL_SEED, TEST_ROW, TEST_COL);
assert!((sign_a - sign_b).abs() < FLOAT_EPSILON);
}
#[test]
fn hash_rademacher_produces_unit_signs() {
for row in 0..TEST_DIM {
for col in 0..TEST_DIM {
let sign = rademacher_sign_from_hash(TEST_QJL_SEED, row, col);
assert!(
(sign - POSITIVE_SIGN).abs() < FLOAT_EPSILON
|| (sign - NEGATIVE_SIGN).abs() < FLOAT_EPSILON,
"sign at ({row},{col}) = {sign}, expected +1.0 or -1.0"
);
}
}
}
#[test]
fn hash_rademacher_approximately_unbiased() {
let mut positive_count = 0usize;
for row in 0..RADEMACHER_SAMPLE_COUNT {
for col in 0..TEST_DIM {
let sign = rademacher_sign_from_hash(TEST_QJL_SEED, row, col);
if sign > 0.0 {
positive_count += 1;
}
}
}
let total = RADEMACHER_SAMPLE_COUNT * TEST_DIM;
let positive_fraction = positive_count as f32 / total as f32;
const EXPECTED_FRACTION: f32 = 0.5;
let bias = (positive_fraction - EXPECTED_FRACTION).abs();
assert!(
bias < RADEMACHER_BIAS_TOLERANCE,
"hash Rademacher bias {bias} (positive fraction {positive_fraction}) \
exceeds tolerance {RADEMACHER_BIAS_TOLERANCE}"
);
}
#[test]
fn hash_rademacher_different_positions_differ() {
let mut same_count = 0usize;
const HASH_DIVERSITY_CHECK_COUNT: usize = 100;
for i in 0..HASH_DIVERSITY_CHECK_COUNT {
let sign_a = rademacher_sign_from_hash(TEST_QJL_SEED, i, 0);
let sign_b = rademacher_sign_from_hash(TEST_QJL_SEED, i, 1);
if (sign_a - sign_b).abs() < FLOAT_EPSILON {
same_count += 1;
}
}
assert!(
same_count < HASH_DIVERSITY_CHECK_COUNT,
"all {HASH_DIVERSITY_CHECK_COUNT} sign pairs were identical -- hash is not mixing positions"
);
}
#[test]
fn estimate_inner_product_is_deterministic() {
let config = TurboQuantConfig::new(BITS_3, TEST_DIM)
.unwrap()
.with_seed(TEST_ROTATION_SEED);
let key = pseudo_random_vec(TEST_DIM, TEST_SEED_D);
let query = pseudo_random_vec(TEST_DIM, TEST_SEED_E);
let block = quantize_with_qjl(&config, &key, TEST_QJL_SEED).unwrap();
let r_query = precompute_query_projections(&query, TEST_DIM, TEST_QJL_SEED);
let est_a = estimate_inner_product(&query, &r_query, &block, &config).unwrap();
let est_b = estimate_inner_product(&query, &r_query, &block, &config).unwrap();
assert!(
(est_a - est_b).abs() < FLOAT_EPSILON,
"not deterministic: {est_a} vs {est_b}"
);
}
#[test]
fn estimate_inner_product_unbiased_over_many_samples() {
let config = TurboQuantConfig::new(BITS_3, TEST_DIM)
.unwrap()
.with_seed(TEST_ROTATION_SEED);
let mut bias_sum = 0.0_f64;
for i in 0..STATISTICAL_SAMPLE_COUNT {
let key_seed = (i as u64)
.wrapping_mul(LCG_MULTIPLIER)
.wrapping_add(TEST_SAMPLE_COUNT);
let query_seed = (i as u64)
.wrapping_mul(LCG_MULTIPLIER)
.wrapping_add(TEST_LARGE_SAMPLE_COUNT);
let qjl_seed = TEST_QJL_SEED.wrapping_add(i as u64);
let key = pseudo_random_vec(TEST_DIM, key_seed);
let query = pseudo_random_vec(TEST_DIM, query_seed);
let true_ip = dot_product(&key, &query) as f64;
let block = quantize_with_qjl(&config, &key, qjl_seed).unwrap();
let r_query = precompute_query_projections(&query, TEST_DIM, qjl_seed);
let est = estimate_inner_product(&query, &r_query, &block, &config).unwrap() as f64;
bias_sum += est - true_ip;
}
let mean_bias = (bias_sum / STATISTICAL_SAMPLE_COUNT as f64).abs() as f32;
assert!(
mean_bias < BIAS_TOLERANCE,
"estimate mean bias {mean_bias} exceeds tolerance {BIAS_TOLERANCE}"
);
}
#[test]
fn precompute_projections_correct_length() {
let query = pseudo_random_vec(TEST_DIM, TEST_SEED_A);
let projections = precompute_query_projections(&query, TEST_DIM, TEST_QJL_SEED);
assert_eq!(projections.len(), TEST_DIM);
}
#[test]
fn precompute_projections_deterministic() {
let query = pseudo_random_vec(TEST_DIM, TEST_SEED_A);
let proj_a = precompute_query_projections(&query, TEST_DIM, TEST_QJL_SEED);
let proj_b = precompute_query_projections(&query, TEST_DIM, TEST_QJL_SEED);
assert_eq!(proj_a, proj_b);
}
#[test]
fn unpack_signs_roundtrip() {
let bools: Vec<bool> = (0..PACK_TEST_LEN)
.map(|i| i % SIGN_PACK_PATTERN_STRIDE == 0)
.collect();
let packed = pack_sign_bits(&bools);
let unpacked = unpack_signs_to_f32(&packed, PACK_TEST_LEN);
for (i, &expected_positive) in bools.iter().enumerate() {
let expected = if expected_positive {
POSITIVE_SIGN
} else {
NEGATIVE_SIGN
};
assert!(
(unpacked[i] - expected).abs() < FLOAT_EPSILON,
"bit {i}: expected {expected}, got {}",
unpacked[i]
);
}
}
#[test]
fn qjl_correction_matches_dot_product() {
const R_QUERY_SCALE: f32 = 0.1;
const R_QUERY_OFFSET: f32 = 3.0;
let r_query: Vec<f32> = (0..TEST_DIM)
.map(|i| (i as f32) * R_QUERY_SCALE - R_QUERY_OFFSET)
.collect();
let bools: Vec<bool> = (0..TEST_DIM).map(|i| i % 3 == 0).collect();
let packed_signs = pack_sign_bits(&bools);
let expected: f32 = (0..TEST_DIM)
.map(|j| r_query[j] * sign_bit(&packed_signs, j))
.sum();
let actual = qjl_correction(&r_query, &packed_signs, TEST_DIM);
assert!(
(expected - actual).abs() < FLOAT_EPSILON,
"qjl_correction mismatch: expected {expected}, got {actual}"
);
}
#[test]
fn from_parts_roundtrip() {
let config = TurboQuantConfig::new(BITS_3, TEST_DIM)
.unwrap()
.with_seed(TEST_ROTATION_SEED);
let data = pseudo_random_vec(TEST_DIM, TEST_QJL_SEED);
let original = quantize_with_qjl(&config, &data, TEST_QJL_SEED).unwrap();
let reconstructed = QjlBlock::from_parts(
PackedBlock::from_raw(
original.polar_block.bits,
original.polar_block.scale,
original.polar_block.packed_indices.to_vec(),
),
original.qjl_signs.to_vec(),
original.residual_norm,
);
assert_eq!(reconstructed.polar_block.bits, original.polar_block.bits);
assert_eq!(reconstructed.polar_block.scale, original.polar_block.scale);
assert_eq!(
reconstructed.polar_block.packed_indices,
original.polar_block.packed_indices
);
assert_eq!(reconstructed.qjl_signs, original.qjl_signs);
assert_eq!(reconstructed.residual_norm, original.residual_norm);
}
#[test]
fn accessor_polar_block_matches_quantized() {
let config = TurboQuantConfig::new(BITS_3, TEST_DIM)
.unwrap()
.with_seed(TEST_ROTATION_SEED);
let data = pseudo_random_vec(TEST_DIM, TEST_QJL_SEED);
let block = quantize_with_qjl(&config, &data, TEST_QJL_SEED).unwrap();
assert_eq!(block.polar_block.bits, BITS_3 - 1);
assert!(block.polar_block.scale.to_f32() > 0.0);
assert!(!block.polar_block.packed_indices.is_empty());
const BITS_PER_BYTE: usize = 8;
let expected_sign_bytes = TEST_DIM.div_ceil(BITS_PER_BYTE);
assert_eq!(block.qjl_signs.len(), expected_sign_bytes);
}
}