1use std::f32::consts::PI;
22
23use ndarray::{Array1, Array2};
24use rand::SeedableRng;
25use rand_chacha::ChaCha8Rng;
26use rand_distr::{Distribution, StandardNormal};
27
28pub struct CompressedCorpus {
40 pub n: usize,
42 pub pairs: usize,
44 pub radii: Vec<f32>,
46 pub indices: Vec<u8>,
48}
49
50#[derive(Clone)]
52pub struct CompressedCode {
53 pub radii: Vec<f32>,
55 pub angle_indices: Vec<u8>,
57}
58
59impl CompressedCode {
60 #[must_use]
62 pub fn encoded_bytes(&self) -> usize {
63 self.radii.len() * 4 + self.angle_indices.len()
64 }
65}
66
67pub struct PolarCodec {
73 dim: usize,
74 #[expect(dead_code, reason = "stored for serialization / reconstruction")]
75 bits: u8,
76 levels: usize,
77 pairs: usize,
78 rotation: Array2<f32>,
80 cos_table: Vec<f32>,
82 sin_table: Vec<f32>,
83}
84
85impl PolarCodec {
86 #[must_use]
92 pub fn new(dim: usize, bits: u8, seed: u64) -> Self {
93 assert!(
94 dim > 0 && dim.is_multiple_of(2),
95 "dim must be even and non-zero"
96 );
97 assert!(bits > 0 && bits <= 8, "bits must be 1..=8");
98
99 let levels = 1usize << bits;
100 let pairs = dim / 2;
101 let rotation = generate_rotation(dim, seed);
102
103 let mut cos_table = Vec::with_capacity(levels);
104 let mut sin_table = Vec::with_capacity(levels);
105 for j in 0..levels {
106 let theta = (j as f32 / levels as f32) * 2.0 * PI - PI;
107 cos_table.push(theta.cos());
108 sin_table.push(theta.sin());
109 }
110
111 Self {
112 dim,
113 bits,
114 levels,
115 pairs,
116 rotation,
117 cos_table,
118 sin_table,
119 }
120 }
121
122 #[must_use]
124 pub fn pairs(&self) -> usize {
125 self.pairs
126 }
127
128 #[must_use]
134 pub fn encode(&self, vector: &[f32]) -> CompressedCode {
135 assert_eq!(vector.len(), self.dim);
136 let x = Array1::from_vec(vector.to_vec());
137 let rotated = self.rotation.dot(&x);
138
139 let mut radii = Vec::with_capacity(self.pairs);
140 let mut angle_indices = Vec::with_capacity(self.pairs);
141 for i in 0..self.pairs {
142 let (r, idx) = self.encode_pair(rotated[2 * i], rotated[2 * i + 1]);
143 radii.push(r);
144 angle_indices.push(idx);
145 }
146 CompressedCode {
147 radii,
148 angle_indices,
149 }
150 }
151
152 #[must_use]
161 pub fn encode_batch(&self, vectors: &Array2<f32>) -> CompressedCorpus {
162 assert_eq!(vectors.ncols(), self.dim);
163 let n = vectors.nrows();
164
165 let rotated = vectors.dot(&self.rotation.t());
167
168 let total = n * self.pairs;
169 let mut radii = Vec::with_capacity(total);
170 let mut indices = Vec::with_capacity(total);
171
172 for row in 0..n {
173 for i in 0..self.pairs {
174 let (r, idx) = self.encode_pair(rotated[[row, 2 * i]], rotated[[row, 2 * i + 1]]);
175 radii.push(r);
176 indices.push(idx);
177 }
178 }
179
180 CompressedCorpus {
181 n,
182 pairs: self.pairs,
183 radii,
184 indices,
185 }
186 }
187
188 #[must_use]
190 pub fn encode_batch_codes(&self, vectors: &Array2<f32>) -> Vec<CompressedCode> {
191 let corpus = self.encode_batch(vectors);
192 (0..corpus.n)
193 .map(|v| {
194 let off = v * corpus.pairs;
195 CompressedCode {
196 radii: corpus.radii[off..off + corpus.pairs].to_vec(),
197 angle_indices: corpus.indices[off..off + corpus.pairs].to_vec(),
198 }
199 })
200 .collect()
201 }
202
203 #[must_use]
212 pub fn prepare_query(&self, query: &[f32]) -> QueryState {
213 assert_eq!(query.len(), self.dim);
214 let q = Array1::from_vec(query.to_vec());
215 let rotated = self.rotation.dot(&q);
216
217 let mut centroid_q = vec![0.0f32; self.pairs * self.levels];
220 for i in 0..self.pairs {
221 let q_a = rotated[2 * i];
222 let q_b = rotated[2 * i + 1];
223 let base = i * self.levels;
224 for j in 0..self.levels {
225 centroid_q[base + j] = q_a * self.cos_table[j] + q_b * self.sin_table[j];
226 }
227 }
228
229 QueryState {
230 centroid_q,
231 pairs: self.pairs,
232 levels: self.levels,
233 }
234 }
235
236 #[must_use]
244 #[expect(
245 clippy::needless_range_loop,
246 reason = "index-based loop is clearer for strided SoA access"
247 )]
248 pub fn scan_corpus(&self, corpus: &CompressedCorpus, qs: &QueryState) -> Vec<f32> {
249 let n = corpus.n;
250 let pairs = corpus.pairs;
251 let mut scores = vec![0.0f32; n];
252
253 for v in 0..n {
256 let base = v * pairs;
257 let mut score = 0.0f32;
258
259 let chunks = pairs / 4;
261 let remainder = pairs % 4;
262
263 for c in 0..chunks {
264 let i = base + c * 4;
265 let i0 = corpus.indices[i] as usize;
266 let i1 = corpus.indices[i + 1] as usize;
267 let i2 = corpus.indices[i + 2] as usize;
268 let i3 = corpus.indices[i + 3] as usize;
269
270 let p = c * 4;
271 score += corpus.radii[i] * qs.centroid_q[p * qs.levels + i0];
272 score += corpus.radii[i + 1] * qs.centroid_q[(p + 1) * qs.levels + i1];
273 score += corpus.radii[i + 2] * qs.centroid_q[(p + 2) * qs.levels + i2];
274 score += corpus.radii[i + 3] * qs.centroid_q[(p + 3) * qs.levels + i3];
275 }
276 for r in 0..remainder {
277 let i = base + chunks * 4 + r;
278 let p = chunks * 4 + r;
279 let j = corpus.indices[i] as usize;
280 score += corpus.radii[i] * qs.centroid_q[p * qs.levels + j];
281 }
282
283 scores[v] = score;
284 }
285
286 scores
287 }
288
289 #[must_use]
291 pub fn batch_scan(&self, codes: &[CompressedCode], qs: &QueryState) -> Vec<f32> {
292 codes
293 .iter()
294 .map(|code| {
295 let mut score = 0.0f32;
296 for i in 0..qs.pairs {
297 let j = code.angle_indices[i] as usize;
298 score += code.radii[i] * qs.centroid_q[i * qs.levels + j];
299 }
300 score
301 })
302 .collect()
303 }
304
305 #[inline]
306 #[expect(
307 clippy::cast_possible_truncation,
308 clippy::cast_sign_loss,
309 reason = "normalized angle [0,1) × levels fits in u8 (max 16 levels)"
310 )]
311 fn encode_pair(&self, a: f32, b: f32) -> (f32, u8) {
312 let r = (a * a + b * b).sqrt();
313 let theta = b.atan2(a);
314 let normalized = (theta + PI) / (2.0 * PI);
315 let idx = ((normalized * self.levels as f32) as usize).min(self.levels - 1);
316 (r, idx as u8)
317 }
318}
319
320pub struct QueryState {
322 pub centroid_q: Vec<f32>,
324 pub pairs: usize,
326 pub levels: usize,
328}
329
330fn generate_rotation(dim: usize, seed: u64) -> Array2<f32> {
336 let mut rng = ChaCha8Rng::seed_from_u64(seed);
337 let mut data = Vec::with_capacity(dim * dim);
338 for _ in 0..(dim * dim) {
339 data.push(StandardNormal.sample(&mut rng));
340 }
341 let a = Array2::from_shape_vec((dim, dim), data).expect("shape matches data length");
342 gram_schmidt_qr(a)
343}
344
345fn gram_schmidt_qr(mut q: Array2<f32>) -> Array2<f32> {
347 let n = q.ncols();
348 for i in 0..n {
349 let norm: f32 = q.column(i).iter().map(|x| x * x).sum::<f32>().sqrt();
350 if norm < 1e-10 {
351 continue;
352 }
353 let inv = 1.0 / norm;
354 for row in 0..q.nrows() {
355 q[[row, i]] *= inv;
356 }
357 for j in (i + 1)..n {
358 let dot: f32 = (0..q.nrows()).map(|row| q[[row, i]] * q[[row, j]]).sum();
359 for row in 0..q.nrows() {
360 q[[row, j]] -= dot * q[[row, i]];
361 }
362 }
363 }
364 q
365}
366
367#[cfg(test)]
372mod tests {
373 use super::*;
374
375 fn l2_normalize(v: &mut [f32]) {
376 let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
377 if norm > 1e-10 {
378 for x in v.iter_mut() {
379 *x /= norm;
380 }
381 }
382 }
383
384 #[test]
385 fn rotation_is_orthogonal() {
386 let r = generate_rotation(8, 42);
387 let eye = r.dot(&r.t());
388 for i in 0..8 {
389 for j in 0..8 {
390 let expected = if i == j { 1.0 } else { 0.0 };
391 assert!(
392 (eye[[i, j]] - expected).abs() < 1e-5,
393 "Q×Qᵀ[{i},{j}] = {}, expected {expected}",
394 eye[[i, j]]
395 );
396 }
397 }
398 }
399
400 #[test]
401 fn encode_decode_roundtrip() {
402 let codec = PolarCodec::new(8, 4, 42);
403 let mut v = vec![0.3, -0.1, 0.5, 0.2, -0.4, 0.1, 0.3, -0.2];
404 l2_normalize(&mut v);
405 let code = codec.encode(&v);
406 assert_eq!(code.radii.len(), 4);
407 assert_eq!(code.angle_indices.len(), 4);
408 }
409
410 #[test]
411 fn corpus_scan_recall_and_throughput() {
412 let dim = 768;
413 let n = 1000;
414 let codec = PolarCodec::new(dim, 4, 42);
415
416 let mut vecs = Array2::<f32>::zeros((n, dim));
418 for i in 0..n {
419 for d in 0..dim {
420 vecs[[i, d]] = ((i * 17 + d * 31) as f32).sin();
421 }
422 let norm: f32 = vecs.row(i).iter().map(|x| x * x).sum::<f32>().sqrt();
423 for d in 0..dim {
424 vecs[[i, d]] /= norm;
425 }
426 }
427
428 let t0 = std::time::Instant::now();
430 let corpus = codec.encode_batch(&vecs);
431 let encode_ms = t0.elapsed().as_secs_f64() * 1000.0;
432 eprintln!(
433 "encode {n} → SoA corpus: {encode_ms:.1}ms ({:.1}µs/vec)",
434 encode_ms * 1000.0 / n as f64
435 );
436
437 let mut query = vec![0.0f32; dim];
439 for d in 0..dim {
440 query[d] = ((42 * 7 + d * 13) as f32).sin();
441 }
442 l2_normalize(&mut query);
443
444 let query_arr = Array1::from_vec(query.clone());
446 let mut exact: Vec<(usize, f32)> =
447 (0..n).map(|i| (i, vecs.row(i).dot(&query_arr))).collect();
448 exact.sort_by(|a, b| b.1.total_cmp(&a.1));
449
450 let t1 = std::time::Instant::now();
452 let qs = codec.prepare_query(&query);
453 let prep_us = t1.elapsed().as_secs_f64() * 1e6;
454
455 let t2 = std::time::Instant::now();
456 let scores = codec.scan_corpus(&corpus, &qs);
457 let scan_us = t2.elapsed().as_secs_f64() * 1e6;
458
459 eprintln!(
460 "prepare: {prep_us:.0}µs, scan {n}: {scan_us:.0}µs ({:.2}µs/vec)",
461 scan_us / n as f64
462 );
463 eprintln!("scan throughput: {:.1}M vec/s", n as f64 / scan_us);
464
465 let mut approx: Vec<(usize, f32)> = scores.into_iter().enumerate().collect();
467 approx.sort_by(|a, b| b.1.total_cmp(&a.1));
468 let exact_top10: Vec<usize> = exact.iter().take(10).map(|(i, _)| *i).collect();
469 let approx_top10: Vec<usize> = approx.iter().take(10).map(|(i, _)| *i).collect();
470 let recall = exact_top10
471 .iter()
472 .filter(|i| approx_top10.contains(i))
473 .count();
474 eprintln!("Recall@10: {recall}/10");
475 assert!(
478 recall >= 4,
479 "raw scan recall should be >= 4/10, got {recall}/10"
480 );
481 }
482
483 #[test]
485 #[cfg(feature = "metal")]
486 fn metal_turboquant_scan() {
487 let dim = 768;
488 let n = 10_000;
489 let codec = PolarCodec::new(dim, 4, 42);
490
491 let mut vecs = Array2::<f32>::zeros((n, dim));
493 for i in 0..n {
494 for d in 0..dim {
495 vecs[[i, d]] = ((i * 17 + d * 31) as f32).sin();
496 }
497 let norm: f32 = vecs.row(i).iter().map(|x| x * x).sum::<f32>().sqrt();
498 for d in 0..dim {
499 vecs[[i, d]] /= norm;
500 }
501 }
502
503 let corpus = codec.encode_batch(&vecs);
504 let mut query = vec![0.0f32; dim];
505 for d in 0..dim {
506 query[d] = ((42 * 7 + d * 13) as f32).sin();
507 }
508 l2_normalize(&mut query);
509 let qs = codec.prepare_query(&query);
510
511 let t0 = std::time::Instant::now();
513 let cpu_scores = codec.scan_corpus(&corpus, &qs);
514 let cpu_us = t0.elapsed().as_secs_f64() * 1e6;
515
516 let driver = crate::backend::driver::metal::MetalDriver::new().unwrap();
518
519 let t_cold = std::time::Instant::now();
521 let gpu_corpus = driver
522 .turboquant_upload_corpus(&corpus.radii, &corpus.indices)
523 .unwrap();
524 let upload_us = t_cold.elapsed().as_secs_f64() * 1e6;
525
526 let t_warm = std::time::Instant::now();
527 let gpu_scores = driver
528 .turboquant_scan_gpu(&gpu_corpus, &qs.centroid_q, n, corpus.pairs, qs.levels)
529 .unwrap();
530 let warm_us = t_warm.elapsed().as_secs_f64() * 1e6;
531
532 let t_hot = std::time::Instant::now();
534 let _ = driver
535 .turboquant_scan_gpu(&gpu_corpus, &qs.centroid_q, n, corpus.pairs, qs.levels)
536 .unwrap();
537 let hot_us = t_hot.elapsed().as_secs_f64() * 1e6;
538
539 eprintln!("10K vectors:");
540 eprintln!(" CPU: {cpu_us:.0}µs ({:.1}M/s)", n as f64 / cpu_us);
541 eprintln!(" GPU upload: {upload_us:.0}µs (one-time)");
542 eprintln!(
543 " GPU warm: {warm_us:.0}µs ({:.1}M/s, {:.1}× vs CPU)",
544 n as f64 / warm_us,
545 cpu_us / warm_us
546 );
547 eprintln!(
548 " GPU hot: {hot_us:.0}µs ({:.1}M/s, {:.1}× vs CPU)",
549 n as f64 / hot_us,
550 cpu_us / hot_us
551 );
552
553 let mut max_diff = 0.0f32;
555 for i in 0..n {
556 let diff = (cpu_scores[i] - gpu_scores[i]).abs();
557 if diff > max_diff {
558 max_diff = diff;
559 }
560 }
561 eprintln!("max CPU/GPU score diff: {max_diff:.6}");
562 assert!(
563 max_diff < 0.01,
564 "GPU scores should match CPU within 0.01, got {max_diff}"
565 );
566 }
567}