use crate::VectorType;
use crate::database::StorageMode;
use crate::error::{Result, TriviumError};
use crate::node::{Edge, NodeId};
use crate::storage::memtable::MemTable;
use crate::storage::vec_pool::VecPool;
use memmap2::Mmap;
use std::fs::File;
use std::io::{BufWriter, Write};
use std::path::Path;
fn robust_rename(from: &Path, to: &Path) -> std::io::Result<()> {
#[cfg(not(windows))]
{
return std::fs::rename(from, to);
}
#[cfg(windows)]
{
let max_retries = 10;
let mut delay = std::time::Duration::from_millis(1);
for attempt in 0..max_retries {
match std::fs::rename(from, to) {
Ok(()) => return Ok(()),
Err(e) if attempt < max_retries - 1 => {
let os_err = e.raw_os_error();
if os_err == Some(5) || os_err == Some(32) {
tracing::debug!(
"robust_rename: attempt {} failed (os_error={:?}), retrying in {:?}",
attempt + 1, os_err, delay
);
std::thread::sleep(delay);
delay = (delay * 2).min(std::time::Duration::from_millis(50));
continue;
}
return Err(e);
}
Err(e) => return Err(e),
}
}
unreachable!()
}
}
const MAGIC: &[u8; 4] = b"TVDB";
const VERSION: u16 = 3; const HEADER_SIZE: u64 = 58;
use crate::index::erpc::{ErpcIndex, ErpcParams, SeqEntry};
fn vec_path_from_db(db_path: &str) -> String {
format!("{}.vec", db_path)
}
fn flush_ok_path_from_db(db_path: &str) -> String {
format!("{}.flush_ok", db_path)
}
pub fn save<T: VectorType>(
memtable: &mut MemTable<T>,
path: &str,
mode: StorageMode,
) -> Result<()> {
match mode {
StorageMode::Mmap => save_mmap(memtable, path),
StorageMode::Rom => save_rom(memtable, path),
}
}
fn save_mmap<T: VectorType>(memtable: &mut MemTable<T>, path: &str) -> Result<()> {
let vec_file_path = vec_path_from_db(path);
let vec_count = memtable.vec_pool_mut().flush(Path::new(&vec_file_path))?;
save_tdb(memtable, path, vec_count, true)?;
let tdb_size = std::fs::metadata(path).map(|m| m.len()).unwrap_or(0);
let vec_size = std::fs::metadata(&vec_file_path)
.map(|m| m.len())
.unwrap_or(0);
let marker_path = flush_ok_path_from_db(path);
let marker_tmp = format!("{}.tmp", &marker_path);
{
let mut f = File::create(&marker_tmp)?;
f.write_all(&tdb_size.to_le_bytes())?;
f.write_all(&vec_size.to_le_bytes())?;
f.sync_all()?;
}
robust_rename(Path::new(&marker_tmp), Path::new(&marker_path))?;
Ok(())
}
fn save_rom<T: VectorType>(memtable: &mut MemTable<T>, path: &str) -> Result<()> {
memtable.ensure_vectors_cache();
let total_vectors = memtable.internal_indices().len();
save_tdb(memtable, path, total_vectors, false)?;
memtable.vec_pool_mut().detach_mmap();
let vec_file_path = vec_path_from_db(path);
if Path::new(&vec_file_path).exists() {
std::fs::remove_file(vec_file_path).ok();
}
let marker_path = flush_ok_path_from_db(path);
if Path::new(&marker_path).exists() {
std::fs::remove_file(marker_path).ok();
}
Ok(())
}
fn save_tdb<T: VectorType>(
memtable: &mut MemTable<T>,
path: &str,
vec_count: usize,
is_mmap_mode: bool,
) -> Result<()> {
if !is_mmap_mode {
memtable.ensure_vectors_cache();
}
let tmp_path = format!("{}.tmp", path);
let file = File::create(&tmp_path)?;
let mut w = BufWriter::new(file);
let dim = memtable.dim();
let internal_indices = memtable.internal_indices();
let node_count = internal_indices.len() as u64;
let mut all_edges: Vec<(NodeId, &Edge)> = Vec::new();
let mut payload_size: u64 = 0;
for &nid in internal_indices {
if nid != 0 {
if let Some(p) = memtable.get_payload(nid) {
let json_bytes = serde_json::to_vec(p).unwrap_or_default();
payload_size += 8 + 4 + json_bytes.len() as u64;
} else {
payload_size += 12;
}
if let Some(edges) = memtable.get_edges(nid) {
for edge in edges {
all_edges.push((nid, edge));
}
}
} else {
payload_size += 12;
}
}
let payload_offset = HEADER_SIZE;
let vector_offset = if is_mmap_mode {
0
} else {
payload_offset + payload_size
};
let vector_size = if is_mmap_mode {
0
} else {
node_count * (dim as u64) * (std::mem::size_of::<T>() as u64)
};
let edge_offset = payload_offset + payload_size + vector_size;
let edge_size = all_edges.iter().map(|(_, e)| (8 + 8 + 2 + e.label.len() + 4) as u64).sum::<u64>();
let erpc_offset = edge_offset + edge_size;
w.write_all(MAGIC)?;
w.write_all(&VERSION.to_le_bytes())?;
w.write_all(&(dim as u32).to_le_bytes())?;
w.write_all(&memtable.next_id_value().to_le_bytes())?;
w.write_all(&node_count.to_le_bytes())?;
w.write_all(&payload_offset.to_le_bytes())?;
w.write_all(&vector_offset.to_le_bytes())?;
w.write_all(&edge_offset.to_le_bytes())?;
w.write_all(&erpc_offset.to_le_bytes())?;
for &nid in internal_indices {
if nid != 0
&& let Some(p) = memtable.get_payload(nid) {
let json_bytes = serde_json::to_vec(p).unwrap_or_default();
w.write_all(&nid.to_le_bytes())?;
w.write_all(&(json_bytes.len() as u32).to_le_bytes())?;
w.write_all(&json_bytes)?;
continue;
}
w.write_all(&0u64.to_le_bytes())?;
w.write_all(&0u32.to_le_bytes())?;
}
if !is_mmap_mode {
let flat = memtable.flat_vectors();
w.write_all(bytemuck::cast_slice(flat))?;
}
for (src_id, edge) in &all_edges {
w.write_all(&src_id.to_le_bytes())?;
w.write_all(&edge.target_id.to_le_bytes())?;
let label_bytes = edge.label.as_bytes();
w.write_all(&(label_bytes.len() as u16).to_le_bytes())?;
w.write_all(label_bytes)?;
w.write_all(&edge.weight.to_le_bytes())?;
}
if let Some(erpc) = &memtable.erpc_index {
w.write_all(bytemuck::bytes_of(&erpc.params))?;
for pc in &erpc.pca_basis {
for &fv in pc { w.write_all(&fv.to_le_bytes())?; }
}
for c in &erpc.centers {
for &fv in c { w.write_all(&fv.to_le_bytes())?; }
}
w.write_all(bytemuck::cast_slice(&erpc.sequence))?;
}
w.flush()?;
let file = w
.into_inner()
.map_err(|e| TriviumError::Io(e.into_error()))?;
file.sync_all()?;
drop(file);
robust_rename(Path::new(&tmp_path), Path::new(path))?;
tracing::info!(
"持久化完成: {} 个槽位(含删除), {} 个向量, Mode: {}",
node_count,
vec_count,
if is_mmap_mode { "Mmap" } else { "Rom" }
);
Ok(())
}
pub fn load<T: VectorType>(path: &str, _mode: StorageMode) -> Result<MemTable<T>> {
let file = File::open(path).map_err(TriviumError::Io)?;
let mmap = unsafe { Mmap::map(&file) }.map_err(TriviumError::Io)?;
if mmap.len() < HEADER_SIZE as usize {
return Err(TriviumError::Generic("File too small for header".into()));
}
let bytes = &mmap[..];
if &bytes[0..4] != MAGIC {
return Err(TriviumError::Generic(format!(
"Invalid file magic: expected TVDB, got {:?}",
&bytes[0..4]
)));
}
let dim = u32::from_le_bytes(bytes[6..10].try_into().unwrap()) as usize;
let next_id = u64::from_le_bytes(bytes[10..18].try_into().unwrap());
let node_count = u64::from_le_bytes(bytes[18..26].try_into().unwrap()) as usize;
let payload_offset = u64::from_le_bytes(bytes[26..34].try_into().unwrap()) as usize;
let vector_offset = u64::from_le_bytes(bytes[34..42].try_into().unwrap()) as usize;
let edge_offset = u64::from_le_bytes(bytes[42..50].try_into().unwrap()) as usize;
let erpc_offset = if mmap.len() >= 58 {
u64::from_le_bytes(bytes[50..58].try_into().unwrap()) as usize
} else {
mmap.len()
};
let vec_file_path = vec_path_from_db(path);
if vector_offset == 0 && Path::new(&vec_file_path).exists() {
let marker_path = flush_ok_path_from_db(path);
let flush_ok_valid = (|| -> Option<bool> {
let marker_bytes = std::fs::read(&marker_path).ok()?;
if marker_bytes.len() < 16 {
return Some(false);
}
let stored_tdb = u64::from_le_bytes(marker_bytes[0..8].try_into().ok()?);
let stored_vec = u64::from_le_bytes(marker_bytes[8..16].try_into().ok()?);
let actual_tdb = std::fs::metadata(path).ok()?.len();
let actual_vec = std::fs::metadata(&vec_file_path).ok()?.len();
Some(stored_tdb == actual_tdb && stored_vec == actual_vec)
})()
.unwrap_or(false);
if flush_ok_valid {
load_v2(
bytes,
dim,
next_id,
node_count,
payload_offset,
edge_offset,
erpc_offset,
&vec_file_path,
&mmap,
)
} else {
tracing::warn!(
"检测到 .tdb/.vec 跨文件撕裂(.flush_ok 标记缺失或不匹配),\
降级为忽略 .vec 的安全模式加载,增量数据将由 WAL 回放恢复"
);
load_v1_rom(
bytes,
dim,
next_id,
node_count,
payload_offset,
vector_offset,
edge_offset,
erpc_offset,
&mmap,
)
}
} else {
load_v1_rom(
bytes,
dim,
next_id,
node_count,
payload_offset,
vector_offset,
edge_offset,
erpc_offset,
&mmap,
)
}
}
fn load_v2<T: VectorType>(
bytes: &[u8],
dim: usize,
next_id: u64,
node_count: usize,
payload_offset: usize,
edge_offset: usize,
erpc_offset: usize,
vec_file_path: &str,
_tdb_mmap: &Mmap,
) -> Result<MemTable<T>> {
let vec_pool = VecPool::<T>::open(Path::new(vec_file_path), dim, node_count)?;
let mut memtable = MemTable::new_with_vec_pool(dim, next_id, vec_pool);
load_payloads(
&mut memtable,
bytes,
node_count,
payload_offset,
edge_offset,
)?;
load_edges(&mut memtable, bytes, edge_offset, erpc_offset)?;
memtable.erpc_index = load_erpc(bytes, erpc_offset, dim)?;
Ok(memtable)
}
fn load_v1_rom<T: VectorType>(
bytes: &[u8],
dim: usize,
next_id: u64,
node_count: usize,
payload_offset: usize,
vector_offset: usize,
edge_offset: usize,
erpc_offset: usize,
tdb_mmap: &Mmap,
) -> Result<MemTable<T>> {
let mut memtable = MemTable::new_with_next_id(dim, next_id);
let vector_bytes_per_elem = std::mem::size_of::<T>();
let expected_vec_size = node_count * dim * vector_bytes_per_elem;
if vector_offset + expected_vec_size > tdb_mmap.len() {
return Err(TriviumError::Generic(
"Vector block exceeds file size".into(),
));
}
load_payloads(
&mut memtable,
bytes,
node_count,
payload_offset,
vector_offset,
)?;
let vec_block = &bytes[vector_offset..vector_offset + expected_vec_size];
let is_aligned = (vec_block.as_ptr() as usize).is_multiple_of(std::mem::align_of::<T>());
if is_aligned {
let t_slice =
unsafe { std::slice::from_raw_parts(vec_block.as_ptr() as *const T, node_count * dim) };
memtable.vec_pool_mut().push(t_slice);
} else {
let mut v = Vec::with_capacity(node_count * dim);
for i in 0..(node_count * dim) {
let off = i * vector_bytes_per_elem;
let chunk = &vec_block[off..off + vector_bytes_per_elem];
let elem: T = bytemuck::pod_read_unaligned(chunk);
v.push(elem);
}
memtable.vec_pool_mut().push(&v);
}
load_edges(&mut memtable, bytes, edge_offset, erpc_offset)?;
memtable.erpc_index = load_erpc(bytes, erpc_offset, dim)?;
Ok(memtable)
}
fn load_payloads<T: VectorType>(
memtable: &mut MemTable<T>,
bytes: &[u8],
node_count: usize,
offset: usize,
end_offset: usize,
) -> Result<()> {
let mut cursor = offset;
for _ in 0..node_count {
if cursor + 12 > end_offset {
return Err(TriviumError::Generic("Payload block overflow".into()));
}
let nid = u64::from_le_bytes(bytes[cursor..cursor + 8].try_into().unwrap());
cursor += 8;
let json_len = u32::from_le_bytes(bytes[cursor..cursor + 4].try_into().unwrap()) as usize;
cursor += 4;
if nid == 0 && json_len == 0 {
memtable.register_tombstone()?;
continue;
}
if cursor + json_len > end_offset {
return Err(TriviumError::Generic("JSON data overflow".into()));
}
let payload: serde_json::Value = serde_json::from_slice(&bytes[cursor..cursor + json_len])
.map_err(|e| TriviumError::Generic(format!("JSON parse error: {}", e)))?;
cursor += json_len;
memtable.register_node(nid, payload)?;
}
Ok(())
}
fn load_edges<T: VectorType>(
memtable: &mut MemTable<T>,
bytes: &[u8],
edge_offset: usize,
file_len: usize,
) -> Result<()> {
let mut cursor = edge_offset;
while cursor + 18 <= file_len {
let src_id = u64::from_le_bytes(bytes[cursor..cursor + 8].try_into().unwrap());
cursor += 8;
let dst_id = u64::from_le_bytes(bytes[cursor..cursor + 8].try_into().unwrap());
cursor += 8;
let label_len = u16::from_le_bytes(bytes[cursor..cursor + 2].try_into().unwrap()) as usize;
cursor += 2;
if cursor + label_len + 4 > file_len {
break;
}
let label = String::from_utf8(bytes[cursor..cursor + label_len].to_vec())
.map_err(|e| TriviumError::Generic(format!("Label decode error: {}", e)))?;
cursor += label_len;
let weight = f32::from_le_bytes(bytes[cursor..cursor + 4].try_into().unwrap());
cursor += 4;
memtable.link(src_id, dst_id, label, weight)?;
}
Ok(())
}
fn load_erpc(bytes: &[u8], erpc_offset: usize, dim: usize) -> Result<Option<ErpcIndex>> {
if erpc_offset >= bytes.len() || bytes.len() - erpc_offset < 32 {
return Ok(None);
}
let mut cursor = erpc_offset;
let params: ErpcParams = bytemuck::pod_read_unaligned(&bytes[cursor..cursor + 32]);
cursor += 32;
let pca_dims = crate::index::erpc::PCA_DIMS;
let mut pca_basis = Vec::with_capacity(pca_dims);
for _ in 0..pca_dims {
let size = dim * 4;
let mut pc = Vec::with_capacity(dim);
for i in 0..dim {
let offset = cursor + i * 4;
let val = f32::from_le_bytes(bytes[offset..offset + 4].try_into().unwrap());
pc.push(val);
}
pca_basis.push(pc);
cursor += size;
}
let k_clusters = params.k_clusters as usize;
let mut centers = Vec::with_capacity(k_clusters);
for _ in 0..k_clusters {
let size = dim * 4;
let mut c = Vec::with_capacity(dim);
for i in 0..dim {
let offset = cursor + i * 4;
let val = f32::from_le_bytes(bytes[offset..offset + 4].try_into().unwrap());
c.push(val);
}
centers.push(c);
cursor += size;
}
let sequence_bytes = &bytes[cursor..];
let expected_entry_size = std::mem::size_of::<SeqEntry>();
let count = sequence_bytes.len() / expected_entry_size;
let mut sequence: Vec<SeqEntry> = vec![bytemuck::Zeroable::zeroed(); count];
unsafe {
std::ptr::copy_nonoverlapping(
sequence_bytes.as_ptr(),
sequence.as_mut_ptr() as *mut u8,
count * expected_entry_size,
);
}
Ok(Some(ErpcIndex {
pca_basis,
centers,
sequence,
dim,
params,
}))
}