use std::convert::{AsMut, AsRef};
use diskann_wide::{arch::Target2, Architecture, Const, Constant, SIMDCast, SIMDVector};
use half::f16;
pub trait CastFromSlice<From> {
fn cast_from_slice(self, from: From);
}
macro_rules! use_simd_cast_from_slice {
($from:ty => $to:ty) => {
impl CastFromSlice<&[$from]> for &mut [$to] {
#[inline(always)]
fn cast_from_slice(self, from: &[$from]) {
SliceCast::<$to, $from>::new().run(diskann_wide::ARCH, self, from)
}
}
impl<const N: usize> CastFromSlice<&[$from; N]> for &mut [$to; N] {
#[inline(always)]
fn cast_from_slice(self, from: &[$from; N]) {
SliceCast::<$to, $from>::new().run(diskann_wide::ARCH, self, from)
}
}
};
}
use_simd_cast_from_slice!(f32 => f16);
use_simd_cast_from_slice!(f16 => f32);
#[derive(Debug, Default, Clone, Copy)]
pub struct SliceCast<To, From> {
_marker: std::marker::PhantomData<(To, From)>,
}
impl<To, From> SliceCast<To, From> {
pub fn new() -> Self {
Self {
_marker: std::marker::PhantomData,
}
}
}
impl<T, U> Target2<diskann_wide::arch::Scalar, (), T, U> for SliceCast<f16, f32>
where
T: AsMut<[f16]>,
U: AsRef<[f32]>,
{
#[inline(always)]
fn run(self, _: diskann_wide::arch::Scalar, mut to: T, from: U) {
let to = to.as_mut();
let from = from.as_ref();
std::iter::zip(to.iter_mut(), from.iter()).for_each(|(to, from)| {
*to = diskann_wide::cast_f32_to_f16(*from);
})
}
}
impl<T, U> Target2<diskann_wide::arch::Scalar, (), T, U> for SliceCast<f32, f16>
where
T: AsMut<[f32]>,
U: AsRef<[f16]>,
{
#[inline(always)]
fn run(self, _: diskann_wide::arch::Scalar, mut to: T, from: U) {
let to = to.as_mut();
let from = from.as_ref();
std::iter::zip(to.iter_mut(), from.iter()).for_each(|(to, from)| {
*to = diskann_wide::cast_f16_to_f32(*from);
})
}
}
#[cfg(target_arch = "x86_64")]
impl<T, U, To, From> Target2<diskann_wide::arch::x86_64::V4, (), T, U> for SliceCast<To, From>
where
T: AsMut<[To]>,
U: AsRef<[From]>,
diskann_wide::arch::x86_64::V4: SIMDConvert<To, From>,
{
#[inline(always)]
fn run(self, arch: diskann_wide::arch::x86_64::V4, mut to: T, from: U) {
simd_convert(arch, to.as_mut(), from.as_ref())
}
}
#[cfg(target_arch = "x86_64")]
impl<T, U, To, From> Target2<diskann_wide::arch::x86_64::V3, (), T, U> for SliceCast<To, From>
where
T: AsMut<[To]>,
U: AsRef<[From]>,
diskann_wide::arch::x86_64::V3: SIMDConvert<To, From>,
{
#[inline(always)]
fn run(self, arch: diskann_wide::arch::x86_64::V3, mut to: T, from: U) {
simd_convert(arch, to.as_mut(), from.as_ref())
}
}
#[cfg(target_arch = "aarch64")]
impl<T, U, To, From> Target2<diskann_wide::arch::aarch64::Neon, (), T, U> for SliceCast<To, From>
where
T: AsMut<[To]>,
U: AsRef<[From]>,
diskann_wide::arch::aarch64::Neon: SIMDConvert<To, From>,
{
#[inline(always)]
fn run(self, arch: diskann_wide::arch::aarch64::Neon, mut to: T, from: U) {
simd_convert(arch, to.as_mut(), from.as_ref())
}
}
trait SIMDConvert<To, From>: Architecture {
type Width: Constant<Type = usize>;
type WideTo: SIMDVector<Arch = Self, Scalar = To, ConstLanes = Self::Width>;
type WideFrom: SIMDVector<Arch = Self, Scalar = From, ConstLanes = Self::Width>;
fn simd_convert(from: Self::WideFrom) -> Self::WideTo;
#[inline(always)]
unsafe fn handle_small(self, pto: *mut To, pfrom: *const From, len: usize) {
let from = Self::WideFrom::load_simd_first(self, pfrom, len);
let to = Self::simd_convert(from);
to.store_simd_first(pto, len);
}
#[inline(always)]
fn get_simd_width() -> usize {
Self::Width::value()
}
}
#[inline(never)]
#[allow(clippy::panic)]
fn emit_length_error(xlen: usize, ylen: usize) -> ! {
panic!(
"lengths must be equal, instead got: xlen = {}, ylen = {}",
xlen, ylen
)
}
#[inline(always)]
fn simd_convert<A, To, From>(arch: A, to: &mut [To], from: &[From])
where
A: SIMDConvert<To, From>,
{
let len = to.len();
if len != from.len() {
emit_length_error(len, from.len())
}
let width = A::get_simd_width();
let pto = to.as_mut_ptr();
let pfrom = from.as_ptr();
if len < width {
unsafe { arch.handle_small(pto, pfrom, len) };
return;
}
const UNROLL: usize = 8;
let mut i = 0;
unsafe {
while i + UNROLL * width <= len {
let s0 = A::WideFrom::load_simd(arch, pfrom.add(i));
A::simd_convert(s0).store_simd(pto.add(i));
let s1 = A::WideFrom::load_simd(arch, pfrom.add(i + width));
A::simd_convert(s1).store_simd(pto.add(i + width));
let s2 = A::WideFrom::load_simd(arch, pfrom.add(i + 2 * width));
A::simd_convert(s2).store_simd(pto.add(i + 2 * width));
let s3 = A::WideFrom::load_simd(arch, pfrom.add(i + 3 * width));
A::simd_convert(s3).store_simd(pto.add(i + 3 * width));
let s0 = A::WideFrom::load_simd(arch, pfrom.add(i + 4 * width));
A::simd_convert(s0).store_simd(pto.add(i + 4 * width));
let s1 = A::WideFrom::load_simd(arch, pfrom.add(i + 5 * width));
A::simd_convert(s1).store_simd(pto.add(i + 5 * width));
let s2 = A::WideFrom::load_simd(arch, pfrom.add(i + 6 * width));
A::simd_convert(s2).store_simd(pto.add(i + 6 * width));
let s3 = A::WideFrom::load_simd(arch, pfrom.add(i + 7 * width));
A::simd_convert(s3).store_simd(pto.add(i + 7 * width));
i += UNROLL * width;
}
}
while i + width <= len {
let s0 = unsafe { A::WideFrom::load_simd(arch, pfrom.add(i)) };
let t0 = A::simd_convert(s0);
unsafe { t0.store_simd(pto.add(i)) };
i += width;
}
if i != len {
let offset = i - (width - len % width);
let s0 = unsafe { A::WideFrom::load_simd(arch, pfrom.add(offset)) };
let t0 = A::simd_convert(s0);
unsafe { t0.store_simd(pto.add(offset)) };
}
}
#[cfg(target_arch = "x86_64")]
impl SIMDConvert<f32, f16> for diskann_wide::arch::x86_64::V4 {
type Width = Const<8>;
type WideTo = <diskann_wide::arch::x86_64::V4 as Architecture>::f32x8;
type WideFrom = <diskann_wide::arch::x86_64::V4 as Architecture>::f16x8;
#[inline(always)]
fn simd_convert(from: Self::WideFrom) -> Self::WideTo {
from.into()
}
#[inline(always)]
unsafe fn handle_small(self, pto: *mut f32, pfrom: *const f16, len: usize) {
for i in 0..len {
*pto.add(i) = diskann_wide::cast_f16_to_f32(*pfrom.add(i))
}
}
}
#[cfg(target_arch = "x86_64")]
impl SIMDConvert<f32, f16> for diskann_wide::arch::x86_64::V3 {
type Width = Const<8>;
type WideTo = <diskann_wide::arch::x86_64::V3 as Architecture>::f32x8;
type WideFrom = <diskann_wide::arch::x86_64::V3 as Architecture>::f16x8;
#[inline(always)]
fn simd_convert(from: Self::WideFrom) -> Self::WideTo {
from.into()
}
#[inline(always)]
unsafe fn handle_small(self, pto: *mut f32, pfrom: *const f16, len: usize) {
for i in 0..len {
*pto.add(i) = diskann_wide::cast_f16_to_f32(*pfrom.add(i))
}
}
}
#[cfg(target_arch = "x86_64")]
impl SIMDConvert<f16, f32> for diskann_wide::arch::x86_64::V4 {
type Width = Const<8>;
type WideTo = <diskann_wide::arch::x86_64::V4 as Architecture>::f16x8;
type WideFrom = <diskann_wide::arch::x86_64::V4 as Architecture>::f32x8;
#[inline(always)]
fn simd_convert(from: Self::WideFrom) -> Self::WideTo {
from.simd_cast()
}
#[inline(always)]
unsafe fn handle_small(self, pto: *mut f16, pfrom: *const f32, len: usize) {
for i in 0..len {
*pto.add(i) = diskann_wide::cast_f32_to_f16(*pfrom.add(i))
}
}
}
#[cfg(target_arch = "x86_64")]
impl SIMDConvert<f16, f32> for diskann_wide::arch::x86_64::V3 {
type Width = Const<8>;
type WideTo = <diskann_wide::arch::x86_64::V3 as Architecture>::f16x8;
type WideFrom = <diskann_wide::arch::x86_64::V3 as Architecture>::f32x8;
#[inline(always)]
fn simd_convert(from: Self::WideFrom) -> Self::WideTo {
from.simd_cast()
}
#[inline(always)]
unsafe fn handle_small(self, pto: *mut f16, pfrom: *const f32, len: usize) {
for i in 0..len {
*pto.add(i) = diskann_wide::cast_f32_to_f16(*pfrom.add(i))
}
}
}
#[cfg(target_arch = "aarch64")]
impl SIMDConvert<f32, f16> for diskann_wide::arch::aarch64::Neon {
type Width = Const<4>;
type WideTo = <diskann_wide::arch::aarch64::Neon as Architecture>::f32x4;
type WideFrom = diskann_wide::arch::aarch64::f16x4;
#[inline(always)]
fn simd_convert(from: Self::WideFrom) -> Self::WideTo {
from.into()
}
#[inline(always)]
unsafe fn handle_small(self, pto: *mut f32, pfrom: *const f16, len: usize) {
for i in 0..len {
*pto.add(i) = diskann_wide::cast_f16_to_f32(*pfrom.add(i))
}
}
}
#[cfg(target_arch = "aarch64")]
impl SIMDConvert<f16, f32> for diskann_wide::arch::aarch64::Neon {
type Width = Const<4>;
type WideTo = diskann_wide::arch::aarch64::f16x4;
type WideFrom = <diskann_wide::arch::aarch64::Neon as Architecture>::f32x4;
#[inline(always)]
fn simd_convert(from: Self::WideFrom) -> Self::WideTo {
from.simd_cast()
}
#[inline(always)]
unsafe fn handle_small(self, pto: *mut f16, pfrom: *const f32, len: usize) {
for i in 0..len {
*pto.add(i) = diskann_wide::cast_f32_to_f16(*pfrom.add(i))
}
}
}
#[cfg(test)]
mod tests {
use rand::{
distr::{Distribution, StandardUniform},
rngs::StdRng,
SeedableRng,
};
use super::*;
trait ReferenceConvert<From> {
fn reference_convert(self, from: &[From]);
}
impl ReferenceConvert<f32> for &mut [f16] {
fn reference_convert(self, from: &[f32]) {
assert_eq!(self.len(), from.len());
std::iter::zip(self.iter_mut(), from.iter()).for_each(|(d, s)| *d = f16::from_f32(*s));
}
}
impl ReferenceConvert<f16> for &mut [f32] {
fn reference_convert(self, from: &[f16]) {
assert_eq!(self.len(), from.len());
std::iter::zip(self.iter_mut(), from.iter()).for_each(|(d, s)| *d = (*s).into());
}
}
fn test_cast_from_slice<To, From>(max_dim: usize, num_trials: usize, rng: &mut StdRng)
where
StandardUniform: Distribution<From>,
To: Default + PartialEq + std::fmt::Debug + Copy,
From: Default + Copy,
for<'a, 'b> &'a mut [To]: CastFromSlice<&'b [From]> + ReferenceConvert<From>,
{
let distribution = StandardUniform {};
for dim in 0..=max_dim {
let mut src = vec![From::default(); dim];
let mut dst = vec![To::default(); dim];
let mut dst_reference = vec![To::default(); dim];
for _ in 0..num_trials {
src.iter_mut().for_each(|s| *s = distribution.sample(rng));
dst.cast_from_slice(src.as_slice());
dst_reference.reference_convert(&src);
assert_eq!(dst, dst_reference);
}
}
}
#[test]
fn test_f32_to_f16_fuzz() {
let mut rng = StdRng::seed_from_u64(0x0a3bfe052a8ebf98);
test_cast_from_slice::<f16, f32>(256, 10, &mut rng);
}
#[test]
fn test_f16_to_f32_fuzz() {
let mut rng = StdRng::seed_from_u64(0x83765b2816321eca);
test_cast_from_slice::<f32, f16>(256, 10, &mut rng);
}
#[test]
fn miri_test_f32_to_f16() {
for dim in 0..256 {
println!("processing dim {}", dim);
let src = vec![f32::default(); dim];
let mut dst = vec![f16::default(); dim];
dst.cast_from_slice(&src);
SliceCast::<f16, f32>::new().run(
diskann_wide::arch::Scalar,
dst.as_mut_slice(),
src.as_slice(),
);
#[cfg(target_arch = "x86_64")]
{
if let Some(arch) = diskann_wide::arch::x86_64::V3::new_checked() {
SliceCast::<f16, f32>::new().run(arch, dst.as_mut_slice(), src.as_slice())
}
if let Some(arch) = diskann_wide::arch::x86_64::V4::new_checked_miri() {
SliceCast::<f16, f32>::new().run(arch, dst.as_mut_slice(), src.as_slice())
}
}
#[cfg(target_arch = "aarch64")]
if let Some(arch) = diskann_wide::arch::aarch64::Neon::new_checked() {
SliceCast::<f16, f32>::new().run(arch, dst.as_mut_slice(), src.as_slice())
}
}
}
#[test]
fn miri_test_f16_to_f32() {
for dim in 0..256 {
println!("processing dim {}", dim);
let src = vec![f16::default(); dim];
let mut dst = vec![f32::default(); dim];
dst.cast_from_slice(&src);
SliceCast::<f32, f16>::new().run(
diskann_wide::arch::Scalar,
dst.as_mut_slice(),
src.as_slice(),
);
#[cfg(target_arch = "x86_64")]
{
if let Some(arch) = diskann_wide::arch::x86_64::V3::new_checked() {
SliceCast::<f32, f16>::new().run(arch, dst.as_mut_slice(), src.as_slice())
}
if let Some(arch) = diskann_wide::arch::x86_64::V4::new_checked_miri() {
SliceCast::<f32, f16>::new().run(arch, dst.as_mut_slice(), src.as_slice())
}
}
#[cfg(target_arch = "aarch64")]
if let Some(arch) = diskann_wide::arch::aarch64::Neon::new_checked() {
SliceCast::<f32, f16>::new().run(arch, dst.as_mut_slice(), src.as_slice())
}
}
}
}