quill_sql/transaction/
lock_manager.rs

1use crate::storage::page::RecordId;
2use crate::transaction::{Transaction, TransactionId};
3use crate::utils::table_ref::TableReference;
4use log::{trace, warn};
5use parking_lot::{Condvar, Mutex};
6use std::collections::{hash_map::Entry, HashMap, HashSet, VecDeque};
7use std::sync::atomic::{AtomicU64, Ordering};
8use std::sync::Arc;
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11pub enum LockMode {
12    Shared,
13    Exclusive,
14    IntentionShared,
15    IntentionExclusive,
16    SharedIntentionExclusive,
17}
18
19#[derive(Debug, Clone)]
20struct LockRequest {
21    id: u64,
22    txn_id: TransactionId,
23    mode: LockMode,
24    table_ref: TableReference,
25    rid: Option<RecordId>,
26    granted: bool,
27}
28
29#[derive(Debug, Default)]
30struct LockQueue {
31    requests: VecDeque<LockRequest>,
32}
33
34#[derive(Debug, Default)]
35struct ResourceLock {
36    state: Mutex<LockQueue>,
37    condvar: Condvar,
38}
39
40impl ResourceLock {
41    /// Create a new empty lock queue for a resource.
42    fn new() -> Self {
43        Self {
44            state: Mutex::new(LockQueue {
45                requests: VecDeque::new(),
46            }),
47            condvar: Condvar::new(),
48        }
49    }
50}
51
52type RowKey = (TableReference, RecordId);
53
54#[derive(Debug)]
55pub struct LockManager {
56    table_lock_map: Mutex<HashMap<TableReference, Arc<ResourceLock>>>,
57    row_lock_map: Mutex<HashMap<RowKey, Arc<ResourceLock>>>,
58    request_id: AtomicU64,
59    wait_for: Mutex<HashMap<TransactionId, HashSet<TransactionId>>>,
60}
61
62impl LockManager {
63    /// Create a new lock manager.
64    pub fn new() -> Self {
65        Self {
66            table_lock_map: Mutex::new(HashMap::new()),
67            row_lock_map: Mutex::new(HashMap::new()),
68            request_id: AtomicU64::new(1),
69            wait_for: Mutex::new(HashMap::new()),
70        }
71    }
72
73    /// Create a lock manager using default configuration.
74    pub fn default_instance() -> Self {
75        Self::new()
76    }
77
78    /// Acquire a table level lock for the given transaction.
79    pub fn lock_table(&self, txn: &Transaction, mode: LockMode, table_ref: TableReference) -> bool {
80        let resource = self.get_table_resource(table_ref.clone());
81        self.lock_resource(
82            resource,
83            LockRequest {
84                id: self.next_request_id(),
85                txn_id: txn.id(),
86                mode,
87                table_ref,
88                rid: None,
89                granted: false,
90            },
91        )
92    }
93
94    /// Release a table level lock held by the transaction.
95    pub fn unlock_table(&self, txn: &Transaction, table_ref: TableReference) -> bool {
96        self.unlock_table_raw(txn.id(), table_ref)
97    }
98
99    /// Acquire a row level lock for the given transaction.
100    pub fn lock_row(
101        &self,
102        txn: &Transaction,
103        mode: LockMode,
104        table_ref: TableReference,
105        rid: RecordId,
106    ) -> bool {
107        let resource = self.get_row_resource(table_ref.clone(), rid);
108        self.lock_resource(
109            resource,
110            LockRequest {
111                id: self.next_request_id(),
112                txn_id: txn.id(),
113                mode,
114                table_ref,
115                rid: Some(rid),
116                granted: false,
117            },
118        )
119    }
120
121    /// Release a row level lock held by the transaction.
122    pub fn unlock_row(
123        &self,
124        txn: &Transaction,
125        table_ref: TableReference,
126        rid: RecordId,
127        _force: bool,
128    ) -> bool {
129        self.unlock_row_raw(txn.id(), table_ref, rid)
130    }
131
132    pub fn unlock_table_raw(&self, txn_id: TransactionId, table_ref: TableReference) -> bool {
133        self.unlock_table_internal(txn_id, table_ref)
134    }
135
136    pub fn unlock_row_raw(
137        &self,
138        txn_id: TransactionId,
139        table_ref: TableReference,
140        rid: RecordId,
141    ) -> bool {
142        self.unlock_row_internal(txn_id, table_ref, rid)
143    }
144
145    /// Force release of all locks (used during shutdown/testing).
146    pub fn unlock_all(&self) {
147        {
148            let mut map = self.table_lock_map.lock();
149            for resource in map.values() {
150                let mut state = resource.state.lock();
151                state.requests.clear();
152                resource.condvar.notify_all();
153            }
154            map.clear();
155        }
156        {
157            let mut map = self.row_lock_map.lock();
158            for resource in map.values() {
159                let mut state = resource.state.lock();
160                state.requests.clear();
161                resource.condvar.notify_all();
162            }
163            map.clear();
164        }
165    }
166
167    fn next_request_id(&self) -> u64 {
168        self.request_id.fetch_add(1, Ordering::SeqCst)
169    }
170
171    fn get_table_resource(&self, table_ref: TableReference) -> Arc<ResourceLock> {
172        let mut map = self.table_lock_map.lock();
173        map.entry(table_ref)
174            .or_insert_with(|| Arc::new(ResourceLock::new()))
175            .clone()
176    }
177
178    fn get_row_resource(&self, table_ref: TableReference, rid: RecordId) -> Arc<ResourceLock> {
179        let key = (table_ref, rid);
180        let mut map = self.row_lock_map.lock();
181        map.entry(key)
182            .or_insert_with(|| Arc::new(ResourceLock::new()))
183            .clone()
184    }
185
186    fn lock_resource(&self, resource: Arc<ResourceLock>, request: LockRequest) -> bool {
187        let mut queue_guard = resource.state.lock();
188
189        let mut prev_mode: Option<LockMode> = None;
190        let mut txn_id = request.txn_id;
191        let request_id = if let Some(existing) = queue_guard.requests.iter_mut().find(|req| {
192            req.txn_id == request.txn_id
193                && req.rid == request.rid
194                && req.table_ref == request.table_ref
195        }) {
196            if existing.mode == request.mode {
197                return true;
198            }
199
200            if !can_upgrade(existing.mode, request.mode) {
201                return false;
202            }
203
204            prev_mode = Some(existing.mode);
205            txn_id = existing.txn_id;
206            existing.mode = request.mode;
207            existing.granted = false;
208            existing.id
209        } else {
210            queue_guard.requests.push_back(request);
211            queue_guard.requests.back().map(|req| req.id).unwrap_or(0)
212        };
213
214        loop {
215            if can_grant(&queue_guard.requests, request_id) {
216                if let Some(req) = queue_guard
217                    .requests
218                    .iter_mut()
219                    .find(|req| req.id == request_id)
220                {
221                    req.granted = true;
222                    trace!(
223                        "lock granted: txn={} resource={:?} mode={:?}",
224                        req.txn_id,
225                        (req.table_ref.clone(), req.rid),
226                        req.mode
227                    );
228                }
229                self.clear_wait_edges(txn_id);
230                return true;
231            }
232
233            let blockers = blockers_for(&queue_guard.requests, request_id);
234            if !blockers.is_empty() {
235                trace!("wait edge: txn={} blocking_on={:?}", txn_id, blockers);
236                if self.record_wait(txn_id, &blockers) {
237                    warn!("deadlock detected: txn={}", txn_id);
238                    if let Some(mode) = prev_mode {
239                        if let Some(req) = queue_guard
240                            .requests
241                            .iter_mut()
242                            .find(|req| req.id == request_id)
243                        {
244                            req.mode = mode;
245                            req.granted = true;
246                        }
247                    } else {
248                        queue_guard.requests.retain(|req| req.id != request_id);
249                    }
250                    self.clear_wait_edges(txn_id);
251                    return false;
252                }
253            }
254            resource.condvar.wait(&mut queue_guard);
255            self.clear_wait_edges(txn_id);
256        }
257    }
258
259    fn record_wait(&self, txn_id: TransactionId, blockers: &[TransactionId]) -> bool {
260        let mut wait_for = self.wait_for.lock();
261        let entry = wait_for.entry(txn_id).or_default();
262        entry.clear();
263        entry.extend(blockers.iter().copied());
264        self.has_cycle(&wait_for, txn_id)
265    }
266
267    fn clear_wait_edges(&self, txn: TransactionId) {
268        let mut wait_for = self.wait_for.lock();
269        wait_for.remove(&txn);
270        for edges in wait_for.values_mut() {
271            edges.remove(&txn);
272        }
273    }
274
275    fn has_cycle(
276        &self,
277        wait_for: &HashMap<TransactionId, HashSet<TransactionId>>,
278        start: TransactionId,
279    ) -> bool {
280        fn dfs(
281            graph: &HashMap<TransactionId, HashSet<TransactionId>>,
282            node: TransactionId,
283            start: TransactionId,
284            visited: &mut HashSet<TransactionId>,
285        ) -> bool {
286            if !visited.insert(node) {
287                return false;
288            }
289            if let Some(edges) = graph.get(&node) {
290                for &next in edges {
291                    if next == start || dfs(graph, next, start, visited) {
292                        return true;
293                    }
294                }
295            }
296            visited.remove(&node);
297            false
298        }
299        dfs(wait_for, start, start, &mut HashSet::new())
300    }
301
302    fn unlock_table_internal(&self, txn_id: TransactionId, table_ref: TableReference) -> bool {
303        let resource = {
304            let map = self.table_lock_map.lock();
305            map.get(&table_ref).cloned()
306        };
307        let Some(resource) = resource else {
308            return false;
309        };
310
311        let mut queue_guard = resource.state.lock();
312        let original_len = queue_guard.requests.len();
313        queue_guard
314            .requests
315            .retain(|req| !(req.txn_id == txn_id && req.rid.is_none()));
316        let removed = queue_guard.requests.len() != original_len;
317        if removed {
318            resource.condvar.notify_all();
319        }
320        let empty = queue_guard.requests.is_empty();
321        drop(queue_guard);
322
323        if empty {
324            let mut map = self.table_lock_map.lock();
325            if let Entry::Occupied(entry) = map.entry(table_ref) {
326                if Arc::ptr_eq(entry.get(), &resource) {
327                    entry.remove();
328                }
329            }
330        }
331
332        removed
333    }
334
335    fn unlock_row_internal(
336        &self,
337        txn_id: TransactionId,
338        table_ref: TableReference,
339        rid: RecordId,
340    ) -> bool {
341        let key = (table_ref.clone(), rid);
342        let resource = {
343            let map = self.row_lock_map.lock();
344            map.get(&key).cloned()
345        };
346        let Some(resource) = resource else {
347            return false;
348        };
349
350        let mut queue_guard = resource.state.lock();
351        let original_len = queue_guard.requests.len();
352        queue_guard.requests.retain(|req| {
353            if req.txn_id != txn_id {
354                return true;
355            }
356            if let Some(lock_rid) = req.rid {
357                return !(lock_rid == rid && req.table_ref == table_ref);
358            }
359            true
360        });
361        let removed = queue_guard.requests.len() != original_len;
362        if removed {
363            resource.condvar.notify_all();
364        }
365        let empty = queue_guard.requests.is_empty();
366        drop(queue_guard);
367
368        if empty {
369            let mut map = self.row_lock_map.lock();
370            if let Entry::Occupied(entry) = map.entry(key) {
371                if Arc::ptr_eq(entry.get(), &resource) {
372                    entry.remove();
373                }
374            }
375        }
376
377        removed
378    }
379}
380
381impl Default for LockManager {
382    fn default() -> Self {
383        Self::new()
384    }
385}
386
387fn can_grant(queue: &VecDeque<LockRequest>, request_id: u64) -> bool {
388    let (position, request) = match queue
389        .iter()
390        .enumerate()
391        .find(|(_, req)| req.id == request_id)
392    {
393        Some(pair) => pair,
394        None => return false,
395    };
396
397    for pending in queue.iter().take(position) {
398        if !pending.granted {
399            return false;
400        }
401    }
402
403    for granted in queue.iter().filter(|req| req.granted) {
404        if granted.id == request_id {
405            continue;
406        }
407        if granted.txn_id == request.txn_id {
408            continue;
409        }
410        if !modes_compatible(request.mode, granted.mode) {
411            return false;
412        }
413    }
414    true
415}
416
417fn blockers_for(queue: &VecDeque<LockRequest>, request_id: u64) -> Vec<TransactionId> {
418    let Some((position, target)) = queue
419        .iter()
420        .enumerate()
421        .find(|(_, req)| req.id == request_id)
422    else {
423        return Vec::new();
424    };
425    queue
426        .iter()
427        .take(position)
428        .filter(|req| req.granted && req.txn_id != target.txn_id)
429        .map(|req| req.txn_id)
430        .collect()
431}
432
433fn can_upgrade(held: LockMode, requested: LockMode) -> bool {
434    matches!(
435        (held, requested),
436        (LockMode::Shared, LockMode::Exclusive)
437            | (LockMode::Shared, LockMode::SharedIntentionExclusive)
438            | (LockMode::IntentionShared, LockMode::IntentionExclusive)
439            | (
440                LockMode::IntentionShared,
441                LockMode::SharedIntentionExclusive
442            )
443            | (
444                LockMode::IntentionExclusive,
445                LockMode::SharedIntentionExclusive
446            )
447    )
448}
449
450fn modes_compatible(requested: LockMode, held: LockMode) -> bool {
451    match requested {
452        LockMode::Shared => matches!(
453            held,
454            LockMode::Shared | LockMode::IntentionShared | LockMode::SharedIntentionExclusive
455        ),
456        LockMode::Exclusive => false,
457        LockMode::IntentionShared => matches!(
458            held,
459            LockMode::Shared
460                | LockMode::IntentionShared
461                | LockMode::IntentionExclusive
462                | LockMode::SharedIntentionExclusive
463        ),
464        LockMode::IntentionExclusive => matches!(
465            held,
466            LockMode::IntentionShared | LockMode::IntentionExclusive
467        ),
468        LockMode::SharedIntentionExclusive => {
469            matches!(held, LockMode::IntentionShared | LockMode::Shared)
470        }
471    }
472}
473
474#[cfg(test)]
475mod tests {
476    use super::*;
477    use crate::transaction::{IsolationLevel, Transaction};
478    use crate::utils::table_ref::TableReference;
479    use std::sync::atomic::{AtomicBool, Ordering as AtomicOrdering};
480    use std::sync::Arc;
481    use std::thread;
482    use std::time::Duration;
483
484    fn new_txn(id: TransactionId) -> Transaction {
485        Transaction::new(
486            id,
487            IsolationLevel::ReadCommitted,
488            sqlparser::ast::TransactionAccessMode::ReadWrite,
489            true,
490        )
491    }
492
493    #[test]
494    fn shared_locks_are_compatible() {
495        let manager = LockManager::new();
496        let table = TableReference::Bare {
497            table: "t_shared".to_string(),
498        };
499        let txn1 = new_txn(1);
500        let txn2 = new_txn(2);
501
502        assert!(manager.lock_table(&txn1, LockMode::Shared, table.clone()));
503        assert!(manager.lock_table(&txn2, LockMode::Shared, table.clone()));
504
505        assert!(manager.unlock_table(&txn1, table.clone()));
506        assert!(manager.unlock_table(&txn2, table));
507    }
508
509    #[test]
510    fn exclusive_waits_for_shared() {
511        let manager = Arc::new(LockManager::new());
512        let table = TableReference::Bare {
513            table: "t_block".to_string(),
514        };
515        let txn1 = new_txn(10);
516        let txn2 = new_txn(20);
517
518        assert!(manager.lock_table(&txn1, LockMode::Shared, table.clone()));
519
520        let acquired = Arc::new(AtomicBool::new(false));
521        let acquired_clone = acquired.clone();
522        let manager_clone = manager.clone();
523        let table_clone = table.clone();
524
525        let handle = thread::spawn(move || {
526            let ok = manager_clone.lock_table(&txn2, LockMode::Exclusive, table_clone.clone());
527            acquired_clone.store(ok, AtomicOrdering::SeqCst);
528            if ok {
529                manager_clone.unlock_table(&txn2, table_clone);
530            }
531        });
532
533        thread::sleep(Duration::from_millis(20));
534        assert!(!acquired.load(AtomicOrdering::SeqCst));
535
536        assert!(manager.unlock_table(&txn1, table.clone()));
537        handle.join().unwrap();
538        assert!(acquired.load(AtomicOrdering::SeqCst));
539    }
540
541    #[test]
542    fn row_lock_conflict_blocks() {
543        let manager = Arc::new(LockManager::new());
544        let table = TableReference::Bare {
545            table: "t_row".to_string(),
546        };
547        let rid = RecordId::new(1, 1);
548        let writer = new_txn(100);
549        let reader = new_txn(200);
550
551        assert!(manager.lock_row(&writer, LockMode::Exclusive, table.clone(), rid));
552
553        let proceed = Arc::new(AtomicBool::new(false));
554        let proceed_clone = proceed.clone();
555        let manager_clone = manager.clone();
556        let table_clone = table.clone();
557
558        let handle = thread::spawn(move || {
559            let ok = manager_clone.lock_row(&reader, LockMode::Shared, table_clone.clone(), rid);
560            proceed_clone.store(ok, AtomicOrdering::SeqCst);
561            if ok {
562                manager_clone.unlock_row(&reader, table_clone, rid, false);
563            }
564        });
565
566        thread::sleep(Duration::from_millis(20));
567        assert!(!proceed.load(AtomicOrdering::SeqCst));
568
569        assert!(manager.unlock_row(&writer, table.clone(), rid, false));
570        handle.join().unwrap();
571        assert!(proceed.load(AtomicOrdering::SeqCst));
572    }
573}