use std::fs::{File, OpenOptions};
use std::io::{self, Write};
use std::path::{Path, PathBuf};
use std::sync::Arc;
use memmap2::{Mmap, MmapOptions};
use thiserror::Error;
use crate::vector::types::{SegmentOrdinal, VectorDimension, VectorError, VectorId};
const STORAGE_VERSION: u32 = 1;
const HEADER_SIZE: usize = 16;
const MAGIC_BYTES: &[u8; 4] = b"CVEC";
const BYTES_PER_F32: usize = 4;
const BYTES_PER_ID: usize = 4;
#[derive(Error, Debug)]
pub enum VectorStorageError {
#[error("IO error: {0}")]
Io(#[from] io::Error),
#[error("Invalid storage format: {0}")]
InvalidFormat(String),
#[error("Vector error: {0}")]
Vector(#[from] VectorError),
}
#[derive(Debug)]
pub struct MmapVectorStorage {
path: PathBuf,
mmap: Option<Mmap>,
dimension: VectorDimension,
vector_count: usize,
segment: SegmentOrdinal,
}
impl MmapVectorStorage {
pub fn new(
base_path: impl AsRef<Path>,
segment: SegmentOrdinal,
dimension: VectorDimension,
) -> Result<Self, VectorStorageError> {
let path = Self::segment_path(base_path.as_ref(), segment);
Ok(Self {
path,
mmap: None,
dimension,
vector_count: 0,
segment,
})
}
pub fn open(
base_path: impl AsRef<Path>,
segment: SegmentOrdinal,
) -> Result<Self, VectorStorageError> {
let path = Self::segment_path(base_path.as_ref(), segment);
if !path.exists() {
return Err(VectorStorageError::Io(io::Error::new(
io::ErrorKind::NotFound,
format!("Vector storage file not found: {path:?}"),
)));
}
let file = File::open(&path)?;
let mmap = unsafe { MmapOptions::new().map(&file)? };
let (version, dimension, vector_count) = Self::read_header(&mmap)?;
if version != STORAGE_VERSION {
return Err(VectorError::VersionMismatch {
expected: STORAGE_VERSION,
actual: version,
}
.into());
}
Ok(Self {
path,
mmap: Some(mmap),
dimension,
vector_count,
segment,
})
}
pub fn open_or_create(
base_path: impl AsRef<Path>,
segment: SegmentOrdinal,
dimension: VectorDimension,
) -> Result<Self, VectorStorageError> {
let path = Self::segment_path(base_path.as_ref(), segment);
if path.exists() {
Self::open(base_path, segment)
} else {
let mut storage = Self::new(base_path, segment, dimension)?;
storage.initialize()?;
Ok(storage)
}
}
pub fn write_batch(
&mut self,
vectors: &[(VectorId, &[f32])],
) -> Result<(), VectorStorageError> {
let owned_vectors: Vec<(VectorId, Vec<f32>)> = vectors
.iter()
.map(|(id, vec)| (*id, vec.to_vec()))
.collect();
self.validate_vectors(&owned_vectors)?;
self.ensure_storage_ready()?;
self.append_vectors(&owned_vectors)?;
self.update_metadata(vectors.len())?;
self.invalidate_cache();
Ok(())
}
fn validate_vectors(&self, vectors: &[(VectorId, Vec<f32>)]) -> Result<(), VectorStorageError> {
for (_, vec) in vectors {
self.dimension.validate_vector(vec)?;
}
Ok(())
}
fn ensure_storage_ready(&self) -> Result<(), VectorStorageError> {
if let Some(parent) = self.path.parent() {
std::fs::create_dir_all(parent)?;
}
Ok(())
}
fn append_vectors(&self, vectors: &[(VectorId, Vec<f32>)]) -> Result<(), VectorStorageError> {
let mut file = OpenOptions::new()
.create(true)
.append(true)
.open(&self.path)?;
if file.metadata()?.len() == 0 {
self.write_header(&mut file)?;
}
for (id, vector) in vectors {
file.write_all(&id.to_bytes())?;
for &value in vector {
file.write_all(&value.to_le_bytes())?;
}
}
file.flush()?;
Ok(())
}
fn update_metadata(&mut self, vector_count: usize) -> Result<(), VectorStorageError> {
self.vector_count += vector_count;
self.update_header_count()?;
Ok(())
}
fn invalidate_cache(&mut self) {
self.mmap = None;
}
#[must_use]
pub fn read_vector(&mut self, id: VectorId) -> Option<Vec<f32>> {
self.ensure_mapped().ok()?;
let mmap = self.mmap.as_ref()?;
let dimension = self.dimension.get();
let vector_size = BYTES_PER_ID + dimension * BYTES_PER_F32;
let mut offset = HEADER_SIZE;
while offset + vector_size <= mmap.len() {
let stored_id = u32::from_le_bytes([
mmap[offset],
mmap[offset + 1],
mmap[offset + 2],
mmap[offset + 3],
]);
if stored_id == id.get() {
let mut vector = Vec::with_capacity(dimension);
let data_offset = offset + BYTES_PER_ID;
for i in 0..dimension {
let bytes_offset = data_offset + i * BYTES_PER_F32;
let value = f32::from_le_bytes([
mmap[bytes_offset],
mmap[bytes_offset + 1],
mmap[bytes_offset + 2],
mmap[bytes_offset + 3],
]);
vector.push(value);
}
return Some(vector);
}
offset += vector_size;
}
None
}
pub fn read_all_vectors(&mut self) -> Result<Vec<(VectorId, Vec<f32>)>, VectorStorageError> {
self.ensure_mapped()?;
let mmap = self.mmap.as_ref().unwrap();
let dimension = self.dimension.get();
let vector_size = BYTES_PER_ID + dimension * BYTES_PER_F32;
let mut vectors = Vec::with_capacity(self.vector_count);
let mut offset = HEADER_SIZE;
while offset + vector_size <= mmap.len() {
let id_bytes = [
mmap[offset],
mmap[offset + 1],
mmap[offset + 2],
mmap[offset + 3],
];
let id = VectorId::from_bytes(id_bytes).ok_or_else(|| {
VectorStorageError::InvalidFormat("Invalid vector ID".to_string())
})?;
let mut vector = Vec::with_capacity(dimension);
let data_offset = offset + BYTES_PER_ID;
for i in 0..dimension {
let bytes_offset = data_offset + i * BYTES_PER_F32;
let value = f32::from_le_bytes([
mmap[bytes_offset],
mmap[bytes_offset + 1],
mmap[bytes_offset + 2],
mmap[bytes_offset + 3],
]);
vector.push(value);
}
vectors.push((id, vector));
offset += vector_size;
}
Ok(vectors)
}
#[must_use]
pub fn vector_count(&self) -> usize {
self.vector_count
}
#[must_use]
pub fn dimension(&self) -> VectorDimension {
self.dimension
}
#[must_use]
pub fn segment(&self) -> SegmentOrdinal {
self.segment
}
#[must_use]
pub fn exists(&self) -> bool {
self.path.exists()
}
pub fn file_size(&self) -> Result<u64, io::Error> {
Ok(std::fs::metadata(&self.path)?.len())
}
fn segment_path(base_path: &Path, segment: SegmentOrdinal) -> PathBuf {
base_path.join(format!("segment_{}.vec", segment.get()))
}
fn initialize(&mut self) -> Result<(), VectorStorageError> {
if let Some(parent) = self.path.parent() {
std::fs::create_dir_all(parent)?;
}
let mut file = File::create(&self.path)?;
self.write_header(&mut file)?;
file.flush()?;
Ok(())
}
fn write_header(&self, file: &mut File) -> Result<(), io::Error> {
file.write_all(MAGIC_BYTES)?;
file.write_all(&STORAGE_VERSION.to_le_bytes())?;
file.write_all(&(self.dimension.get() as u32).to_le_bytes())?;
file.write_all(&0u32.to_le_bytes())?;
Ok(())
}
fn read_header(mmap: &Mmap) -> Result<(u32, VectorDimension, usize), VectorStorageError> {
if mmap.len() < HEADER_SIZE {
return Err(VectorStorageError::InvalidFormat(
"File too small to contain header".to_string(),
));
}
if &mmap[0..4] != MAGIC_BYTES {
return Err(VectorStorageError::InvalidFormat(
"Invalid magic bytes".to_string(),
));
}
let version = u32::from_le_bytes([mmap[4], mmap[5], mmap[6], mmap[7]]);
let dim_value = u32::from_le_bytes([mmap[8], mmap[9], mmap[10], mmap[11]]);
let dimension = VectorDimension::new(dim_value as usize)?;
let vector_count = u32::from_le_bytes([mmap[12], mmap[13], mmap[14], mmap[15]]) as usize;
Ok((version, dimension, vector_count))
}
fn ensure_mapped(&mut self) -> Result<(), VectorStorageError> {
if self.mmap.is_none() {
let file = File::open(&self.path)?;
let mmap = unsafe { MmapOptions::new().map(&file)? };
let (_, _, count) = Self::read_header(&mmap)?;
self.vector_count = count;
self.mmap = Some(mmap);
}
Ok(())
}
fn update_header_count(&self) -> Result<(), VectorStorageError> {
use std::io::{Seek, SeekFrom};
let mut file = OpenOptions::new().write(true).open(&self.path)?;
file.seek(SeekFrom::Start(12))?;
file.write_all(&(self.vector_count as u32).to_le_bytes())?;
file.flush()?;
Ok(())
}
}
pub struct ConcurrentVectorStorage {
inner: Arc<parking_lot::RwLock<MmapVectorStorage>>,
}
impl ConcurrentVectorStorage {
pub fn new(storage: MmapVectorStorage) -> Self {
Self {
inner: Arc::new(parking_lot::RwLock::new(storage)),
}
}
#[must_use]
pub fn read_vector(&self, id: VectorId) -> Option<Vec<f32>> {
self.inner.write().read_vector(id)
}
pub fn write_batch(&self, vectors: &[(VectorId, &[f32])]) -> Result<(), VectorStorageError> {
self.inner.write().write_batch(vectors).map_err(|e| {
VectorStorageError::Io(io::Error::other(format!(
"Concurrent write failed for {} vectors: {}",
vectors.len(),
e
)))
})
}
}
impl Clone for MmapVectorStorage {
fn clone(&self) -> Self {
Self {
path: self.path.clone(),
mmap: None, dimension: self.dimension,
vector_count: self.vector_count,
segment: self.segment,
}
}
}
impl PartialEq for MmapVectorStorage {
fn eq(&self, other: &Self) -> bool {
self.path == other.path
&& self.dimension == other.dimension
&& self.segment == other.segment
}
}
impl std::fmt::Debug for ConcurrentVectorStorage {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self.inner.try_read() {
Some(storage) => write!(f, "ConcurrentVectorStorage {{ storage: {storage:?} }}"),
None => write!(f, "ConcurrentVectorStorage {{ <locked> }}"),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
#[test]
fn test_storage_create_and_open() {
let temp_dir = TempDir::new().unwrap();
let segment = SegmentOrdinal::new(0);
let dimension = VectorDimension::dimension_384();
let storage = MmapVectorStorage::new(&temp_dir, segment, dimension).unwrap();
assert_eq!(storage.vector_count(), 0);
assert_eq!(storage.dimension(), dimension);
assert!(MmapVectorStorage::open(&temp_dir, segment).is_err());
}
#[test]
fn test_write_and_read_vectors() {
let temp_dir = TempDir::new().unwrap();
let segment = SegmentOrdinal::new(0);
let dimension = VectorDimension::new(4).unwrap();
let mut storage = MmapVectorStorage::open_or_create(&temp_dir, segment, dimension).unwrap();
let test_data = vec![
(VectorId::new(1).unwrap(), vec![1.0, 2.0, 3.0, 4.0]),
(VectorId::new(2).unwrap(), vec![5.0, 6.0, 7.0, 8.0]),
(VectorId::new(3).unwrap(), vec![9.0, 10.0, 11.0, 12.0]),
];
let vectors: Vec<(VectorId, &[f32])> = test_data
.iter()
.map(|(id, vec)| (*id, vec.as_slice()))
.collect();
storage.write_batch(&vectors).unwrap();
assert_eq!(storage.vector_count(), 3);
for (id, expected_vector) in &test_data {
let read_vector = storage.read_vector(*id).unwrap();
assert_eq!(&read_vector, expected_vector);
}
assert!(storage.read_vector(VectorId::new(999).unwrap()).is_none());
}
#[test]
fn test_read_all_vectors() {
let temp_dir = TempDir::new().unwrap();
let segment = SegmentOrdinal::new(0);
let dimension = VectorDimension::new(3).unwrap();
let mut storage = MmapVectorStorage::open_or_create(&temp_dir, segment, dimension).unwrap();
let test_data = vec![
(VectorId::new(10).unwrap(), vec![1.0, 2.0, 3.0]),
(VectorId::new(20).unwrap(), vec![4.0, 5.0, 6.0]),
];
let vectors: Vec<(VectorId, &[f32])> = test_data
.iter()
.map(|(id, vec)| (*id, vec.as_slice()))
.collect();
storage.write_batch(&vectors).unwrap();
let all_vectors = storage.read_all_vectors().unwrap();
assert_eq!(all_vectors.len(), 2);
assert_eq!(all_vectors, test_data);
}
#[test]
fn test_dimension_validation() {
let temp_dir = TempDir::new().unwrap();
let segment = SegmentOrdinal::new(0);
let dimension = VectorDimension::new(3).unwrap();
let mut storage = MmapVectorStorage::open_or_create(&temp_dir, segment, dimension).unwrap();
let wrong_test_data = [(VectorId::new(1).unwrap(), vec![1.0, 2.0])];
let wrong_vectors: Vec<(VectorId, &[f32])> = wrong_test_data
.iter()
.map(|(id, vec)| (*id, vec.as_slice()))
.collect();
assert!(storage.write_batch(&wrong_vectors).is_err());
}
#[test]
fn test_persistence_across_instances() {
let temp_dir = TempDir::new().unwrap();
let segment = SegmentOrdinal::new(0);
let dimension = VectorDimension::new(2).unwrap();
{
let mut storage =
MmapVectorStorage::open_or_create(&temp_dir, segment, dimension).unwrap();
let test_data = [
(VectorId::new(1).unwrap(), vec![1.0, 2.0]),
(VectorId::new(2).unwrap(), vec![3.0, 4.0]),
];
let vectors: Vec<(VectorId, &[f32])> = test_data
.iter()
.map(|(id, vec)| (*id, vec.as_slice()))
.collect();
storage.write_batch(&vectors).unwrap();
}
{
let mut storage = MmapVectorStorage::open(&temp_dir, segment).unwrap();
assert_eq!(storage.vector_count(), 2);
let vec1 = storage.read_vector(VectorId::new(1).unwrap()).unwrap();
assert_eq!(vec1, vec![1.0, 2.0]);
let vec2 = storage.read_vector(VectorId::new(2).unwrap()).unwrap();
assert_eq!(vec2, vec![3.0, 4.0]);
}
}
#[test]
fn test_vector_access_performance() {
use std::time::Instant;
let temp_dir = TempDir::new().unwrap();
let segment = SegmentOrdinal::new(0);
let dimension = VectorDimension::new(128).unwrap();
let mut storage = MmapVectorStorage::open_or_create(&temp_dir, segment, dimension).unwrap();
let test_data: Vec<_> = (1..=1000)
.map(|i| {
let id = VectorId::new(i).unwrap();
let vector = vec![i as f32 / 1000.0; 128];
(id, vector)
})
.collect();
let vectors: Vec<(VectorId, &[f32])> = test_data
.iter()
.map(|(id, vec)| (*id, vec.as_slice()))
.collect();
storage.write_batch(&vectors).unwrap();
for i in 1..=10 {
let id = VectorId::new(i).unwrap();
let _ = storage.read_vector(id);
}
let mut timings = Vec::with_capacity(1000);
for i in 1..=1000 {
let id = VectorId::new(i).unwrap();
let start = Instant::now();
let _ = storage.read_vector(id);
let elapsed = start.elapsed();
timings.push(elapsed.as_nanos());
}
timings.sort_unstable();
let median_nanos = timings[timings.len() / 2];
println!("Median read time: {median_nanos}ns");
assert!(
median_nanos < 100_000,
"Read performance should be <100μs in test environment"
);
}
}