use super::MmapStorage;
use crate::storage::log_payload::{crc32_hash, DurabilityMode};
use crate::storage::traits::VectorStorage;
use crate::storage::vector_bytes::{bytes_to_vector, vector_to_bytes};
use rustc_hash::FxHashMap;
use std::fs::File;
use std::io::{self, Write};
use std::sync::atomic::Ordering;
fn write_wal_store_entry(
wal: &mut io::BufWriter<File>,
id: u64,
data: &[u8],
buf: &mut Vec<u8>,
) -> io::Result<()> {
buf.clear();
buf.push(1u8);
buf.extend_from_slice(&id.to_le_bytes());
#[allow(clippy::cast_possible_truncation)]
let len_u32 = data.len() as u32;
buf.extend_from_slice(&len_u32.to_le_bytes());
buf.extend_from_slice(data);
let crc = crc32_hash(buf);
wal.write_all(buf)?;
wal.write_all(&crc.to_le_bytes())
}
const WAL_STORE_ENTRY_OVERHEAD: usize = 17;
fn write_wal_store_entries_grouped(
wal: &mut io::BufWriter<File>,
vectors: &[(u64, &[f32])],
vector_byte_size: usize,
entry_buf: &mut Vec<u8>,
group_buf: &mut Vec<u8>,
) -> io::Result<()> {
let entry_size = vector_byte_size + WAL_STORE_ENTRY_OVERHEAD;
group_buf.clear();
group_buf.reserve(vectors.len() * entry_size);
for &(id, vector) in vectors {
let data = vector_to_bytes(vector);
serialize_wal_store_entry(id, data, entry_buf);
let crc = crc32_hash(entry_buf);
group_buf.extend_from_slice(entry_buf);
group_buf.extend_from_slice(&crc.to_le_bytes());
}
wal.write_all(group_buf)
}
fn serialize_wal_store_entry(id: u64, data: &[u8], buf: &mut Vec<u8>) {
buf.clear();
buf.push(1u8);
buf.extend_from_slice(&id.to_le_bytes());
#[allow(clippy::cast_possible_truncation)]
let len_u32 = data.len() as u32;
buf.extend_from_slice(&len_u32.to_le_bytes());
buf.extend_from_slice(data);
}
fn write_wal_delete_entry(
wal: &mut io::BufWriter<File>,
id: u64,
buf: &mut Vec<u8>,
) -> io::Result<()> {
buf.clear();
buf.push(2u8);
buf.extend_from_slice(&id.to_le_bytes());
let crc = crc32_hash(buf);
wal.write_all(buf)?;
wal.write_all(&crc.to_le_bytes())
}
#[inline]
fn validate_dimension(expected: usize, actual: usize) -> io::Result<()> {
if actual != expected {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
format!("Vector dimension mismatch: expected {expected}, got {actual}"),
));
}
Ok(())
}
impl VectorStorage for MmapStorage {
fn store(&mut self, id: u64, vector: &[f32]) -> io::Result<()> {
validate_dimension(self.dimension, vector.len())?;
let vector_bytes = vector_to_bytes(vector);
if self.durability != DurabilityMode::None {
let mut wal = self.wal.write();
let mut buf = Vec::with_capacity(1 + 8 + 4 + vector_bytes.len());
write_wal_store_entry(&mut wal, id, vector_bytes, &mut buf)?;
}
let vector_size = vector_bytes.len();
let (offset, is_new) = if let Some(existing_offset) = self.index.get(id) {
(existing_offset, false)
} else {
let offset = self.next_offset.load(Ordering::Acquire);
self.next_offset.fetch_add(vector_size, Ordering::AcqRel);
(offset, true)
};
self.ensure_capacity(offset + vector_size)?;
{
let mut mmap = self.mmap.write();
mmap[offset..offset + vector_size].copy_from_slice(vector_bytes);
}
if is_new {
self.index.insert(id, offset);
}
Ok(())
}
fn store_batch(&mut self, vectors: &[(u64, &[f32])]) -> io::Result<usize> {
if vectors.is_empty() {
return Ok(0);
}
let vector_size = self.dimension * std::mem::size_of::<f32>();
for (_, vector) in vectors {
validate_dimension(self.dimension, vector.len())?;
}
let (new_vector_offsets, total_new_size) = self.compute_new_offsets(vectors, vector_size);
if total_new_size > 0 {
let start_offset = self.next_offset.load(Ordering::Acquire);
self.ensure_capacity(start_offset + total_new_size)?;
self.next_offset.fetch_add(total_new_size, Ordering::AcqRel);
}
if self.durability != DurabilityMode::None {
let mut wal = self.wal.write();
let mut entry_buf = Vec::with_capacity(1 + 8 + 4 + vector_size);
let mut group_buf = Vec::new();
write_wal_store_entries_grouped(
&mut wal,
vectors,
vector_size,
&mut entry_buf,
&mut group_buf,
)?;
}
self.write_vectors_to_mmap(vectors, vector_size, &new_vector_offsets)?;
for (id, offset) in new_vector_offsets {
self.index.insert(id, offset);
}
Ok(vectors.len())
}
fn retrieve(&self, id: u64) -> io::Result<Option<Vec<f32>>> {
let Some(offset) = self.index.get(id) else {
return Ok(None);
};
let mmap = self.mmap.read();
let vector_size = self.dimension * std::mem::size_of::<f32>();
if offset + vector_size > mmap.len() {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"Offset out of bounds",
));
}
let bytes = &mmap[offset..offset + vector_size];
Ok(Some(bytes_to_vector(bytes, self.dimension)))
}
fn delete(&mut self, id: u64) -> io::Result<()> {
if self.durability != DurabilityMode::None {
let mut wal = self.wal.write();
let mut buf = Vec::with_capacity(1 + 8);
write_wal_delete_entry(&mut wal, id, &mut buf)?;
}
let offset = self.index.get(id);
self.index.remove(id);
if let Some(offset) = offset {
let vector_size = self.dimension * std::mem::size_of::<f32>();
let offset_u64 = u64::try_from(offset).unwrap_or(u64::MAX);
let size_u64 = u64::try_from(vector_size).unwrap_or(u64::MAX);
if offset_u64 != u64::MAX && size_u64 != u64::MAX {
let _ =
crate::storage::compaction::punch_hole(&self.data_file, offset_u64, size_u64);
}
}
Ok(())
}
fn flush(&mut self) -> io::Result<()> {
self.mmap.write().flush()?;
match self.durability {
DurabilityMode::Fsync => {
let mut wal = self.wal.write();
wal.flush()?;
wal.get_ref().sync_all()?;
}
DurabilityMode::FlushOnly => {
self.wal.write().flush()?;
}
DurabilityMode::None => {}
}
Ok(())
}
fn len(&self) -> usize {
self.index.len()
}
fn ids(&self) -> Vec<u64> {
self.index.keys()
}
}
impl MmapStorage {
fn compute_new_offsets(
&self,
vectors: &[(u64, &[f32])],
vector_size: usize,
) -> (FxHashMap<u64, usize>, usize) {
let mut new_vector_offsets: FxHashMap<u64, usize> = FxHashMap::default();
new_vector_offsets.reserve(vectors.len());
let mut total_new_size = 0usize;
for &(id, _) in vectors {
if !self.index.contains_key(id) {
let offset = self.next_offset.load(Ordering::Acquire) + total_new_size;
new_vector_offsets.insert(id, offset);
total_new_size += vector_size;
}
}
(new_vector_offsets, total_new_size)
}
fn write_vectors_to_mmap(
&self,
vectors: &[(u64, &[f32])],
vector_size: usize,
new_vector_offsets: &FxHashMap<u64, usize>,
) -> io::Result<()> {
let mut mmap = self.mmap.write();
for &(id, vector) in vectors {
let vector_bytes = vector_to_bytes(vector);
let offset = if let Some(existing) = self.index.get(id) {
existing
} else {
new_vector_offsets.get(&id).copied().ok_or_else(|| {
io::Error::new(
io::ErrorKind::InvalidData,
"ID not found in new_vector_offsets",
)
})?
};
mmap[offset..offset + vector_size].copy_from_slice(vector_bytes);
}
Ok(())
}
}