lance_linalg/distance/
norm_l2.rs1use std::{iter::Sum, ops::AddAssign};
5
6use half::{bf16, f16};
7#[cfg(feature = "fp16kernels")]
8use lance_core::utils::cpu::SimdSupport;
9#[allow(unused_imports)]
10use lance_core::utils::cpu::FP16_SIMD_SUPPORT;
11use num_traits::{AsPrimitive, Float, Num};
12
13pub trait Normalize: Num {
15 fn norm_l2(vector: &[Self]) -> f32;
17}
18
19#[cfg(feature = "fp16kernels")]
20mod kernel {
21 use super::*;
22
23 extern "C" {
26 #[cfg(target_arch = "aarch64")]
27 pub fn norm_l2_f16_neon(ptr: *const f16, len: u32) -> f32;
28 #[cfg(all(kernel_support = "avx512", target_arch = "x86_64"))]
29 pub fn norm_l2_f16_avx512(ptr: *const f16, len: u32) -> f32;
30 #[cfg(target_arch = "x86_64")]
31 pub fn norm_l2_f16_avx2(ptr: *const f16, len: u32) -> f32;
32 #[cfg(target_arch = "loongarch64")]
33 pub fn norm_l2_f16_lsx(ptr: *const f16, len: u32) -> f32;
34 #[cfg(target_arch = "loongarch64")]
35 pub fn norm_l2_f16_lasx(ptr: *const f16, len: u32) -> f32;
36 }
37}
38
39impl Normalize for u8 {
40 #[inline]
41 fn norm_l2(vector: &[Self]) -> f32 {
42 norm_l2_impl::<Self, f32, 16>(vector)
43 }
44}
45
46impl Normalize for f16 {
47 #[inline]
48 fn norm_l2(vector: &[Self]) -> f32 {
49 match *FP16_SIMD_SUPPORT {
50 #[cfg(all(feature = "fp16kernels", target_arch = "aarch64"))]
51 SimdSupport::Neon => unsafe {
52 kernel::norm_l2_f16_neon(vector.as_ptr(), vector.len() as u32)
53 },
54 #[cfg(all(
55 feature = "fp16kernels",
56 kernel_support = "avx512",
57 target_arch = "x86_64"
58 ))]
59 SimdSupport::Avx512 => unsafe {
60 kernel::norm_l2_f16_avx512(vector.as_ptr(), vector.len() as u32)
61 },
62 #[cfg(all(feature = "fp16kernels", target_arch = "x86_64"))]
63 SimdSupport::Avx2 => unsafe {
64 kernel::norm_l2_f16_avx2(vector.as_ptr(), vector.len() as u32)
65 },
66 #[cfg(all(feature = "fp16kernels", target_arch = "loongarch64"))]
67 SimdSupport::Lasx => unsafe {
68 kernel::norm_l2_f16_lasx(vector.as_ptr(), vector.len() as u32)
69 },
70 #[cfg(all(feature = "fp16kernels", target_arch = "loongarch64"))]
71 SimdSupport::Lsx => unsafe {
72 kernel::norm_l2_f16_lsx(vector.as_ptr(), vector.len() as u32)
73 },
74 _ => norm_l2_impl::<Self, f32, 32>(vector),
75 }
76 }
77}
78
79impl Normalize for bf16 {
80 #[inline]
81 fn norm_l2(vector: &[Self]) -> f32 {
82 norm_l2_impl::<Self, f32, 32>(vector)
83 }
84}
85
86impl Normalize for f32 {
87 #[inline]
88 fn norm_l2(vector: &[Self]) -> f32 {
89 norm_l2_impl::<Self, Self, 16>(vector)
90 }
91}
92
93impl Normalize for f64 {
94 #[inline]
95 fn norm_l2(vector: &[Self]) -> f32 {
96 norm_l2_impl::<Self, Self, 8>(vector) as f32
97 }
98}
99
100#[inline]
102pub fn norm_l2_impl<
103 T: AsPrimitive<Output>,
104 Output: Float + Sum + 'static + AddAssign,
105 const LANES: usize,
106>(
107 vector: &[T],
108) -> Output {
109 let chunks = vector.chunks_exact(LANES);
110 let sum = if chunks.remainder().is_empty() {
111 Output::zero()
112 } else {
113 chunks
114 .remainder()
115 .iter()
116 .map(|&v| v.as_().powi(2))
117 .sum::<Output>()
118 };
119 let mut sums = [Output::zero(); LANES];
120 for chunk in chunks {
121 for i in 0..LANES {
122 sums[i] += chunk[i].as_().powi(2);
123 }
124 }
125 (sum + sums.iter().copied().sum::<Output>()).sqrt()
126}
127
128#[inline]
133pub fn norm_l2<T: Normalize>(vector: &[T]) -> f32 {
134 T::norm_l2(vector)
135}
136
137#[cfg(test)]
138mod tests {
139 use super::*;
140 use crate::test_utils::{arbitrary_bf16, arbitrary_f16, arbitrary_f32, arbitrary_f64};
141 use num_traits::ToPrimitive;
142 use proptest::prelude::*;
143
144 fn norm_l2_reference(data: &[f64]) -> f32 {
146 data.iter().map(|v| (*v * *v)).sum::<f64>().sqrt() as f32
147 }
148
149 fn do_norm_l2_test<T: Normalize + ToPrimitive>(
150 data: &[T],
151 ) -> std::result::Result<(), TestCaseError> {
152 let f64_data = data
153 .iter()
154 .map(|v| v.to_f64().unwrap())
155 .collect::<Vec<f64>>();
156
157 let result = norm_l2(data);
158 let reference = norm_l2_reference(&f64_data);
159
160 prop_assert!(approx::relative_eq!(result, reference, max_relative = 1e-6));
161 Ok(())
162 }
163
164 proptest::proptest! {
165 #[test]
166 fn test_l2_norm_f16(data in prop::collection::vec(arbitrary_f16(), 4..4048)) {
167 do_norm_l2_test(&data)?;
168 }
169
170 #[test]
171 fn test_l2_norm_bf16(data in prop::collection::vec(arbitrary_bf16(), 4..4048)){
172 do_norm_l2_test(&data)?;
173 }
174
175 #[test]
176 fn test_l2_norm_f32(data in prop::collection::vec(arbitrary_f32(), 4..4048)){
177 do_norm_l2_test(&data)?;
178 }
179
180 #[test]
181 fn test_l2_norm_f64(data in prop::collection::vec(arbitrary_f64(), 4..4048)){
182 do_norm_l2_test(&data)?;
183 }
184 }
185}