use core::slice;
use crate::{bf16, binary16::arch, f16};
pub trait HalfFloatSliceExt: private::SealedHalfFloatSlice {
#[must_use]
fn reinterpret_cast(&self) -> &[u16];
#[must_use]
fn reinterpret_cast_mut(&mut self) -> &mut [u16];
fn convert_from_f32_slice(&mut self, src: &[f32]);
fn convert_from_f64_slice(&mut self, src: &[f64]);
fn convert_to_f32_slice(&self, dst: &mut [f32]);
fn convert_to_f64_slice(&self, dst: &mut [f64]);
}
pub trait HalfBitsSliceExt: private::SealedHalfBitsSlice {
#[must_use]
fn reinterpret_cast<H>(&self) -> &[H]
where
H: crate::private::SealedHalf;
#[must_use]
fn reinterpret_cast_mut<H>(&mut self) -> &mut [H]
where
H: crate::private::SealedHalf;
}
mod private {
use crate::{bf16, f16};
pub trait SealedHalfFloatSlice {}
impl SealedHalfFloatSlice for [f16] {
}
impl SealedHalfFloatSlice for [bf16] {
}
pub trait SealedHalfBitsSlice {}
impl SealedHalfBitsSlice for [u16] {
}
}
impl HalfFloatSliceExt for [f16] {
#[inline]
fn reinterpret_cast(&self) -> &[u16] {
let pointer = self.as_ptr() as *const u16;
let length = self.len();
unsafe { slice::from_raw_parts(pointer, length) }
}
#[inline]
fn reinterpret_cast_mut(&mut self) -> &mut [u16] {
let pointer = self.as_mut_ptr().cast::<u16>();
let length = self.len();
unsafe { slice::from_raw_parts_mut(pointer, length) }
}
#[inline]
fn convert_from_f32_slice(&mut self, src: &[f32]) {
assert_eq!(self.len(), src.len(), "destination and source slices have different lengths");
arch::f32_to_f16_slice(src, self.reinterpret_cast_mut())
}
#[inline]
fn convert_from_f64_slice(&mut self, src: &[f64]) {
assert_eq!(self.len(), src.len(), "destination and source slices have different lengths");
arch::f64_to_f16_slice(src, self.reinterpret_cast_mut())
}
#[inline]
fn convert_to_f32_slice(&self, dst: &mut [f32]) {
assert_eq!(self.len(), dst.len(), "destination and source slices have different lengths");
arch::f16_to_f32_slice(self.reinterpret_cast(), dst)
}
#[inline]
fn convert_to_f64_slice(&self, dst: &mut [f64]) {
assert_eq!(self.len(), dst.len(), "destination and source slices have different lengths");
arch::f16_to_f64_slice(self.reinterpret_cast(), dst)
}
}
impl HalfFloatSliceExt for [bf16] {
#[inline]
fn reinterpret_cast(&self) -> &[u16] {
let pointer = self.as_ptr() as *const u16;
let length = self.len();
unsafe { slice::from_raw_parts(pointer, length) }
}
#[inline]
fn reinterpret_cast_mut(&mut self) -> &mut [u16] {
let pointer = self.as_mut_ptr().cast::<u16>();
let length = self.len();
unsafe { slice::from_raw_parts_mut(pointer, length) }
}
#[inline]
fn convert_from_f32_slice(&mut self, src: &[f32]) {
assert_eq!(self.len(), src.len(), "destination and source slices have different lengths");
for (i, f) in src.iter().enumerate() {
self[i] = bf16::from_f32(*f);
}
}
#[inline]
fn convert_from_f64_slice(&mut self, src: &[f64]) {
assert_eq!(self.len(), src.len(), "destination and source slices have different lengths");
for (i, f) in src.iter().enumerate() {
self[i] = bf16::from_f64(*f);
}
}
#[inline]
fn convert_to_f32_slice(&self, dst: &mut [f32]) {
assert_eq!(self.len(), dst.len(), "destination and source slices have different lengths");
for (i, f) in self.iter().enumerate() {
dst[i] = f.to_f32();
}
}
#[inline]
fn convert_to_f64_slice(&self, dst: &mut [f64]) {
assert_eq!(self.len(), dst.len(), "destination and source slices have different lengths");
for (i, f) in self.iter().enumerate() {
dst[i] = f.to_f64();
}
}
}
impl HalfBitsSliceExt for [u16] {
#[inline]
fn reinterpret_cast<H>(&self) -> &[H]
where
H: crate::private::SealedHalf,
{
let pointer = self.as_ptr() as *const H;
let length = self.len();
unsafe { slice::from_raw_parts(pointer, length) }
}
#[inline]
fn reinterpret_cast_mut<H>(&mut self) -> &mut [H]
where
H: crate::private::SealedHalf,
{
let pointer = self.as_mut_ptr() as *mut H;
let length = self.len();
unsafe { slice::from_raw_parts_mut(pointer, length) }
}
}
#[allow(clippy::float_cmp)]
#[cfg(test)]
mod test {
use super::{HalfBitsSliceExt, HalfFloatSliceExt};
use crate::{bf16, f16};
#[test]
fn test_slice_conversions_f16() {
let bits = &[
f16::E.to_bits(),
f16::PI.to_bits(),
f16::EPSILON.to_bits(),
f16::FRAC_1_SQRT_2.to_bits(),
];
let numbers = &[f16::E, f16::PI, f16::EPSILON, f16::FRAC_1_SQRT_2];
let from_bits = bits.reinterpret_cast::<f16>();
assert_eq!(from_bits, numbers);
let to_bits = from_bits.reinterpret_cast();
assert_eq!(to_bits, bits);
}
#[test]
fn test_mutablility_f16() {
let mut bits_array = [f16::PI.to_bits()];
let bits = &mut bits_array[..];
{
let numbers = bits.reinterpret_cast_mut();
numbers[0] = f16::E;
}
assert_eq!(bits, &[f16::E.to_bits()]);
bits[0] = f16::LN_2.to_bits();
assert_eq!(bits, &[f16::LN_2.to_bits()]);
}
#[test]
fn test_slice_conversions_bf16() {
let bits = &[
bf16::E.to_bits(),
bf16::PI.to_bits(),
bf16::EPSILON.to_bits(),
bf16::FRAC_1_SQRT_2.to_bits(),
];
let numbers = &[bf16::E, bf16::PI, bf16::EPSILON, bf16::FRAC_1_SQRT_2];
let from_bits = bits.reinterpret_cast::<bf16>();
assert_eq!(from_bits, numbers);
let to_bits = from_bits.reinterpret_cast();
assert_eq!(to_bits, bits);
}
#[test]
fn test_mutablility_bf16() {
let mut bits_array = [bf16::PI.to_bits()];
let bits = &mut bits_array[..];
{
let numbers = bits.reinterpret_cast_mut();
numbers[0] = bf16::E;
}
assert_eq!(bits, &[bf16::E.to_bits()]);
bits[0] = bf16::LN_2.to_bits();
assert_eq!(bits, &[bf16::LN_2.to_bits()]);
}
#[test]
fn slice_convert_f16_f32() {
let vf32 = [1., 2., 3., 4., 5., 6., 7., 8.];
let vf16 = [
f16::from_f32(1.),
f16::from_f32(2.),
f16::from_f32(3.),
f16::from_f32(4.),
f16::from_f32(5.),
f16::from_f32(6.),
f16::from_f32(7.),
f16::from_f32(8.),
];
let mut buf32 = vf32;
let mut buf16 = vf16;
vf16.convert_to_f32_slice(&mut buf32);
assert_eq!(&vf32, &buf32);
buf16.convert_from_f32_slice(&vf32);
assert_eq!(&vf16, &buf16);
let vf32 = [1., 2., 3., 4., 5., 6., 7., 8., 9.];
let vf16 = [
f16::from_f32(1.),
f16::from_f32(2.),
f16::from_f32(3.),
f16::from_f32(4.),
f16::from_f32(5.),
f16::from_f32(6.),
f16::from_f32(7.),
f16::from_f32(8.),
f16::from_f32(9.),
];
let mut buf32 = vf32;
let mut buf16 = vf16;
vf16.convert_to_f32_slice(&mut buf32);
assert_eq!(&vf32, &buf32);
buf16.convert_from_f32_slice(&vf32);
assert_eq!(&vf16, &buf16);
let vf32 = [1., 2.];
let vf16 = [f16::from_f32(1.), f16::from_f32(2.)];
let mut buf32 = vf32;
let mut buf16 = vf16;
vf16.convert_to_f32_slice(&mut buf32);
assert_eq!(&vf32, &buf32);
buf16.convert_from_f32_slice(&vf32);
assert_eq!(&vf16, &buf16);
}
#[test]
fn slice_convert_bf16_f32() {
let vf32 = [1., 2., 3., 4., 5., 6., 7., 8.];
let vf16 = [
bf16::from_f32(1.),
bf16::from_f32(2.),
bf16::from_f32(3.),
bf16::from_f32(4.),
bf16::from_f32(5.),
bf16::from_f32(6.),
bf16::from_f32(7.),
bf16::from_f32(8.),
];
let mut buf32 = vf32;
let mut buf16 = vf16;
vf16.convert_to_f32_slice(&mut buf32);
assert_eq!(&vf32, &buf32);
buf16.convert_from_f32_slice(&vf32);
assert_eq!(&vf16, &buf16);
let vf32 = [1., 2., 3., 4., 5., 6., 7., 8., 9.];
let vf16 = [
bf16::from_f32(1.),
bf16::from_f32(2.),
bf16::from_f32(3.),
bf16::from_f32(4.),
bf16::from_f32(5.),
bf16::from_f32(6.),
bf16::from_f32(7.),
bf16::from_f32(8.),
bf16::from_f32(9.),
];
let mut buf32 = vf32;
let mut buf16 = vf16;
vf16.convert_to_f32_slice(&mut buf32);
assert_eq!(&vf32, &buf32);
buf16.convert_from_f32_slice(&vf32);
assert_eq!(&vf16, &buf16);
let vf32 = [1., 2.];
let vf16 = [bf16::from_f32(1.), bf16::from_f32(2.)];
let mut buf32 = vf32;
let mut buf16 = vf16;
vf16.convert_to_f32_slice(&mut buf32);
assert_eq!(&vf32, &buf32);
buf16.convert_from_f32_slice(&vf32);
assert_eq!(&vf16, &buf16);
}
#[test]
fn slice_convert_f16_f64() {
let vf64 = [1., 2., 3., 4., 5., 6., 7., 8.];
let vf16 = [
f16::from_f64(1.),
f16::from_f64(2.),
f16::from_f64(3.),
f16::from_f64(4.),
f16::from_f64(5.),
f16::from_f64(6.),
f16::from_f64(7.),
f16::from_f64(8.),
];
let mut buf64 = vf64;
let mut buf16 = vf16;
vf16.convert_to_f64_slice(&mut buf64);
assert_eq!(&vf64, &buf64);
buf16.convert_from_f64_slice(&vf64);
assert_eq!(&vf16, &buf16);
let vf64 = [1., 2., 3., 4., 5., 6., 7., 8., 9.];
let vf16 = [
f16::from_f64(1.),
f16::from_f64(2.),
f16::from_f64(3.),
f16::from_f64(4.),
f16::from_f64(5.),
f16::from_f64(6.),
f16::from_f64(7.),
f16::from_f64(8.),
f16::from_f64(9.),
];
let mut buf64 = vf64;
let mut buf16 = vf16;
vf16.convert_to_f64_slice(&mut buf64);
assert_eq!(&vf64, &buf64);
buf16.convert_from_f64_slice(&vf64);
assert_eq!(&vf16, &buf16);
let vf64 = [1., 2.];
let vf16 = [f16::from_f64(1.), f16::from_f64(2.)];
let mut buf64 = vf64;
let mut buf16 = vf16;
vf16.convert_to_f64_slice(&mut buf64);
assert_eq!(&vf64, &buf64);
buf16.convert_from_f64_slice(&vf64);
assert_eq!(&vf16, &buf16);
}
#[test]
fn slice_convert_bf16_f64() {
let vf64 = [1., 2., 3., 4., 5., 6., 7., 8.];
let vf16 = [
bf16::from_f64(1.),
bf16::from_f64(2.),
bf16::from_f64(3.),
bf16::from_f64(4.),
bf16::from_f64(5.),
bf16::from_f64(6.),
bf16::from_f64(7.),
bf16::from_f64(8.),
];
let mut buf64 = vf64;
let mut buf16 = vf16;
vf16.convert_to_f64_slice(&mut buf64);
assert_eq!(&vf64, &buf64);
buf16.convert_from_f64_slice(&vf64);
assert_eq!(&vf16, &buf16);
let vf64 = [1., 2., 3., 4., 5., 6., 7., 8., 9.];
let vf16 = [
bf16::from_f64(1.),
bf16::from_f64(2.),
bf16::from_f64(3.),
bf16::from_f64(4.),
bf16::from_f64(5.),
bf16::from_f64(6.),
bf16::from_f64(7.),
bf16::from_f64(8.),
bf16::from_f64(9.),
];
let mut buf64 = vf64;
let mut buf16 = vf16;
vf16.convert_to_f64_slice(&mut buf64);
assert_eq!(&vf64, &buf64);
buf16.convert_from_f64_slice(&vf64);
assert_eq!(&vf16, &buf16);
let vf64 = [1., 2.];
let vf16 = [bf16::from_f64(1.), bf16::from_f64(2.)];
let mut buf64 = vf64;
let mut buf16 = vf16;
vf16.convert_to_f64_slice(&mut buf64);
assert_eq!(&vf64, &buf64);
buf16.convert_from_f64_slice(&vf64);
assert_eq!(&vf16, &buf16);
}
#[test]
#[should_panic]
fn convert_from_f32_slice_len_mismatch_panics() {
let mut slice1 = [f16::ZERO; 3];
let slice2 = [0f32; 4];
slice1.convert_from_f32_slice(&slice2);
}
#[test]
#[should_panic]
fn convert_from_f64_slice_len_mismatch_panics() {
let mut slice1 = [f16::ZERO; 3];
let slice2 = [0f64; 4];
slice1.convert_from_f64_slice(&slice2);
}
#[test]
#[should_panic]
fn convert_to_f32_slice_len_mismatch_panics() {
let slice1 = [f16::ZERO; 3];
let mut slice2 = [0f32; 4];
slice1.convert_to_f32_slice(&mut slice2);
}
#[test]
#[should_panic]
fn convert_to_f64_slice_len_mismatch_panics() {
let slice1 = [f16::ZERO; 3];
let mut slice2 = [0f64; 4];
slice1.convert_to_f64_slice(&mut slice2);
}
}