use std::fmt;
use bytemuck::{Pod, cast_slice};
use half::f16;
#[derive(Debug, Clone, PartialEq)]
pub enum VectorTypeError {
UnknownType(String),
DimensionMismatch { expected: usize, got: usize },
NonFiniteValue,
}
impl fmt::Display for VectorTypeError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::UnknownType(name) => write!(f, "unknown vector type: {name}"),
Self::DimensionMismatch { expected, got } => {
write!(f, "expected {expected} dimensions, got {got}")
}
Self::NonFiniteValue => write!(f, "vector contains NaN or Inf"),
}
}
}
impl std::error::Error for VectorTypeError {}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum VectorType {
Float2,
Float4,
Float8,
Int1,
Int2,
Int4,
}
impl VectorType {
pub fn from_name(name: &str) -> Result<Self, VectorTypeError> {
match name {
"float2" => Ok(Self::Float2),
"float4" => Ok(Self::Float4),
"float8" => Ok(Self::Float8),
"int1" => Ok(Self::Int1),
"int2" => Ok(Self::Int2),
"int4" => Ok(Self::Int4),
other => Err(VectorTypeError::UnknownType(other.to_string())),
}
}
pub fn element_size(&self) -> usize {
match self {
Self::Float2 => 2,
Self::Float4 => 4,
Self::Float8 => 8,
Self::Int1 => 1,
Self::Int2 => 2,
Self::Int4 => 4,
}
}
pub fn blob_size(&self, dim: usize) -> usize {
dim * self.element_size()
}
pub fn validate_blob(&self, blob: &[u8], dim: usize) -> Result<(), VectorTypeError> {
let expected = self.blob_size(dim);
if blob.len() != expected {
return Err(VectorTypeError::DimensionMismatch {
expected: dim,
got: blob.len() / self.element_size(),
});
}
Ok(())
}
pub fn validate_finite(&self, blob: &[u8], dim: usize) -> Result<(), VectorTypeError> {
self.validate_blob(blob, dim)?;
match self {
Self::Float2 => {
let values: &[f16] = cast_slice(blob);
if values.iter().any(|v| !v.is_finite()) {
return Err(VectorTypeError::NonFiniteValue);
}
}
Self::Float4 => {
let values: &[f32] = cast_slice(blob);
if values.iter().any(|v| !v.is_finite()) {
return Err(VectorTypeError::NonFiniteValue);
}
}
Self::Float8 => {
let values: &[f64] = cast_slice(blob);
if values.iter().any(|v| !v.is_finite()) {
return Err(VectorTypeError::NonFiniteValue);
}
}
Self::Int1 | Self::Int2 | Self::Int4 => {} }
Ok(())
}
pub fn slice_to_blob<T: Pod>(&self, values: &[T]) -> Vec<u8> {
cast_slice(values).to_vec()
}
pub fn blob_to_slice<'a, T: Pod>(&self, blob: &'a [u8]) -> &'a [T] {
cast_slice(blob)
}
pub fn is_float(&self) -> bool {
matches!(self, Self::Float2 | Self::Float4 | Self::Float8)
}
pub fn name(&self) -> &'static str {
match self {
Self::Float2 => "float2",
Self::Float4 => "float4",
Self::Float8 => "float8",
Self::Int1 => "int1",
Self::Int2 => "int2",
Self::Int4 => "int4",
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use bytemuck::cast_slice;
use half::f16;
#[test]
fn from_name_all_valid() {
assert_eq!(VectorType::from_name("float2"), Ok(VectorType::Float2));
assert_eq!(VectorType::from_name("float4"), Ok(VectorType::Float4));
assert_eq!(VectorType::from_name("float8"), Ok(VectorType::Float8));
assert_eq!(VectorType::from_name("int1"), Ok(VectorType::Int1));
assert_eq!(VectorType::from_name("int2"), Ok(VectorType::Int2));
assert_eq!(VectorType::from_name("int4"), Ok(VectorType::Int4));
}
#[test]
fn from_name_unknown_returns_error() {
let err = VectorType::from_name("float3").unwrap_err();
assert_eq!(err, VectorTypeError::UnknownType("float3".to_string()));
}
#[test]
fn from_name_case_sensitive() {
assert!(VectorType::from_name("Float4").is_err());
assert!(VectorType::from_name("INT1").is_err());
assert!(VectorType::from_name("").is_err());
}
#[test]
fn name_round_trips_with_from_name() {
let variants = [
VectorType::Float2,
VectorType::Float4,
VectorType::Float8,
VectorType::Int1,
VectorType::Int2,
VectorType::Int4,
];
for vt in variants {
assert_eq!(VectorType::from_name(vt.name()), Ok(vt));
}
}
#[test]
fn element_size_correct() {
assert_eq!(VectorType::Float2.element_size(), 2);
assert_eq!(VectorType::Float4.element_size(), 4);
assert_eq!(VectorType::Float8.element_size(), 8);
assert_eq!(VectorType::Int1.element_size(), 1);
assert_eq!(VectorType::Int2.element_size(), 2);
assert_eq!(VectorType::Int4.element_size(), 4);
}
#[test]
fn blob_size_is_element_size_times_dim() {
for vt in [
VectorType::Float2,
VectorType::Float4,
VectorType::Float8,
VectorType::Int1,
VectorType::Int2,
VectorType::Int4,
] {
for dim in [0, 1, 3, 128, 1536] {
assert_eq!(vt.blob_size(dim), vt.element_size() * dim);
}
}
}
#[test]
fn validate_blob_correct_size_ok() {
let blob = vec![0u8; VectorType::Float4.blob_size(4)]; assert!(VectorType::Float4.validate_blob(&blob, 4).is_ok());
}
#[test]
fn validate_blob_too_short_returns_error() {
let blob = vec![0u8; 12]; let err = VectorType::Float4.validate_blob(&blob, 4).unwrap_err();
assert_eq!(
err,
VectorTypeError::DimensionMismatch {
expected: 4,
got: 3
}
);
}
#[test]
fn validate_blob_too_long_returns_error() {
let blob = vec![0u8; 20]; let err = VectorType::Float4.validate_blob(&blob, 4).unwrap_err();
assert_eq!(
err,
VectorTypeError::DimensionMismatch {
expected: 4,
got: 5
}
);
}
#[test]
fn validate_blob_int_types() {
let blob = vec![0u8; 6]; assert!(VectorType::Int2.validate_blob(&blob, 3).is_ok());
let err = VectorType::Int2.validate_blob(&blob, 4).unwrap_err();
assert_eq!(
err,
VectorTypeError::DimensionMismatch {
expected: 4,
got: 3
}
);
}
#[test]
fn validate_finite_all_finite_f32_ok() {
let values: Vec<f32> = vec![1.0, -2.5, 0.0, f32::MAX];
let blob = VectorType::Float4.slice_to_blob(&values);
assert!(VectorType::Float4.validate_finite(&blob, 4).is_ok());
}
#[test]
fn validate_finite_nan_f32_errors() {
let values: Vec<f32> = vec![1.0, f32::NAN, 3.0];
let blob = VectorType::Float4.slice_to_blob(&values);
assert_eq!(
VectorType::Float4.validate_finite(&blob, 3).unwrap_err(),
VectorTypeError::NonFiniteValue
);
}
#[test]
fn validate_finite_inf_f32_errors() {
let values: Vec<f32> = vec![1.0, f32::INFINITY];
let blob = VectorType::Float4.slice_to_blob(&values);
assert_eq!(
VectorType::Float4.validate_finite(&blob, 2).unwrap_err(),
VectorTypeError::NonFiniteValue
);
}
#[test]
fn validate_finite_neg_inf_f64_errors() {
let values: Vec<f64> = vec![0.0, f64::NEG_INFINITY];
let blob = VectorType::Float8.slice_to_blob(&values);
assert_eq!(
VectorType::Float8.validate_finite(&blob, 2).unwrap_err(),
VectorTypeError::NonFiniteValue
);
}
#[test]
fn validate_finite_all_finite_f64_ok() {
let values: Vec<f64> = vec![1.0, -2.5, 0.0, f64::MAX];
let blob = VectorType::Float8.slice_to_blob(&values);
assert!(VectorType::Float8.validate_finite(&blob, 4).is_ok());
}
#[test]
fn validate_finite_nan_f16_errors() {
let values: Vec<f16> = vec![f16::from_f32(1.0), f16::NAN];
let blob = VectorType::Float2.slice_to_blob(&values);
assert_eq!(
VectorType::Float2.validate_finite(&blob, 2).unwrap_err(),
VectorTypeError::NonFiniteValue
);
}
#[test]
fn validate_finite_inf_f16_errors() {
let values: Vec<f16> = vec![f16::INFINITY];
let blob = VectorType::Float2.slice_to_blob(&values);
assert_eq!(
VectorType::Float2.validate_finite(&blob, 1).unwrap_err(),
VectorTypeError::NonFiniteValue
);
}
#[test]
fn validate_finite_all_finite_f16_ok() {
let values: Vec<f16> = vec![f16::from_f32(1.0), f16::from_f32(-0.5), f16::from_f32(0.0)];
let blob = VectorType::Float2.slice_to_blob(&values);
assert!(VectorType::Float2.validate_finite(&blob, 3).is_ok());
}
#[test]
fn validate_finite_integer_types_always_ok() {
let i8_blob = VectorType::Int1.slice_to_blob::<i8>(&[i8::MIN, 0, i8::MAX]);
let i16_blob = VectorType::Int2.slice_to_blob::<i16>(&[i16::MIN, 0, i16::MAX]);
let i32_blob = VectorType::Int4.slice_to_blob::<i32>(&[i32::MIN, 0, i32::MAX]);
assert!(VectorType::Int1.validate_finite(&i8_blob, 3).is_ok());
assert!(VectorType::Int2.validate_finite(&i16_blob, 3).is_ok());
assert!(VectorType::Int4.validate_finite(&i32_blob, 3).is_ok());
}
#[test]
fn is_float_true_for_float_variants() {
assert!(VectorType::Float2.is_float());
assert!(VectorType::Float4.is_float());
assert!(VectorType::Float8.is_float());
}
#[test]
fn is_float_false_for_int_variants() {
assert!(!VectorType::Int1.is_float());
assert!(!VectorType::Int2.is_float());
assert!(!VectorType::Int4.is_float());
}
#[test]
fn round_trip_f32() {
let original: Vec<f32> = vec![1.0, -2.5, 3.125, 0.0];
let blob = VectorType::Float4.slice_to_blob(&original);
assert_eq!(blob.len(), original.len() * 4);
let recovered: &[f32] = VectorType::Float4.blob_to_slice(&blob);
assert_eq!(recovered, original.as_slice());
}
#[test]
fn round_trip_f64() {
let original: Vec<f64> = vec![1.0, -2.5, f64::MAX, f64::MIN_POSITIVE];
let blob = VectorType::Float8.slice_to_blob(&original);
assert_eq!(blob.len(), original.len() * 8);
let recovered: &[f64] = VectorType::Float8.blob_to_slice(&blob);
assert_eq!(recovered, original.as_slice());
}
#[test]
fn round_trip_f16() {
let original: Vec<f16> = vec![f16::from_f32(1.0), f16::from_f32(-0.5), f16::from_f32(0.0)];
let blob = VectorType::Float2.slice_to_blob(&original);
assert_eq!(blob.len(), original.len() * 2);
let recovered: &[f16] = VectorType::Float2.blob_to_slice(&blob);
assert_eq!(recovered, original.as_slice());
}
#[test]
fn round_trip_i8() {
let original: Vec<i8> = vec![i8::MIN, -1, 0, 1, i8::MAX];
let blob = VectorType::Int1.slice_to_blob(&original);
assert_eq!(blob.len(), original.len());
let recovered: &[i8] = VectorType::Int1.blob_to_slice(&blob);
assert_eq!(recovered, original.as_slice());
}
#[test]
fn round_trip_i16() {
let original: Vec<i16> = vec![i16::MIN, -1, 0, 1, i16::MAX];
let blob = VectorType::Int2.slice_to_blob(&original);
assert_eq!(blob.len(), original.len() * 2);
let recovered: &[i16] = VectorType::Int2.blob_to_slice(&blob);
assert_eq!(recovered, original.as_slice());
}
#[test]
fn round_trip_i32() {
let original: Vec<i32> = vec![i32::MIN, -1, 0, 1, i32::MAX];
let blob = VectorType::Int4.slice_to_blob(&original);
assert_eq!(blob.len(), original.len() * 4);
let recovered: &[i32] = VectorType::Int4.blob_to_slice(&blob);
assert_eq!(recovered, original.as_slice());
}
#[test]
fn slice_to_blob_matches_bytemuck_cast_slice() {
let values: Vec<f32> = vec![1.0_f32, 2.0, 3.0];
let expected: &[u8] = cast_slice(&values);
let got = VectorType::Float4.slice_to_blob(&values);
assert_eq!(got.as_slice(), expected);
}
}