use crate::e8::{E8Codec, E8EncodedVector};
use crate::h4::{H4Codec, H4EncodedVector};
use crate::error::{EmbedVecError, Result};
use crate::quantization::Quantization;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum StoredVector {
Raw(Vec<f32>),
E8(E8EncodedVector),
H4(H4EncodedVector),
}
impl StoredVector {
pub fn to_f32(&self, e8_codec: Option<&E8Codec>, h4_codec: Option<&H4Codec>) -> Vec<f32> {
match self {
StoredVector::Raw(v) => v.clone(),
StoredVector::E8(encoded) => {
if let Some(c) = e8_codec {
c.decode(encoded)
} else {
vec![0.0; encoded.points.len() * 8]
}
}
StoredVector::H4(encoded) => {
if let Some(c) = h4_codec {
c.decode(encoded)
} else {
vec![0.0; encoded.indices.len() * 4]
}
}
}
}
pub fn size_bytes(&self) -> usize {
match self {
StoredVector::Raw(v) => v.len() * 4,
StoredVector::E8(encoded) => encoded.size_bytes(),
StoredVector::H4(encoded) => encoded.size_bytes(),
}
}
}
#[derive(Debug)]
pub struct VectorStorage {
dimension: usize,
vectors: Vec<StoredVector>,
quantization: Quantization,
memory_bytes: usize,
}
impl VectorStorage {
pub fn new(dimension: usize, quantization: Quantization) -> Self {
Self {
dimension,
vectors: Vec::new(),
quantization,
memory_bytes: 0,
}
}
pub fn add(
&mut self,
vector: &[f32],
e8_codec: Option<&E8Codec>,
h4_codec: Option<&H4Codec>,
) -> Result<usize> {
let stored = match &self.quantization {
Quantization::None => StoredVector::Raw(vector.to_vec()),
Quantization::E8 { .. } => {
if let Some(c) = e8_codec {
let encoded = c.encode(vector)?;
StoredVector::E8(encoded)
} else {
return Err(EmbedVecError::QuantizationError(
"E8 codec required for E8 quantization".to_string(),
));
}
}
Quantization::H4 { .. } => {
if let Some(c) = h4_codec {
let encoded = c.encode(vector)?;
StoredVector::H4(encoded)
} else {
return Err(EmbedVecError::QuantizationError(
"H4 codec required for H4 quantization".to_string(),
));
}
}
};
self.memory_bytes += stored.size_bytes();
let id = self.vectors.len();
self.vectors.push(stored);
Ok(id)
}
#[inline]
pub fn get(
&self,
id: usize,
e8_codec: Option<&E8Codec>,
h4_codec: Option<&H4Codec>,
) -> Result<Vec<f32>> {
self.vectors
.get(id)
.map(|v| v.to_f32(e8_codec, h4_codec))
.ok_or(EmbedVecError::VectorNotFound(id))
}
#[inline]
pub fn get_raw_slice(&self, id: usize) -> Option<&[f32]> {
match self.vectors.get(id) {
Some(StoredVector::Raw(v)) => Some(v.as_slice()),
_ => None,
}
}
#[inline]
pub fn get_stored(&self, id: usize) -> Option<&StoredVector> {
self.vectors.get(id)
}
pub fn get_batch(
&self,
ids: &[usize],
e8_codec: Option<&E8Codec>,
h4_codec: Option<&H4Codec>,
) -> Vec<Option<Vec<f32>>> {
ids.iter()
.map(|&id| self.vectors.get(id).map(|v| v.to_f32(e8_codec, h4_codec)))
.collect()
}
pub fn len(&self) -> usize {
self.vectors.len()
}
pub fn is_empty(&self) -> bool {
self.vectors.is_empty()
}
pub fn clear(&mut self) {
self.vectors.clear();
self.memory_bytes = 0;
}
pub fn memory_bytes(&self) -> usize {
self.memory_bytes
}
pub fn bytes_per_vector(&self) -> f32 {
if self.vectors.is_empty() {
0.0
} else {
self.memory_bytes as f32 / self.vectors.len() as f32
}
}
pub fn set_quantization(
&mut self,
new_quantization: Quantization,
e8_codec: Option<&E8Codec>,
h4_codec: Option<&H4Codec>,
) -> Result<()> {
if self.quantization == new_quantization {
return Ok(());
}
let mut new_vectors = Vec::with_capacity(self.vectors.len());
let mut new_memory = 0usize;
for stored in &self.vectors {
let raw = stored.to_f32(e8_codec, h4_codec);
let new_stored = match &new_quantization {
Quantization::None => StoredVector::Raw(raw),
Quantization::E8 { .. } => {
if let Some(c) = e8_codec {
let encoded = c.encode(&raw)?;
StoredVector::E8(encoded)
} else {
return Err(EmbedVecError::QuantizationError(
"E8 codec required for E8 quantization".to_string(),
));
}
}
Quantization::H4 { .. } => {
if let Some(c) = h4_codec {
let encoded = c.encode(&raw)?;
StoredVector::H4(encoded)
} else {
return Err(EmbedVecError::QuantizationError(
"H4 codec required for H4 quantization".to_string(),
));
}
}
};
new_memory += new_stored.size_bytes();
new_vectors.push(new_stored);
}
self.vectors = new_vectors;
self.memory_bytes = new_memory;
self.quantization = new_quantization;
Ok(())
}
pub fn quantization(&self) -> &Quantization {
&self.quantization
}
pub fn dimension(&self) -> usize {
self.dimension
}
pub fn compute_distance(
&self,
query: &[f32],
id: usize,
e8_codec: Option<&E8Codec>,
h4_codec: Option<&H4Codec>,
distance_fn: impl Fn(&[f32], &[f32]) -> f32,
) -> Result<f32> {
let stored = self
.vectors
.get(id)
.ok_or(EmbedVecError::VectorNotFound(id))?;
match stored {
StoredVector::Raw(v) => Ok(distance_fn(query, v)),
StoredVector::E8(encoded) => {
if let Some(c) = e8_codec {
let decoded = c.decode(encoded);
Ok(distance_fn(query, &decoded))
} else {
Err(EmbedVecError::QuantizationError(
"E8 codec required for distance computation".to_string(),
))
}
}
StoredVector::H4(encoded) => {
if let Some(c) = h4_codec {
let decoded = c.decode(encoded);
Ok(distance_fn(query, &decoded))
} else {
Err(EmbedVecError::QuantizationError(
"H4 codec required for distance computation".to_string(),
))
}
}
}
}
pub fn iter_ids(&self) -> impl Iterator<Item = usize> {
0..self.vectors.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_raw_storage() {
let mut storage = VectorStorage::new(4, Quantization::None);
let v1 = vec![1.0, 2.0, 3.0, 4.0];
let id = storage.add(&v1, None, None).unwrap();
assert_eq!(id, 0);
let retrieved = storage.get(0, None, None).unwrap();
assert_eq!(retrieved, v1);
}
#[test]
fn test_e8_storage() {
use crate::e8::E8Codec;
let codec = E8Codec::new(16, 10, true, 42);
let mut storage = VectorStorage::new(16, Quantization::e8_default());
let v1: Vec<f32> = (0..16).map(|i| i as f32 * 0.1).collect();
let id = storage.add(&v1, Some(&codec), None).unwrap();
assert_eq!(id, 0);
let retrieved = storage.get(0, Some(&codec), None).unwrap();
assert_eq!(retrieved.len(), 16);
let mse: f32 = v1
.iter()
.zip(retrieved.iter())
.map(|(a, b)| (a - b).powi(2))
.sum::<f32>()
/ 16.0;
assert!(mse < 1.0, "MSE too high: {}", mse);
}
#[test]
fn test_h4_storage() {
use crate::h4::H4Codec;
let codec = H4Codec::new(16, true, 42);
let mut storage = VectorStorage::new(16, Quantization::h4_default());
let v1: Vec<f32> = (0..16).map(|i| (i as f32 * 0.3).sin()).collect();
let id = storage.add(&v1, None, Some(&codec)).unwrap();
assert_eq!(id, 0);
let retrieved = storage.get(0, None, Some(&codec)).unwrap();
assert_eq!(retrieved.len(), 16);
let mse: f32 = v1
.iter()
.zip(retrieved.iter())
.map(|(a, b)| (a - b).powi(2))
.sum::<f32>()
/ 16.0;
assert!(mse < 1.0, "H4 MSE too high: {}", mse);
}
#[test]
fn test_memory_tracking() {
let mut storage = VectorStorage::new(768, Quantization::None);
for _ in 0..10 {
let v: Vec<f32> = vec![0.0; 768];
storage.add(&v, None, None).unwrap();
}
assert_eq!(storage.memory_bytes(), 768 * 4 * 10);
}
#[test]
fn test_clear() {
let mut storage = VectorStorage::new(4, Quantization::None);
storage.add(&[1.0, 2.0, 3.0, 4.0], None, None).unwrap();
assert_eq!(storage.len(), 1);
storage.clear();
assert_eq!(storage.len(), 0);
assert_eq!(storage.memory_bytes(), 0);
}
}