use crate::DagrsResult;
use crate::connection::{in_channel::InChannels, out_channel::OutChannels};
use crate::node::{Node, NodeId, NodeName, NodeTable};
use crate::utils::{env::EnvVar, output::FlowControl, output::Output};
use async_trait::async_trait;
use std::sync::{Arc, Mutex};
pub trait LoopCondition: Send + Sync {
fn should_continue(&mut self, input: &InChannels, out: &OutChannels, env: Arc<EnvVar>) -> bool;
fn reset(&mut self) {}
fn restore_from_checkpoint(&mut self, _completed_iterations: usize) -> DagrsResult<()> {
Ok(())
}
}
pub struct CountLoopCondition {
max_iterations: usize,
current_iteration: usize,
}
impl CountLoopCondition {
pub fn new(max: usize) -> Self {
Self {
max_iterations: max,
current_iteration: 0,
}
}
}
impl LoopCondition for CountLoopCondition {
fn should_continue(
&mut self,
_input: &InChannels,
_out: &OutChannels,
_env: Arc<EnvVar>,
) -> bool {
if self.current_iteration < self.max_iterations {
self.current_iteration += 1;
true
} else {
false
}
}
fn reset(&mut self) {
self.current_iteration = 0;
}
fn restore_from_checkpoint(&mut self, completed_iterations: usize) -> DagrsResult<()> {
self.current_iteration = completed_iterations.min(self.max_iterations);
Ok(())
}
}
pub struct LoopNode {
id: NodeId,
name: NodeName,
in_channels: InChannels,
out_channels: OutChannels,
target_node: NodeId,
condition: Mutex<Box<dyn LoopCondition>>,
}
impl LoopNode {
pub fn new(
name: NodeName,
target_node: NodeId,
condition: impl LoopCondition + 'static,
node_table: &mut NodeTable,
) -> Self {
Self {
id: node_table.alloc_id_for(&name),
name,
in_channels: InChannels::default(),
out_channels: OutChannels::default(),
target_node,
condition: Mutex::new(Box::new(condition)),
}
}
}
#[async_trait]
impl Node for LoopNode {
fn id(&self) -> NodeId {
self.id
}
fn name(&self) -> NodeName {
self.name.clone()
}
fn input_channels(&mut self) -> &mut InChannels {
&mut self.in_channels
}
fn output_channels(&mut self) -> &mut OutChannels {
&mut self.out_channels
}
async fn run(&mut self, env: Arc<EnvVar>) -> Output {
let should_continue = self
.condition
.lock()
.unwrap_or_else(|poisoned| {
log::warn!("LoopNode condition mutex was poisoned, recovering");
poisoned.into_inner()
})
.should_continue(&self.in_channels, &self.out_channels, env);
if should_continue {
Output::Flow(FlowControl::loop_to_node(self.target_node.as_usize()))
} else {
Output::Flow(FlowControl::Continue)
}
}
fn reset(&mut self) {
self.condition
.lock()
.unwrap_or_else(|poisoned| {
log::warn!("LoopNode condition mutex was poisoned during reset, recovering");
poisoned.into_inner()
})
.reset();
}
fn restore_from_checkpoint(&mut self, loop_count: usize) -> DagrsResult<()> {
self.condition
.lock()
.unwrap_or_else(|poisoned| {
log::warn!(
"LoopNode condition mutex was poisoned during checkpoint restore, recovering"
);
poisoned.into_inner()
})
.restore_from_checkpoint(loop_count)
}
}