lance_linalg/distance/
norm_l2.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4use std::{iter::Sum, ops::AddAssign};
5
6use half::{bf16, f16};
7#[cfg(feature = "fp16kernels")]
8use lance_core::utils::cpu::SimdSupport;
9#[allow(unused_imports)]
10use lance_core::utils::cpu::FP16_SIMD_SUPPORT;
11use num_traits::{AsPrimitive, Float, Num};
12
13/// L2 normalization
14pub trait Normalize: Num {
15    /// L2 Normalization over a Vector.
16    fn norm_l2(vector: &[Self]) -> f32;
17}
18
19#[cfg(feature = "fp16kernels")]
20mod kernel {
21    use super::*;
22
23    // These are the `norm_l2_f16` function in f16.c. Our build.rs script compiles
24    // a version of this file for each SIMD level with different suffixes.
25    extern "C" {
26        #[cfg(target_arch = "aarch64")]
27        pub fn norm_l2_f16_neon(ptr: *const f16, len: u32) -> f32;
28        #[cfg(all(kernel_support = "avx512", target_arch = "x86_64"))]
29        pub fn norm_l2_f16_avx512(ptr: *const f16, len: u32) -> f32;
30        #[cfg(target_arch = "x86_64")]
31        pub fn norm_l2_f16_avx2(ptr: *const f16, len: u32) -> f32;
32        #[cfg(target_arch = "loongarch64")]
33        pub fn norm_l2_f16_lsx(ptr: *const f16, len: u32) -> f32;
34        #[cfg(target_arch = "loongarch64")]
35        pub fn norm_l2_f16_lasx(ptr: *const f16, len: u32) -> f32;
36    }
37}
38
39impl Normalize for u8 {
40    #[inline]
41    fn norm_l2(vector: &[Self]) -> f32 {
42        norm_l2_impl::<Self, f32, 16>(vector)
43    }
44}
45
46impl Normalize for f16 {
47    #[inline]
48    fn norm_l2(vector: &[Self]) -> f32 {
49        match *FP16_SIMD_SUPPORT {
50            #[cfg(all(feature = "fp16kernels", target_arch = "aarch64"))]
51            SimdSupport::Neon => unsafe {
52                kernel::norm_l2_f16_neon(vector.as_ptr(), vector.len() as u32)
53            },
54            #[cfg(all(
55                feature = "fp16kernels",
56                kernel_support = "avx512",
57                target_arch = "x86_64"
58            ))]
59            SimdSupport::Avx512 => unsafe {
60                kernel::norm_l2_f16_avx512(vector.as_ptr(), vector.len() as u32)
61            },
62            #[cfg(all(feature = "fp16kernels", target_arch = "x86_64"))]
63            SimdSupport::Avx2 => unsafe {
64                kernel::norm_l2_f16_avx2(vector.as_ptr(), vector.len() as u32)
65            },
66            #[cfg(all(feature = "fp16kernels", target_arch = "loongarch64"))]
67            SimdSupport::Lasx => unsafe {
68                kernel::norm_l2_f16_lasx(vector.as_ptr(), vector.len() as u32)
69            },
70            #[cfg(all(feature = "fp16kernels", target_arch = "loongarch64"))]
71            SimdSupport::Lsx => unsafe {
72                kernel::norm_l2_f16_lsx(vector.as_ptr(), vector.len() as u32)
73            },
74            _ => norm_l2_impl::<Self, f32, 32>(vector),
75        }
76    }
77}
78
79impl Normalize for bf16 {
80    #[inline]
81    fn norm_l2(vector: &[Self]) -> f32 {
82        norm_l2_impl::<Self, f32, 32>(vector)
83    }
84}
85
86impl Normalize for f32 {
87    #[inline]
88    fn norm_l2(vector: &[Self]) -> f32 {
89        norm_l2_impl::<Self, Self, 16>(vector)
90    }
91}
92
93impl Normalize for f64 {
94    #[inline]
95    fn norm_l2(vector: &[Self]) -> f32 {
96        norm_l2_impl::<Self, Self, 8>(vector) as f32
97    }
98}
99
100/// NOTE: this is only pub for benchmarking purposes
101#[inline]
102pub fn norm_l2_impl<
103    T: AsPrimitive<Output>,
104    Output: Float + Sum + 'static + AddAssign,
105    const LANES: usize,
106>(
107    vector: &[T],
108) -> Output {
109    let chunks = vector.chunks_exact(LANES);
110    let sum = if chunks.remainder().is_empty() {
111        Output::zero()
112    } else {
113        chunks
114            .remainder()
115            .iter()
116            .map(|&v| v.as_().powi(2))
117            .sum::<Output>()
118    };
119    let mut sums = [Output::zero(); LANES];
120    for chunk in chunks {
121        for i in 0..LANES {
122            sums[i] += chunk[i].as_().powi(2);
123        }
124    }
125    (sum + sums.iter().copied().sum::<Output>()).sqrt()
126}
127
128/// Normalize a vector.
129///
130/// The parameters must be cache line aligned. For example, from
131/// Arrow Arrays, i.e., Float32Array
132#[inline]
133pub fn norm_l2<T: Normalize>(vector: &[T]) -> f32 {
134    T::norm_l2(vector)
135}
136
137#[cfg(test)]
138mod tests {
139    use super::*;
140    use crate::test_utils::{arbitrary_bf16, arbitrary_f16, arbitrary_f32, arbitrary_f64};
141    use num_traits::ToPrimitive;
142    use proptest::prelude::*;
143
144    /// Reference implementation of L2 norm.
145    fn norm_l2_reference(data: &[f64]) -> f32 {
146        data.iter().map(|v| (*v * *v)).sum::<f64>().sqrt() as f32
147    }
148
149    fn do_norm_l2_test<T: Normalize + ToPrimitive>(
150        data: &[T],
151    ) -> std::result::Result<(), TestCaseError> {
152        let f64_data = data
153            .iter()
154            .map(|v| v.to_f64().unwrap())
155            .collect::<Vec<f64>>();
156
157        let result = norm_l2(data);
158        let reference = norm_l2_reference(&f64_data);
159
160        prop_assert!(approx::relative_eq!(result, reference, max_relative = 1e-6));
161        Ok(())
162    }
163
164    proptest::proptest! {
165        #[test]
166        fn test_l2_norm_f16(data in prop::collection::vec(arbitrary_f16(), 4..4048)) {
167            do_norm_l2_test(&data)?;
168        }
169
170        #[test]
171        fn test_l2_norm_bf16(data in prop::collection::vec(arbitrary_bf16(), 4..4048)){
172            do_norm_l2_test(&data)?;
173        }
174
175        #[test]
176        fn test_l2_norm_f32(data in prop::collection::vec(arbitrary_f32(), 4..4048)){
177            do_norm_l2_test(&data)?;
178        }
179
180        #[test]
181        fn test_l2_norm_f64(data in prop::collection::vec(arbitrary_f64(), 4..4048)){
182            do_norm_l2_test(&data)?;
183        }
184    }
185}