use std::ops::Deref;
use diskann_vector::conversion::CastFromSlice;
use half::f16;
pub(crate) trait ConvertF32 {
type Returns<'a>: Deref<Target = [f32]>
where
Self: 'a;
fn convert_f32(&self) -> Self::Returns<'_>;
}
impl ConvertF32 for [f32] {
type Returns<'a>
= &'a [f32]
where
Self: 'a;
fn convert_f32(&self) -> Self::Returns<'_> {
self
}
}
macro_rules! bulk_conversion_allocating {
($T:ty) => {
impl ConvertF32 for [$T] {
type Returns<'a>
= Vec<f32>
where
Self: 'a;
fn convert_f32(&self) -> Self::Returns<'_> {
self.iter().map(|i| (*i).into()).collect()
}
}
};
}
bulk_conversion_allocating!(i8);
bulk_conversion_allocating!(u8);
impl ConvertF32 for [f16] {
type Returns<'a> = Vec<f32>;
fn convert_f32(&self) -> Self::Returns<'_> {
let mut output = vec![f32::default(); self.len()];
output.cast_from_slice(self);
output
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn f32_is_identity() {
let v: Vec<f32> = vec![1.0, 2.0, 3.0];
let ptr = v.as_ptr();
let u = v.convert_f32();
assert_eq!(u, v);
assert_eq!(
u.as_ptr(),
ptr,
"conversion should be the identity and return the underlying span"
);
}
fn test_allocating<T>(input: &[T])
where
T: Into<f32> + Copy,
[T]: ConvertF32,
{
let converted = input.convert_f32();
assert_eq!(converted.len(), input.len());
for (i, (c, n)) in std::iter::zip(converted.iter(), input.iter()).enumerate() {
assert_eq!(
*c,
<T as Into<f32>>::into(*n),
"conversion failed for input {i}"
);
}
}
#[test]
fn test_i8() {
let v: Vec<i8> = vec![1, 2, 3];
test_allocating(&v);
}
#[test]
fn test_u8() {
let v: Vec<u8> = vec![1, 2, 3];
test_allocating(&v);
}
#[test]
fn test_f16() {
let v: Vec<f16> = vec![f16::from_f32(1.0), f16::from_f32(2.0), f16::from_f32(3.0)];
test_allocating(&v);
}
}