use crate::VectorType;
use crate::error::{Result, TriviumError};
use std::fs::OpenOptions;
use std::io::Write;
use std::path::{Path, PathBuf};
#[cfg(windows)]
fn robust_rename(from: &Path, to: &Path) -> std::io::Result<()> {
let max_retries = 10;
let mut delay = std::time::Duration::from_millis(1);
for attempt in 0..max_retries {
match std::fs::rename(from, to) {
Ok(()) => return Ok(()),
Err(e) if attempt < max_retries - 1 => {
let os_err = e.raw_os_error();
if os_err == Some(5) || os_err == Some(32) {
std::thread::sleep(delay);
delay = (delay * 2).min(std::time::Duration::from_millis(50));
continue;
}
return Err(e);
}
Err(e) => return Err(e),
}
}
unreachable!()
}
#[cfg(not(windows))]
fn robust_rename(from: &Path, to: &Path) -> std::io::Result<()> {
std::fs::rename(from, to)
}
pub struct VecPool<T: VectorType> {
dim: usize,
vec_path: Option<PathBuf>,
mmap: Option<memmap2::MmapMut>,
mmap_count: usize,
delta: Vec<T>,
merged: Vec<T>,
merged_valid: bool,
has_dirty_base: bool,
}
#[cfg(unix)]
#[inline]
fn madvise(mmap: &memmap2::MmapMut, advice: memmap2::Advice) {
let _ = mmap.advise(advice);
}
impl<T: VectorType> VecPool<T> {
pub fn new(dim: usize) -> Self {
Self {
dim,
vec_path: None,
mmap: None,
mmap_count: 0,
delta: Vec::new(),
merged: Vec::new(),
merged_valid: false,
has_dirty_base: false,
}
}
pub fn open(vec_path: &Path, dim: usize, expected_count: usize) -> Result<Self> {
let mut pool = Self::new(dim);
pool.vec_path = Some(vec_path.to_path_buf());
if vec_path.exists() && expected_count > 0 {
let file = std::fs::File::open(vec_path)?;
let file_len = file.metadata()?.len() as usize;
let elem_size = std::mem::size_of::<T>();
let expected_size = expected_count * dim * elem_size;
if file_len < expected_size {
return Err(TriviumError::CorruptedFile(format!(
"向量文件大小不匹配: 文件 {} 字节, 预期最少 {} 字节",
file_len, expected_size
)));
}
if file_len > 0 {
let mmap = unsafe {
memmap2::MmapOptions::new()
.len(expected_size)
.map_copy(&file)
.map_err(TriviumError::Io)?
};
pool.mmap = Some(mmap);
pool.mmap_count = expected_count;
#[cfg(unix)]
if let Some(ref m) = pool.mmap {
madvise(m, memmap2::Advice::Sequential);
}
}
}
pool.invalidate_cache();
Ok(pool)
}
#[inline]
pub fn total_count(&self) -> usize {
self.mmap_count + self.delta_count()
}
#[inline]
pub fn delta_count(&self) -> usize {
if self.dim == 0 {
0
} else {
self.delta.len() / self.dim
}
}
#[inline]
pub fn mmap_count(&self) -> usize {
self.mmap_count
}
pub fn push(&mut self, vector: &[T]) {
self.delta.extend_from_slice(vector);
self.invalidate_cache();
}
pub fn zero_out(&mut self, index: usize) {
let offset = index * self.dim;
if index < self.mmap_count {
if let Some(ref mut mmap) = self.mmap {
let elem_size = std::mem::size_of::<T>();
let byte_offset = offset * elem_size;
let byte_len = self.dim * elem_size;
let slice = &mut mmap[byte_offset..byte_offset + byte_len];
for b in slice.iter_mut() {
*b = 0;
}
}
self.has_dirty_base = true;
} else {
let delta_offset = (index - self.mmap_count) * self.dim;
for i in delta_offset..delta_offset + self.dim {
self.delta[i] = T::zero();
}
}
self.invalidate_cache();
}
pub fn update(&mut self, index: usize, vector: &[T]) {
let offset = index * self.dim;
if index < self.mmap_count {
if let Some(ref mut mmap) = self.mmap {
let elem_size = std::mem::size_of::<T>();
let byte_offset = offset * elem_size;
let src_bytes = bytemuck::cast_slice(vector);
mmap[byte_offset..byte_offset + src_bytes.len()].copy_from_slice(src_bytes);
}
self.has_dirty_base = true;
} else {
let delta_offset = (index - self.mmap_count) * self.dim;
self.delta[delta_offset..delta_offset + self.dim].copy_from_slice(vector);
}
self.invalidate_cache();
}
pub fn get(&self, index: usize) -> Option<&[T]> {
if index < self.mmap_count {
self.mmap.as_ref().map(|m| {
let elem_size = std::mem::size_of::<T>();
let byte_offset = index * self.dim * elem_size;
let byte_len = self.dim * elem_size;
let bytes = &m[byte_offset..byte_offset + byte_len];
let ptr = bytes.as_ptr();
if (ptr as usize).is_multiple_of(std::mem::align_of::<T>()) {
unsafe { std::slice::from_raw_parts(ptr as *const T, self.dim) }
} else {
panic!("mmap 对齐异常,这不应该发生在正常的 OS 页映射中")
}
})
} else {
let delta_index = index - self.mmap_count;
let delta_offset = delta_index * self.dim;
if delta_offset + self.dim <= self.delta.len() {
Some(&self.delta[delta_offset..delta_offset + self.dim])
} else {
None
}
}
}
pub fn ensure_cache(&mut self) {
if self.mmap.is_some() && self.mmap_count > 0 && !self.merged_valid {
self.rebuild_merged_cache();
}
}
pub fn flat_vectors(&self) -> &[T] {
if self.mmap.is_none() || self.mmap_count == 0 {
return &self.delta;
}
&self.merged
}
#[inline]
pub fn delta_raw(&self) -> &[T] {
&self.delta
}
pub fn detach_mmap(&mut self) {
if self.mmap.is_some() {
self.ensure_cache();
let mut new_delta = Vec::with_capacity(self.merged.len());
new_delta.extend_from_slice(&self.merged);
self.delta = new_delta;
self.mmap = None;
self.vec_path = None;
self.mmap_count = 0;
self.merged.clear();
self.merged_valid = false;
self.has_dirty_base = false;
}
}
pub fn flush(&mut self, vec_path: &Path) -> Result<usize> {
let total = self.total_count();
if total == 0 {
self.mmap = None;
self.mmap_count = 0;
if vec_path.exists() {
std::fs::remove_file(vec_path).ok();
}
self.delta.clear();
self.has_dirty_base = false;
self.invalidate_cache();
return Ok(0);
}
if !self.has_dirty_base && self.mmap.is_some() {
if self.delta.is_empty() {
return Ok(self.mmap_count);
}
self.flush_append(vec_path)
} else {
self.flush_rewrite(vec_path)
}
}
fn flush_append(&mut self, vec_path: &Path) -> Result<usize> {
let append_count = self.delta_count();
let elem_size = std::mem::size_of::<T>();
{
let mut file = OpenOptions::new()
.append(true)
.open(vec_path)
.map_err(TriviumError::Io)?;
let delta_bytes = bytemuck::cast_slice(&self.delta);
file.write_all(delta_bytes).map_err(TriviumError::Io)?;
file.sync_all().map_err(TriviumError::Io)?;
}
let new_total = self.mmap_count + append_count;
self.mmap = None;
let file = std::fs::File::open(vec_path).map_err(TriviumError::Io)?;
let expected_bytes = new_total * self.dim * elem_size;
let new_mmap = unsafe {
memmap2::MmapOptions::new()
.len(expected_bytes)
.map_copy(&file)
.map_err(TriviumError::Io)?
};
self.mmap = Some(new_mmap);
self.mmap_count = new_total;
#[cfg(unix)]
if let Some(ref m) = self.mmap {
madvise(m, memmap2::Advice::Sequential);
}
self.delta.clear();
self.delta.shrink_to_fit();
self.vec_path = Some(vec_path.to_path_buf());
self.has_dirty_base = false;
self.invalidate_cache();
tracing::debug!(
"[VecPool] 追加写入: +{} 向量, 累计 {} 向量",
append_count,
new_total
);
Ok(new_total)
}
fn flush_rewrite(&mut self, vec_path: &Path) -> Result<usize> {
let total = self.total_count();
let elem_size = std::mem::size_of::<T>();
let tmp_path = vec_path.with_extension("vec.tmp");
{
let mut file = std::fs::File::create(&tmp_path)?;
if let Some(ref mmap) = self.mmap {
let base_bytes = self.mmap_count * self.dim * elem_size;
file.write_all(&mmap[..base_bytes])?;
}
if !self.delta.is_empty() {
let delta_bytes = bytemuck::cast_slice(&self.delta);
file.write_all(delta_bytes)?;
}
file.sync_all()?;
}
self.mmap = None;
robust_rename(&tmp_path, vec_path)?;
let file = std::fs::File::open(vec_path)?;
let expected_bytes = total * self.dim * elem_size;
let new_mmap = unsafe {
memmap2::MmapOptions::new()
.len(expected_bytes)
.map_copy(&file)
.map_err(TriviumError::Io)?
};
self.mmap = Some(new_mmap);
self.mmap_count = total;
#[cfg(unix)]
if let Some(ref m) = self.mmap {
madvise(m, memmap2::Advice::Sequential);
}
self.delta.clear();
self.delta.shrink_to_fit();
self.vec_path = Some(vec_path.to_path_buf());
self.has_dirty_base = false;
self.invalidate_cache();
tracing::debug!("[VecPool] 全量重写: {} 向量", total);
Ok(total)
}
pub fn advise_dontneed(&self) {
#[cfg(unix)]
if let Some(ref m) = self.mmap {
let _ = unsafe { m.unchecked_advise(memmap2::UncheckedAdvice::DontNeed) };
tracing::debug!(
"[VecPool] madvise(DONTNEED):释放 {} MB 冷页",
self.mmap_count * self.dim * std::mem::size_of::<T>() / (1024 * 1024)
);
}
}
pub fn advise_random(&self) {
#[cfg(unix)]
if let Some(ref m) = self.mmap {
madvise(m, memmap2::Advice::Random);
}
}
pub fn advise_sequential(&self) {
#[cfg(unix)]
if let Some(ref m) = self.mmap {
madvise(m, memmap2::Advice::Sequential);
}
}
#[inline]
fn invalidate_cache(&mut self) {
self.merged_valid = false;
}
fn rebuild_merged_cache(&mut self) {
let total_elements = self.total_count() * self.dim;
self.merged.clear();
self.merged.reserve(total_elements);
if let Some(ref mmap) = self.mmap {
let elem_size = std::mem::size_of::<T>();
let base_bytes = self.mmap_count * self.dim * elem_size;
let bytes = &mmap[..base_bytes];
let ptr = bytes.as_ptr();
if (ptr as usize).is_multiple_of(std::mem::align_of::<T>()) {
let base_slice = unsafe {
std::slice::from_raw_parts(ptr as *const T, self.mmap_count * self.dim)
};
self.merged.extend_from_slice(base_slice);
} else {
for i in 0..self.mmap_count * self.dim {
let off = i * elem_size;
let chunk = &bytes[off..off + elem_size];
let elem: T = bytemuck::pod_read_unaligned(chunk);
self.merged.push(elem);
}
}
}
self.merged.extend_from_slice(&self.delta);
self.merged_valid = true;
}
pub fn heap_memory_bytes(&self) -> usize {
let delta_bytes = self.delta.len() * std::mem::size_of::<T>();
let merged_bytes = self.merged.len() * std::mem::size_of::<T>();
delta_bytes + merged_bytes
}
pub fn total_data_bytes(&self) -> usize {
self.total_count() * self.dim * std::mem::size_of::<T>()
}
}