1use std::f32::consts::PI;
22
23use ndarray::{Array1, Array2};
24use rand::SeedableRng;
25use rand_chacha::ChaCha8Rng;
26use rand_distr::{Distribution, StandardNormal};
27
28pub struct CompressedCorpus {
40 pub n: usize,
42 pub pairs: usize,
44 pub radii: Vec<f32>,
46 pub indices: Vec<u8>,
48}
49
50#[derive(Clone)]
52pub struct CompressedCode {
53 pub radii: Vec<f32>,
55 pub angle_indices: Vec<u8>,
57}
58
59impl CompressedCode {
60 #[must_use]
62 pub fn encoded_bytes(&self) -> usize {
63 self.radii.len() * 4 + self.angle_indices.len()
64 }
65}
66
67pub struct PolarCodec {
73 dim: usize,
74 #[expect(dead_code, reason = "stored for serialization / reconstruction")]
75 bits: u8,
76 levels: usize,
77 pairs: usize,
78 rotation: Array2<f32>,
80 cos_table: Vec<f32>,
82 sin_table: Vec<f32>,
83}
84
85impl PolarCodec {
86 #[must_use]
92 pub fn new(dim: usize, bits: u8, seed: u64) -> Self {
93 assert!(
94 dim > 0 && dim.is_multiple_of(2),
95 "dim must be even and non-zero"
96 );
97 assert!(bits > 0 && bits <= 8, "bits must be 1..=8");
98
99 let levels = 1usize << bits;
100 let pairs = dim / 2;
101 let rotation = generate_rotation(dim, seed);
102
103 let mut cos_table = Vec::with_capacity(levels);
104 let mut sin_table = Vec::with_capacity(levels);
105 for j in 0..levels {
106 let theta = (j as f32 / levels as f32) * 2.0 * PI - PI;
107 cos_table.push(theta.cos());
108 sin_table.push(theta.sin());
109 }
110
111 Self {
112 dim,
113 bits,
114 levels,
115 pairs,
116 rotation,
117 cos_table,
118 sin_table,
119 }
120 }
121
122 #[must_use]
124 pub fn pairs(&self) -> usize {
125 self.pairs
126 }
127
128 #[must_use]
130 pub fn encode(&self, vector: &[f32]) -> CompressedCode {
131 assert_eq!(vector.len(), self.dim);
132 let x = Array1::from_vec(vector.to_vec());
133 let rotated = self.rotation.dot(&x);
134
135 let mut radii = Vec::with_capacity(self.pairs);
136 let mut angle_indices = Vec::with_capacity(self.pairs);
137 for i in 0..self.pairs {
138 let (r, idx) = self.encode_pair(rotated[2 * i], rotated[2 * i + 1]);
139 radii.push(r);
140 angle_indices.push(idx);
141 }
142 CompressedCode {
143 radii,
144 angle_indices,
145 }
146 }
147
148 #[must_use]
153 pub fn encode_batch(&self, vectors: &Array2<f32>) -> CompressedCorpus {
154 assert_eq!(vectors.ncols(), self.dim);
155 let n = vectors.nrows();
156
157 let rotated = vectors.dot(&self.rotation.t());
159
160 let total = n * self.pairs;
161 let mut radii = Vec::with_capacity(total);
162 let mut indices = Vec::with_capacity(total);
163
164 for row in 0..n {
165 for i in 0..self.pairs {
166 let (r, idx) = self.encode_pair(rotated[[row, 2 * i]], rotated[[row, 2 * i + 1]]);
167 radii.push(r);
168 indices.push(idx);
169 }
170 }
171
172 CompressedCorpus {
173 n,
174 pairs: self.pairs,
175 radii,
176 indices,
177 }
178 }
179
180 #[must_use]
182 pub fn encode_batch_codes(&self, vectors: &Array2<f32>) -> Vec<CompressedCode> {
183 let corpus = self.encode_batch(vectors);
184 (0..corpus.n)
185 .map(|v| {
186 let off = v * corpus.pairs;
187 CompressedCode {
188 radii: corpus.radii[off..off + corpus.pairs].to_vec(),
189 angle_indices: corpus.indices[off..off + corpus.pairs].to_vec(),
190 }
191 })
192 .collect()
193 }
194
195 #[must_use]
200 pub fn prepare_query(&self, query: &[f32]) -> QueryState {
201 assert_eq!(query.len(), self.dim);
202 let q = Array1::from_vec(query.to_vec());
203 let rotated = self.rotation.dot(&q);
204
205 let mut centroid_q = vec![0.0f32; self.pairs * self.levels];
208 for i in 0..self.pairs {
209 let q_a = rotated[2 * i];
210 let q_b = rotated[2 * i + 1];
211 let base = i * self.levels;
212 for j in 0..self.levels {
213 centroid_q[base + j] = q_a * self.cos_table[j] + q_b * self.sin_table[j];
214 }
215 }
216
217 QueryState {
218 centroid_q,
219 pairs: self.pairs,
220 levels: self.levels,
221 }
222 }
223
224 #[must_use]
232 pub fn scan_corpus(&self, corpus: &CompressedCorpus, qs: &QueryState) -> Vec<f32> {
233 let n = corpus.n;
234 let pairs = corpus.pairs;
235 let mut scores = vec![0.0f32; n];
236
237 for v in 0..n {
240 let base = v * pairs;
241 let mut score = 0.0f32;
242
243 let chunks = pairs / 4;
245 let remainder = pairs % 4;
246
247 for c in 0..chunks {
248 let i = base + c * 4;
249 let i0 = corpus.indices[i] as usize;
250 let i1 = corpus.indices[i + 1] as usize;
251 let i2 = corpus.indices[i + 2] as usize;
252 let i3 = corpus.indices[i + 3] as usize;
253
254 let p = c * 4;
255 score += corpus.radii[i] * qs.centroid_q[p * qs.levels + i0];
256 score += corpus.radii[i + 1] * qs.centroid_q[(p + 1) * qs.levels + i1];
257 score += corpus.radii[i + 2] * qs.centroid_q[(p + 2) * qs.levels + i2];
258 score += corpus.radii[i + 3] * qs.centroid_q[(p + 3) * qs.levels + i3];
259 }
260 for r in 0..remainder {
261 let i = base + chunks * 4 + r;
262 let p = chunks * 4 + r;
263 let j = corpus.indices[i] as usize;
264 score += corpus.radii[i] * qs.centroid_q[p * qs.levels + j];
265 }
266
267 scores[v] = score;
268 }
269
270 scores
271 }
272
273 #[must_use]
275 pub fn batch_scan(&self, codes: &[CompressedCode], qs: &QueryState) -> Vec<f32> {
276 codes
277 .iter()
278 .map(|code| {
279 let mut score = 0.0f32;
280 for i in 0..qs.pairs {
281 let j = code.angle_indices[i] as usize;
282 score += code.radii[i] * qs.centroid_q[i * qs.levels + j];
283 }
284 score
285 })
286 .collect()
287 }
288
289 #[inline]
290 fn encode_pair(&self, a: f32, b: f32) -> (f32, u8) {
291 let r = (a * a + b * b).sqrt();
292 let theta = b.atan2(a);
293 let normalized = (theta + PI) / (2.0 * PI);
294 let idx = ((normalized * self.levels as f32) as usize).min(self.levels - 1);
295 (r, idx as u8)
296 }
297}
298
299pub struct QueryState {
301 pub centroid_q: Vec<f32>,
303 pub pairs: usize,
305 pub levels: usize,
307}
308
309fn generate_rotation(dim: usize, seed: u64) -> Array2<f32> {
315 let mut rng = ChaCha8Rng::seed_from_u64(seed);
316 let mut data = Vec::with_capacity(dim * dim);
317 for _ in 0..(dim * dim) {
318 data.push(StandardNormal.sample(&mut rng));
319 }
320 let a = Array2::from_shape_vec((dim, dim), data).expect("shape matches data length");
321 gram_schmidt_qr(a)
322}
323
324fn gram_schmidt_qr(mut q: Array2<f32>) -> Array2<f32> {
326 let n = q.ncols();
327 for i in 0..n {
328 let norm: f32 = q.column(i).iter().map(|x| x * x).sum::<f32>().sqrt();
329 if norm < 1e-10 {
330 continue;
331 }
332 let inv = 1.0 / norm;
333 for row in 0..q.nrows() {
334 q[[row, i]] *= inv;
335 }
336 for j in (i + 1)..n {
337 let dot: f32 = (0..q.nrows()).map(|row| q[[row, i]] * q[[row, j]]).sum();
338 for row in 0..q.nrows() {
339 q[[row, j]] -= dot * q[[row, i]];
340 }
341 }
342 }
343 q
344}
345
346#[cfg(test)]
351mod tests {
352 use super::*;
353
354 fn l2_normalize(v: &mut [f32]) {
355 let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
356 if norm > 1e-10 {
357 for x in v.iter_mut() {
358 *x /= norm;
359 }
360 }
361 }
362
363 #[test]
364 fn rotation_is_orthogonal() {
365 let r = generate_rotation(8, 42);
366 let eye = r.dot(&r.t());
367 for i in 0..8 {
368 for j in 0..8 {
369 let expected = if i == j { 1.0 } else { 0.0 };
370 assert!(
371 (eye[[i, j]] - expected).abs() < 1e-5,
372 "Q×Qᵀ[{i},{j}] = {}, expected {expected}",
373 eye[[i, j]]
374 );
375 }
376 }
377 }
378
379 #[test]
380 fn encode_decode_roundtrip() {
381 let codec = PolarCodec::new(8, 4, 42);
382 let mut v = vec![0.3, -0.1, 0.5, 0.2, -0.4, 0.1, 0.3, -0.2];
383 l2_normalize(&mut v);
384 let code = codec.encode(&v);
385 assert_eq!(code.radii.len(), 4);
386 assert_eq!(code.angle_indices.len(), 4);
387 }
388
389 #[test]
390 fn corpus_scan_recall_and_throughput() {
391 let dim = 768;
392 let n = 1000;
393 let codec = PolarCodec::new(dim, 4, 42);
394
395 let mut vecs = Array2::<f32>::zeros((n, dim));
397 for i in 0..n {
398 for d in 0..dim {
399 vecs[[i, d]] = ((i * 17 + d * 31) as f32).sin();
400 }
401 let norm: f32 = vecs.row(i).iter().map(|x| x * x).sum::<f32>().sqrt();
402 for d in 0..dim {
403 vecs[[i, d]] /= norm;
404 }
405 }
406
407 let t0 = std::time::Instant::now();
409 let corpus = codec.encode_batch(&vecs);
410 let encode_ms = t0.elapsed().as_secs_f64() * 1000.0;
411 eprintln!(
412 "encode {n} → SoA corpus: {encode_ms:.1}ms ({:.1}µs/vec)",
413 encode_ms * 1000.0 / n as f64
414 );
415
416 let mut query = vec![0.0f32; dim];
418 for d in 0..dim {
419 query[d] = ((42 * 7 + d * 13) as f32).sin();
420 }
421 l2_normalize(&mut query);
422
423 let query_arr = Array1::from_vec(query.clone());
425 let mut exact: Vec<(usize, f32)> =
426 (0..n).map(|i| (i, vecs.row(i).dot(&query_arr))).collect();
427 exact.sort_by(|a, b| b.1.total_cmp(&a.1));
428
429 let t1 = std::time::Instant::now();
431 let qs = codec.prepare_query(&query);
432 let prep_us = t1.elapsed().as_secs_f64() * 1e6;
433
434 let t2 = std::time::Instant::now();
435 let scores = codec.scan_corpus(&corpus, &qs);
436 let scan_us = t2.elapsed().as_secs_f64() * 1e6;
437
438 eprintln!(
439 "prepare: {prep_us:.0}µs, scan {n}: {scan_us:.0}µs ({:.2}µs/vec)",
440 scan_us / n as f64
441 );
442 eprintln!("scan throughput: {:.1}M vec/s", n as f64 / scan_us);
443
444 let mut approx: Vec<(usize, f32)> = scores.into_iter().enumerate().collect();
446 approx.sort_by(|a, b| b.1.total_cmp(&a.1));
447 let exact_top10: Vec<usize> = exact.iter().take(10).map(|(i, _)| *i).collect();
448 let approx_top10: Vec<usize> = approx.iter().take(10).map(|(i, _)| *i).collect();
449 let recall = exact_top10
450 .iter()
451 .filter(|i| approx_top10.contains(i))
452 .count();
453 eprintln!("Recall@10: {recall}/10");
454 assert!(
457 recall >= 4,
458 "raw scan recall should be >= 4/10, got {recall}/10"
459 );
460 }
461
462 #[test]
464 #[cfg(feature = "metal")]
465 fn metal_turboquant_scan() {
466 let dim = 768;
467 let n = 10_000;
468 let codec = PolarCodec::new(dim, 4, 42);
469
470 let mut vecs = Array2::<f32>::zeros((n, dim));
472 for i in 0..n {
473 for d in 0..dim {
474 vecs[[i, d]] = ((i * 17 + d * 31) as f32).sin();
475 }
476 let norm: f32 = vecs.row(i).iter().map(|x| x * x).sum::<f32>().sqrt();
477 for d in 0..dim {
478 vecs[[i, d]] /= norm;
479 }
480 }
481
482 let corpus = codec.encode_batch(&vecs);
483 let mut query = vec![0.0f32; dim];
484 for d in 0..dim {
485 query[d] = ((42 * 7 + d * 13) as f32).sin();
486 }
487 l2_normalize(&mut query);
488 let qs = codec.prepare_query(&query);
489
490 let t0 = std::time::Instant::now();
492 let cpu_scores = codec.scan_corpus(&corpus, &qs);
493 let cpu_us = t0.elapsed().as_secs_f64() * 1e6;
494
495 let driver = crate::backend::driver::metal::MetalDriver::new().unwrap();
497
498 let t_cold = std::time::Instant::now();
500 let gpu_corpus = driver
501 .turboquant_upload_corpus(&corpus.radii, &corpus.indices)
502 .unwrap();
503 let upload_us = t_cold.elapsed().as_secs_f64() * 1e6;
504
505 let t_warm = std::time::Instant::now();
506 let gpu_scores = driver
507 .turboquant_scan_gpu(&gpu_corpus, &qs.centroid_q, n, corpus.pairs, qs.levels)
508 .unwrap();
509 let warm_us = t_warm.elapsed().as_secs_f64() * 1e6;
510
511 let t_hot = std::time::Instant::now();
513 let _ = driver
514 .turboquant_scan_gpu(&gpu_corpus, &qs.centroid_q, n, corpus.pairs, qs.levels)
515 .unwrap();
516 let hot_us = t_hot.elapsed().as_secs_f64() * 1e6;
517
518 eprintln!("10K vectors:");
519 eprintln!(" CPU: {cpu_us:.0}µs ({:.1}M/s)", n as f64 / cpu_us);
520 eprintln!(" GPU upload: {upload_us:.0}µs (one-time)");
521 eprintln!(
522 " GPU warm: {warm_us:.0}µs ({:.1}M/s, {:.1}× vs CPU)",
523 n as f64 / warm_us,
524 cpu_us / warm_us
525 );
526 eprintln!(
527 " GPU hot: {hot_us:.0}µs ({:.1}M/s, {:.1}× vs CPU)",
528 n as f64 / hot_us,
529 cpu_us / hot_us
530 );
531
532 let mut max_diff = 0.0f32;
534 for i in 0..n {
535 let diff = (cpu_scores[i] - gpu_scores[i]).abs();
536 if diff > max_diff {
537 max_diff = diff;
538 }
539 }
540 eprintln!("max CPU/GPU score diff: {max_diff:.6}");
541 assert!(
542 max_diff < 0.01,
543 "GPU scores should match CPU within 0.01, got {max_diff}"
544 );
545 }
546}