quill_sql/transaction/
lock_manager.rs

1use crate::storage::page::RecordId;
2use crate::transaction::{Transaction, TransactionId};
3use crate::utils::table_ref::TableReference;
4use log::{trace, warn};
5use parking_lot::{Condvar, Mutex};
6use serde::Serialize;
7use std::collections::{hash_map::Entry, HashMap, HashSet, VecDeque};
8use std::sync::atomic::{AtomicU64, Ordering};
9use std::sync::Arc;
10use std::time::{Duration, Instant};
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
13pub enum LockMode {
14    Shared,
15    Exclusive,
16    IntentionShared,
17    IntentionExclusive,
18    SharedIntentionExclusive,
19}
20
21#[derive(Debug, Clone)]
22struct LockRequest {
23    id: u64,
24    txn_id: TransactionId,
25    mode: LockMode,
26    table_ref: TableReference,
27    rid: Option<RecordId>,
28    granted: bool,
29}
30
31#[derive(Debug, Default)]
32struct LockQueue {
33    requests: VecDeque<LockRequest>,
34}
35
36#[derive(Debug, Default)]
37struct ResourceLock {
38    state: Mutex<LockQueue>,
39    condvar: Condvar,
40}
41
42impl ResourceLock {
43    /// Create a new empty lock queue for a resource.
44    fn new() -> Self {
45        Self {
46            state: Mutex::new(LockQueue {
47                requests: VecDeque::new(),
48            }),
49            condvar: Condvar::new(),
50        }
51    }
52}
53
54type RowKey = (TableReference, RecordId);
55
56#[derive(Debug)]
57pub struct LockManager {
58    table_lock_map: Mutex<HashMap<TableReference, Arc<ResourceLock>>>,
59    row_lock_map: Mutex<HashMap<RowKey, Arc<ResourceLock>>>,
60    request_id: AtomicU64,
61    wait_for: Mutex<HashMap<TransactionId, HashSet<TransactionId>>>,
62    deadlock_timeout: Duration,
63}
64
65#[derive(Debug, Clone, Serialize)]
66pub struct LockGraphEdge {
67    pub from: TransactionId,
68    pub to: TransactionId,
69}
70
71#[derive(Debug, Clone, Serialize)]
72pub struct LockRequestDebug {
73    pub txn_id: TransactionId,
74    pub mode: LockMode,
75    pub granted: bool,
76    pub table: String,
77    pub rid: Option<String>,
78}
79
80#[derive(Debug, Clone, Serialize)]
81pub struct LockResourceDebug {
82    pub resource: String,
83    pub requests: Vec<LockRequestDebug>,
84}
85
86#[derive(Debug, Clone, Serialize)]
87pub struct LockDebugSnapshot {
88    pub wait_for: Vec<LockGraphEdge>,
89    pub tables: Vec<LockResourceDebug>,
90    pub rows: Vec<LockResourceDebug>,
91}
92
93impl LockManager {
94    /// Create a new lock manager.
95    pub fn new() -> Self {
96        Self {
97            table_lock_map: Mutex::new(HashMap::new()),
98            row_lock_map: Mutex::new(HashMap::new()),
99            request_id: AtomicU64::new(1),
100            wait_for: Mutex::new(HashMap::new()),
101            deadlock_timeout: Duration::from_millis(1000),
102        }
103    }
104
105    /// Create a lock manager using default configuration.
106    pub fn default_instance() -> Self {
107        Self::new()
108    }
109
110    /// Acquire a table level lock for the given transaction.
111    pub fn lock_table(&self, txn: &Transaction, mode: LockMode, table_ref: TableReference) -> bool {
112        let resource = self.get_table_resource(table_ref.clone());
113        self.lock_resource(
114            resource,
115            LockRequest {
116                id: self.next_request_id(),
117                txn_id: txn.id(),
118                mode,
119                table_ref,
120                rid: None,
121                granted: false,
122            },
123        )
124    }
125
126    /// Release a table level lock held by the transaction.
127    pub fn unlock_table(&self, txn: &Transaction, table_ref: TableReference) -> bool {
128        self.unlock_table_raw(txn.id(), table_ref)
129    }
130
131    /// Acquire a row level lock for the given transaction.
132    pub fn lock_row(
133        &self,
134        txn: &Transaction,
135        mode: LockMode,
136        table_ref: TableReference,
137        rid: RecordId,
138    ) -> bool {
139        let resource = self.get_row_resource(table_ref.clone(), rid);
140        self.lock_resource(
141            resource,
142            LockRequest {
143                id: self.next_request_id(),
144                txn_id: txn.id(),
145                mode,
146                table_ref,
147                rid: Some(rid),
148                granted: false,
149            },
150        )
151    }
152
153    /// Release a row level lock held by the transaction.
154    pub fn unlock_row(
155        &self,
156        txn: &Transaction,
157        table_ref: TableReference,
158        rid: RecordId,
159        _force: bool,
160    ) -> bool {
161        self.unlock_row_raw(txn.id(), table_ref, rid)
162    }
163
164    pub fn unlock_table_raw(&self, txn_id: TransactionId, table_ref: TableReference) -> bool {
165        self.unlock_table_internal(txn_id, table_ref)
166    }
167
168    pub fn unlock_row_raw(
169        &self,
170        txn_id: TransactionId,
171        table_ref: TableReference,
172        rid: RecordId,
173    ) -> bool {
174        self.unlock_row_internal(txn_id, table_ref, rid)
175    }
176
177    /// Force release of all locks (used during shutdown/testing).
178    pub fn unlock_all(&self) {
179        {
180            let mut map = self.table_lock_map.lock();
181            for resource in map.values() {
182                let mut state = resource.state.lock();
183                state.requests.clear();
184                resource.condvar.notify_all();
185            }
186            map.clear();
187        }
188        {
189            let mut map = self.row_lock_map.lock();
190            for resource in map.values() {
191                let mut state = resource.state.lock();
192                state.requests.clear();
193                resource.condvar.notify_all();
194            }
195            map.clear();
196        }
197    }
198
199    pub fn debug_snapshot(&self) -> LockDebugSnapshot {
200        let wait_for = {
201            let guard = self.wait_for.lock();
202            guard
203                .iter()
204                .flat_map(|(from, blocked)| {
205                    blocked.iter().map(move |to| LockGraphEdge {
206                        from: *from,
207                        to: *to,
208                    })
209                })
210                .collect()
211        };
212
213        let table_resources = {
214            let map = self.table_lock_map.lock();
215            map.iter()
216                .map(|(table_ref, resource)| {
217                    let state = resource.state.lock();
218                    let requests = state
219                        .requests
220                        .iter()
221                        .map(|req| LockRequestDebug {
222                            txn_id: req.txn_id,
223                            mode: req.mode,
224                            granted: req.granted,
225                            table: req.table_ref.to_string(),
226                            rid: req.rid.map(|rid| rid.to_string()),
227                        })
228                        .collect();
229                    LockResourceDebug {
230                        resource: table_ref.to_string(),
231                        requests,
232                    }
233                })
234                .collect()
235        };
236
237        let row_resources = {
238            let map = self.row_lock_map.lock();
239            map.iter()
240                .map(|((table_ref, rid), resource)| {
241                    let state = resource.state.lock();
242                    let requests = state
243                        .requests
244                        .iter()
245                        .map(|req| LockRequestDebug {
246                            txn_id: req.txn_id,
247                            mode: req.mode,
248                            granted: req.granted,
249                            table: req.table_ref.to_string(),
250                            rid: req.rid.map(|rid| rid.to_string()),
251                        })
252                        .collect();
253                    LockResourceDebug {
254                        resource: format!("{}#{}", table_ref, rid),
255                        requests,
256                    }
257                })
258                .collect()
259        };
260
261        LockDebugSnapshot {
262            wait_for,
263            tables: table_resources,
264            rows: row_resources,
265        }
266    }
267
268    fn next_request_id(&self) -> u64 {
269        self.request_id.fetch_add(1, Ordering::SeqCst)
270    }
271
272    fn get_table_resource(&self, table_ref: TableReference) -> Arc<ResourceLock> {
273        let mut map = self.table_lock_map.lock();
274        map.entry(table_ref)
275            .or_insert_with(|| Arc::new(ResourceLock::new()))
276            .clone()
277    }
278
279    fn get_row_resource(&self, table_ref: TableReference, rid: RecordId) -> Arc<ResourceLock> {
280        let key = (table_ref, rid);
281        let mut map = self.row_lock_map.lock();
282        map.entry(key)
283            .or_insert_with(|| Arc::new(ResourceLock::new()))
284            .clone()
285    }
286
287    fn lock_resource(&self, resource: Arc<ResourceLock>, request: LockRequest) -> bool {
288        let mut queue_guard = resource.state.lock();
289
290        let mut prev_mode: Option<LockMode> = None;
291        let mut txn_id = request.txn_id;
292        let request_id = if let Some(existing) = queue_guard.requests.iter_mut().find(|req| {
293            req.txn_id == request.txn_id
294                && req.rid == request.rid
295                && req.table_ref == request.table_ref
296        }) {
297            if existing.mode == request.mode {
298                return true;
299            }
300
301            if !can_upgrade(existing.mode, request.mode) {
302                return false;
303            }
304
305            prev_mode = Some(existing.mode);
306            txn_id = existing.txn_id;
307            existing.mode = request.mode;
308            existing.granted = false;
309            existing.id
310        } else {
311            queue_guard.requests.push_back(request);
312            queue_guard.requests.back().map(|req| req.id).unwrap_or(0)
313        };
314
315        let mut wait_since = Instant::now();
316        loop {
317            if can_grant(&queue_guard.requests, request_id) {
318                if let Some(req) = queue_guard
319                    .requests
320                    .iter_mut()
321                    .find(|req| req.id == request_id)
322                {
323                    req.granted = true;
324                    trace!(
325                        "lock granted: txn={} resource={:?} mode={:?}",
326                        req.txn_id,
327                        (req.table_ref.clone(), req.rid),
328                        req.mode
329                    );
330                }
331                self.clear_wait_edges(txn_id);
332                return true;
333            }
334
335            let blockers = blockers_for(&queue_guard.requests, request_id);
336            if !blockers.is_empty() {
337                trace!("wait edge: txn={} blocking_on={:?}", txn_id, blockers);
338                let should_check = wait_since.elapsed() >= self.deadlock_timeout;
339                if should_check && self.record_wait(txn_id, &blockers) {
340                    warn!("deadlock detected: txn={}", txn_id);
341                    if let Some(mode) = prev_mode {
342                        if let Some(req) = queue_guard
343                            .requests
344                            .iter_mut()
345                            .find(|req| req.id == request_id)
346                        {
347                            req.mode = mode;
348                            req.granted = true;
349                        }
350                    } else {
351                        queue_guard.requests.retain(|req| req.id != request_id);
352                    }
353                    self.clear_wait_edges(txn_id);
354                    return false;
355                }
356            }
357            let remaining = self.deadlock_timeout.saturating_sub(wait_since.elapsed());
358            resource.condvar.wait_for(&mut queue_guard, remaining);
359            self.clear_wait_edges(txn_id);
360            wait_since = Instant::now();
361        }
362    }
363
364    fn record_wait(&self, txn_id: TransactionId, blockers: &[TransactionId]) -> bool {
365        let mut wait_for = self.wait_for.lock();
366        let entry = wait_for.entry(txn_id).or_default();
367        entry.clear();
368        entry.extend(blockers.iter().copied());
369        self.has_cycle(&wait_for, txn_id)
370    }
371
372    fn clear_wait_edges(&self, txn: TransactionId) {
373        let mut wait_for = self.wait_for.lock();
374        wait_for.remove(&txn);
375        for edges in wait_for.values_mut() {
376            edges.remove(&txn);
377        }
378    }
379
380    fn has_cycle(
381        &self,
382        wait_for: &HashMap<TransactionId, HashSet<TransactionId>>,
383        start: TransactionId,
384    ) -> bool {
385        fn dfs(
386            graph: &HashMap<TransactionId, HashSet<TransactionId>>,
387            node: TransactionId,
388            start: TransactionId,
389            visited: &mut HashSet<TransactionId>,
390        ) -> bool {
391            if !visited.insert(node) {
392                return false;
393            }
394            if let Some(edges) = graph.get(&node) {
395                for &next in edges {
396                    if next == start || dfs(graph, next, start, visited) {
397                        return true;
398                    }
399                }
400            }
401            visited.remove(&node);
402            false
403        }
404        dfs(wait_for, start, start, &mut HashSet::new())
405    }
406
407    fn unlock_table_internal(&self, txn_id: TransactionId, table_ref: TableReference) -> bool {
408        let resource = {
409            let map = self.table_lock_map.lock();
410            map.get(&table_ref).cloned()
411        };
412        let Some(resource) = resource else {
413            return false;
414        };
415
416        let mut queue_guard = resource.state.lock();
417        let original_len = queue_guard.requests.len();
418        queue_guard
419            .requests
420            .retain(|req| !(req.txn_id == txn_id && req.rid.is_none()));
421        let removed = queue_guard.requests.len() != original_len;
422        if removed {
423            resource.condvar.notify_all();
424        }
425        let empty = queue_guard.requests.is_empty();
426        drop(queue_guard);
427
428        if empty {
429            let mut map = self.table_lock_map.lock();
430            if let Entry::Occupied(entry) = map.entry(table_ref) {
431                if Arc::ptr_eq(entry.get(), &resource) {
432                    entry.remove();
433                }
434            }
435        }
436
437        removed
438    }
439
440    fn unlock_row_internal(
441        &self,
442        txn_id: TransactionId,
443        table_ref: TableReference,
444        rid: RecordId,
445    ) -> bool {
446        let key = (table_ref.clone(), rid);
447        let resource = {
448            let map = self.row_lock_map.lock();
449            map.get(&key).cloned()
450        };
451        let Some(resource) = resource else {
452            return false;
453        };
454
455        let mut queue_guard = resource.state.lock();
456        let original_len = queue_guard.requests.len();
457        queue_guard.requests.retain(|req| {
458            if req.txn_id != txn_id {
459                return true;
460            }
461            if let Some(lock_rid) = req.rid {
462                return !(lock_rid == rid && req.table_ref == table_ref);
463            }
464            true
465        });
466        let removed = queue_guard.requests.len() != original_len;
467        if removed {
468            resource.condvar.notify_all();
469        }
470        let empty = queue_guard.requests.is_empty();
471        drop(queue_guard);
472
473        if empty {
474            let mut map = self.row_lock_map.lock();
475            if let Entry::Occupied(entry) = map.entry(key) {
476                if Arc::ptr_eq(entry.get(), &resource) {
477                    entry.remove();
478                }
479            }
480        }
481
482        removed
483    }
484}
485
486impl Default for LockManager {
487    fn default() -> Self {
488        Self::new()
489    }
490}
491
492fn can_grant(queue: &VecDeque<LockRequest>, request_id: u64) -> bool {
493    let (position, request) = match queue
494        .iter()
495        .enumerate()
496        .find(|(_, req)| req.id == request_id)
497    {
498        Some(pair) => pair,
499        None => return false,
500    };
501
502    for pending in queue.iter().take(position) {
503        if !pending.granted {
504            return false;
505        }
506    }
507
508    for granted in queue.iter().filter(|req| req.granted) {
509        if granted.id == request_id {
510            continue;
511        }
512        if granted.txn_id == request.txn_id {
513            continue;
514        }
515        if !modes_compatible(request.mode, granted.mode) {
516            return false;
517        }
518    }
519    true
520}
521
522fn blockers_for(queue: &VecDeque<LockRequest>, request_id: u64) -> Vec<TransactionId> {
523    let Some((position, target)) = queue
524        .iter()
525        .enumerate()
526        .find(|(_, req)| req.id == request_id)
527    else {
528        return Vec::new();
529    };
530    queue
531        .iter()
532        .take(position)
533        .filter(|req| req.granted && req.txn_id != target.txn_id)
534        .map(|req| req.txn_id)
535        .collect()
536}
537
538fn can_upgrade(held: LockMode, requested: LockMode) -> bool {
539    matches!(
540        (held, requested),
541        (LockMode::Shared, LockMode::Exclusive)
542            | (LockMode::Shared, LockMode::SharedIntentionExclusive)
543            | (LockMode::IntentionShared, LockMode::IntentionExclusive)
544            | (
545                LockMode::IntentionShared,
546                LockMode::SharedIntentionExclusive
547            )
548            | (
549                LockMode::IntentionExclusive,
550                LockMode::SharedIntentionExclusive
551            )
552    )
553}
554
555fn modes_compatible(requested: LockMode, held: LockMode) -> bool {
556    match requested {
557        LockMode::Shared => matches!(
558            held,
559            LockMode::Shared | LockMode::IntentionShared | LockMode::SharedIntentionExclusive
560        ),
561        LockMode::Exclusive => false,
562        LockMode::IntentionShared => matches!(
563            held,
564            LockMode::Shared
565                | LockMode::IntentionShared
566                | LockMode::IntentionExclusive
567                | LockMode::SharedIntentionExclusive
568        ),
569        LockMode::IntentionExclusive => matches!(
570            held,
571            LockMode::IntentionShared | LockMode::IntentionExclusive
572        ),
573        LockMode::SharedIntentionExclusive => {
574            matches!(held, LockMode::IntentionShared | LockMode::Shared)
575        }
576    }
577}
578
579#[cfg(test)]
580mod tests {
581    use super::*;
582    use crate::transaction::{IsolationLevel, Transaction};
583    use crate::utils::table_ref::TableReference;
584    use std::sync::atomic::{AtomicBool, Ordering as AtomicOrdering};
585    use std::sync::Arc;
586    use std::thread;
587    use std::time::Duration;
588
589    fn new_txn(id: TransactionId) -> Transaction {
590        Transaction::new(
591            id,
592            IsolationLevel::ReadCommitted,
593            sqlparser::ast::TransactionAccessMode::ReadWrite,
594            true,
595        )
596    }
597
598    #[test]
599    fn shared_locks_are_compatible() {
600        let manager = LockManager::new();
601        let table = TableReference::Bare {
602            table: "t_shared".to_string(),
603        };
604        let txn1 = new_txn(1);
605        let txn2 = new_txn(2);
606
607        assert!(manager.lock_table(&txn1, LockMode::Shared, table.clone()));
608        assert!(manager.lock_table(&txn2, LockMode::Shared, table.clone()));
609
610        assert!(manager.unlock_table(&txn1, table.clone()));
611        assert!(manager.unlock_table(&txn2, table));
612    }
613
614    #[test]
615    fn exclusive_waits_for_shared() {
616        let manager = Arc::new(LockManager::new());
617        let table = TableReference::Bare {
618            table: "t_block".to_string(),
619        };
620        let txn1 = new_txn(10);
621        let txn2 = new_txn(20);
622
623        assert!(manager.lock_table(&txn1, LockMode::Shared, table.clone()));
624
625        let acquired = Arc::new(AtomicBool::new(false));
626        let acquired_clone = acquired.clone();
627        let manager_clone = manager.clone();
628        let table_clone = table.clone();
629
630        let handle = thread::spawn(move || {
631            let ok = manager_clone.lock_table(&txn2, LockMode::Exclusive, table_clone.clone());
632            acquired_clone.store(ok, AtomicOrdering::SeqCst);
633            if ok {
634                manager_clone.unlock_table(&txn2, table_clone);
635            }
636        });
637
638        thread::sleep(Duration::from_millis(20));
639        assert!(!acquired.load(AtomicOrdering::SeqCst));
640
641        assert!(manager.unlock_table(&txn1, table.clone()));
642        handle.join().unwrap();
643        assert!(acquired.load(AtomicOrdering::SeqCst));
644    }
645
646    #[test]
647    fn row_lock_conflict_blocks() {
648        let manager = Arc::new(LockManager::new());
649        let table = TableReference::Bare {
650            table: "t_row".to_string(),
651        };
652        let rid = RecordId::new(1, 1);
653        let writer = new_txn(100);
654        let reader = new_txn(200);
655
656        assert!(manager.lock_row(&writer, LockMode::Exclusive, table.clone(), rid));
657
658        let proceed = Arc::new(AtomicBool::new(false));
659        let proceed_clone = proceed.clone();
660        let manager_clone = manager.clone();
661        let table_clone = table.clone();
662
663        let handle = thread::spawn(move || {
664            let ok = manager_clone.lock_row(&reader, LockMode::Shared, table_clone.clone(), rid);
665            proceed_clone.store(ok, AtomicOrdering::SeqCst);
666            if ok {
667                manager_clone.unlock_row(&reader, table_clone, rid, false);
668            }
669        });
670
671        thread::sleep(Duration::from_millis(20));
672        assert!(!proceed.load(AtomicOrdering::SeqCst));
673
674        assert!(manager.unlock_row(&writer, table.clone(), rid, false));
675        handle.join().unwrap();
676        assert!(proceed.load(AtomicOrdering::SeqCst));
677    }
678}