use std::sync::Arc;
use std::time::Duration;
use chrono::{DateTime, Utc};
use dashmap::DashMap;
use tokio::sync::Mutex;
use uuid::Uuid;
use crate::{
error::{ClawDBError, ClawDBResult},
events::{bus::EventBus, types::ClawEvent},
lifecycle::manager::ComponentLifecycleManager,
session::context::SessionContext,
transaction::{
coordinator::TransactionCoordinator,
log::{TransactionLog, TransactionLogEntry},
},
};
#[derive(Debug, Clone)]
pub struct VectorUpsertOp {
pub collection: String,
pub id: String,
pub text: String,
pub metadata: serde_json::Value,
pub dimensions: Option<usize>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum TxStatus {
Active,
Committing,
RollingBack,
Committed,
RolledBack,
}
#[derive(Debug)]
pub struct TxState {
pub id: Uuid,
pub session_id: Uuid,
pub agent_id: Uuid,
pub started_at: DateTime<Utc>,
pub status: TxStatus,
pub core_tx_begun: bool,
pub vector_buffer: Vec<VectorUpsertOp>,
pub branch_snapshot_id: Option<Uuid>,
}
impl TxState {
fn new(session_id: Uuid, agent_id: Uuid) -> Self {
Self {
id: Uuid::new_v4(),
session_id,
agent_id,
started_at: Utc::now(),
status: TxStatus::Active,
core_tx_begun: false,
vector_buffer: Vec::new(),
branch_snapshot_id: None,
}
}
}
pub struct TransactionManager {
lifecycle: Arc<ComponentLifecycleManager>,
coordinator: Arc<TransactionCoordinator>,
log: Arc<TransactionLog>,
event_bus: Arc<EventBus>,
active: DashMap<Uuid, Arc<Mutex<TxState>>>,
}
impl TransactionManager {
pub fn new(
lifecycle: Arc<ComponentLifecycleManager>,
event_bus: Arc<EventBus>,
) -> Self {
Self {
lifecycle,
coordinator: Arc::new(TransactionCoordinator::new()),
log: Arc::new(TransactionLog::new(8_192)),
event_bus,
active: DashMap::new(),
}
}
#[tracing::instrument(skip(self, session), fields(session_id = %session.session_id))]
pub async fn begin(&self, session: &SessionContext) -> ClawDBResult<Uuid> {
let mut state = TxState::new(session.session_id, session.agent_id);
let tx_id = state.id;
if let Ok(core) = self.lifecycle.core() {
core.execute_raw_write("BEGIN IMMEDIATE", &[]).await.map_err(|e| {
ClawDBError::TransactionFailed {
tx_id,
reason: format!("core BEGIN failed: {e}"),
}
})?;
state.core_tx_begun = true;
}
if let Ok(branch) = self.lifecycle.branch() {
let snapshot_name = format!("tx-{tx_id}");
match branch.create_snapshot(&snapshot_name).await {
Ok(snap_id) => {
state.branch_snapshot_id = Some(snap_id);
}
Err(e) => {
tracing::warn!(tx_id = %tx_id, "branch snapshot failed (non-fatal): {e}");
}
}
}
self.coordinator.register(tx_id, session.session_id);
self.active
.insert(tx_id, Arc::new(Mutex::new(state)));
tracing::info!(tx_id = %tx_id, agent_id = %session.agent_id, "transaction begun");
Ok(tx_id)
}
#[tracing::instrument(skip(self), fields(tx_id = %tx_id))]
pub async fn commit(&self, tx_id: Uuid) -> ClawDBResult<()> {
let handle = self.get_handle(tx_id)?;
let mut state = handle.lock().await;
if state.status != TxStatus::Active {
return Err(ClawDBError::TransactionFailed {
tx_id,
reason: format!("cannot commit; status is {:?}", state.status),
});
}
self.coordinator.prepare(tx_id, &[])?;
let vector_cfg_dims = self
.lifecycle
.core()
.map(|_| 0usize) .unwrap_or(0);
for op in &state.vector_buffer {
if let Some(dims) = op.dimensions {
if vector_cfg_dims > 0 && dims != vector_cfg_dims {
return Err(ClawDBError::TransactionFailed {
tx_id,
reason: format!(
"vector op '{}' has {dims} dims; expected {vector_cfg_dims}",
op.id
),
});
}
}
}
if state.branch_snapshot_id.is_none() {
tracing::debug!(tx_id = %tx_id, "no branch snapshot; skipping snapshot verify");
}
state.status = TxStatus::Committing;
let result = self.apply_phase2(&mut state).await;
if let Err(ref e) = result {
tracing::error!(tx_id = %tx_id, err = %e, "phase-2 failed; attempting rollback");
state.status = TxStatus::RollingBack;
if let Err(rb_err) = self.do_rollback(&mut state).await {
tracing::error!(tx_id = %tx_id, err = %rb_err, "rollback also failed");
}
state.status = TxStatus::RolledBack;
self.coordinator.deregister(tx_id);
self.active.remove(&tx_id);
self.event_bus.emit(ClawEvent::ShutdownInitiated {
reason: format!("transaction {tx_id} failed during phase-2 commit: {e}"),
});
return Err(ClawDBError::TransactionFailed {
tx_id,
reason: e.to_string(),
});
}
state.status = TxStatus::Committed;
self.coordinator.deregister(tx_id);
self.active.remove(&tx_id);
self.log.append(TransactionLogEntry {
tx_id,
session_id: state.session_id,
committed_at: Utc::now().timestamp(),
write_set: state
.vector_buffer
.iter()
.map(|op| op.id.clone())
.collect(),
});
tracing::info!(tx_id = %tx_id, "transaction committed");
Ok(())
}
async fn apply_phase2(&self, state: &mut TxState) -> ClawDBResult<()> {
let tx_id = state.id;
if !state.vector_buffer.is_empty() {
if let Ok(vector) = self.lifecycle.vector() {
let buffer = std::mem::take(&mut state.vector_buffer);
for op in buffer {
vector
.upsert(&op.collection, &op.id, &op.text, &op.metadata)
.await
.map_err(|e| ClawDBError::TransactionFailed {
tx_id,
reason: format!("vector upsert '{}' failed: {e}", op.id),
})?;
}
}
}
if state.core_tx_begun {
if let Ok(core) = self.lifecycle.core() {
core.execute_raw_write("COMMIT", &[])
.await
.map_err(|e| ClawDBError::TransactionFailed {
tx_id,
reason: format!("core COMMIT failed: {e}"),
})?;
state.core_tx_begun = false;
}
}
if let Some(snap_id) = state.branch_snapshot_id.take() {
if let Ok(branch) = self.lifecycle.branch() {
if let Err(e) = branch.delete_snapshot(snap_id).await {
tracing::warn!(tx_id = %tx_id, snap_id = %snap_id, "snapshot delete failed (non-fatal): {e}");
}
}
}
Ok(())
}
#[tracing::instrument(skip(self), fields(tx_id = %tx_id))]
pub async fn rollback(&self, tx_id: Uuid) -> ClawDBResult<()> {
let handle = self.get_handle(tx_id)?;
let mut state = handle.lock().await;
if matches!(state.status, TxStatus::Committed | TxStatus::RolledBack) {
return Err(ClawDBError::TransactionFailed {
tx_id,
reason: format!("cannot rollback; status is {:?}", state.status),
});
}
state.status = TxStatus::RollingBack;
self.do_rollback(&mut state).await?;
state.status = TxStatus::RolledBack;
self.coordinator.deregister(tx_id);
self.active.remove(&tx_id);
tracing::info!(tx_id = %tx_id, "transaction rolled back");
Ok(())
}
async fn do_rollback(&self, state: &mut TxState) -> ClawDBResult<()> {
let tx_id = state.id;
if state.core_tx_begun {
if let Ok(core) = self.lifecycle.core() {
if let Err(e) = core.execute_raw_write("ROLLBACK", &[]).await {
tracing::error!(tx_id = %tx_id, "core ROLLBACK failed: {e}");
}
}
state.core_tx_begun = false;
}
state.vector_buffer.clear();
if let Some(snap_id) = state.branch_snapshot_id.take() {
if let Ok(branch) = self.lifecycle.branch() {
if let Err(e) = branch.restore_snapshot(snap_id).await {
tracing::error!(tx_id = %tx_id, snap_id = %snap_id, "snapshot restore failed: {e}");
return Err(ClawDBError::TransactionFailed {
tx_id,
reason: format!("branch snapshot restore failed: {e}"),
});
}
}
}
Ok(())
}
pub async fn timeout_stale(&self, older_than_secs: u64) -> ClawDBResult<u32> {
let threshold = Duration::from_secs(older_than_secs);
let stale = self.coordinator.stale_transactions(threshold);
let mut count = 0u32;
for (tx_id, _session_id) in stale {
tracing::warn!(tx_id = %tx_id, older_than_secs, "timing out stale transaction");
if let Err(e) = self.rollback(tx_id).await {
tracing::error!(tx_id = %tx_id, err = %e, "stale transaction rollback failed");
} else {
count += 1;
}
}
Ok(count)
}
pub async fn buffer_vector_upsert(
&self,
tx_id: Uuid,
op: VectorUpsertOp,
) -> ClawDBResult<()> {
let handle = self.get_handle(tx_id)?;
let mut state = handle.lock().await;
if state.status != TxStatus::Active {
return Err(ClawDBError::TransactionFailed {
tx_id,
reason: "transaction is not active".to_string(),
});
}
let key = op.id.clone();
state.vector_buffer.push(op);
self.coordinator.extend_write_set(tx_id, [key]);
Ok(())
}
fn get_handle(&self, tx_id: Uuid) -> ClawDBResult<Arc<Mutex<TxState>>> {
self.active
.get(&tx_id)
.map(|r| Arc::clone(&r))
.ok_or(ClawDBError::TransactionFailed {
tx_id,
reason: "transaction not found".to_string(),
})
}
pub fn active_count(&self) -> usize {
self.active.len()
}
pub fn begin_simple(
&self,
session_id: Uuid,
isolation: crate::transaction::context::IsolationLevel,
) -> crate::transaction::context::TransactionContext {
let ctx = crate::transaction::context::TransactionContext::new(session_id, isolation);
self.coordinator.register(ctx.tx_id, session_id);
ctx
}
pub fn commit_simple(
&self,
ctx: &mut crate::transaction::context::TransactionContext,
) -> ClawDBResult<()> {
use crate::transaction::context::TransactionStatus;
if !ctx.is_active() {
return Err(ClawDBError::TransactionFailed {
tx_id: ctx.tx_id,
reason: "transaction is not active".to_string(),
});
}
self.coordinator.check_conflicts(ctx.tx_id, &ctx.write_set)?;
ctx.status = TransactionStatus::Committed;
self.coordinator.deregister(ctx.tx_id);
self.log.append(TransactionLogEntry {
tx_id: ctx.tx_id,
session_id: ctx.session_id,
committed_at: Utc::now().timestamp(),
write_set: ctx.write_set.clone(),
});
Ok(())
}
}
Ok(())
}
pub fn rollback(&self, ctx: &mut TransactionContext) {
ctx.status = TransactionStatus::RolledBack;
self.coordinator.deregister(ctx.tx_id);
}
pub fn active_count(&self) -> usize {
self.coordinator.active_count()
}
pub fn log(&self) -> &TransactionLog {
&self.log
}
}
impl Default for TransactionManager {
fn default() -> Self {
Self::new()
}
}