use super::codebook::Codebook;
use super::qjl::QjlProjector;
use super::rotation::HadamardRotation;
#[derive(Debug, Clone)]
pub struct TurboQuantConfig {
pub bits: u8,
pub use_qjl: bool,
pub dim: usize,
}
#[derive(Debug, Clone)]
pub struct TurboQuantEngine {
rotation: HadamardRotation,
codebook: Codebook,
qjl: Option<QjlProjector>,
dim: usize,
padded_dim: usize,
}
pub struct CompressedVector {
pub packed_indices: Vec<u8>,
pub qjl_bits: Option<Vec<u64>>,
pub residual_norm: Option<f32>,
}
impl TurboQuantEngine {
pub fn new(config: &TurboQuantConfig, rotation_seed: u64, qjl_seed: u64) -> Self {
let rotation = HadamardRotation::new(rotation_seed, config.dim);
let padded_dim = rotation.padded_dim();
let codebook = Codebook::new(padded_dim, config.bits);
let qjl = if config.use_qjl {
Some(QjlProjector::new(qjl_seed, padded_dim))
} else {
None
};
Self {
rotation,
codebook,
qjl,
dim: config.dim,
padded_dim,
}
}
pub fn compress(&self, x: &[f32], rotated_buf: &mut Vec<f32>, deq_buf: &mut Vec<f32>) -> CompressedVector {
rotated_buf.resize(self.padded_dim, 0.0);
self.rotation.rotate(x, rotated_buf);
let mut packed_indices = Vec::new();
self.codebook.quantize_vector(rotated_buf, &mut packed_indices);
let (qjl_bits, residual_norm) = if let Some(ref proj) = self.qjl {
deq_buf.clear();
self.codebook
.dequantize_vector(&packed_indices, self.padded_dim, deq_buf);
let residual: Vec<f32> = rotated_buf
.iter()
.zip(deq_buf.iter())
.map(|(&r, &d)| r - d)
.collect();
let (bits, norm) = proj.compress(&residual);
(Some(bits), Some(norm))
} else {
(None, None)
};
CompressedVector {
packed_indices,
qjl_bits,
residual_norm,
}
}
pub fn attention_score(
&self,
query: &[f32],
compressed: &CompressedVector,
rotated_query_buf: &mut Vec<f32>,
) -> f32 {
rotated_query_buf.resize(self.padded_dim, 0.0);
self.rotation.rotate(query, rotated_query_buf);
let polar_score = self
.codebook
.dot_with_packed(rotated_query_buf, &compressed.packed_indices, self.padded_dim);
if let (Some(proj), Some(bits), Some(norm)) =
(&self.qjl, &compressed.qjl_bits, compressed.residual_norm)
{
let qjl_correction = proj.inner_product(rotated_query_buf, bits, norm);
polar_score + qjl_correction
} else {
polar_score
}
}
pub fn attention_scores(
&self,
query: &[f32],
keys: &[CompressedVector],
rotated_query_buf: &mut Vec<f32>,
scores: &mut Vec<f32>,
) {
rotated_query_buf.resize(self.padded_dim, 0.0);
self.rotation.rotate(query, rotated_query_buf);
let projected_query = self.qjl.as_ref().map(|proj| proj.project_query(rotated_query_buf));
scores.clear();
scores.reserve(keys.len());
for key in keys {
let polar_score = self
.codebook
.dot_with_packed(rotated_query_buf, &key.packed_indices, self.padded_dim);
let score = if let (Some(proj_q), Some(proj), Some(bits), Some(norm)) = (
&projected_query,
&self.qjl,
&key.qjl_bits,
key.residual_norm,
) {
let _ = proj;
let qjl_correction = proj.inner_product_fast(proj_q, bits, norm);
polar_score + qjl_correction
} else {
polar_score
};
scores.push(score);
}
}
pub fn bytes_per_entry(&self) -> usize {
let codebook_bytes = self.codebook.packed_bytes(self.padded_dim);
if self.qjl.is_some() {
let qjl_words = (self.padded_dim + 63) / 64;
codebook_bytes + qjl_words * 8 + 4 } else {
codebook_bytes
}
}
pub fn bits_per_element(&self) -> f32 {
let total_bits = self.bytes_per_entry() as f32 * 8.0;
total_bits / self.dim as f32
}
#[inline]
pub fn padded_dim(&self) -> usize {
self.padded_dim
}
#[inline]
pub fn dim(&self) -> usize {
self.dim
}
pub fn rotation(&self) -> &HadamardRotation {
&self.rotation
}
pub fn codebook(&self) -> &Codebook {
&self.codebook
}
pub fn qjl(&self) -> Option<&QjlProjector> {
self.qjl.as_ref()
}
}
pub fn softmax_inplace(scores: &mut [f32]) {
if scores.is_empty() {
return;
}
let max_val = scores.iter().copied().fold(f32::NEG_INFINITY, f32::max);
let mut sum = 0.0f32;
for s in scores.iter_mut() {
*s = (*s - max_val).exp();
sum += *s;
}
let inv_sum = 1.0 / sum;
for s in scores.iter_mut() {
*s *= inv_sum;
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_config(dim: usize, bits: u8, use_qjl: bool) -> TurboQuantConfig {
TurboQuantConfig { bits, use_qjl, dim }
}
#[test]
fn test_mse_compress_decompress() {
let config = make_config(64, 2, false);
let engine = TurboQuantEngine::new(&config, 100, 200);
let x: Vec<f32> = (0..64).map(|i| (i as f32 - 32.0) * 0.01).collect();
let mut rot_buf = Vec::new();
let mut deq_buf = Vec::new();
let compressed = engine.compress(&x, &mut rot_buf, &mut deq_buf);
assert!(compressed.qjl_bits.is_none());
assert!(compressed.residual_norm.is_none());
}
#[test]
fn test_prod_compress_has_qjl() {
let config = make_config(64, 2, true);
let engine = TurboQuantEngine::new(&config, 100, 200);
let x: Vec<f32> = (0..64).map(|i| (i as f32 - 32.0) * 0.01).collect();
let mut rot_buf = Vec::new();
let mut deq_buf = Vec::new();
let compressed = engine.compress(&x, &mut rot_buf, &mut deq_buf);
assert!(compressed.qjl_bits.is_some());
assert!(compressed.residual_norm.is_some());
}
#[test]
fn test_attention_score_mse_direction() {
let config = make_config(64, 2, false);
let engine = TurboQuantEngine::new(&config, 100, 200);
let q: Vec<f32> = (0..64).map(|i| (i as f32) * 0.02).collect();
let k_similar: Vec<f32> = (0..64).map(|i| (i as f32) * 0.02 + 0.001).collect();
let k_opposite: Vec<f32> = (0..64).map(|i| -(i as f32) * 0.02).collect();
let mut rot_buf = Vec::new();
let mut deq_buf = Vec::new();
let mut rot_q_buf = Vec::new();
let c_similar = engine.compress(&k_similar, &mut rot_buf, &mut deq_buf);
let c_opposite = engine.compress(&k_opposite, &mut rot_buf, &mut deq_buf);
let s1 = engine.attention_score(&q, &c_similar, &mut rot_q_buf);
let s2 = engine.attention_score(&q, &c_opposite, &mut rot_q_buf);
assert!(
s1 > s2,
"similar key should score higher: {s1} vs {s2}"
);
}
#[test]
fn test_batch_scores_match_individual() {
let config = make_config(32, 2, false);
let engine = TurboQuantEngine::new(&config, 100, 200);
let q: Vec<f32> = (0..32).map(|i| (i as f32) * 0.1).collect();
let keys_raw: Vec<Vec<f32>> = (0..5)
.map(|k| (0..32).map(|i| ((i + k) as f32) * 0.05).collect())
.collect();
let mut rot_buf = Vec::new();
let mut deq_buf = Vec::new();
let mut rot_q_buf = Vec::new();
let compressed: Vec<CompressedVector> = keys_raw
.iter()
.map(|k| engine.compress(k, &mut rot_buf, &mut deq_buf))
.collect();
let individual: Vec<f32> = compressed
.iter()
.map(|c| engine.attention_score(&q, c, &mut rot_q_buf))
.collect();
let mut batch = Vec::new();
engine.attention_scores(&q, &compressed, &mut rot_q_buf, &mut batch);
for (a, b) in individual.iter().zip(batch.iter()) {
assert!(
(a - b).abs() < 1e-4,
"batch should match individual: {a} vs {b}"
);
}
}
#[test]
fn test_softmax() {
let mut scores = vec![1.0, 2.0, 3.0];
softmax_inplace(&mut scores);
let sum: f32 = scores.iter().sum();
assert!((sum - 1.0).abs() < 1e-6);
assert!(scores[2] > scores[1]);
assert!(scores[1] > scores[0]);
}
#[test]
fn test_bytes_per_entry() {
let config = make_config(128, 2, false);
let engine = TurboQuantEngine::new(&config, 1, 2);
let pd = engine.padded_dim(); assert_eq!(pd, 128);
assert_eq!(engine.bytes_per_entry(), pd / 4);
let config_qjl = make_config(128, 2, true);
let engine_qjl = TurboQuantEngine::new(&config_qjl, 1, 2);
let expected = pd / 4 + (pd + 63) / 64 * 8 + 4;
assert_eq!(engine_qjl.bytes_per_entry(), expected);
}
#[test]
fn test_non_power_of_two_dim() {
let config = make_config(80, 2, false);
let engine = TurboQuantEngine::new(&config, 100, 200);
assert_eq!(engine.dim(), 80);
assert_eq!(engine.padded_dim(), 128);
let x: Vec<f32> = (0..80).map(|i| (i as f32 - 40.0) * 0.01).collect();
let mut rot_buf = Vec::new();
let mut deq_buf = Vec::new();
let mut rot_q_buf = Vec::new();
let compressed = engine.compress(&x, &mut rot_buf, &mut deq_buf);
let score = engine.attention_score(&x, &compressed, &mut rot_q_buf);
assert!(score > 0.0, "self-attention score should be positive");
}
}