Skip to main content

diskann_vector/
norm.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use diskann_wide::{
7    arch::{Target1, Target2},
8    Architecture,
9};
10
11use crate::{
12    distance::{implementations::L1NormFunctor, InnerProduct},
13    Half, MathematicalValue, Norm,
14};
15
16/// Evaluate the square of the L2 norm of the argument.
17///
18/// # Implementation
19///
20/// The implementations behind this method use a naive approach to computing the norm.
21/// This is faster but less accurate than more precise methods.
22#[derive(Debug, Clone, Copy)]
23pub struct FastL2NormSquared;
24
25impl<T, To> Norm<T, To> for FastL2NormSquared
26where
27    Self: Target1<diskann_wide::arch::Current, To, T>,
28    T: Copy,
29    To: Copy,
30{
31    #[inline]
32    fn evaluate(&self, x: T) -> To {
33        // As an implementation note: if the implementation of `InnerProduct` is inlined
34        // into the callsite, then LLVM recognizes that the two ranges overlap and optimizes
35        // out half the loads.
36        //
37        // This means we don't need to reimplement *all* the different unrolling strategies.
38        //
39        // The down-side is that perhaps the best unrolling strategy is slightly different
40        // for norm calculations. It is at least a start though.
41        self.run(diskann_wide::ARCH, x)
42    }
43}
44
45impl<A, T, To> Target1<A, To, T> for FastL2NormSquared
46where
47    A: Architecture,
48    InnerProduct: Target2<A, MathematicalValue<To>, T, T>,
49    T: Copy,
50    To: Copy,
51{
52    #[inline(always)]
53    fn run(self, arch: A, x: T) -> To {
54        (InnerProduct {}).run(arch, x, x).into_inner()
55    }
56}
57
58/// Evaluate the L2 norm of the argument.
59///
60/// # Implementation
61///
62/// The implementations behind this method use a naive approach to computing the norm.
63/// This is faster but less accurate than more precise methods.
64#[derive(Debug, Clone, Copy)]
65pub struct FastL2Norm;
66
67impl<T> Norm<T, f32> for FastL2Norm
68where
69    Self: Target1<diskann_wide::arch::Current, f32, T>,
70{
71    #[inline]
72    fn evaluate(&self, x: T) -> f32 {
73        self.run(diskann_wide::ARCH, x)
74    }
75}
76
77impl<A, T> Target1<A, f32, T> for FastL2Norm
78where
79    A: Architecture,
80    FastL2NormSquared: Target1<A, f32, T>,
81    T: Copy,
82{
83    #[inline(always)]
84    fn run(self, arch: A, x: T) -> f32 {
85        (FastL2NormSquared).run(arch, x).sqrt()
86    }
87}
88
89/// Evaluate the L1 norm of the argument.
90///
91/// # Implementation
92///
93/// This implementation uses the SIMD-optimized L1Norm from `distance::simd`.
94///
95/// ==================================================================================================
96/// NOTE: L1Norm IS A LOGICAL UNARY OPERATION
97/// --------------------------------------------------------------------------------------------------
98/// Although wired through the generic binary 'SIMDSchema'/'simd_op' infrastructure (which expects
99/// two input slices of equal length), 'L1Norm' conceptually computes: sum_i |x_i|
100/// The right-hand operand is completely ignored and exists ONLY to satisfy the shared execution
101/// machinery (loop tiling, epilogue handling, etc.).
102/// ==================================================================================================
103#[derive(Debug, Clone, Copy)]
104pub struct L1Norm;
105
106impl<T> Norm<T, f32> for L1Norm
107where
108    Self: Target1<diskann_wide::arch::Current, f32, T>,
109{
110    #[inline]
111    fn evaluate(&self, x: T) -> f32 {
112        self.run(diskann_wide::ARCH, x)
113    }
114}
115
116impl<A, T, To> Target1<A, To, T> for L1Norm
117where
118    A: Architecture,
119    L1NormFunctor: Target2<A, To, T, T>,
120    T: Copy,
121    To: Copy,
122{
123    #[inline(always)]
124    fn run(self, arch: A, x: T) -> To {
125        (L1NormFunctor {}).run(arch, x, x)
126    }
127}
128
129/// Evaluate the LInf norm of the argument.
130///
131/// # Implementation
132///
133/// Closed implementation:
134///  Supported input types: f32, Half.
135///  f32 path: simple scalar loop using abs and max.
136///  Half path: widens each element with 'diskann_wide::cast_f16_to_f32' then applies abs
137///    and max.
138///
139/// # Performance
140///
141/// The Half widening (cast_f16_to_f32) is per element and does not auto-vectorize well,
142/// so LInfNorm on large Half slices may become a throughput bottleneck compared to
143/// an explicit SIMD reduction (e.g. loading f16x8, converting once, then doing
144/// lane-wise abs & max in f32).
145/// Callers with large Half inputs should be aware of the potential bottleneck.
146///
147/// Current behavior is correct but potentially slower than expected for large Half slices.
148#[derive(Debug, Clone, Copy)]
149pub struct LInfNorm;
150
151impl Norm<&[f32], f32> for LInfNorm {
152    #[inline]
153    fn evaluate(&self, x: &[f32]) -> f32 {
154        self.run(diskann_wide::ARCH, x)
155    }
156}
157
158impl Norm<&[Half], f32> for LInfNorm {
159    #[inline]
160    fn evaluate(&self, x: &[Half]) -> f32 {
161        self.run(diskann_wide::ARCH, x)
162    }
163}
164
165impl<A> Target1<A, f32, &[f32]> for LInfNorm
166where
167    A: Architecture,
168{
169    #[inline(always)]
170    fn run(self, _: A, x: &[f32]) -> f32 {
171        let mut m = 0.0f32;
172        for &v in x {
173            m = m.max(v.abs());
174        }
175        m
176    }
177}
178
179impl<A> Target1<A, f32, &[Half]> for LInfNorm
180where
181    A: Architecture,
182{
183    #[inline(always)]
184    fn run(self, _: A, x: &[Half]) -> f32 {
185        let mut m = 0.0f32;
186        for &v in x {
187            m = m.max(diskann_wide::cast_f16_to_f32(v).abs());
188        }
189        m
190    }
191}
192
193///////////
194// Tests //
195///////////
196
197#[cfg(test)]
198mod tests {
199    use rand::{
200        distr::{Distribution, StandardUniform, Uniform},
201        rngs::StdRng,
202        SeedableRng,
203    };
204
205    use super::*;
206    use crate::Half;
207
208    trait ReferenceL2NormSquared {
209        fn reference_l2_norm_squared(self) -> f32;
210    }
211
212    impl ReferenceL2NormSquared for &[f32] {
213        fn reference_l2_norm_squared(self) -> f32 {
214            self.iter().map(|x| x * x).sum()
215        }
216    }
217    impl ReferenceL2NormSquared for &[Half] {
218        fn reference_l2_norm_squared(self) -> f32 {
219            self.iter()
220                .map(|x| {
221                    let x = x.to_f32();
222                    x * x
223                })
224                .sum()
225        }
226    }
227    impl ReferenceL2NormSquared for &[i8] {
228        fn reference_l2_norm_squared(self) -> f32 {
229            self.iter()
230                .map(|x| {
231                    let x: i32 = (*x).into();
232                    x * x
233                })
234                .sum::<i32>() as f32
235        }
236    }
237    impl ReferenceL2NormSquared for &[u8] {
238        fn reference_l2_norm_squared(self) -> f32 {
239            self.iter()
240                .map(|x| {
241                    let x: i32 = (*x).into();
242                    x * x
243                })
244                .sum::<i32>() as f32
245        }
246    }
247
248    // For testing the fast L2 norm, we are less concerned about numerical accuracy and more
249    // that the right sequence of operations are being performed.
250    //
251    // To that end, try to keep the inpout distribution "nice" to avoid dealing with rounding
252    // issues.
253    fn test_fast_l2_norm<T>(generator: &mut dyn FnMut(&mut [T]), max_dim: usize, num_trials: usize)
254    where
255        T: Copy + Default + std::fmt::Debug,
256        for<'a> &'a [T]: ReferenceL2NormSquared,
257        FastL2NormSquared: for<'a> Norm<&'a [T], f32>,
258        FastL2Norm: for<'a> Norm<&'a [T], f32>,
259    {
260        for dim in 0..max_dim {
261            let mut v = vec![T::default(); dim];
262            for _ in 0..num_trials {
263                // Generate the test case.
264                generator(&mut v);
265                let reference = v.reference_l2_norm_squared();
266                let fast = (FastL2NormSquared).evaluate(&*v);
267
268                // We should keep the distribution nice enough that this is exact.
269                assert_eq!(reference, fast, "failed on dim {} with input: {:?}", dim, v);
270
271                let norm = (FastL2Norm).evaluate(&*v);
272                assert_eq!(
273                    norm,
274                    fast.sqrt(),
275                    "failed on dim {} with input: {:?}",
276                    dim,
277                    v
278                );
279            }
280        }
281    }
282
283    const MAX_DIM: usize = 256;
284    cfg_if::cfg_if! {
285        if #[cfg(miri)] {
286            const NUM_TRIALS: usize = 1;
287        } else {
288            const NUM_TRIALS: usize = 16;
289        }
290    }
291
292    #[test]
293    fn test_fast_l2_norm_f32() {
294        let mut rng = StdRng::seed_from_u64(0x4033f5b85e3513f3);
295        let distribution = Uniform::<i64>::new(-16, 16).unwrap();
296        let mut generator = |v: &mut [f32]| {
297            v.iter_mut().for_each(|v| {
298                *v = distribution.sample(&mut rng) as f32;
299            });
300        };
301        test_fast_l2_norm(&mut generator, MAX_DIM, NUM_TRIALS);
302    }
303
304    #[test]
305    fn test_fast_l2_norm_f16() {
306        let mut rng = StdRng::seed_from_u64(0xfb0cf009aaa309f8);
307        let distribution = Uniform::<i64>::new(-16, 16).unwrap();
308        let mut generator = |v: &mut [Half]| {
309            v.iter_mut().for_each(|v| {
310                *v = Half::from_f32(distribution.sample(&mut rng) as f32);
311            });
312        };
313        test_fast_l2_norm(&mut generator, MAX_DIM, NUM_TRIALS);
314    }
315
316    #[test]
317    fn test_fast_l2_norm_u8() {
318        let mut rng = StdRng::seed_from_u64(0xa119d2f91656ae35);
319        let distribution = StandardUniform {};
320        let mut generator = |v: &mut [u8]| {
321            v.iter_mut().for_each(|v| {
322                *v = distribution.sample(&mut rng);
323            });
324        };
325        test_fast_l2_norm(&mut generator, MAX_DIM, NUM_TRIALS);
326    }
327
328    #[test]
329    fn test_fast_l2_norm_i8() {
330        let mut rng = StdRng::seed_from_u64(0x9d96fbf7c321886d);
331        let distribution = StandardUniform {};
332        let mut generator = |v: &mut [i8]| {
333            v.iter_mut().for_each(|v| {
334                *v = distribution.sample(&mut rng);
335            });
336        };
337        test_fast_l2_norm(&mut generator, MAX_DIM, NUM_TRIALS);
338    }
339
340    #[test]
341    fn test_linf_norm_f16() {
342        let mut rng = StdRng::seed_from_u64(0xfb0cf009aaa309f8);
343        let distribution = Uniform::<i64>::new(-16, 16).unwrap();
344        let mut generator = |v: &mut [Half]| {
345            v.iter_mut().for_each(|v| {
346                *v = Half::from_f32(distribution.sample(&mut rng) as f32);
347            });
348        };
349
350        for dim in 0..MAX_DIM {
351            let mut dst = vec![Half::default(); dim];
352            for _ in 0..NUM_TRIALS {
353                generator(&mut dst);
354                let got = (LInfNorm).evaluate(&*dst);
355                let expected = dst
356                    .iter()
357                    .map(|v| diskann_wide::cast_f16_to_f32(*v).abs())
358                    .fold(0.0f32, f32::max);
359
360                assert_eq!(
361                    got, expected,
362                    "LInf(f16) expected {}, got {} - dim {}",
363                    expected, got, dim
364                );
365            }
366        }
367    }
368
369    #[test]
370    fn test_linf_norm_f32() {
371        let mut rng = StdRng::seed_from_u64(0x4033f5b85e3513f3);
372        let distribution = Uniform::<i64>::new(-16, 16).unwrap();
373        let mut generator = |v: &mut [f32]| {
374            v.iter_mut().for_each(|v| {
375                *v = distribution.sample(&mut rng) as f32;
376            });
377        };
378
379        for dim in 0..MAX_DIM {
380            let mut dst = vec![f32::default(); dim];
381            for _ in 0..NUM_TRIALS {
382                generator(&mut dst);
383                let got = (LInfNorm).evaluate(&*dst);
384                let expected = dst.iter().map(|v| v.abs()).fold(0.0f32, f32::max);
385
386                assert_eq!(
387                    got, expected,
388                    "LInf(f32) expected {}, got {} - dim {}",
389                    expected, got, dim
390                );
391            }
392        }
393    }
394}