use async_trait::async_trait;
use dagrs::node::action::Action;
use dagrs::node::default_node::DefaultNode;
use dagrs::node::loop_node::{CountLoopCondition, LoopNode};
use dagrs::{EnvVar, Graph, InChannels, Node, NodeTable, OutChannels, Output};
use std::sync::{Arc, Mutex};
#[derive(Clone)]
struct IncAction {
counter: Arc<Mutex<usize>>,
}
#[async_trait]
impl Action for IncAction {
async fn run(&self, _: &mut InChannels, _: &mut OutChannels, _: Arc<EnvVar>) -> Output {
let mut c = self.counter.lock().unwrap();
*c += 1;
Output::empty()
}
}
#[tokio::test]
async fn test_loop_reset() {
let mut graph = Graph::new();
let mut table = NodeTable::new();
let counter = Arc::new(Mutex::new(0));
let node_a = DefaultNode::with_action(
"A".to_string(),
IncAction {
counter: counter.clone(),
},
&mut table,
);
let id_a = node_a.id();
let loop_node = LoopNode::new(
"Loop".to_string(),
id_a,
CountLoopCondition::new(3),
&mut table,
);
let id_loop = loop_node.id();
graph.add_node(node_a);
graph.add_node(loop_node);
graph.add_edge(id_a, vec![id_loop]);
println!("First Run");
graph.async_start().await.unwrap();
let first_run_count = *counter.lock().unwrap();
println!("First run count: {}", first_run_count);
assert_eq!(
first_run_count, 4,
"Counter should be 4 (1 initial + 3 loops)"
);
graph.reset().await;
*counter.lock().unwrap() = 0;
println!("Second Run");
graph.async_start().await.unwrap();
let second_run_count = *counter.lock().unwrap();
println!("Second run count: {}", second_run_count);
assert_eq!(
second_run_count, 4,
"Counter should be 4 after reset (1 initial + 3 loops)"
);
assert_eq!(
first_run_count, second_run_count,
"Loop execution count should be consistent after reset"
);
}