use std::sync::Arc;
use super::constants::{MAX_VECTOR_DIMENSIONS, TAG_SPARSE_VECTOR, TAG_VECTOR};
use super::sparse::SparseVec;
use crate::core::error::{Error, Result, StorageError, VectorError};
#[inline]
pub(crate) fn validate_vector_dimensions(len: usize) -> Result<()> {
if len > MAX_VECTOR_DIMENSIONS {
return Err(Error::Vector(VectorError::DimensionTooLarge {
dimension: len,
max_allowed: MAX_VECTOR_DIMENSIONS,
}));
}
Ok(())
}
pub fn serialize_vector(v: &[f32]) -> Vec<u8> {
let mut buffer = Vec::with_capacity(1 + 4 + v.len() * 4);
serialize_vector_into(v, &mut buffer);
buffer
}
pub fn serialize_vector_into(v: &[f32], buffer: &mut Vec<u8>) {
try_serialize_vector_into(v, buffer).unwrap_or_else(|e| panic!("{}", e))
}
pub fn try_serialize_vector_into(v: &[f32], buffer: &mut Vec<u8>) -> Result<()> {
validate_vector_dimensions(v.len())?;
let required_size = 1 + 4 + std::mem::size_of_val(v);
buffer.reserve(required_size);
buffer.push(TAG_VECTOR);
buffer.extend_from_slice(&(v.len() as u32).to_le_bytes());
#[cfg(target_endian = "little")]
{
let byte_slice = unsafe {
std::slice::from_raw_parts(v.as_ptr() as *const u8, std::mem::size_of_val(v))
};
buffer.extend_from_slice(byte_slice);
}
#[cfg(not(target_endian = "little"))]
{
for &value in v {
buffer.extend_from_slice(&value.to_le_bytes());
}
}
Ok(())
}
pub fn deserialize_vector(bytes: &[u8]) -> Result<(Arc<[f32]>, usize)> {
if bytes.len() < 5 {
return Err(
StorageError::CorruptedData("Buffer too short for vector header".to_string()).into(),
);
}
let tag = bytes[0];
if tag != TAG_VECTOR {
return Err(StorageError::CorruptedData(format!(
"Expected vector type tag {}, got {}",
TAG_VECTOR, tag
))
.into());
}
let dimension = u32::from_le_bytes(bytes[1..5].try_into().unwrap()) as usize;
validate_vector_dimensions(dimension)?;
let data_start: usize = 5;
let data_len = dimension
.checked_mul(4)
.ok_or_else(|| StorageError::CorruptedData("Vector dimension overflow".to_string()))?;
let total_len = data_start
.checked_add(data_len)
.ok_or_else(|| StorageError::CorruptedData("Vector size overflow".to_string()))?;
if bytes.len() < total_len {
return Err(StorageError::CorruptedData(format!(
"Buffer too short for vector data: need {} bytes, have {}",
total_len,
bytes.len()
))
.into());
}
let data_slice = &bytes[data_start..total_len];
#[cfg(target_endian = "little")]
let values = {
let mut values = Vec::with_capacity(dimension);
if dimension > 0 {
unsafe {
let src_ptr = data_slice.as_ptr();
let dst_ptr = values.as_mut_ptr() as *mut u8;
std::ptr::copy_nonoverlapping(src_ptr, dst_ptr, data_slice.len());
values.set_len(dimension);
}
}
values
};
#[cfg(not(target_endian = "little"))]
let values = {
let mut values = Vec::with_capacity(dimension);
for chunk in data_slice.chunks_exact(4) {
values.push(f32::from_le_bytes(chunk.try_into().unwrap()));
}
values
};
Ok((Arc::from(values.into_boxed_slice()), total_len))
}
pub fn serialize_sparse_vector(sv: &SparseVec) -> Vec<u8> {
let mut buffer = Vec::with_capacity(1 + 4 + 4 + sv.nnz() * 8);
serialize_sparse_vector_into(sv, &mut buffer);
buffer
}
pub fn serialize_sparse_vector_into(sv: &SparseVec, buffer: &mut Vec<u8>) {
buffer.reserve(1 + 4 + 4 + sv.nnz() * 8);
buffer.push(TAG_SPARSE_VECTOR);
buffer.extend_from_slice(&(sv.dimension() as u32).to_le_bytes());
buffer.extend_from_slice(&(sv.nnz() as u32).to_le_bytes());
for &idx in sv.indices() {
buffer.extend_from_slice(&idx.to_le_bytes());
}
for &val in sv.values() {
buffer.extend_from_slice(&val.to_le_bytes());
}
}
pub fn deserialize_sparse_vector(bytes: &[u8]) -> Result<(Arc<SparseVec>, usize)> {
if bytes.len() < 9 {
return Err(StorageError::CorruptedData(
"Buffer too short for sparse vector header".to_string(),
)
.into());
}
let tag = bytes[0];
if tag != TAG_SPARSE_VECTOR {
return Err(StorageError::CorruptedData(format!(
"Expected sparse vector type tag {}, got {}",
TAG_SPARSE_VECTOR, tag
))
.into());
}
let dimension = u32::from_le_bytes(bytes[1..5].try_into().unwrap());
let nnz = u32::from_le_bytes(bytes[5..9].try_into().unwrap()) as usize;
if nnz > dimension as usize {
return Err(StorageError::CorruptedData(format!(
"Sparse vector nnz {} exceeds dimension {}",
nnz, dimension
))
.into());
}
validate_vector_dimensions(nnz)?;
let data_start: usize = 9;
let indices_len = nnz
.checked_mul(4)
.ok_or_else(|| StorageError::CorruptedData("Sparse vector nnz overflow".to_string()))?;
let values_len = indices_len; let total_len = data_start
.checked_add(indices_len)
.and_then(|x: usize| x.checked_add(values_len))
.ok_or_else(|| StorageError::CorruptedData("Sparse vector size overflow".to_string()))?;
if bytes.len() < total_len {
return Err(StorageError::CorruptedData(format!(
"Buffer too short for sparse vector data: need {} bytes, have {}",
total_len,
bytes.len()
))
.into());
}
let indices_end = data_start + indices_len;
let indices_slice = &bytes[data_start..indices_end];
#[cfg(target_endian = "little")]
let indices = {
let mut indices = Vec::with_capacity(nnz);
if nnz > 0 {
unsafe {
let src_ptr = indices_slice.as_ptr();
let dst_ptr = indices.as_mut_ptr() as *mut u8;
std::ptr::copy_nonoverlapping(src_ptr, dst_ptr, indices_slice.len());
indices.set_len(nnz);
}
}
indices
};
#[cfg(not(target_endian = "little"))]
let indices = {
let mut indices = Vec::with_capacity(nnz);
for chunk in indices_slice.chunks_exact(4) {
indices.push(u32::from_le_bytes(chunk.try_into().unwrap()));
}
indices
};
let values_end = indices_end + values_len;
let values_slice = &bytes[indices_end..values_end];
#[cfg(target_endian = "little")]
let values = {
let mut values = Vec::with_capacity(nnz);
if nnz > 0 {
unsafe {
let src_ptr = values_slice.as_ptr();
let dst_ptr = values.as_mut_ptr() as *mut u8;
std::ptr::copy_nonoverlapping(src_ptr, dst_ptr, values_slice.len());
values.set_len(nnz);
}
}
values
};
#[cfg(not(target_endian = "little"))]
let values = {
let mut values = Vec::with_capacity(nnz);
for chunk in values_slice.chunks_exact(4) {
values.push(f32::from_le_bytes(chunk.try_into().unwrap()));
}
values
};
let sparse_vec = SparseVec::new(indices, values, dimension)?;
Ok((Arc::new(sparse_vec), total_len))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_serialize_vector_basic() {
let data = [1.0f32, 2.0, 3.0];
let bytes = serialize_vector(&data);
assert_eq!(bytes[0], TAG_VECTOR);
assert_eq!(bytes.len(), 1 + 4 + 3 * 4);
let (deserialized, consumed) = deserialize_vector(&bytes).unwrap();
assert_eq!(deserialized.as_ref(), &data[..]);
assert_eq!(consumed, bytes.len());
}
#[test]
fn test_serialize_vector_round_trip() {
let data: Vec<f32> = (0..100).map(|i| i as f32 * 0.01).collect();
let bytes = serialize_vector(&data);
let (deserialized, consumed) = deserialize_vector(&bytes).unwrap();
assert_eq!(deserialized.as_ref(), &data[..]);
assert_eq!(consumed, bytes.len());
}
#[test]
fn test_serialize_vector_empty() {
let empty: Vec<f32> = vec![];
let bytes = serialize_vector(&empty);
assert_eq!(bytes.len(), 5);
assert_eq!(bytes[0], TAG_VECTOR);
let (deserialized, consumed) = deserialize_vector(&bytes).unwrap();
assert!(deserialized.is_empty());
assert_eq!(consumed, 5);
}
#[test]
fn test_serialize_vector_large() {
let large_vector: Vec<f32> = (0..1536).map(|i| (i as f32) / 1536.0).collect();
let bytes = serialize_vector(&large_vector);
assert_eq!(bytes.len(), 1 + 4 + 1536 * 4);
let (deserialized, consumed) = deserialize_vector(&bytes).unwrap();
assert_eq!(deserialized.len(), 1536);
assert_eq!(consumed, bytes.len());
for (i, &val) in deserialized.iter().enumerate() {
assert!((val - (i as f32) / 1536.0).abs() < f32::EPSILON);
}
}
#[test]
fn test_serialize_vector_special_values() {
let data = [f32::INFINITY, f32::NEG_INFINITY, 0.0, -0.0, f32::NAN];
let bytes = serialize_vector(&data);
let (deserialized, _) = deserialize_vector(&bytes).unwrap();
assert_eq!(deserialized[0], f32::INFINITY);
assert_eq!(deserialized[1], f32::NEG_INFINITY);
assert_eq!(deserialized[2], 0.0);
assert_eq!(deserialized[3], 0.0); assert!(deserialized[4].is_nan());
}
#[test]
fn test_deserialize_vector_errors() {
let result = deserialize_vector(&[]);
assert!(result.is_err());
let result = deserialize_vector(&[TAG_VECTOR, 1, 0, 0]);
assert!(result.is_err());
let result = deserialize_vector(&[TAG_VECTOR + 1, 3, 0, 0, 0]);
assert!(result.is_err());
let mut bytes = vec![TAG_VECTOR, 3, 0, 0, 0]; bytes.extend_from_slice(&[1.0f32.to_le_bytes()[0]]); let result = deserialize_vector(&bytes);
assert!(result.is_err());
}
#[test]
fn test_vector_serialization_optimization_correctness() {
let test_cases = vec![
vec![], vec![1.0f32], vec![1.0f32, 2.0, 3.0], (0..100).map(|i| i as f32 * 0.01).collect(), (0..384).map(|i| (i as f32) / 384.0).collect(), (0..1536).map(|i| (i as f32) / 1536.0).collect(), ];
for test_vector in test_cases {
let bytes = serialize_vector(&test_vector);
assert_eq!(bytes[0], TAG_VECTOR);
let dimension = u32::from_le_bytes(bytes[1..5].try_into().unwrap()) as usize;
assert_eq!(dimension, test_vector.len());
assert_eq!(bytes.len(), 1 + 4 + test_vector.len() * 4);
let (deserialized, consumed) = deserialize_vector(&bytes).unwrap();
assert_eq!(deserialized.len(), test_vector.len());
assert_eq!(consumed, bytes.len());
for (i, (&original, &recovered)) in
test_vector.iter().zip(deserialized.iter()).enumerate()
{
assert_eq!(
original,
recovered,
"Mismatch at index {} for vector of length {}",
i,
test_vector.len()
);
}
}
}
#[test]
fn test_vector_serialization_special_values_optimization() {
let special_values = vec![
f32::INFINITY,
f32::NEG_INFINITY,
0.0,
-0.0,
f32::MAX,
f32::MIN,
f32::MIN_POSITIVE,
1.0,
-1.0,
std::f32::consts::PI,
f32::NAN,
];
let bytes = serialize_vector(&special_values);
let (deserialized, _) = deserialize_vector(&bytes).unwrap();
assert_eq!(deserialized[0], f32::INFINITY);
assert_eq!(deserialized[1], f32::NEG_INFINITY);
assert_eq!(deserialized[2], 0.0);
assert_eq!(deserialized[3], 0.0); assert_eq!(deserialized[4], f32::MAX);
assert_eq!(deserialized[5], f32::MIN);
assert_eq!(deserialized[6], f32::MIN_POSITIVE);
assert_eq!(deserialized[7], 1.0);
assert_eq!(deserialized[8], -1.0);
assert!((deserialized[9] - std::f32::consts::PI).abs() < f32::EPSILON);
assert!(deserialized[10].is_nan());
}
#[test]
fn test_vector_serialization_deterministic() {
let vector: Vec<f32> = (0..100).map(|i| i as f32 * 0.1).collect();
let bytes1 = serialize_vector(&vector);
let bytes2 = serialize_vector(&vector);
let bytes3 = serialize_vector(&vector);
assert_eq!(bytes1, bytes2);
assert_eq!(bytes2, bytes3);
let (v1, _) = deserialize_vector(&bytes1).unwrap();
let (v2, _) = deserialize_vector(&bytes2).unwrap();
let (v3, _) = deserialize_vector(&bytes3).unwrap();
assert_eq!(v1.as_ref(), v2.as_ref());
assert_eq!(v2.as_ref(), v3.as_ref());
}
#[test]
fn test_vector_deserialization_unaligned() {
let vector: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let serialized = serialize_vector(&vector);
let mut padded_buffer = vec![0xFF]; padded_buffer.extend_from_slice(&serialized);
let (deserialized, consumed) = deserialize_vector(&padded_buffer[1..]).unwrap();
assert_eq!(deserialized.len(), vector.len());
assert_eq!(consumed, serialized.len());
for (original, recovered) in vector.iter().zip(deserialized.iter()) {
assert_eq!(original, recovered);
}
let mut padded_buffer3 = vec![0xFF, 0xFF, 0xFF];
padded_buffer3.extend_from_slice(&serialized);
let (deserialized3, consumed3) = deserialize_vector(&padded_buffer3[3..]).unwrap();
assert_eq!(deserialized3.len(), vector.len());
assert_eq!(consumed3, serialized.len());
for (original, recovered) in vector.iter().zip(deserialized3.iter()) {
assert_eq!(original, recovered);
}
}
#[test]
fn test_serialize_sparse_vector_basic() {
let sparse = SparseVec::new(vec![0, 2, 4], vec![1.0, 2.0, 3.0], 5).unwrap();
let bytes = serialize_sparse_vector(&sparse);
assert_eq!(bytes[0], TAG_SPARSE_VECTOR);
let (deserialized, consumed) = deserialize_sparse_vector(&bytes).unwrap();
assert_eq!(consumed, bytes.len());
assert_eq!(deserialized.nnz(), 3);
assert_eq!(deserialized.dimension(), 5);
assert_eq!(deserialized.indices(), &[0, 2, 4]);
assert_eq!(deserialized.values(), &[1.0, 2.0, 3.0]);
}
#[test]
fn test_serialize_sparse_vector_empty() {
let sparse = SparseVec::new(vec![], vec![], 100).unwrap();
let bytes = serialize_sparse_vector(&sparse);
assert_eq!(bytes.len(), 9);
assert_eq!(bytes[0], TAG_SPARSE_VECTOR);
let (deserialized, consumed) = deserialize_sparse_vector(&bytes).unwrap();
assert!(deserialized.indices().is_empty());
assert_eq!(deserialized.dimension(), 100);
assert_eq!(consumed, 9);
}
#[test]
fn test_serialize_sparse_vector_round_trip() {
let sparse = SparseVec::new(
vec![1, 10, 42, 99, 256],
vec![1.5, 2.3, 0.8, 4.2, 1.1],
1000,
)
.unwrap();
let bytes = serialize_sparse_vector(&sparse);
let (deserialized, consumed) = deserialize_sparse_vector(&bytes).unwrap();
assert_eq!(consumed, bytes.len());
assert_eq!(deserialized.nnz(), sparse.nnz());
assert_eq!(deserialized.dimension(), sparse.dimension());
assert_eq!(deserialized.indices(), sparse.indices());
assert_eq!(deserialized.values(), sparse.values());
}
#[test]
fn test_deserialize_sparse_vector_errors() {
let result = deserialize_sparse_vector(&[]);
assert!(result.is_err());
let result = deserialize_sparse_vector(&[TAG_SPARSE_VECTOR, 1, 0, 0]);
assert!(result.is_err());
let result = deserialize_sparse_vector(&[TAG_SPARSE_VECTOR + 1, 5, 0, 0, 0, 2, 0, 0, 0]);
assert!(result.is_err());
let mut bytes = vec![TAG_SPARSE_VECTOR];
bytes.extend_from_slice(&5u32.to_le_bytes()); bytes.extend_from_slice(&10u32.to_le_bytes()); let result = deserialize_sparse_vector(&bytes);
assert!(result.is_err());
}
#[test]
fn test_serialize_sparse_vector_bm25_like() {
let sparse = SparseVec::new(
vec![42, 157, 891, 1023, 5000],
vec![2.3, 1.8, 0.9, 3.1, 1.5],
10000, )
.unwrap();
let bytes = serialize_sparse_vector(&sparse);
assert_eq!(bytes.len(), 49);
let (deserialized, _) = deserialize_sparse_vector(&bytes).unwrap();
assert_eq!(deserialized.nnz(), 5);
assert_eq!(deserialized.dimension(), 10000);
}
#[test]
#[should_panic(expected = "Vector dimension")]
fn test_serialize_vector_into_panics_on_overflow() {
let large_vector = vec![0.0; MAX_VECTOR_DIMENSIONS + 1];
let mut buffer = Vec::new();
serialize_vector_into(&large_vector, &mut buffer);
}
#[test]
fn test_serialize_vector_into_buffer_appending() {
let mut buffer = vec![0xAA, 0xBB, 0xCC]; let vector = vec![1.0f32, 2.0, 3.0];
serialize_vector_into(&vector, &mut buffer);
assert_eq!(buffer[0], 0xAA);
assert_eq!(buffer[1], 0xBB);
assert_eq!(buffer[2], 0xCC);
assert_eq!(buffer[3], TAG_VECTOR);
let (deserialized, consumed) = deserialize_vector(&buffer[3..]).unwrap();
assert_eq!(deserialized.as_ref(), &vector[..]);
assert_eq!(consumed, buffer.len() - 3);
}
#[test]
fn test_deserialize_sparse_vector_validates_duplicates() {
let mut buffer = Vec::new();
buffer.push(TAG_SPARSE_VECTOR);
let dim = 10u32;
let nnz = 2u32;
buffer.extend_from_slice(&dim.to_le_bytes());
buffer.extend_from_slice(&nnz.to_le_bytes());
buffer.extend_from_slice(&5u32.to_le_bytes());
buffer.extend_from_slice(&5u32.to_le_bytes());
buffer.extend_from_slice(&1.0f32.to_le_bytes());
buffer.extend_from_slice(&2.0f32.to_le_bytes());
let result = deserialize_sparse_vector(&buffer);
assert!(result.is_err());
match result {
Err(Error::Vector(VectorError::InvalidSparseVector { reason })) => {
assert!(reason.contains("Duplicate index"));
}
_ => panic!("Expected InvalidSparseVector error, got {:?}", result),
}
}
#[test]
fn test_deserialize_sparse_vector_validates_zeros() {
let mut buffer = Vec::new();
buffer.push(TAG_SPARSE_VECTOR);
let dim = 10u32;
let nnz = 1u32;
buffer.extend_from_slice(&dim.to_le_bytes());
buffer.extend_from_slice(&nnz.to_le_bytes());
buffer.extend_from_slice(&0u32.to_le_bytes());
buffer.extend_from_slice(&0.0f32.to_le_bytes());
let result = deserialize_sparse_vector(&buffer);
assert!(result.is_err());
match result {
Err(Error::Vector(VectorError::InvalidSparseVector { reason })) => {
assert!(reason.contains("zero value"));
}
_ => panic!("Expected InvalidSparseVector error, got {:?}", result),
}
}
#[test]
fn test_deserialize_sparse_vector_sorts_indices() {
let mut buffer = Vec::new();
buffer.push(TAG_SPARSE_VECTOR);
let dim = 10u32;
let nnz = 2u32;
buffer.extend_from_slice(&dim.to_le_bytes());
buffer.extend_from_slice(&nnz.to_le_bytes());
buffer.extend_from_slice(&5u32.to_le_bytes());
buffer.extend_from_slice(&2u32.to_le_bytes());
buffer.extend_from_slice(&1.0f32.to_le_bytes());
buffer.extend_from_slice(&2.0f32.to_le_bytes());
let (sv_arc, _) =
deserialize_sparse_vector(&buffer).expect("Should succeed and sort indices");
assert_eq!(sv_arc.indices(), &[2, 5]);
assert_eq!(sv_arc.values(), &[2.0, 1.0]);
}
#[test]
fn test_vector_bitwise_preservation() {
let pos_zero = 0.0f32;
let neg_zero = -0.0f32;
let inf = f32::INFINITY;
let neg_inf = f32::NEG_INFINITY;
let nan1 = f32::from_bits(0x7fc00001); let nan2 = f32::from_bits(0x7fc00002);
let data = vec![pos_zero, neg_zero, inf, neg_inf, nan1, nan2];
let bytes = serialize_vector(&data);
let (deserialized, _) = deserialize_vector(&bytes).unwrap();
assert_eq!(deserialized.len(), data.len());
for (i, &val) in data.iter().enumerate() {
assert_eq!(
val.to_bits(),
deserialized[i].to_bits(),
"Bitwise mismatch at index {}: expected {:08x}, got {:08x}",
i,
val.to_bits(),
deserialized[i].to_bits()
);
}
}
#[test]
fn test_serialize_vector_slice_offsets() {
let full_vec: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let slice = &full_vec[1..4];
let bytes = serialize_vector(slice);
let (deserialized, _) = deserialize_vector(&bytes).unwrap();
assert_eq!(&*deserialized, &[2.0, 3.0, 4.0]);
}
#[test]
fn test_deserialize_vector_zero_dim() {
let empty: Vec<f32> = Vec::new();
let bytes = serialize_vector(&empty);
let (deserialized, _) =
deserialize_vector(&bytes).expect("Should deserialize empty vector");
assert!(deserialized.is_empty());
}
}