use std::collections::{HashMap, HashSet};
use crate::graph::{
edge::{Edge, EdgeType, END, START},
error::GraphError,
state::State,
};
pub struct NodeScheduler<S: State> {
pub(crate) adjacency: HashMap<String, Vec<Edge<S>>>,
reverse_adjacency: HashMap<String, Vec<String>>,
}
impl<S: State> NodeScheduler<S> {
pub fn new(adjacency: HashMap<String, Vec<Edge<S>>>) -> Self {
let mut reverse_adjacency: HashMap<String, Vec<String>> = HashMap::new();
for (from, edges) in &adjacency {
for edge in edges {
match &edge.edge_type {
EdgeType::Regular { to } => {
reverse_adjacency
.entry(to.clone())
.or_insert_with(Vec::new)
.push(from.clone());
}
EdgeType::Conditional { mapping, .. } => {
for target in mapping.values() {
reverse_adjacency
.entry(target.clone())
.or_insert_with(Vec::new)
.push(from.clone());
}
}
}
}
}
Self {
adjacency,
reverse_adjacency,
}
}
pub async fn get_ready_nodes(
&self,
executed_nodes: &HashSet<String>,
current_state: &S,
) -> Result<Vec<String>, GraphError> {
let mut ready_nodes = Vec::new();
if executed_nodes.is_empty() {
if let Some(start_edges) = self.adjacency.get(START) {
for edge in start_edges {
let target = edge.get_target(current_state).await?;
if target != END && !ready_nodes.contains(&target) {
ready_nodes.push(target);
}
}
}
return Ok(ready_nodes);
}
for (node, predecessors) in &self.reverse_adjacency {
if node == START || node == END {
continue;
}
let all_predecessors_executed = predecessors
.iter()
.all(|pred| pred == START || executed_nodes.contains(pred));
if all_predecessors_executed && !executed_nodes.contains(node) {
ready_nodes.push(node.clone());
}
}
Ok(ready_nodes)
}
pub async fn get_next_nodes(
&self,
current_nodes: &[String],
state: &S,
) -> Result<Vec<String>, GraphError> {
let mut next_nodes = HashSet::new();
for node in current_nodes {
if let Some(edges) = self.adjacency.get(node) {
for edge in edges {
let target = edge.get_target(state).await?;
if target != END {
next_nodes.insert(target);
}
}
}
}
Ok(next_nodes.into_iter().collect())
}
pub async fn is_complete(
&self,
current_nodes: &[String],
state: &S,
) -> Result<bool, GraphError> {
for node in current_nodes {
if let Some(edges) = self.adjacency.get(node) {
for edge in edges {
let target = edge.get_target(state).await?;
if target == END {
return Ok(true);
}
}
}
}
Ok(false)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::graph::state::MessagesState;
#[tokio::test]
async fn test_get_ready_nodes() {
let mut adjacency = HashMap::new();
adjacency.insert(START.to_string(), vec![Edge::new(START, "node1")]);
adjacency.insert("node1".to_string(), vec![Edge::new("node1", "node2")]);
let scheduler = NodeScheduler::<MessagesState>::new(adjacency);
let executed = HashSet::new();
let state = MessagesState::new();
let ready = scheduler.get_ready_nodes(&executed, &state).await.unwrap();
assert_eq!(ready, vec!["node1"]);
}
}