1use std::convert::{AsMut, AsRef};
7
8use diskann_wide::arch::Target2;
9#[cfg(not(target_arch = "aarch64"))]
10use diskann_wide::{Architecture, Const, Constant, SIMDCast, SIMDVector};
11use half::f16;
12
13pub trait CastFromSlice<From> {
24 fn cast_from_slice(self, from: From);
25}
26
27macro_rules! use_simd_cast_from_slice {
28 ($from:ty => $to:ty) => {
29 impl CastFromSlice<&[$from]> for &mut [$to] {
30 #[inline(always)]
31 fn cast_from_slice(self, from: &[$from]) {
32 SliceCast::<$to, $from>::new().run(diskann_wide::ARCH, self, from)
33 }
34 }
35
36 impl<const N: usize> CastFromSlice<&[$from; N]> for &mut [$to; N] {
37 #[inline(always)]
38 fn cast_from_slice(self, from: &[$from; N]) {
39 SliceCast::<$to, $from>::new().run(diskann_wide::ARCH, self, from)
40 }
41 }
42 };
43}
44
45use_simd_cast_from_slice!(f32 => f16);
46use_simd_cast_from_slice!(f16 => f32);
47
48#[derive(Debug, Default, Clone, Copy)]
51pub struct SliceCast<To, From> {
52 _marker: std::marker::PhantomData<(To, From)>,
53}
54
55impl<To, From> SliceCast<To, From> {
56 pub fn new() -> Self {
57 Self {
58 _marker: std::marker::PhantomData,
59 }
60 }
61}
62
63impl<T, U> Target2<diskann_wide::arch::Scalar, (), T, U> for SliceCast<f16, f32>
66where
67 T: AsMut<[f16]>,
68 U: AsRef<[f32]>,
69{
70 #[inline(always)]
71 fn run(self, _: diskann_wide::arch::Scalar, mut to: T, from: U) {
72 let to = to.as_mut();
73 let from = from.as_ref();
74 std::iter::zip(to.iter_mut(), from.iter()).for_each(|(to, from)| {
75 *to = diskann_wide::cast_f32_to_f16(*from);
76 })
77 }
78}
79
80impl<T, U> Target2<diskann_wide::arch::Scalar, (), T, U> for SliceCast<f32, f16>
81where
82 T: AsMut<[f32]>,
83 U: AsRef<[f16]>,
84{
85 #[inline(always)]
86 fn run(self, _: diskann_wide::arch::Scalar, mut to: T, from: U) {
87 let to = to.as_mut();
88 let from = from.as_ref();
89 std::iter::zip(to.iter_mut(), from.iter()).for_each(|(to, from)| {
90 *to = diskann_wide::cast_f16_to_f32(*from);
91 })
92 }
93}
94
95#[cfg(target_arch = "x86_64")]
97impl<T, U, To, From> Target2<diskann_wide::arch::x86_64::V4, (), T, U> for SliceCast<To, From>
98where
99 T: AsMut<[To]>,
100 U: AsRef<[From]>,
101 diskann_wide::arch::x86_64::V4: SIMDConvert<To, From>,
102{
103 #[inline(always)]
104 fn run(self, arch: diskann_wide::arch::x86_64::V4, mut to: T, from: U) {
105 simd_convert(arch, to.as_mut(), from.as_ref())
106 }
107}
108
109#[cfg(target_arch = "x86_64")]
110impl<T, U, To, From> Target2<diskann_wide::arch::x86_64::V3, (), T, U> for SliceCast<To, From>
111where
112 T: AsMut<[To]>,
113 U: AsRef<[From]>,
114 diskann_wide::arch::x86_64::V3: SIMDConvert<To, From>,
115{
116 #[inline(always)]
117 fn run(self, arch: diskann_wide::arch::x86_64::V3, mut to: T, from: U) {
118 simd_convert(arch, to.as_mut(), from.as_ref())
119 }
120}
121
122#[cfg(target_arch = "x86_64")]
128trait SIMDConvert<To, From>: Architecture {
129 type Width: Constant<Type = usize>;
131
132 type WideTo: SIMDVector<Arch = Self, Scalar = To, ConstLanes = Self::Width>;
134
135 type WideFrom: SIMDVector<Arch = Self, Scalar = From, ConstLanes = Self::Width>;
137
138 fn simd_convert(from: Self::WideFrom) -> Self::WideTo;
140
141 #[inline(always)]
156 unsafe fn handle_small(self, pto: *mut To, pfrom: *const From, len: usize) {
157 let from = Self::WideFrom::load_simd_first(self, pfrom, len);
158 let to = Self::simd_convert(from);
159 to.store_simd_first(pto, len);
160 }
161
162 #[inline(always)]
167 fn get_simd_width() -> usize {
168 Self::Width::value()
169 }
170}
171
172#[inline(never)]
173#[allow(clippy::panic)]
174#[cfg(target_arch = "x86_64")]
175fn emit_length_error(xlen: usize, ylen: usize) -> ! {
176 panic!(
177 "lengths must be equal, instead got: xlen = {}, ylen = {}",
178 xlen, ylen
179 )
180}
181
182#[inline(always)]
209#[cfg(target_arch = "x86_64")]
210fn simd_convert<A, To, From>(arch: A, to: &mut [To], from: &[From])
211where
212 A: SIMDConvert<To, From>,
213{
214 let len = to.len();
215
216 if len != from.len() {
218 emit_length_error(len, from.len())
219 }
220
221 let width = A::get_simd_width();
225
226 let pto = to.as_mut_ptr();
227 let pfrom = from.as_ptr();
228
229 if len < width {
231 unsafe { arch.handle_small(pto, pfrom, len) };
236 return;
237 }
238
239 const UNROLL: usize = 8;
240
241 let mut i = 0;
242 unsafe {
247 while i + UNROLL * width <= len {
248 let s0 = A::WideFrom::load_simd(arch, pfrom.add(i));
249 A::simd_convert(s0).store_simd(pto.add(i));
250
251 let s1 = A::WideFrom::load_simd(arch, pfrom.add(i + width));
252 A::simd_convert(s1).store_simd(pto.add(i + width));
253
254 let s2 = A::WideFrom::load_simd(arch, pfrom.add(i + 2 * width));
255 A::simd_convert(s2).store_simd(pto.add(i + 2 * width));
256
257 let s3 = A::WideFrom::load_simd(arch, pfrom.add(i + 3 * width));
258 A::simd_convert(s3).store_simd(pto.add(i + 3 * width));
259
260 let s0 = A::WideFrom::load_simd(arch, pfrom.add(i + 4 * width));
261 A::simd_convert(s0).store_simd(pto.add(i + 4 * width));
262
263 let s1 = A::WideFrom::load_simd(arch, pfrom.add(i + 5 * width));
264 A::simd_convert(s1).store_simd(pto.add(i + 5 * width));
265
266 let s2 = A::WideFrom::load_simd(arch, pfrom.add(i + 6 * width));
267 A::simd_convert(s2).store_simd(pto.add(i + 6 * width));
268
269 let s3 = A::WideFrom::load_simd(arch, pfrom.add(i + 7 * width));
270 A::simd_convert(s3).store_simd(pto.add(i + 7 * width));
271
272 i += UNROLL * width;
273 }
274 }
275
276 while i + width <= len {
277 let s0 = unsafe { A::WideFrom::load_simd(arch, pfrom.add(i)) };
279 let t0 = A::simd_convert(s0);
280 unsafe { t0.store_simd(pto.add(i)) };
282 i += width;
283 }
284
285 if i != len {
288 let offset = i - (width - len % width);
289
290 let s0 = unsafe { A::WideFrom::load_simd(arch, pfrom.add(offset)) };
295 let t0 = A::simd_convert(s0);
296
297 unsafe { t0.store_simd(pto.add(offset)) };
299 }
300}
301
302#[cfg(target_arch = "x86_64")]
303impl SIMDConvert<f32, f16> for diskann_wide::arch::x86_64::V4 {
304 type Width = Const<8>;
305 type WideTo = <diskann_wide::arch::x86_64::V4 as Architecture>::f32x8;
306 type WideFrom = <diskann_wide::arch::x86_64::V4 as Architecture>::f16x8;
307
308 #[inline(always)]
309 fn simd_convert(from: Self::WideFrom) -> Self::WideTo {
310 from.into()
311 }
312
313 #[inline(always)]
315 unsafe fn handle_small(self, pto: *mut f32, pfrom: *const f16, len: usize) {
316 for i in 0..len {
317 *pto.add(i) = diskann_wide::cast_f16_to_f32(*pfrom.add(i))
318 }
319 }
320}
321
322#[cfg(target_arch = "x86_64")]
323impl SIMDConvert<f32, f16> for diskann_wide::arch::x86_64::V3 {
324 type Width = Const<8>;
325 type WideTo = <diskann_wide::arch::x86_64::V3 as Architecture>::f32x8;
326 type WideFrom = <diskann_wide::arch::x86_64::V3 as Architecture>::f16x8;
327
328 #[inline(always)]
329 fn simd_convert(from: Self::WideFrom) -> Self::WideTo {
330 from.into()
331 }
332
333 #[inline(always)]
335 unsafe fn handle_small(self, pto: *mut f32, pfrom: *const f16, len: usize) {
336 for i in 0..len {
337 *pto.add(i) = diskann_wide::cast_f16_to_f32(*pfrom.add(i))
338 }
339 }
340}
341
342#[cfg(target_arch = "x86_64")]
343impl SIMDConvert<f16, f32> for diskann_wide::arch::x86_64::V4 {
344 type Width = Const<8>;
345 type WideTo = <diskann_wide::arch::x86_64::V4 as Architecture>::f16x8;
346 type WideFrom = <diskann_wide::arch::x86_64::V4 as Architecture>::f32x8;
347
348 #[inline(always)]
349 fn simd_convert(from: Self::WideFrom) -> Self::WideTo {
350 from.simd_cast()
351 }
352
353 #[inline(always)]
355 unsafe fn handle_small(self, pto: *mut f16, pfrom: *const f32, len: usize) {
356 for i in 0..len {
357 *pto.add(i) = diskann_wide::cast_f32_to_f16(*pfrom.add(i))
358 }
359 }
360}
361
362#[cfg(target_arch = "x86_64")]
363impl SIMDConvert<f16, f32> for diskann_wide::arch::x86_64::V3 {
364 type Width = Const<8>;
365 type WideTo = <diskann_wide::arch::x86_64::V3 as Architecture>::f16x8;
366 type WideFrom = <diskann_wide::arch::x86_64::V3 as Architecture>::f32x8;
367
368 #[inline(always)]
369 fn simd_convert(from: Self::WideFrom) -> Self::WideTo {
370 from.simd_cast()
371 }
372
373 #[inline(always)]
375 unsafe fn handle_small(self, pto: *mut f16, pfrom: *const f32, len: usize) {
376 for i in 0..len {
377 *pto.add(i) = diskann_wide::cast_f32_to_f16(*pfrom.add(i))
378 }
379 }
380}
381
382#[cfg(test)]
387mod tests {
388 use rand::{
389 distr::{Distribution, StandardUniform},
390 rngs::StdRng,
391 SeedableRng,
392 };
393
394 use super::*;
395
396 trait ReferenceConvert<From> {
401 fn reference_convert(self, from: &[From]);
402 }
403
404 impl ReferenceConvert<f32> for &mut [f16] {
405 fn reference_convert(self, from: &[f32]) {
406 assert_eq!(self.len(), from.len());
407 std::iter::zip(self.iter_mut(), from.iter()).for_each(|(d, s)| *d = f16::from_f32(*s));
408 }
409 }
410
411 impl ReferenceConvert<f16> for &mut [f32] {
412 fn reference_convert(self, from: &[f16]) {
413 assert_eq!(self.len(), from.len());
414 std::iter::zip(self.iter_mut(), from.iter()).for_each(|(d, s)| *d = (*s).into());
415 }
416 }
417
418 fn test_cast_from_slice<To, From>(max_dim: usize, num_trials: usize, rng: &mut StdRng)
419 where
420 StandardUniform: Distribution<From>,
421 To: Default + PartialEq + std::fmt::Debug + Copy,
422 From: Default + Copy,
423 for<'a, 'b> &'a mut [To]: CastFromSlice<&'b [From]> + ReferenceConvert<From>,
424 {
425 let distribution = StandardUniform {};
426 for dim in 0..=max_dim {
427 let mut src = vec![From::default(); dim];
428 let mut dst = vec![To::default(); dim];
429 let mut dst_reference = vec![To::default(); dim];
430
431 for _ in 0..num_trials {
432 src.iter_mut().for_each(|s| *s = distribution.sample(rng));
433 dst.cast_from_slice(src.as_slice());
434 dst_reference.reference_convert(&src);
435
436 assert_eq!(dst, dst_reference);
437 }
438 }
439 }
440
441 #[test]
442 fn test_f32_to_f16_fuzz() {
443 let mut rng = StdRng::seed_from_u64(0x0a3bfe052a8ebf98);
444 test_cast_from_slice::<f16, f32>(256, 10, &mut rng);
445 }
446
447 #[test]
448 fn test_f16_to_f32_fuzz() {
449 let mut rng = StdRng::seed_from_u64(0x83765b2816321eca);
450 test_cast_from_slice::<f32, f16>(256, 10, &mut rng);
451 }
452
453 #[test]
460 fn miri_test_f32_to_f16() {
461 for dim in 0..256 {
462 println!("processing dim {}", dim);
463
464 let src = vec![f32::default(); dim];
465 let mut dst = vec![f16::default(); dim];
466
467 dst.cast_from_slice(&src);
469
470 SliceCast::<f16, f32>::new().run(
472 diskann_wide::arch::Scalar,
473 dst.as_mut_slice(),
474 src.as_slice(),
475 );
476
477 #[cfg(target_arch = "x86_64")]
479 {
480 if let Some(arch) = diskann_wide::arch::x86_64::V3::new_checked() {
481 SliceCast::<f16, f32>::new().run(arch, dst.as_mut_slice(), src.as_slice())
482 }
483 }
484 }
485 }
486
487 #[test]
488 fn miri_test_f16_to_f32() {
489 for dim in 0..256 {
490 println!("processing dim {}", dim);
491
492 let src = vec![f16::default(); dim];
493 let mut dst = vec![f32::default(); dim];
494
495 dst.cast_from_slice(&src);
497
498 SliceCast::<f32, f16>::new().run(
500 diskann_wide::arch::Scalar,
501 dst.as_mut_slice(),
502 src.as_slice(),
503 );
504
505 #[cfg(target_arch = "x86_64")]
507 {
508 if let Some(arch) = diskann_wide::arch::x86_64::V3::new_checked() {
509 SliceCast::<f32, f16>::new().run(arch, dst.as_mut_slice(), src.as_slice())
510 }
511 }
512 }
513 }
514}