use crate::core::types::ThreadId;
use fxhash::{FxHashMap, FxHashSet};
use std::collections::VecDeque;
pub struct WaitForGraph {
pub(crate) edges: FxHashMap<ThreadId, FxHashSet<ThreadId>>,
pub(crate) incoming_edges: FxHashMap<ThreadId, FxHashSet<ThreadId>>,
bfs_queue: VecDeque<ThreadId>,
bfs_visited: FxHashSet<ThreadId>,
bfs_parent: FxHashMap<ThreadId, ThreadId>,
}
impl Default for WaitForGraph {
fn default() -> Self {
Self::new()
}
}
impl WaitForGraph {
pub fn new() -> Self {
Self {
edges: FxHashMap::default(),
incoming_edges: FxHashMap::default(),
bfs_queue: VecDeque::with_capacity(64),
bfs_visited: FxHashSet::default(),
bfs_parent: FxHashMap::default(),
}
}
pub fn add_edge(&mut self, from: ThreadId, to: ThreadId) -> Option<Vec<ThreadId>> {
if let Some(targets) = self.edges.get(&from)
&& targets.contains(&to)
{
return None;
}
if let Some(path) = self.find_path(to, from) {
return Some(path);
}
self.edges.entry(from).or_default().insert(to);
self.incoming_edges.entry(to).or_default().insert(from);
None
}
pub fn clear_wait_edges(&mut self, thread_id: ThreadId) {
if let Some(targets) = self.edges.remove(&thread_id) {
for target in targets {
if let Some(waiters) = self.incoming_edges.get_mut(&target) {
waiters.remove(&thread_id);
if waiters.is_empty() {
self.incoming_edges.remove(&target);
}
}
}
}
}
pub fn remove_edge(&mut self, from: ThreadId, to: ThreadId) {
if let Some(neighbors) = self.edges.get_mut(&from)
&& neighbors.remove(&to)
{
if neighbors.is_empty() {
self.edges.remove(&from);
}
if let Some(waiters) = self.incoming_edges.get_mut(&to) {
waiters.remove(&from);
if waiters.is_empty() {
self.incoming_edges.remove(&to);
}
}
}
}
pub fn remove_thread(&mut self, thread_id: ThreadId) {
self.clear_wait_edges(thread_id);
if let Some(waiters) = self.incoming_edges.remove(&thread_id) {
for waiter in waiters {
if let Some(forward_set) = self.edges.get_mut(&waiter) {
forward_set.remove(&thread_id);
if forward_set.is_empty() {
self.edges.remove(&waiter);
}
}
}
}
}
fn find_path(&mut self, start: ThreadId, target: ThreadId) -> Option<Vec<ThreadId>> {
if start == target {
return Some(vec![start]);
}
self.bfs_queue.clear();
self.bfs_visited.clear();
self.bfs_parent.clear();
self.bfs_queue.push_back(start);
self.bfs_visited.insert(start);
while let Some(current) = self.bfs_queue.pop_front() {
if current == target {
let mut path = Vec::with_capacity(self.bfs_parent.len() + 1);
let mut curr = target;
path.push(curr);
while let Some(&p) = self.bfs_parent.get(&curr) {
path.push(p);
curr = p;
}
path.reverse();
return Some(path);
}
if let Some(neighbors) = self.edges.get(¤t) {
for &neighbor in neighbors {
if self.bfs_visited.insert(neighbor) {
self.bfs_parent.insert(neighbor, current);
self.bfs_queue.push_back(neighbor);
}
}
}
}
None
}
}