use std::sync::Arc;
use async_trait::async_trait;
use crate::error::{GraphError, ObservedError, TerminalError};
use crate::event::{BarrierId, GraphEvent, SpanId};
use crate::graph::Edge;
use crate::state::State;
pub use crate::barrier_node::{BarrierDefaultAction, BarrierNode};
pub use crate::llm_node::{AgentNode, LLMNode};
pub use crate::tool_node::ToolNode;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum NextStep {
Goto(String),
GoToNext,
End,
}
#[derive(Debug)]
pub enum StreamNodeResult {
Done {
next: NextStep,
span_id: SpanId,
},
BarrierPaused {
barrier_id: BarrierId,
node_name: String,
span_id: SpanId,
timeout: Option<std::time::Duration>,
default_action: crate::barrier_node::BarrierDefaultAction,
},
Observed {
error: ObservedError,
next: NextStep,
span_id: SpanId,
},
}
#[async_trait]
pub trait GraphNode: Send + Sync {
async fn execute(&self, state: &mut State) -> Result<NextStep, GraphError>;
async fn execute_stream(
&self,
state: &mut State,
_sink: &tokio::sync::mpsc::Sender<GraphEvent>,
span_id: SpanId,
) -> Result<StreamNodeResult, GraphError> {
let next = self.execute(state).await?;
Ok(StreamNodeResult::Done { next, span_id })
}
}
pub enum NodeKind {
Task(TaskNode),
Agent(Box<AgentNode>),
Tool(ToolNode),
Condition(ConditionNode),
Loop(Box<LoopNode>),
Barrier(BarrierNode),
}
pub type TaskFn = Arc<dyn Fn(&mut State) -> Result<(), GraphError> + Send + Sync>;
pub type BranchCondition = Arc<dyn Fn(&State) -> bool + Send + Sync>;
pub struct TaskNode {
pub name: String,
pub func: TaskFn,
}
impl TaskNode {
pub fn new(
name: impl Into<String>,
func: impl Fn(&mut State) -> Result<(), GraphError> + Send + Sync + 'static,
) -> Self {
Self {
name: name.into(),
func: Arc::new(func),
}
}
}
#[async_trait]
impl GraphNode for TaskNode {
async fn execute(&self, state: &mut State) -> Result<NextStep, GraphError> {
(self.func)(state)?;
Ok(NextStep::GoToNext)
}
}
pub struct ConditionNode {
pub name: String,
pub branches: Vec<(String, BranchCondition)>,
pub otherwise_target: Option<String>,
}
impl ConditionNode {
pub fn builder(name: impl Into<String>) -> ConditionNodeBuilder {
ConditionNodeBuilder {
name: name.into(),
branches: Vec::new(),
otherwise_target: None,
}
}
}
pub struct ConditionNodeBuilder {
name: String,
branches: Vec<(String, BranchCondition)>,
otherwise_target: Option<String>,
}
impl ConditionNodeBuilder {
pub fn branch(
mut self,
target: impl Into<String>,
condition: impl Fn(&State) -> bool + Send + Sync + 'static,
) -> Self {
self.branches.push((target.into(), Arc::new(condition)));
self
}
pub fn otherwise(mut self, target: impl Into<String>) -> Self {
self.otherwise_target = Some(target.into());
self
}
pub fn build(self) -> ConditionNode {
ConditionNode {
name: self.name,
branches: self.branches,
otherwise_target: self.otherwise_target,
}
}
}
#[async_trait]
impl GraphNode for ConditionNode {
async fn execute(&self, state: &mut State) -> Result<NextStep, GraphError> {
for (target, condition) in &self.branches {
if condition(state) {
return Ok(NextStep::Goto(target.clone()));
}
}
if let Some(ref target) = self.otherwise_target {
return Ok(NextStep::Goto(target.clone()));
}
Err(GraphError::Terminal(TerminalError::NodeExecutionFailed {
node: self.name.clone(),
source: "no matching branch and no otherwise target".into(),
}))
}
}
#[derive(Default)]
pub struct SubGraph {
pub nodes: Vec<Arc<dyn GraphNode>>,
pub edges: Vec<Edge>,
}
impl SubGraph {
pub fn new() -> Self {
Self::default()
}
pub async fn execute(&self, state: &mut State) -> Result<(), GraphError> {
for node in &self.nodes {
match node.execute(state).await? {
NextStep::GoToNext => {
}
NextStep::End => {
break;
}
NextStep::Goto(target) => {
return Err(GraphError::Terminal(TerminalError::InvalidGraph(format!(
"SubGraph does not support Goto(\"{}\"). Use Graph::edge_if for conditional jumps.",
target
))));
}
}
}
Ok(())
}
}
pub struct LoopNode {
pub name: String,
pub body: SubGraph,
pub continue_condition: Arc<dyn Fn(&State) -> bool + Send + Sync>,
pub max_iterations: usize,
}
impl LoopNode {
pub fn new(
name: impl Into<String>,
body: SubGraph,
continue_condition: impl Fn(&State) -> bool + Send + Sync + 'static,
max_iterations: usize,
) -> Self {
Self {
name: name.into(),
body,
continue_condition: Arc::new(continue_condition),
max_iterations,
}
}
}
#[async_trait]
impl GraphNode for LoopNode {
async fn execute(&self, state: &mut State) -> Result<NextStep, GraphError> {
for i in 0..self.max_iterations {
tracing::debug!(
loop_name = %self.name,
iteration = i + 1,
max = self.max_iterations,
"executing loop body"
);
self.body.execute(state).await?;
if !(self.continue_condition)(state) {
tracing::debug!(
loop_name = %self.name,
iterations = i + 1,
"loop condition met, exiting"
);
return Ok(NextStep::GoToNext);
}
}
Err(GraphError::Terminal(TerminalError::LoopLimitExceeded {
limit: self.max_iterations,
}))
}
}
#[async_trait]
impl GraphNode for NodeKind {
async fn execute(&self, state: &mut State) -> Result<NextStep, GraphError> {
match self {
Self::Task(n) => n.execute(state).await,
Self::Agent(n) => n.execute(state).await,
Self::Tool(n) => n.execute(state).await,
Self::Condition(n) => n.execute(state).await,
Self::Loop(n) => n.execute(state).await,
Self::Barrier(n) => n.execute(state).await,
}
}
async fn execute_stream(
&self,
state: &mut State,
sink: &tokio::sync::mpsc::Sender<GraphEvent>,
span_id: SpanId,
) -> Result<StreamNodeResult, GraphError> {
match self {
Self::Task(n) => n.execute_stream(state, sink, span_id).await,
Self::Agent(n) => n.execute_stream(state, sink, span_id).await,
Self::Tool(n) => n.execute_stream(state, sink, span_id).await,
Self::Condition(n) => n.execute_stream(state, sink, span_id).await,
Self::Loop(n) => n.execute_stream(state, sink, span_id).await,
Self::Barrier(n) => n.execute_stream(state, sink, span_id).await,
}
}
}