use super::MmapStorage;
use crate::storage::log_payload::crc32_hash;
use crate::storage::traits::VectorStorage;
use crate::storage::vector_bytes::{bytes_to_vector, vector_to_bytes};
use rustc_hash::FxHashMap;
use std::fs::File;
use std::io::{self, Write};
use std::sync::atomic::Ordering;
fn write_wal_store_entry(wal: &mut io::BufWriter<File>, id: u64, data: &[u8]) -> io::Result<()> {
let mut frame = Vec::with_capacity(1 + 8 + 4 + data.len());
frame.push(1u8);
frame.extend_from_slice(&id.to_le_bytes());
#[allow(clippy::cast_possible_truncation)]
let len_u32 = data.len() as u32;
frame.extend_from_slice(&len_u32.to_le_bytes());
frame.extend_from_slice(data);
let crc = crc32_hash(&frame);
wal.write_all(&frame)?;
wal.write_all(&crc.to_le_bytes())
}
fn write_wal_delete_entry(wal: &mut io::BufWriter<File>, id: u64) -> io::Result<()> {
let mut frame = Vec::with_capacity(1 + 8);
frame.push(2u8);
frame.extend_from_slice(&id.to_le_bytes());
let crc = crc32_hash(&frame);
wal.write_all(&frame)?;
wal.write_all(&crc.to_le_bytes())
}
#[inline]
fn validate_dimension(expected: usize, actual: usize) -> io::Result<()> {
if actual != expected {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
format!("Vector dimension mismatch: expected {expected}, got {actual}"),
));
}
Ok(())
}
impl VectorStorage for MmapStorage {
fn store(&mut self, id: u64, vector: &[f32]) -> io::Result<()> {
validate_dimension(self.dimension, vector.len())?;
let vector_bytes = vector_to_bytes(vector);
{
let mut wal = self.wal.write();
write_wal_store_entry(&mut wal, id, vector_bytes)?;
}
let vector_size = vector_bytes.len();
let (offset, is_new) = if let Some(existing_offset) = self.index.get(id) {
(existing_offset, false)
} else {
let offset = self.next_offset.load(Ordering::Acquire);
self.next_offset.fetch_add(vector_size, Ordering::AcqRel);
(offset, true)
};
self.ensure_capacity(offset + vector_size)?;
{
let mut mmap = self.mmap.write();
mmap[offset..offset + vector_size].copy_from_slice(vector_bytes);
}
if is_new {
self.index.insert(id, offset);
}
Ok(())
}
fn store_batch(&mut self, vectors: &[(u64, &[f32])]) -> io::Result<usize> {
if vectors.is_empty() {
return Ok(0);
}
let vector_size = self.dimension * std::mem::size_of::<f32>();
for (_, vector) in vectors {
validate_dimension(self.dimension, vector.len())?;
}
let mut new_vector_offsets: FxHashMap<u64, usize> = FxHashMap::default();
new_vector_offsets.reserve(vectors.len());
let mut total_new_size = 0usize;
for &(id, _) in vectors {
if !self.index.contains_key(id) {
let offset = self.next_offset.load(Ordering::Acquire) + total_new_size;
new_vector_offsets.insert(id, offset);
total_new_size += vector_size;
}
}
if total_new_size > 0 {
let start_offset = self.next_offset.load(Ordering::Acquire);
self.ensure_capacity(start_offset + total_new_size)?;
self.next_offset.fetch_add(total_new_size, Ordering::AcqRel);
}
{
let mut wal = self.wal.write();
for &(id, vector) in vectors {
let vector_bytes = vector_to_bytes(vector);
write_wal_store_entry(&mut wal, id, vector_bytes)?;
}
}
{
let mut mmap = self.mmap.write();
for &(id, vector) in vectors {
let vector_bytes = vector_to_bytes(vector);
let offset = if let Some(existing) = self.index.get(id) {
existing
} else {
new_vector_offsets.get(&id).copied().ok_or_else(|| {
io::Error::new(
io::ErrorKind::InvalidData,
"ID not found in new_vector_offsets",
)
})?
};
mmap[offset..offset + vector_size].copy_from_slice(vector_bytes);
}
}
for (id, offset) in new_vector_offsets {
self.index.insert(id, offset);
}
Ok(vectors.len())
}
fn retrieve(&self, id: u64) -> io::Result<Option<Vec<f32>>> {
let Some(offset) = self.index.get(id) else {
return Ok(None);
};
let mmap = self.mmap.read();
let vector_size = self.dimension * std::mem::size_of::<f32>();
if offset + vector_size > mmap.len() {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"Offset out of bounds",
));
}
let bytes = &mmap[offset..offset + vector_size];
Ok(Some(bytes_to_vector(bytes, self.dimension)))
}
fn delete(&mut self, id: u64) -> io::Result<()> {
{
let mut wal = self.wal.write();
write_wal_delete_entry(&mut wal, id)?;
}
let offset = self.index.get(id);
self.index.remove(id);
if let Some(offset) = offset {
let vector_size = self.dimension * std::mem::size_of::<f32>();
let offset_u64 = u64::try_from(offset).unwrap_or(u64::MAX);
let size_u64 = u64::try_from(vector_size).unwrap_or(u64::MAX);
if offset_u64 != u64::MAX && size_u64 != u64::MAX {
let _ =
crate::storage::compaction::punch_hole(&self.data_file, offset_u64, size_u64);
}
}
Ok(())
}
fn flush(&mut self) -> io::Result<()> {
self.mmap.write().flush()?;
{
let mut wal = self.wal.write();
wal.flush()?;
wal.get_ref().sync_all()?;
}
let index_path = self.path.join("vectors.idx");
let file = File::create(&index_path)?;
let mut writer = io::BufWriter::new(file);
let flat_index = self.index.to_hashmap();
let bytes = postcard::to_allocvec(&flat_index).map_err(io::Error::other)?;
writer.write_all(&bytes)?;
writer.flush()?;
writer
.into_inner()
.map_err(std::io::IntoInnerError::into_error)?
.sync_all()?;
Ok(())
}
fn len(&self) -> usize {
self.index.len()
}
fn ids(&self) -> Vec<u64> {
self.index.keys()
}
}