use crate::error::{Result, TriviumError};
#[cfg(not(feature = "hnsw"))]
use crate::index::brute_force;
#[cfg(feature = "hnsw")]
use crate::index::hnsw::HnswIndex;
use crate::node::{NodeId, SearchHit};
use crate::storage::compaction::CompactionThread;
use crate::storage::file_format;
use crate::storage::memtable::MemTable;
use crate::storage::wal::{Wal, WalEntry, SyncMode};
use crate::VectorType;
use fs2::FileExt;
use std::sync::{Arc, Mutex, MutexGuard};
use std::time::Duration;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum StorageMode {
#[default]
Mmap,
Rom,
}
#[derive(Debug, Clone, Copy)]
pub struct Config {
pub dim: usize,
pub sync_mode: SyncMode,
pub storage_mode: StorageMode,
}
impl Default for Config {
fn default() -> Self {
Self {
dim: 1536,
sync_mode: SyncMode::default(),
storage_mode: StorageMode::default(),
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct SearchConfig {
pub top_k: usize,
pub expand_depth: usize,
pub min_score: f32,
pub teleport_alpha: f32,
pub enable_advanced_pipeline: bool,
pub enable_sparse_residual: bool,
pub fista_lambda: f32,
pub fista_threshold: f32,
pub enable_dpp: bool,
pub dpp_quality_weight: f32,
pub enable_inverse_inhibition: bool,
pub lateral_inhibition_threshold: usize,
pub enable_bq_coarse_search: bool,
pub bq_candidate_ratio: f32,
pub text_boost: f32,
pub enable_text_hybrid_search: bool,
pub bm25_k1: f32,
pub bm25_b: f32,
}
impl Default for SearchConfig {
fn default() -> Self {
Self {
top_k: 5,
expand_depth: 2,
min_score: 0.1,
teleport_alpha: 0.0,
enable_advanced_pipeline: false,
enable_sparse_residual: false,
fista_lambda: 0.1,
fista_threshold: 0.30,
enable_dpp: false,
dpp_quality_weight: 1.0,
enable_inverse_inhibition: false,
lateral_inhibition_threshold: 0,
enable_bq_coarse_search: false,
bq_candidate_ratio: 0.1,
text_boost: 1.5,
enable_text_hybrid_search: false,
bm25_k1: 1.2,
bm25_b: 0.75,
}
}
}
fn lock_or_recover<T>(mutex: &Mutex<T>) -> MutexGuard<'_, T> {
mutex.lock().unwrap_or_else(|poisoned| {
tracing::warn!("Mutex was poisoned, recovering...");
poisoned.into_inner()
})
}
pub struct Database<T: VectorType> {
db_path: String,
memtable: Arc<Mutex<MemTable<T>>>,
wal: Arc<Mutex<Wal>>,
#[cfg(feature = "hnsw")]
hnsw_index: HnswIndex<T>,
compaction: Option<CompactionThread>,
_lock_file: std::fs::File,
memory_limit: usize,
storage_mode: StorageMode,
}
impl<T: VectorType + serde::Serialize + serde::de::DeserializeOwned> Database<T> {
pub fn open(path: &str, dim: usize) -> Result<Self> {
let config = Config { dim, ..Default::default() };
Self::open_with_config(path, config)
}
pub fn open_with_sync(path: &str, dim: usize, sync_mode: SyncMode) -> Result<Self> {
let config = Config { dim, sync_mode, ..Default::default() };
Self::open_with_config(path, config)
}
pub fn open_with_config(path: &str, config: Config) -> Result<Self> {
let dim = config.dim;
if let Some(parent_dir) = std::path::Path::new(path).parent() {
if !parent_dir.as_os_str().is_empty() {
std::fs::create_dir_all(parent_dir)?;
}
}
let lock_path = format!("{}.lock", path);
let lock_file = std::fs::OpenOptions::new()
.create(true)
.write(true)
.open(&lock_path)?;
lock_file.try_lock_exclusive().map_err(|_| {
TriviumError::Generic(format!(
"Database '{}' is already opened by another process. \
If this is unexpected, delete '{}'",
path, lock_path
))
})?;
let mut memtable = if std::path::Path::new(path).exists() {
file_format::load(path, config.storage_mode)?
} else {
MemTable::new(dim)
};
if Wal::needs_recovery(path) {
let entries = Wal::read_entries::<T>(path)?;
if !entries.is_empty() {
tracing::info!("Recovering {} entries from WAL...", entries.len());
for entry in entries {
replay_entry(&mut memtable, entry);
}
}
}
let wal = Wal::open_with_sync(path, config.sync_mode)?;
#[cfg(feature = "hnsw")]
let hnsw_index = HnswIndex::<T>::new(dim);
Ok(Self {
db_path: path.to_string(),
memtable: Arc::new(Mutex::new(memtable)),
wal: Arc::new(Mutex::new(wal)),
#[cfg(feature = "hnsw")]
hnsw_index,
compaction: None,
_lock_file: lock_file,
memory_limit: 0,
storage_mode: config.storage_mode,
})
}
pub fn set_sync_mode(&mut self, mode: SyncMode) {
let mut w = lock_or_recover(&self.wal);
w.set_sync_mode(mode);
}
pub fn set_memory_limit(&mut self, bytes: usize) {
self.memory_limit = bytes;
}
pub fn estimated_memory(&self) -> usize {
lock_or_recover(&self.memtable).estimated_memory_bytes()
}
fn check_memory_pressure(&mut self) {
if self.memory_limit > 0 {
let usage = lock_or_recover(&self.memtable).estimated_memory_bytes();
if usage > self.memory_limit {
tracing::info!(
"Memory pressure: {}MB > limit {}MB. Auto-flushing...",
usage / (1024 * 1024),
self.memory_limit / (1024 * 1024)
);
if let Err(e) = self.flush() {
tracing::error!("Auto-flush failed: {}", e);
}
}
}
}
pub fn enable_auto_compaction(&mut self, interval: Duration) {
self.compaction.take();
let ct = CompactionThread::spawn(
interval,
Arc::clone(&self.memtable),
Arc::clone(&self.wal),
self.db_path.clone(),
self.storage_mode,
);
self.compaction = Some(ct);
}
pub fn disable_auto_compaction(&mut self) {
self.compaction.take();
}
pub fn insert(&mut self, vector: &[T], payload: serde_json::Value) -> Result<NodeId> {
let id = {
let mut mt = lock_or_recover(&self.memtable);
mt.insert(vector, payload.clone())?
};
{
let mut w = lock_or_recover(&self.wal);
w.append(&WalEntry::Insert {
id,
vector: vector.to_vec(),
payload,
})?;
}
self.check_memory_pressure();
Ok(id)
}
pub fn insert_with_id(&mut self, id: NodeId, vector: &[T], payload: serde_json::Value) -> Result<()> {
{
let mut mt = lock_or_recover(&self.memtable);
mt.insert_with_id(id, vector, payload.clone())?;
}
{
let mut w = lock_or_recover(&self.wal);
w.append(&WalEntry::Insert {
id,
vector: vector.to_vec(),
payload,
})?;
}
self.check_memory_pressure();
Ok(())
}
pub fn link(&mut self, src: NodeId, dst: NodeId, label: &str, weight: f32) -> Result<()> {
{
let mut mt = lock_or_recover(&self.memtable);
mt.link(src, dst, label.to_string(), weight)?;
}
{
let mut w = lock_or_recover(&self.wal);
w.append(&WalEntry::Link::<T> {
src, dst,
label: label.to_string(),
weight,
})?;
}
Ok(())
}
pub fn delete(&mut self, id: NodeId) -> Result<()> {
{
let mut mt = lock_or_recover(&self.memtable);
mt.delete(id)?;
}
{
let mut w = lock_or_recover(&self.wal);
w.append(&WalEntry::Delete::<T> { id })?;
}
Ok(())
}
pub fn unlink(&mut self, src: NodeId, dst: NodeId) -> Result<()> {
{
let mut mt = lock_or_recover(&self.memtable);
mt.unlink(src, dst)?;
}
{
let mut w = lock_or_recover(&self.wal);
w.append(&WalEntry::Unlink::<T> { src, dst })?;
}
Ok(())
}
pub fn update_payload(&mut self, id: NodeId, payload: serde_json::Value) -> Result<()> {
{
let mut mt = lock_or_recover(&self.memtable);
mt.update_payload(id, payload.clone())?;
}
{
let mut w = lock_or_recover(&self.wal);
w.append(&WalEntry::UpdatePayload::<T> { id, payload })?;
}
Ok(())
}
pub fn update_vector(&mut self, id: NodeId, vector: &[T]) -> Result<()> {
{
let mut mt = lock_or_recover(&self.memtable);
mt.update_vector(id, vector)?;
}
{
let mut w = lock_or_recover(&self.wal);
w.append(&WalEntry::UpdateVector::<T> { id, vector: vector.to_vec() })?;
}
Ok(())
}
pub fn index_keyword(&mut self, id: NodeId, keyword: &str) -> Result<()> {
let mut mt = lock_or_recover(&self.memtable);
mt.index_keyword(id, keyword);
Ok(())
}
pub fn index_text(&mut self, id: NodeId, text: &str) -> Result<()> {
let mut mt = lock_or_recover(&self.memtable);
mt.index_text(id, text);
Ok(())
}
pub fn build_text_index(&mut self) -> Result<()> {
let mut mt = lock_or_recover(&self.memtable);
mt.build_text_index();
Ok(())
}
pub fn get_payload(&self, id: NodeId) -> Option<serde_json::Value> {
let mt = lock_or_recover(&self.memtable);
mt.get_payload(id).cloned()
}
pub fn get_edges(&self, id: NodeId) -> Vec<crate::node::Edge> {
let mt = lock_or_recover(&self.memtable);
mt.get_edges(id).map(|e| e.to_vec()).unwrap_or_default()
}
pub fn get_all_ids(&self) -> Vec<NodeId> {
let mt = lock_or_recover(&self.memtable);
mt.get_all_ids() }
pub fn search(
&self,
query_vector: &[T],
top_k: usize,
expand_depth: usize,
min_score: f32,
) -> Result<Vec<SearchHit>> {
let config = SearchConfig {
top_k,
expand_depth,
min_score,
enable_advanced_pipeline: false,
..Default::default()
};
self.search_hybrid(None, Some(query_vector), &config)
}
pub fn search_advanced(
&self,
query_vector: &[T],
config: &SearchConfig,
) -> Result<Vec<SearchHit>> {
self.search_hybrid(None, Some(query_vector), config)
}
pub fn search_hybrid(
&self,
query_text: Option<&str>,
query_vector: Option<&[T]>,
config: &SearchConfig,
) -> Result<Vec<SearchHit>> {
let mut mt = lock_or_recover(&self.memtable);
let dim = mt.dim();
if let Some(qv) = query_vector {
if qv.len() != dim {
return Err(crate::error::TriviumError::DimensionMismatch {
expected: dim,
got: qv.len(),
});
}
for item in qv {
let f = item.to_f32();
if f.is_nan() || f.is_infinite() {
return Err(crate::error::TriviumError::Generic("Query vector contains NaN or Infinity".to_string()));
}
}
}
let mut safe_cfg = *config;
safe_cfg.top_k = safe_cfg.top_k.max(1);
safe_cfg.fista_lambda = safe_cfg.fista_lambda.clamp(1e-5, 100.0); safe_cfg.teleport_alpha = safe_cfg.teleport_alpha.clamp(0.0, 1.0); safe_cfg.dpp_quality_weight = safe_cfg.dpp_quality_weight.clamp(0.0, 10.0); safe_cfg.fista_threshold = safe_cfg.fista_threshold.clamp(0.0, f32::MAX);
safe_cfg.bq_candidate_ratio = safe_cfg.bq_candidate_ratio.clamp(0.0, 1.0);
let config = &safe_cfg;
let mut anchor_hits: Vec<SearchHit> = Vec::new();
let mut seed_map: std::collections::HashMap<NodeId, f32> = std::collections::HashMap::new();
if config.enable_text_hybrid_search {
if let Some(txt) = query_text {
let text_engine = mt.text_engine();
let ac_hits = text_engine.search_ac(txt);
for (id, score) in ac_hits {
*seed_map.entry(id).or_insert(0.0) += score * config.text_boost; }
let bm25_hits = text_engine.search_bm25(txt, config.bm25_k1, config.bm25_b);
for (id, score) in bm25_hits {
let normalized_score = (score / 10.0).clamp(0.0, 1.0) * config.text_boost;
*seed_map.entry(id).or_insert(0.0) += normalized_score;
}
}
}
if let Some(query_vector) = query_vector {
#[cfg(not(feature = "hnsw"))]
let vector_hits: Vec<SearchHit> = {
let dim = mt.dim();
mt.ensure_vectors_cache();
let vectors = mt.flat_vectors();
if config.enable_bq_coarse_search {
let q_bq = crate::index::bq::BqSignature::from_vector(query_vector);
let m_count = mt.node_count(); let candidate_cnt = (((m_count as f32) * config.bq_candidate_ratio).ceil() as usize)
.max(config.top_k);
let mut bq_scores: Vec<(usize, u32)> = (0..m_count)
.filter_map(|i| mt.get_bq_signature(i).map(|sig| (i, sig.hamming_distance(&q_bq))))
.collect();
bq_scores.sort_unstable_by_key(|&(_, dist)| dist);
bq_scores.truncate(candidate_cnt);
let mut refined = Vec::with_capacity(candidate_cnt);
for (i, _dist) in bq_scores {
let offset = i * dim;
if offset + dim <= vectors.len() {
let score = T::similarity(query_vector, &vectors[offset..offset + dim]);
if score >= config.min_score {
refined.push(SearchHit {
id: mt.get_id_by_index(i),
score,
payload: serde_json::Value::Null, });
}
}
}
refined.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal));
refined.truncate(config.top_k);
for hit in &mut refined {
if let Some(p) = mt.get_payload(hit.id) {
hit.payload = p.clone();
}
}
refined
} else {
brute_force::search(
query_vector, vectors, dim, config.top_k, config.min_score,
|idx| mt.get_id_by_index(idx),
)
}
};
#[cfg(feature = "hnsw")]
let vector_hits: Vec<SearchHit> = {
self.hnsw_index.search(query_vector, config.top_k, config.min_score)
};
for hit in vector_hits {
*seed_map.entry(hit.id).or_insert(0.0) += hit.score;
}
if config.enable_advanced_pipeline {
if config.enable_sparse_residual && !seed_map.is_empty() {
let entity_vecs: Vec<Vec<f32>> = seed_map.keys()
.filter_map(|&id| mt.get_vector(id).map(|v| v.iter().map(|&x| x.to_f32()).collect()))
.collect();
let q_f32: Vec<f32> = query_vector.iter().map(|&x| x.to_f32()).collect();
let (_, residual, residual_norm) = crate::cognitive::fista_solve(&q_f32, &entity_vecs, config.fista_lambda, 80);
if residual_norm > config.fista_threshold {
tracing::debug!("FISTA Residual magnitude high ({} > {}). Triggering Shadow Query.", residual_norm, config.fista_threshold);
let r_orig: Vec<T> = residual.iter().map(|&x| T::from_f32(x)).collect();
let shadow_hits: Vec<SearchHit> = {
#[cfg(not(feature = "hnsw"))]
{
let dim = mt.dim();
brute_force::search(
&r_orig, mt.flat_vectors(), dim, config.top_k, config.min_score,
|idx| mt.get_id_by_index(idx),
)
}
#[cfg(feature = "hnsw")]
{
self.hnsw_index.search(&r_orig, config.top_k, config.min_score)
}
};
for sh in shadow_hits {
*seed_map.entry(sh.id).or_insert(0.0) += sh.score * 0.8; }
}
}
}
}
for (id, score) in seed_map {
if score >= config.min_score {
let payload = mt.get_payload(id).cloned().unwrap_or(serde_json::Value::Null);
anchor_hits.push(SearchHit { id, score, payload });
}
}
anchor_hits.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal));
anchor_hits.truncate(config.top_k.max(15));
if anchor_hits.is_empty() {
return Ok(vec![]);
}
if anchor_hits.is_empty() {
return Ok(Vec::new());
}
let mut seeds = Vec::with_capacity(anchor_hits.len());
for mut hit in anchor_hits {
if let Some(payload) = mt.get_payload(hit.id) {
hit.payload = payload.clone();
seeds.push(hit);
}
}
let mut expanded = crate::graph::traversal::expand_graph(
&mt,
seeds,
config.expand_depth,
config.teleport_alpha,
config.enable_inverse_inhibition,
config.lateral_inhibition_threshold,
);
if config.enable_advanced_pipeline && config.enable_dpp && expanded.len() > config.top_k {
let limit = config.top_k;
let dpp_pool_size = std::cmp::min(expanded.len(), limit * 3);
let mut pool_vecs = Vec::with_capacity(dpp_pool_size);
let mut pool_scores = Vec::with_capacity(dpp_pool_size);
let mut pool_valid = Vec::with_capacity(dpp_pool_size);
for i in 0..dpp_pool_size {
let hit = &expanded[i];
if let Some(v) = mt.get_vector(hit.id) {
pool_vecs.push(v.iter().map(|&x| x.to_f32()).collect());
pool_scores.push(hit.score);
pool_valid.push(hit.clone());
}
}
if pool_valid.len() > limit {
let selected_idx = crate::cognitive::dpp_greedy(
&pool_vecs,
&pool_scores,
limit,
config.dpp_quality_weight
);
let mut final_results = Vec::with_capacity(limit);
for &idx in &selected_idx {
final_results.push(pool_valid[idx].clone());
}
final_results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal));
return Ok(final_results);
}
}
expanded.truncate(config.top_k);
Ok(expanded)
}
pub fn get(&self, id: NodeId) -> Option<crate::node::NodeView<T>> {
let mt = lock_or_recover(&self.memtable);
let payload = mt.get_payload(id)?.clone();
let vector = mt.get_vector(id)?.to_vec();
let edges = mt.get_edges(id).unwrap_or(&[]).to_vec();
Some(crate::node::NodeView { id, vector, payload, edges })
}
pub fn neighbors(&self, id: NodeId, depth: usize) -> Vec<NodeId> {
use std::collections::{HashSet, VecDeque};
let mt = lock_or_recover(&self.memtable);
let mut visited = HashSet::new();
let mut queue = VecDeque::new();
visited.insert(id);
queue.push_back((id, 0usize));
while let Some((curr, d)) = queue.pop_front() {
if d >= depth { continue; }
if let Some(edges) = mt.get_edges(curr) {
for edge in edges {
if visited.insert(edge.target_id) {
queue.push_back((edge.target_id, d + 1));
}
}
}
}
visited.remove(&id);
visited.into_iter().collect()
}
pub fn filter(&self, key: &str, value: &serde_json::Value) -> Vec<crate::node::NodeView<T>> {
let mt = lock_or_recover(&self.memtable);
let mut results = Vec::new();
for nid in mt.all_node_ids() {
if let Some(payload) = mt.get_payload(nid) {
if payload.get(key) == Some(value) {
let vector = mt.get_vector(nid).unwrap_or(&[]).to_vec();
let edges = mt.get_edges(nid).unwrap_or(&[]).to_vec();
results.push(crate::node::NodeView {
id: nid, vector, payload: payload.clone(), edges,
});
}
}
}
results
}
pub fn filter_where(&self, condition: &crate::filter::Filter) -> Vec<crate::node::NodeView<T>> {
let mt = lock_or_recover(&self.memtable);
let mut results = Vec::new();
for nid in mt.all_node_ids() {
if let Some(payload) = mt.get_payload(nid) {
if condition.matches(payload) {
let vector = mt.get_vector(nid).unwrap_or(&[]).to_vec();
let edges = mt.get_edges(nid).unwrap_or(&[]).to_vec();
results.push(crate::node::NodeView {
id: nid, vector, payload: payload.clone(), edges,
});
}
}
}
results
}
pub fn flush(&mut self) -> Result<()> {
{
let mut mt = lock_or_recover(&self.memtable);
file_format::save(&mut mt, &self.db_path, self.storage_mode)?;
}
{
let mut w = lock_or_recover(&self.wal);
w.clear()?;
}
Ok(())
}
pub fn close(mut self) -> Result<()> {
self.disable_auto_compaction();
self.flush()
}
pub fn node_count(&self) -> usize {
lock_or_recover(&self.memtable).node_count()
}
pub fn contains(&self, id: NodeId) -> bool {
lock_or_recover(&self.memtable).contains(id)
}
pub fn dim(&self) -> usize {
lock_or_recover(&self.memtable).dim()
}
pub fn all_node_ids(&self) -> Vec<NodeId> {
lock_or_recover(&self.memtable).all_node_ids()
}
pub fn rebuild_index(&mut self) {
#[cfg(feature = "hnsw")]
{
let mut mt = lock_or_recover(&self.memtable);
let dim = mt.dim();
mt.ensure_vectors_cache();
let flat = mt.flat_vectors();
self.hnsw_index.rebuild(
flat,
dim,
|idx| mt.get_id_by_index(idx),
|idx| {
let nid = mt.get_id_by_index(idx);
mt.contains(nid)
},
);
tracing::info!("HNSW 索引重建完成,共 {} 个活跃节点", mt.node_count());
}
#[cfg(not(feature = "hnsw"))]
{
tracing::debug!("未启用 HNSW feature,rebuild_index 为 no-op");
}
}
pub fn migrate_to(
&self,
new_path: &str,
new_dim: usize,
) -> Result<(Database<T>, Vec<NodeId>)>
where
T: serde::Serialize + serde::de::DeserializeOwned,
{
let mt = lock_or_recover(&self.memtable);
let mut node_ids = mt.all_node_ids();
node_ids.sort();
let mut new_db = Database::<T>::open(new_path, new_dim)?;
let zero_vec = vec![T::zero(); new_dim];
for &nid in &node_ids {
if let Some(payload) = mt.get_payload(nid) {
new_db.insert_with_id(nid, &zero_vec, payload.clone())?;
}
}
for &nid in &node_ids {
if let Some(edges) = mt.get_edges(nid) {
for edge in edges {
if mt.get_payload(edge.target_id).is_some() {
new_db.link(nid, edge.target_id, &edge.label, edge.weight)?;
}
}
}
}
new_db.flush()?;
tracing::info!(
"维度迁移完成: {} → {},共迁移 {} 个节点",
mt.dim(), new_dim, node_ids.len()
);
Ok((new_db, node_ids))
}
pub fn begin_tx(&mut self) -> Transaction<'_, T> {
Transaction {
db: self,
ops: Vec::new(),
committed: false,
}
}
pub fn query(&self, cypher: &str) -> Result<Vec<std::collections::HashMap<String, crate::node::NodeView<T>>>> {
let ast = crate::query::parser::parse(cypher)
.map_err(|e| crate::error::TriviumError::Generic(format!("查询语句解析失败: {}", e)))?;
let mt = lock_or_recover(&self.memtable);
Ok(crate::query::executor::execute(&ast, &mt))
}
}
fn replay_entry<T: VectorType>(mt: &mut MemTable<T>, entry: WalEntry<T>) {
match entry {
WalEntry::Insert { id, vector, payload } => { let _ = mt.raw_insert(id, &vector, payload); }
WalEntry::Link { src, dst, label, weight } => { let _ = mt.link(src, dst, label, weight); }
WalEntry::Delete { id } => { let _ = mt.delete(id); }
WalEntry::Unlink { src, dst } => { let _ = mt.unlink(src, dst); }
WalEntry::UpdatePayload { id, payload } => { let _ = mt.update_payload(id, payload); }
WalEntry::UpdateVector { id, vector } => { let _ = mt.update_vector(id, &vector); }
WalEntry::TxBegin { .. } | WalEntry::TxCommit { .. } => {
}
}
}
enum TxOp<T> {
Insert { vector: Vec<T>, payload: serde_json::Value },
InsertWithId { id: NodeId, vector: Vec<T>, payload: serde_json::Value },
Link { src: NodeId, dst: NodeId, label: String, weight: f32 },
Delete { id: NodeId },
Unlink { src: NodeId, dst: NodeId },
UpdatePayload { id: NodeId, payload: serde_json::Value },
UpdateVector { id: NodeId, vector: Vec<T> },
}
pub struct Transaction<'a, T: VectorType + serde::Serialize + serde::de::DeserializeOwned> {
db: &'a mut Database<T>,
ops: Vec<TxOp<T>>,
committed: bool,
}
impl<'a, T: VectorType + serde::Serialize + serde::de::DeserializeOwned> Transaction<'a, T> {
pub fn insert(&mut self, vector: &[T], payload: serde_json::Value) {
self.ops.push(TxOp::Insert {
vector: vector.to_vec(),
payload,
});
}
pub fn insert_with_id(&mut self, id: NodeId, vector: &[T], payload: serde_json::Value) {
self.ops.push(TxOp::InsertWithId {
id,
vector: vector.to_vec(),
payload,
});
}
pub fn link(&mut self, src: NodeId, dst: NodeId, label: &str, weight: f32) {
self.ops.push(TxOp::Link {
src, dst,
label: label.to_string(),
weight,
});
}
pub fn delete(&mut self, id: NodeId) {
self.ops.push(TxOp::Delete { id });
}
pub fn unlink(&mut self, src: NodeId, dst: NodeId) {
self.ops.push(TxOp::Unlink { src, dst });
}
pub fn update_payload(&mut self, id: NodeId, payload: serde_json::Value) {
self.ops.push(TxOp::UpdatePayload { id, payload });
}
pub fn update_vector(&mut self, id: NodeId, vector: &[T]) {
self.ops.push(TxOp::UpdateVector {
id,
vector: vector.to_vec(),
});
}
pub fn pending_count(&self) -> usize {
self.ops.len()
}
pub fn commit(mut self) -> Result<Vec<NodeId>> {
let ops = std::mem::take(&mut self.ops);
if ops.is_empty() {
self.committed = true;
return Ok(Vec::new());
}
let mut mt = lock_or_recover(&self.db.memtable);
let mut sim_next_id = mt.next_id_value();
let dim = mt.dim();
let mut pending_ids = std::collections::HashSet::new();
let mut pending_deletes = std::collections::HashSet::new();
macro_rules! check_exists {
($id:expr) => {
!pending_deletes.contains($id) && (pending_ids.contains($id) || mt.contains(*$id))
}
}
for op in &ops {
match op {
TxOp::Insert { vector, .. } => {
if vector.len() != dim {
return Err(crate::error::TriviumError::DimensionMismatch { expected: dim, got: vector.len() });
}
pending_ids.insert(sim_next_id);
sim_next_id += 1;
}
TxOp::InsertWithId { id, vector, .. } => {
if check_exists!(id) {
return Err(crate::error::TriviumError::Generic(format!("Node {} already exists", id)));
}
if vector.len() != dim {
return Err(crate::error::TriviumError::DimensionMismatch { expected: dim, got: vector.len() });
}
pending_ids.insert(*id);
if *id >= sim_next_id { sim_next_id = *id + 1; }
}
TxOp::Link { src, dst, .. } => {
if !check_exists!(src) { return Err(crate::error::TriviumError::NodeNotFound(*src)); }
if !check_exists!(dst) { return Err(crate::error::TriviumError::NodeNotFound(*dst)); }
}
TxOp::Delete { id } => {
if !check_exists!(id) { return Err(crate::error::TriviumError::NodeNotFound(*id)); }
pending_deletes.insert(*id);
}
TxOp::Unlink { src, .. } => {
if !check_exists!(src) { return Err(crate::error::TriviumError::NodeNotFound(*src)); }
}
TxOp::UpdatePayload { id, .. } => {
if !check_exists!(id) { return Err(crate::error::TriviumError::NodeNotFound(*id)); }
}
TxOp::UpdateVector { id, vector } => {
if !check_exists!(id) { return Err(crate::error::TriviumError::NodeNotFound(*id)); }
if vector.len() != dim {
return Err(crate::error::TriviumError::DimensionMismatch { expected: dim, got: vector.len() });
}
}
}
}
let mut wal_entries: Vec<WalEntry<T>> = Vec::with_capacity(ops.len());
let mut generated_ids: Vec<NodeId> = Vec::new();
for op in ops {
match op {
TxOp::Insert { vector, payload } => {
let id = mt.insert(&vector, payload.clone())?;
wal_entries.push(WalEntry::Insert { id, vector, payload });
generated_ids.push(id);
}
TxOp::InsertWithId { id, vector, payload } => {
mt.insert_with_id(id, &vector, payload.clone())?;
wal_entries.push(WalEntry::Insert { id, vector, payload });
generated_ids.push(id);
}
TxOp::Link { src, dst, label, weight } => {
mt.link(src, dst, label.clone(), weight)?;
wal_entries.push(WalEntry::Link { src, dst, label, weight });
}
TxOp::Delete { id } => {
mt.delete(id)?;
wal_entries.push(WalEntry::Delete { id });
}
TxOp::Unlink { src, dst } => {
mt.unlink(src, dst)?;
wal_entries.push(WalEntry::Unlink { src, dst });
}
TxOp::UpdatePayload { id, payload } => {
mt.update_payload(id, payload.clone())?;
wal_entries.push(WalEntry::UpdatePayload { id, payload });
}
TxOp::UpdateVector { id, vector } => {
mt.update_vector(id, &vector)?;
wal_entries.push(WalEntry::UpdateVector { id, vector });
}
}
}
drop(mt);
{
let mut w = lock_or_recover(&self.db.wal);
let tx_id = std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_nanos() as u64;
w.append_batch(tx_id, &wal_entries)?;
}
self.committed = true;
Ok(generated_ids)
}
pub fn rollback(mut self) {
self.ops.clear();
self.committed = true; }
}
impl<'a, T: VectorType + serde::Serialize + serde::de::DeserializeOwned> Drop for Transaction<'a, T> {
fn drop(&mut self) {
if !self.committed && !self.ops.is_empty() {
tracing::warn!(
"Transaction with {} pending ops was dropped without commit/rollback. Operations discarded.",
self.ops.len()
);
}
}
}