use crate::error::{Result, TriviumError};
use crate::node::{Edge, NodeId};
use crate::storage::memtable::MemTable;
use crate::storage::vec_pool::VecPool;
use crate::database::StorageMode;
use crate::VectorType;
use memmap2::Mmap;
use std::fs::File;
use std::io::{BufWriter, Write};
use std::path::Path;
const MAGIC: &[u8; 4] = b"TVDB";
const VERSION: u16 = 2; const HEADER_SIZE: u64 = 50;
fn vec_path_from_db(db_path: &str) -> String {
format!("{}.vec", db_path)
}
pub fn save<T: VectorType>(memtable: &mut MemTable<T>, path: &str, mode: StorageMode) -> Result<()> {
match mode {
StorageMode::Mmap => save_mmap(memtable, path),
StorageMode::Rom => save_rom(memtable, path),
}
}
fn save_mmap<T: VectorType>(memtable: &mut MemTable<T>, path: &str) -> Result<()> {
let vec_file_path = vec_path_from_db(path);
let vec_count = memtable.vec_pool_mut().flush(Path::new(&vec_file_path))?;
save_tdb(memtable, path, vec_count, true)?;
Ok(())
}
fn save_rom<T: VectorType>(memtable: &mut MemTable<T>, path: &str) -> Result<()> {
memtable.ensure_vectors_cache();
let total_vectors = memtable.internal_indices().len();
save_tdb(memtable, path, total_vectors, false)?;
memtable.vec_pool_mut().detach_mmap();
let vec_file_path = vec_path_from_db(path);
if Path::new(&vec_file_path).exists() {
std::fs::remove_file(vec_file_path).ok(); }
Ok(())
}
fn save_tdb<T: VectorType>(memtable: &mut MemTable<T>, path: &str, vec_count: usize, is_mmap_mode: bool) -> Result<()> {
if !is_mmap_mode {
memtable.ensure_vectors_cache();
}
let tmp_path = format!("{}.tmp", path);
let file = File::create(&tmp_path)?;
let mut w = BufWriter::new(file);
let dim = memtable.dim();
let internal_indices = memtable.internal_indices();
let node_count = internal_indices.len() as u64;
let mut all_edges: Vec<(NodeId, &Edge)> = Vec::new();
let mut payload_size: u64 = 0;
for &nid in internal_indices {
if nid != 0 { if let Some(p) = memtable.get_payload(nid) {
let json_bytes = serde_json::to_vec(p).unwrap_or_default();
payload_size += 8 + 4 + json_bytes.len() as u64;
} else {
payload_size += 12;
}
if let Some(edges) = memtable.get_edges(nid) {
for edge in edges {
all_edges.push((nid, edge));
}
}
} else { payload_size += 12;
}
}
let payload_offset = HEADER_SIZE;
let vector_offset = if is_mmap_mode { 0 } else { payload_offset + payload_size };
let vector_size = if is_mmap_mode { 0 } else { node_count * (dim as u64) * (std::mem::size_of::<T>() as u64) };
let edge_offset = payload_offset + payload_size + vector_size;
w.write_all(MAGIC)?;
w.write_all(&VERSION.to_le_bytes())?;
w.write_all(&(dim as u32).to_le_bytes())?;
w.write_all(&memtable.next_id_value().to_le_bytes())?;
w.write_all(&node_count.to_le_bytes())?;
w.write_all(&payload_offset.to_le_bytes())?;
w.write_all(&vector_offset.to_le_bytes())?;
w.write_all(&edge_offset.to_le_bytes())?;
for &nid in internal_indices {
if nid != 0 {
if let Some(p) = memtable.get_payload(nid) {
let json_bytes = serde_json::to_vec(p).unwrap_or_default();
w.write_all(&nid.to_le_bytes())?;
w.write_all(&(json_bytes.len() as u32).to_le_bytes())?;
w.write_all(&json_bytes)?;
continue;
}
}
w.write_all(&0u64.to_le_bytes())?;
w.write_all(&0u32.to_le_bytes())?;
}
if !is_mmap_mode {
let flat = memtable.flat_vectors();
w.write_all(bytemuck::cast_slice(flat))?;
}
for (src_id, edge) in &all_edges {
w.write_all(&src_id.to_le_bytes())?;
w.write_all(&edge.target_id.to_le_bytes())?;
let label_bytes = edge.label.as_bytes();
w.write_all(&(label_bytes.len() as u16).to_le_bytes())?;
w.write_all(label_bytes)?;
w.write_all(&edge.weight.to_le_bytes())?;
}
w.flush()?;
let file = w.into_inner().map_err(|e| TriviumError::Io(e.into_error()))?;
file.sync_all()?;
drop(file);
std::fs::rename(&tmp_path, path)?;
tracing::info!(
"持久化完成: {} 个槽位(含删除), {} 个向量, Mode: {}",
node_count, vec_count, if is_mmap_mode { "Mmap" } else { "Rom" }
);
Ok(())
}
pub fn load<T: VectorType>(path: &str, _mode: StorageMode) -> Result<MemTable<T>> {
let file = File::open(path).map_err(TriviumError::Io)?;
let mmap = unsafe { Mmap::map(&file) }
.map_err(|e| TriviumError::Io(e))?;
if mmap.len() < HEADER_SIZE as usize {
return Err(TriviumError::Generic("File too small for header".into()));
}
let bytes = &mmap[..];
if &bytes[0..4] != MAGIC {
return Err(TriviumError::Generic(format!(
"Invalid file magic: expected TVDB, got {:?}", &bytes[0..4]
)));
}
let dim = u32::from_le_bytes(bytes[6..10].try_into().unwrap()) as usize;
let next_id = u64::from_le_bytes(bytes[10..18].try_into().unwrap());
let node_count = u64::from_le_bytes(bytes[18..26].try_into().unwrap()) as usize;
let payload_offset = u64::from_le_bytes(bytes[26..34].try_into().unwrap()) as usize;
let vector_offset = u64::from_le_bytes(bytes[34..42].try_into().unwrap()) as usize;
let edge_offset = u64::from_le_bytes(bytes[42..50].try_into().unwrap()) as usize;
let vec_file_path = vec_path_from_db(path);
if vector_offset == 0 && Path::new(&vec_file_path).exists() {
load_v2(bytes, dim, next_id, node_count, payload_offset, edge_offset, &vec_file_path, &mmap)
} else {
load_v1_rom(bytes, dim, next_id, node_count, payload_offset, vector_offset, edge_offset, &mmap)
}
}
fn load_v2<T: VectorType>(
bytes: &[u8], dim: usize, next_id: u64, node_count: usize,
payload_offset: usize, edge_offset: usize, vec_file_path: &str, tdb_mmap: &Mmap,
) -> Result<MemTable<T>> {
let vec_pool = VecPool::<T>::open(Path::new(vec_file_path), dim, node_count)?;
let mut memtable = MemTable::new_with_vec_pool(dim, next_id, vec_pool);
load_payloads(&mut memtable, bytes, node_count, payload_offset, edge_offset)?;
load_edges(&mut memtable, bytes, edge_offset, tdb_mmap.len())?;
Ok(memtable)
}
fn load_v1_rom<T: VectorType>(
bytes: &[u8], dim: usize, next_id: u64, node_count: usize,
payload_offset: usize, vector_offset: usize, edge_offset: usize, tdb_mmap: &Mmap,
) -> Result<MemTable<T>> {
let mut memtable = MemTable::new_with_next_id(dim, next_id);
let vector_bytes_per_elem = std::mem::size_of::<T>();
let expected_vec_size = node_count * dim * vector_bytes_per_elem;
if vector_offset + expected_vec_size > tdb_mmap.len() {
return Err(TriviumError::Generic("Vector block exceeds file size".into()));
}
load_payloads(&mut memtable, bytes, node_count, payload_offset, vector_offset)?;
let vec_block = &bytes[vector_offset..vector_offset + expected_vec_size];
let is_aligned = (vec_block.as_ptr() as usize) % std::mem::align_of::<T>() == 0;
if is_aligned {
let t_slice = unsafe {
std::slice::from_raw_parts(
vec_block.as_ptr() as *const T, node_count * dim
)
};
memtable.vec_pool_mut().push(t_slice);
} else {
let mut v = Vec::with_capacity(node_count * dim);
for i in 0..(node_count * dim) {
let off = i * vector_bytes_per_elem;
let chunk = &vec_block[off..off + vector_bytes_per_elem];
let elem: T = bytemuck::pod_read_unaligned(chunk);
v.push(elem);
}
memtable.vec_pool_mut().push(&v);
}
load_edges(&mut memtable, bytes, edge_offset, tdb_mmap.len())?;
Ok(memtable)
}
fn load_payloads<T: VectorType>(
memtable: &mut MemTable<T>, bytes: &[u8], node_count: usize, offset: usize, end_offset: usize
) -> Result<()> {
let mut cursor = offset;
for _ in 0..node_count {
if cursor + 12 > end_offset {
return Err(TriviumError::Generic("Payload block overflow".into()));
}
let nid = u64::from_le_bytes(bytes[cursor..cursor+8].try_into().unwrap());
cursor += 8;
let json_len = u32::from_le_bytes(bytes[cursor..cursor+4].try_into().unwrap()) as usize;
cursor += 4;
if nid == 0 && json_len == 0 {
memtable.register_tombstone()?;
continue;
}
if cursor + json_len > end_offset {
return Err(TriviumError::Generic("JSON data overflow".into()));
}
let payload: serde_json::Value = serde_json::from_slice(&bytes[cursor..cursor+json_len])
.map_err(|e| TriviumError::Generic(format!("JSON parse error: {}", e)))?;
cursor += json_len;
memtable.register_node(nid, payload)?;
}
Ok(())
}
fn load_edges<T: VectorType>(memtable: &mut MemTable<T>, bytes: &[u8], edge_offset: usize, file_len: usize) -> Result<()> {
let mut cursor = edge_offset;
while cursor + 18 <= file_len {
let src_id = u64::from_le_bytes(bytes[cursor..cursor+8].try_into().unwrap());
cursor += 8;
let dst_id = u64::from_le_bytes(bytes[cursor..cursor+8].try_into().unwrap());
cursor += 8;
let label_len = u16::from_le_bytes(bytes[cursor..cursor+2].try_into().unwrap()) as usize;
cursor += 2;
if cursor + label_len + 4 > file_len { break; }
let label = String::from_utf8(bytes[cursor..cursor+label_len].to_vec())
.map_err(|e| TriviumError::Generic(format!("Label decode error: {}", e)))?;
cursor += label_len;
let weight = f32::from_le_bytes(bytes[cursor..cursor+4].try_into().unwrap());
cursor += 4;
memtable.link(src_id, dst_id, label, weight)?;
}
Ok(())
}