use std::iter::Sum;
use std::ops::AddAssign;
use std::sync::Arc;
use crate::{Error, Result};
use arrow_array::{
Array, FixedSizeListArray, Float32Array,
cast::AsArray,
types::{Float16Type, Float32Type, Float64Type, Int8Type},
};
use arrow_schema::DataType;
use half::{bf16, f16};
use lance_arrow::{ArrowFloatType, FixedSizeListArrayExt, FloatArray};
use lance_core::assume_eq;
use lance_core::utils::cpu::SIMD_SUPPORT;
#[cfg(feature = "fp16kernels")]
use lance_core::utils::cpu::SimdSupport;
use num_traits::{AsPrimitive, Num};
pub trait L2: Num {
fn l2(x: &[Self], y: &[Self]) -> f32;
fn l2_batch(x: &[Self], y: &[Self], dimension: usize) -> impl Iterator<Item = f32> {
y.chunks_exact(dimension).map(|v| Self::l2(x, v))
}
}
#[inline]
pub fn l2<T: L2>(from: &[T], to: &[T]) -> f32 {
T::l2(from, to)
}
#[inline]
pub fn l2_distance_uint_scalar(key: &[u8], target: &[u8]) -> f32 {
key.iter()
.zip(target.iter())
.map(|(&x, &y)| (x.abs_diff(y) as u32).pow(2))
.sum::<u32>() as f32
}
#[inline]
pub fn l2_scalar<
T: AsPrimitive<Output>,
Output: Num + Copy + Sum + AddAssign + 'static,
const LANES: usize,
>(
from: &[T],
to: &[T],
) -> Output {
let x_chunks = from.chunks_exact(LANES);
let y_chunks = to.chunks_exact(LANES);
let s = if !x_chunks.remainder().is_empty() {
x_chunks
.remainder()
.iter()
.zip(y_chunks.remainder())
.map(|(&x, &y)| {
let diff = x.as_() - y.as_();
diff * diff
})
.sum::<Output>()
} else {
Output::zero()
};
let mut sums = [Output::zero(); LANES];
for (x, y) in x_chunks.zip(y_chunks) {
for i in 0..LANES {
let diff = x[i].as_() - y[i].as_();
sums[i] += diff * diff;
}
}
s + sums.iter().copied().sum()
}
impl L2 for u8 {
#[inline]
fn l2(x: &[Self], y: &[Self]) -> f32 {
l2_distance_uint_scalar(x, y)
}
}
impl L2 for bf16 {
#[inline]
fn l2(x: &[Self], y: &[Self]) -> f32 {
l2_scalar::<Self, f32, 16>(x, y)
}
}
#[cfg(feature = "fp16kernels")]
mod kernel {
use super::*;
unsafe extern "C" {
#[cfg(target_arch = "aarch64")]
pub fn l2_f16_neon(ptr1: *const f16, ptr2: *const f16, len: u32) -> f32;
#[cfg(all(kernel_support = "avx512", target_arch = "x86_64"))]
pub fn l2_f16_avx512(ptr1: *const f16, ptr2: *const f16, len: u32) -> f32;
#[cfg(target_arch = "x86_64")]
pub fn l2_f16_avx2(ptr1: *const f16, ptr2: *const f16, len: u32) -> f32;
#[cfg(target_arch = "loongarch64")]
pub fn l2_f16_lsx(ptr1: *const f16, ptr2: *const f16, len: u32) -> f32;
#[cfg(target_arch = "loongarch64")]
pub fn l2_f16_lasx(ptr1: *const f16, ptr2: *const f16, len: u32) -> f32;
}
}
impl L2 for f16 {
#[inline]
fn l2(x: &[Self], y: &[Self]) -> f32 {
match *SIMD_SUPPORT {
#[cfg(all(feature = "fp16kernels", target_arch = "aarch64"))]
SimdSupport::Neon => unsafe {
kernel::l2_f16_neon(x.as_ptr(), y.as_ptr(), x.len() as u32)
},
#[cfg(all(
feature = "fp16kernels",
kernel_support = "avx512",
target_arch = "x86_64"
))]
SimdSupport::Avx512FP16 => unsafe {
kernel::l2_f16_avx512(x.as_ptr(), y.as_ptr(), x.len() as u32)
},
#[cfg(all(feature = "fp16kernels", target_arch = "x86_64"))]
SimdSupport::Avx2 => unsafe {
kernel::l2_f16_avx2(x.as_ptr(), y.as_ptr(), x.len() as u32)
},
#[cfg(all(feature = "fp16kernels", target_arch = "loongarch64"))]
SimdSupport::Lasx => unsafe {
kernel::l2_f16_lasx(x.as_ptr(), y.as_ptr(), x.len() as u32)
},
#[cfg(all(feature = "fp16kernels", target_arch = "loongarch64"))]
SimdSupport::Lsx => unsafe {
kernel::l2_f16_lsx(x.as_ptr(), y.as_ptr(), x.len() as u32)
},
_ => l2_scalar::<Self, f32, 16>(x, y),
}
}
}
impl L2 for f32 {
#[inline]
fn l2(x: &[Self], y: &[Self]) -> f32 {
l2_scalar::<Self, Self, 16>(x, y)
}
}
impl L2 for f64 {
#[inline]
fn l2(x: &[Self], y: &[Self]) -> f32 {
l2_scalar::<Self, Self, 8>(x, y) as f32
}
}
#[inline(never)]
fn accumulate_l2_dimension(q: f32, row: &[f32], result: &mut [f32]) {
for (dist, &target) in result.iter_mut().zip(row.iter()) {
let diff = q - target;
*dist += diff * diff;
}
}
#[derive(Debug, Clone)]
pub struct L2Prepared {
transposed: Vec<f32>,
dimension: usize,
num_targets: usize,
}
impl L2Prepared {
pub fn new(targets: &[f32], dimension: usize) -> Self {
let num_targets = targets.len() / dimension;
debug_assert_eq!(targets.len(), num_targets * dimension);
let mut transposed = vec![0.0f32; targets.len()];
for t in 0..num_targets {
for d in 0..dimension {
transposed[d * num_targets + t] = targets[t * dimension + d];
}
}
Self {
transposed,
dimension,
num_targets,
}
}
pub fn distances_into(&self, query: &[f32], out: &mut [f32]) {
debug_assert_eq!(query.len(), self.dimension);
debug_assert_eq!(out.len(), self.num_targets);
out.fill(0.0);
for (d, &q) in query.iter().enumerate() {
let row = &self.transposed[d * self.num_targets..][..self.num_targets];
accumulate_l2_dimension(q, row, out);
}
}
pub fn distances(&self, query: &[f32]) -> Vec<f32> {
let mut result = vec![0.0f32; self.num_targets];
self.distances_into(query, &mut result);
result
}
pub fn nearest_into(&self, query: &[f32], buf: &mut [f32]) -> Option<u32> {
self.distances_into(query, buf);
crate::kernels::argmin_value_float(buf.iter().copied()).map(|(idx, _)| idx)
}
pub fn nearest(&self, query: &[f32]) -> Option<u32> {
self.nearest_into(query, &mut vec![0.0f32; self.num_targets])
}
pub fn num_targets(&self) -> usize {
self.num_targets
}
pub fn dimension(&self) -> usize {
self.dimension
}
pub fn size_bytes(&self) -> usize {
self.transposed.len() * std::mem::size_of::<f32>()
}
}
#[inline]
pub fn l2_distance(from: &[f32], to: &[f32]) -> f32 {
l2(from, to)
}
pub fn l2_distance_batch<'a, T: L2>(
from: &'a [T],
to: &'a [T],
dimension: usize,
) -> impl Iterator<Item = f32> + 'a {
assume_eq!(from.len(), dimension);
assume_eq!(to.len() % dimension, 0);
T::l2_batch(from, to, dimension)
}
fn do_l2_distance_arrow_batch<T: ArrowFloatType>(
from: &T::ArrayType,
to: &FixedSizeListArray,
) -> Result<Arc<Float32Array>>
where
T::Native: L2,
{
let dimension = to.value_length() as usize;
debug_assert_eq!(from.len(), dimension);
let to_values =
to.values()
.as_any()
.downcast_ref::<T::ArrayType>()
.ok_or(Error::ComputeError(format!(
"Cannot downcast to the same type: {} != {}",
T::FLOAT_TYPE,
to.value_type()
)))?;
let dists = l2_distance_batch(from.as_slice(), to_values.as_slice(), dimension);
Ok(Arc::new(Float32Array::new(
dists.collect(),
to.nulls().cloned(),
)))
}
pub fn l2_distance_arrow_batch(
from: &dyn Array,
to: &FixedSizeListArray,
) -> Result<Arc<Float32Array>> {
match *from.data_type() {
DataType::Float16 => do_l2_distance_arrow_batch::<Float16Type>(from.as_primitive(), to),
DataType::Float32 => do_l2_distance_arrow_batch::<Float32Type>(from.as_primitive(), to),
DataType::Float64 => do_l2_distance_arrow_batch::<Float64Type>(from.as_primitive(), to),
DataType::Int8 => do_l2_distance_arrow_batch::<Float32Type>(
&from
.as_primitive::<Int8Type>()
.into_iter()
.map(|x| x.unwrap() as f32)
.collect(),
&to.convert_to_floating_point()?,
),
_ => Err(Error::ComputeError(format!(
"Unsupported data type: {}",
from.data_type()
))),
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
use num_traits::ToPrimitive;
use proptest::prelude::*;
use crate::test_utils::{
arbitrary_bf16, arbitrary_f16, arbitrary_f32, arbitrary_f64, arbitrary_vector_pair,
};
#[test]
fn test_euclidean_distance() {
let mat = FixedSizeListArray::from_iter_primitive::<Float32Type, _, _>(
vec![
Some((0..8).map(|v| Some(v as f32)).collect::<Vec<_>>()),
Some((1..9).map(|v| Some(v as f32)).collect::<Vec<_>>()),
Some((2..10).map(|v| Some(v as f32)).collect::<Vec<_>>()),
Some((3..11).map(|v| Some(v as f32)).collect::<Vec<_>>()),
],
8,
);
let point = Float32Array::from((2..10).map(|v| Some(v as f32)).collect::<Vec<_>>());
let distances = l2_distance_batch(
point.values(),
mat.values().as_primitive::<Float32Type>().values(),
8,
)
.collect::<Vec<_>>();
assert_eq!(distances, vec![32.0, 8.0, 0.0, 8.0]);
}
#[test]
fn test_not_aligned() {
let mat = (0..6)
.chain(0..8)
.chain(1..9)
.chain(2..10)
.chain(3..11)
.map(|v| v as f32)
.collect::<Vec<_>>();
let point = Float32Array::from((0..10).map(|v| Some(v as f32)).collect::<Vec<_>>());
let distances = l2_distance_batch(&point.values()[2..], &mat[6..], 8).collect::<Vec<_>>();
assert_eq!(distances, vec![32.0, 8.0, 0.0, 8.0]);
}
#[test]
fn test_odd_length_vector() {
let mat = Float32Array::from_iter((0..5).map(|v| Some(v as f32)));
let point = Float32Array::from((2..7).map(|v| Some(v as f32)).collect::<Vec<_>>());
let distances = l2_distance_batch(point.values(), mat.values(), 5).collect::<Vec<_>>();
assert_eq!(distances, vec![20.0]);
}
#[test]
fn test_l2_distance_cases() {
let values: Float32Array = vec![
0.25335717, 0.24663818, 0.26330215, 0.14988247, 0.06042378, 0.21077952, 0.26687378,
0.22145681, 0.18319066, 0.18688454, 0.05216244, 0.11470364, 0.10554603, 0.19964123,
0.06387895, 0.18992095, 0.00123718, 0.13500804, 0.09516747, 0.19508345, 0.2582458,
0.1211653, 0.21121833, 0.24809816, 0.04078768, 0.19586588, 0.16496408, 0.14766085,
0.04898421, 0.14728612, 0.21263947, 0.16763233,
]
.into();
let q: Float32Array = vec![
0.18549609,
0.29954708,
0.28318876,
0.05424477,
0.093134984,
0.21580857,
0.2951282,
0.19866848,
0.13868214,
0.19819534,
0.23271298,
0.047727287,
0.14394054,
0.023316395,
0.18589257,
0.037315924,
0.07037327,
0.32609823,
0.07344752,
0.020155912,
0.18485495,
0.32763934,
0.14296658,
0.04498596,
0.06254237,
0.24348071,
0.16009757,
0.053892266,
0.05918874,
0.040363103,
0.19913352,
0.14545348,
]
.into();
let d = l2_distance_batch(q.values(), values.values(), 32).collect::<Vec<_>>();
assert_relative_eq!(0.319_357_84, d[0]);
}
fn l2_distance_reference(x: &[f64], y: &[f64]) -> f64 {
x.iter()
.zip(y.iter())
.map(|(x, y)| (*x - *y).powi(2))
.sum::<f64>()
}
fn do_l2_test<T: L2 + ToPrimitive>(x: &[T], y: &[T]) -> std::result::Result<(), TestCaseError> {
let x_f64 = x.iter().map(|v| v.to_f64().unwrap()).collect::<Vec<f64>>();
let y_f64 = y.iter().map(|v| v.to_f64().unwrap()).collect::<Vec<f64>>();
let result = l2(x, y);
let reference = l2_distance_reference(&x_f64, &y_f64) as f32;
prop_assert!(approx::relative_eq!(result, reference, max_relative = 1e-6));
Ok(())
}
#[test]
fn test_l2_distance_f16_max() {
let x = vec![f16::MAX; 4048];
let y = vec![-f16::MAX; 4048];
do_l2_test(&x, &y).unwrap();
}
proptest::proptest! {
#[test]
fn test_l2_distance_f16((x, y) in arbitrary_vector_pair(arbitrary_f16, 4..4048)) {
do_l2_test(&x, &y)?;
}
#[test]
fn test_l2_distance_bf16((x, y) in arbitrary_vector_pair(arbitrary_bf16, 4..4048)){
do_l2_test(&x, &y)?;
}
#[test]
fn test_l2_distance_f32((x, y) in arbitrary_vector_pair(arbitrary_f32, 4..4048)){
do_l2_test(&x, &y)?;
}
#[test]
fn test_l2_distance_f64((x, y) in arbitrary_vector_pair(arbitrary_f64, 4..4048)){
do_l2_test(&x, &y)?;
}
}
#[test]
fn test_uint8_l2_edge_cases() {
let q = vec![0_u8; 2048];
let v = vec![0_u8; 2048];
assert_eq!(l2_distance_uint_scalar(&q, &v), 0.0);
let q = vec![0_u8; 2048];
let v = vec![255_u8; 2048];
assert_eq!(
l2_distance_uint_scalar(&q, &v),
(255_u32.pow(2) * 2048) as f32
);
assert_eq!(
l2_distance_uint_scalar(&v, &q),
(255_u32.pow(2) * 2048) as f32
);
}
#[test]
fn test_l2_targets_matches_scalar() {
let cases = vec![
(16, 8), (16, 16), (16, 256), (16, 17), (16, 31), (1, 32), (3, 20), (128, 64), ];
for (dim, num_targets) in cases {
let query: Vec<f32> = (0..dim).map(|i| (i as f32) * 0.1 + 0.05).collect();
let targets: Vec<f32> = (0..dim * num_targets)
.map(|i| ((i * 7 + 3) % 100) as f32 * 0.01)
.collect();
let expected: Vec<f32> = targets
.chunks_exact(dim)
.map(|v| l2_scalar::<f32, f32, 16>(&query, v))
.collect();
let prepared = L2Prepared::new(&targets, dim);
let actual = prepared.distances(&query);
assert_eq!(
actual.len(),
expected.len(),
"length mismatch for dim={dim}, num_targets={num_targets}"
);
for (i, (a, e)) in actual.iter().zip(expected.iter()).enumerate() {
assert!(
approx::relative_eq!(a, e, max_relative = 1e-6),
"mismatch at index {i} for dim={dim}, num_targets={num_targets}: \
prepared={a}, scalar={e}"
);
}
}
}
#[test]
fn test_l2_targets_zeros() {
let dim = 16;
let num_targets = 32;
let query = vec![0.0f32; dim];
let targets = vec![0.0f32; dim * num_targets];
let prepared = L2Prepared::new(&targets, dim);
let distances = prepared.distances(&query);
assert_eq!(distances.len(), num_targets);
for d in &distances {
assert_eq!(*d, 0.0);
}
}
#[test]
fn test_l2_targets_known_values() {
let dim = 2;
let query = vec![1.0f32, 0.0];
let mut targets = vec![1.0, 0.0, 0.0, 1.0, 2.0, 0.0, 0.0, 0.0];
for _ in 4..16 {
targets.extend_from_slice(&[0.0, 0.0]);
}
let prepared = L2Prepared::new(&targets, dim);
let distances = prepared.distances(&query);
assert_eq!(distances.len(), 16);
assert_relative_eq!(distances[0], 0.0);
assert_relative_eq!(distances[1], 2.0);
assert_relative_eq!(distances[2], 1.0);
assert_relative_eq!(distances[3], 1.0);
for d in &distances[4..] {
assert_relative_eq!(*d, 1.0);
}
}
#[test]
fn test_l2_targets_reuse() {
let dim = 4;
let targets = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let prepared = L2Prepared::new(&targets, dim);
let q1 = vec![1.0, 2.0, 3.0, 4.0];
let q2 = vec![5.0, 6.0, 7.0, 8.0];
let d1 = prepared.distances(&q1);
let d2 = prepared.distances(&q2);
assert_relative_eq!(d1[0], 0.0); assert_relative_eq!(d2[1], 0.0); }
}