use crate::VectorType;
use crate::error::{Result, TriviumError};
use std::path::{Path, PathBuf};
pub struct VecPool<T: VectorType> {
dim: usize,
vec_path: Option<PathBuf>,
mmap: Option<memmap2::MmapMut>,
mmap_count: usize,
delta: Vec<T>,
merged: Vec<T>,
merged_valid: bool,
}
impl<T: VectorType> VecPool<T> {
pub fn new(dim: usize) -> Self {
Self {
dim,
vec_path: None,
mmap: None,
mmap_count: 0,
delta: Vec::new(),
merged: Vec::new(),
merged_valid: false,
}
}
pub fn open(vec_path: &Path, dim: usize, expected_count: usize) -> Result<Self> {
let mut pool = Self::new(dim);
pool.vec_path = Some(vec_path.to_path_buf());
if vec_path.exists() && expected_count > 0 {
let file = std::fs::File::open(vec_path)?;
let file_len = file.metadata()?.len() as usize;
let elem_size = std::mem::size_of::<T>();
let expected_size = expected_count * dim * elem_size;
if file_len < expected_size {
return Err(TriviumError::Generic(format!(
"向量文件大小不匹配: 文件 {} 字节, 预期最少 {} 字节",
file_len, expected_size
)));
}
if file_len > 0 {
let mmap = unsafe {
memmap2::MmapOptions::new()
.len(expected_size)
.map_copy(&file)
.map_err(|e| TriviumError::Io(e))?
};
pool.mmap = Some(mmap);
pool.mmap_count = expected_count;
}
}
pool.invalidate_cache();
Ok(pool)
}
#[inline]
pub fn total_count(&self) -> usize {
self.mmap_count + self.delta_count()
}
#[inline]
pub fn delta_count(&self) -> usize {
if self.dim == 0 { 0 } else { self.delta.len() / self.dim }
}
#[inline]
pub fn mmap_count(&self) -> usize {
self.mmap_count
}
pub fn push(&mut self, vector: &[T]) {
self.delta.extend_from_slice(vector);
self.invalidate_cache();
}
pub fn zero_out(&mut self, index: usize) {
let offset = index * self.dim;
if index < self.mmap_count {
if let Some(ref mut mmap) = self.mmap {
let elem_size = std::mem::size_of::<T>();
let byte_offset = offset * elem_size;
let byte_len = self.dim * elem_size;
let slice = &mut mmap[byte_offset..byte_offset + byte_len];
for b in slice.iter_mut() {
*b = 0;
}
}
} else {
let delta_offset = (index - self.mmap_count) * self.dim;
for i in delta_offset..delta_offset + self.dim {
self.delta[i] = T::zero();
}
}
self.invalidate_cache();
}
pub fn update(&mut self, index: usize, vector: &[T]) {
let offset = index * self.dim;
if index < self.mmap_count {
if let Some(ref mut mmap) = self.mmap {
let elem_size = std::mem::size_of::<T>();
let byte_offset = offset * elem_size;
let src_bytes = bytemuck::cast_slice(vector);
mmap[byte_offset..byte_offset + src_bytes.len()].copy_from_slice(src_bytes);
}
} else {
let delta_offset = (index - self.mmap_count) * self.dim;
self.delta[delta_offset..delta_offset + self.dim].copy_from_slice(vector);
}
self.invalidate_cache();
}
pub fn get(&self, index: usize) -> Option<&[T]> {
if index < self.mmap_count {
self.mmap.as_ref().map(|m| {
let elem_size = std::mem::size_of::<T>();
let byte_offset = index * self.dim * elem_size;
let byte_len = self.dim * elem_size;
let bytes = &m[byte_offset..byte_offset + byte_len];
let ptr = bytes.as_ptr();
if (ptr as usize) % std::mem::align_of::<T>() == 0 {
unsafe {
std::slice::from_raw_parts(ptr as *const T, self.dim)
}
} else {
panic!("mmap 对齐异常,这不应该发生在正常的 OS 页映射中")
}
})
} else {
let delta_index = index - self.mmap_count;
let delta_offset = delta_index * self.dim;
if delta_offset + self.dim <= self.delta.len() {
Some(&self.delta[delta_offset..delta_offset + self.dim])
} else {
None
}
}
}
pub fn ensure_cache(&mut self) {
if self.mmap.is_some() && self.mmap_count > 0 && !self.merged_valid {
self.rebuild_merged_cache();
}
}
pub fn flat_vectors(&self) -> &[T] {
if self.mmap.is_none() || self.mmap_count == 0 {
return &self.delta;
}
&self.merged
}
#[inline]
pub fn delta_raw(&self) -> &[T] {
&self.delta
}
pub fn detach_mmap(&mut self) {
if self.mmap.is_some() {
self.ensure_cache();
let mut new_delta = Vec::with_capacity(self.merged.len());
new_delta.extend_from_slice(&self.merged);
self.delta = new_delta;
self.mmap = None;
self.vec_path = None;
self.mmap_count = 0;
self.merged.clear();
self.merged_valid = false;
}
}
pub fn flush(&mut self, vec_path: &Path) -> Result<usize> {
let total = self.total_count();
if total == 0 {
if vec_path.exists() {
std::fs::remove_file(vec_path)?;
}
self.mmap = None;
self.mmap_count = 0;
self.delta.clear();
self.invalidate_cache();
return Ok(0);
}
let tmp_path = vec_path.with_extension("vec.tmp");
let elem_size = std::mem::size_of::<T>();
{
let mut file = std::fs::File::create(&tmp_path)?;
if let Some(ref mmap) = self.mmap {
let base_bytes = self.mmap_count * self.dim * elem_size;
std::io::Write::write_all(&mut file, &mmap[..base_bytes])?;
}
if !self.delta.is_empty() {
let delta_bytes = bytemuck::cast_slice(&self.delta);
std::io::Write::write_all(&mut file, delta_bytes)?;
}
file.sync_all()?;
}
std::fs::rename(&tmp_path, vec_path)?;
let new_total = total;
let file = std::fs::File::open(vec_path)?;
let new_mmap = unsafe {
memmap2::MmapOptions::new()
.map_copy(&file)
.map_err(|e| TriviumError::Io(e))?
};
self.mmap = Some(new_mmap);
self.mmap_count = new_total;
self.delta.clear();
self.delta.shrink_to_fit();
self.vec_path = Some(vec_path.to_path_buf());
self.invalidate_cache();
Ok(new_total)
}
#[inline]
fn invalidate_cache(&mut self) {
self.merged_valid = false;
}
fn rebuild_merged_cache(&mut self) {
let total_elements = self.total_count() * self.dim;
self.merged.clear();
self.merged.reserve(total_elements);
if let Some(ref mmap) = self.mmap {
let elem_size = std::mem::size_of::<T>();
let base_bytes = self.mmap_count * self.dim * elem_size;
let bytes = &mmap[..base_bytes];
let ptr = bytes.as_ptr();
if (ptr as usize) % std::mem::align_of::<T>() == 0 {
let base_slice = unsafe {
std::slice::from_raw_parts(ptr as *const T, self.mmap_count * self.dim)
};
self.merged.extend_from_slice(base_slice);
} else {
for i in 0..self.mmap_count * self.dim {
let off = i * elem_size;
let chunk = &bytes[off..off + elem_size];
let elem: T = bytemuck::pod_read_unaligned(chunk);
self.merged.push(elem);
}
}
}
self.merged.extend_from_slice(&self.delta);
self.merged_valid = true;
}
pub fn heap_memory_bytes(&self) -> usize {
let delta_bytes = self.delta.len() * std::mem::size_of::<T>();
let merged_bytes = self.merged.len() * std::mem::size_of::<T>();
delta_bytes + merged_bytes
}
pub fn total_data_bytes(&self) -> usize {
self.total_count() * self.dim * std::mem::size_of::<T>()
}
}