1use crate::core::LuciError;
6#[cfg(target_arch = "aarch64")]
7use std::arch::aarch64::{vaddq_f32, vaddvq_f32, vdupq_n_f32, vfmaq_f32, vld1q_f32, vsubq_f32};
8
9pub mod global;
10pub mod hnsw;
11pub mod quantize;
12pub mod query;
13
14#[cfg(test)]
15mod distance_tests;
16
17#[derive(Clone, Copy, Debug, PartialEq, Eq)]
25#[repr(u8)]
26pub enum DistanceMetric {
27 Cosine = 0,
29 DotProduct = 1,
31 L2 = 2,
33}
34
35impl DistanceMetric {
36 pub fn from_byte(byte: u8) -> Self {
42 match byte {
43 0 => Self::Cosine,
44 1 => Self::DotProduct,
45 2 => Self::L2,
46 other => panic!(
47 "unknown distance metric byte {other}: segment is corrupted \
48 or was written by a newer version of Luci"
49 ),
50 }
51 }
52}
53
54pub fn distance(a: &[f32], b: &[f32], metric: DistanceMetric) -> f32 {
56 debug_assert_eq!(a.len(), b.len());
57 match metric {
58 DistanceMetric::Cosine => cosine_distance_normalized(a, b),
59 DistanceMetric::DotProduct => -dot_product(a, b),
60 DistanceMetric::L2 => l2_distance(a, b),
61 }
62}
63
64fn dot_product(a: &[f32], b: &[f32]) -> f32 {
77 debug_assert_eq!(a.len(), b.len());
78 #[cfg(target_arch = "aarch64")]
79 {
80 unsafe { dot_product_neon(a, b) }
82 }
83 #[cfg(not(target_arch = "aarch64"))]
84 {
85 a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
86 }
87}
88
89#[cfg(target_arch = "aarch64")]
90#[target_feature(enable = "neon")]
91unsafe fn dot_product_neon(a: &[f32], b: &[f32]) -> f32 {
92 let n = a.len();
93 let a_ptr = a.as_ptr();
94 let b_ptr = b.as_ptr();
95
96 unsafe {
99 let mut acc0 = vdupq_n_f32(0.0);
100 let mut acc1 = vdupq_n_f32(0.0);
101 let mut acc2 = vdupq_n_f32(0.0);
102 let mut acc3 = vdupq_n_f32(0.0);
103
104 let mut i = 0;
105 while i + 16 <= n {
106 let a0 = vld1q_f32(a_ptr.add(i));
107 let a1 = vld1q_f32(a_ptr.add(i + 4));
108 let a2 = vld1q_f32(a_ptr.add(i + 8));
109 let a3 = vld1q_f32(a_ptr.add(i + 12));
110 let b0 = vld1q_f32(b_ptr.add(i));
111 let b1 = vld1q_f32(b_ptr.add(i + 4));
112 let b2 = vld1q_f32(b_ptr.add(i + 8));
113 let b3 = vld1q_f32(b_ptr.add(i + 12));
114 acc0 = vfmaq_f32(acc0, a0, b0);
115 acc1 = vfmaq_f32(acc1, a1, b1);
116 acc2 = vfmaq_f32(acc2, a2, b2);
117 acc3 = vfmaq_f32(acc3, a3, b3);
118 i += 16;
119 }
120 while i + 4 <= n {
121 let av = vld1q_f32(a_ptr.add(i));
122 let bv = vld1q_f32(b_ptr.add(i));
123 acc0 = vfmaq_f32(acc0, av, bv);
124 i += 4;
125 }
126 let acc = vaddq_f32(vaddq_f32(acc0, acc1), vaddq_f32(acc2, acc3));
127 let mut sum = vaddvq_f32(acc);
128 while i < n {
129 sum += *a_ptr.add(i) * *b_ptr.add(i);
130 i += 1;
131 }
132 sum
133 }
134}
135
136fn cosine_distance_normalized(a: &[f32], b: &[f32]) -> f32 {
142 1.0 - dot_product(a, b)
145}
146
147pub fn normalize_in_place(v: &mut [f32]) -> Result<(), LuciError> {
158 let norm_sq: f32 = v.iter().map(|x| x * x).sum();
159 if !norm_sq.is_finite() || norm_sq == 0.0 {
160 return Err(LuciError::InvalidQuery(
161 "zero-length / non-finite vector not supported with cosine \
162 metric — use metric: dot_product to bypass normalization"
163 .into(),
164 ));
165 }
166 if (norm_sq - 1.0).abs() < 1e-4 {
167 return Ok(());
168 }
169 let inv = 1.0 / norm_sq.sqrt();
170 for x in v.iter_mut() {
171 *x *= inv;
172 }
173 Ok(())
174}
175
176fn l2_distance(a: &[f32], b: &[f32]) -> f32 {
183 debug_assert_eq!(a.len(), b.len());
184 #[cfg(target_arch = "aarch64")]
185 {
186 unsafe { l2_distance_neon(a, b) }
188 }
189 #[cfg(not(target_arch = "aarch64"))]
190 {
191 a.iter()
192 .zip(b.iter())
193 .map(|(x, y)| (x - y) * (x - y))
194 .sum::<f32>()
195 .sqrt()
196 }
197}
198
199#[cfg(target_arch = "aarch64")]
200#[target_feature(enable = "neon")]
201unsafe fn l2_distance_neon(a: &[f32], b: &[f32]) -> f32 {
202 let n = a.len();
203 let a_ptr = a.as_ptr();
204 let b_ptr = b.as_ptr();
205
206 unsafe {
209 let mut acc0 = vdupq_n_f32(0.0);
210 let mut acc1 = vdupq_n_f32(0.0);
211 let mut acc2 = vdupq_n_f32(0.0);
212 let mut acc3 = vdupq_n_f32(0.0);
213
214 let mut i = 0;
215 while i + 16 <= n {
216 let a0 = vld1q_f32(a_ptr.add(i));
217 let a1 = vld1q_f32(a_ptr.add(i + 4));
218 let a2 = vld1q_f32(a_ptr.add(i + 8));
219 let a3 = vld1q_f32(a_ptr.add(i + 12));
220 let b0 = vld1q_f32(b_ptr.add(i));
221 let b1 = vld1q_f32(b_ptr.add(i + 4));
222 let b2 = vld1q_f32(b_ptr.add(i + 8));
223 let b3 = vld1q_f32(b_ptr.add(i + 12));
224 let d0 = vsubq_f32(a0, b0);
225 let d1 = vsubq_f32(a1, b1);
226 let d2 = vsubq_f32(a2, b2);
227 let d3 = vsubq_f32(a3, b3);
228 acc0 = vfmaq_f32(acc0, d0, d0);
229 acc1 = vfmaq_f32(acc1, d1, d1);
230 acc2 = vfmaq_f32(acc2, d2, d2);
231 acc3 = vfmaq_f32(acc3, d3, d3);
232 i += 16;
233 }
234 while i + 4 <= n {
235 let av = vld1q_f32(a_ptr.add(i));
236 let bv = vld1q_f32(b_ptr.add(i));
237 let d = vsubq_f32(av, bv);
238 acc0 = vfmaq_f32(acc0, d, d);
239 i += 4;
240 }
241 let acc = vaddq_f32(vaddq_f32(acc0, acc1), vaddq_f32(acc2, acc3));
242 let mut sum = vaddvq_f32(acc);
243 while i < n {
244 let d = *a_ptr.add(i) - *b_ptr.add(i);
245 sum += d * d;
246 i += 1;
247 }
248 sum.sqrt()
249 }
250}
251
252pub fn distance_to_score(raw_distance: f32, metric: DistanceMetric) -> f32 {
265 match metric {
266 DistanceMetric::Cosine => {
267 ((2.0 - raw_distance) / 2.0).max(0.0)
270 }
271 DistanceMetric::L2 => {
272 1.0 / (1.0 + raw_distance * raw_distance)
275 }
276 DistanceMetric::DotProduct => {
277 ((1.0 - raw_distance) / 2.0).max(0.0)
281 }
282 }
283}
284
285#[cfg(test)]
286mod tests {
287 use super::*;
288
289 #[test]
290 fn cosine_identical() {
291 let mut v = vec![1.0, 2.0, 3.0];
294 normalize_in_place(&mut v).unwrap();
295 let d = distance(&v, &v, DistanceMetric::Cosine);
296 assert!(
297 d.abs() < 1e-5,
298 "identical vectors should have cosine distance ~0, got {d}"
299 );
300 }
301
302 #[test]
303 fn cosine_orthogonal() {
304 let a = vec![1.0, 0.0];
305 let b = vec![0.0, 1.0];
306 let d = distance(&a, &b, DistanceMetric::Cosine);
307 assert!(
308 (d - 1.0).abs() < 1e-5,
309 "orthogonal vectors should have cosine distance ~1, got {d}"
310 );
311 }
312
313 #[test]
314 fn cosine_opposite() {
315 let a = vec![1.0, 0.0];
316 let b = vec![-1.0, 0.0];
317 let d = distance(&a, &b, DistanceMetric::Cosine);
318 assert!(
319 (d - 2.0).abs() < 1e-5,
320 "opposite vectors should have cosine distance ~2, got {d}"
321 );
322 }
323
324 #[test]
325 fn dot_product_metric() {
326 let a = vec![1.0, 2.0];
327 let b = vec![3.0, 4.0];
328 let d = distance(&a, &b, DistanceMetric::DotProduct);
329 assert_eq!(d, -11.0);
331 }
332
333 #[test]
334 fn l2_distance_metric() {
335 let a = vec![0.0, 0.0];
336 let b = vec![3.0, 4.0];
337 let d = distance(&a, &b, DistanceMetric::L2);
338 assert!((d - 5.0).abs() < 1e-5, "L2 distance should be 5.0, got {d}");
339 }
340
341 #[test]
342 fn l2_identical() {
343 let v = vec![1.0, 2.0, 3.0];
344 let d = distance(&v, &v, DistanceMetric::L2);
345 assert!(d.abs() < 1e-5);
346 }
347
348 #[test]
355 fn unit_vectors() {
356 let a = vec![1.0, 0.0, 0.0];
357 let b = vec![0.0, 1.0, 0.0];
358 let d_cos = distance(&a, &b, DistanceMetric::Cosine);
359 let d_l2 = distance(&a, &b, DistanceMetric::L2);
360 assert!((d_cos - 1.0).abs() < 1e-5);
361 assert!((d_l2 - std::f32::consts::SQRT_2).abs() < 1e-5);
362 }
363
364 #[test]
367 fn cosine_score_identical() {
368 let s = distance_to_score(0.0, DistanceMetric::Cosine);
370 assert!(
371 (s - 1.0).abs() < 1e-5,
372 "identical vectors: score={s}, expected 1.0"
373 );
374 }
375
376 #[test]
377 fn cosine_score_orthogonal() {
378 let s = distance_to_score(1.0, DistanceMetric::Cosine);
380 assert!(
381 (s - 0.5).abs() < 1e-5,
382 "orthogonal vectors: score={s}, expected 0.5"
383 );
384 }
385
386 #[test]
387 fn cosine_score_opposite() {
388 let s = distance_to_score(2.0, DistanceMetric::Cosine);
390 assert!(s.abs() < 1e-5, "opposite vectors: score={s}, expected 0.0");
391 }
392
393 #[test]
394 fn l2_score_identical() {
395 let s = distance_to_score(0.0, DistanceMetric::L2);
397 assert!((s - 1.0).abs() < 1e-5, "identical: score={s}, expected 1.0");
398 }
399
400 #[test]
401 fn l2_score_unit_distance() {
402 let s = distance_to_score(1.0, DistanceMetric::L2);
404 assert!(
405 (s - 0.5).abs() < 1e-5,
406 "unit distance: score={s}, expected 0.5"
407 );
408 }
409
410 #[test]
411 fn l2_score_far() {
412 let s = distance_to_score(2.0, DistanceMetric::L2);
414 assert!((s - 0.2).abs() < 1e-5, "far: score={s}, expected 0.2");
415 }
416
417 #[test]
418 fn dot_product_score_high_similarity() {
419 let s = distance_to_score(-1.0, DistanceMetric::DotProduct);
421 assert!((s - 1.0).abs() < 1e-5, "high sim: score={s}, expected 1.0");
422 }
423
424 #[test]
425 fn dot_product_score_zero() {
426 let s = distance_to_score(0.0, DistanceMetric::DotProduct);
428 assert!((s - 0.5).abs() < 1e-5, "zero dot: score={s}, expected 0.5");
429 }
430
431 #[test]
432 fn dot_product_score_negative() {
433 let s = distance_to_score(1.0, DistanceMetric::DotProduct);
435 assert!(s.abs() < 1e-5, "negative dot: score={s}, expected 0.0");
436 }
437
438 #[test]
439 fn all_scores_non_negative() {
440 for dist in [0.0, 0.5, 1.0, 2.0, 5.0, 10.0] {
443 for metric in [
444 DistanceMetric::Cosine,
445 DistanceMetric::L2,
446 DistanceMetric::DotProduct,
447 ] {
448 let s = distance_to_score(dist, metric);
449 assert!(
450 s >= 0.0,
451 "score should be non-negative: metric={metric:?}, dist={dist}, score={s}"
452 );
453 }
454 }
455 }
456
457 #[test]
458 fn l2_scores_bounded_unit() {
459 for dist in [0.0, 0.1, 1.0, 10.0, 100.0] {
461 let s = distance_to_score(dist, DistanceMetric::L2);
462 assert!(
463 s > 0.0 && s <= 1.0,
464 "L2 score out of (0,1]: dist={dist}, score={s}"
465 );
466 }
467 }
468
469 #[test]
470 fn dot_product_unnormalized_can_exceed_one() {
471 let s = distance_to_score(-2.0, DistanceMetric::DotProduct);
474 assert!(
475 s > 1.0,
476 "unnormalized dot product should produce score > 1: {s}"
477 );
478 }
479
480 #[test]
481 fn from_byte_round_trips_known_metrics() {
482 for metric in [
483 DistanceMetric::Cosine,
484 DistanceMetric::DotProduct,
485 DistanceMetric::L2,
486 ] {
487 let byte = metric as u8;
488 assert_eq!(DistanceMetric::from_byte(byte), metric);
489 }
490 }
491
492 #[test]
493 fn from_byte_discriminants_are_pinned() {
494 assert_eq!(DistanceMetric::Cosine as u8, 0);
497 assert_eq!(DistanceMetric::DotProduct as u8, 1);
498 assert_eq!(DistanceMetric::L2 as u8, 2);
499 }
500
501 #[test]
502 #[should_panic(expected = "unknown distance metric byte 3")]
503 fn from_byte_panics_on_unknown_metric() {
504 let _ = DistanceMetric::from_byte(3);
507 }
508
509 #[test]
510 #[should_panic(expected = "unknown distance metric byte 255")]
511 fn from_byte_panics_on_garbage() {
512 let _ = DistanceMetric::from_byte(255);
514 }
515
516 #[test]
519 fn normalize_in_place_unit_length() {
520 let mut v = vec![3.0_f32, 4.0];
521 normalize_in_place(&mut v).unwrap();
522 let norm = (v[0] * v[0] + v[1] * v[1]).sqrt();
523 assert!((norm - 1.0).abs() < 1e-6, "norm after normalize: {norm}");
524 assert!((v[0] - 0.6).abs() < 1e-6 && (v[1] - 0.8).abs() < 1e-6);
525 }
526
527 #[test]
528 fn normalize_in_place_idempotent_on_unit_input() {
529 let mut v = vec![0.6_f32, 0.8];
530 let before = v.clone();
531 normalize_in_place(&mut v).unwrap();
532 for (a, b) in v.iter().zip(before.iter()) {
534 assert_eq!(a, b);
535 }
536 }
537
538 #[test]
539 fn normalize_in_place_zero_errors() {
540 let mut v = vec![0.0_f32, 0.0, 0.0];
541 let err = normalize_in_place(&mut v).unwrap_err();
542 let msg = format!("{err}");
543 assert!(
544 msg.contains("zero-length / non-finite vector"),
545 "unexpected message: {msg}",
546 );
547 }
548
549 #[test]
550 fn normalize_in_place_subnormal_errors() {
551 let mut v = vec![f32::MIN_POSITIVE * 1e-2; 3];
554 let err = normalize_in_place(&mut v).unwrap_err();
555 assert!(format!("{err}").contains("zero-length / non-finite vector"));
556 }
557
558 #[test]
559 fn normalize_in_place_overflow_errors() {
560 let mut v = vec![1e20_f32; 3];
562 let err = normalize_in_place(&mut v).unwrap_err();
563 assert!(format!("{err}").contains("zero-length / non-finite vector"));
564 }
565
566 #[test]
567 fn normalize_in_place_nan_errors() {
568 let mut v = vec![1.0_f32, f32::NAN, 2.0];
569 let err = normalize_in_place(&mut v).unwrap_err();
570 assert!(format!("{err}").contains("zero-length / non-finite vector"));
571 }
572
573 #[test]
574 fn cosine_score_unchanged_after_normalize() {
575 let cases: &[(Vec<f32>, Vec<f32>)] = &[
580 (vec![1.0, 2.0, 3.0], vec![4.0, 5.0, 6.0]),
581 (vec![1.0, 0.0, 0.0, 0.0], vec![0.0, 1.0, 0.0, 0.0]),
582 (vec![0.1; 100], vec![0.2; 100]),
583 ];
584 for (a_raw, b_raw) in cases {
585 let dot64: f64 = a_raw
587 .iter()
588 .zip(b_raw.iter())
589 .map(|(x, y)| (*x as f64) * (*y as f64))
590 .sum();
591 let na64: f64 = a_raw
592 .iter()
593 .map(|x| (*x as f64).powi(2))
594 .sum::<f64>()
595 .sqrt();
596 let nb64: f64 = b_raw
597 .iter()
598 .map(|x| (*x as f64).powi(2))
599 .sum::<f64>()
600 .sqrt();
601 let oracle_dist = 1.0 - dot64 / (na64 * nb64);
602 let oracle_score = ((2.0 - oracle_dist) / 2.0).max(0.0);
603
604 let mut a = a_raw.clone();
605 let mut b = b_raw.clone();
606 normalize_in_place(&mut a).unwrap();
607 normalize_in_place(&mut b).unwrap();
608 let d = distance(&a, &b, DistanceMetric::Cosine);
609 let s = distance_to_score(d, DistanceMetric::Cosine);
610 assert!(
611 ((s as f64) - oracle_score).abs() < 1e-3,
612 "score drift > 1e-3: post={s}, oracle={oracle_score}",
613 );
614 }
615 }
616
617 #[test]
618 fn cosine_distance_orthogonal_after_normalize() {
619 let mut a = vec![3.0, 0.0];
620 let mut b = vec![0.0, 7.0];
621 normalize_in_place(&mut a).unwrap();
622 normalize_in_place(&mut b).unwrap();
623 let d = distance(&a, &b, DistanceMetric::Cosine);
624 assert!((d - 1.0).abs() < 1e-6, "orthogonal cosine distance: {d}");
625 }
626
627 #[test]
628 fn cosine_distance_identical_after_normalize() {
629 let mut a = vec![1.0, 2.0, 3.0];
630 let mut b = a.clone();
631 normalize_in_place(&mut a).unwrap();
632 normalize_in_place(&mut b).unwrap();
633 let d = distance(&a, &b, DistanceMetric::Cosine);
634 assert!(d.abs() < 1e-6, "identical cosine distance: {d}");
635 }
636}