use async_trait::async_trait;
use dagrs::node::action::Action;
use dagrs::node::default_node::DefaultNode;
use dagrs::node::router_node::{Router, RouterNode};
use dagrs::{EnvVar, Graph, InChannels, Node, NodeTable, OutChannels, Output};
use std::sync::{Arc, Mutex};
#[derive(Clone)]
struct PassthroughAction {
name: String,
}
#[async_trait]
impl Action for PassthroughAction {
async fn run(&self, input: &mut InChannels, out: &mut OutChannels, _: Arc<EnvVar>) -> Output {
println!("[{}] Running", self.name);
if !input.get_sender_ids().is_empty() {
match input.recv_any().await {
Ok((_, val)) => {
println!("[{}] Received {:?}", self.name, val);
out.broadcast(val).await;
}
Err(_) => {
println!("[{}] Input closed", self.name);
return Output::empty();
}
}
}
Output::empty()
}
}
struct SwitchRouter {
target: Arc<Mutex<dagrs::NodeId>>,
}
#[async_trait]
impl Router for SwitchRouter {
async fn route(&self, _: &mut InChannels, out: &OutChannels, _: Arc<EnvVar>) -> Vec<usize> {
let t = *self.target.lock().unwrap();
let _ = out
.send_to(&t, dagrs::Content::new("ping".to_string()))
.await;
vec![t.as_usize()]
}
}
#[tokio::test]
async fn test_chain_skip_deadlock() {
let mut graph = Graph::new();
let mut table = NodeTable::new();
let action_b = PassthroughAction {
name: "B".to_string(),
};
let node_b = DefaultNode::with_action("B".to_string(), action_b, &mut table);
let id_b = node_b.id();
let action_c = PassthroughAction {
name: "C".to_string(),
};
let node_c = DefaultNode::with_action("C".to_string(), action_c, &mut table);
let id_c = node_c.id();
let action_d = PassthroughAction {
name: "D".to_string(),
};
let node_d = DefaultNode::with_action("D".to_string(), action_d, &mut table);
let id_d = node_d.id();
let dummy_node = DefaultNode::new("Dummy".to_string(), &mut table);
let dummy_id = dummy_node.id();
let target = Arc::new(Mutex::new(dummy_id));
let router = RouterNode::new(
"Router".to_string(),
SwitchRouter {
target: target.clone(),
},
&mut table,
);
let id_router = router.id();
graph.add_node(router).unwrap();
graph.add_node(node_b).unwrap();
graph.add_node(node_c).unwrap();
graph.add_node(node_d).unwrap();
graph.add_edge(id_router, vec![id_b]).unwrap();
graph.add_edge(id_b, vec![id_c]).unwrap();
graph.add_edge(id_c, vec![id_d]).unwrap();
println!("Starting Graph...");
let result = tokio::time::timeout(std::time::Duration::from_secs(1), graph.async_start()).await;
match result {
Ok(Ok(_)) => println!("Graph finished successfully"),
Ok(Err(e)) => panic!("Graph failed: {:?}", e),
Err(_) => panic!("Graph timed out! Deadlock detected."),
}
}