1use crate::storage::page::RecordId;
2use crate::transaction::{Transaction, TransactionId};
3use crate::utils::table_ref::TableReference;
4use log::{trace, warn};
5use parking_lot::{Condvar, Mutex};
6use std::collections::{hash_map::Entry, HashMap, HashSet, VecDeque};
7use std::sync::atomic::{AtomicU64, Ordering};
8use std::sync::Arc;
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11pub enum LockMode {
12 Shared,
13 Exclusive,
14 IntentionShared,
15 IntentionExclusive,
16 SharedIntentionExclusive,
17}
18
19#[derive(Debug, Clone)]
20struct LockRequest {
21 id: u64,
22 txn_id: TransactionId,
23 mode: LockMode,
24 table_ref: TableReference,
25 rid: Option<RecordId>,
26 granted: bool,
27}
28
29#[derive(Debug, Default)]
30struct LockQueue {
31 requests: VecDeque<LockRequest>,
32}
33
34#[derive(Debug, Default)]
35struct ResourceLock {
36 state: Mutex<LockQueue>,
37 condvar: Condvar,
38}
39
40impl ResourceLock {
41 fn new() -> Self {
43 Self {
44 state: Mutex::new(LockQueue {
45 requests: VecDeque::new(),
46 }),
47 condvar: Condvar::new(),
48 }
49 }
50}
51
52type RowKey = (TableReference, RecordId);
53
54#[derive(Debug)]
55pub struct LockManager {
56 table_lock_map: Mutex<HashMap<TableReference, Arc<ResourceLock>>>,
57 row_lock_map: Mutex<HashMap<RowKey, Arc<ResourceLock>>>,
58 request_id: AtomicU64,
59 wait_for: Mutex<HashMap<TransactionId, HashSet<TransactionId>>>,
60}
61
62impl LockManager {
63 pub fn new() -> Self {
65 Self {
66 table_lock_map: Mutex::new(HashMap::new()),
67 row_lock_map: Mutex::new(HashMap::new()),
68 request_id: AtomicU64::new(1),
69 wait_for: Mutex::new(HashMap::new()),
70 }
71 }
72
73 pub fn default_instance() -> Self {
75 Self::new()
76 }
77
78 pub fn lock_table(&self, txn: &Transaction, mode: LockMode, table_ref: TableReference) -> bool {
80 let resource = self.get_table_resource(table_ref.clone());
81 self.lock_resource(
82 resource,
83 LockRequest {
84 id: self.next_request_id(),
85 txn_id: txn.id(),
86 mode,
87 table_ref,
88 rid: None,
89 granted: false,
90 },
91 )
92 }
93
94 pub fn unlock_table(&self, txn: &Transaction, table_ref: TableReference) -> bool {
96 self.unlock_table_raw(txn.id(), table_ref)
97 }
98
99 pub fn lock_row(
101 &self,
102 txn: &Transaction,
103 mode: LockMode,
104 table_ref: TableReference,
105 rid: RecordId,
106 ) -> bool {
107 let resource = self.get_row_resource(table_ref.clone(), rid);
108 self.lock_resource(
109 resource,
110 LockRequest {
111 id: self.next_request_id(),
112 txn_id: txn.id(),
113 mode,
114 table_ref,
115 rid: Some(rid),
116 granted: false,
117 },
118 )
119 }
120
121 pub fn unlock_row(
123 &self,
124 txn: &Transaction,
125 table_ref: TableReference,
126 rid: RecordId,
127 _force: bool,
128 ) -> bool {
129 self.unlock_row_raw(txn.id(), table_ref, rid)
130 }
131
132 pub fn unlock_table_raw(&self, txn_id: TransactionId, table_ref: TableReference) -> bool {
133 self.unlock_table_internal(txn_id, table_ref)
134 }
135
136 pub fn unlock_row_raw(
137 &self,
138 txn_id: TransactionId,
139 table_ref: TableReference,
140 rid: RecordId,
141 ) -> bool {
142 self.unlock_row_internal(txn_id, table_ref, rid)
143 }
144
145 pub fn unlock_all(&self) {
147 {
148 let mut map = self.table_lock_map.lock();
149 for resource in map.values() {
150 let mut state = resource.state.lock();
151 state.requests.clear();
152 resource.condvar.notify_all();
153 }
154 map.clear();
155 }
156 {
157 let mut map = self.row_lock_map.lock();
158 for resource in map.values() {
159 let mut state = resource.state.lock();
160 state.requests.clear();
161 resource.condvar.notify_all();
162 }
163 map.clear();
164 }
165 }
166
167 fn next_request_id(&self) -> u64 {
168 self.request_id.fetch_add(1, Ordering::SeqCst)
169 }
170
171 fn get_table_resource(&self, table_ref: TableReference) -> Arc<ResourceLock> {
172 let mut map = self.table_lock_map.lock();
173 map.entry(table_ref)
174 .or_insert_with(|| Arc::new(ResourceLock::new()))
175 .clone()
176 }
177
178 fn get_row_resource(&self, table_ref: TableReference, rid: RecordId) -> Arc<ResourceLock> {
179 let key = (table_ref, rid);
180 let mut map = self.row_lock_map.lock();
181 map.entry(key)
182 .or_insert_with(|| Arc::new(ResourceLock::new()))
183 .clone()
184 }
185
186 fn lock_resource(&self, resource: Arc<ResourceLock>, request: LockRequest) -> bool {
187 let mut queue_guard = resource.state.lock();
188
189 let mut prev_mode: Option<LockMode> = None;
190 let mut txn_id = request.txn_id;
191 let request_id = if let Some(existing) = queue_guard.requests.iter_mut().find(|req| {
192 req.txn_id == request.txn_id
193 && req.rid == request.rid
194 && req.table_ref == request.table_ref
195 }) {
196 if existing.mode == request.mode {
197 return true;
198 }
199
200 if !can_upgrade(existing.mode, request.mode) {
201 return false;
202 }
203
204 prev_mode = Some(existing.mode);
205 txn_id = existing.txn_id;
206 existing.mode = request.mode;
207 existing.granted = false;
208 existing.id
209 } else {
210 queue_guard.requests.push_back(request);
211 queue_guard.requests.back().map(|req| req.id).unwrap_or(0)
212 };
213
214 loop {
215 if can_grant(&queue_guard.requests, request_id) {
216 if let Some(req) = queue_guard
217 .requests
218 .iter_mut()
219 .find(|req| req.id == request_id)
220 {
221 req.granted = true;
222 trace!(
223 "lock granted: txn={} resource={:?} mode={:?}",
224 req.txn_id,
225 (req.table_ref.clone(), req.rid),
226 req.mode
227 );
228 }
229 self.clear_wait_edges(txn_id);
230 return true;
231 }
232
233 let blockers = blockers_for(&queue_guard.requests, request_id);
234 if !blockers.is_empty() {
235 trace!("wait edge: txn={} blocking_on={:?}", txn_id, blockers);
236 if self.record_wait(txn_id, &blockers) {
237 warn!("deadlock detected: txn={}", txn_id);
238 if let Some(mode) = prev_mode {
239 if let Some(req) = queue_guard
240 .requests
241 .iter_mut()
242 .find(|req| req.id == request_id)
243 {
244 req.mode = mode;
245 req.granted = true;
246 }
247 } else {
248 queue_guard.requests.retain(|req| req.id != request_id);
249 }
250 self.clear_wait_edges(txn_id);
251 return false;
252 }
253 }
254 resource.condvar.wait(&mut queue_guard);
255 self.clear_wait_edges(txn_id);
256 }
257 }
258
259 fn record_wait(&self, txn_id: TransactionId, blockers: &[TransactionId]) -> bool {
260 let mut wait_for = self.wait_for.lock();
261 let entry = wait_for.entry(txn_id).or_default();
262 entry.clear();
263 entry.extend(blockers.iter().copied());
264 self.has_cycle(&wait_for, txn_id)
265 }
266
267 fn clear_wait_edges(&self, txn: TransactionId) {
268 let mut wait_for = self.wait_for.lock();
269 wait_for.remove(&txn);
270 for edges in wait_for.values_mut() {
271 edges.remove(&txn);
272 }
273 }
274
275 fn has_cycle(
276 &self,
277 wait_for: &HashMap<TransactionId, HashSet<TransactionId>>,
278 start: TransactionId,
279 ) -> bool {
280 fn dfs(
281 graph: &HashMap<TransactionId, HashSet<TransactionId>>,
282 node: TransactionId,
283 start: TransactionId,
284 visited: &mut HashSet<TransactionId>,
285 ) -> bool {
286 if !visited.insert(node) {
287 return false;
288 }
289 if let Some(edges) = graph.get(&node) {
290 for &next in edges {
291 if next == start || dfs(graph, next, start, visited) {
292 return true;
293 }
294 }
295 }
296 visited.remove(&node);
297 false
298 }
299 dfs(wait_for, start, start, &mut HashSet::new())
300 }
301
302 fn unlock_table_internal(&self, txn_id: TransactionId, table_ref: TableReference) -> bool {
303 let resource = {
304 let map = self.table_lock_map.lock();
305 map.get(&table_ref).cloned()
306 };
307 let Some(resource) = resource else {
308 return false;
309 };
310
311 let mut queue_guard = resource.state.lock();
312 let original_len = queue_guard.requests.len();
313 queue_guard
314 .requests
315 .retain(|req| !(req.txn_id == txn_id && req.rid.is_none()));
316 let removed = queue_guard.requests.len() != original_len;
317 if removed {
318 resource.condvar.notify_all();
319 }
320 let empty = queue_guard.requests.is_empty();
321 drop(queue_guard);
322
323 if empty {
324 let mut map = self.table_lock_map.lock();
325 if let Entry::Occupied(entry) = map.entry(table_ref) {
326 if Arc::ptr_eq(entry.get(), &resource) {
327 entry.remove();
328 }
329 }
330 }
331
332 removed
333 }
334
335 fn unlock_row_internal(
336 &self,
337 txn_id: TransactionId,
338 table_ref: TableReference,
339 rid: RecordId,
340 ) -> bool {
341 let key = (table_ref.clone(), rid);
342 let resource = {
343 let map = self.row_lock_map.lock();
344 map.get(&key).cloned()
345 };
346 let Some(resource) = resource else {
347 return false;
348 };
349
350 let mut queue_guard = resource.state.lock();
351 let original_len = queue_guard.requests.len();
352 queue_guard.requests.retain(|req| {
353 if req.txn_id != txn_id {
354 return true;
355 }
356 if let Some(lock_rid) = req.rid {
357 return !(lock_rid == rid && req.table_ref == table_ref);
358 }
359 true
360 });
361 let removed = queue_guard.requests.len() != original_len;
362 if removed {
363 resource.condvar.notify_all();
364 }
365 let empty = queue_guard.requests.is_empty();
366 drop(queue_guard);
367
368 if empty {
369 let mut map = self.row_lock_map.lock();
370 if let Entry::Occupied(entry) = map.entry(key) {
371 if Arc::ptr_eq(entry.get(), &resource) {
372 entry.remove();
373 }
374 }
375 }
376
377 removed
378 }
379}
380
381impl Default for LockManager {
382 fn default() -> Self {
383 Self::new()
384 }
385}
386
387fn can_grant(queue: &VecDeque<LockRequest>, request_id: u64) -> bool {
388 let (position, request) = match queue
389 .iter()
390 .enumerate()
391 .find(|(_, req)| req.id == request_id)
392 {
393 Some(pair) => pair,
394 None => return false,
395 };
396
397 for pending in queue.iter().take(position) {
398 if !pending.granted {
399 return false;
400 }
401 }
402
403 for granted in queue.iter().filter(|req| req.granted) {
404 if granted.id == request_id {
405 continue;
406 }
407 if granted.txn_id == request.txn_id {
408 continue;
409 }
410 if !modes_compatible(request.mode, granted.mode) {
411 return false;
412 }
413 }
414 true
415}
416
417fn blockers_for(queue: &VecDeque<LockRequest>, request_id: u64) -> Vec<TransactionId> {
418 let Some((position, target)) = queue
419 .iter()
420 .enumerate()
421 .find(|(_, req)| req.id == request_id)
422 else {
423 return Vec::new();
424 };
425 queue
426 .iter()
427 .take(position)
428 .filter(|req| req.granted && req.txn_id != target.txn_id)
429 .map(|req| req.txn_id)
430 .collect()
431}
432
433fn can_upgrade(held: LockMode, requested: LockMode) -> bool {
434 matches!(
435 (held, requested),
436 (LockMode::Shared, LockMode::Exclusive)
437 | (LockMode::Shared, LockMode::SharedIntentionExclusive)
438 | (LockMode::IntentionShared, LockMode::IntentionExclusive)
439 | (
440 LockMode::IntentionShared,
441 LockMode::SharedIntentionExclusive
442 )
443 | (
444 LockMode::IntentionExclusive,
445 LockMode::SharedIntentionExclusive
446 )
447 )
448}
449
450fn modes_compatible(requested: LockMode, held: LockMode) -> bool {
451 match requested {
452 LockMode::Shared => matches!(
453 held,
454 LockMode::Shared | LockMode::IntentionShared | LockMode::SharedIntentionExclusive
455 ),
456 LockMode::Exclusive => false,
457 LockMode::IntentionShared => matches!(
458 held,
459 LockMode::Shared
460 | LockMode::IntentionShared
461 | LockMode::IntentionExclusive
462 | LockMode::SharedIntentionExclusive
463 ),
464 LockMode::IntentionExclusive => matches!(
465 held,
466 LockMode::IntentionShared | LockMode::IntentionExclusive
467 ),
468 LockMode::SharedIntentionExclusive => {
469 matches!(held, LockMode::IntentionShared | LockMode::Shared)
470 }
471 }
472}
473
474#[cfg(test)]
475mod tests {
476 use super::*;
477 use crate::transaction::{IsolationLevel, Transaction};
478 use crate::utils::table_ref::TableReference;
479 use std::sync::atomic::{AtomicBool, Ordering as AtomicOrdering};
480 use std::sync::Arc;
481 use std::thread;
482 use std::time::Duration;
483
484 fn new_txn(id: TransactionId) -> Transaction {
485 Transaction::new(
486 id,
487 IsolationLevel::ReadCommitted,
488 sqlparser::ast::TransactionAccessMode::ReadWrite,
489 true,
490 )
491 }
492
493 #[test]
494 fn shared_locks_are_compatible() {
495 let manager = LockManager::new();
496 let table = TableReference::Bare {
497 table: "t_shared".to_string(),
498 };
499 let txn1 = new_txn(1);
500 let txn2 = new_txn(2);
501
502 assert!(manager.lock_table(&txn1, LockMode::Shared, table.clone()));
503 assert!(manager.lock_table(&txn2, LockMode::Shared, table.clone()));
504
505 assert!(manager.unlock_table(&txn1, table.clone()));
506 assert!(manager.unlock_table(&txn2, table));
507 }
508
509 #[test]
510 fn exclusive_waits_for_shared() {
511 let manager = Arc::new(LockManager::new());
512 let table = TableReference::Bare {
513 table: "t_block".to_string(),
514 };
515 let txn1 = new_txn(10);
516 let txn2 = new_txn(20);
517
518 assert!(manager.lock_table(&txn1, LockMode::Shared, table.clone()));
519
520 let acquired = Arc::new(AtomicBool::new(false));
521 let acquired_clone = acquired.clone();
522 let manager_clone = manager.clone();
523 let table_clone = table.clone();
524
525 let handle = thread::spawn(move || {
526 let ok = manager_clone.lock_table(&txn2, LockMode::Exclusive, table_clone.clone());
527 acquired_clone.store(ok, AtomicOrdering::SeqCst);
528 if ok {
529 manager_clone.unlock_table(&txn2, table_clone);
530 }
531 });
532
533 thread::sleep(Duration::from_millis(20));
534 assert!(!acquired.load(AtomicOrdering::SeqCst));
535
536 assert!(manager.unlock_table(&txn1, table.clone()));
537 handle.join().unwrap();
538 assert!(acquired.load(AtomicOrdering::SeqCst));
539 }
540
541 #[test]
542 fn row_lock_conflict_blocks() {
543 let manager = Arc::new(LockManager::new());
544 let table = TableReference::Bare {
545 table: "t_row".to_string(),
546 };
547 let rid = RecordId::new(1, 1);
548 let writer = new_txn(100);
549 let reader = new_txn(200);
550
551 assert!(manager.lock_row(&writer, LockMode::Exclusive, table.clone(), rid));
552
553 let proceed = Arc::new(AtomicBool::new(false));
554 let proceed_clone = proceed.clone();
555 let manager_clone = manager.clone();
556 let table_clone = table.clone();
557
558 let handle = thread::spawn(move || {
559 let ok = manager_clone.lock_row(&reader, LockMode::Shared, table_clone.clone(), rid);
560 proceed_clone.store(ok, AtomicOrdering::SeqCst);
561 if ok {
562 manager_clone.unlock_row(&reader, table_clone, rid, false);
563 }
564 });
565
566 thread::sleep(Duration::from_millis(20));
567 assert!(!proceed.load(AtomicOrdering::SeqCst));
568
569 assert!(manager.unlock_row(&writer, table.clone(), rid, false));
570 handle.join().unwrap();
571 assert!(proceed.load(AtomicOrdering::SeqCst));
572 }
573}