1use diskann_wide::{
7 arch::{Target1, Target2},
8 Architecture,
9};
10
11use crate::{
12 distance::{implementations::L1NormFunctor, InnerProduct},
13 Half, MathematicalValue, Norm,
14};
15
16#[derive(Debug, Clone, Copy)]
23pub struct FastL2NormSquared;
24
25impl<T, To> Norm<T, To> for FastL2NormSquared
26where
27 Self: Target1<diskann_wide::arch::Current, To, T>,
28 T: Copy,
29 To: Copy,
30{
31 #[inline]
32 fn evaluate(&self, x: T) -> To {
33 self.run(diskann_wide::ARCH, x)
42 }
43}
44
45impl<A, T, To> Target1<A, To, T> for FastL2NormSquared
46where
47 A: Architecture,
48 InnerProduct: Target2<A, MathematicalValue<To>, T, T>,
49 T: Copy,
50 To: Copy,
51{
52 #[inline(always)]
53 fn run(self, arch: A, x: T) -> To {
54 (InnerProduct {}).run(arch, x, x).into_inner()
55 }
56}
57
58#[derive(Debug, Clone, Copy)]
65pub struct FastL2Norm;
66
67impl<T> Norm<T, f32> for FastL2Norm
68where
69 Self: Target1<diskann_wide::arch::Current, f32, T>,
70{
71 #[inline]
72 fn evaluate(&self, x: T) -> f32 {
73 self.run(diskann_wide::ARCH, x)
74 }
75}
76
77impl<A, T> Target1<A, f32, T> for FastL2Norm
78where
79 A: Architecture,
80 FastL2NormSquared: Target1<A, f32, T>,
81 T: Copy,
82{
83 #[inline(always)]
84 fn run(self, arch: A, x: T) -> f32 {
85 (FastL2NormSquared).run(arch, x).sqrt()
86 }
87}
88
89#[derive(Debug, Clone, Copy)]
104pub struct L1Norm;
105
106impl<T> Norm<T, f32> for L1Norm
107where
108 Self: Target1<diskann_wide::arch::Current, f32, T>,
109{
110 #[inline]
111 fn evaluate(&self, x: T) -> f32 {
112 self.run(diskann_wide::ARCH, x)
113 }
114}
115
116impl<A, T, To> Target1<A, To, T> for L1Norm
117where
118 A: Architecture,
119 L1NormFunctor: Target2<A, To, T, T>,
120 T: Copy,
121 To: Copy,
122{
123 #[inline(always)]
124 fn run(self, arch: A, x: T) -> To {
125 (L1NormFunctor {}).run(arch, x, x)
126 }
127}
128
129#[derive(Debug, Clone, Copy)]
149pub struct LInfNorm;
150
151impl Norm<&[f32], f32> for LInfNorm {
152 #[inline]
153 fn evaluate(&self, x: &[f32]) -> f32 {
154 self.run(diskann_wide::ARCH, x)
155 }
156}
157
158impl Norm<&[Half], f32> for LInfNorm {
159 #[inline]
160 fn evaluate(&self, x: &[Half]) -> f32 {
161 self.run(diskann_wide::ARCH, x)
162 }
163}
164
165impl<A> Target1<A, f32, &[f32]> for LInfNorm
166where
167 A: Architecture,
168{
169 #[inline(always)]
170 fn run(self, _: A, x: &[f32]) -> f32 {
171 let mut m = 0.0f32;
172 for &v in x {
173 m = m.max(v.abs());
174 }
175 m
176 }
177}
178
179impl<A> Target1<A, f32, &[Half]> for LInfNorm
180where
181 A: Architecture,
182{
183 #[inline(always)]
184 fn run(self, _: A, x: &[Half]) -> f32 {
185 let mut m = 0.0f32;
186 for &v in x {
187 m = m.max(diskann_wide::cast_f16_to_f32(v).abs());
188 }
189 m
190 }
191}
192
193#[cfg(test)]
198mod tests {
199 use rand::{
200 distr::{Distribution, StandardUniform, Uniform},
201 rngs::StdRng,
202 SeedableRng,
203 };
204
205 use super::*;
206 use crate::Half;
207
208 trait ReferenceL2NormSquared {
209 fn reference_l2_norm_squared(self) -> f32;
210 }
211
212 impl ReferenceL2NormSquared for &[f32] {
213 fn reference_l2_norm_squared(self) -> f32 {
214 self.iter().map(|x| x * x).sum()
215 }
216 }
217 impl ReferenceL2NormSquared for &[Half] {
218 fn reference_l2_norm_squared(self) -> f32 {
219 self.iter()
220 .map(|x| {
221 let x = x.to_f32();
222 x * x
223 })
224 .sum()
225 }
226 }
227 impl ReferenceL2NormSquared for &[i8] {
228 fn reference_l2_norm_squared(self) -> f32 {
229 self.iter()
230 .map(|x| {
231 let x: i32 = (*x).into();
232 x * x
233 })
234 .sum::<i32>() as f32
235 }
236 }
237 impl ReferenceL2NormSquared for &[u8] {
238 fn reference_l2_norm_squared(self) -> f32 {
239 self.iter()
240 .map(|x| {
241 let x: i32 = (*x).into();
242 x * x
243 })
244 .sum::<i32>() as f32
245 }
246 }
247
248 fn test_fast_l2_norm<T>(generator: &mut dyn FnMut(&mut [T]), max_dim: usize, num_trials: usize)
254 where
255 T: Copy + Default + std::fmt::Debug,
256 for<'a> &'a [T]: ReferenceL2NormSquared,
257 FastL2NormSquared: for<'a> Norm<&'a [T], f32>,
258 FastL2Norm: for<'a> Norm<&'a [T], f32>,
259 {
260 for dim in 0..max_dim {
261 let mut v = vec![T::default(); dim];
262 for _ in 0..num_trials {
263 generator(&mut v);
265 let reference = v.reference_l2_norm_squared();
266 let fast = (FastL2NormSquared).evaluate(&*v);
267
268 assert_eq!(reference, fast, "failed on dim {} with input: {:?}", dim, v);
270
271 let norm = (FastL2Norm).evaluate(&*v);
272 assert_eq!(
273 norm,
274 fast.sqrt(),
275 "failed on dim {} with input: {:?}",
276 dim,
277 v
278 );
279 }
280 }
281 }
282
283 const MAX_DIM: usize = 256;
284 cfg_if::cfg_if! {
285 if #[cfg(miri)] {
286 const NUM_TRIALS: usize = 1;
287 } else {
288 const NUM_TRIALS: usize = 16;
289 }
290 }
291
292 #[test]
293 fn test_fast_l2_norm_f32() {
294 let mut rng = StdRng::seed_from_u64(0x4033f5b85e3513f3);
295 let distribution = Uniform::<i64>::new(-16, 16).unwrap();
296 let mut generator = |v: &mut [f32]| {
297 v.iter_mut().for_each(|v| {
298 *v = distribution.sample(&mut rng) as f32;
299 });
300 };
301 test_fast_l2_norm(&mut generator, MAX_DIM, NUM_TRIALS);
302 }
303
304 #[test]
305 fn test_fast_l2_norm_f16() {
306 let mut rng = StdRng::seed_from_u64(0xfb0cf009aaa309f8);
307 let distribution = Uniform::<i64>::new(-16, 16).unwrap();
308 let mut generator = |v: &mut [Half]| {
309 v.iter_mut().for_each(|v| {
310 *v = Half::from_f32(distribution.sample(&mut rng) as f32);
311 });
312 };
313 test_fast_l2_norm(&mut generator, MAX_DIM, NUM_TRIALS);
314 }
315
316 #[test]
317 fn test_fast_l2_norm_u8() {
318 let mut rng = StdRng::seed_from_u64(0xa119d2f91656ae35);
319 let distribution = StandardUniform {};
320 let mut generator = |v: &mut [u8]| {
321 v.iter_mut().for_each(|v| {
322 *v = distribution.sample(&mut rng);
323 });
324 };
325 test_fast_l2_norm(&mut generator, MAX_DIM, NUM_TRIALS);
326 }
327
328 #[test]
329 fn test_fast_l2_norm_i8() {
330 let mut rng = StdRng::seed_from_u64(0x9d96fbf7c321886d);
331 let distribution = StandardUniform {};
332 let mut generator = |v: &mut [i8]| {
333 v.iter_mut().for_each(|v| {
334 *v = distribution.sample(&mut rng);
335 });
336 };
337 test_fast_l2_norm(&mut generator, MAX_DIM, NUM_TRIALS);
338 }
339
340 #[test]
341 fn test_linf_norm_f16() {
342 let mut rng = StdRng::seed_from_u64(0xfb0cf009aaa309f8);
343 let distribution = Uniform::<i64>::new(-16, 16).unwrap();
344 let mut generator = |v: &mut [Half]| {
345 v.iter_mut().for_each(|v| {
346 *v = Half::from_f32(distribution.sample(&mut rng) as f32);
347 });
348 };
349
350 for dim in 0..MAX_DIM {
351 let mut dst = vec![Half::default(); dim];
352 for _ in 0..NUM_TRIALS {
353 generator(&mut dst);
354 let got = (LInfNorm).evaluate(&*dst);
355 let expected = dst
356 .iter()
357 .map(|v| diskann_wide::cast_f16_to_f32(*v).abs())
358 .fold(0.0f32, f32::max);
359
360 assert_eq!(
361 got, expected,
362 "LInf(f16) expected {}, got {} - dim {}",
363 expected, got, dim
364 );
365 }
366 }
367 }
368
369 #[test]
370 fn test_linf_norm_f32() {
371 let mut rng = StdRng::seed_from_u64(0x4033f5b85e3513f3);
372 let distribution = Uniform::<i64>::new(-16, 16).unwrap();
373 let mut generator = |v: &mut [f32]| {
374 v.iter_mut().for_each(|v| {
375 *v = distribution.sample(&mut rng) as f32;
376 });
377 };
378
379 for dim in 0..MAX_DIM {
380 let mut dst = vec![f32::default(); dim];
381 for _ in 0..NUM_TRIALS {
382 generator(&mut dst);
383 let got = (LInfNorm).evaluate(&*dst);
384 let expected = dst.iter().map(|v| v.abs()).fold(0.0f32, f32::max);
385
386 assert_eq!(
387 got, expected,
388 "LInf(f32) expected {}, got {} - dim {}",
389 expected, got, dim
390 );
391 }
392 }
393 }
394}