1pub mod mvcc;
33
34use std::collections::{HashMap, HashSet};
35use std::sync::atomic::{AtomicU64, Ordering};
36use std::sync::{Arc, Mutex};
37
38use arrow::array::RecordBatch;
39use arrow::datatypes::Schema;
40
41pub use mvcc::{
42 RowVersion, TXN_ID_AUTO_COMMIT, TXN_ID_NONE, TransactionSnapshot, TxnId, TxnIdManager,
43};
44
45pub type SessionId = u64;
50
51use llkv_expr::expr::Expr as LlkvExpr;
52use llkv_plan::plans::{
53 ColumnSpec, CreateIndexPlan, CreateTablePlan, DeletePlan, InsertPlan, PlanOperation, PlanValue,
54 SelectPlan, UpdatePlan,
55};
56use llkv_result::{Error, Result as LlkvResult};
57use llkv_storage::pager::Pager;
58use simd_r_drive_entry_handle::EntryHandle;
59
60use llkv_executor::SelectExecution;
61
62pub struct RowBatch {
68 pub columns: Vec<String>,
69 pub rows: Vec<Vec<PlanValue>>,
70}
71
72#[derive(Clone, Debug)]
74pub enum TransactionKind {
75 Begin,
76 Commit,
77 Rollback,
78}
79
80fn select_plan_table_name(plan: &SelectPlan) -> Option<String> {
82 if plan.tables.len() == 1 {
83 Some(plan.tables[0].qualified_name())
84 } else {
85 None
86 }
87}
88
89#[allow(clippy::large_enum_variant)]
91#[derive(Clone, Debug)]
94pub enum TransactionResult<P>
95where
96 P: Pager<Blob = EntryHandle> + Send + Sync,
97{
98 CreateTable {
99 table_name: String,
100 },
101 Insert {
102 rows_inserted: usize,
103 },
104 Update {
105 rows_matched: usize,
106 rows_updated: usize,
107 },
108 Delete {
109 rows_deleted: usize,
110 },
111 CreateIndex {
112 table_name: String,
113 index_name: Option<String>,
114 },
115 Select {
116 table_name: String,
117 schema: Arc<Schema>,
118 execution: SelectExecution<P>,
119 },
120 Transaction {
121 kind: TransactionKind,
122 },
123}
124
125impl<P> TransactionResult<P>
126where
127 P: Pager<Blob = EntryHandle> + Send + Sync + 'static,
128{
129 pub fn convert_pager_type<P2>(self) -> LlkvResult<TransactionResult<P2>>
131 where
132 P2: Pager<Blob = EntryHandle> + Send + Sync + 'static,
133 {
134 match self {
135 TransactionResult::CreateTable { table_name } => {
136 Ok(TransactionResult::CreateTable { table_name })
137 }
138 TransactionResult::Insert { rows_inserted } => {
139 Ok(TransactionResult::Insert { rows_inserted })
140 }
141 TransactionResult::Update {
142 rows_matched,
143 rows_updated,
144 } => Ok(TransactionResult::Update {
145 rows_matched,
146 rows_updated,
147 }),
148 TransactionResult::Delete { rows_deleted } => {
149 Ok(TransactionResult::Delete { rows_deleted })
150 }
151 TransactionResult::CreateIndex {
152 table_name,
153 index_name,
154 } => Ok(TransactionResult::CreateIndex {
155 table_name,
156 index_name,
157 }),
158 TransactionResult::Transaction { kind } => Ok(TransactionResult::Transaction { kind }),
159 TransactionResult::Select { .. } => Err(Error::Internal(
160 "cannot convert SELECT TransactionResult between pager types".into(),
161 )),
162 }
163 }
164}
165
166pub trait TransactionContext: Send + Sync {
174 type Pager: Pager<Blob = EntryHandle> + Send + Sync + 'static;
176
177 fn set_snapshot(&self, snapshot: mvcc::TransactionSnapshot);
179
180 fn snapshot(&self) -> mvcc::TransactionSnapshot;
182
183 fn table_column_specs(&self, table_name: &str) -> LlkvResult<Vec<ColumnSpec>>;
185
186 fn export_table_rows(&self, table_name: &str) -> LlkvResult<RowBatch>;
188
189 fn get_batches_with_row_ids(
191 &self,
192 table_name: &str,
193 filter: Option<LlkvExpr<'static, String>>,
194 ) -> LlkvResult<Vec<RecordBatch>>;
195
196 fn execute_select(&self, plan: SelectPlan) -> LlkvResult<SelectExecution<Self::Pager>>;
198
199 fn create_table_plan(
201 &self,
202 plan: CreateTablePlan,
203 ) -> LlkvResult<TransactionResult<Self::Pager>>;
204
205 fn insert(&self, plan: InsertPlan) -> LlkvResult<TransactionResult<Self::Pager>>;
207
208 fn update(&self, plan: UpdatePlan) -> LlkvResult<TransactionResult<Self::Pager>>;
210
211 fn delete(&self, plan: DeletePlan) -> LlkvResult<TransactionResult<Self::Pager>>;
213
214 fn create_index(&self, plan: CreateIndexPlan) -> LlkvResult<TransactionResult<Self::Pager>>;
216
217 fn append_batches_with_row_ids(
219 &self,
220 table_name: &str,
221 batches: Vec<RecordBatch>,
222 ) -> LlkvResult<usize>;
223
224 fn table_names(&self) -> Vec<String>;
226
227 fn table_id(&self, table_name: &str) -> LlkvResult<llkv_table::types::TableId>;
229
230 fn catalog_snapshot(&self) -> llkv_table::catalog::TableCatalogSnapshot;
232
233 fn validate_commit_constraints(&self, _txn_id: TxnId) -> LlkvResult<()> {
235 Ok(())
236 }
237
238 fn clear_transaction_state(&self, _txn_id: TxnId) {}
240}
241
242pub struct SessionTransaction<BaseCtx, StagingCtx>
244where
245 BaseCtx: TransactionContext + 'static,
246 StagingCtx: TransactionContext + 'static,
247{
248 snapshot: mvcc::TransactionSnapshot,
250 staging: Arc<StagingCtx>,
252 operations: Vec<PlanOperation>,
254 staged_tables: HashSet<String>,
256 new_tables: HashSet<String>,
258 missing_tables: HashSet<String>,
260 catalog_snapshot: llkv_table::catalog::TableCatalogSnapshot,
263 base_context: Arc<BaseCtx>,
265 is_aborted: bool,
267 txn_manager: Arc<TxnIdManager>,
269 accessed_tables: HashSet<String>,
271}
272
273impl<BaseCtx, StagingCtx> SessionTransaction<BaseCtx, StagingCtx>
274where
275 BaseCtx: TransactionContext + 'static,
276 StagingCtx: TransactionContext + 'static,
277{
278 pub fn new(
279 base_context: Arc<BaseCtx>,
280 staging: Arc<StagingCtx>,
281 txn_manager: Arc<TxnIdManager>,
282 ) -> Self {
283 let catalog_snapshot = base_context.catalog_snapshot();
286
287 let snapshot = txn_manager.begin_transaction();
288 tracing::debug!(
289 "[SESSION_TX] new() created transaction with txn_id={}, snapshot_id={}",
290 snapshot.txn_id,
291 snapshot.snapshot_id
292 );
293 TransactionContext::set_snapshot(&*base_context, snapshot);
294 TransactionContext::set_snapshot(&*staging, snapshot);
295
296 Self {
297 staging,
298 operations: Vec::new(),
299 staged_tables: HashSet::new(),
300 new_tables: HashSet::new(),
301 missing_tables: HashSet::new(),
302 catalog_snapshot,
303 base_context,
304 is_aborted: false,
305 accessed_tables: HashSet::new(),
306 snapshot,
307 txn_manager,
308 }
309 }
310
311 fn ensure_table_exists(&mut self, table_name: &str) -> LlkvResult<()> {
314 tracing::trace!(
315 "[ENSURE] ensure_table_exists called for table='{}'",
316 table_name
317 );
318
319 if self.staged_tables.contains(table_name) {
321 tracing::trace!("[ENSURE] table already verified to exist");
322 return Ok(());
323 }
324
325 if !self.catalog_snapshot.table_exists(table_name) && !self.new_tables.contains(table_name)
327 {
328 self.missing_tables.insert(table_name.to_string());
329 return Err(Error::CatalogError(format!(
330 "Catalog Error: Table '{table_name}' does not exist"
331 )));
332 }
333
334 if self.missing_tables.contains(table_name) {
335 return Err(Error::CatalogError(format!(
336 "Catalog Error: Table '{table_name}' does not exist"
337 )));
338 }
339
340 if self.new_tables.contains(table_name) {
342 tracing::trace!("[ENSURE] Table was created in this transaction");
343 match self.staging.table_column_specs(table_name) {
345 Ok(_) => {
346 self.staged_tables.insert(table_name.to_string());
347 return Ok(());
348 }
349 Err(_) => {
350 return Err(Error::CatalogError(format!(
351 "Catalog Error: Table '{table_name}' was created but not found in staging"
352 )));
353 }
354 }
355 }
356
357 tracing::trace!(
359 "[ENSURE] Table exists in base, no copying needed (MVCC will handle visibility)"
360 );
361 self.staged_tables.insert(table_name.to_string());
362 Ok(())
363 }
364
365 pub fn execute_select(
369 &mut self,
370 plan: SelectPlan,
371 ) -> LlkvResult<SelectExecution<StagingCtx::Pager>> {
372 let table_name = select_plan_table_name(&plan).ok_or_else(|| {
374 Error::InvalidArgumentError(
375 "Transaction execute_select requires single-table query".into(),
376 )
377 })?;
378
379 self.ensure_table_exists(&table_name)?;
381
382 if self.new_tables.contains(&table_name) {
384 tracing::trace!(
385 "[SELECT] Reading from staging for new table '{}'",
386 table_name
387 );
388 return self.staging.execute_select(plan);
389 }
390
391 self.accessed_tables.insert(table_name.clone());
393
394 tracing::trace!(
397 "[SELECT] Reading from BASE with MVCC for existing table '{}'",
398 table_name
399 );
400 self.base_context.execute_select(plan).and_then(|exec| {
401 let schema = exec.schema();
405 let batches = exec.collect().unwrap_or_default();
406 let combined = if batches.is_empty() {
407 RecordBatch::new_empty(Arc::clone(&schema))
408 } else if batches.len() == 1 {
409 batches.into_iter().next().unwrap()
410 } else {
411 let refs: Vec<&RecordBatch> = batches.iter().collect();
412 arrow::compute::concat_batches(&schema, refs).map_err(|err| {
413 Error::Internal(format!("failed to concatenate batches: {err}"))
414 })?
415 };
416 Ok(SelectExecution::from_batch(
417 table_name,
418 Arc::clone(&schema),
419 combined,
420 ))
421 })
422 }
423
424 pub fn execute_operation(
426 &mut self,
427 operation: PlanOperation,
428 ) -> LlkvResult<TransactionResult<StagingCtx::Pager>> {
429 tracing::trace!(
430 "[TX] SessionTransaction::execute_operation called, operation={:?}",
431 match &operation {
432 PlanOperation::Insert(p) => format!("INSERT({})", p.table),
433 PlanOperation::Update(p) => format!("UPDATE({})", p.table),
434 PlanOperation::Delete(p) => format!("DELETE({})", p.table),
435 PlanOperation::CreateTable(p) => format!("CREATE_TABLE({})", p.name),
436 _ => "OTHER".to_string(),
437 }
438 );
439 if self.is_aborted {
441 return Err(Error::TransactionContextError(
442 "TransactionContext Error: transaction is aborted".into(),
443 ));
444 }
445
446 let result = match operation {
448 PlanOperation::CreateTable(ref plan) => {
449 match self.staging.create_table_plan(plan.clone()) {
450 Ok(result) => {
451 self.new_tables.insert(plan.name.clone());
453 self.missing_tables.remove(&plan.name);
454 self.staged_tables.insert(plan.name.clone());
455 self.operations
457 .push(PlanOperation::CreateTable(plan.clone()));
458 result.convert_pager_type()?
459 }
460 Err(e) => {
461 self.is_aborted = true;
462 return Err(e);
463 }
464 }
465 }
466 PlanOperation::Insert(ref plan) => {
467 tracing::trace!(
468 "[TX] SessionTransaction::execute_operation INSERT for table='{}'",
469 plan.table
470 );
471 if let Err(e) = self.ensure_table_exists(&plan.table) {
473 self.is_aborted = true;
474 return Err(e);
475 }
476
477 let is_new_table = self.new_tables.contains(&plan.table);
480 if !is_new_table {
482 self.accessed_tables.insert(plan.table.clone());
483 }
484 let result = if is_new_table {
485 tracing::trace!("[TX] INSERT into staging for new table");
486 self.staging.insert(plan.clone())
487 } else {
488 tracing::trace!(
489 "[TX] INSERT directly into BASE with txn_id={}",
490 self.snapshot.txn_id
491 );
492 self.base_context
494 .insert(plan.clone())
495 .and_then(|r| r.convert_pager_type())
496 };
497
498 match result {
499 Ok(result) => {
500 if is_new_table {
503 tracing::trace!(
504 "[TX] INSERT to new table - tracking for commit replay"
505 );
506 self.operations.push(PlanOperation::Insert(plan.clone()));
507 } else {
508 tracing::trace!(
509 "[TX] INSERT to existing table - already in BASE, no replay needed"
510 );
511 }
512 result
513 }
514 Err(e) => {
515 tracing::trace!(
516 "DEBUG SessionTransaction::execute_operation INSERT failed: {:?}",
517 e
518 );
519 tracing::trace!("DEBUG setting is_aborted=true");
520 self.is_aborted = true;
521 return Err(e);
522 }
523 }
524 }
525 PlanOperation::Update(ref plan) => {
526 if let Err(e) = self.ensure_table_exists(&plan.table) {
527 self.is_aborted = true;
528 return Err(e);
529 }
530
531 let is_new_table = self.new_tables.contains(&plan.table);
534 if !is_new_table {
536 self.accessed_tables.insert(plan.table.clone());
537 }
538 let result = if is_new_table {
539 tracing::trace!("[TX] UPDATE in staging for new table");
540 self.staging.update(plan.clone())
541 } else {
542 tracing::trace!(
543 "[TX] UPDATE directly in BASE with txn_id={}",
544 self.snapshot.txn_id
545 );
546 self.base_context
547 .update(plan.clone())
548 .and_then(|r| r.convert_pager_type())
549 };
550
551 match result {
552 Ok(result) => {
553 if is_new_table {
555 tracing::trace!(
556 "[TX] UPDATE to new table - tracking for commit replay"
557 );
558 self.operations.push(PlanOperation::Update(plan.clone()));
559 } else {
560 tracing::trace!(
561 "[TX] UPDATE to existing table - already in BASE, no replay needed"
562 );
563 }
564 result
565 }
566 Err(e) => {
567 self.is_aborted = true;
568 return Err(e);
569 }
570 }
571 }
572 PlanOperation::Delete(ref plan) => {
573 tracing::debug!("[DELETE] Starting delete for table '{}'", plan.table);
574 if let Err(e) = self.ensure_table_exists(&plan.table) {
575 tracing::debug!("[DELETE] ensure_table_exists failed: {}", e);
576 self.is_aborted = true;
577 return Err(e);
578 }
579
580 let is_new_table = self.new_tables.contains(&plan.table);
583 tracing::debug!("[DELETE] is_new_table={}", is_new_table);
584 if !is_new_table {
586 tracing::debug!(
587 "[DELETE] Tracking access to existing table '{}'",
588 plan.table
589 );
590 self.accessed_tables.insert(plan.table.clone());
591 }
592 let result = if is_new_table {
593 tracing::debug!("[DELETE] Deleting from staging for new table");
594 self.staging.delete(plan.clone())
595 } else {
596 tracing::debug!(
597 "[DELETE] Deleting from BASE with txn_id={}",
598 self.snapshot.txn_id
599 );
600 self.base_context
601 .delete(plan.clone())
602 .and_then(|r| r.convert_pager_type())
603 };
604
605 tracing::debug!(
606 "[DELETE] Result: {:?}",
607 result.as_ref().map(|_| "Ok").map_err(|e| format!("{}", e))
608 );
609 match result {
610 Ok(result) => {
611 if is_new_table {
613 tracing::trace!(
614 "[TX] DELETE from new table - tracking for commit replay"
615 );
616 self.operations.push(PlanOperation::Delete(plan.clone()));
617 } else {
618 tracing::trace!(
619 "[TX] DELETE from existing table - already in BASE, no replay needed"
620 );
621 }
622 result
623 }
624 Err(e) => {
625 self.is_aborted = true;
626 return Err(e);
627 }
628 }
629 }
630 PlanOperation::Select(ref plan) => {
631 let table_name = select_plan_table_name(plan).unwrap_or_default();
634 match self.execute_select(plan.clone()) {
635 Ok(staging_execution) => {
636 let schema = staging_execution.schema();
638 let batches = staging_execution.collect().unwrap_or_default();
639
640 let combined = if batches.is_empty() {
642 RecordBatch::new_empty(Arc::clone(&schema))
643 } else if batches.len() == 1 {
644 batches.into_iter().next().unwrap()
645 } else {
646 let refs: Vec<&RecordBatch> = batches.iter().collect();
647 arrow::compute::concat_batches(&schema, refs).map_err(|err| {
648 Error::Internal(format!("failed to concatenate batches: {err}"))
649 })?
650 };
651
652 let execution = SelectExecution::from_batch(
654 table_name.clone(),
655 Arc::clone(&schema),
656 combined,
657 );
658
659 TransactionResult::Select {
660 table_name,
661 schema,
662 execution,
663 }
664 }
665 Err(e) => {
666 return Err(e);
669 }
670 }
671 }
672 };
673
674 Ok(result)
675 }
676
677 pub fn operations(&self) -> &[PlanOperation] {
679 &self.operations
680 }
681}
682
683pub struct TransactionSession<BaseCtx, StagingCtx>
686where
687 BaseCtx: TransactionContext + 'static,
688 StagingCtx: TransactionContext + 'static,
689{
690 context: Arc<BaseCtx>,
691 session_id: SessionId,
692 transactions: Arc<Mutex<HashMap<SessionId, SessionTransaction<BaseCtx, StagingCtx>>>>,
693 txn_manager: Arc<TxnIdManager>,
694}
695
696impl<BaseCtx, StagingCtx> TransactionSession<BaseCtx, StagingCtx>
697where
698 BaseCtx: TransactionContext + 'static,
699 StagingCtx: TransactionContext + 'static,
700{
701 pub fn new(
702 context: Arc<BaseCtx>,
703 session_id: SessionId,
704 transactions: Arc<Mutex<HashMap<SessionId, SessionTransaction<BaseCtx, StagingCtx>>>>,
705 txn_manager: Arc<TxnIdManager>,
706 ) -> Self {
707 Self {
708 context,
709 session_id,
710 transactions,
711 txn_manager,
712 }
713 }
714
715 pub fn clone_session(&self) -> Self {
718 Self {
719 context: Arc::clone(&self.context),
720 session_id: self.session_id,
721 transactions: Arc::clone(&self.transactions),
722 txn_manager: Arc::clone(&self.txn_manager),
723 }
724 }
725
726 pub fn session_id(&self) -> SessionId {
728 self.session_id
729 }
730
731 pub fn context(&self) -> &Arc<BaseCtx> {
733 &self.context
734 }
735
736 pub fn has_active_transaction(&self) -> bool {
738 self.transactions
739 .lock()
740 .expect("transactions lock poisoned")
741 .contains_key(&self.session_id)
742 }
743
744 pub fn is_aborted(&self) -> bool {
746 self.transactions
747 .lock()
748 .expect("transactions lock poisoned")
749 .get(&self.session_id)
750 .map(|tx| tx.is_aborted)
751 .unwrap_or(false)
752 }
753
754 pub fn abort_transaction(&self) {
757 let mut guard = self
758 .transactions
759 .lock()
760 .expect("transactions lock poisoned");
761 if let Some(tx) = guard.get_mut(&self.session_id) {
762 tx.is_aborted = true;
763 }
764 }
765
766 pub fn begin_transaction(
768 &self,
769 staging: Arc<StagingCtx>,
770 ) -> LlkvResult<TransactionResult<BaseCtx::Pager>> {
771 tracing::debug!(
772 "[BEGIN] begin_transaction called for session_id={}",
773 self.session_id
774 );
775 let mut guard = self
776 .transactions
777 .lock()
778 .expect("transactions lock poisoned");
779 tracing::debug!(
780 "[BEGIN] session_id={}, transactions map has {} entries",
781 self.session_id,
782 guard.len()
783 );
784 if guard.contains_key(&self.session_id) {
785 return Err(Error::InvalidArgumentError(
786 "a transaction is already in progress in this session".into(),
787 ));
788 }
789 guard.insert(
790 self.session_id,
791 SessionTransaction::new(
792 Arc::clone(&self.context),
793 staging,
794 Arc::clone(&self.txn_manager),
795 ),
796 );
797 tracing::debug!(
798 "[BEGIN] session_id={}, inserted transaction, map now has {} entries",
799 self.session_id,
800 guard.len()
801 );
802 Ok(TransactionResult::Transaction {
803 kind: TransactionKind::Begin,
804 })
805 }
806
807 pub fn commit_transaction(
810 &self,
811 ) -> LlkvResult<(TransactionResult<BaseCtx::Pager>, Vec<PlanOperation>)> {
812 tracing::trace!(
813 "[COMMIT] commit_transaction called for session {:?}",
814 self.session_id
815 );
816 let mut guard = self
817 .transactions
818 .lock()
819 .expect("transactions lock poisoned");
820 tracing::trace!("[COMMIT] commit_transaction got lock, checking for transaction...");
821 let tx_opt = guard.remove(&self.session_id);
822 tracing::trace!(
823 "[COMMIT] commit_transaction remove returned: {}",
824 tx_opt.is_some()
825 );
826 let tx = tx_opt.ok_or_else(|| {
827 tracing::trace!("[COMMIT] commit_transaction: no transaction found!");
828 Error::InvalidArgumentError(
829 "no transaction is currently in progress in this session".into(),
830 )
831 })?;
832 tracing::trace!("DEBUG commit_transaction: is_aborted={}", tx.is_aborted);
833
834 if tx.is_aborted {
836 tx.txn_manager.mark_aborted(tx.snapshot.txn_id);
837 tx.base_context.clear_transaction_state(tx.snapshot.txn_id);
838 tx.staging.clear_transaction_state(tx.snapshot.txn_id);
839 let auto_commit_snapshot = TransactionSnapshot {
841 txn_id: TXN_ID_AUTO_COMMIT,
842 snapshot_id: tx.txn_manager.last_committed(),
843 };
844 TransactionContext::set_snapshot(&*self.context, auto_commit_snapshot);
845 tracing::trace!("DEBUG commit_transaction: returning Rollback with 0 operations");
846 return Ok((
847 TransactionResult::Transaction {
848 kind: TransactionKind::Rollback,
849 },
850 Vec::new(),
851 ));
852 }
853
854 tracing::debug!(
857 "[COMMIT CONFLICT CHECK] Transaction {} accessed {} tables",
858 tx.snapshot.txn_id,
859 tx.accessed_tables.len()
860 );
861 for accessed_table_name in &tx.accessed_tables {
862 tracing::debug!(
863 "[COMMIT CONFLICT CHECK] Checking table '{}'",
864 accessed_table_name
865 );
866 if let Some(snapshot_table_id) = tx.catalog_snapshot.table_id(accessed_table_name) {
868 match self.context.table_id(accessed_table_name) {
870 Ok(current_table_id) => {
871 if current_table_id != snapshot_table_id {
873 tx.txn_manager.mark_aborted(tx.snapshot.txn_id);
874 tx.base_context.clear_transaction_state(tx.snapshot.txn_id);
875 tx.staging.clear_transaction_state(tx.snapshot.txn_id);
876 let auto_commit_snapshot = TransactionSnapshot {
877 txn_id: TXN_ID_AUTO_COMMIT,
878 snapshot_id: tx.txn_manager.last_committed(),
879 };
880 TransactionContext::set_snapshot(&*self.context, auto_commit_snapshot);
881 return Err(Error::TransactionContextError(
882 "another transaction has dropped this table".into(),
883 ));
884 }
885 }
886 Err(_) => {
887 tx.txn_manager.mark_aborted(tx.snapshot.txn_id);
889 tx.base_context.clear_transaction_state(tx.snapshot.txn_id);
890 tx.staging.clear_transaction_state(tx.snapshot.txn_id);
891 let auto_commit_snapshot = TransactionSnapshot {
892 txn_id: TXN_ID_AUTO_COMMIT,
893 snapshot_id: tx.txn_manager.last_committed(),
894 };
895 TransactionContext::set_snapshot(&*self.context, auto_commit_snapshot);
896 return Err(Error::TransactionContextError(
897 "another transaction has dropped this table".into(),
898 ));
899 }
900 }
901 }
902 }
903
904 if let Err(err) = tx
905 .base_context
906 .validate_commit_constraints(tx.snapshot.txn_id)
907 {
908 tx.txn_manager.mark_aborted(tx.snapshot.txn_id);
909 tx.base_context.clear_transaction_state(tx.snapshot.txn_id);
910 tx.staging.clear_transaction_state(tx.snapshot.txn_id);
911 let auto_commit_snapshot = TransactionSnapshot {
912 txn_id: TXN_ID_AUTO_COMMIT,
913 snapshot_id: tx.txn_manager.last_committed(),
914 };
915 TransactionContext::set_snapshot(&*self.context, auto_commit_snapshot);
916 let wrapped = match err {
917 Error::ConstraintError(msg) => Error::TransactionContextError(format!(
918 "TransactionContext Error: constraint violation: {msg}"
919 )),
920 other => other,
921 };
922 return Err(wrapped);
923 }
924
925 let operations = tx.operations;
926 tracing::trace!(
927 "DEBUG commit_transaction: returning Commit with {} operations",
928 operations.len()
929 );
930
931 tx.txn_manager.mark_committed(tx.snapshot.txn_id);
932 tx.base_context.clear_transaction_state(tx.snapshot.txn_id);
933 tx.staging.clear_transaction_state(tx.snapshot.txn_id);
934 TransactionContext::set_snapshot(&*self.context, tx.snapshot);
935
936 Ok((
937 TransactionResult::Transaction {
938 kind: TransactionKind::Commit,
939 },
940 operations,
941 ))
942 }
943
944 pub fn rollback_transaction(&self) -> LlkvResult<TransactionResult<BaseCtx::Pager>> {
946 let mut guard = self
947 .transactions
948 .lock()
949 .expect("transactions lock poisoned");
950 if let Some(tx) = guard.remove(&self.session_id) {
951 tx.txn_manager.mark_aborted(tx.snapshot.txn_id);
952 tx.base_context.clear_transaction_state(tx.snapshot.txn_id);
953 tx.staging.clear_transaction_state(tx.snapshot.txn_id);
954 let auto_commit_snapshot = TransactionSnapshot {
956 txn_id: TXN_ID_AUTO_COMMIT,
957 snapshot_id: tx.txn_manager.last_committed(),
958 };
959 TransactionContext::set_snapshot(&*self.context, auto_commit_snapshot);
960 } else {
961 return Err(Error::InvalidArgumentError(
962 "no transaction is currently in progress in this session".into(),
963 ));
964 }
965 Ok(TransactionResult::Transaction {
966 kind: TransactionKind::Rollback,
967 })
968 }
969
970 pub fn execute_operation(
972 &self,
973 operation: PlanOperation,
974 ) -> LlkvResult<TransactionResult<StagingCtx::Pager>> {
975 tracing::debug!(
976 "[EXECUTE_OP] execute_operation called for session_id={}",
977 self.session_id
978 );
979 if !self.has_active_transaction() {
980 return Err(Error::InvalidArgumentError(
982 "execute_operation called without active transaction".into(),
983 ));
984 }
985
986 let mut guard = self
988 .transactions
989 .lock()
990 .expect("transactions lock poisoned");
991 tracing::debug!(
992 "[EXECUTE_OP] session_id={}, transactions map has {} entries",
993 self.session_id,
994 guard.len()
995 );
996 let tx = guard
997 .get_mut(&self.session_id)
998 .ok_or_else(|| Error::Internal("transaction disappeared during execution".into()))?;
999 tracing::debug!(
1000 "[EXECUTE_OP] session_id={}, found transaction with txn_id={}, accessed_tables={}",
1001 self.session_id,
1002 tx.snapshot.txn_id,
1003 tx.accessed_tables.len()
1004 );
1005
1006 let result = tx.execute_operation(operation);
1007 if let Err(ref e) = result {
1008 tracing::trace!("DEBUG TransactionSession::execute_operation error: {:?}", e);
1009 tracing::trace!("DEBUG Transaction is_aborted={}", tx.is_aborted);
1010 }
1011 result
1012 }
1013}
1014
1015impl<BaseCtx, StagingCtx> Drop for TransactionSession<BaseCtx, StagingCtx>
1016where
1017 BaseCtx: TransactionContext,
1018 StagingCtx: TransactionContext,
1019{
1020 fn drop(&mut self) {
1021 match self.transactions.lock() {
1024 Ok(mut guard) => {
1025 if guard.remove(&self.session_id).is_some() {
1026 eprintln!(
1027 "Warning: TransactionSession dropped with active transaction - auto-rolling back"
1028 );
1029 }
1030 }
1031 Err(_) => {
1032 tracing::trace!(
1035 "Warning: TransactionSession dropped with poisoned transaction mutex"
1036 );
1037 }
1038 }
1039 }
1040}
1041
1042pub struct TransactionManager<BaseCtx, StagingCtx>
1044where
1045 BaseCtx: TransactionContext + 'static,
1046 StagingCtx: TransactionContext + 'static,
1047{
1048 transactions: Arc<Mutex<HashMap<SessionId, SessionTransaction<BaseCtx, StagingCtx>>>>,
1049 next_session_id: AtomicU64,
1050 txn_manager: Arc<TxnIdManager>,
1051}
1052
1053impl<BaseCtx, StagingCtx> TransactionManager<BaseCtx, StagingCtx>
1054where
1055 BaseCtx: TransactionContext + 'static,
1056 StagingCtx: TransactionContext + 'static,
1057{
1058 pub fn new() -> Self {
1059 Self {
1060 transactions: Arc::new(Mutex::new(HashMap::new())),
1061 next_session_id: AtomicU64::new(1),
1062 txn_manager: Arc::new(TxnIdManager::new()),
1063 }
1064 }
1065
1066 pub fn new_with_initial_txn_id(next_txn_id: TxnId) -> Self {
1068 Self {
1069 transactions: Arc::new(Mutex::new(HashMap::new())),
1070 next_session_id: AtomicU64::new(1),
1071 txn_manager: Arc::new(TxnIdManager::new_with_initial_txn_id(next_txn_id)),
1072 }
1073 }
1074
1075 pub fn new_with_initial_state(next_txn_id: TxnId, last_committed: TxnId) -> Self {
1077 Self {
1078 transactions: Arc::new(Mutex::new(HashMap::new())),
1079 next_session_id: AtomicU64::new(1),
1080 txn_manager: Arc::new(TxnIdManager::new_with_initial_state(
1081 next_txn_id,
1082 last_committed,
1083 )),
1084 }
1085 }
1086
1087 pub fn create_session(&self, context: Arc<BaseCtx>) -> TransactionSession<BaseCtx, StagingCtx> {
1089 let session_id = self.next_session_id.fetch_add(1, Ordering::SeqCst);
1090 tracing::debug!(
1091 "[TX_MANAGER] create_session: allocated session_id={}",
1092 session_id
1093 );
1094 TransactionSession::new(
1095 context,
1096 session_id,
1097 Arc::clone(&self.transactions),
1098 Arc::clone(&self.txn_manager),
1099 )
1100 }
1101
1102 pub fn txn_manager(&self) -> Arc<TxnIdManager> {
1104 Arc::clone(&self.txn_manager)
1105 }
1106
1107 pub fn has_active_transaction(&self) -> bool {
1109 !self
1110 .transactions
1111 .lock()
1112 .expect("transactions lock poisoned")
1113 .is_empty()
1114 }
1115}
1116
1117impl<BaseCtx, StagingCtx> Default for TransactionManager<BaseCtx, StagingCtx>
1118where
1119 BaseCtx: TransactionContext + 'static,
1120 StagingCtx: TransactionContext + 'static,
1121{
1122 fn default() -> Self {
1123 Self::new()
1124 }
1125}