Skip to main content

diskann_vector/
conversion.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use std::convert::{AsMut, AsRef};
7
8use diskann_wide::arch::Target2;
9#[cfg(not(target_arch = "aarch64"))]
10use diskann_wide::{Architecture, Const, Constant, SIMDCast, SIMDVector};
11use half::f16;
12
13/// Perform a numeric cast on a slice of values.
14///
15/// This trait is intended to have the following numerical behavior:
16///
17/// 1. If a lossless conversion between types is available, use that.
18/// 2. Otherwise, if the two type are floating point types, use a round-to-nearest strategy.
19/// 3. Otherwise, try to behave like the Rust `as` numeric cast.
20///
21/// The main reason we can't just say "behave like "as"" is because Rust does not have
22/// a native `f16` type, which this crate supports.
23pub trait CastFromSlice<From> {
24    fn cast_from_slice(self, from: From);
25}
26
27macro_rules! use_simd_cast_from_slice {
28    ($from:ty => $to:ty) => {
29        impl CastFromSlice<&[$from]> for &mut [$to] {
30            #[inline(always)]
31            fn cast_from_slice(self, from: &[$from]) {
32                SliceCast::<$to, $from>::new().run(diskann_wide::ARCH, self, from)
33            }
34        }
35
36        impl<const N: usize> CastFromSlice<&[$from; N]> for &mut [$to; N] {
37            #[inline(always)]
38            fn cast_from_slice(self, from: &[$from; N]) {
39                SliceCast::<$to, $from>::new().run(diskann_wide::ARCH, self, from)
40            }
41        }
42    };
43}
44
45use_simd_cast_from_slice!(f32 => f16);
46use_simd_cast_from_slice!(f16 => f32);
47
48/// A zero-sized type providing implementations of [`diskann_wide::arch::Target2`] to provide
49/// platform-dependent conversions between slices of the two generic types.
50#[derive(Debug, Default, Clone, Copy)]
51pub struct SliceCast<To, From> {
52    _marker: std::marker::PhantomData<(To, From)>,
53}
54
55impl<To, From> SliceCast<To, From> {
56    pub fn new() -> Self {
57        Self {
58            _marker: std::marker::PhantomData,
59        }
60    }
61}
62
63// Non-SIMD Instantiations
64
65impl<T, U> Target2<diskann_wide::arch::Scalar, (), T, U> for SliceCast<f16, f32>
66where
67    T: AsMut<[f16]>,
68    U: AsRef<[f32]>,
69{
70    #[inline(always)]
71    fn run(self, _: diskann_wide::arch::Scalar, mut to: T, from: U) {
72        let to = to.as_mut();
73        let from = from.as_ref();
74        std::iter::zip(to.iter_mut(), from.iter()).for_each(|(to, from)| {
75            *to = diskann_wide::cast_f32_to_f16(*from);
76        })
77    }
78}
79
80impl<T, U> Target2<diskann_wide::arch::Scalar, (), T, U> for SliceCast<f32, f16>
81where
82    T: AsMut<[f32]>,
83    U: AsRef<[f16]>,
84{
85    #[inline(always)]
86    fn run(self, _: diskann_wide::arch::Scalar, mut to: T, from: U) {
87        let to = to.as_mut();
88        let from = from.as_ref();
89        std::iter::zip(to.iter_mut(), from.iter()).for_each(|(to, from)| {
90            *to = diskann_wide::cast_f16_to_f32(*from);
91        })
92    }
93}
94
95// SIMD Instantiations
96#[cfg(target_arch = "x86_64")]
97impl<T, U, To, From> Target2<diskann_wide::arch::x86_64::V4, (), T, U> for SliceCast<To, From>
98where
99    T: AsMut<[To]>,
100    U: AsRef<[From]>,
101    diskann_wide::arch::x86_64::V4: SIMDConvert<To, From>,
102{
103    #[inline(always)]
104    fn run(self, arch: diskann_wide::arch::x86_64::V4, mut to: T, from: U) {
105        simd_convert(arch, to.as_mut(), from.as_ref())
106    }
107}
108
109#[cfg(target_arch = "x86_64")]
110impl<T, U, To, From> Target2<diskann_wide::arch::x86_64::V3, (), T, U> for SliceCast<To, From>
111where
112    T: AsMut<[To]>,
113    U: AsRef<[From]>,
114    diskann_wide::arch::x86_64::V3: SIMDConvert<To, From>,
115{
116    #[inline(always)]
117    fn run(self, arch: diskann_wide::arch::x86_64::V3, mut to: T, from: U) {
118        simd_convert(arch, to.as_mut(), from.as_ref())
119    }
120}
121
122/////////////////////////////
123// General SIMD Conversion //
124/////////////////////////////
125
126/// A helper trait to fill in the gaps for the unrolled `simd_convert` method.
127#[cfg(target_arch = "x86_64")]
128trait SIMDConvert<To, From>: Architecture {
129    /// A constant encoding the the SIMD width of the underlying schema.
130    type Width: Constant<Type = usize>;
131
132    /// The SIMD Vector for the converted-to type.
133    type WideTo: SIMDVector<Arch = Self, Scalar = To, ConstLanes = Self::Width>;
134
135    /// The SIMD Vector for the converted-from type.
136    type WideFrom: SIMDVector<Arch = Self, Scalar = From, ConstLanes = Self::Width>;
137
138    /// The method that actually does the vector-wide conversion.
139    fn simd_convert(from: Self::WideFrom) -> Self::WideTo;
140
141    /// Delegate routing for handling conversion lengths less than the vector width.
142    ///
143    /// The canonical implementation uses predicated loads, but implementations may wish
144    /// to use a scalar loop instead.
145    ///
146    /// # Safety
147    ///
148    /// This trait will only be called when the following guarantees are made:
149    ///
150    /// * `pto` will point to properly aligned memory that is valid for writes on the
151    ///   range `[pto, pto + len)`.
152    /// * `pfrom` will point to properly aligned memory that is valid for reads on the
153    ///   range `[pfrom, pfrom + len)`.
154    /// * The memory ranges covered by `pto` and `pfrom` must not alias.
155    #[inline(always)]
156    unsafe fn handle_small(self, pto: *mut To, pfrom: *const From, len: usize) {
157        let from = Self::WideFrom::load_simd_first(self, pfrom, len);
158        let to = Self::simd_convert(from);
159        to.store_simd_first(pto, len);
160    }
161
162    /// !! Do not extend this function !!
163    ///
164    /// Due to limitations on how associated constants can be used, we need a function
165    /// to access the SIMD width and rely on the compiler to constant propagate the result.
166    #[inline(always)]
167    fn get_simd_width() -> usize {
168        Self::Width::value()
169    }
170}
171
172#[inline(never)]
173#[allow(clippy::panic)]
174#[cfg(target_arch = "x86_64")]
175fn emit_length_error(xlen: usize, ylen: usize) -> ! {
176    panic!(
177        "lengths must be equal, instead got: xlen = {}, ylen = {}",
178        xlen, ylen
179    )
180}
181
182/// Convert each element of `from` into its corresponding position in `to` using the
183/// conversion rule applied by `S`.
184///
185/// # Panics
186///
187/// Panics if `to.len() != from.len()`.
188///
189/// # Implementation Notes
190///
191/// This function will only call `A::handle_small` if the total length of the processed
192/// slices is less that the underlying SIMD width.
193///
194/// Otherwise, we take advantage of unaligned operations to avoid dealing with
195/// non-full-width chunks.
196///
197/// For example, if the SIMD width was 4 and the total length was 7, then it would be
198/// processed in two chunks of 4 like so:
199/// ```text
200///      Chunk 0
201/// |---------------|
202///   0   1   2   3   4   5   6
203///             |---------------|
204///                   Chunk 1
205/// ```
206/// This overlapping can only happen at the very end of the slice and only if the length
207/// of the slice is not a multiple of the SIMD width used.
208#[inline(always)]
209#[cfg(target_arch = "x86_64")]
210fn simd_convert<A, To, From>(arch: A, to: &mut [To], from: &[From])
211where
212    A: SIMDConvert<To, From>,
213{
214    let len = to.len();
215
216    // Keep stack writes to a minimum by explicitly outlining error handling.
217    if len != from.len() {
218        emit_length_error(len, from.len())
219    }
220
221    // Get the SIMD width.
222    //
223    // We're relying on the compiler to constant propagate this.
224    let width = A::get_simd_width();
225
226    let pto = to.as_mut_ptr();
227    let pfrom = from.as_ptr();
228
229    // Too short, deal with the small case and return.
230    if len < width {
231        // SAFETY: We know `pto` and `pfrom` do not alias because of Rust's aliasing
232        // rules on `to` and `from.
233        //
234        // Additionally, we've checked that both spans are valid for `len`.
235        unsafe { arch.handle_small(pto, pfrom, len) };
236        return;
237    }
238
239    const UNROLL: usize = 8;
240
241    let mut i = 0;
242    // SAFETY: We emit a bunch of unrolled load and store operations in this loop.
243    //
244    // All of these operations are safe because the bound `i + UNROLL * width <= len`
245    // is checked.
246    unsafe {
247        while i + UNROLL * width <= len {
248            let s0 = A::WideFrom::load_simd(arch, pfrom.add(i));
249            A::simd_convert(s0).store_simd(pto.add(i));
250
251            let s1 = A::WideFrom::load_simd(arch, pfrom.add(i + width));
252            A::simd_convert(s1).store_simd(pto.add(i + width));
253
254            let s2 = A::WideFrom::load_simd(arch, pfrom.add(i + 2 * width));
255            A::simd_convert(s2).store_simd(pto.add(i + 2 * width));
256
257            let s3 = A::WideFrom::load_simd(arch, pfrom.add(i + 3 * width));
258            A::simd_convert(s3).store_simd(pto.add(i + 3 * width));
259
260            let s0 = A::WideFrom::load_simd(arch, pfrom.add(i + 4 * width));
261            A::simd_convert(s0).store_simd(pto.add(i + 4 * width));
262
263            let s1 = A::WideFrom::load_simd(arch, pfrom.add(i + 5 * width));
264            A::simd_convert(s1).store_simd(pto.add(i + 5 * width));
265
266            let s2 = A::WideFrom::load_simd(arch, pfrom.add(i + 6 * width));
267            A::simd_convert(s2).store_simd(pto.add(i + 6 * width));
268
269            let s3 = A::WideFrom::load_simd(arch, pfrom.add(i + 7 * width));
270            A::simd_convert(s3).store_simd(pto.add(i + 7 * width));
271
272            i += UNROLL * width;
273        }
274    }
275
276    while i + width <= len {
277        // SAFETY: `i + width <= len` ensure that this read is in-bounds.
278        let s0 = unsafe { A::WideFrom::load_simd(arch, pfrom.add(i)) };
279        let t0 = A::simd_convert(s0);
280        // SAFETY: `i + width <= len` ensure that this write is in-bounds.
281        unsafe { t0.store_simd(pto.add(i)) };
282        i += width;
283    }
284
285    // Check if we need to deal with any remaining elements.
286    // If so, bump back `i` so we can process a whole chunk.
287    if i != len {
288        let offset = i - (width - len % width);
289
290        // SAFETY: At this point, we know that `len >= width`, `i < len`, and
291        // `len - i == len % width != 0`.
292        //
293        // Therefore, `offset` is inbounds and `offset + width == len`.
294        let s0 = unsafe { A::WideFrom::load_simd(arch, pfrom.add(offset)) };
295        let t0 = A::simd_convert(s0);
296
297        // SAFETY: This write is safe for the same reason that the preceeding read is safe.
298        unsafe { t0.store_simd(pto.add(offset)) };
299    }
300}
301
302#[cfg(target_arch = "x86_64")]
303impl SIMDConvert<f32, f16> for diskann_wide::arch::x86_64::V4 {
304    type Width = Const<8>;
305    type WideTo = <diskann_wide::arch::x86_64::V4 as Architecture>::f32x8;
306    type WideFrom = <diskann_wide::arch::x86_64::V4 as Architecture>::f16x8;
307
308    #[inline(always)]
309    fn simd_convert(from: Self::WideFrom) -> Self::WideTo {
310        from.into()
311    }
312
313    // SAFETY: We only access data in the valid range for `pto` and `pfrom`.
314    #[inline(always)]
315    unsafe fn handle_small(self, pto: *mut f32, pfrom: *const f16, len: usize) {
316        for i in 0..len {
317            *pto.add(i) = diskann_wide::cast_f16_to_f32(*pfrom.add(i))
318        }
319    }
320}
321
322#[cfg(target_arch = "x86_64")]
323impl SIMDConvert<f32, f16> for diskann_wide::arch::x86_64::V3 {
324    type Width = Const<8>;
325    type WideTo = <diskann_wide::arch::x86_64::V3 as Architecture>::f32x8;
326    type WideFrom = <diskann_wide::arch::x86_64::V3 as Architecture>::f16x8;
327
328    #[inline(always)]
329    fn simd_convert(from: Self::WideFrom) -> Self::WideTo {
330        from.into()
331    }
332
333    // SAFETY: We only access data in the valid range for `pto` and `pfrom`.
334    #[inline(always)]
335    unsafe fn handle_small(self, pto: *mut f32, pfrom: *const f16, len: usize) {
336        for i in 0..len {
337            *pto.add(i) = diskann_wide::cast_f16_to_f32(*pfrom.add(i))
338        }
339    }
340}
341
342#[cfg(target_arch = "x86_64")]
343impl SIMDConvert<f16, f32> for diskann_wide::arch::x86_64::V4 {
344    type Width = Const<8>;
345    type WideTo = <diskann_wide::arch::x86_64::V4 as Architecture>::f16x8;
346    type WideFrom = <diskann_wide::arch::x86_64::V4 as Architecture>::f32x8;
347
348    #[inline(always)]
349    fn simd_convert(from: Self::WideFrom) -> Self::WideTo {
350        from.simd_cast()
351    }
352
353    // SAFETY: We only access data in the valid range for `pto` and `pfrom`.
354    #[inline(always)]
355    unsafe fn handle_small(self, pto: *mut f16, pfrom: *const f32, len: usize) {
356        for i in 0..len {
357            *pto.add(i) = diskann_wide::cast_f32_to_f16(*pfrom.add(i))
358        }
359    }
360}
361
362#[cfg(target_arch = "x86_64")]
363impl SIMDConvert<f16, f32> for diskann_wide::arch::x86_64::V3 {
364    type Width = Const<8>;
365    type WideTo = <diskann_wide::arch::x86_64::V3 as Architecture>::f16x8;
366    type WideFrom = <diskann_wide::arch::x86_64::V3 as Architecture>::f32x8;
367
368    #[inline(always)]
369    fn simd_convert(from: Self::WideFrom) -> Self::WideTo {
370        from.simd_cast()
371    }
372
373    // SAFETY: We only access data in the valid range for `pto` and `pfrom`.
374    #[inline(always)]
375    unsafe fn handle_small(self, pto: *mut f16, pfrom: *const f32, len: usize) {
376        for i in 0..len {
377            *pto.add(i) = diskann_wide::cast_f32_to_f16(*pfrom.add(i))
378        }
379    }
380}
381
382///////////
383// Tests //
384///////////
385
386#[cfg(test)]
387mod tests {
388    use rand::{
389        distr::{Distribution, StandardUniform},
390        rngs::StdRng,
391        SeedableRng,
392    };
393
394    use super::*;
395
396    ////////////////
397    // Fuzz tests //
398    ////////////////
399
400    trait ReferenceConvert<From> {
401        fn reference_convert(self, from: &[From]);
402    }
403
404    impl ReferenceConvert<f32> for &mut [f16] {
405        fn reference_convert(self, from: &[f32]) {
406            assert_eq!(self.len(), from.len());
407            std::iter::zip(self.iter_mut(), from.iter()).for_each(|(d, s)| *d = f16::from_f32(*s));
408        }
409    }
410
411    impl ReferenceConvert<f16> for &mut [f32] {
412        fn reference_convert(self, from: &[f16]) {
413            assert_eq!(self.len(), from.len());
414            std::iter::zip(self.iter_mut(), from.iter()).for_each(|(d, s)| *d = (*s).into());
415        }
416    }
417
418    fn test_cast_from_slice<To, From>(max_dim: usize, num_trials: usize, rng: &mut StdRng)
419    where
420        StandardUniform: Distribution<From>,
421        To: Default + PartialEq + std::fmt::Debug + Copy,
422        From: Default + Copy,
423        for<'a, 'b> &'a mut [To]: CastFromSlice<&'b [From]> + ReferenceConvert<From>,
424    {
425        let distribution = StandardUniform {};
426        for dim in 0..=max_dim {
427            let mut src = vec![From::default(); dim];
428            let mut dst = vec![To::default(); dim];
429            let mut dst_reference = vec![To::default(); dim];
430
431            for _ in 0..num_trials {
432                src.iter_mut().for_each(|s| *s = distribution.sample(rng));
433                dst.cast_from_slice(src.as_slice());
434                dst_reference.reference_convert(&src);
435
436                assert_eq!(dst, dst_reference);
437            }
438        }
439    }
440
441    #[test]
442    fn test_f32_to_f16_fuzz() {
443        let mut rng = StdRng::seed_from_u64(0x0a3bfe052a8ebf98);
444        test_cast_from_slice::<f16, f32>(256, 10, &mut rng);
445    }
446
447    #[test]
448    fn test_f16_to_f32_fuzz() {
449        let mut rng = StdRng::seed_from_u64(0x83765b2816321eca);
450        test_cast_from_slice::<f32, f16>(256, 10, &mut rng);
451    }
452
453    ////////////////
454    // Miri Tests //
455    ////////////////
456
457    // With Miri, we need to be really careful that we do not hit methods that call
458    // `cvtph_ps`...
459    #[test]
460    fn miri_test_f32_to_f16() {
461        for dim in 0..256 {
462            println!("processing dim {}", dim);
463
464            let src = vec![f32::default(); dim];
465            let mut dst = vec![f16::default(); dim];
466
467            // Just check that all accesses are inbounds.
468            dst.cast_from_slice(&src);
469
470            // Scalar conversion.
471            SliceCast::<f16, f32>::new().run(
472                diskann_wide::arch::Scalar,
473                dst.as_mut_slice(),
474                src.as_slice(),
475            );
476
477            // SIMD conversion
478            #[cfg(target_arch = "x86_64")]
479            {
480                if let Some(arch) = diskann_wide::arch::x86_64::V3::new_checked() {
481                    SliceCast::<f16, f32>::new().run(arch, dst.as_mut_slice(), src.as_slice())
482                }
483            }
484        }
485    }
486
487    #[test]
488    fn miri_test_f16_to_f32() {
489        for dim in 0..256 {
490            println!("processing dim {}", dim);
491
492            let src = vec![f16::default(); dim];
493            let mut dst = vec![f32::default(); dim];
494
495            // Just check that all accesses are inbounds.
496            dst.cast_from_slice(&src);
497
498            // Scalar conversion.
499            SliceCast::<f32, f16>::new().run(
500                diskann_wide::arch::Scalar,
501                dst.as_mut_slice(),
502                src.as_slice(),
503            );
504
505            // SIMD conversion
506            #[cfg(target_arch = "x86_64")]
507            {
508                if let Some(arch) = diskann_wide::arch::x86_64::V3::new_checked() {
509                    SliceCast::<f32, f16>::new().run(arch, dst.as_mut_slice(), src.as_slice())
510                }
511            }
512        }
513    }
514}