Skip to main content

diskann_vector/
conversion.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use std::convert::{AsMut, AsRef};
7
8use diskann_wide::{arch::Target2, Architecture, Const, Constant, SIMDCast, SIMDVector};
9use half::f16;
10
11/// Perform a numeric cast on a slice of values.
12///
13/// This trait is intended to have the following numerical behavior:
14///
15/// 1. If a lossless conversion between types is available, use that.
16/// 2. Otherwise, if the two type are floating point types, use a round-to-nearest strategy.
17/// 3. Otherwise, try to behave like the Rust `as` numeric cast.
18///
19/// The main reason we can't just say "behave like "as"" is because Rust does not have
20/// a native `f16` type, which this crate supports.
21pub trait CastFromSlice<From> {
22    fn cast_from_slice(self, from: From);
23}
24
25macro_rules! use_simd_cast_from_slice {
26    ($from:ty => $to:ty) => {
27        impl CastFromSlice<&[$from]> for &mut [$to] {
28            #[inline(always)]
29            fn cast_from_slice(self, from: &[$from]) {
30                SliceCast::<$to, $from>::new().run(diskann_wide::ARCH, self, from)
31            }
32        }
33
34        impl<const N: usize> CastFromSlice<&[$from; N]> for &mut [$to; N] {
35            #[inline(always)]
36            fn cast_from_slice(self, from: &[$from; N]) {
37                SliceCast::<$to, $from>::new().run(diskann_wide::ARCH, self, from)
38            }
39        }
40    };
41}
42
43use_simd_cast_from_slice!(f32 => f16);
44use_simd_cast_from_slice!(f16 => f32);
45
46/// A zero-sized type providing implementations of [`diskann_wide::arch::Target2`] to provide
47/// platform-dependent conversions between slices of the two generic types.
48#[derive(Debug, Default, Clone, Copy)]
49pub struct SliceCast<To, From> {
50    _marker: std::marker::PhantomData<(To, From)>,
51}
52
53impl<To, From> SliceCast<To, From> {
54    pub fn new() -> Self {
55        Self {
56            _marker: std::marker::PhantomData,
57        }
58    }
59}
60
61// Non-SIMD Instantiations
62
63impl<T, U> Target2<diskann_wide::arch::Scalar, (), T, U> for SliceCast<f16, f32>
64where
65    T: AsMut<[f16]>,
66    U: AsRef<[f32]>,
67{
68    #[inline(always)]
69    fn run(self, _: diskann_wide::arch::Scalar, mut to: T, from: U) {
70        let to = to.as_mut();
71        let from = from.as_ref();
72        std::iter::zip(to.iter_mut(), from.iter()).for_each(|(to, from)| {
73            *to = diskann_wide::cast_f32_to_f16(*from);
74        })
75    }
76}
77
78impl<T, U> Target2<diskann_wide::arch::Scalar, (), T, U> for SliceCast<f32, f16>
79where
80    T: AsMut<[f32]>,
81    U: AsRef<[f16]>,
82{
83    #[inline(always)]
84    fn run(self, _: diskann_wide::arch::Scalar, mut to: T, from: U) {
85        let to = to.as_mut();
86        let from = from.as_ref();
87        std::iter::zip(to.iter_mut(), from.iter()).for_each(|(to, from)| {
88            *to = diskann_wide::cast_f16_to_f32(*from);
89        })
90    }
91}
92
93// SIMD Instantiations
94#[cfg(target_arch = "x86_64")]
95impl<T, U, To, From> Target2<diskann_wide::arch::x86_64::V4, (), T, U> for SliceCast<To, From>
96where
97    T: AsMut<[To]>,
98    U: AsRef<[From]>,
99    diskann_wide::arch::x86_64::V4: SIMDConvert<To, From>,
100{
101    #[inline(always)]
102    fn run(self, arch: diskann_wide::arch::x86_64::V4, mut to: T, from: U) {
103        simd_convert(arch, to.as_mut(), from.as_ref())
104    }
105}
106
107#[cfg(target_arch = "x86_64")]
108impl<T, U, To, From> Target2<diskann_wide::arch::x86_64::V3, (), T, U> for SliceCast<To, From>
109where
110    T: AsMut<[To]>,
111    U: AsRef<[From]>,
112    diskann_wide::arch::x86_64::V3: SIMDConvert<To, From>,
113{
114    #[inline(always)]
115    fn run(self, arch: diskann_wide::arch::x86_64::V3, mut to: T, from: U) {
116        simd_convert(arch, to.as_mut(), from.as_ref())
117    }
118}
119
120#[cfg(target_arch = "aarch64")]
121impl<T, U, To, From> Target2<diskann_wide::arch::aarch64::Neon, (), T, U> for SliceCast<To, From>
122where
123    T: AsMut<[To]>,
124    U: AsRef<[From]>,
125    diskann_wide::arch::aarch64::Neon: SIMDConvert<To, From>,
126{
127    #[inline(always)]
128    fn run(self, arch: diskann_wide::arch::aarch64::Neon, mut to: T, from: U) {
129        simd_convert(arch, to.as_mut(), from.as_ref())
130    }
131}
132
133/////////////////////////////
134// General SIMD Conversion //
135/////////////////////////////
136
137/// A helper trait to fill in the gaps for the unrolled `simd_convert` method.
138trait SIMDConvert<To, From>: Architecture {
139    /// A constant encoding the the SIMD width of the underlying schema.
140    type Width: Constant<Type = usize>;
141
142    /// The SIMD Vector for the converted-to type.
143    type WideTo: SIMDVector<Arch = Self, Scalar = To, ConstLanes = Self::Width>;
144
145    /// The SIMD Vector for the converted-from type.
146    type WideFrom: SIMDVector<Arch = Self, Scalar = From, ConstLanes = Self::Width>;
147
148    /// The method that actually does the vector-wide conversion.
149    fn simd_convert(from: Self::WideFrom) -> Self::WideTo;
150
151    /// Delegate routing for handling conversion lengths less than the vector width.
152    ///
153    /// The canonical implementation uses predicated loads, but implementations may wish
154    /// to use a scalar loop instead.
155    ///
156    /// # Safety
157    ///
158    /// This trait will only be called when the following guarantees are made:
159    ///
160    /// * `pto` will point to properly aligned memory that is valid for writes on the
161    ///   range `[pto, pto + len)`.
162    /// * `pfrom` will point to properly aligned memory that is valid for reads on the
163    ///   range `[pfrom, pfrom + len)`.
164    /// * The memory ranges covered by `pto` and `pfrom` must not alias.
165    #[inline(always)]
166    unsafe fn handle_small(self, pto: *mut To, pfrom: *const From, len: usize) {
167        let from = Self::WideFrom::load_simd_first(self, pfrom, len);
168        let to = Self::simd_convert(from);
169        to.store_simd_first(pto, len);
170    }
171
172    /// !! Do not extend this function !!
173    ///
174    /// Due to limitations on how associated constants can be used, we need a function
175    /// to access the SIMD width and rely on the compiler to constant propagate the result.
176    #[inline(always)]
177    fn get_simd_width() -> usize {
178        Self::Width::value()
179    }
180}
181
182#[inline(never)]
183#[allow(clippy::panic)]
184fn emit_length_error(xlen: usize, ylen: usize) -> ! {
185    panic!(
186        "lengths must be equal, instead got: xlen = {}, ylen = {}",
187        xlen, ylen
188    )
189}
190
191/// Convert each element of `from` into its corresponding position in `to` using the
192/// conversion rule applied by `S`.
193///
194/// # Panics
195///
196/// Panics if `to.len() != from.len()`.
197///
198/// # Implementation Notes
199///
200/// This function will only call `A::handle_small` if the total length of the processed
201/// slices is less that the underlying SIMD width.
202///
203/// Otherwise, we take advantage of unaligned operations to avoid dealing with
204/// non-full-width chunks.
205///
206/// For example, if the SIMD width was 4 and the total length was 7, then it would be
207/// processed in two chunks of 4 like so:
208/// ```text
209///      Chunk 0
210/// |---------------|
211///   0   1   2   3   4   5   6
212///             |---------------|
213///                   Chunk 1
214/// ```
215/// This overlapping can only happen at the very end of the slice and only if the length
216/// of the slice is not a multiple of the SIMD width used.
217#[inline(always)]
218fn simd_convert<A, To, From>(arch: A, to: &mut [To], from: &[From])
219where
220    A: SIMDConvert<To, From>,
221{
222    let len = to.len();
223
224    // Keep stack writes to a minimum by explicitly outlining error handling.
225    if len != from.len() {
226        emit_length_error(len, from.len())
227    }
228
229    // Get the SIMD width.
230    //
231    // We're relying on the compiler to constant propagate this.
232    let width = A::get_simd_width();
233
234    let pto = to.as_mut_ptr();
235    let pfrom = from.as_ptr();
236
237    // Too short, deal with the small case and return.
238    if len < width {
239        // SAFETY: We know `pto` and `pfrom` do not alias because of Rust's aliasing
240        // rules on `to` and `from.
241        //
242        // Additionally, we've checked that both spans are valid for `len`.
243        unsafe { arch.handle_small(pto, pfrom, len) };
244        return;
245    }
246
247    const UNROLL: usize = 8;
248
249    let mut i = 0;
250    // SAFETY: We emit a bunch of unrolled load and store operations in this loop.
251    //
252    // All of these operations are safe because the bound `i + UNROLL * width <= len`
253    // is checked.
254    unsafe {
255        while i + UNROLL * width <= len {
256            let s0 = A::WideFrom::load_simd(arch, pfrom.add(i));
257            A::simd_convert(s0).store_simd(pto.add(i));
258
259            let s1 = A::WideFrom::load_simd(arch, pfrom.add(i + width));
260            A::simd_convert(s1).store_simd(pto.add(i + width));
261
262            let s2 = A::WideFrom::load_simd(arch, pfrom.add(i + 2 * width));
263            A::simd_convert(s2).store_simd(pto.add(i + 2 * width));
264
265            let s3 = A::WideFrom::load_simd(arch, pfrom.add(i + 3 * width));
266            A::simd_convert(s3).store_simd(pto.add(i + 3 * width));
267
268            let s0 = A::WideFrom::load_simd(arch, pfrom.add(i + 4 * width));
269            A::simd_convert(s0).store_simd(pto.add(i + 4 * width));
270
271            let s1 = A::WideFrom::load_simd(arch, pfrom.add(i + 5 * width));
272            A::simd_convert(s1).store_simd(pto.add(i + 5 * width));
273
274            let s2 = A::WideFrom::load_simd(arch, pfrom.add(i + 6 * width));
275            A::simd_convert(s2).store_simd(pto.add(i + 6 * width));
276
277            let s3 = A::WideFrom::load_simd(arch, pfrom.add(i + 7 * width));
278            A::simd_convert(s3).store_simd(pto.add(i + 7 * width));
279
280            i += UNROLL * width;
281        }
282    }
283
284    while i + width <= len {
285        // SAFETY: `i + width <= len` ensure that this read is in-bounds.
286        let s0 = unsafe { A::WideFrom::load_simd(arch, pfrom.add(i)) };
287        let t0 = A::simd_convert(s0);
288        // SAFETY: `i + width <= len` ensure that this write is in-bounds.
289        unsafe { t0.store_simd(pto.add(i)) };
290        i += width;
291    }
292
293    // Check if we need to deal with any remaining elements.
294    // If so, bump back `i` so we can process a whole chunk.
295    if i != len {
296        let offset = i - (width - len % width);
297
298        // SAFETY: At this point, we know that `len >= width`, `i < len`, and
299        // `len - i == len % width != 0`.
300        //
301        // Therefore, `offset` is inbounds and `offset + width == len`.
302        let s0 = unsafe { A::WideFrom::load_simd(arch, pfrom.add(offset)) };
303        let t0 = A::simd_convert(s0);
304
305        // SAFETY: This write is safe for the same reason that the preceeding read is safe.
306        unsafe { t0.store_simd(pto.add(offset)) };
307    }
308}
309
310#[cfg(target_arch = "x86_64")]
311impl SIMDConvert<f32, f16> for diskann_wide::arch::x86_64::V4 {
312    type Width = Const<8>;
313    type WideTo = <diskann_wide::arch::x86_64::V4 as Architecture>::f32x8;
314    type WideFrom = <diskann_wide::arch::x86_64::V4 as Architecture>::f16x8;
315
316    #[inline(always)]
317    fn simd_convert(from: Self::WideFrom) -> Self::WideTo {
318        from.into()
319    }
320
321    // SAFETY: We only access data in the valid range for `pto` and `pfrom`.
322    #[inline(always)]
323    unsafe fn handle_small(self, pto: *mut f32, pfrom: *const f16, len: usize) {
324        for i in 0..len {
325            *pto.add(i) = diskann_wide::cast_f16_to_f32(*pfrom.add(i))
326        }
327    }
328}
329
330#[cfg(target_arch = "x86_64")]
331impl SIMDConvert<f32, f16> for diskann_wide::arch::x86_64::V3 {
332    type Width = Const<8>;
333    type WideTo = <diskann_wide::arch::x86_64::V3 as Architecture>::f32x8;
334    type WideFrom = <diskann_wide::arch::x86_64::V3 as Architecture>::f16x8;
335
336    #[inline(always)]
337    fn simd_convert(from: Self::WideFrom) -> Self::WideTo {
338        from.into()
339    }
340
341    // SAFETY: We only access data in the valid range for `pto` and `pfrom`.
342    #[inline(always)]
343    unsafe fn handle_small(self, pto: *mut f32, pfrom: *const f16, len: usize) {
344        for i in 0..len {
345            *pto.add(i) = diskann_wide::cast_f16_to_f32(*pfrom.add(i))
346        }
347    }
348}
349
350#[cfg(target_arch = "x86_64")]
351impl SIMDConvert<f16, f32> for diskann_wide::arch::x86_64::V4 {
352    type Width = Const<8>;
353    type WideTo = <diskann_wide::arch::x86_64::V4 as Architecture>::f16x8;
354    type WideFrom = <diskann_wide::arch::x86_64::V4 as Architecture>::f32x8;
355
356    #[inline(always)]
357    fn simd_convert(from: Self::WideFrom) -> Self::WideTo {
358        from.simd_cast()
359    }
360
361    // SAFETY: We only access data in the valid range for `pto` and `pfrom`.
362    #[inline(always)]
363    unsafe fn handle_small(self, pto: *mut f16, pfrom: *const f32, len: usize) {
364        for i in 0..len {
365            *pto.add(i) = diskann_wide::cast_f32_to_f16(*pfrom.add(i))
366        }
367    }
368}
369
370#[cfg(target_arch = "x86_64")]
371impl SIMDConvert<f16, f32> for diskann_wide::arch::x86_64::V3 {
372    type Width = Const<8>;
373    type WideTo = <diskann_wide::arch::x86_64::V3 as Architecture>::f16x8;
374    type WideFrom = <diskann_wide::arch::x86_64::V3 as Architecture>::f32x8;
375
376    #[inline(always)]
377    fn simd_convert(from: Self::WideFrom) -> Self::WideTo {
378        from.simd_cast()
379    }
380
381    // SAFETY: We only access data in the valid range for `pto` and `pfrom`.
382    #[inline(always)]
383    unsafe fn handle_small(self, pto: *mut f16, pfrom: *const f32, len: usize) {
384        for i in 0..len {
385            *pto.add(i) = diskann_wide::cast_f32_to_f16(*pfrom.add(i))
386        }
387    }
388}
389
390//---------//
391// Aarch64 //
392//---------//
393
394#[cfg(target_arch = "aarch64")]
395impl SIMDConvert<f32, f16> for diskann_wide::arch::aarch64::Neon {
396    type Width = Const<4>;
397    type WideTo = <diskann_wide::arch::aarch64::Neon as Architecture>::f32x4;
398    type WideFrom = diskann_wide::arch::aarch64::f16x4;
399
400    #[inline(always)]
401    fn simd_convert(from: Self::WideFrom) -> Self::WideTo {
402        from.into()
403    }
404
405    // SAFETY: We only access data in the valid range for `pto` and `pfrom`.
406    #[inline(always)]
407    unsafe fn handle_small(self, pto: *mut f32, pfrom: *const f16, len: usize) {
408        for i in 0..len {
409            *pto.add(i) = diskann_wide::cast_f16_to_f32(*pfrom.add(i))
410        }
411    }
412}
413
414#[cfg(target_arch = "aarch64")]
415impl SIMDConvert<f16, f32> for diskann_wide::arch::aarch64::Neon {
416    type Width = Const<4>;
417    type WideTo = diskann_wide::arch::aarch64::f16x4;
418    type WideFrom = <diskann_wide::arch::aarch64::Neon as Architecture>::f32x4;
419
420    #[inline(always)]
421    fn simd_convert(from: Self::WideFrom) -> Self::WideTo {
422        from.simd_cast()
423    }
424
425    // SAFETY: We only access data in the valid range for `pto` and `pfrom`.
426    #[inline(always)]
427    unsafe fn handle_small(self, pto: *mut f16, pfrom: *const f32, len: usize) {
428        for i in 0..len {
429            *pto.add(i) = diskann_wide::cast_f32_to_f16(*pfrom.add(i))
430        }
431    }
432}
433
434///////////
435// Tests //
436///////////
437
438#[cfg(test)]
439mod tests {
440    use rand::{
441        distr::{Distribution, StandardUniform},
442        rngs::StdRng,
443        SeedableRng,
444    };
445
446    use super::*;
447
448    ////////////////
449    // Fuzz tests //
450    ////////////////
451
452    trait ReferenceConvert<From> {
453        fn reference_convert(self, from: &[From]);
454    }
455
456    impl ReferenceConvert<f32> for &mut [f16] {
457        fn reference_convert(self, from: &[f32]) {
458            assert_eq!(self.len(), from.len());
459            std::iter::zip(self.iter_mut(), from.iter()).for_each(|(d, s)| *d = f16::from_f32(*s));
460        }
461    }
462
463    impl ReferenceConvert<f16> for &mut [f32] {
464        fn reference_convert(self, from: &[f16]) {
465            assert_eq!(self.len(), from.len());
466            std::iter::zip(self.iter_mut(), from.iter()).for_each(|(d, s)| *d = (*s).into());
467        }
468    }
469
470    fn test_cast_from_slice<To, From>(max_dim: usize, num_trials: usize, rng: &mut StdRng)
471    where
472        StandardUniform: Distribution<From>,
473        To: Default + PartialEq + std::fmt::Debug + Copy,
474        From: Default + Copy,
475        for<'a, 'b> &'a mut [To]: CastFromSlice<&'b [From]> + ReferenceConvert<From>,
476    {
477        let distribution = StandardUniform {};
478        for dim in 0..=max_dim {
479            let mut src = vec![From::default(); dim];
480            let mut dst = vec![To::default(); dim];
481            let mut dst_reference = vec![To::default(); dim];
482
483            for _ in 0..num_trials {
484                src.iter_mut().for_each(|s| *s = distribution.sample(rng));
485                dst.cast_from_slice(src.as_slice());
486                dst_reference.reference_convert(&src);
487
488                assert_eq!(dst, dst_reference);
489            }
490        }
491    }
492
493    #[test]
494    fn test_f32_to_f16_fuzz() {
495        let mut rng = StdRng::seed_from_u64(0x0a3bfe052a8ebf98);
496        test_cast_from_slice::<f16, f32>(256, 10, &mut rng);
497    }
498
499    #[test]
500    fn test_f16_to_f32_fuzz() {
501        let mut rng = StdRng::seed_from_u64(0x83765b2816321eca);
502        test_cast_from_slice::<f32, f16>(256, 10, &mut rng);
503    }
504
505    ////////////////
506    // Miri Tests //
507    ////////////////
508
509    // With Miri, we need to be really careful that we do not hit methods that call
510    // `cvtph_ps`...
511    #[test]
512    fn miri_test_f32_to_f16() {
513        for dim in 0..256 {
514            println!("processing dim {}", dim);
515
516            let src = vec![f32::default(); dim];
517            let mut dst = vec![f16::default(); dim];
518
519            // Just check that all accesses are inbounds.
520            dst.cast_from_slice(&src);
521
522            // Scalar conversion.
523            SliceCast::<f16, f32>::new().run(
524                diskann_wide::arch::Scalar,
525                dst.as_mut_slice(),
526                src.as_slice(),
527            );
528
529            // SIMD conversion
530            #[cfg(target_arch = "x86_64")]
531            {
532                if let Some(arch) = diskann_wide::arch::x86_64::V3::new_checked() {
533                    SliceCast::<f16, f32>::new().run(arch, dst.as_mut_slice(), src.as_slice())
534                }
535
536                if let Some(arch) = diskann_wide::arch::x86_64::V4::new_checked_miri() {
537                    SliceCast::<f16, f32>::new().run(arch, dst.as_mut_slice(), src.as_slice())
538                }
539            }
540
541            #[cfg(target_arch = "aarch64")]
542            if let Some(arch) = diskann_wide::arch::aarch64::Neon::new_checked() {
543                SliceCast::<f16, f32>::new().run(arch, dst.as_mut_slice(), src.as_slice())
544            }
545        }
546    }
547
548    #[test]
549    fn miri_test_f16_to_f32() {
550        for dim in 0..256 {
551            println!("processing dim {}", dim);
552
553            let src = vec![f16::default(); dim];
554            let mut dst = vec![f32::default(); dim];
555
556            // Just check that all accesses are inbounds.
557            dst.cast_from_slice(&src);
558
559            // Scalar conversion.
560            SliceCast::<f32, f16>::new().run(
561                diskann_wide::arch::Scalar,
562                dst.as_mut_slice(),
563                src.as_slice(),
564            );
565
566            // SIMD conversion
567            #[cfg(target_arch = "x86_64")]
568            {
569                if let Some(arch) = diskann_wide::arch::x86_64::V3::new_checked() {
570                    SliceCast::<f32, f16>::new().run(arch, dst.as_mut_slice(), src.as_slice())
571                }
572
573                if let Some(arch) = diskann_wide::arch::x86_64::V4::new_checked_miri() {
574                    SliceCast::<f32, f16>::new().run(arch, dst.as_mut_slice(), src.as_slice())
575                }
576            }
577
578            #[cfg(target_arch = "aarch64")]
579            if let Some(arch) = diskann_wide::arch::aarch64::Neon::new_checked() {
580                SliceCast::<f32, f16>::new().run(arch, dst.as_mut_slice(), src.as_slice())
581            }
582        }
583    }
584}