use crate::core::LuciError;
#[cfg(target_arch = "aarch64")]
use std::arch::aarch64::{vaddq_f32, vaddvq_f32, vdupq_n_f32, vfmaq_f32, vld1q_f32, vsubq_f32};
pub mod global;
pub mod hnsw;
pub mod quantize;
pub mod query;
#[cfg(test)]
mod distance_tests;
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
#[repr(u8)]
pub enum DistanceMetric {
Cosine = 0,
DotProduct = 1,
L2 = 2,
}
impl DistanceMetric {
pub fn from_byte(byte: u8) -> Self {
match byte {
0 => Self::Cosine,
1 => Self::DotProduct,
2 => Self::L2,
other => panic!(
"unknown distance metric byte {other}: segment is corrupted \
or was written by a newer version of Luci"
),
}
}
}
pub fn distance(a: &[f32], b: &[f32], metric: DistanceMetric) -> f32 {
debug_assert_eq!(a.len(), b.len());
match metric {
DistanceMetric::Cosine => cosine_distance_normalized(a, b),
DistanceMetric::DotProduct => -dot_product(a, b),
DistanceMetric::L2 => l2_distance(a, b),
}
}
fn dot_product(a: &[f32], b: &[f32]) -> f32 {
debug_assert_eq!(a.len(), b.len());
#[cfg(target_arch = "aarch64")]
{
unsafe { dot_product_neon(a, b) }
}
#[cfg(not(target_arch = "aarch64"))]
{
a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
}
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
unsafe fn dot_product_neon(a: &[f32], b: &[f32]) -> f32 {
let n = a.len();
let a_ptr = a.as_ptr();
let b_ptr = b.as_ptr();
unsafe {
let mut acc0 = vdupq_n_f32(0.0);
let mut acc1 = vdupq_n_f32(0.0);
let mut acc2 = vdupq_n_f32(0.0);
let mut acc3 = vdupq_n_f32(0.0);
let mut i = 0;
while i + 16 <= n {
let a0 = vld1q_f32(a_ptr.add(i));
let a1 = vld1q_f32(a_ptr.add(i + 4));
let a2 = vld1q_f32(a_ptr.add(i + 8));
let a3 = vld1q_f32(a_ptr.add(i + 12));
let b0 = vld1q_f32(b_ptr.add(i));
let b1 = vld1q_f32(b_ptr.add(i + 4));
let b2 = vld1q_f32(b_ptr.add(i + 8));
let b3 = vld1q_f32(b_ptr.add(i + 12));
acc0 = vfmaq_f32(acc0, a0, b0);
acc1 = vfmaq_f32(acc1, a1, b1);
acc2 = vfmaq_f32(acc2, a2, b2);
acc3 = vfmaq_f32(acc3, a3, b3);
i += 16;
}
while i + 4 <= n {
let av = vld1q_f32(a_ptr.add(i));
let bv = vld1q_f32(b_ptr.add(i));
acc0 = vfmaq_f32(acc0, av, bv);
i += 4;
}
let acc = vaddq_f32(vaddq_f32(acc0, acc1), vaddq_f32(acc2, acc3));
let mut sum = vaddvq_f32(acc);
while i < n {
sum += *a_ptr.add(i) * *b_ptr.add(i);
i += 1;
}
sum
}
}
fn cosine_distance_normalized(a: &[f32], b: &[f32]) -> f32 {
1.0 - dot_product(a, b)
}
pub fn normalize_in_place(v: &mut [f32]) -> Result<(), LuciError> {
let norm_sq: f32 = v.iter().map(|x| x * x).sum();
if !norm_sq.is_finite() || norm_sq == 0.0 {
return Err(LuciError::InvalidQuery(
"zero-length / non-finite vector not supported with cosine \
metric — use metric: dot_product to bypass normalization"
.into(),
));
}
if (norm_sq - 1.0).abs() < 1e-4 {
return Ok(());
}
let inv = 1.0 / norm_sq.sqrt();
for x in v.iter_mut() {
*x *= inv;
}
Ok(())
}
fn l2_distance(a: &[f32], b: &[f32]) -> f32 {
debug_assert_eq!(a.len(), b.len());
#[cfg(target_arch = "aarch64")]
{
unsafe { l2_distance_neon(a, b) }
}
#[cfg(not(target_arch = "aarch64"))]
{
a.iter()
.zip(b.iter())
.map(|(x, y)| (x - y) * (x - y))
.sum::<f32>()
.sqrt()
}
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
unsafe fn l2_distance_neon(a: &[f32], b: &[f32]) -> f32 {
let n = a.len();
let a_ptr = a.as_ptr();
let b_ptr = b.as_ptr();
unsafe {
let mut acc0 = vdupq_n_f32(0.0);
let mut acc1 = vdupq_n_f32(0.0);
let mut acc2 = vdupq_n_f32(0.0);
let mut acc3 = vdupq_n_f32(0.0);
let mut i = 0;
while i + 16 <= n {
let a0 = vld1q_f32(a_ptr.add(i));
let a1 = vld1q_f32(a_ptr.add(i + 4));
let a2 = vld1q_f32(a_ptr.add(i + 8));
let a3 = vld1q_f32(a_ptr.add(i + 12));
let b0 = vld1q_f32(b_ptr.add(i));
let b1 = vld1q_f32(b_ptr.add(i + 4));
let b2 = vld1q_f32(b_ptr.add(i + 8));
let b3 = vld1q_f32(b_ptr.add(i + 12));
let d0 = vsubq_f32(a0, b0);
let d1 = vsubq_f32(a1, b1);
let d2 = vsubq_f32(a2, b2);
let d3 = vsubq_f32(a3, b3);
acc0 = vfmaq_f32(acc0, d0, d0);
acc1 = vfmaq_f32(acc1, d1, d1);
acc2 = vfmaq_f32(acc2, d2, d2);
acc3 = vfmaq_f32(acc3, d3, d3);
i += 16;
}
while i + 4 <= n {
let av = vld1q_f32(a_ptr.add(i));
let bv = vld1q_f32(b_ptr.add(i));
let d = vsubq_f32(av, bv);
acc0 = vfmaq_f32(acc0, d, d);
i += 4;
}
let acc = vaddq_f32(vaddq_f32(acc0, acc1), vaddq_f32(acc2, acc3));
let mut sum = vaddvq_f32(acc);
while i < n {
let d = *a_ptr.add(i) - *b_ptr.add(i);
sum += d * d;
i += 1;
}
sum.sqrt()
}
}
pub fn distance_to_score(raw_distance: f32, metric: DistanceMetric) -> f32 {
match metric {
DistanceMetric::Cosine => {
((2.0 - raw_distance) / 2.0).max(0.0)
}
DistanceMetric::L2 => {
1.0 / (1.0 + raw_distance * raw_distance)
}
DistanceMetric::DotProduct => {
((1.0 - raw_distance) / 2.0).max(0.0)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn cosine_identical() {
let mut v = vec![1.0, 2.0, 3.0];
normalize_in_place(&mut v).unwrap();
let d = distance(&v, &v, DistanceMetric::Cosine);
assert!(
d.abs() < 1e-5,
"identical vectors should have cosine distance ~0, got {d}"
);
}
#[test]
fn cosine_orthogonal() {
let a = vec![1.0, 0.0];
let b = vec![0.0, 1.0];
let d = distance(&a, &b, DistanceMetric::Cosine);
assert!(
(d - 1.0).abs() < 1e-5,
"orthogonal vectors should have cosine distance ~1, got {d}"
);
}
#[test]
fn cosine_opposite() {
let a = vec![1.0, 0.0];
let b = vec![-1.0, 0.0];
let d = distance(&a, &b, DistanceMetric::Cosine);
assert!(
(d - 2.0).abs() < 1e-5,
"opposite vectors should have cosine distance ~2, got {d}"
);
}
#[test]
fn dot_product_metric() {
let a = vec![1.0, 2.0];
let b = vec![3.0, 4.0];
let d = distance(&a, &b, DistanceMetric::DotProduct);
assert_eq!(d, -11.0);
}
#[test]
fn l2_distance_metric() {
let a = vec![0.0, 0.0];
let b = vec![3.0, 4.0];
let d = distance(&a, &b, DistanceMetric::L2);
assert!((d - 5.0).abs() < 1e-5, "L2 distance should be 5.0, got {d}");
}
#[test]
fn l2_identical() {
let v = vec![1.0, 2.0, 3.0];
let d = distance(&v, &v, DistanceMetric::L2);
assert!(d.abs() < 1e-5);
}
#[test]
fn unit_vectors() {
let a = vec![1.0, 0.0, 0.0];
let b = vec![0.0, 1.0, 0.0];
let d_cos = distance(&a, &b, DistanceMetric::Cosine);
let d_l2 = distance(&a, &b, DistanceMetric::L2);
assert!((d_cos - 1.0).abs() < 1e-5);
assert!((d_l2 - std::f32::consts::SQRT_2).abs() < 1e-5);
}
#[test]
fn cosine_score_identical() {
let s = distance_to_score(0.0, DistanceMetric::Cosine);
assert!(
(s - 1.0).abs() < 1e-5,
"identical vectors: score={s}, expected 1.0"
);
}
#[test]
fn cosine_score_orthogonal() {
let s = distance_to_score(1.0, DistanceMetric::Cosine);
assert!(
(s - 0.5).abs() < 1e-5,
"orthogonal vectors: score={s}, expected 0.5"
);
}
#[test]
fn cosine_score_opposite() {
let s = distance_to_score(2.0, DistanceMetric::Cosine);
assert!(s.abs() < 1e-5, "opposite vectors: score={s}, expected 0.0");
}
#[test]
fn l2_score_identical() {
let s = distance_to_score(0.0, DistanceMetric::L2);
assert!((s - 1.0).abs() < 1e-5, "identical: score={s}, expected 1.0");
}
#[test]
fn l2_score_unit_distance() {
let s = distance_to_score(1.0, DistanceMetric::L2);
assert!(
(s - 0.5).abs() < 1e-5,
"unit distance: score={s}, expected 0.5"
);
}
#[test]
fn l2_score_far() {
let s = distance_to_score(2.0, DistanceMetric::L2);
assert!((s - 0.2).abs() < 1e-5, "far: score={s}, expected 0.2");
}
#[test]
fn dot_product_score_high_similarity() {
let s = distance_to_score(-1.0, DistanceMetric::DotProduct);
assert!((s - 1.0).abs() < 1e-5, "high sim: score={s}, expected 1.0");
}
#[test]
fn dot_product_score_zero() {
let s = distance_to_score(0.0, DistanceMetric::DotProduct);
assert!((s - 0.5).abs() < 1e-5, "zero dot: score={s}, expected 0.5");
}
#[test]
fn dot_product_score_negative() {
let s = distance_to_score(1.0, DistanceMetric::DotProduct);
assert!(s.abs() < 1e-5, "negative dot: score={s}, expected 0.0");
}
#[test]
fn all_scores_non_negative() {
for dist in [0.0, 0.5, 1.0, 2.0, 5.0, 10.0] {
for metric in [
DistanceMetric::Cosine,
DistanceMetric::L2,
DistanceMetric::DotProduct,
] {
let s = distance_to_score(dist, metric);
assert!(
s >= 0.0,
"score should be non-negative: metric={metric:?}, dist={dist}, score={s}"
);
}
}
}
#[test]
fn l2_scores_bounded_unit() {
for dist in [0.0, 0.1, 1.0, 10.0, 100.0] {
let s = distance_to_score(dist, DistanceMetric::L2);
assert!(
s > 0.0 && s <= 1.0,
"L2 score out of (0,1]: dist={dist}, score={s}"
);
}
}
#[test]
fn dot_product_unnormalized_can_exceed_one() {
let s = distance_to_score(-2.0, DistanceMetric::DotProduct);
assert!(
s > 1.0,
"unnormalized dot product should produce score > 1: {s}"
);
}
#[test]
fn from_byte_round_trips_known_metrics() {
for metric in [
DistanceMetric::Cosine,
DistanceMetric::DotProduct,
DistanceMetric::L2,
] {
let byte = metric as u8;
assert_eq!(DistanceMetric::from_byte(byte), metric);
}
}
#[test]
fn from_byte_discriminants_are_pinned() {
assert_eq!(DistanceMetric::Cosine as u8, 0);
assert_eq!(DistanceMetric::DotProduct as u8, 1);
assert_eq!(DistanceMetric::L2 as u8, 2);
}
#[test]
#[should_panic(expected = "unknown distance metric byte 3")]
fn from_byte_panics_on_unknown_metric() {
let _ = DistanceMetric::from_byte(3);
}
#[test]
#[should_panic(expected = "unknown distance metric byte 255")]
fn from_byte_panics_on_garbage() {
let _ = DistanceMetric::from_byte(255);
}
#[test]
fn normalize_in_place_unit_length() {
let mut v = vec![3.0_f32, 4.0];
normalize_in_place(&mut v).unwrap();
let norm = (v[0] * v[0] + v[1] * v[1]).sqrt();
assert!((norm - 1.0).abs() < 1e-6, "norm after normalize: {norm}");
assert!((v[0] - 0.6).abs() < 1e-6 && (v[1] - 0.8).abs() < 1e-6);
}
#[test]
fn normalize_in_place_idempotent_on_unit_input() {
let mut v = vec![0.6_f32, 0.8];
let before = v.clone();
normalize_in_place(&mut v).unwrap();
for (a, b) in v.iter().zip(before.iter()) {
assert_eq!(a, b);
}
}
#[test]
fn normalize_in_place_zero_errors() {
let mut v = vec![0.0_f32, 0.0, 0.0];
let err = normalize_in_place(&mut v).unwrap_err();
let msg = format!("{err}");
assert!(
msg.contains("zero-length / non-finite vector"),
"unexpected message: {msg}",
);
}
#[test]
fn normalize_in_place_subnormal_errors() {
let mut v = vec![f32::MIN_POSITIVE * 1e-2; 3];
let err = normalize_in_place(&mut v).unwrap_err();
assert!(format!("{err}").contains("zero-length / non-finite vector"));
}
#[test]
fn normalize_in_place_overflow_errors() {
let mut v = vec![1e20_f32; 3];
let err = normalize_in_place(&mut v).unwrap_err();
assert!(format!("{err}").contains("zero-length / non-finite vector"));
}
#[test]
fn normalize_in_place_nan_errors() {
let mut v = vec![1.0_f32, f32::NAN, 2.0];
let err = normalize_in_place(&mut v).unwrap_err();
assert!(format!("{err}").contains("zero-length / non-finite vector"));
}
#[test]
fn cosine_score_unchanged_after_normalize() {
let cases: &[(Vec<f32>, Vec<f32>)] = &[
(vec![1.0, 2.0, 3.0], vec![4.0, 5.0, 6.0]),
(vec![1.0, 0.0, 0.0, 0.0], vec![0.0, 1.0, 0.0, 0.0]),
(vec![0.1; 100], vec![0.2; 100]),
];
for (a_raw, b_raw) in cases {
let dot64: f64 = a_raw
.iter()
.zip(b_raw.iter())
.map(|(x, y)| (*x as f64) * (*y as f64))
.sum();
let na64: f64 = a_raw
.iter()
.map(|x| (*x as f64).powi(2))
.sum::<f64>()
.sqrt();
let nb64: f64 = b_raw
.iter()
.map(|x| (*x as f64).powi(2))
.sum::<f64>()
.sqrt();
let oracle_dist = 1.0 - dot64 / (na64 * nb64);
let oracle_score = ((2.0 - oracle_dist) / 2.0).max(0.0);
let mut a = a_raw.clone();
let mut b = b_raw.clone();
normalize_in_place(&mut a).unwrap();
normalize_in_place(&mut b).unwrap();
let d = distance(&a, &b, DistanceMetric::Cosine);
let s = distance_to_score(d, DistanceMetric::Cosine);
assert!(
((s as f64) - oracle_score).abs() < 1e-3,
"score drift > 1e-3: post={s}, oracle={oracle_score}",
);
}
}
#[test]
fn cosine_distance_orthogonal_after_normalize() {
let mut a = vec![3.0, 0.0];
let mut b = vec![0.0, 7.0];
normalize_in_place(&mut a).unwrap();
normalize_in_place(&mut b).unwrap();
let d = distance(&a, &b, DistanceMetric::Cosine);
assert!((d - 1.0).abs() < 1e-6, "orthogonal cosine distance: {d}");
}
#[test]
fn cosine_distance_identical_after_normalize() {
let mut a = vec![1.0, 2.0, 3.0];
let mut b = a.clone();
normalize_in_place(&mut a).unwrap();
normalize_in_place(&mut b).unwrap();
let d = distance(&a, &b, DistanceMetric::Cosine);
assert!(d.abs() < 1e-6, "identical cosine distance: {d}");
}
}