use std::f32::consts::PI;
use ndarray::{Array1, Array2};
use rand::SeedableRng;
use rand_chacha::ChaCha8Rng;
use rand_distr::{Distribution, StandardNormal};
pub struct CompressedCorpus {
pub n: usize,
pub pairs: usize,
pub radii: Vec<f32>,
pub indices: Vec<u8>,
}
#[derive(Clone)]
pub struct CompressedCode {
pub radii: Vec<f32>,
pub angle_indices: Vec<u8>,
}
impl CompressedCode {
#[must_use]
pub fn encoded_bytes(&self) -> usize {
self.radii.len() * 4 + self.angle_indices.len()
}
}
pub struct PolarCodec {
dim: usize,
#[expect(dead_code, reason = "stored for serialization / reconstruction")]
bits: u8,
levels: usize,
pairs: usize,
rotation: Array2<f32>,
cos_table: Vec<f32>,
sin_table: Vec<f32>,
}
impl PolarCodec {
#[must_use]
pub fn new(dim: usize, bits: u8, seed: u64) -> Self {
assert!(
dim > 0 && dim.is_multiple_of(2),
"dim must be even and non-zero"
);
assert!(bits > 0 && bits <= 8, "bits must be 1..=8");
let levels = 1usize << bits;
let pairs = dim / 2;
let rotation = generate_rotation(dim, seed);
let mut cos_table = Vec::with_capacity(levels);
let mut sin_table = Vec::with_capacity(levels);
for j in 0..levels {
let theta = (j as f32 / levels as f32) * 2.0 * PI - PI;
cos_table.push(theta.cos());
sin_table.push(theta.sin());
}
Self {
dim,
bits,
levels,
pairs,
rotation,
cos_table,
sin_table,
}
}
#[must_use]
pub fn pairs(&self) -> usize {
self.pairs
}
#[must_use]
pub fn encode(&self, vector: &[f32]) -> CompressedCode {
assert_eq!(vector.len(), self.dim);
let x = Array1::from_vec(vector.to_vec());
let rotated = self.rotation.dot(&x);
let mut radii = Vec::with_capacity(self.pairs);
let mut angle_indices = Vec::with_capacity(self.pairs);
for i in 0..self.pairs {
let (r, idx) = self.encode_pair(rotated[2 * i], rotated[2 * i + 1]);
radii.push(r);
angle_indices.push(idx);
}
CompressedCode {
radii,
angle_indices,
}
}
#[must_use]
pub fn encode_batch(&self, vectors: &Array2<f32>) -> CompressedCorpus {
assert_eq!(vectors.ncols(), self.dim);
let n = vectors.nrows();
let rotated = vectors.dot(&self.rotation.t());
let total = n * self.pairs;
let mut radii = Vec::with_capacity(total);
let mut indices = Vec::with_capacity(total);
for row in 0..n {
for i in 0..self.pairs {
let (r, idx) = self.encode_pair(rotated[[row, 2 * i]], rotated[[row, 2 * i + 1]]);
radii.push(r);
indices.push(idx);
}
}
CompressedCorpus {
n,
pairs: self.pairs,
radii,
indices,
}
}
#[must_use]
pub fn encode_batch_codes(&self, vectors: &Array2<f32>) -> Vec<CompressedCode> {
let corpus = self.encode_batch(vectors);
(0..corpus.n)
.map(|v| {
let off = v * corpus.pairs;
CompressedCode {
radii: corpus.radii[off..off + corpus.pairs].to_vec(),
angle_indices: corpus.indices[off..off + corpus.pairs].to_vec(),
}
})
.collect()
}
#[must_use]
pub fn prepare_query(&self, query: &[f32]) -> QueryState {
assert_eq!(query.len(), self.dim);
let q = Array1::from_vec(query.to_vec());
let rotated = self.rotation.dot(&q);
let mut centroid_q = vec![0.0f32; self.pairs * self.levels];
for i in 0..self.pairs {
let q_a = rotated[2 * i];
let q_b = rotated[2 * i + 1];
let base = i * self.levels;
for j in 0..self.levels {
centroid_q[base + j] = q_a * self.cos_table[j] + q_b * self.sin_table[j];
}
}
QueryState {
centroid_q,
pairs: self.pairs,
levels: self.levels,
}
}
#[must_use]
#[expect(
clippy::needless_range_loop,
reason = "index-based loop is clearer for strided SoA access"
)]
pub fn scan_corpus(&self, corpus: &CompressedCorpus, qs: &QueryState) -> Vec<f32> {
let n = corpus.n;
let pairs = corpus.pairs;
let mut scores = vec![0.0f32; n];
for v in 0..n {
let base = v * pairs;
let mut score = 0.0f32;
let chunks = pairs / 4;
let remainder = pairs % 4;
for c in 0..chunks {
let i = base + c * 4;
let i0 = corpus.indices[i] as usize;
let i1 = corpus.indices[i + 1] as usize;
let i2 = corpus.indices[i + 2] as usize;
let i3 = corpus.indices[i + 3] as usize;
let p = c * 4;
score += corpus.radii[i] * qs.centroid_q[p * qs.levels + i0];
score += corpus.radii[i + 1] * qs.centroid_q[(p + 1) * qs.levels + i1];
score += corpus.radii[i + 2] * qs.centroid_q[(p + 2) * qs.levels + i2];
score += corpus.radii[i + 3] * qs.centroid_q[(p + 3) * qs.levels + i3];
}
for r in 0..remainder {
let i = base + chunks * 4 + r;
let p = chunks * 4 + r;
let j = corpus.indices[i] as usize;
score += corpus.radii[i] * qs.centroid_q[p * qs.levels + j];
}
scores[v] = score;
}
scores
}
#[must_use]
pub fn batch_scan(&self, codes: &[CompressedCode], qs: &QueryState) -> Vec<f32> {
codes
.iter()
.map(|code| {
let mut score = 0.0f32;
for i in 0..qs.pairs {
let j = code.angle_indices[i] as usize;
score += code.radii[i] * qs.centroid_q[i * qs.levels + j];
}
score
})
.collect()
}
#[inline]
#[expect(
clippy::cast_possible_truncation,
clippy::cast_sign_loss,
reason = "normalized angle [0,1) × levels fits in u8 (max 16 levels)"
)]
fn encode_pair(&self, a: f32, b: f32) -> (f32, u8) {
let r = (a * a + b * b).sqrt();
let theta = b.atan2(a);
let normalized = (theta + PI) / (2.0 * PI);
let idx = ((normalized * self.levels as f32) as usize).min(self.levels - 1);
(r, idx as u8)
}
}
pub struct QueryState {
pub centroid_q: Vec<f32>,
pub pairs: usize,
pub levels: usize,
}
fn generate_rotation(dim: usize, seed: u64) -> Array2<f32> {
let mut rng = ChaCha8Rng::seed_from_u64(seed);
let mut data = Vec::with_capacity(dim * dim);
for _ in 0..(dim * dim) {
data.push(StandardNormal.sample(&mut rng));
}
let a = Array2::from_shape_vec((dim, dim), data).expect("shape matches data length");
gram_schmidt_qr(a)
}
fn gram_schmidt_qr(mut q: Array2<f32>) -> Array2<f32> {
let n = q.ncols();
for i in 0..n {
let norm: f32 = q.column(i).iter().map(|x| x * x).sum::<f32>().sqrt();
if norm < 1e-10 {
continue;
}
let inv = 1.0 / norm;
for row in 0..q.nrows() {
q[[row, i]] *= inv;
}
for j in (i + 1)..n {
let dot: f32 = (0..q.nrows()).map(|row| q[[row, i]] * q[[row, j]]).sum();
for row in 0..q.nrows() {
q[[row, j]] -= dot * q[[row, i]];
}
}
}
q
}
#[cfg(test)]
mod tests {
use super::*;
fn l2_normalize(v: &mut [f32]) {
let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 1e-10 {
for x in v.iter_mut() {
*x /= norm;
}
}
}
#[test]
fn rotation_is_orthogonal() {
let r = generate_rotation(8, 42);
let eye = r.dot(&r.t());
for i in 0..8 {
for j in 0..8 {
let expected = if i == j { 1.0 } else { 0.0 };
assert!(
(eye[[i, j]] - expected).abs() < 1e-5,
"Q×Qᵀ[{i},{j}] = {}, expected {expected}",
eye[[i, j]]
);
}
}
}
#[test]
fn encode_decode_roundtrip() {
let codec = PolarCodec::new(8, 4, 42);
let mut v = vec![0.3, -0.1, 0.5, 0.2, -0.4, 0.1, 0.3, -0.2];
l2_normalize(&mut v);
let code = codec.encode(&v);
assert_eq!(code.radii.len(), 4);
assert_eq!(code.angle_indices.len(), 4);
}
#[test]
fn corpus_scan_recall_and_throughput() {
let dim = 768;
let n = 1000;
let codec = PolarCodec::new(dim, 4, 42);
let mut vecs = Array2::<f32>::zeros((n, dim));
for i in 0..n {
for d in 0..dim {
vecs[[i, d]] = ((i * 17 + d * 31) as f32).sin();
}
let norm: f32 = vecs.row(i).iter().map(|x| x * x).sum::<f32>().sqrt();
for d in 0..dim {
vecs[[i, d]] /= norm;
}
}
let t0 = std::time::Instant::now();
let corpus = codec.encode_batch(&vecs);
let encode_ms = t0.elapsed().as_secs_f64() * 1000.0;
eprintln!(
"encode {n} → SoA corpus: {encode_ms:.1}ms ({:.1}µs/vec)",
encode_ms * 1000.0 / n as f64
);
let mut query = vec![0.0f32; dim];
for d in 0..dim {
query[d] = ((42 * 7 + d * 13) as f32).sin();
}
l2_normalize(&mut query);
let query_arr = Array1::from_vec(query.clone());
let mut exact: Vec<(usize, f32)> =
(0..n).map(|i| (i, vecs.row(i).dot(&query_arr))).collect();
exact.sort_by(|a, b| b.1.total_cmp(&a.1));
let t1 = std::time::Instant::now();
let qs = codec.prepare_query(&query);
let prep_us = t1.elapsed().as_secs_f64() * 1e6;
let t2 = std::time::Instant::now();
let scores = codec.scan_corpus(&corpus, &qs);
let scan_us = t2.elapsed().as_secs_f64() * 1e6;
eprintln!(
"prepare: {prep_us:.0}µs, scan {n}: {scan_us:.0}µs ({:.2}µs/vec)",
scan_us / n as f64
);
eprintln!("scan throughput: {:.1}M vec/s", n as f64 / scan_us);
let mut approx: Vec<(usize, f32)> = scores.into_iter().enumerate().collect();
approx.sort_by(|a, b| b.1.total_cmp(&a.1));
let exact_top10: Vec<usize> = exact.iter().take(10).map(|(i, _)| *i).collect();
let approx_top10: Vec<usize> = approx.iter().take(10).map(|(i, _)| *i).collect();
let recall = exact_top10
.iter()
.filter(|i| approx_top10.contains(i))
.count();
eprintln!("Recall@10: {recall}/10");
assert!(
recall >= 4,
"raw scan recall should be >= 4/10, got {recall}/10"
);
}
#[test]
#[cfg(feature = "metal")]
fn metal_turboquant_scan() {
let dim = 768;
let n = 10_000;
let codec = PolarCodec::new(dim, 4, 42);
let mut vecs = Array2::<f32>::zeros((n, dim));
for i in 0..n {
for d in 0..dim {
vecs[[i, d]] = ((i * 17 + d * 31) as f32).sin();
}
let norm: f32 = vecs.row(i).iter().map(|x| x * x).sum::<f32>().sqrt();
for d in 0..dim {
vecs[[i, d]] /= norm;
}
}
let corpus = codec.encode_batch(&vecs);
let mut query = vec![0.0f32; dim];
for d in 0..dim {
query[d] = ((42 * 7 + d * 13) as f32).sin();
}
l2_normalize(&mut query);
let qs = codec.prepare_query(&query);
let t0 = std::time::Instant::now();
let cpu_scores = codec.scan_corpus(&corpus, &qs);
let cpu_us = t0.elapsed().as_secs_f64() * 1e6;
let driver = crate::backend::driver::metal::MetalDriver::new().unwrap();
let t_cold = std::time::Instant::now();
let gpu_corpus = driver
.turboquant_upload_corpus(&corpus.radii, &corpus.indices)
.unwrap();
let upload_us = t_cold.elapsed().as_secs_f64() * 1e6;
let t_warm = std::time::Instant::now();
let gpu_scores = driver
.turboquant_scan_gpu(&gpu_corpus, &qs.centroid_q, n, corpus.pairs, qs.levels)
.unwrap();
let warm_us = t_warm.elapsed().as_secs_f64() * 1e6;
let t_hot = std::time::Instant::now();
let _ = driver
.turboquant_scan_gpu(&gpu_corpus, &qs.centroid_q, n, corpus.pairs, qs.levels)
.unwrap();
let hot_us = t_hot.elapsed().as_secs_f64() * 1e6;
eprintln!("10K vectors:");
eprintln!(" CPU: {cpu_us:.0}µs ({:.1}M/s)", n as f64 / cpu_us);
eprintln!(" GPU upload: {upload_us:.0}µs (one-time)");
eprintln!(
" GPU warm: {warm_us:.0}µs ({:.1}M/s, {:.1}× vs CPU)",
n as f64 / warm_us,
cpu_us / warm_us
);
eprintln!(
" GPU hot: {hot_us:.0}µs ({:.1}M/s, {:.1}× vs CPU)",
n as f64 / hot_us,
cpu_us / hot_us
);
let mut max_diff = 0.0f32;
for i in 0..n {
let diff = (cpu_scores[i] - gpu_scores[i]).abs();
if diff > max_diff {
max_diff = diff;
}
}
eprintln!("max CPU/GPU score diff: {max_diff:.6}");
assert!(
max_diff < 0.01,
"GPU scores should match CPU within 0.01, got {max_diff}"
);
}
}