use std::io::Result;
use super::{StorageReadProvider, StorageWriteProvider};
use diskann_quantization::scalar::ScalarQuantizer;
use super::protos;
use crate::model::graph::provider::async_::inmem::SQError;
const COMPRESSED_DATA_FILE_NAME_SUFFIX: &str = "sq_compressed.bin";
const QUANTIZER_DATA_FILE_NAME_SUFFIX: &str = "scalar_quantizer_proto.bin";
#[derive(Debug)]
pub struct SQStorage {
compressed_data_path: String,
quantizer_path: String,
}
impl SQStorage {
pub fn new(prefix: &str) -> Self {
Self {
compressed_data_path: format!("{}_{}", prefix, COMPRESSED_DATA_FILE_NAME_SUFFIX),
quantizer_path: format!("{}_{}", prefix, QUANTIZER_DATA_FILE_NAME_SUFFIX),
}
}
pub fn compressed_data_path(&self) -> &str {
&self.compressed_data_path
}
pub fn quantizer_path(&self) -> &str {
&self.quantizer_path
}
pub fn save_quantizer<Storage>(
&self,
quantizer: &ScalarQuantizer,
write_provider: &Storage,
) -> Result<usize>
where
Storage: StorageWriteProvider,
{
let quantizer_proto =
protos::ScalarQuantizer::from(quantizer, self.compressed_data_path().to_string());
protos::save(quantizer_proto, write_provider, self.quantizer_path())
}
pub fn load_quantizer<Storage>(
&self,
read_provider: &Storage,
) -> std::result::Result<ScalarQuantizer, SQError>
where
Storage: StorageReadProvider,
{
let quantizer_proto: protos::ScalarQuantizer =
protos::load(read_provider, self.quantizer_path())?;
Ok(ScalarQuantizer::try_from(quantizer_proto)?)
}
}
#[cfg(test)]
mod tests {
use crate::storage::VirtualStorageProvider;
use super::*;
#[test]
fn new_constructs_correct_paths() {
let storage = SQStorage::new("my_prefix");
assert_eq!(
storage.compressed_data_path(),
"my_prefix_sq_compressed.bin"
);
assert_eq!(
storage.quantizer_path(),
"my_prefix_scalar_quantizer_proto.bin"
);
}
#[test]
fn getters_return_references_to_internal_strings() {
let storage = SQStorage::new("foo");
let comp = storage.compressed_data_path();
let quant = storage.quantizer_path();
assert_eq!(comp, storage.compressed_data_path());
assert_eq!(quant, storage.quantizer_path());
}
#[test]
fn empty_prefix_still_constructs_paths_with_leading_underscore() {
let storage = SQStorage::new("");
assert_eq!(storage.compressed_data_path(), "_sq_compressed.bin");
assert_eq!(storage.quantizer_path(), "_scalar_quantizer_proto.bin");
}
#[test]
fn prefix_with_slashes_is_handled_as_literal() {
let storage = SQStorage::new("dir/subdir/file");
assert_eq!(
storage.compressed_data_path(),
"dir/subdir/file_sq_compressed.bin"
);
assert_eq!(
storage.quantizer_path(),
"dir/subdir/file_scalar_quantizer_proto.bin"
);
}
#[test]
fn save_and_load_quantizer_roundtrip() {
let storage_provider = VirtualStorageProvider::new_memory();
let sq_storage = SQStorage::new("/roundtrip");
let quantizer = ScalarQuantizer::new(1.0, vec![0.0, 1.0, 2.0], None);
let bytes_written = sq_storage
.save_quantizer(&quantizer, &storage_provider)
.expect("save_quantizer should succeed");
assert!(bytes_written > 0);
let loaded_quantizer = sq_storage
.load_quantizer(&storage_provider)
.expect("load_quantizer should succeed");
assert_eq!(quantizer.compare(&loaded_quantizer), Ok(()));
}
}