use super::{bf16, f16, slice::HalfFloatSliceExt};
#[cfg(feature = "alloc")]
use alloc::vec::Vec;
use core::mem;
pub trait HalfFloatVecExt: private::SealedHalfFloatVec {
#[must_use]
fn reinterpret_into(self) -> Vec<u16>;
#[must_use]
fn from_f32_slice(slice: &[f32]) -> Self;
#[must_use]
fn from_f64_slice(slice: &[f64]) -> Self;
}
pub trait HalfBitsVecExt: private::SealedHalfBitsVec {
#[must_use]
fn reinterpret_into<H>(self) -> Vec<H>
where
H: crate::private::SealedHalf;
}
mod private {
use crate::{bf16, f16};
#[cfg(feature = "alloc")]
use alloc::vec::Vec;
pub trait SealedHalfFloatVec {}
impl SealedHalfFloatVec for Vec<f16> {}
impl SealedHalfFloatVec for Vec<bf16> {}
pub trait SealedHalfBitsVec {}
impl SealedHalfBitsVec for Vec<u16> {}
}
impl HalfFloatVecExt for Vec<f16> {
#[inline]
fn reinterpret_into(mut self) -> Vec<u16> {
let length = self.len();
let capacity = self.capacity();
let pointer = self.as_mut_ptr() as *mut u16;
mem::forget(self);
unsafe { Vec::from_raw_parts(pointer, length, capacity) }
}
#[allow(clippy::uninit_vec)]
fn from_f32_slice(slice: &[f32]) -> Self {
let mut vec = Vec::with_capacity(slice.len());
unsafe { vec.set_len(slice.len()) };
vec.convert_from_f32_slice(slice);
vec
}
#[allow(clippy::uninit_vec)]
fn from_f64_slice(slice: &[f64]) -> Self {
let mut vec = Vec::with_capacity(slice.len());
unsafe { vec.set_len(slice.len()) };
vec.convert_from_f64_slice(slice);
vec
}
}
impl HalfFloatVecExt for Vec<bf16> {
#[inline]
fn reinterpret_into(mut self) -> Vec<u16> {
let length = self.len();
let capacity = self.capacity();
let pointer = self.as_mut_ptr() as *mut u16;
mem::forget(self);
unsafe { Vec::from_raw_parts(pointer, length, capacity) }
}
#[allow(clippy::uninit_vec)]
fn from_f32_slice(slice: &[f32]) -> Self {
let mut vec = Vec::with_capacity(slice.len());
unsafe { vec.set_len(slice.len()) };
vec.convert_from_f32_slice(slice);
vec
}
#[allow(clippy::uninit_vec)]
fn from_f64_slice(slice: &[f64]) -> Self {
let mut vec = Vec::with_capacity(slice.len());
unsafe { vec.set_len(slice.len()) };
vec.convert_from_f64_slice(slice);
vec
}
}
impl HalfBitsVecExt for Vec<u16> {
#[inline]
fn reinterpret_into<H>(mut self) -> Vec<H>
where
H: crate::private::SealedHalf,
{
let length = self.len();
let capacity = self.capacity();
let pointer = self.as_mut_ptr() as *mut H;
mem::forget(self);
unsafe { Vec::from_raw_parts(pointer, length, capacity) }
}
}
#[cfg(test)]
mod test {
use super::{HalfBitsVecExt, HalfFloatVecExt};
use crate::{bf16, f16};
#[cfg(all(feature = "alloc", not(feature = "std")))]
use alloc::vec;
#[test]
fn test_vec_conversions_f16() {
let numbers = vec![f16::E, f16::PI, f16::EPSILON, f16::FRAC_1_SQRT_2];
let bits = vec![
f16::E.to_bits(),
f16::PI.to_bits(),
f16::EPSILON.to_bits(),
f16::FRAC_1_SQRT_2.to_bits(),
];
let bits_cloned = bits.clone();
let from_bits = bits.reinterpret_into::<f16>();
assert_eq!(&from_bits[..], &numbers[..]);
let to_bits = from_bits.reinterpret_into();
assert_eq!(&to_bits[..], &bits_cloned[..]);
}
#[test]
fn test_vec_conversions_bf16() {
let numbers = vec![bf16::E, bf16::PI, bf16::EPSILON, bf16::FRAC_1_SQRT_2];
let bits = vec![
bf16::E.to_bits(),
bf16::PI.to_bits(),
bf16::EPSILON.to_bits(),
bf16::FRAC_1_SQRT_2.to_bits(),
];
let bits_cloned = bits.clone();
let from_bits = bits.reinterpret_into::<bf16>();
assert_eq!(&from_bits[..], &numbers[..]);
let to_bits = from_bits.reinterpret_into();
assert_eq!(&to_bits[..], &bits_cloned[..]);
}
}