Skip to main content

vibesql_server/
transaction.rs

1//! Transaction isolation support for server sessions.
2//!
3//! This module provides transaction state management for individual sessions,
4//! enabling READ COMMITTED isolation level for shared databases.
5//!
6//! # Architecture
7//!
8//! When multiple sessions share a database via `DatabaseRegistry`, each session
9//! needs its own transaction state to provide proper isolation:
10//!
11//! - **READ COMMITTED**: Uncommitted changes in one transaction are NOT visible
12//!   to other sessions. Only committed changes propagate to other sessions.
13//!
14//! # Implementation
15//!
16//! We use a copy-on-write approach:
17//! 1. On BEGIN: Create a snapshot of affected tables as writes occur
18//! 2. During transaction: Writes go to the session's local buffer
19//! 3. On COMMIT: Atomically merge buffer changes into shared database
20//! 4. On ROLLBACK: Discard the buffer (no changes to shared database)
21//!
22//! This provides READ COMMITTED semantics where:
23//! - Other sessions always read committed data
24//! - A session in a transaction sees its own uncommitted changes
25//! - Committed changes become visible to all sessions
26
27use std::collections::HashMap;
28use vibesql_storage::Row;
29
30/// A change made during a transaction that needs to be applied on commit.
31#[derive(Debug, Clone, PartialEq)]
32#[allow(clippy::large_enum_variant)]
33pub enum TransactionChange {
34    /// A row was inserted
35    Insert {
36        table_name: String,
37        row: Row,
38    },
39    /// A row was updated (old values for rollback reference)
40    Update {
41        table_name: String,
42        row_index: usize,
43        old_row: Row,
44        new_row: Row,
45    },
46    /// A row was deleted
47    Delete {
48        table_name: String,
49        row_index: usize,
50        row: Row,
51    },
52    /// A table was created
53    CreateTable {
54        table_name: String,
55    },
56    /// A table was dropped
57    DropTable {
58        table_name: String,
59    },
60    /// An index was created
61    CreateIndex {
62        index_name: String,
63        table_name: String,
64    },
65    /// An index was dropped
66    DropIndex {
67        index_name: String,
68    },
69}
70
71/// Transaction state for a session.
72///
73/// Tracks uncommitted changes during an active transaction, providing
74/// READ COMMITTED isolation when multiple sessions share a database.
75#[derive(Debug)]
76pub struct TransactionState {
77    /// Transaction ID (monotonically increasing)
78    pub id: u64,
79    /// Whether we're in an active transaction block
80    pub active: bool,
81    /// Changes made during this transaction (in order)
82    changes: Vec<TransactionChange>,
83    /// Inserted rows indexed by table name (for reads during transaction)
84    inserted_rows: HashMap<String, Vec<Row>>,
85    /// Deleted row indices indexed by table name (to filter from reads)
86    deleted_indices: HashMap<String, Vec<usize>>,
87    /// Updated rows: table_name -> (row_index -> new_row)
88    updated_rows: HashMap<String, HashMap<usize, Row>>,
89}
90
91impl TransactionState {
92    /// Create a new transaction state with the given ID.
93    pub fn new(id: u64) -> Self {
94        Self {
95            id,
96            active: true,
97            changes: Vec::new(),
98            inserted_rows: HashMap::new(),
99            deleted_indices: HashMap::new(),
100            updated_rows: HashMap::new(),
101        }
102    }
103
104    /// Record an insert operation.
105    pub fn record_insert(&mut self, table_name: String, row: Row) {
106        self.changes.push(TransactionChange::Insert {
107            table_name: table_name.clone(),
108            row: row.clone(),
109        });
110        self.inserted_rows.entry(table_name).or_default().push(row);
111    }
112
113    /// Record an update operation.
114    pub fn record_update(
115        &mut self,
116        table_name: String,
117        row_index: usize,
118        old_row: Row,
119        new_row: Row,
120    ) {
121        self.changes.push(TransactionChange::Update {
122            table_name: table_name.clone(),
123            row_index,
124            old_row,
125            new_row: new_row.clone(),
126        });
127        self.updated_rows
128            .entry(table_name)
129            .or_default()
130            .insert(row_index, new_row);
131    }
132
133    /// Record a delete operation.
134    pub fn record_delete(&mut self, table_name: String, row_index: usize, row: Row) {
135        self.changes.push(TransactionChange::Delete {
136            table_name: table_name.clone(),
137            row_index,
138            row,
139        });
140        self.deleted_indices.entry(table_name).or_default().push(row_index);
141    }
142
143    /// Record a table creation.
144    pub fn record_create_table(&mut self, table_name: String) {
145        self.changes.push(TransactionChange::CreateTable { table_name });
146    }
147
148    /// Record a table drop.
149    pub fn record_drop_table(&mut self, table_name: String) {
150        self.changes.push(TransactionChange::DropTable { table_name });
151    }
152
153    /// Record an index creation.
154    pub fn record_create_index(&mut self, index_name: String, table_name: String) {
155        self.changes.push(TransactionChange::CreateIndex { index_name, table_name });
156    }
157
158    /// Record an index drop.
159    pub fn record_drop_index(&mut self, index_name: String) {
160        self.changes.push(TransactionChange::DropIndex { index_name });
161    }
162
163    /// Get rows inserted in this transaction for a table.
164    pub fn get_inserted_rows(&self, table_name: &str) -> Option<&Vec<Row>> {
165        self.inserted_rows.get(table_name)
166    }
167
168    /// Get indices of rows deleted in this transaction for a table.
169    pub fn get_deleted_indices(&self, table_name: &str) -> Option<&Vec<usize>> {
170        self.deleted_indices.get(table_name)
171    }
172
173    /// Get updated rows for a table (index -> new_row).
174    pub fn get_updated_rows(&self, table_name: &str) -> Option<&HashMap<usize, Row>> {
175        self.updated_rows.get(table_name)
176    }
177
178    /// Check if a row at a given index was deleted in this transaction.
179    pub fn is_deleted(&self, table_name: &str, row_index: usize) -> bool {
180        self.deleted_indices
181            .get(table_name)
182            .is_some_and(|indices| indices.contains(&row_index))
183    }
184
185    /// Get the updated version of a row if it was updated in this transaction.
186    pub fn get_updated_row(&self, table_name: &str, row_index: usize) -> Option<&Row> {
187        self.updated_rows
188            .get(table_name)
189            .and_then(|updates| updates.get(&row_index))
190    }
191
192    /// Consume the transaction state and return all changes for commit.
193    pub fn take_changes(self) -> Vec<TransactionChange> {
194        self.changes
195    }
196
197    /// Get all changes (for inspection without consuming).
198    pub fn changes(&self) -> &[TransactionChange] {
199        &self.changes
200    }
201
202    /// Check if there are any uncommitted changes.
203    pub fn has_changes(&self) -> bool {
204        !self.changes.is_empty()
205    }
206
207    /// Clear all changes (for rollback).
208    pub fn clear(&mut self) {
209        self.changes.clear();
210        self.inserted_rows.clear();
211        self.deleted_indices.clear();
212        self.updated_rows.clear();
213    }
214}
215
216/// Manager for session transaction state.
217///
218/// Each session has its own `SessionTransactionManager` to track
219/// its transaction state independently of other sessions.
220#[derive(Debug, Default)]
221pub struct SessionTransactionManager {
222    /// Current transaction state (None if no active transaction)
223    current: Option<TransactionState>,
224    /// Next transaction ID to assign
225    next_id: u64,
226}
227
228impl SessionTransactionManager {
229    /// Create a new session transaction manager.
230    pub fn new() -> Self {
231        Self { current: None, next_id: 1 }
232    }
233
234    /// Begin a new transaction.
235    ///
236    /// Returns an error if a transaction is already active.
237    pub fn begin(&mut self) -> Result<u64, TransactionError> {
238        if self.current.is_some() {
239            return Err(TransactionError::AlreadyInTransaction);
240        }
241
242        let id = self.next_id;
243        self.next_id += 1;
244        self.current = Some(TransactionState::new(id));
245        Ok(id)
246    }
247
248    /// Commit the current transaction.
249    ///
250    /// Returns the changes to be applied to the shared database.
251    pub fn commit(&mut self) -> Result<Vec<TransactionChange>, TransactionError> {
252        let state = self.current.take().ok_or(TransactionError::NoActiveTransaction)?;
253        Ok(state.take_changes())
254    }
255
256    /// Rollback the current transaction.
257    ///
258    /// Discards all uncommitted changes.
259    pub fn rollback(&mut self) -> Result<(), TransactionError> {
260        self.current.take().ok_or(TransactionError::NoActiveTransaction)?;
261        Ok(())
262    }
263
264    /// Check if a transaction is currently active.
265    pub fn in_transaction(&self) -> bool {
266        self.current.as_ref().is_some_and(|s| s.active)
267    }
268
269    /// Get the current transaction ID, if any.
270    pub fn transaction_id(&self) -> Option<u64> {
271        self.current.as_ref().map(|s| s.id)
272    }
273
274    /// Get mutable access to the current transaction state.
275    pub fn current_mut(&mut self) -> Option<&mut TransactionState> {
276        self.current.as_mut()
277    }
278
279    /// Get read access to the current transaction state.
280    pub fn current(&self) -> Option<&TransactionState> {
281        self.current.as_ref()
282    }
283
284    /// Record an insert in the current transaction.
285    ///
286    /// No-op if not in a transaction.
287    pub fn record_insert(&mut self, table_name: String, row: Row) {
288        if let Some(state) = &mut self.current {
289            state.record_insert(table_name, row);
290        }
291    }
292
293    /// Record an update in the current transaction.
294    ///
295    /// No-op if not in a transaction.
296    pub fn record_update(
297        &mut self,
298        table_name: String,
299        row_index: usize,
300        old_row: Row,
301        new_row: Row,
302    ) {
303        if let Some(state) = &mut self.current {
304            state.record_update(table_name, row_index, old_row, new_row);
305        }
306    }
307
308    /// Record a delete in the current transaction.
309    ///
310    /// No-op if not in a transaction.
311    pub fn record_delete(&mut self, table_name: String, row_index: usize, row: Row) {
312        if let Some(state) = &mut self.current {
313            state.record_delete(table_name, row_index, row);
314        }
315    }
316}
317
318/// Errors that can occur during transaction management.
319#[derive(Debug, Clone, PartialEq, Eq)]
320pub enum TransactionError {
321    /// Attempted to begin a transaction when one is already active.
322    AlreadyInTransaction,
323    /// Attempted to commit/rollback when no transaction is active.
324    NoActiveTransaction,
325    /// A conflict was detected during commit.
326    CommitConflict(String),
327}
328
329impl std::fmt::Display for TransactionError {
330    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
331        match self {
332            TransactionError::AlreadyInTransaction => {
333                write!(f, "Transaction already in progress")
334            }
335            TransactionError::NoActiveTransaction => {
336                write!(f, "No transaction in progress")
337            }
338            TransactionError::CommitConflict(msg) => {
339                write!(f, "Commit conflict: {}", msg)
340            }
341        }
342    }
343}
344
345impl std::error::Error for TransactionError {}
346
347#[cfg(test)]
348mod tests {
349    use super::*;
350    use vibesql_types::SqlValue;
351
352    fn make_row(values: Vec<SqlValue>) -> Row {
353        Row::new(values)
354    }
355
356    #[test]
357    fn test_begin_transaction() {
358        let mut mgr = SessionTransactionManager::new();
359
360        assert!(!mgr.in_transaction());
361        assert_eq!(mgr.transaction_id(), None);
362
363        let id = mgr.begin().unwrap();
364        assert_eq!(id, 1);
365        assert!(mgr.in_transaction());
366        assert_eq!(mgr.transaction_id(), Some(1));
367    }
368
369    #[test]
370    fn test_double_begin_fails() {
371        let mut mgr = SessionTransactionManager::new();
372
373        mgr.begin().unwrap();
374        let result = mgr.begin();
375        assert_eq!(result, Err(TransactionError::AlreadyInTransaction));
376    }
377
378    #[test]
379    fn test_commit_without_transaction_fails() {
380        let mut mgr = SessionTransactionManager::new();
381
382        let result = mgr.commit();
383        assert_eq!(result, Err(TransactionError::NoActiveTransaction));
384    }
385
386    #[test]
387    fn test_rollback_without_transaction_fails() {
388        let mut mgr = SessionTransactionManager::new();
389
390        let result = mgr.rollback();
391        assert_eq!(result, Err(TransactionError::NoActiveTransaction));
392    }
393
394    #[test]
395    fn test_record_insert() {
396        let mut mgr = SessionTransactionManager::new();
397        mgr.begin().unwrap();
398
399        let row = make_row(vec![SqlValue::Integer(1), SqlValue::Varchar(arcstr::ArcStr::from("test"))]);
400        mgr.record_insert("users".to_string(), row.clone());
401
402        let state = mgr.current().unwrap();
403        assert!(state.has_changes());
404
405        let inserted = state.get_inserted_rows("users").unwrap();
406        assert_eq!(inserted.len(), 1);
407        assert_eq!(inserted[0].values, row.values);
408    }
409
410    #[test]
411    fn test_record_delete() {
412        let mut mgr = SessionTransactionManager::new();
413        mgr.begin().unwrap();
414
415        let row = make_row(vec![SqlValue::Integer(1)]);
416        mgr.record_delete("users".to_string(), 5, row);
417
418        let state = mgr.current().unwrap();
419        assert!(state.is_deleted("users", 5));
420        assert!(!state.is_deleted("users", 6));
421        assert!(!state.is_deleted("other_table", 5));
422    }
423
424    #[test]
425    fn test_record_update() {
426        let mut mgr = SessionTransactionManager::new();
427        mgr.begin().unwrap();
428
429        let old_row = make_row(vec![SqlValue::Integer(1), SqlValue::Varchar(arcstr::ArcStr::from("old"))]);
430        let new_row = make_row(vec![SqlValue::Integer(1), SqlValue::Varchar(arcstr::ArcStr::from("new"))]);
431        mgr.record_update("users".to_string(), 3, old_row, new_row.clone());
432
433        let state = mgr.current().unwrap();
434        let updated = state.get_updated_row("users", 3).unwrap();
435        assert_eq!(updated.values, new_row.values);
436        assert!(state.get_updated_row("users", 4).is_none());
437    }
438
439    #[test]
440    fn test_commit_returns_changes() {
441        let mut mgr = SessionTransactionManager::new();
442        mgr.begin().unwrap();
443
444        let row1 = make_row(vec![SqlValue::Integer(1)]);
445        let row2 = make_row(vec![SqlValue::Integer(2)]);
446        mgr.record_insert("users".to_string(), row1);
447        mgr.record_insert("users".to_string(), row2);
448
449        let changes = mgr.commit().unwrap();
450        assert_eq!(changes.len(), 2);
451        assert!(!mgr.in_transaction());
452    }
453
454    #[test]
455    fn test_rollback_discards_changes() {
456        let mut mgr = SessionTransactionManager::new();
457        mgr.begin().unwrap();
458
459        let row = make_row(vec![SqlValue::Integer(1)]);
460        mgr.record_insert("users".to_string(), row);
461
462        mgr.rollback().unwrap();
463        assert!(!mgr.in_transaction());
464
465        // Can start a new transaction after rollback
466        mgr.begin().unwrap();
467        assert!(mgr.in_transaction());
468        assert_eq!(mgr.transaction_id(), Some(2)); // ID incremented
469    }
470
471    #[test]
472    fn test_transaction_id_increments() {
473        let mut mgr = SessionTransactionManager::new();
474
475        let id1 = mgr.begin().unwrap();
476        mgr.commit().unwrap();
477
478        let id2 = mgr.begin().unwrap();
479        mgr.rollback().unwrap();
480
481        let id3 = mgr.begin().unwrap();
482
483        assert_eq!(id1, 1);
484        assert_eq!(id2, 2);
485        assert_eq!(id3, 3);
486    }
487
488    #[test]
489    fn test_no_op_when_not_in_transaction() {
490        let mut mgr = SessionTransactionManager::new();
491
492        // These should not panic and should be no-ops
493        let row = make_row(vec![SqlValue::Integer(1)]);
494        mgr.record_insert("users".to_string(), row.clone());
495        mgr.record_delete("users".to_string(), 0, row.clone());
496        mgr.record_update("users".to_string(), 0, row.clone(), row);
497
498        // Should not be in transaction
499        assert!(!mgr.in_transaction());
500    }
501}