use std::io::{Read, Result, Write};
use super::super::{StorageReadProvider, StorageWriteProvider};
use thiserror::Error;
pub fn load<P, S>(read_provider: &P, path: &str) -> std::result::Result<S, ProtoStorageError>
where
P: StorageReadProvider,
S: prost::Message + Default,
{
let mut reader = read_provider.open_reader(path)?;
let mut raw_buffer = Vec::new();
reader.read_to_end(&mut raw_buffer)?;
Ok(S::decode(&*raw_buffer)?)
}
pub fn save<S, P>(proto_struct: S, write_provider: &P, path: &str) -> Result<usize>
where
P: StorageWriteProvider,
S: prost::Message,
{
let mut writer = write_provider.create_for_write(path)?;
let encoded_proto = proto_struct.encode_to_vec();
writer.write_all(&encoded_proto)?;
writer.flush()?;
Ok(encoded_proto.len())
}
#[derive(Debug, Error)]
pub enum ProtoStorageError {
#[error("Error while creating/opening file {0:?}")]
IoError(#[from] std::io::Error),
#[error("Error while decoding bytes to proto struct: {0:?}")]
DecodeError(#[from] prost::DecodeError),
}
#[cfg(test)]
mod tests {
use crate::storage::VirtualStorageProvider;
use prost::Message;
use super::*;
use crate::storage::protos::ScalarQuantizer;
use crate::storage::protos::scalar_quantization::Version;
#[test]
fn test_save_and_load_success() {
let storage_provider = VirtualStorageProvider::new_memory();
let original = ScalarQuantizer {
version: Some(Version {
major: 0,
minor: 1,
patch: 0,
}),
scale: 1.0,
shift: vec![0.1, 1.2, 2.3],
mean_norm: Some(0.5),
compressed_data_file_name: "compressed_data.bin".to_string(),
};
let bytes_written = save(original.clone(), &storage_provider, "/test.bin").unwrap();
assert_eq!(bytes_written, original.encode_to_vec().len());
let loaded: ScalarQuantizer = load(&storage_provider, "/test.bin").unwrap();
assert_eq!(loaded, original);
}
#[test]
fn test_load_invalid_data_returns_decode_error() {
let storage_provider = VirtualStorageProvider::new_memory();
{
let mut writer = storage_provider.create_for_write("/bad.bin").unwrap();
writer
.write_all(&[
13, 0, 0, 128, 63, 16, 4, 29, 0, 0, 128, 63, 34, 12, 205, 204,
])
.unwrap();
writer.flush().unwrap();
}
let err = load::<_, ScalarQuantizer>(&storage_provider, "/bad.bin").unwrap_err();
assert!(matches!(err, ProtoStorageError::DecodeError(_)));
}
#[test]
fn test_load_io_error_when_missing_file() {
let storage_provider = VirtualStorageProvider::new_memory();
let err = load::<_, ScalarQuantizer>(&storage_provider, "/missing.bin").unwrap_err();
assert!(
matches!(err, ProtoStorageError::IoError(_)),
"Expected IoError, got: {:?}",
err
);
}
}