use crate::VectorType;
use crate::database::StorageMode;
use crate::error::{Result, TriviumError};
use crate::index::bq::BqSignature;
use crate::node::{Edge, NodeId};
use crate::storage::memtable::MemTable;
use crate::storage::vec_pool::VecPool;
use memmap2::Mmap;
use std::fs::File;
use std::io::{BufWriter, Write};
use std::path::Path;
fn robust_rename(from: &Path, to: &Path) -> std::io::Result<()> {
#[cfg(not(windows))]
{
std::fs::rename(from, to)
}
#[cfg(windows)]
{
let max_retries = 10;
let mut delay = std::time::Duration::from_millis(1);
for attempt in 0..max_retries {
match std::fs::rename(from, to) {
Ok(()) => return Ok(()),
Err(e) if attempt < max_retries - 1 => {
let os_err = e.raw_os_error();
if os_err == Some(5) || os_err == Some(32) {
tracing::debug!(
"robust_rename: attempt {} failed (os_error={:?}), retrying in {:?}",
attempt + 1,
os_err,
delay
);
std::thread::sleep(delay);
delay = (delay * 2).min(std::time::Duration::from_millis(50));
continue;
}
return Err(e);
}
Err(e) => return Err(e),
}
}
Err(std::io::Error::other("robust_rename exhausted retries"))
}
}
const MAGIC: &[u8; 4] = b"TVDB";
const VERSION: u16 = 5; const HEADER_SIZE: u64 = 58;
#[inline]
fn read_u16_le(bytes: &[u8], offset: usize, field: &str) -> Result<u16> {
bytes
.get(offset..offset + 2)
.and_then(|s| s.try_into().ok())
.map(u16::from_le_bytes)
.ok_or_else(|| TriviumError::CorruptedFile(format!("{} at offset {}", field, offset)))
}
#[inline]
fn read_u32_le(bytes: &[u8], offset: usize, field: &str) -> Result<u32> {
bytes
.get(offset..offset + 4)
.and_then(|s| s.try_into().ok())
.map(u32::from_le_bytes)
.ok_or_else(|| TriviumError::CorruptedFile(format!("{} at offset {}", field, offset)))
}
#[inline]
fn read_u64_le(bytes: &[u8], offset: usize, field: &str) -> Result<u64> {
bytes
.get(offset..offset + 8)
.and_then(|s| s.try_into().ok())
.map(u64::from_le_bytes)
.ok_or_else(|| TriviumError::CorruptedFile(format!("{} at offset {}", field, offset)))
}
#[inline]
fn read_f32_le(bytes: &[u8], offset: usize, field: &str) -> Result<f32> {
bytes
.get(offset..offset + 4)
.and_then(|s| s.try_into().ok())
.map(f32::from_le_bytes)
.ok_or_else(|| TriviumError::CorruptedFile(format!("{} at offset {}", field, offset)))
}
fn vec_path_from_db(db_path: &str) -> String {
format!("{}.vec", db_path)
}
fn flush_ok_path_from_db(db_path: &str) -> String {
format!("{}.flush_ok", db_path)
}
fn quiver_path_from_db(db_path: &str) -> String {
format!("{}.quiver", db_path)
}
pub fn save<T: VectorType>(
memtable: &mut MemTable<T>,
path: &str,
mode: StorageMode,
) -> Result<()> {
match mode {
StorageMode::Mmap => save_mmap(memtable, path),
StorageMode::Rom => save_rom(memtable, path),
}?;
let quiver_path = quiver_path_from_db(path);
if let Some(quiver) = memtable.quiver() {
if let Err(e) = quiver.save_to_file(std::path::Path::new(&quiver_path)) {
tracing::warn!("QuIVer 索引持久化失败(不影响主数据): {}", e);
}
} else {
let qp = std::path::Path::new(&quiver_path);
if qp.exists() {
std::fs::remove_file(qp).ok();
}
}
Ok(())
}
fn save_mmap<T: VectorType>(memtable: &mut MemTable<T>, path: &str) -> Result<()> {
let vec_file_path = vec_path_from_db(path);
let vec_count = memtable.vec_pool_mut().flush(Path::new(&vec_file_path))?;
save_tdb(memtable, path, vec_count, true)?;
let tdb_size = std::fs::metadata(path).map(|m| m.len()).unwrap_or(0);
let vec_size = std::fs::metadata(&vec_file_path)
.map(|m| m.len())
.unwrap_or(0);
let marker_path = flush_ok_path_from_db(path);
let marker_tmp = format!("{}.tmp", &marker_path);
{
let mut f = File::create(&marker_tmp)?;
f.write_all(&tdb_size.to_le_bytes())?;
f.write_all(&vec_size.to_le_bytes())?;
f.sync_all()?;
}
robust_rename(Path::new(&marker_tmp), Path::new(&marker_path))?;
Ok(())
}
fn save_rom<T: VectorType>(memtable: &mut MemTable<T>, path: &str) -> Result<()> {
memtable.ensure_vectors_cache();
let total_vectors = memtable.internal_indices().len();
save_tdb(memtable, path, total_vectors, false)?;
memtable.vec_pool_mut().detach_mmap();
let vec_file_path = vec_path_from_db(path);
if Path::new(&vec_file_path).exists() {
std::fs::remove_file(vec_file_path).ok();
}
let marker_path = flush_ok_path_from_db(path);
if Path::new(&marker_path).exists() {
std::fs::remove_file(marker_path).ok();
}
Ok(())
}
fn save_tdb<T: VectorType>(
memtable: &mut MemTable<T>,
path: &str,
vec_count: usize,
is_mmap_mode: bool,
) -> Result<()> {
memtable.ensure_vectors_cache();
let tmp_path = format!("{}.tmp", path);
let file = File::create(&tmp_path)?;
let mut w = BufWriter::new(file);
let dim = memtable.dim();
let internal_indices = memtable.internal_indices();
let node_count = internal_indices.len() as u64;
let mut all_edges: Vec<(NodeId, &Edge)> = Vec::new();
let mut payload_size: u64 = 0;
for &nid in internal_indices {
if nid != 0 {
if let Some(p) = memtable.get_payload(nid) {
let json_bytes = serde_json::to_vec(p).unwrap_or_default();
payload_size += 8 + 4 + json_bytes.len() as u64;
} else {
payload_size += 12;
}
if let Some(edges) = memtable.get_edges(nid) {
for edge in edges {
all_edges.push((nid, edge));
}
}
} else {
payload_size += 12;
}
}
let payload_offset = HEADER_SIZE;
let vector_offset = if is_mmap_mode {
0
} else {
payload_offset + payload_size
};
let vector_size = if is_mmap_mode {
0
} else {
node_count * (dim as u64) * (std::mem::size_of::<T>() as u64)
};
let edge_offset = payload_offset + payload_size + vector_size;
let mut edge_block_size: u64 = 0;
for (_src_id, edge) in &all_edges {
edge_block_size += 8 + 8 + 2 + edge.label.len() as u64 + 4;
}
let bq_offset = edge_offset + edge_block_size;
w.write_all(MAGIC)?;
w.write_all(&VERSION.to_le_bytes())?;
w.write_all(&(dim as u32).to_le_bytes())?;
w.write_all(&memtable.next_id_value().to_le_bytes())?;
w.write_all(&node_count.to_le_bytes())?;
w.write_all(&payload_offset.to_le_bytes())?;
w.write_all(&vector_offset.to_le_bytes())?;
w.write_all(&edge_offset.to_le_bytes())?;
w.write_all(&bq_offset.to_le_bytes())?;
for &nid in internal_indices {
if nid != 0
&& let Some(p) = memtable.get_payload(nid)
{
let json_bytes = serde_json::to_vec(p).unwrap_or_default();
w.write_all(&nid.to_le_bytes())?;
w.write_all(&(json_bytes.len() as u32).to_le_bytes())?;
w.write_all(&json_bytes)?;
continue;
}
w.write_all(&0u64.to_le_bytes())?;
w.write_all(&0u32.to_le_bytes())?;
}
if !is_mmap_mode {
let flat = memtable.flat_vectors();
w.write_all(bytemuck::cast_slice(flat))?;
}
for (src_id, edge) in &all_edges {
w.write_all(&src_id.to_le_bytes())?;
w.write_all(&edge.target_id.to_le_bytes())?;
let label_bytes = edge.label.as_bytes();
w.write_all(&(label_bytes.len() as u16).to_le_bytes())?;
w.write_all(label_bytes)?;
w.write_all(&edge.weight.to_le_bytes())?;
}
let bq_sigs = memtable.bq_signatures_slice();
let bq_count = bq_sigs.len() as u64;
w.write_all(&bq_count.to_le_bytes())?; if !bq_sigs.is_empty() {
w.write_all(bytemuck::cast_slice(bq_sigs))?;
}
w.flush()?;
let file = w
.into_inner()
.map_err(|e| TriviumError::Io(e.into_error()))?;
file.sync_all()?;
drop(file);
robust_rename(Path::new(&tmp_path), Path::new(path))?;
tracing::info!(
"持久化完成: {} 个槽位(含删除), {} 个向量, {} 个 BQ 签名, Mode: {}",
node_count,
vec_count,
bq_count,
if is_mmap_mode { "Mmap" } else { "Rom" }
);
Ok(())
}
pub fn load<T: VectorType>(path: &str, _mode: StorageMode) -> Result<MemTable<T>> {
let file = File::open(path).map_err(TriviumError::Io)?;
let mmap = unsafe { Mmap::map(&file) }.map_err(TriviumError::Io)?;
if mmap.len() < HEADER_SIZE as usize {
return Err(TriviumError::CorruptedFile(
"File too small for header".into(),
));
}
let bytes = &mmap[..];
if &bytes[0..4] != MAGIC {
return Err(TriviumError::CorruptedFile(format!(
"Invalid file magic: expected TVDB, got {:?}",
&bytes[0..4]
)));
}
let version = read_u16_le(bytes, 4, "header version")?;
let dim = read_u32_le(bytes, 6, "header dim")? as usize;
let next_id = read_u64_le(bytes, 10, "header next_id")?;
let node_count = read_u64_le(bytes, 18, "header node_count")? as usize;
let payload_offset = read_u64_le(bytes, 26, "header payload_offset")? as usize;
let vector_offset = read_u64_le(bytes, 34, "header vector_offset")? as usize;
let edge_offset = read_u64_le(bytes, 42, "header edge_offset")? as usize;
let bq_offset = if version >= 5 && mmap.len() >= 58 {
read_u64_le(bytes, 50, "header bq_offset")? as usize
} else {
0 };
let file_len = mmap.len();
if payload_offset > file_len {
return Err(TriviumError::CorruptedFile(format!(
"payload_offset ({}) exceeds file size ({}), file truncated",
payload_offset, file_len
)));
}
if edge_offset > file_len {
return Err(TriviumError::CorruptedFile(format!(
"edge_offset ({}) exceeds file size ({}), file truncated",
edge_offset, file_len
)));
}
if bq_offset > 0 {
if bq_offset > file_len {
return Err(TriviumError::CorruptedFile(format!(
"bq_offset ({}) exceeds file size ({}), file truncated",
bq_offset, file_len
)));
}
if bq_offset + 8 <= file_len {
let bq_count = u64::from_le_bytes(
bytes[bq_offset..bq_offset + 8].try_into().unwrap_or([0; 8]),
) as usize;
if bq_count > 0 {
let sig_size = std::mem::size_of::<BqSignature>();
let expected_bq_end = bq_offset + 8 + bq_count * sig_size;
if expected_bq_end > file_len {
return Err(TriviumError::CorruptedFile(format!(
"BQ block truncated: expected {} bytes (offset {} + 8 + {} × {}), \
actual file size {} bytes",
expected_bq_end, bq_offset, bq_count, sig_size, file_len
)));
}
}
}
}
let edge_limit_offset = if version >= 4 {
if version >= 5 && bq_offset > 0 {
bq_offset
} else {
mmap.len()
}
} else if mmap.len() >= 58 {
read_u64_le(bytes, 50, "header edge_limit_offset")? as usize
} else {
mmap.len()
};
let vec_file_path = vec_path_from_db(path);
if vector_offset == 0 && Path::new(&vec_file_path).exists() {
let marker_path = flush_ok_path_from_db(path);
let flush_ok_valid = (|| -> Option<bool> {
let marker_bytes = std::fs::read(&marker_path).ok()?;
if marker_bytes.len() < 16 {
return Some(false);
}
let stored_tdb = u64::from_le_bytes(marker_bytes[0..8].try_into().ok()?);
let stored_vec = u64::from_le_bytes(marker_bytes[8..16].try_into().ok()?);
let actual_tdb = std::fs::metadata(path).ok()?.len();
let actual_vec = std::fs::metadata(&vec_file_path).ok()?.len();
Some(stored_tdb == actual_tdb && stored_vec == actual_vec)
})()
.unwrap_or(false);
if flush_ok_valid {
let mut mt = load_v2(
bytes,
dim,
next_id,
node_count,
payload_offset,
edge_offset,
edge_limit_offset,
&vec_file_path,
&mmap,
)?;
load_bq_block(&mut mt, bytes, bq_offset, mmap.len());
load_quiver_index(&mut mt, path);
Ok(mt)
} else {
tracing::warn!(
"检测到 .tdb/.vec 跨文件撕裂(.flush_ok 标记缺失或不匹配),\
将尝试按当前文件恢复,失败后再降级为仅加载 .tdb 元数据"
);
match load_v2(
bytes,
dim,
next_id,
node_count,
payload_offset,
edge_offset,
edge_limit_offset,
&vec_file_path,
&mmap,
) {
Ok(mut mt) => {
load_bq_block(&mut mt, bytes, bq_offset, mmap.len());
load_quiver_index(&mut mt, path);
Ok(mt)
}
Err(e) => {
tracing::warn!("当前 .tdb/.vec 组合不可用,进入安全降级恢复: {}", e);
let mut mt = load_v2_metadata_only(
bytes,
dim,
next_id,
node_count,
payload_offset,
edge_offset,
edge_limit_offset,
)?;
load_bq_block(&mut mt, bytes, bq_offset, mmap.len());
load_quiver_index(&mut mt, path);
Ok(mt)
}
}
}
} else {
let mut mt = load_v1_rom(
bytes,
dim,
next_id,
node_count,
payload_offset,
vector_offset,
edge_offset,
edge_limit_offset,
&mmap,
)?;
load_bq_block(&mut mt, bytes, bq_offset, mmap.len());
load_quiver_index(&mut mt, path);
Ok(mt)
}
}
fn load_v2<T: VectorType>(
bytes: &[u8],
dim: usize,
next_id: u64,
node_count: usize,
payload_offset: usize,
edge_offset: usize,
edge_limit_offset: usize,
vec_file_path: &str,
_tdb_mmap: &Mmap,
) -> Result<MemTable<T>> {
let vec_pool = VecPool::<T>::open(Path::new(vec_file_path), dim, node_count)?;
let mut memtable = MemTable::new_with_vec_pool(dim, next_id, vec_pool);
load_payloads(
&mut memtable,
bytes,
node_count,
payload_offset,
edge_offset,
)?;
load_edges(&mut memtable, bytes, edge_offset, edge_limit_offset)?;
Ok(memtable)
}
fn load_v2_metadata_only<T: VectorType>(
bytes: &[u8],
dim: usize,
next_id: u64,
node_count: usize,
payload_offset: usize,
edge_offset: usize,
edge_limit_offset: usize,
) -> Result<MemTable<T>> {
let mut memtable = MemTable::new_with_next_id(dim, next_id);
load_payloads(
&mut memtable,
bytes,
node_count,
payload_offset,
edge_offset,
)?;
let zero = vec![T::zero(); dim];
for id in memtable.internal_indices().to_vec() {
if id != 0
&& let Some(payload) = memtable.get_payload(id).cloned()
{
memtable.raw_insert(id, &zero, payload)?;
}
}
load_edges(&mut memtable, bytes, edge_offset, edge_limit_offset)?;
Ok(memtable)
}
fn load_v1_rom<T: VectorType>(
bytes: &[u8],
dim: usize,
next_id: u64,
node_count: usize,
payload_offset: usize,
vector_offset: usize,
edge_offset: usize,
edge_limit_offset: usize,
tdb_mmap: &Mmap,
) -> Result<MemTable<T>> {
let mut memtable = MemTable::new_with_next_id(dim, next_id);
let vector_bytes_per_elem = std::mem::size_of::<T>();
let expected_vec_size = node_count * dim * vector_bytes_per_elem;
if vector_offset + expected_vec_size > tdb_mmap.len() {
return Err(TriviumError::CorruptedFile(
"Vector block exceeds file size".into(),
));
}
load_payloads(
&mut memtable,
bytes,
node_count,
payload_offset,
vector_offset,
)?;
let vec_block = &bytes[vector_offset..vector_offset + expected_vec_size];
let is_aligned = (vec_block.as_ptr() as usize).is_multiple_of(std::mem::align_of::<T>());
if is_aligned {
let t_slice =
unsafe { std::slice::from_raw_parts(vec_block.as_ptr() as *const T, node_count * dim) };
memtable.vec_pool_mut().push(t_slice);
} else {
let mut v = Vec::with_capacity(node_count * dim);
for i in 0..(node_count * dim) {
let off = i * vector_bytes_per_elem;
let chunk = &vec_block[off..off + vector_bytes_per_elem];
let elem: T = bytemuck::pod_read_unaligned(chunk);
v.push(elem);
}
memtable.vec_pool_mut().push(&v);
}
load_edges(&mut memtable, bytes, edge_offset, edge_limit_offset)?;
Ok(memtable)
}
fn load_payloads<T: VectorType>(
memtable: &mut MemTable<T>,
bytes: &[u8],
node_count: usize,
offset: usize,
end_offset: usize,
) -> Result<()> {
let mut cursor = offset;
for _ in 0..node_count {
if cursor.saturating_add(12) > end_offset {
return Err(TriviumError::CorruptedFile("Payload block overflow".into()));
}
let nid = read_u64_le(bytes, cursor, "payload node_id")?;
cursor += 8;
let json_len = read_u32_le(bytes, cursor, "payload json_len")? as usize;
cursor += 4;
if nid == 0 && json_len == 0 {
memtable.register_tombstone()?;
continue;
}
if cursor.saturating_add(json_len) > end_offset {
return Err(TriviumError::CorruptedFile("JSON data overflow".into()));
}
let payload: serde_json::Value = serde_json::from_slice(&bytes[cursor..cursor + json_len])
.map_err(|e| TriviumError::CorruptedFile(format!("JSON parse error: {}", e)))?;
cursor += json_len;
memtable.register_node(nid, payload)?;
}
Ok(())
}
fn load_edges<T: VectorType>(
memtable: &mut MemTable<T>,
bytes: &[u8],
edge_offset: usize,
file_len: usize,
) -> Result<()> {
let mut cursor = edge_offset;
while cursor.saturating_add(18) <= file_len {
let src_id = read_u64_le(bytes, cursor, "edge src_id")?;
cursor += 8;
let dst_id = read_u64_le(bytes, cursor, "edge dst_id")?;
cursor += 8;
let label_len = read_u16_le(bytes, cursor, "edge label_len")? as usize;
cursor += 2;
if cursor.saturating_add(label_len).saturating_add(4) > file_len {
break;
}
let label = String::from_utf8(bytes[cursor..cursor + label_len].to_vec())
.map_err(|e| TriviumError::CorruptedFile(format!("Label decode error: {}", e)))?;
cursor += label_len;
let weight = read_f32_le(bytes, cursor, "edge weight")?;
cursor += 4;
memtable.link(src_id, dst_id, label, weight)?;
}
Ok(())
}
fn load_bq_block<T: VectorType>(
memtable: &mut MemTable<T>,
bytes: &[u8],
bq_offset: usize,
file_len: usize,
) {
if bq_offset == 0 || bq_offset + 8 > file_len {
return; }
let bq_count =
u64::from_le_bytes(bytes[bq_offset..bq_offset + 8].try_into().unwrap_or([0; 8])) as usize;
if bq_count == 0 {
return;
}
let sig_size = std::mem::size_of::<BqSignature>();
let data_start = bq_offset + 8;
let data_end = data_start + bq_count * sig_size;
if data_end > file_len {
tracing::warn!(
"BQ Block 数据不完整(需要 {} 字节,文件仅剩 {} 字节),跳过恢复",
bq_count * sig_size,
file_len.saturating_sub(data_start)
);
return;
}
let bq_bytes = &bytes[data_start..data_end];
let is_aligned =
(bq_bytes.as_ptr() as usize).is_multiple_of(std::mem::align_of::<BqSignature>());
let sigs: Vec<BqSignature> = if is_aligned {
let slice: &[BqSignature] = bytemuck::cast_slice(bq_bytes);
slice.to_vec()
} else {
let mut v = Vec::with_capacity(bq_count);
for i in 0..bq_count {
let off = i * sig_size;
let sig: BqSignature = bytemuck::pod_read_unaligned(&bq_bytes[off..off + sig_size]);
v.push(sig);
}
v
};
memtable.set_bq_signatures(sigs);
tracing::info!("从 .tdb 恢复了 {} 个 BQ 签名(零拷贝加载)", bq_count);
}
fn load_quiver_index<T: VectorType>(memtable: &mut MemTable<T>, db_path: &str) {
use crate::index::quiver::QuIVer;
let quiver_path = quiver_path_from_db(db_path);
let qp = std::path::Path::new(&quiver_path);
if !qp.exists() {
return;
}
match QuIVer::load_from_file(qp) {
Ok(quiver) => {
memtable.set_quiver_index(quiver);
}
Err(e) => {
tracing::warn!(
"QuIVer 索引加载失败(将在首次查询时自动重建): {}",
e
);
std::fs::remove_file(qp).ok();
}
}
}