1use std::convert::{AsMut, AsRef};
7
8use diskann_wide::{arch::Target2, Architecture, Const, Constant, SIMDCast, SIMDVector};
9use half::f16;
10
11pub trait CastFromSlice<From> {
22 fn cast_from_slice(self, from: From);
23}
24
25macro_rules! use_simd_cast_from_slice {
26 ($from:ty => $to:ty) => {
27 impl CastFromSlice<&[$from]> for &mut [$to] {
28 #[inline(always)]
29 fn cast_from_slice(self, from: &[$from]) {
30 SliceCast::<$to, $from>::new().run(diskann_wide::ARCH, self, from)
31 }
32 }
33
34 impl<const N: usize> CastFromSlice<&[$from; N]> for &mut [$to; N] {
35 #[inline(always)]
36 fn cast_from_slice(self, from: &[$from; N]) {
37 SliceCast::<$to, $from>::new().run(diskann_wide::ARCH, self, from)
38 }
39 }
40 };
41}
42
43use_simd_cast_from_slice!(f32 => f16);
44use_simd_cast_from_slice!(f16 => f32);
45
46#[derive(Debug, Default, Clone, Copy)]
49pub struct SliceCast<To, From> {
50 _marker: std::marker::PhantomData<(To, From)>,
51}
52
53impl<To, From> SliceCast<To, From> {
54 pub fn new() -> Self {
55 Self {
56 _marker: std::marker::PhantomData,
57 }
58 }
59}
60
61impl<T, U> Target2<diskann_wide::arch::Scalar, (), T, U> for SliceCast<f16, f32>
64where
65 T: AsMut<[f16]>,
66 U: AsRef<[f32]>,
67{
68 #[inline(always)]
69 fn run(self, _: diskann_wide::arch::Scalar, mut to: T, from: U) {
70 let to = to.as_mut();
71 let from = from.as_ref();
72 std::iter::zip(to.iter_mut(), from.iter()).for_each(|(to, from)| {
73 *to = diskann_wide::cast_f32_to_f16(*from);
74 })
75 }
76}
77
78impl<T, U> Target2<diskann_wide::arch::Scalar, (), T, U> for SliceCast<f32, f16>
79where
80 T: AsMut<[f32]>,
81 U: AsRef<[f16]>,
82{
83 #[inline(always)]
84 fn run(self, _: diskann_wide::arch::Scalar, mut to: T, from: U) {
85 let to = to.as_mut();
86 let from = from.as_ref();
87 std::iter::zip(to.iter_mut(), from.iter()).for_each(|(to, from)| {
88 *to = diskann_wide::cast_f16_to_f32(*from);
89 })
90 }
91}
92
93#[cfg(target_arch = "x86_64")]
95impl<T, U, To, From> Target2<diskann_wide::arch::x86_64::V4, (), T, U> for SliceCast<To, From>
96where
97 T: AsMut<[To]>,
98 U: AsRef<[From]>,
99 diskann_wide::arch::x86_64::V4: SIMDConvert<To, From>,
100{
101 #[inline(always)]
102 fn run(self, arch: diskann_wide::arch::x86_64::V4, mut to: T, from: U) {
103 simd_convert(arch, to.as_mut(), from.as_ref())
104 }
105}
106
107#[cfg(target_arch = "x86_64")]
108impl<T, U, To, From> Target2<diskann_wide::arch::x86_64::V3, (), T, U> for SliceCast<To, From>
109where
110 T: AsMut<[To]>,
111 U: AsRef<[From]>,
112 diskann_wide::arch::x86_64::V3: SIMDConvert<To, From>,
113{
114 #[inline(always)]
115 fn run(self, arch: diskann_wide::arch::x86_64::V3, mut to: T, from: U) {
116 simd_convert(arch, to.as_mut(), from.as_ref())
117 }
118}
119
120#[cfg(target_arch = "aarch64")]
121impl<T, U, To, From> Target2<diskann_wide::arch::aarch64::Neon, (), T, U> for SliceCast<To, From>
122where
123 T: AsMut<[To]>,
124 U: AsRef<[From]>,
125 diskann_wide::arch::aarch64::Neon: SIMDConvert<To, From>,
126{
127 #[inline(always)]
128 fn run(self, arch: diskann_wide::arch::aarch64::Neon, mut to: T, from: U) {
129 simd_convert(arch, to.as_mut(), from.as_ref())
130 }
131}
132
133trait SIMDConvert<To, From>: Architecture {
139 type Width: Constant<Type = usize>;
141
142 type WideTo: SIMDVector<Arch = Self, Scalar = To, ConstLanes = Self::Width>;
144
145 type WideFrom: SIMDVector<Arch = Self, Scalar = From, ConstLanes = Self::Width>;
147
148 fn simd_convert(from: Self::WideFrom) -> Self::WideTo;
150
151 #[inline(always)]
166 unsafe fn handle_small(self, pto: *mut To, pfrom: *const From, len: usize) {
167 let from = Self::WideFrom::load_simd_first(self, pfrom, len);
168 let to = Self::simd_convert(from);
169 to.store_simd_first(pto, len);
170 }
171
172 #[inline(always)]
177 fn get_simd_width() -> usize {
178 Self::Width::value()
179 }
180}
181
182#[inline(never)]
183#[allow(clippy::panic)]
184fn emit_length_error(xlen: usize, ylen: usize) -> ! {
185 panic!(
186 "lengths must be equal, instead got: xlen = {}, ylen = {}",
187 xlen, ylen
188 )
189}
190
191#[inline(always)]
218fn simd_convert<A, To, From>(arch: A, to: &mut [To], from: &[From])
219where
220 A: SIMDConvert<To, From>,
221{
222 let len = to.len();
223
224 if len != from.len() {
226 emit_length_error(len, from.len())
227 }
228
229 let width = A::get_simd_width();
233
234 let pto = to.as_mut_ptr();
235 let pfrom = from.as_ptr();
236
237 if len < width {
239 unsafe { arch.handle_small(pto, pfrom, len) };
244 return;
245 }
246
247 const UNROLL: usize = 8;
248
249 let mut i = 0;
250 unsafe {
255 while i + UNROLL * width <= len {
256 let s0 = A::WideFrom::load_simd(arch, pfrom.add(i));
257 A::simd_convert(s0).store_simd(pto.add(i));
258
259 let s1 = A::WideFrom::load_simd(arch, pfrom.add(i + width));
260 A::simd_convert(s1).store_simd(pto.add(i + width));
261
262 let s2 = A::WideFrom::load_simd(arch, pfrom.add(i + 2 * width));
263 A::simd_convert(s2).store_simd(pto.add(i + 2 * width));
264
265 let s3 = A::WideFrom::load_simd(arch, pfrom.add(i + 3 * width));
266 A::simd_convert(s3).store_simd(pto.add(i + 3 * width));
267
268 let s0 = A::WideFrom::load_simd(arch, pfrom.add(i + 4 * width));
269 A::simd_convert(s0).store_simd(pto.add(i + 4 * width));
270
271 let s1 = A::WideFrom::load_simd(arch, pfrom.add(i + 5 * width));
272 A::simd_convert(s1).store_simd(pto.add(i + 5 * width));
273
274 let s2 = A::WideFrom::load_simd(arch, pfrom.add(i + 6 * width));
275 A::simd_convert(s2).store_simd(pto.add(i + 6 * width));
276
277 let s3 = A::WideFrom::load_simd(arch, pfrom.add(i + 7 * width));
278 A::simd_convert(s3).store_simd(pto.add(i + 7 * width));
279
280 i += UNROLL * width;
281 }
282 }
283
284 while i + width <= len {
285 let s0 = unsafe { A::WideFrom::load_simd(arch, pfrom.add(i)) };
287 let t0 = A::simd_convert(s0);
288 unsafe { t0.store_simd(pto.add(i)) };
290 i += width;
291 }
292
293 if i != len {
296 let offset = i - (width - len % width);
297
298 let s0 = unsafe { A::WideFrom::load_simd(arch, pfrom.add(offset)) };
303 let t0 = A::simd_convert(s0);
304
305 unsafe { t0.store_simd(pto.add(offset)) };
307 }
308}
309
310#[cfg(target_arch = "x86_64")]
311impl SIMDConvert<f32, f16> for diskann_wide::arch::x86_64::V4 {
312 type Width = Const<8>;
313 type WideTo = <diskann_wide::arch::x86_64::V4 as Architecture>::f32x8;
314 type WideFrom = <diskann_wide::arch::x86_64::V4 as Architecture>::f16x8;
315
316 #[inline(always)]
317 fn simd_convert(from: Self::WideFrom) -> Self::WideTo {
318 from.into()
319 }
320
321 #[inline(always)]
323 unsafe fn handle_small(self, pto: *mut f32, pfrom: *const f16, len: usize) {
324 for i in 0..len {
325 *pto.add(i) = diskann_wide::cast_f16_to_f32(*pfrom.add(i))
326 }
327 }
328}
329
330#[cfg(target_arch = "x86_64")]
331impl SIMDConvert<f32, f16> for diskann_wide::arch::x86_64::V3 {
332 type Width = Const<8>;
333 type WideTo = <diskann_wide::arch::x86_64::V3 as Architecture>::f32x8;
334 type WideFrom = <diskann_wide::arch::x86_64::V3 as Architecture>::f16x8;
335
336 #[inline(always)]
337 fn simd_convert(from: Self::WideFrom) -> Self::WideTo {
338 from.into()
339 }
340
341 #[inline(always)]
343 unsafe fn handle_small(self, pto: *mut f32, pfrom: *const f16, len: usize) {
344 for i in 0..len {
345 *pto.add(i) = diskann_wide::cast_f16_to_f32(*pfrom.add(i))
346 }
347 }
348}
349
350#[cfg(target_arch = "x86_64")]
351impl SIMDConvert<f16, f32> for diskann_wide::arch::x86_64::V4 {
352 type Width = Const<8>;
353 type WideTo = <diskann_wide::arch::x86_64::V4 as Architecture>::f16x8;
354 type WideFrom = <diskann_wide::arch::x86_64::V4 as Architecture>::f32x8;
355
356 #[inline(always)]
357 fn simd_convert(from: Self::WideFrom) -> Self::WideTo {
358 from.simd_cast()
359 }
360
361 #[inline(always)]
363 unsafe fn handle_small(self, pto: *mut f16, pfrom: *const f32, len: usize) {
364 for i in 0..len {
365 *pto.add(i) = diskann_wide::cast_f32_to_f16(*pfrom.add(i))
366 }
367 }
368}
369
370#[cfg(target_arch = "x86_64")]
371impl SIMDConvert<f16, f32> for diskann_wide::arch::x86_64::V3 {
372 type Width = Const<8>;
373 type WideTo = <diskann_wide::arch::x86_64::V3 as Architecture>::f16x8;
374 type WideFrom = <diskann_wide::arch::x86_64::V3 as Architecture>::f32x8;
375
376 #[inline(always)]
377 fn simd_convert(from: Self::WideFrom) -> Self::WideTo {
378 from.simd_cast()
379 }
380
381 #[inline(always)]
383 unsafe fn handle_small(self, pto: *mut f16, pfrom: *const f32, len: usize) {
384 for i in 0..len {
385 *pto.add(i) = diskann_wide::cast_f32_to_f16(*pfrom.add(i))
386 }
387 }
388}
389
390#[cfg(target_arch = "aarch64")]
395impl SIMDConvert<f32, f16> for diskann_wide::arch::aarch64::Neon {
396 type Width = Const<4>;
397 type WideTo = <diskann_wide::arch::aarch64::Neon as Architecture>::f32x4;
398 type WideFrom = diskann_wide::arch::aarch64::f16x4;
399
400 #[inline(always)]
401 fn simd_convert(from: Self::WideFrom) -> Self::WideTo {
402 from.into()
403 }
404
405 #[inline(always)]
407 unsafe fn handle_small(self, pto: *mut f32, pfrom: *const f16, len: usize) {
408 for i in 0..len {
409 *pto.add(i) = diskann_wide::cast_f16_to_f32(*pfrom.add(i))
410 }
411 }
412}
413
414#[cfg(target_arch = "aarch64")]
415impl SIMDConvert<f16, f32> for diskann_wide::arch::aarch64::Neon {
416 type Width = Const<4>;
417 type WideTo = diskann_wide::arch::aarch64::f16x4;
418 type WideFrom = <diskann_wide::arch::aarch64::Neon as Architecture>::f32x4;
419
420 #[inline(always)]
421 fn simd_convert(from: Self::WideFrom) -> Self::WideTo {
422 from.simd_cast()
423 }
424
425 #[inline(always)]
427 unsafe fn handle_small(self, pto: *mut f16, pfrom: *const f32, len: usize) {
428 for i in 0..len {
429 *pto.add(i) = diskann_wide::cast_f32_to_f16(*pfrom.add(i))
430 }
431 }
432}
433
434#[cfg(test)]
439mod tests {
440 use rand::{
441 distr::{Distribution, StandardUniform},
442 rngs::StdRng,
443 SeedableRng,
444 };
445
446 use super::*;
447
448 trait ReferenceConvert<From> {
453 fn reference_convert(self, from: &[From]);
454 }
455
456 impl ReferenceConvert<f32> for &mut [f16] {
457 fn reference_convert(self, from: &[f32]) {
458 assert_eq!(self.len(), from.len());
459 std::iter::zip(self.iter_mut(), from.iter()).for_each(|(d, s)| *d = f16::from_f32(*s));
460 }
461 }
462
463 impl ReferenceConvert<f16> for &mut [f32] {
464 fn reference_convert(self, from: &[f16]) {
465 assert_eq!(self.len(), from.len());
466 std::iter::zip(self.iter_mut(), from.iter()).for_each(|(d, s)| *d = (*s).into());
467 }
468 }
469
470 fn test_cast_from_slice<To, From>(max_dim: usize, num_trials: usize, rng: &mut StdRng)
471 where
472 StandardUniform: Distribution<From>,
473 To: Default + PartialEq + std::fmt::Debug + Copy,
474 From: Default + Copy,
475 for<'a, 'b> &'a mut [To]: CastFromSlice<&'b [From]> + ReferenceConvert<From>,
476 {
477 let distribution = StandardUniform {};
478 for dim in 0..=max_dim {
479 let mut src = vec![From::default(); dim];
480 let mut dst = vec![To::default(); dim];
481 let mut dst_reference = vec![To::default(); dim];
482
483 for _ in 0..num_trials {
484 src.iter_mut().for_each(|s| *s = distribution.sample(rng));
485 dst.cast_from_slice(src.as_slice());
486 dst_reference.reference_convert(&src);
487
488 assert_eq!(dst, dst_reference);
489 }
490 }
491 }
492
493 #[test]
494 fn test_f32_to_f16_fuzz() {
495 let mut rng = StdRng::seed_from_u64(0x0a3bfe052a8ebf98);
496 test_cast_from_slice::<f16, f32>(256, 10, &mut rng);
497 }
498
499 #[test]
500 fn test_f16_to_f32_fuzz() {
501 let mut rng = StdRng::seed_from_u64(0x83765b2816321eca);
502 test_cast_from_slice::<f32, f16>(256, 10, &mut rng);
503 }
504
505 #[test]
512 fn miri_test_f32_to_f16() {
513 for dim in 0..256 {
514 println!("processing dim {}", dim);
515
516 let src = vec![f32::default(); dim];
517 let mut dst = vec![f16::default(); dim];
518
519 dst.cast_from_slice(&src);
521
522 SliceCast::<f16, f32>::new().run(
524 diskann_wide::arch::Scalar,
525 dst.as_mut_slice(),
526 src.as_slice(),
527 );
528
529 #[cfg(target_arch = "x86_64")]
531 {
532 if let Some(arch) = diskann_wide::arch::x86_64::V3::new_checked() {
533 SliceCast::<f16, f32>::new().run(arch, dst.as_mut_slice(), src.as_slice())
534 }
535
536 if let Some(arch) = diskann_wide::arch::x86_64::V4::new_checked_miri() {
537 SliceCast::<f16, f32>::new().run(arch, dst.as_mut_slice(), src.as_slice())
538 }
539 }
540
541 #[cfg(target_arch = "aarch64")]
542 if let Some(arch) = diskann_wide::arch::aarch64::Neon::new_checked() {
543 SliceCast::<f16, f32>::new().run(arch, dst.as_mut_slice(), src.as_slice())
544 }
545 }
546 }
547
548 #[test]
549 fn miri_test_f16_to_f32() {
550 for dim in 0..256 {
551 println!("processing dim {}", dim);
552
553 let src = vec![f16::default(); dim];
554 let mut dst = vec![f32::default(); dim];
555
556 dst.cast_from_slice(&src);
558
559 SliceCast::<f32, f16>::new().run(
561 diskann_wide::arch::Scalar,
562 dst.as_mut_slice(),
563 src.as_slice(),
564 );
565
566 #[cfg(target_arch = "x86_64")]
568 {
569 if let Some(arch) = diskann_wide::arch::x86_64::V3::new_checked() {
570 SliceCast::<f32, f16>::new().run(arch, dst.as_mut_slice(), src.as_slice())
571 }
572
573 if let Some(arch) = diskann_wide::arch::x86_64::V4::new_checked_miri() {
574 SliceCast::<f32, f16>::new().run(arch, dst.as_mut_slice(), src.as_slice())
575 }
576 }
577
578 #[cfg(target_arch = "aarch64")]
579 if let Some(arch) = diskann_wide::arch::aarch64::Neon::new_checked() {
580 SliceCast::<f32, f16>::new().run(arch, dst.as_mut_slice(), src.as_slice())
581 }
582 }
583 }
584}