1use std::collections::HashSet;
2use std::sync::atomic::{AtomicU64, Ordering};
3use std::sync::Arc;
4
5use crate::error::{QuillSQLError, QuillSQLResult};
6use crate::storage::record::RecordId;
7use crate::transaction::{
8 IsolationLevel, LockManager, LockMode, Transaction, TransactionId, TransactionSnapshot,
9 TransactionState, TransactionStatus,
10};
11use crate::utils::table_ref::TableReference;
12use dashmap::{DashMap, DashSet};
13use serde::Serialize;
14use sqlparser::ast::TransactionAccessMode;
15
16#[derive(Debug, Default)]
17struct HeldLocks {
18 tables: Vec<(TableReference, LockMode)>,
19 rows: Vec<(TableReference, RecordId, LockMode)>,
20 row_keys: HashSet<(TableReference, RecordId)>,
21}
22
23pub struct TransactionManager {
24 next_txn_id: AtomicU64,
25 active_txns: DashSet<TransactionId>,
26 lock_manager: Arc<LockManager>,
27 held_locks: DashMap<TransactionId, HeldLocks>,
28 txn_statuses: DashMap<TransactionId, TransactionStatus>,
29}
30
31impl Default for TransactionManager {
32 fn default() -> Self {
33 Self::new()
34 }
35}
36
37#[derive(Debug, Clone, Serialize)]
38pub struct TxnDebugEntry {
39 pub txn_id: TransactionId,
40 pub status: TransactionStatus,
41 pub held_tables: usize,
42 pub held_rows: usize,
43}
44
45#[derive(Debug, Clone, Serialize)]
46pub struct TxnDebugSnapshot {
47 pub active: Vec<TxnDebugEntry>,
48 pub oldest_active: Option<TransactionId>,
49 pub next_txn_id: TransactionId,
50}
51
52impl TransactionManager {
53 pub fn new() -> Self {
54 Self {
55 next_txn_id: AtomicU64::new(1),
56 active_txns: DashSet::new(),
57 lock_manager: Arc::new(LockManager::new()),
58 held_locks: DashMap::new(),
59 txn_statuses: DashMap::new(),
60 }
61 }
62
63 pub fn with_lock_manager(lock_manager: Arc<LockManager>) -> Self {
64 Self {
65 next_txn_id: AtomicU64::new(1),
66 active_txns: DashSet::new(),
67 lock_manager,
68 held_locks: DashMap::new(),
69 txn_statuses: DashMap::new(),
70 }
71 }
72
73 pub fn lock_manager_arc(&self) -> Arc<LockManager> {
74 self.lock_manager.clone()
75 }
76
77 pub fn begin(
78 &self,
79 isolation_level: IsolationLevel,
80 access_mode: TransactionAccessMode,
81 ) -> QuillSQLResult<Transaction> {
82 let txn_id = self.next_txn_id.fetch_add(1, Ordering::SeqCst);
83 if txn_id == 0 {
84 return Err(QuillSQLError::Internal(
85 "Transaction ID wrapped around".to_string(),
86 ));
87 }
88 let txn = Transaction::new(txn_id, isolation_level, access_mode);
89 self.active_txns.insert(txn_id);
90 self.txn_statuses
91 .insert(txn_id, TransactionStatus::InProgress);
92 self.held_locks.insert(txn_id, HeldLocks::default());
93 Ok(txn)
94 }
95
96 pub fn acquire_table_lock(
97 &self,
98 txn: &Transaction,
99 table: TableReference,
100 mode: LockMode,
101 ) -> QuillSQLResult<()> {
102 if self.lock_manager.lock_table(txn, mode, table.clone()) {
103 if let Some(mut entry) = self.held_locks.get_mut(&txn.id()) {
104 entry.tables.push((table, mode));
105 } else {
106 let mut new_entry = HeldLocks::default();
107 new_entry.tables.push((table, mode));
108 self.held_locks.insert(txn.id(), new_entry);
109 }
110 Ok(())
111 } else {
112 Err(QuillSQLError::Internal(format!(
113 "Failed to acquire table lock for txn {}",
114 txn.id()
115 )))
116 }
117 }
118
119 pub fn try_acquire_row_lock(
120 &self,
121 txn: &Transaction,
122 table: TableReference,
123 rid: RecordId,
124 mode: LockMode,
125 ) -> QuillSQLResult<bool> {
126 let key = (table.clone(), rid);
127 if let Some(entry) = self.held_locks.get(&txn.id()) {
128 if entry.row_keys.contains(&key) {
129 return Ok(true);
130 }
131 }
132 if self.lock_manager.lock_row(txn, mode, table.clone(), rid) {
133 self.record_row_lock(txn.id(), table, rid, mode);
134 Ok(true)
135 } else {
136 Ok(false)
137 }
138 }
139
140 pub fn acquire_row_lock(
141 &self,
142 txn: &Transaction,
143 table: TableReference,
144 rid: RecordId,
145 mode: LockMode,
146 ) -> QuillSQLResult<()> {
147 if !self.try_acquire_row_lock(txn, table.clone(), rid, mode)? {
148 return Err(QuillSQLError::Internal(format!(
149 "Failed to acquire row lock for txn {}",
150 txn.id()
151 )));
152 }
153 Ok(())
154 }
155
156 pub fn commit(&self, txn: &mut Transaction) -> QuillSQLResult<()> {
157 match txn.state() {
158 TransactionState::Running | TransactionState::Tainted => {}
159 TransactionState::Committed => {
160 return Err(QuillSQLError::Internal(format!(
161 "Transaction {} already committed",
162 txn.id()
163 )))
164 }
165 TransactionState::Aborted => {
166 return Err(QuillSQLError::Internal(format!(
167 "Transaction {} already aborted",
168 txn.id()
169 )))
170 }
171 }
172
173 let txn_id = txn.id();
174 txn.set_state(TransactionState::Committed);
175
176 self.active_txns.remove(&txn_id);
177 self.txn_statuses
178 .insert(txn_id, TransactionStatus::Committed);
179 self.release_all_locks(txn_id);
180 txn.clear_undo();
181 txn.clear_snapshot();
182 Ok(())
183 }
184
185 pub fn abort(&self, txn: &mut Transaction) -> QuillSQLResult<()> {
186 match txn.state() {
187 TransactionState::Committed => {
188 return Err(QuillSQLError::Internal(format!(
189 "Transaction {} already committed",
190 txn.id()
191 )))
192 }
193 TransactionState::Aborted => return Ok(()),
194 TransactionState::Running | TransactionState::Tainted => {}
195 }
196
197 let txn_id = txn.id();
198 while let Some(action) = txn.pop_undo_action() {
199 action.undo(txn_id)?;
200 }
201
202 txn.set_state(TransactionState::Aborted);
203
204 self.active_txns.remove(&txn_id);
205 self.txn_statuses.insert(txn_id, TransactionStatus::Aborted);
206 self.release_all_locks(txn_id);
207 txn.clear_undo();
208 txn.clear_snapshot();
209 Ok(())
210 }
211
212 pub fn active_transactions(&self) -> Vec<TransactionId> {
213 self.active_txns.iter().map(|txn| *txn).collect()
214 }
215
216 pub fn snapshot(&self, txn_id: TransactionId) -> TransactionSnapshot {
217 let active: Vec<TransactionId> = self
218 .active_txns
219 .iter()
220 .map(|id| *id)
221 .filter(|id| *id != txn_id)
222 .collect();
223 let xmax = self.next_txn_id.load(Ordering::SeqCst);
224 let xmin = active.iter().copied().min().unwrap_or(xmax);
225 TransactionSnapshot::new(txn_id, xmin, xmax, active)
226 }
227
228 pub fn transaction_status(&self, txn_id: TransactionId) -> TransactionStatus {
229 if txn_id == 0 {
230 return TransactionStatus::Committed;
231 }
232 self.txn_statuses
233 .get(&txn_id)
234 .map(|entry| *entry.value())
235 .unwrap_or(TransactionStatus::Unknown)
236 }
237
238 pub fn record_recovered_status(&self, txn_id: TransactionId, status: TransactionStatus) {
239 if txn_id == 0 {
240 return;
241 }
242 self.txn_statuses.insert(txn_id, status);
243 }
244
245 pub fn ensure_next_txn_id_at_least(&self, next_txn_id: TransactionId) {
246 let mut current = self.next_txn_id.load(Ordering::SeqCst);
247 while current < next_txn_id {
248 match self.next_txn_id.compare_exchange(
249 current,
250 next_txn_id,
251 Ordering::SeqCst,
252 Ordering::SeqCst,
253 ) {
254 Ok(_) => break,
255 Err(observed) => current = observed,
256 }
257 }
258 }
259
260 pub fn oldest_active_txn(&self) -> Option<TransactionId> {
261 self.active_txns.iter().map(|txn| *txn).min()
262 }
263
264 pub fn next_txn_id_hint(&self) -> TransactionId {
265 self.next_txn_id.load(Ordering::SeqCst)
266 }
267
268 pub fn debug_snapshot(&self) -> TxnDebugSnapshot {
269 let active_ids = self.active_transactions();
270 let mut active = Vec::with_capacity(active_ids.len());
271 for txn_id in active_ids {
272 let status = self.transaction_status(txn_id);
273 let held = self.held_locks.get(&txn_id);
274 let (held_tables, held_rows) = held
275 .map(|locks| (locks.tables.len(), locks.rows.len()))
276 .unwrap_or((0, 0));
277 active.push(TxnDebugEntry {
278 txn_id,
279 status,
280 held_tables,
281 held_rows,
282 });
283 }
284 TxnDebugSnapshot {
285 active,
286 oldest_active: self.oldest_active_txn(),
287 next_txn_id: self.next_txn_id_hint(),
288 }
289 }
290
291 pub fn record_row_lock(
292 &self,
293 txn_id: TransactionId,
294 table: TableReference,
295 rid: RecordId,
296 mode: LockMode,
297 ) {
298 let mut entry = self.held_locks.entry(txn_id).or_default();
299 if entry.row_keys.insert((table.clone(), rid)) {
300 entry.rows.push((table, rid, mode));
301 }
302 }
303
304 pub fn unlock_row(&self, txn_id: TransactionId, table: &TableReference, rid: RecordId) {
305 if self.lock_manager.unlock_row_raw(txn_id, table.clone(), rid) {
306 if let Some(mut entry) = self.held_locks.get_mut(&txn_id) {
307 entry.row_keys.remove(&(table.clone(), rid));
308 entry.rows.retain(|(t, r, _)| !(t == table && *r == rid));
309 }
310 }
311 }
312
313 fn release_all_locks(&self, txn_id: TransactionId) {
314 if let Some((_, mut held)) = self.held_locks.remove(&txn_id) {
315 for (table, rid, _) in held.rows.drain(..).rev() {
316 let _ = self.lock_manager.unlock_row_raw(txn_id, table, rid);
317 }
318 for (table, _) in held.tables.drain(..).rev() {
319 let _ = self.lock_manager.unlock_table_raw(txn_id, table);
320 }
321 }
322 }
323}
324
325#[cfg(test)]
326mod tests {
327 use super::*;
328 use crate::storage::record::TupleMeta;
329
330 #[test]
331 fn commit_marks_state_and_status() {
332 let manager = TransactionManager::new();
333
334 let mut txn = manager
335 .begin(
336 IsolationLevel::ReadUncommitted,
337 TransactionAccessMode::ReadWrite,
338 )
339 .expect("begin txn");
340 manager.commit(&mut txn).expect("commit");
341
342 assert_eq!(txn.state(), TransactionState::Committed);
343 assert_eq!(
344 manager.transaction_status(txn.id()),
345 TransactionStatus::Committed
346 );
347 }
348
349 #[test]
350 fn abort_marks_state_and_status() {
351 let manager = TransactionManager::new();
352
353 let mut txn = manager
354 .begin(
355 IsolationLevel::ReadUncommitted,
356 TransactionAccessMode::ReadWrite,
357 )
358 .expect("begin txn");
359 manager.abort(&mut txn).expect("abort");
360
361 assert_eq!(txn.state(), TransactionState::Aborted);
362 assert_eq!(
363 manager.transaction_status(txn.id()),
364 TransactionStatus::Aborted
365 );
366 }
367
368 #[test]
369 fn snapshot_excludes_running_insert_until_commit() {
370 let manager = TransactionManager::new();
371
372 let mut writer = manager
373 .begin(
374 IsolationLevel::ReadCommitted,
375 TransactionAccessMode::ReadWrite,
376 )
377 .expect("writer");
378 let meta = TupleMeta::new(writer.id(), 0);
379
380 let mut reader = manager
381 .begin(
382 IsolationLevel::ReadCommitted,
383 TransactionAccessMode::ReadWrite,
384 )
385 .expect("reader");
386 let snapshot = manager.snapshot(reader.id());
387 assert!(
388 !snapshot.is_visible(&meta, 0, |tid| manager.transaction_status(tid)),
389 "running writer should not be visible",
390 );
391
392 manager.commit(&mut writer).expect("commit writer");
393 let snapshot_after_commit = manager.snapshot(reader.id());
394 assert!(snapshot_after_commit.is_visible(&meta, 0, |tid| manager.transaction_status(tid)));
395
396 manager.abort(&mut reader).expect("abort reader");
397 }
398
399 #[test]
400 fn snapshot_treats_committed_delete_as_invisible() {
401 let manager = TransactionManager::new();
402
403 let mut inserter = manager
404 .begin(
405 IsolationLevel::ReadCommitted,
406 TransactionAccessMode::ReadWrite,
407 )
408 .expect("insert txn");
409 let mut meta = TupleMeta::new(inserter.id(), 0);
410 manager.commit(&mut inserter).expect("commit insert");
411
412 let mut deleter = manager
413 .begin(
414 IsolationLevel::ReadCommitted,
415 TransactionAccessMode::ReadWrite,
416 )
417 .expect("delete txn");
418 meta.mark_deleted(deleter.id(), 0);
419
420 let mut reader = manager
421 .begin(
422 IsolationLevel::ReadCommitted,
423 TransactionAccessMode::ReadWrite,
424 )
425 .expect("reader txn");
426
427 let before_commit = manager.snapshot(reader.id());
428 assert!(before_commit.is_visible(&meta, 0, |tid| manager.transaction_status(tid)));
429
430 manager.commit(&mut deleter).expect("commit delete");
431 let after_commit = manager.snapshot(reader.id());
432 assert!(
433 !after_commit.is_visible(&meta, 0, |tid| manager.transaction_status(tid)),
434 "committed delete should hide tuple",
435 );
436
437 manager.abort(&mut reader).expect("abort reader");
438 }
439}