use half::{bf16, f16};
use nodedb_types::vector_dtype::VectorStorageDtype;
#[derive(thiserror::Error, Debug)]
pub enum DtypeError {
#[error(
"dtype byte-length mismatch for {dtype}: expected {expected} bytes for dim {dim}, got {actual}"
)]
BadByteLen {
dtype: VectorStorageDtype,
dim: usize,
expected: usize,
actual: usize,
},
}
pub fn validate_byte_len(
bytes: &[u8],
dtype: VectorStorageDtype,
dim: usize,
) -> Result<(), DtypeError> {
let expected = dtype.bytes_for_dim(dim);
if bytes.len() != expected {
return Err(DtypeError::BadByteLen {
dtype,
dim,
expected,
actual: bytes.len(),
});
}
Ok(())
}
pub fn cast_to_f32(
src: &[u8],
dtype: VectorStorageDtype,
dim: usize,
) -> Result<Vec<f32>, DtypeError> {
validate_byte_len(src, dtype, dim)?;
match dtype {
VectorStorageDtype::F32 => {
let out = src
.chunks_exact(4)
.map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
.collect();
Ok(out)
}
VectorStorageDtype::F16 => {
let out = src
.chunks_exact(2)
.map(|c| f16::from_le_bytes([c[0], c[1]]).to_f32())
.collect();
Ok(out)
}
VectorStorageDtype::BF16 => {
let out = src
.chunks_exact(2)
.map(|c| bf16::from_le_bytes([c[0], c[1]]).to_f32())
.collect();
Ok(out)
}
_ => unreachable!("unrecognised VectorStorageDtype variant in cast_to_f32"),
}
}
pub fn cast_from_f32(src: &[f32], dtype: VectorStorageDtype) -> Vec<u8> {
match dtype {
VectorStorageDtype::F32 => {
let mut out = Vec::with_capacity(src.len() * 4);
for &x in src {
out.extend_from_slice(&x.to_le_bytes());
}
out
}
VectorStorageDtype::F16 => {
let mut out = Vec::with_capacity(src.len() * 2);
for &x in src {
out.extend_from_slice(&f16::from_f32(x).to_le_bytes());
}
out
}
VectorStorageDtype::BF16 => {
let mut out = Vec::with_capacity(src.len() * 2);
for &x in src {
out.extend_from_slice(&bf16::from_f32(x).to_le_bytes());
}
out
}
_ => unreachable!("unrecognised VectorStorageDtype variant in cast_from_f32"),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn f32_round_trip_identity() {
let src = [1.0_f32, 2.0, 3.0];
let bytes = cast_from_f32(&src, VectorStorageDtype::F32);
let got = cast_to_f32(&bytes, VectorStorageDtype::F32, 3).unwrap();
assert_eq!(got, vec![1.0_f32, 2.0, 3.0]);
}
#[test]
fn f32_empty_round_trip() {
let bytes = cast_from_f32(&[], VectorStorageDtype::F32);
assert!(bytes.is_empty());
let got = cast_to_f32(&[], VectorStorageDtype::F32, 0).unwrap();
assert!(got.is_empty());
}
#[test]
fn f16_round_trip_within_tolerance() {
let src = [0.5_f32, 1.0, 2.5, 100.0];
let bytes = cast_from_f32(&src, VectorStorageDtype::F16);
let got = cast_to_f32(&bytes, VectorStorageDtype::F16, 4).unwrap();
for (orig, recovered) in src.iter().zip(got.iter()) {
assert!(
(orig - recovered).abs() < 1e-3,
"F16 round-trip: {orig} → {recovered}, diff too large"
);
}
}
#[test]
fn f16_empty_round_trip() {
let bytes = cast_from_f32(&[], VectorStorageDtype::F16);
assert!(bytes.is_empty());
let got = cast_to_f32(&[], VectorStorageDtype::F16, 0).unwrap();
assert!(got.is_empty());
}
#[test]
fn bf16_round_trip_within_tolerance() {
let src = [0.5_f32, 1.0, 2.5, 100.0];
let bytes = cast_from_f32(&src, VectorStorageDtype::BF16);
let got = cast_to_f32(&bytes, VectorStorageDtype::BF16, 4).unwrap();
for (orig, recovered) in src.iter().zip(got.iter()) {
assert!(
(orig - recovered).abs() < 1e-2,
"BF16 round-trip: {orig} → {recovered}, diff too large"
);
}
}
#[test]
fn bf16_empty_round_trip() {
let bytes = cast_from_f32(&[], VectorStorageDtype::BF16);
assert!(bytes.is_empty());
let got = cast_to_f32(&[], VectorStorageDtype::BF16, 0).unwrap();
assert!(got.is_empty());
}
#[test]
fn bf16_can_represent_large_values_f16_cannot() {
let large = [1.0e30_f32];
let bf16_bytes = cast_from_f32(&large, VectorStorageDtype::BF16);
let bf16_back = cast_to_f32(&bf16_bytes, VectorStorageDtype::BF16, 1).unwrap();
assert!(
bf16_back[0].is_finite(),
"BF16 should represent 1e30 as finite"
);
let f16_bytes = cast_from_f32(&large, VectorStorageDtype::F16);
let f16_back = cast_to_f32(&f16_bytes, VectorStorageDtype::F16, 1).unwrap();
assert!(
f16_back[0].is_infinite(),
"F16 should overflow 1e30 to infinity"
);
}
#[test]
fn bad_byte_len_f32_mismatch() {
let err = cast_to_f32(&[0u8; 7], VectorStorageDtype::F32, 2).unwrap_err();
match err {
DtypeError::BadByteLen {
dtype,
dim,
expected,
actual,
} => {
assert_eq!(dtype, VectorStorageDtype::F32);
assert_eq!(dim, 2);
assert_eq!(expected, 8);
assert_eq!(actual, 7);
}
}
}
#[test]
fn bad_byte_len_f16_odd_byte_count() {
let err = cast_to_f32(&[0u8; 3], VectorStorageDtype::F16, 2).unwrap_err();
match err {
DtypeError::BadByteLen {
dtype,
dim,
expected,
actual,
} => {
assert_eq!(dtype, VectorStorageDtype::F16);
assert_eq!(dim, 2);
assert_eq!(expected, 4);
assert_eq!(actual, 3);
}
}
}
#[test]
fn bad_byte_len_bf16_mismatch() {
let err = cast_to_f32(&[0u8; 5], VectorStorageDtype::BF16, 3).unwrap_err();
match err {
DtypeError::BadByteLen {
dtype,
dim,
expected,
actual,
} => {
assert_eq!(dtype, VectorStorageDtype::BF16);
assert_eq!(dim, 3);
assert_eq!(expected, 6);
assert_eq!(actual, 5);
}
}
}
#[test]
fn validate_byte_len_correct_passes() {
let bytes = [0u8; 12]; assert!(validate_byte_len(&bytes, VectorStorageDtype::F32, 3).is_ok());
}
#[test]
fn validate_byte_len_off_by_one_fails() {
let bytes = [0u8; 11]; let err = validate_byte_len(&bytes, VectorStorageDtype::F32, 3).unwrap_err();
match err {
DtypeError::BadByteLen {
expected, actual, ..
} => {
assert_eq!(expected, 12);
assert_eq!(actual, 11);
}
}
}
}