use crate::VectorType;
use crate::database::Database;
use crate::error::Result;
use crate::node::NodeId;
use crate::storage::memtable::MemTable;
use crate::storage::wal::WalEntry;
use super::lock_or_recover;
pub(crate) fn replay_entry<T: VectorType>(mt: &mut MemTable<T>, entry: WalEntry<T>) {
match entry {
WalEntry::Insert {
id,
vector,
payload,
} => {
if mt.contains(id) {
tracing::debug!("WAL 回放跳过已存在的节点 {}", id);
} else {
let payload_val: serde_json::Value =
serde_json::from_str(&payload).unwrap_or_default();
let _ = mt.raw_insert(id, &vector, payload_val);
}
mt.advance_next_id(id + 1);
}
WalEntry::Link {
src,
dst,
label,
weight,
} => {
if mt.contains(src) && mt.contains(dst) {
let _ = mt.link(src, dst, label, weight);
}
}
WalEntry::Delete { id } => {
if mt.contains(id) {
let _ = mt.delete(id);
}
}
WalEntry::Unlink { src, dst } => {
if mt.contains(src) {
let _ = mt.unlink(src, dst);
}
}
WalEntry::UpdatePayload { id, payload } => {
if mt.contains(id) {
let payload_val: serde_json::Value =
serde_json::from_str(&payload).unwrap_or_default();
let _ = mt.update_payload(id, payload_val);
}
}
WalEntry::UpdateVector { id, vector } => {
if mt.contains(id) {
let _ = mt.update_vector(id, &vector);
}
}
WalEntry::TxBegin { .. } | WalEntry::TxCommit { .. } => {
}
}
}
pub enum TxOp<T> {
Insert {
vector: Vec<T>,
payload: serde_json::Value,
},
InsertWithId {
id: NodeId,
vector: Vec<T>,
payload: serde_json::Value,
},
Link {
src: NodeId,
dst: NodeId,
label: String,
weight: f32,
},
Delete {
id: NodeId,
},
Unlink {
src: NodeId,
dst: NodeId,
},
UpdatePayload {
id: NodeId,
payload: serde_json::Value,
},
UpdateVector {
id: NodeId,
vector: Vec<T>,
},
}
pub struct Transaction<'a, T: VectorType + serde::Serialize + serde::de::DeserializeOwned> {
pub(crate) db: &'a mut Database<T>,
pub(crate) ops: Vec<TxOp<T>>,
pub(crate) committed: bool,
}
impl<'a, T: VectorType + serde::Serialize + serde::de::DeserializeOwned> Transaction<'a, T> {
pub fn insert(&mut self, vector: &[T], payload: serde_json::Value) {
self.ops.push(TxOp::Insert {
vector: vector.to_vec(),
payload,
});
}
pub fn insert_with_id(&mut self, id: NodeId, vector: &[T], payload: serde_json::Value) {
self.ops.push(TxOp::InsertWithId {
id,
vector: vector.to_vec(),
payload,
});
}
pub fn link(&mut self, src: NodeId, dst: NodeId, label: &str, weight: f32) {
self.ops.push(TxOp::Link {
src,
dst,
label: label.to_string(),
weight,
});
}
pub fn delete(&mut self, id: NodeId) {
self.ops.push(TxOp::Delete { id });
}
pub fn unlink(&mut self, src: NodeId, dst: NodeId) {
self.ops.push(TxOp::Unlink { src, dst });
}
pub fn update_payload(&mut self, id: NodeId, payload: serde_json::Value) {
self.ops.push(TxOp::UpdatePayload { id, payload });
}
pub fn update_vector(&mut self, id: NodeId, vector: &[T]) {
self.ops.push(TxOp::UpdateVector {
id,
vector: vector.to_vec(),
});
}
pub fn pending_count(&self) -> usize {
self.ops.len()
}
pub fn commit(mut self) -> Result<Vec<NodeId>> {
let ops = std::mem::take(&mut self.ops);
self.committed = true;
self.db.commit_ops(ops)
}
pub fn rollback(mut self) {
self.ops.clear();
self.committed = true;
}
}
impl<'a, T: VectorType + serde::Serialize + serde::de::DeserializeOwned> Drop
for Transaction<'a, T>
{
fn drop(&mut self) {
if !self.committed && !self.ops.is_empty() {
tracing::warn!(
"Transaction with {} pending ops was dropped without commit/rollback. Operations discarded.",
self.ops.len()
);
}
}
}
pub struct TxBuilder<T> {
ops: Vec<TxOp<T>>,
}
impl<T: VectorType> TxBuilder<T> {
pub fn new() -> Self {
Self { ops: Vec::new() }
}
pub fn insert(&mut self, vector: &[T], payload: serde_json::Value) {
self.ops.push(TxOp::Insert { vector: vector.to_vec(), payload });
}
pub fn insert_with_id(&mut self, id: NodeId, vector: &[T], payload: serde_json::Value) {
self.ops.push(TxOp::InsertWithId { id, vector: vector.to_vec(), payload });
}
pub fn link(&mut self, src: NodeId, dst: NodeId, label: &str, weight: f32) {
self.ops.push(TxOp::Link { src, dst, label: label.to_string(), weight });
}
pub fn delete(&mut self, id: NodeId) {
self.ops.push(TxOp::Delete { id });
}
pub fn unlink(&mut self, src: NodeId, dst: NodeId) {
self.ops.push(TxOp::Unlink { src, dst });
}
pub fn update_payload(&mut self, id: NodeId, payload: serde_json::Value) {
self.ops.push(TxOp::UpdatePayload { id, payload });
}
pub fn update_vector(&mut self, id: NodeId, vector: &[T]) {
self.ops.push(TxOp::UpdateVector { id, vector: vector.to_vec() });
}
pub fn pending_count(&self) -> usize {
self.ops.len()
}
pub(crate) fn into_ops(self) -> Vec<TxOp<T>> {
self.ops
}
}
impl<T: VectorType> Default for TxBuilder<T> {
fn default() -> Self {
Self::new()
}
}
impl<T: VectorType + serde::Serialize + serde::de::DeserializeOwned> Database<T> {
pub(crate) fn commit_ops(&mut self, ops: Vec<TxOp<T>>) -> Result<Vec<NodeId>> {
if ops.is_empty() {
return Ok(Vec::new());
}
let mut mt = lock_or_recover(&self.memtable);
let mut sim_next_id = mt.next_id_value();
let dim = mt.dim();
let mut pending_ids = std::collections::HashSet::new();
let mut pending_deletes = std::collections::HashSet::new();
let mut pre_assigned_ids: Vec<Option<NodeId>> = Vec::with_capacity(ops.len());
macro_rules! check_exists {
($id:expr) => {
!pending_deletes.contains($id) && (pending_ids.contains($id) || mt.contains(*$id))
};
}
for op in &ops {
match op {
TxOp::Insert { vector, .. } => {
if vector.len() != dim {
return Err(crate::error::TriviumError::DimensionMismatch {
expected: dim,
got: vector.len(),
});
}
for item in vector {
let f = item.to_f32();
if f.is_nan() || f.is_infinite() {
return Err(crate::error::TriviumError::InvalidVector {
reason: "Vector contains NaN or Infinity".into(),
});
}
}
pre_assigned_ids.push(Some(sim_next_id));
pending_ids.insert(sim_next_id);
sim_next_id += 1;
}
TxOp::InsertWithId { id, vector, .. } => {
if check_exists!(id) {
return Err(crate::error::TriviumError::NodeAlreadyExists(*id));
}
if vector.len() != dim {
return Err(crate::error::TriviumError::DimensionMismatch {
expected: dim,
got: vector.len(),
});
}
for item in vector {
let f = item.to_f32();
if f.is_nan() || f.is_infinite() {
return Err(crate::error::TriviumError::InvalidVector {
reason: "Vector contains NaN or Infinity".into(),
});
}
}
pre_assigned_ids.push(Some(*id));
pending_ids.insert(*id);
if *id >= sim_next_id {
sim_next_id = *id + 1;
}
}
TxOp::Link { src, dst, .. } => {
if !check_exists!(src) {
return Err(crate::error::TriviumError::NodeNotFound(*src));
}
if !check_exists!(dst) {
return Err(crate::error::TriviumError::NodeNotFound(*dst));
}
pre_assigned_ids.push(None);
}
TxOp::Delete { id } => {
if !check_exists!(id) {
return Err(crate::error::TriviumError::NodeNotFound(*id));
}
pending_deletes.insert(*id);
pre_assigned_ids.push(None);
}
TxOp::Unlink { src, .. } => {
if !check_exists!(src) {
return Err(crate::error::TriviumError::NodeNotFound(*src));
}
pre_assigned_ids.push(None);
}
TxOp::UpdatePayload { id, .. } => {
if !check_exists!(id) {
return Err(crate::error::TriviumError::NodeNotFound(*id));
}
pre_assigned_ids.push(None);
}
TxOp::UpdateVector { id, vector } => {
if !check_exists!(id) {
return Err(crate::error::TriviumError::NodeNotFound(*id));
}
if vector.len() != dim {
return Err(crate::error::TriviumError::DimensionMismatch {
expected: dim,
got: vector.len(),
});
}
for item in vector {
let f = item.to_f32();
if f.is_nan() || f.is_infinite() {
return Err(crate::error::TriviumError::InvalidVector {
reason: "Vector contains NaN or Infinity".into(),
});
}
}
pre_assigned_ids.push(None);
}
}
}
let mut wal_entries: Vec<WalEntry<T>> = Vec::with_capacity(ops.len());
let mut generated_ids: Vec<NodeId> = Vec::new();
for (i, op) in ops.iter().enumerate() {
match op {
TxOp::Insert { vector, payload } => {
let id = pre_assigned_ids[i].expect("BUG: Insert op must have pre-assigned ID");
let payload_str = payload.to_string();
if payload_str.len() > 8 * 1024 * 1024 {
return Err(crate::error::TriviumError::PayloadTooLarge {
size_bytes: payload_str.len(),
max_bytes: 8 * 1024 * 1024,
});
}
generated_ids.push(id);
wal_entries.push(WalEntry::Insert {
id,
vector: vector.clone(),
payload: payload_str,
});
}
TxOp::InsertWithId {
id,
vector,
payload,
} => {
let payload_str = payload.to_string();
if payload_str.len() > 8 * 1024 * 1024 {
return Err(crate::error::TriviumError::PayloadTooLarge {
size_bytes: payload_str.len(),
max_bytes: 8 * 1024 * 1024,
});
}
generated_ids.push(*id);
wal_entries.push(WalEntry::Insert {
id: *id,
vector: vector.clone(),
payload: payload_str,
});
}
TxOp::Link {
src,
dst,
label,
weight,
} => {
wal_entries.push(WalEntry::Link {
src: *src,
dst: *dst,
label: label.clone(),
weight: *weight,
});
}
TxOp::Delete { id } => {
wal_entries.push(WalEntry::Delete { id: *id });
}
TxOp::Unlink { src, dst } => {
wal_entries.push(WalEntry::Unlink {
src: *src,
dst: *dst,
});
}
TxOp::UpdatePayload { id, payload } => {
let payload_str = payload.to_string();
if payload_str.len() > 8 * 1024 * 1024 {
return Err(crate::error::TriviumError::PayloadTooLarge {
size_bytes: payload_str.len(),
max_bytes: 8 * 1024 * 1024,
});
}
wal_entries.push(WalEntry::UpdatePayload {
id: *id,
payload: payload_str,
});
}
TxOp::UpdateVector { id, vector } => {
wal_entries.push(WalEntry::UpdateVector {
id: *id,
vector: vector.clone(),
});
}
}
}
{
let mut w = lock_or_recover(&self.wal);
let tx_id = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_nanos() as u64;
w.append_batch(tx_id, &wal_entries)?;
}
mt.set_quiver_sync_paused(true);
for entry in &wal_entries {
match entry {
WalEntry::Insert {
id,
vector,
payload,
} => {
let payload_val: serde_json::Value =
serde_json::from_str(payload).unwrap_or_default();
let _ = mt.insert_with_id(*id, vector, payload_val);
}
WalEntry::Link {
src,
dst,
label,
weight,
} => {
let _ = mt.link(*src, *dst, label.clone(), *weight);
}
WalEntry::Delete { id } => {
let _ = mt.delete(*id);
}
WalEntry::Unlink { src, dst } => {
let _ = mt.unlink(*src, *dst);
}
WalEntry::UpdatePayload { id, payload } => {
let payload_val: serde_json::Value =
serde_json::from_str(payload).unwrap_or_default();
let _ = mt.update_payload(*id, payload_val);
}
WalEntry::UpdateVector { id, vector } => {
let _ = mt.update_vector(*id, vector);
}
_ => {}
}
}
mt.set_quiver_sync_paused(false);
mt.quiver_sync_tx_entries(&wal_entries);
drop(mt);
Ok(generated_ids)
}
pub fn commit_tx(&mut self, builder: TxBuilder<T>) -> Result<Vec<NodeId>> {
self.commit_ops(builder.into_ops())
}
}