pub mod config;
pub(crate) mod pipeline;
pub mod transaction;
pub use config::{Config, SearchConfig, StorageMode};
pub use transaction::Transaction;
use crate::VectorType;
use crate::error::{Result, TriviumError};
use crate::hook::{HookContext, NoopHook, SearchHook};
use crate::node::{NodeId, SearchHit};
use crate::storage::compaction::CompactionThread;
use crate::storage::file_format;
use crate::storage::memtable::MemTable;
use crate::storage::wal::{SyncMode, Wal, WalEntry};
use fs2::FileExt;
use std::sync::{Arc, Mutex, MutexGuard};
use std::time::Duration;
pub(crate) fn lock_or_recover<T>(mutex: &Mutex<T>) -> MutexGuard<'_, T> {
mutex.lock().unwrap_or_else(|poisoned| {
tracing::warn!("Mutex was poisoned, recovering...");
poisoned.into_inner()
})
}
pub struct Database<T: VectorType> {
pub(crate) db_path: String,
pub(crate) memtable: Arc<Mutex<MemTable<T>>>,
pub(crate) wal: Arc<Mutex<Wal>>,
pub(crate) compaction: Option<CompactionThread>,
_lock_file: std::fs::File,
memory_limit: usize,
pub(crate) storage_mode: StorageMode,
hook: Arc<dyn SearchHook>,
}
impl<T: VectorType + serde::Serialize + serde::de::DeserializeOwned> Database<T> {
pub fn open(path: &str, dim: usize) -> Result<Self> {
let config = Config {
dim,
..Default::default()
};
Self::open_with_config(path, config)
}
pub fn open_with_sync(path: &str, dim: usize, sync_mode: SyncMode) -> Result<Self> {
let config = Config {
dim,
sync_mode,
..Default::default()
};
Self::open_with_config(path, config)
}
pub fn open_with_config(path: &str, config: Config) -> Result<Self> {
let dim = config.dim;
if let Some(parent_dir) = std::path::Path::new(path).parent()
&& !parent_dir.as_os_str().is_empty() {
std::fs::create_dir_all(parent_dir)?;
}
let lock_path = format!("{}.lock", path);
let lock_file = std::fs::OpenOptions::new()
.create(true)
.write(true)
.open(&lock_path)?;
lock_file.try_lock_exclusive().map_err(|_| {
TriviumError::Generic(format!(
"Database '{}' is already opened by another process. \
If this is unexpected, delete '{}'",
path, lock_path
))
})?;
let mut memtable = if std::path::Path::new(path).exists() {
file_format::load(path, config.storage_mode)?
} else {
MemTable::new(dim)
};
if Wal::needs_recovery(path) {
let (entries, valid_offset) = Wal::read_entries::<T>(path)?;
let wal_path = format!("{}.wal", path);
let wal_file = std::fs::OpenOptions::new().write(true).open(&wal_path)?;
wal_file.set_len(valid_offset)?;
wal_file.sync_all()?;
if !entries.is_empty() {
tracing::info!("Recovering {} entries from WAL, safely truncated at offset {}...", entries.len(), valid_offset);
for entry in entries {
transaction::replay_entry(&mut memtable, entry);
}
} else {
tracing::info!("Cleared purely corrupt/uncommitted WAL data, truncated back to {}.", valid_offset);
}
}
memtable.rebuild_text_index_from_payloads();
let wal = Wal::open_with_sync(path, config.sync_mode)?;
Ok(Self {
db_path: path.to_string(),
memtable: Arc::new(Mutex::new(memtable)),
wal: Arc::new(Mutex::new(wal)),
compaction: None,
_lock_file: lock_file,
memory_limit: 0,
storage_mode: config.storage_mode,
hook: Arc::new(NoopHook),
})
}
pub fn set_sync_mode(&mut self, mode: SyncMode) {
let mut w = lock_or_recover(&self.wal);
w.set_sync_mode(mode);
}
pub fn set_hook(&mut self, hook: impl SearchHook + 'static) {
self.hook = Arc::new(hook);
}
pub fn clear_hook(&mut self) {
self.hook = Arc::new(NoopHook);
}
pub fn hook(&self) -> &dyn SearchHook {
self.hook.as_ref()
}
pub fn set_memory_limit(&mut self, bytes: usize) {
self.memory_limit = bytes;
}
pub fn estimated_memory(&self) -> usize {
lock_or_recover(&self.memtable).estimated_memory_bytes()
}
fn check_memory_pressure(&mut self) {
if self.memory_limit > 0 {
let usage = lock_or_recover(&self.memtable).estimated_memory_bytes();
if usage > self.memory_limit {
tracing::info!(
"Memory pressure: {}MB > limit {}MB. Auto-flushing...",
usage / (1024 * 1024),
self.memory_limit / (1024 * 1024)
);
if let Err(e) = self.flush() {
tracing::error!("Auto-flush failed: {}", e);
}
}
}
}
pub fn enable_auto_compaction(&mut self, interval: Duration) {
self.compaction.take();
let ct = CompactionThread::spawn(
interval,
Arc::clone(&self.memtable),
Arc::clone(&self.wal),
self.db_path.clone(),
self.storage_mode,
);
self.compaction = Some(ct);
}
pub fn disable_auto_compaction(&mut self) {
self.compaction.take();
}
pub fn compact(&mut self) -> Result<()> {
{
let mut mt = lock_or_recover(&self.memtable);
tracing::info!("Manual compaction started for {}", self.db_path);
mt.ensure_vectors_cache();
}
{
let mut mt = lock_or_recover(&self.memtable);
file_format::save(&mut mt, &self.db_path, self.storage_mode)?;
let mut w = lock_or_recover(&self.wal);
w.clear()?;
}
tracing::info!("Manual compaction completed for {}", self.db_path);
Ok(())
}
pub fn insert(&mut self, vector: &[T], payload: serde_json::Value) -> Result<NodeId> {
let payload_str = payload.to_string();
if payload_str.len() > 8 * 1024 * 1024 {
return Err(crate::error::TriviumError::Generic("Payload size exceeds maximum allowed limit (8MB)".into()));
}
let id = {
let mut mt = lock_or_recover(&self.memtable);
mt.insert(vector, payload.clone())?
};
{
let mut w = lock_or_recover(&self.wal);
w.append(&WalEntry::Insert {
id,
vector: vector.to_vec(),
payload: payload_str,
})?;
}
self.check_memory_pressure();
Ok(id)
}
pub fn insert_with_id(
&mut self,
id: NodeId,
vector: &[T],
payload: serde_json::Value,
) -> Result<()> {
let payload_str = payload.to_string();
if payload_str.len() > 8 * 1024 * 1024 {
return Err(crate::error::TriviumError::Generic("Payload size exceeds maximum allowed limit (8MB)".into()));
}
{
let mut mt = lock_or_recover(&self.memtable);
mt.insert_with_id(id, vector, payload.clone())?;
}
{
let mut w = lock_or_recover(&self.wal);
w.append(&WalEntry::Insert {
id,
vector: vector.to_vec(),
payload: payload_str,
})?;
}
self.check_memory_pressure();
Ok(())
}
pub fn link(&mut self, src: NodeId, dst: NodeId, label: &str, weight: f32) -> Result<()> {
{
let mut mt = lock_or_recover(&self.memtable);
mt.link(src, dst, label.to_string(), weight)?;
}
{
let mut w = lock_or_recover(&self.wal);
w.append(&WalEntry::Link::<T> {
src,
dst,
label: label.to_string(),
weight,
})?;
}
Ok(())
}
pub fn delete(&mut self, id: NodeId) -> Result<()> {
{
let mut mt = lock_or_recover(&self.memtable);
mt.delete(id)?;
}
{
let mut w = lock_or_recover(&self.wal);
w.append(&WalEntry::Delete::<T> { id })?;
}
Ok(())
}
pub fn unlink(&mut self, src: NodeId, dst: NodeId) -> Result<()> {
{
let mut mt = lock_or_recover(&self.memtable);
mt.unlink(src, dst)?;
}
{
let mut w = lock_or_recover(&self.wal);
w.append(&WalEntry::Unlink::<T> { src, dst })?;
}
Ok(())
}
pub fn update_payload(&mut self, id: NodeId, payload: serde_json::Value) -> Result<()> {
let payload_str = payload.to_string();
if payload_str.len() > 8 * 1024 * 1024 {
return Err(crate::error::TriviumError::Generic("Payload size exceeds maximum allowed limit (8MB)".into()));
}
{
let mut mt = lock_or_recover(&self.memtable);
mt.update_payload(id, payload.clone())?;
}
{
let mut w = lock_or_recover(&self.wal);
w.append(&WalEntry::UpdatePayload::<T> {
id,
payload: payload_str,
})?;
}
Ok(())
}
pub fn update_vector(&mut self, id: NodeId, vector: &[T]) -> Result<()> {
{
let mut mt = lock_or_recover(&self.memtable);
mt.update_vector(id, vector)?;
}
{
let mut w = lock_or_recover(&self.wal);
w.append(&WalEntry::UpdateVector::<T> {
id,
vector: vector.to_vec(),
})?;
}
Ok(())
}
pub fn leiden_cluster(
&self,
min_community_size: usize,
max_iterations: Option<usize>,
with_centroids: Option<bool>,
) -> Result<crate::graph::leiden::LeidenResult> {
let config = crate::graph::leiden::LeidenConfig {
min_community_size,
max_iterations: max_iterations.unwrap_or(15),
compute_centroids: with_centroids.unwrap_or(true),
};
let (snapshot, dim) = {
let mt = lock_or_recover(&self.memtable);
let node_ids = mt.all_node_ids();
let mut edges = std::collections::HashMap::new();
for &id in &node_ids {
if let Some(e) = mt.get_edges(id) {
edges.insert(id, e.iter().map(|edge| (edge.target_id, edge.weight)).collect());
}
}
(
crate::graph::leiden::AdjacencySnapshot { edges, node_ids },
mt.dim(),
)
};
let mut result = crate::graph::leiden::run_leiden(&snapshot, &config);
if config.compute_centroids && !result.node_to_cluster.is_empty() {
let vectors = {
let mt = lock_or_recover(&self.memtable);
let mut vecs = std::collections::HashMap::new();
for &node_id in result.node_to_cluster.keys() {
if let Some(v) = mt.get_vector(node_id) {
vecs.insert(node_id, v.iter().map(|x| x.to_f32()).collect::<Vec<f32>>());
}
}
vecs
};
crate::graph::leiden::compute_centroids(&mut result, &vectors, dim);
}
Ok(result)
}
pub fn index_keyword(&mut self, id: NodeId, keyword: &str) -> Result<()> {
let mut mt = lock_or_recover(&self.memtable);
mt.index_keyword(id, keyword);
Ok(())
}
pub fn index_text(&mut self, id: NodeId, text: &str) -> Result<()> {
let mut mt = lock_or_recover(&self.memtable);
mt.index_text(id, text);
Ok(())
}
pub fn build_text_index(&mut self) -> Result<()> {
let mut mt = lock_or_recover(&self.memtable);
mt.build_text_index();
Ok(())
}
pub fn get_payload(&self, id: NodeId) -> Option<serde_json::Value> {
let mt = lock_or_recover(&self.memtable);
mt.get_payload(id).cloned()
}
pub fn get_edges(&self, id: NodeId) -> Vec<crate::node::Edge> {
let mt = lock_or_recover(&self.memtable);
mt.get_edges(id).map(|e| e.to_vec()).unwrap_or_default()
}
pub fn get_all_ids(&self) -> Vec<NodeId> {
let mt = lock_or_recover(&self.memtable);
mt.get_all_ids()
}
pub fn search(
&self,
query_vector: &[T],
top_k: usize,
expand_depth: usize,
min_score: f32,
) -> Result<Vec<SearchHit>> {
let config = SearchConfig {
top_k,
expand_depth,
min_score,
enable_advanced_pipeline: false,
..Default::default()
};
self.search_hybrid(None, Some(query_vector), &config)
}
pub fn search_advanced(
&self,
query_vector: &[T],
config: &SearchConfig,
) -> Result<Vec<SearchHit>> {
self.search_hybrid(None, Some(query_vector), config)
}
pub fn search_hybrid_with_context(
&self,
query_text: Option<&str>,
query_vector: Option<&[T]>,
config: &SearchConfig,
) -> Result<(Vec<SearchHit>, HookContext)> {
let mut ctx = HookContext::new();
let results = pipeline::execute_pipeline(
&self.memtable,
&self.hook,
query_text,
query_vector,
config,
&mut ctx,
)?;
Ok((results, ctx))
}
pub fn search_hybrid(
&self,
query_text: Option<&str>,
query_vector: Option<&[T]>,
config: &SearchConfig,
) -> Result<Vec<SearchHit>> {
let mut ctx = HookContext::new();
pipeline::execute_pipeline(
&self.memtable,
&self.hook,
query_text,
query_vector,
config,
&mut ctx,
)
}
pub fn get(&self, id: NodeId) -> Option<crate::node::NodeView<T>> {
let mt = lock_or_recover(&self.memtable);
let payload = mt.get_payload(id)?.clone();
let vector = mt.get_vector(id)?.to_vec();
let edges = mt.get_edges(id).unwrap_or(&[]).to_vec();
Some(crate::node::NodeView {
id,
vector,
payload,
edges,
})
}
pub fn neighbors(&self, id: NodeId, depth: usize) -> Vec<NodeId> {
use std::collections::{HashSet, VecDeque};
let mt = lock_or_recover(&self.memtable);
let mut visited = HashSet::new();
let mut queue = VecDeque::new();
visited.insert(id);
queue.push_back((id, 0usize));
while let Some((curr, d)) = queue.pop_front() {
if d >= depth {
continue;
}
if let Some(edges) = mt.get_edges(curr) {
for edge in edges {
if visited.insert(edge.target_id) {
queue.push_back((edge.target_id, d + 1));
}
}
}
}
visited.remove(&id);
visited.into_iter().collect()
}
pub fn filter(&self, key: &str, value: &serde_json::Value) -> Vec<crate::node::NodeView<T>> {
let mt = lock_or_recover(&self.memtable);
let mut results = Vec::new();
for nid in mt.all_node_ids() {
if let Some(payload) = mt.get_payload(nid)
&& payload.get(key) == Some(value) {
let vector = mt.get_vector(nid).unwrap_or(&[]).to_vec();
let edges = mt.get_edges(nid).unwrap_or(&[]).to_vec();
results.push(crate::node::NodeView {
id: nid,
vector,
payload: payload.clone(),
edges,
});
}
}
results
}
pub fn filter_where(&self, condition: &crate::filter::Filter) -> Vec<crate::node::NodeView<T>> {
let mt = lock_or_recover(&self.memtable);
let mut results = Vec::new();
for nid in mt.all_node_ids() {
if let Some(payload) = mt.get_payload(nid)
&& condition.matches(payload) {
let vector = mt.get_vector(nid).unwrap_or(&[]).to_vec();
let edges = mt.get_edges(nid).unwrap_or(&[]).to_vec();
results.push(crate::node::NodeView {
id: nid,
vector,
payload: payload.clone(),
edges,
});
}
}
results
}
pub fn flush(&mut self) -> Result<()> {
{
let mut mt = lock_or_recover(&self.memtable);
file_format::save(&mut mt, &self.db_path, self.storage_mode)?;
}
{
let mut w = lock_or_recover(&self.wal);
w.clear()?;
}
Ok(())
}
pub fn close(mut self) -> Result<()> {
self.disable_auto_compaction();
self.flush()
}
pub fn node_count(&self) -> usize {
lock_or_recover(&self.memtable).node_count()
}
pub fn contains(&self, id: NodeId) -> bool {
lock_or_recover(&self.memtable).contains(id)
}
pub fn dim(&self) -> usize {
lock_or_recover(&self.memtable).dim()
}
pub fn all_node_ids(&self) -> Vec<NodeId> {
lock_or_recover(&self.memtable).all_node_ids()
}
pub fn migrate_to(&self, new_path: &str, new_dim: usize) -> Result<(Database<T>, Vec<NodeId>)>
where
T: serde::Serialize + serde::de::DeserializeOwned,
{
let mt = lock_or_recover(&self.memtable);
let mut node_ids = mt.all_node_ids();
node_ids.sort();
let mut new_db = Database::<T>::open(new_path, new_dim)?;
let zero_vec = vec![T::zero(); new_dim];
for &nid in &node_ids {
if let Some(payload) = mt.get_payload(nid) {
new_db.insert_with_id(nid, &zero_vec, payload.clone())?;
}
}
for &nid in &node_ids {
if let Some(edges) = mt.get_edges(nid) {
for edge in edges {
if mt.get_payload(edge.target_id).is_some() {
new_db.link(nid, edge.target_id, &edge.label, edge.weight)?;
}
}
}
}
new_db.flush()?;
tracing::info!(
"维度迁移完成: {} → {},共迁移 {} 个节点",
mt.dim(),
new_dim,
node_ids.len()
);
Ok((new_db, node_ids))
}
pub fn begin_tx(&mut self) -> Transaction<'_, T> {
Transaction {
db: self,
ops: Vec::new(),
committed: false,
}
}
pub fn query(
&self,
cypher: &str,
) -> Result<Vec<std::collections::HashMap<String, crate::node::NodeView<T>>>> {
let ast = crate::query::parser::parse(cypher)
.map_err(|e| crate::error::TriviumError::Generic(format!("查询语句解析失败: {}", e)))?;
let mt = lock_or_recover(&self.memtable);
Ok(crate::query::executor::execute(&ast, &mt)?)
}
}
impl<T: VectorType> Drop for Database<T> {
fn drop(&mut self) {
self.compaction.take();
if let Ok(mut w) = self.wal.lock() {
w.flush_writer();
}
}
}