1use std::collections::{HashMap, HashSet, VecDeque};
6use std::sync::{
7 Condvar, Mutex, MutexGuard, RwLock, RwLockReadGuard, RwLockWriteGuard, WaitTimeoutResult,
8};
9use std::time::{Duration, Instant};
10
11pub type TxnId = u64;
13
14fn read_unpoisoned<'a, T>(lock: &'a RwLock<T>) -> RwLockReadGuard<'a, T> {
15 match lock.read() {
16 Ok(guard) => guard,
17 Err(poisoned) => poisoned.into_inner(),
18 }
19}
20
21fn write_unpoisoned<'a, T>(lock: &'a RwLock<T>) -> RwLockWriteGuard<'a, T> {
22 match lock.write() {
23 Ok(guard) => guard,
24 Err(poisoned) => poisoned.into_inner(),
25 }
26}
27
28fn mutex_unpoisoned<'a, T>(lock: &'a Mutex<T>) -> MutexGuard<'a, T> {
29 match lock.lock() {
30 Ok(guard) => guard,
31 Err(poisoned) => poisoned.into_inner(),
32 }
33}
34
35fn wait_timeout_unpoisoned<'a, T>(
36 condvar: &Condvar,
37 guard: MutexGuard<'a, T>,
38 timeout: Duration,
39) -> (MutexGuard<'a, T>, WaitTimeoutResult) {
40 match condvar.wait_timeout(guard, timeout) {
41 Ok(result) => result,
42 Err(poisoned) => poisoned.into_inner(),
43 }
44}
45
46#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
48pub enum LockMode {
49 Shared,
51 Exclusive,
53 IntentShared,
55 IntentExclusive,
57 SharedIntentExclusive,
59}
60
61impl LockMode {
62 pub fn is_compatible(&self, other: &LockMode) -> bool {
64 use LockMode::*;
65 matches!(
66 (self, other),
67 (Shared, Shared)
68 | (Shared, IntentShared)
69 | (IntentShared, Shared)
70 | (IntentShared, IntentShared)
71 | (IntentShared, IntentExclusive)
72 | (IntentExclusive, IntentShared)
73 | (IntentExclusive, IntentExclusive)
74 )
75 }
76
77 pub fn can_upgrade_to(&self, target: &LockMode) -> bool {
79 use LockMode::*;
80 matches!(
81 (self, target),
82 (Shared, Exclusive)
83 | (IntentShared, IntentExclusive)
84 | (IntentShared, SharedIntentExclusive)
85 | (Shared, SharedIntentExclusive)
86 )
87 }
88}
89
90#[derive(Debug, Clone, PartialEq, Eq)]
92pub enum LockResult {
93 Granted,
95 Waiting,
97 Deadlock(Vec<TxnId>),
99 Timeout,
101 Upgraded,
103 AlreadyHeld,
105 TxnNotFound,
107 LockLimitExceeded,
109}
110
111#[derive(Debug, Clone)]
113pub struct LockWaiter {
114 pub txn_id: TxnId,
116 pub mode: LockMode,
118 pub start_time: Instant,
120 pub timeout: Duration,
122}
123
124impl LockWaiter {
125 pub fn new(txn_id: TxnId, mode: LockMode, timeout: Duration) -> Self {
127 Self {
128 txn_id,
129 mode,
130 start_time: Instant::now(),
131 timeout,
132 }
133 }
134
135 pub fn is_timed_out(&self) -> bool {
137 self.start_time.elapsed() > self.timeout
138 }
139}
140
141#[derive(Debug)]
143struct Lock {
144 resource: Vec<u8>,
146 holders: HashMap<TxnId, LockMode>,
148 waiters: VecDeque<LockWaiter>,
150}
151
152impl Lock {
153 fn new(resource: Vec<u8>) -> Self {
154 Self {
155 resource,
156 holders: HashMap::new(),
157 waiters: VecDeque::new(),
158 }
159 }
160
161 fn can_grant(&self, txn_id: TxnId, mode: LockMode) -> bool {
163 if let Some(held_mode) = self.holders.get(&txn_id) {
165 if *held_mode == mode {
166 return true; }
168 if held_mode.can_upgrade_to(&mode) {
170 return self
172 .holders
173 .iter()
174 .all(|(id, m)| *id == txn_id || mode.is_compatible(m));
175 }
176 return false;
177 }
178
179 self.holders.values().all(|m| mode.is_compatible(m))
181 }
182
183 fn grant(&mut self, txn_id: TxnId, mode: LockMode) {
185 self.holders.insert(txn_id, mode);
186 }
187
188 fn release(&mut self, txn_id: TxnId) -> Option<LockMode> {
190 self.holders.remove(&txn_id)
191 }
192
193 fn add_waiter(&mut self, waiter: LockWaiter) {
195 self.waiters.push_back(waiter);
196 }
197
198 fn process_waiters(&mut self) -> Vec<TxnId> {
200 let mut granted = Vec::new();
201
202 self.waiters.retain(|w| !w.is_timed_out());
204
205 let mut i = 0;
207 while i < self.waiters.len() {
208 let waiter = &self.waiters[i];
209 if self.can_grant(waiter.txn_id, waiter.mode) {
210 if let Some(waiter) = self.waiters.remove(i) {
211 self.grant(waiter.txn_id, waiter.mode);
212 granted.push(waiter.txn_id);
213 }
214 } else {
215 i += 1;
216 }
217 }
218
219 granted
220 }
221}
222
223#[derive(Debug, Clone)]
225pub struct LockConfig {
226 pub default_timeout: Duration,
228 pub deadlock_detection: bool,
230 pub detection_interval: Duration,
232 pub max_locks_per_txn: usize,
234}
235
236impl Default for LockConfig {
237 fn default() -> Self {
238 Self {
239 default_timeout: Duration::from_secs(30),
240 deadlock_detection: true,
241 detection_interval: Duration::from_millis(100),
242 max_locks_per_txn: 10000,
243 }
244 }
245}
246
247#[derive(Debug, Clone, Default)]
249pub struct LockStats {
250 pub requests: u64,
252 pub granted: u64,
254 pub waited: u64,
256 pub deadlocks: u64,
258 pub timeouts: u64,
260 pub active_locks: u64,
262 pub waiting: u64,
264}
265
266pub struct LockManager {
268 config: LockConfig,
270 locks: RwLock<HashMap<Vec<u8>, Lock>>,
272 txn_locks: RwLock<HashMap<TxnId, HashSet<Vec<u8>>>>,
274 wait_graph: RwLock<HashMap<TxnId, HashSet<TxnId>>>,
276 waiter_cv: Condvar,
278 waiter_mutex: Mutex<()>,
280 stats: RwLock<LockStats>,
282}
283
284impl LockManager {
285 pub fn new(config: LockConfig) -> Self {
287 Self {
288 config,
289 locks: RwLock::new(HashMap::new()),
290 txn_locks: RwLock::new(HashMap::new()),
291 wait_graph: RwLock::new(HashMap::new()),
292 waiter_cv: Condvar::new(),
293 waiter_mutex: Mutex::new(()),
294 stats: RwLock::new(LockStats::default()),
295 }
296 }
297
298 pub fn with_defaults() -> Self {
300 Self::new(LockConfig::default())
301 }
302
303 pub fn acquire(&self, txn_id: TxnId, resource: &[u8], mode: LockMode) -> LockResult {
305 self.acquire_with_timeout(txn_id, resource, mode, self.config.default_timeout)
306 }
307
308 pub fn acquire_with_timeout(
310 &self,
311 txn_id: TxnId,
312 resource: &[u8],
313 mode: LockMode,
314 timeout: Duration,
315 ) -> LockResult {
316 {
318 let mut stats = write_unpoisoned(&self.stats);
319 stats.requests += 1;
320 }
321
322 let resource_key = resource.to_vec();
323
324 {
326 let txn_locks = read_unpoisoned(&self.txn_locks);
327 if let Some(locks) = txn_locks.get(&txn_id) {
328 if locks.len() >= self.config.max_locks_per_txn && !locks.contains(&resource_key) {
329 return LockResult::LockLimitExceeded;
330 }
331 }
332 }
333
334 {
336 let mut locks = write_unpoisoned(&self.locks);
337 let lock = locks
338 .entry(resource_key.clone())
339 .or_insert_with(|| Lock::new(resource_key.clone()));
340
341 if lock.can_grant(txn_id, mode) {
342 let already_held = lock.holders.contains_key(&txn_id);
343 lock.grant(txn_id, mode);
344
345 let mut txn_locks = write_unpoisoned(&self.txn_locks);
347 txn_locks.entry(txn_id).or_default().insert(resource_key);
348
349 let mut stats = write_unpoisoned(&self.stats);
350 stats.granted += 1;
351 stats.active_locks = locks.values().map(|l| l.holders.len() as u64).sum();
352
353 return if already_held {
354 LockResult::Upgraded
355 } else {
356 LockResult::Granted
357 };
358 }
359
360 let waiting_for: HashSet<TxnId> = lock.holders.keys().copied().collect();
362
363 if self.config.deadlock_detection {
364 let mut wait_graph = Self::build_wait_graph_from_locks(&locks);
365 wait_graph
366 .entry(txn_id)
367 .or_default()
368 .extend(waiting_for.iter().copied());
369
370 if self.detect_deadlock_inner(txn_id, &wait_graph) {
371 let cycle: Vec<TxnId> = waiting_for.iter().copied().collect();
372 let mut stats = write_unpoisoned(&self.stats);
373 stats.deadlocks += 1;
374 return LockResult::Deadlock(cycle);
375 }
376
377 *write_unpoisoned(&self.wait_graph) = wait_graph;
378 }
379
380 if let Some(lock) = locks.get_mut(&resource_key) {
382 lock.add_waiter(LockWaiter::new(txn_id, mode, timeout));
383 }
384
385 let mut stats = write_unpoisoned(&self.stats);
386 stats.waited += 1;
387 stats.waiting += 1;
388 }
389
390 let start = Instant::now();
392 loop {
393 let guard = mutex_unpoisoned(&self.waiter_mutex);
395 let (_guard, _wait_result) =
396 wait_timeout_unpoisoned(&self.waiter_cv, guard, Duration::from_millis(10));
397
398 let holders: Option<HashSet<TxnId>> = {
400 let locks = read_unpoisoned(&self.locks);
401 if let Some(lock) = locks.get(&resource_key) {
402 if lock.holders.contains_key(&txn_id) {
403 if self.config.deadlock_detection {
405 let mut wait_graph = write_unpoisoned(&self.wait_graph);
406 wait_graph.remove(&txn_id);
407 }
408
409 let mut stats = write_unpoisoned(&self.stats);
410 stats.waiting -= 1;
411
412 return LockResult::Granted;
413 }
414
415 Some(lock.holders.keys().copied().collect())
416 } else {
417 None
418 }
419 };
420
421 if self.config.deadlock_detection {
422 let locks = read_unpoisoned(&self.locks);
423 let wait_graph = Self::build_wait_graph_from_locks(&locks);
424 drop(locks);
425
426 if self.detect_deadlock_inner(txn_id, &wait_graph) {
427 let mut stats = write_unpoisoned(&self.stats);
428 stats.deadlocks += 1;
429 stats.waiting -= 1;
430 return LockResult::Deadlock(holders.unwrap_or_default().into_iter().collect());
431 }
432
433 *write_unpoisoned(&self.wait_graph) = wait_graph;
434 }
435
436 if start.elapsed() > timeout {
438 {
440 let mut locks = write_unpoisoned(&self.locks);
441 if let Some(lock) = locks.get_mut(&resource_key) {
442 lock.waiters.retain(|w| w.txn_id != txn_id);
443 }
444 }
445
446 if self.config.deadlock_detection {
448 let mut wait_graph = write_unpoisoned(&self.wait_graph);
449 wait_graph.remove(&txn_id);
450 }
451
452 let mut stats = write_unpoisoned(&self.stats);
453 stats.timeouts += 1;
454 stats.waiting -= 1;
455
456 return LockResult::Timeout;
457 }
458 }
459 }
460
461 pub fn release(&self, txn_id: TxnId, resource: &[u8]) -> bool {
463 let resource_key = resource.to_vec();
464
465 let granted = {
466 let mut locks = write_unpoisoned(&self.locks);
467
468 if let Some(lock) = locks.get_mut(&resource_key) {
469 if lock.release(txn_id).is_some() {
470 let mut txn_locks = write_unpoisoned(&self.txn_locks);
472 if let Some(resources) = txn_locks.get_mut(&txn_id) {
473 resources.remove(&resource_key);
474 }
475
476 let granted = lock.process_waiters();
478
479 if self.config.deadlock_detection && !granted.is_empty() {
481 let mut wait_graph = write_unpoisoned(&self.wait_graph);
482 for txn in &granted {
483 wait_graph.remove(txn);
484 }
485 }
486
487 if lock.holders.is_empty() && lock.waiters.is_empty() {
489 locks.remove(&resource_key);
490 }
491
492 self.waiter_cv.notify_all();
494
495 return true;
496 }
497 }
498
499 false
500 };
501
502 granted
503 }
504
505 pub fn release_all(&self, txn_id: TxnId) -> usize {
507 let resources: Vec<Vec<u8>> = {
508 let txn_locks = read_unpoisoned(&self.txn_locks);
509 txn_locks
510 .get(&txn_id)
511 .map(|r| r.iter().cloned().collect())
512 .unwrap_or_default()
513 };
514
515 let count = resources.len();
516
517 for resource in resources {
518 self.release(txn_id, &resource);
519 }
520
521 {
523 let mut txn_locks = write_unpoisoned(&self.txn_locks);
524 txn_locks.remove(&txn_id);
525 }
526
527 if self.config.deadlock_detection {
529 let mut wait_graph = write_unpoisoned(&self.wait_graph);
530 wait_graph.remove(&txn_id);
531 }
532
533 count
534 }
535
536 pub fn holds_lock(&self, txn_id: TxnId, resource: &[u8]) -> Option<LockMode> {
538 let locks = read_unpoisoned(&self.locks);
539 locks
540 .get(resource)
541 .and_then(|lock| lock.holders.get(&txn_id).copied())
542 }
543
544 pub fn get_locks(&self, txn_id: TxnId) -> Vec<(Vec<u8>, LockMode)> {
546 let txn_locks = read_unpoisoned(&self.txn_locks);
547 let locks = read_unpoisoned(&self.locks);
548
549 txn_locks
550 .get(&txn_id)
551 .map(|resources| {
552 resources
553 .iter()
554 .filter_map(|r| {
555 locks
556 .get(r)
557 .and_then(|l| l.holders.get(&txn_id).map(|m| (r.clone(), *m)))
558 })
559 .collect()
560 })
561 .unwrap_or_default()
562 }
563
564 fn detect_deadlock_inner(
566 &self,
567 start: TxnId,
568 wait_graph: &HashMap<TxnId, HashSet<TxnId>>,
569 ) -> bool {
570 let mut visited = HashSet::new();
571 let mut stack = HashSet::new();
572
573 Self::dfs_cycle(start, &mut visited, &mut stack, wait_graph)
574 }
575
576 fn build_wait_graph_from_locks(
577 locks: &HashMap<Vec<u8>, Lock>,
578 ) -> HashMap<TxnId, HashSet<TxnId>> {
579 let mut graph: HashMap<TxnId, HashSet<TxnId>> = HashMap::new();
580
581 for lock in locks.values() {
582 if lock.holders.is_empty() {
583 continue;
584 }
585 let holders: HashSet<TxnId> = lock.holders.keys().copied().collect();
586 for waiter in &lock.waiters {
587 graph
588 .entry(waiter.txn_id)
589 .or_default()
590 .extend(holders.iter().copied());
591 }
592 }
593
594 graph
595 }
596
597 fn dfs_cycle(
598 node: TxnId,
599 visited: &mut HashSet<TxnId>,
600 stack: &mut HashSet<TxnId>,
601 wait_graph: &HashMap<TxnId, HashSet<TxnId>>,
602 ) -> bool {
603 if stack.contains(&node) {
604 return true; }
606 if visited.contains(&node) {
607 return false; }
609
610 visited.insert(node);
611 stack.insert(node);
612
613 if let Some(waiting_for) = wait_graph.get(&node) {
614 for &next in waiting_for {
615 if Self::dfs_cycle(next, visited, stack, wait_graph) {
616 return true;
617 }
618 }
619 }
620
621 stack.remove(&node);
622 false
623 }
624
625 pub fn stats(&self) -> LockStats {
627 read_unpoisoned(&self.stats).clone()
628 }
629
630 pub fn config(&self) -> &LockConfig {
632 &self.config
633 }
634}
635
636impl Default for LockManager {
637 fn default() -> Self {
638 Self::with_defaults()
639 }
640}
641
642#[cfg(test)]
643mod tests {
644 use super::*;
645
646 #[test]
647 fn test_lock_mode_compatibility() {
648 assert!(LockMode::Shared.is_compatible(&LockMode::Shared));
649 assert!(!LockMode::Shared.is_compatible(&LockMode::Exclusive));
650 assert!(!LockMode::Exclusive.is_compatible(&LockMode::Exclusive));
651 assert!(LockMode::IntentShared.is_compatible(&LockMode::IntentShared));
652 }
653
654 #[test]
655 fn test_lock_acquire_release() {
656 let lm = LockManager::with_defaults();
657
658 let result = lm.acquire(1, b"resource1", LockMode::Shared);
660 assert_eq!(result, LockResult::Granted);
661
662 let result = lm.acquire(2, b"resource1", LockMode::Shared);
664 assert_eq!(result, LockResult::Granted);
665
666 assert!(lm.release(1, b"resource1"));
668 assert!(lm.release(2, b"resource1"));
669 }
670
671 #[test]
672 fn test_exclusive_lock() {
673 let lm = LockManager::with_defaults();
674
675 let result = lm.acquire(1, b"resource1", LockMode::Exclusive);
677 assert_eq!(result, LockResult::Granted);
678
679 assert_eq!(lm.holds_lock(1, b"resource1"), Some(LockMode::Exclusive));
681
682 lm.release_all(1);
684 assert_eq!(lm.holds_lock(1, b"resource1"), None);
685 }
686
687 #[test]
688 fn test_release_all() {
689 let lm = LockManager::with_defaults();
690
691 lm.acquire(1, b"r1", LockMode::Shared);
693 lm.acquire(1, b"r2", LockMode::Exclusive);
694 lm.acquire(1, b"r3", LockMode::Shared);
695
696 let count = lm.release_all(1);
698 assert_eq!(count, 3);
699 }
700
701 #[test]
702 fn test_lock_limit_exceeded() {
703 let config = LockConfig {
704 max_locks_per_txn: 1,
705 ..LockConfig::default()
706 };
707 let lm = LockManager::new(config);
708
709 let result = lm.acquire(1, b"r1", LockMode::Shared);
710 assert_eq!(result, LockResult::Granted);
711
712 let result = lm.acquire(1, b"r2", LockMode::Shared);
713 assert_eq!(result, LockResult::LockLimitExceeded);
714 }
715
716 #[test]
717 fn test_lock_limit_allows_upgrade() {
718 let config = LockConfig {
719 max_locks_per_txn: 1,
720 ..LockConfig::default()
721 };
722 let lm = LockManager::new(config);
723
724 let result = lm.acquire(1, b"r1", LockMode::Shared);
725 assert_eq!(result, LockResult::Granted);
726
727 let result = lm.acquire(1, b"r1", LockMode::Exclusive);
728 assert_eq!(result, LockResult::Upgraded);
729 }
730
731 #[test]
732 fn test_get_locks() {
733 let lm = LockManager::with_defaults();
734
735 lm.acquire(1, b"r1", LockMode::Shared);
736 lm.acquire(1, b"r2", LockMode::Exclusive);
737
738 let locks = lm.get_locks(1);
739 assert_eq!(locks.len(), 2);
740 }
741
742 #[test]
743 fn test_waiter_timeout() {
744 let waiter = LockWaiter::new(1, LockMode::Shared, Duration::from_millis(1));
745 std::thread::sleep(Duration::from_millis(5));
746 assert!(waiter.is_timed_out());
747 }
748}