use std::collections::HashMap;
use crate::TxnId;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub enum VictimPolicy {
#[default]
Youngest,
Oldest,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Deadlock {
pub victim: TxnId,
pub cycle: Vec<TxnId>,
}
#[derive(Debug, Clone, Default)]
pub struct WaitForGraph {
edges: HashMap<TxnId, Vec<TxnId>>,
}
impl WaitForGraph {
#[must_use]
pub fn new() -> Self {
Self {
edges: HashMap::new(),
}
}
pub fn add_wait(&mut self, waiter: TxnId, holder: TxnId) {
if waiter == holder {
return;
}
self.edges.entry(waiter).or_default().push(holder);
}
pub fn add_waits(&mut self, waiter: TxnId, holders: &[TxnId]) {
for &holder in holders {
self.add_wait(waiter, holder);
}
}
pub fn clear_waiter(&mut self, waiter: TxnId) {
let _ = self.edges.remove(&waiter);
}
pub fn remove_txn(&mut self, txn: TxnId) {
let _ = self.edges.remove(&txn);
for holders in self.edges.values_mut() {
holders.retain(|h| *h != txn);
}
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.edges.is_empty()
}
#[must_use]
pub fn waiter_count(&self) -> usize {
self.edges.len()
}
#[must_use]
pub fn detect_cycle(&self) -> Option<Vec<TxnId>> {
for &start in self.edges.keys() {
if let Some(cycle) = self.cycle_from(start) {
return Some(cycle);
}
}
None
}
#[must_use]
pub fn cycle_from(&self, start: TxnId) -> Option<Vec<TxnId>> {
let mut state: HashMap<TxnId, u8> = HashMap::new();
let mut path: Vec<TxnId> = Vec::new();
let mut stack: Vec<(TxnId, usize)> = Vec::new();
let _ = state.insert(start, 1);
path.push(start);
stack.push((start, 0));
while let Some(&(node, idx)) = stack.last() {
let neighbors: &[TxnId] = self.edges.get(&node).map_or(&[], Vec::as_slice);
if idx < neighbors.len() {
if let Some(top) = stack.last_mut() {
top.1 += 1;
}
let next = neighbors[idx];
match state.get(&next).copied().unwrap_or(0) {
1 => {
if let Some(pos) = path.iter().position(|t| *t == next) {
return Some(path[pos..].to_vec());
}
}
0 => {
let _ = state.insert(next, 1);
path.push(next);
stack.push((next, 0));
}
_ => {}
}
} else {
let _ = state.insert(node, 2);
let _ = path.pop();
let _ = stack.pop();
}
}
None
}
#[must_use]
pub fn pick_victim(cycle: &[TxnId], policy: VictimPolicy) -> Option<TxnId> {
match policy {
VictimPolicy::Youngest => cycle.iter().copied().max(),
VictimPolicy::Oldest => cycle.iter().copied().min(),
}
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::{VictimPolicy, WaitForGraph};
use crate::TxnId;
fn t(id: u64) -> TxnId {
TxnId::new(id)
}
#[test]
fn test_empty_graph_has_no_cycle() {
let g = WaitForGraph::new();
assert!(g.is_empty());
assert!(g.detect_cycle().is_none());
}
#[test]
fn test_self_edge_ignored() {
let mut g = WaitForGraph::new();
g.add_wait(t(1), t(1));
assert!(g.is_empty());
assert!(g.detect_cycle().is_none());
}
#[test]
fn test_chain_has_no_cycle() {
let mut g = WaitForGraph::new();
g.add_wait(t(1), t(2));
g.add_wait(t(2), t(3));
g.add_wait(t(3), t(4));
assert!(g.detect_cycle().is_none());
}
#[test]
fn test_two_cycle_detected() {
let mut g = WaitForGraph::new();
g.add_wait(t(1), t(2));
g.add_wait(t(2), t(1));
let cycle = g.detect_cycle().unwrap();
assert_eq!(cycle.len(), 2);
assert!(cycle.contains(&t(1)) && cycle.contains(&t(2)));
}
#[test]
fn test_three_cycle_detected() {
let mut g = WaitForGraph::new();
g.add_wait(t(1), t(2));
g.add_wait(t(2), t(3));
g.add_wait(t(3), t(1));
let cycle = g.detect_cycle().unwrap();
assert_eq!(cycle.len(), 3);
}
#[test]
fn test_cycle_from_unknown_txn_is_none() {
let mut g = WaitForGraph::new();
g.add_wait(t(1), t(2));
g.add_wait(t(2), t(1));
assert!(g.cycle_from(t(99)).is_none());
}
#[test]
fn test_cycle_from_finds_cycle_containing_start() {
let mut g = WaitForGraph::new();
g.add_wait(t(1), t(2));
g.add_wait(t(2), t(3));
g.add_wait(t(3), t(2)); let cycle = g.cycle_from(t(1)).unwrap();
assert!(cycle.contains(&t(2)) && cycle.contains(&t(3)));
assert!(!cycle.contains(&t(1))); }
#[test]
fn test_clear_waiter_breaks_cycle() {
let mut g = WaitForGraph::new();
g.add_wait(t(1), t(2));
g.add_wait(t(2), t(1));
g.clear_waiter(t(1));
assert!(g.detect_cycle().is_none());
}
#[test]
fn test_remove_txn_drops_incoming_and_outgoing() {
let mut g = WaitForGraph::new();
g.add_wait(t(1), t(2));
g.add_wait(t(2), t(1));
g.add_wait(t(3), t(2));
g.remove_txn(t(2));
assert!(g.detect_cycle().is_none());
assert!(g.cycle_from(t(3)).is_none());
}
#[test]
fn test_diamond_no_cycle() {
let mut g = WaitForGraph::new();
g.add_wait(t(1), t(2));
g.add_wait(t(1), t(3));
g.add_wait(t(2), t(4));
g.add_wait(t(3), t(4));
assert!(g.detect_cycle().is_none());
}
#[test]
fn test_pick_victim_policies() {
let cycle = [t(3), t(7), t(5)];
assert_eq!(
WaitForGraph::pick_victim(&cycle, VictimPolicy::Youngest),
Some(t(7))
);
assert_eq!(
WaitForGraph::pick_victim(&cycle, VictimPolicy::Oldest),
Some(t(3))
);
assert_eq!(WaitForGraph::pick_victim(&[], VictimPolicy::Youngest), None);
}
#[test]
fn test_detected_cycle_is_an_actual_cycle() {
let mut g = WaitForGraph::new();
g.add_wait(t(1), t(2));
g.add_wait(t(2), t(3));
g.add_wait(t(3), t(1));
let cycle = g.detect_cycle().unwrap();
for i in 0..cycle.len() {
let from = cycle[i];
let to = cycle[(i + 1) % cycle.len()];
let edges = g.edges.get(&from).unwrap();
assert!(edges.contains(&to), "missing edge {from:?} -> {to:?}");
}
}
#[test]
fn test_default_policy_is_youngest() {
assert_eq!(VictimPolicy::default(), VictimPolicy::Youngest);
}
}