use crate::task::TaskId;
use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
use std::fmt;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct ResourceId(u64);
impl ResourceId {
pub fn new() -> Self {
static COUNTER: AtomicU64 = AtomicU64::new(1);
Self(COUNTER.fetch_add(1, Ordering::Relaxed))
}
#[must_use]
pub fn as_u64(&self) -> u64 {
self.0
}
}
impl Default for ResourceId {
fn default() -> Self {
Self::new()
}
}
impl fmt::Display for ResourceId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Resource#{}", self.0)
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum ResourceKind {
Mutex,
RwLock,
Semaphore,
Channel,
Other(String),
}
impl fmt::Display for ResourceKind {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Mutex => write!(f, "Mutex"),
Self::RwLock => write!(f, "RwLock"),
Self::Semaphore => write!(f, "Semaphore"),
Self::Channel => write!(f, "Channel"),
Self::Other(name) => write!(f, "{name}"),
}
}
}
#[derive(Debug, Clone)]
pub struct ResourceInfo {
pub id: ResourceId,
pub kind: ResourceKind,
pub name: String,
pub holder: Option<TaskId>,
pub waiters: Vec<TaskId>,
pub address: Option<usize>,
}
impl ResourceInfo {
#[must_use]
pub fn new(kind: ResourceKind, name: String) -> Self {
Self {
id: ResourceId::new(),
kind,
name,
holder: None,
waiters: Vec::new(),
address: None,
}
}
#[must_use]
pub fn with_address(mut self, address: usize) -> Self {
self.address = Some(address);
self
}
#[must_use]
pub fn is_held(&self) -> bool {
self.holder.is_some()
}
#[must_use]
pub fn has_waiters(&self) -> bool {
!self.waiters.is_empty()
}
}
impl fmt::Display for ResourceInfo {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{} '{}' ({})", self.kind, self.name, self.id)?;
if let Some(addr) = self.address {
write!(f, " @ 0x{addr:x}")?;
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct DeadlockCycle {
pub tasks: Vec<TaskId>,
pub resources: Vec<ResourceId>,
pub chain: Vec<WaitEdge>,
}
#[derive(Debug, Clone)]
pub struct WaitEdge {
pub task: TaskId,
pub resource: ResourceId,
pub holder: TaskId,
}
impl DeadlockCycle {
#[must_use]
pub fn describe(&self) -> String {
let mut desc = String::from("Deadlock cycle detected:\n");
for (i, edge) in self.chain.iter().enumerate() {
desc.push_str(&format!(
" {} Task {} → {} → Task {}\n",
if i == 0 { "→" } else { " " },
edge.task,
edge.resource,
edge.holder
));
}
desc.push_str(&format!(
"\n{} tasks and {} resources involved",
self.tasks.len(),
self.resources.len()
));
desc
}
}
#[derive(Clone)]
pub struct DeadlockDetector {
state: Arc<RwLock<DetectorState>>,
}
struct DetectorState {
resources: HashMap<ResourceId, ResourceInfo>,
task_waiting: HashMap<TaskId, ResourceId>,
enabled: bool,
}
impl DeadlockDetector {
#[must_use]
pub fn new() -> Self {
Self {
state: Arc::new(RwLock::new(DetectorState {
resources: HashMap::new(),
task_waiting: HashMap::new(),
enabled: true,
})),
}
}
pub fn enable(&self) {
self.state.write().enabled = true;
}
pub fn disable(&self) {
self.state.write().enabled = false;
}
#[must_use]
pub fn is_enabled(&self) -> bool {
self.state.read().enabled
}
#[must_use]
pub fn register_resource(&self, info: ResourceInfo) -> ResourceId {
if !self.is_enabled() {
return info.id;
}
let resource_id = info.id;
self.state.write().resources.insert(resource_id, info);
resource_id
}
pub fn acquire(&self, task_id: TaskId, resource_id: ResourceId) {
if !self.is_enabled() {
return;
}
let mut state = self.state.write();
state.task_waiting.remove(&task_id);
if let Some(resource) = state.resources.get_mut(&resource_id) {
resource.holder = Some(task_id);
resource.waiters.retain(|&t| t != task_id);
}
}
pub fn release(&self, task_id: TaskId, resource_id: ResourceId) {
if !self.is_enabled() {
return;
}
let mut state = self.state.write();
if let Some(resource) = state.resources.get_mut(&resource_id) {
if resource.holder == Some(task_id) {
resource.holder = None;
}
}
}
pub fn wait_for(&self, task_id: TaskId, resource_id: ResourceId) {
if !self.is_enabled() {
return;
}
let mut state = self.state.write();
state.task_waiting.insert(task_id, resource_id);
if let Some(resource) = state.resources.get_mut(&resource_id) {
if !resource.waiters.contains(&task_id) {
resource.waiters.push(task_id);
}
}
}
#[must_use]
pub fn detect_deadlocks(&self) -> Vec<DeadlockCycle> {
let state = self.state.read();
let mut graph: HashMap<TaskId, Vec<TaskId>> = HashMap::new();
let mut task_to_resource: HashMap<TaskId, ResourceId> = HashMap::new();
for (&waiting_task, &resource_id) in &state.task_waiting {
if let Some(resource) = state.resources.get(&resource_id) {
if let Some(holder_task) = resource.holder {
graph.entry(waiting_task).or_default().push(holder_task);
task_to_resource.insert(waiting_task, resource_id);
}
}
}
let mut cycles = Vec::new();
let mut visited = HashSet::new();
let mut rec_stack = HashSet::new();
for &task in graph.keys() {
if !visited.contains(&task) {
if let Some(cycle) = self.find_cycle_dfs(
task,
&graph,
&task_to_resource,
&mut visited,
&mut rec_stack,
&mut Vec::new(),
) {
cycles.push(cycle);
}
}
}
cycles
}
fn find_cycle_dfs(
&self,
task: TaskId,
graph: &HashMap<TaskId, Vec<TaskId>>,
task_to_resource: &HashMap<TaskId, ResourceId>,
visited: &mut HashSet<TaskId>,
rec_stack: &mut HashSet<TaskId>,
path: &mut Vec<TaskId>,
) -> Option<DeadlockCycle> {
visited.insert(task);
rec_stack.insert(task);
path.push(task);
if let Some(neighbors) = graph.get(&task) {
for &neighbor in neighbors {
if !visited.contains(&neighbor) {
if let Some(cycle) = self.find_cycle_dfs(
neighbor,
graph,
task_to_resource,
visited,
rec_stack,
path,
) {
return Some(cycle);
}
} else if rec_stack.contains(&neighbor) {
return Some(self.build_cycle(neighbor, path, task_to_resource));
}
}
}
rec_stack.remove(&task);
path.pop();
None
}
fn build_cycle(
&self,
start_task: TaskId,
path: &[TaskId],
task_to_resource: &HashMap<TaskId, ResourceId>,
) -> DeadlockCycle {
let cycle_start = path.iter().position(|&t| t == start_task).unwrap_or(0);
let cycle_tasks: Vec<TaskId> = path[cycle_start..].to_vec();
let mut resources = Vec::new();
let mut chain = Vec::new();
for i in 0..cycle_tasks.len() {
let waiting_task = cycle_tasks[i];
let holder_task = cycle_tasks[(i + 1) % cycle_tasks.len()];
if let Some(&resource_id) = task_to_resource.get(&waiting_task) {
resources.push(resource_id);
chain.push(WaitEdge {
task: waiting_task,
resource: resource_id,
holder: holder_task,
});
}
}
DeadlockCycle {
tasks: cycle_tasks,
resources,
chain,
}
}
#[must_use]
pub fn get_resources(&self) -> Vec<ResourceInfo> {
self.state.read().resources.values().cloned().collect()
}
#[must_use]
pub fn get_resource(&self, id: ResourceId) -> Option<ResourceInfo> {
self.state.read().resources.get(&id).cloned()
}
pub fn clear(&self) {
let mut state = self.state.write();
state.resources.clear();
state.task_waiting.clear();
}
}
impl Default for DeadlockDetector {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_resource_creation() {
let resource = ResourceInfo::new(ResourceKind::Mutex, "test_mutex".to_string());
assert_eq!(resource.kind, ResourceKind::Mutex);
assert_eq!(resource.name, "test_mutex");
assert!(!resource.is_held());
assert!(!resource.has_waiters());
}
#[test]
fn test_detector_registration() {
let detector = DeadlockDetector::new();
let resource = ResourceInfo::new(ResourceKind::Mutex, "test".to_string());
let resource_id = resource.id;
detector.register_resource(resource);
let retrieved = detector.get_resource(resource_id).unwrap();
assert_eq!(retrieved.name, "test");
}
#[test]
fn test_simple_deadlock_detection() {
let detector = DeadlockDetector::new();
let res1 = ResourceInfo::new(ResourceKind::Mutex, "mutex_a".to_string());
let res2 = ResourceInfo::new(ResourceKind::Mutex, "mutex_b".to_string());
let res1_id = res1.id;
let res2_id = res2.id;
detector.register_resource(res1);
detector.register_resource(res2);
let task1 = TaskId::new();
let task2 = TaskId::new();
detector.acquire(task1, res1_id);
detector.wait_for(task1, res2_id);
detector.acquire(task2, res2_id);
detector.wait_for(task2, res1_id);
let deadlocks = detector.detect_deadlocks();
assert_eq!(deadlocks.len(), 1);
let cycle = &deadlocks[0];
assert_eq!(cycle.tasks.len(), 2);
assert!(cycle.tasks.contains(&task1));
assert!(cycle.tasks.contains(&task2));
}
#[test]
fn test_no_deadlock() {
let detector = DeadlockDetector::new();
let res = ResourceInfo::new(ResourceKind::Mutex, "mutex".to_string());
let res_id = res.id;
detector.register_resource(res);
let task1 = TaskId::new();
let task2 = TaskId::new();
detector.acquire(task1, res_id);
detector.release(task1, res_id);
detector.acquire(task2, res_id);
let deadlocks = detector.detect_deadlocks();
assert_eq!(deadlocks.len(), 0);
}
}