use anyhow::{bail, Result};
const EMBEDDING_DIM: usize = 384;
const BINARY_BLOB_SIZE: usize = EMBEDDING_DIM * 4;
pub fn encode_embedding(vec: &[f32]) -> Vec<u8> {
let mut buf = Vec::with_capacity(vec.len() * 4);
for &val in vec {
buf.extend_from_slice(&val.to_le_bytes());
}
buf
}
pub fn decode_embedding(blob: &[u8]) -> Result<Vec<f32>> {
if blob.len() == BINARY_BLOB_SIZE {
let mut vec = Vec::with_capacity(EMBEDDING_DIM);
for chunk in blob.chunks_exact(4) {
vec.push(f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]));
}
Ok(vec)
} else if blob.first() == Some(&b'[') {
let vec: Vec<f32> = serde_json::from_slice(blob)?;
Ok(vec)
} else {
bail!(
"Unknown embedding format: length={}, first byte={:?}",
blob.len(),
blob.first()
)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_roundtrip_binary() {
let original: Vec<f32> = (0..384).map(|i| i as f32 * 0.001).collect();
let encoded = encode_embedding(&original);
assert_eq!(encoded.len(), 1536);
let decoded = decode_embedding(&encoded).unwrap();
assert_eq!(original, decoded);
}
#[test]
fn test_decode_legacy_json() {
let original: Vec<f32> = (0..384).map(|i| i as f32 * 0.001).collect();
let json_blob = serde_json::to_vec(&original).unwrap();
assert!(json_blob.len() > 1536); let decoded = decode_embedding(&json_blob).unwrap();
assert_eq!(original, decoded);
}
#[test]
fn test_invalid_format() {
let bad_blob = vec![0u8; 100];
assert!(decode_embedding(&bad_blob).is_err());
}
#[test]
fn test_empty_embedding() {
let empty: Vec<f32> = vec![];
let encoded = encode_embedding(&empty);
assert!(encoded.is_empty());
assert!(decode_embedding(&encoded).is_err());
}
#[test]
fn test_special_float_values() {
let mut vec: Vec<f32> = (0..384).map(|i| i as f32).collect();
vec[0] = f32::NEG_INFINITY;
vec[1] = f32::INFINITY;
vec[2] = 0.0;
vec[3] = -0.0;
let encoded = encode_embedding(&vec);
let decoded = decode_embedding(&encoded).unwrap();
assert_eq!(vec.len(), decoded.len());
assert!(decoded[0].is_infinite() && decoded[0].is_sign_negative());
assert!(decoded[1].is_infinite() && decoded[1].is_sign_positive());
}
}