use super::sq8::{QuantizedVector, SQ8Quantizer};
use crate::types::RowId;
use crate::{Result, StorageError};
use lru::LruCache;
use parking_lot::RwLock;
use std::collections::HashMap;
use std::fs::{File, OpenOptions};
use std::io::{Read, Seek, SeekFrom, Write};
use std::num::NonZeroUsize;
use std::path::{Path, PathBuf};
use std::sync::Arc;
pub struct SQ8Vectors {
_data_dir: PathBuf,
dimension: usize,
quantizer: Arc<SQ8Quantizer>,
_entry_size: usize,
index: Arc<RwLock<HashMap<RowId, u64>>>,
cache: Arc<RwLock<LruCache<RowId, Arc<Vec<f32>>>>>,
quantized_cache: Arc<RwLock<LruCache<RowId, Arc<QuantizedVector>>>>,
file_path: PathBuf,
}
impl SQ8Vectors {
pub fn create(
data_dir: impl AsRef<Path>,
quantizer: Arc<SQ8Quantizer>,
cache_size: usize,
) -> Result<Self> {
let data_dir = data_dir.as_ref().to_path_buf();
std::fs::create_dir_all(&data_dir).map_err(StorageError::Io)?;
let dimension = quantizer.dimension();
let entry_size = 8 + 4 + 4 + dimension; let file_path = data_dir.join("vectors_sq8.bin");
let mut file = File::create(&file_path).map_err(StorageError::Io)?;
file.write_all(&0u64.to_le_bytes())
.map_err(StorageError::Io)?;
Ok(Self {
_data_dir: data_dir,
dimension,
quantizer,
_entry_size: entry_size,
index: Arc::new(RwLock::new(HashMap::new())),
cache: Arc::new(RwLock::new(LruCache::new(
NonZeroUsize::new(cache_size).unwrap(),
))),
quantized_cache: Arc::new(RwLock::new(LruCache::new(
NonZeroUsize::new(cache_size * 2).unwrap(), ))),
file_path,
})
}
pub fn load(
data_dir: impl AsRef<Path>,
quantizer: Arc<SQ8Quantizer>,
cache_size: usize,
) -> Result<Self> {
let data_dir = data_dir.as_ref().to_path_buf();
let dimension = quantizer.dimension();
let entry_size = 8 + 4 + 4 + dimension;
let file_path = data_dir.join("vectors_sq8.bin");
if !file_path.exists() {
return Err(StorageError::InvalidData(
"SQ8 vectors file not found".to_string(),
));
}
let mut file = File::open(&file_path).map_err(StorageError::Io)?;
let mut count_bytes = [0u8; 8];
file.read_exact(&mut count_bytes).map_err(StorageError::Io)?;
let count = u64::from_le_bytes(count_bytes);
let mut index = HashMap::new();
let mut offset = 8u64;
for _ in 0..count {
let mut row_id_bytes = [0u8; 8];
file.read_exact(&mut row_id_bytes)
.map_err(StorageError::Io)?;
let row_id = u64::from_le_bytes(row_id_bytes);
index.insert(row_id, offset);
offset += entry_size as u64;
file.seek(SeekFrom::Current((entry_size - 8) as i64))
.map_err(StorageError::Io)?;
}
Ok(Self {
_data_dir: data_dir,
dimension,
quantizer,
_entry_size: entry_size,
index: Arc::new(RwLock::new(index)),
cache: Arc::new(RwLock::new(LruCache::new(
NonZeroUsize::new(cache_size).unwrap(),
))),
quantized_cache: Arc::new(RwLock::new(LruCache::new(
NonZeroUsize::new(cache_size * 2).unwrap(),
))),
file_path,
})
}
pub fn get(&self, row_id: RowId) -> Option<Arc<Vec<f32>>> {
{
let mut cache = self.cache.write();
if let Some(vec) = cache.get(&row_id) {
return Some(Arc::clone(vec)); }
}
let offset = {
let index = self.index.read();
*index.get(&row_id)?
};
let qvec = self.read_quantized(offset).ok()?;
let vec = self.quantizer.dequantize(&qvec);
let arc_vec = Arc::new(vec);
{
let mut cache = self.cache.write();
cache.put(row_id, Arc::clone(&arc_vec));
}
Some(arc_vec)
}
pub fn get_quantized(&self, row_id: RowId) -> Option<Arc<QuantizedVector>> {
{
let mut cache = self.quantized_cache.write();
if let Some(qvec) = cache.get(&row_id) {
return Some(Arc::clone(qvec)); }
}
let offset = {
let index = self.index.read();
*index.get(&row_id)?
};
let qvec = self.read_quantized(offset).ok()?;
let arc_qvec = Arc::new(qvec);
{
let mut cache = self.quantized_cache.write();
cache.put(row_id, Arc::clone(&arc_qvec));
}
Some(arc_qvec)
}
pub fn batch_get_quantized(&self, row_ids: &[RowId]) -> HashMap<RowId, Arc<QuantizedVector>> {
let mut result = HashMap::with_capacity(row_ids.len());
let mut uncached_ids = Vec::new();
{
let mut cache = self.quantized_cache.write();
for &row_id in row_ids {
if let Some(qvec) = cache.get(&row_id) {
result.insert(row_id, Arc::clone(qvec)); } else {
uncached_ids.push(row_id);
}
}
}
for row_id in uncached_ids {
if let Some(qvec) = self.get_quantized(row_id) {
result.insert(row_id, qvec);
}
}
result
}
pub fn insert(&self, row_id: RowId, vector: Vec<f32>) -> Result<()> {
if vector.len() != self.dimension {
return Err(StorageError::InvalidData(format!(
"Vector dimension mismatch: expected {}, got {}",
self.dimension,
vector.len()
)));
}
{
let index = self.index.read();
if index.contains_key(&row_id) {
return Err(StorageError::InvalidData(format!(
"Vector {} already exists",
row_id
)));
}
}
let qvec = self.quantizer.quantize(&vector)?;
let offset = self.append_quantized(row_id, &qvec)?;
{
let mut index = self.index.write();
index.insert(row_id, offset);
}
{
let mut cache = self.cache.write();
cache.put(row_id, Arc::new(vector)); }
Ok(())
}
pub fn batch_insert(&self, batch: Vec<(RowId, Vec<f32>)>) -> Result<usize> {
let mut inserted = 0;
for (row_id, vector) in batch {
if self.insert(row_id, vector).is_ok() {
inserted += 1;
}
}
Ok(inserted)
}
pub fn update(&self, row_id: RowId, vector: Vec<f32>) -> Result<bool> {
let exists = self.index.read().contains_key(&row_id);
if !exists {
return Ok(false);
}
{
let mut cache = self.cache.write();
cache.put(row_id, Arc::new(vector.clone())); }
{
let mut qcache = self.quantized_cache.write();
qcache.pop(&row_id);
}
Ok(true)
}
pub fn delete(&self, row_id: RowId) -> Result<bool> {
let removed = {
let mut index = self.index.write();
index.remove(&row_id).is_some()
};
if removed {
self.invalidate_single(row_id);
}
Ok(removed)
}
fn invalidate_single(&self, row_id: RowId) {
let mut cache = self.cache.write();
cache.pop(&row_id);
drop(cache);
let mut qcache = self.quantized_cache.write();
qcache.pop(&row_id);
}
pub fn invalidate_batch(&self, row_ids: &[RowId]) {
if row_ids.is_empty() {
return;
}
let mut cache = self.cache.write();
for &row_id in row_ids {
cache.pop(&row_id);
}
drop(cache);
let mut qcache = self.quantized_cache.write();
for &row_id in row_ids {
qcache.pop(&row_id);
}
}
pub fn flush(&self) -> Result<()> {
let count = self.index.read().len() as u64;
let mut file = OpenOptions::new()
.write(true)
.open(&self.file_path)
.map_err(StorageError::Io)?;
file.seek(SeekFrom::Start(0)).map_err(StorageError::Io)?;
file.write_all(&count.to_le_bytes())
.map_err(StorageError::Io)?;
Ok(())
}
pub fn ids(&self) -> Vec<RowId> {
self.index.read().keys().copied().collect()
}
pub fn len(&self) -> usize {
self.index.read().len()
}
pub fn is_empty(&self) -> bool {
self.index.read().is_empty()
}
pub fn dimension(&self) -> usize {
self.dimension
}
pub fn memory_usage(&self) -> usize {
let index_size = self.index.read().len() * (8 + 8); let cache_size = self.cache.read().len() * (8 + self.dimension * 4);
index_size + cache_size
}
pub fn disk_usage(&self) -> usize {
std::fs::metadata(&self.file_path)
.map(|m| m.len() as usize)
.unwrap_or(0)
}
fn read_quantized(&self, offset: u64) -> Result<QuantizedVector> {
let mut file = File::open(&self.file_path).map_err(StorageError::Io)?;
file.seek(SeekFrom::Start(offset + 8))
.map_err(StorageError::Io)?;
let mut min_bytes = [0u8; 4];
let mut max_bytes = [0u8; 4];
file.read_exact(&mut min_bytes).map_err(StorageError::Io)?;
file.read_exact(&mut max_bytes).map_err(StorageError::Io)?;
let min = f32::from_le_bytes(min_bytes);
let max = f32::from_le_bytes(max_bytes);
let mut codes = vec![0u8; self.dimension];
file.read_exact(&mut codes).map_err(StorageError::Io)?;
Ok(QuantizedVector { codes, min, max })
}
fn append_quantized(&self, row_id: RowId, qvec: &QuantizedVector) -> Result<u64> {
let mut file = OpenOptions::new()
.append(true)
.open(&self.file_path)
.map_err(StorageError::Io)?;
let offset = file.metadata().map_err(StorageError::Io)?.len();
file.write_all(&row_id.to_le_bytes())
.map_err(StorageError::Io)?;
file.write_all(&qvec.min.to_le_bytes())
.map_err(StorageError::Io)?;
file.write_all(&qvec.max.to_le_bytes())
.map_err(StorageError::Io)?;
file.write_all(&qvec.codes).map_err(StorageError::Io)?;
Ok(offset)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sq8_vectors_basic() {
use std::env;
let temp_dir = env::temp_dir().join("sq8_vectors_test");
let _ = std::fs::remove_dir_all(&temp_dir);
std::fs::create_dir_all(&temp_dir).unwrap();
let quantizer = Arc::new(SQ8Quantizer::new(4));
let storage = SQ8Vectors::create(&temp_dir, quantizer.clone(), 10).unwrap();
storage.insert(1, vec![1.0, 2.0, 3.0, 4.0]).unwrap();
storage.insert(2, vec![5.0, 6.0, 7.0, 8.0]).unwrap();
let v1 = storage.get(1).unwrap();
assert_eq!(v1.len(), 4);
let expected = vec![1.0, 2.0, 3.0, 4.0];
for i in 0..4 {
assert!((v1[i] - expected[i]).abs() < 0.1);
}
storage.flush().unwrap();
let loaded = SQ8Vectors::load(&temp_dir, quantizer, 10).unwrap();
assert_eq!(loaded.len(), 2);
let v1_loaded = loaded.get(1).unwrap();
assert_eq!(v1_loaded.len(), 4);
std::fs::remove_dir_all(&temp_dir).ok();
}
}