use crate::error::CodecError;
pub fn encode(data: &[f32], dims: usize, vectors: usize) -> Result<Vec<u8>, CodecError> {
if data.is_empty() || dims == 0 {
return Ok(build_header(dims as u32, 0, 0, &[]));
}
if data.len() != vectors * dims {
return Err(CodecError::Corrupt {
detail: format!(
"expected {} elements ({vectors} vectors × {dims} dims), got {}",
vectors * dims,
data.len()
),
});
}
let mut transformed = Vec::with_capacity(data.len());
for v in 0..vectors {
let start = v * dims;
let vec_data = &data[start..start + dims];
let spherical = cartesian_to_spherical(vec_data);
transformed.extend_from_slice(&spherical);
}
let raw_bytes: Vec<u8> = transformed.iter().flat_map(|f| f.to_le_bytes()).collect();
let compressed = crate::lz4::encode(&raw_bytes);
Ok(build_header(
dims as u32,
vectors as u32,
1, &compressed,
))
}
pub fn encode_raw(data: &[f32], dims: usize, vectors: usize) -> Result<Vec<u8>, CodecError> {
let raw_bytes: Vec<u8> = data.iter().flat_map(|f| f.to_le_bytes()).collect();
let compressed = crate::lz4::encode(&raw_bytes);
Ok(build_header(dims as u32, vectors as u32, 0, &compressed))
}
pub fn decode(data: &[u8]) -> Result<(Vec<f32>, usize, usize), CodecError> {
if data.len() < 9 {
return Err(CodecError::Truncated {
expected: 9,
actual: data.len(),
});
}
let dims = u32::from_le_bytes([data[0], data[1], data[2], data[3]]) as usize;
let vectors = u32::from_le_bytes([data[4], data[5], data[6], data[7]]) as usize;
let transform = data[8];
if vectors == 0 || dims == 0 {
return Ok((Vec::new(), dims, 0));
}
let compressed = &data[9..];
let raw_bytes = crate::lz4::decode(compressed).map_err(|e| CodecError::DecompressFailed {
detail: format!("spherical lz4: {e}"),
})?;
let expected_bytes = vectors * dims * 4;
if raw_bytes.len() != expected_bytes {
return Err(CodecError::Corrupt {
detail: format!("expected {expected_bytes} bytes, got {}", raw_bytes.len()),
});
}
let floats: Vec<f32> = raw_bytes
.chunks_exact(4)
.map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
.collect();
if transform == 0 {
return Ok((floats, dims, vectors));
}
let mut result = Vec::with_capacity(floats.len());
for v in 0..vectors {
let start = v * dims;
let spherical = &floats[start..start + dims];
let cartesian = spherical_to_cartesian(spherical);
result.extend_from_slice(&cartesian);
}
Ok((result, dims, vectors))
}
pub fn normalization_ratio(data: &[f32], dims: usize) -> f64 {
if data.is_empty() || dims == 0 {
return 0.0;
}
let vectors = data.len() / dims;
let mut normalized = 0usize;
for v in 0..vectors {
let start = v * dims;
let norm: f32 = data[start..start + dims]
.iter()
.map(|x| x * x)
.sum::<f32>()
.sqrt();
if (0.99..=1.01).contains(&norm) {
normalized += 1;
}
}
normalized as f64 / vectors as f64
}
fn cartesian_to_spherical(cart: &[f32]) -> Vec<f32> {
let n = cart.len();
if n == 0 {
return Vec::new();
}
if n == 1 {
return vec![cart[0]];
}
let mut spherical = Vec::with_capacity(n);
let r: f32 = cart.iter().map(|x| x * x).sum::<f32>().sqrt();
spherical.push(r);
for i in 0..n - 1 {
let sum_sq: f32 = cart[i..].iter().map(|x| x * x).sum::<f32>();
let denom = sum_sq.sqrt();
if denom < 1e-30 {
spherical.push(0.0);
} else if i < n - 2 {
spherical.push((cart[i] / denom).acos());
} else {
spherical.push(cart[n - 1].atan2(cart[n - 2]));
}
}
spherical
}
fn spherical_to_cartesian(sph: &[f32]) -> Vec<f32> {
let n = sph.len();
if n == 0 {
return Vec::new();
}
if n == 1 {
return vec![sph[0]];
}
let r = sph[0];
let angles = &sph[1..];
let dims = n;
let mut cart = Vec::with_capacity(dims);
for i in 0..dims - 1 {
let mut val = r;
for a in &angles[..i] {
val *= a.sin();
}
if i < dims - 2 {
val *= angles[i].cos();
} else {
val *= angles[dims - 2].sin();
}
cart.push(val);
}
let mut val = r;
for angle in &angles[..dims - 2] {
val *= angle.sin();
}
val *= angles[dims - 2].cos();
cart.push(val);
cart
}
fn build_header(dims: u32, vectors: u32, transform: u8, compressed: &[u8]) -> Vec<u8> {
let mut out = Vec::with_capacity(9 + compressed.len());
out.extend_from_slice(&dims.to_le_bytes());
out.extend_from_slice(&vectors.to_le_bytes());
out.push(transform);
out.extend_from_slice(compressed);
out
}
#[cfg(test)]
mod tests {
use super::*;
fn make_normalized_vectors(n: usize, dims: usize) -> Vec<f32> {
let mut data = Vec::with_capacity(n * dims);
for i in 0..n {
let mut vec: Vec<f32> = (0..dims)
.map(|d| ((i * dims + d) as f32 * 0.1).sin())
.collect();
let norm: f32 = vec.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 0.0 {
for v in &mut vec {
*v /= norm;
}
}
data.extend_from_slice(&vec);
}
data
}
#[test]
fn empty_roundtrip() {
let encoded = encode(&[], 32, 0).unwrap();
let (decoded, dims, vectors) = decode(&encoded).unwrap();
assert!(decoded.is_empty());
assert_eq!(vectors, 0);
assert_eq!(dims, 32);
}
#[test]
fn normalized_roundtrip() {
let data = make_normalized_vectors(100, 32);
let encoded = encode(&data, 32, 100).unwrap();
let (decoded, dims, vectors) = decode(&encoded).unwrap();
assert_eq!(dims, 32);
assert_eq!(vectors, 100);
assert_eq!(decoded.len(), data.len());
let max_error: f32 = data
.iter()
.zip(decoded.iter())
.map(|(a, b)| (a - b).abs())
.fold(0.0f32, f32::max);
assert!(
max_error < 0.1,
"max reconstruction error {max_error} exceeds threshold"
);
}
#[test]
fn raw_roundtrip() {
let data: Vec<f32> = (0..320).map(|i| i as f32 * 0.01).collect();
let encoded = encode_raw(&data, 32, 10).unwrap();
let (decoded, dims, vectors) = decode(&encoded).unwrap();
assert_eq!(dims, 32);
assert_eq!(vectors, 10);
for (a, b) in data.iter().zip(decoded.iter()) {
assert_eq!(a.to_bits(), b.to_bits());
}
}
#[test]
fn compression_ratio() {
let data = make_normalized_vectors(1000, 128);
let encoded = encode(&data, 128, 1000).unwrap();
let raw_size = data.len() * 4;
let ratio = raw_size as f64 / encoded.len() as f64;
assert!(ratio > 0.9, "should not expand >10%, got {ratio:.2}x");
}
#[test]
fn normalization_check() {
let normalized = make_normalized_vectors(100, 32);
assert!(normalization_ratio(&normalized, 32) > 0.95);
let raw: Vec<f32> = (0..3200).map(|i| i as f32).collect();
assert!(normalization_ratio(&raw, 32) < 0.1);
}
#[test]
fn truncated_error() {
assert!(decode(&[]).is_err());
assert!(decode(&[0; 5]).is_err());
}
}