use std::sync::Arc;
use async_trait::async_trait;
use crate::error::{GraphError, ObservedError};
use crate::event::BarrierId;
use crate::ids::SpanId;
use crate::node_context::NodeContext;
use crate::state::{State, StateMerge};
use crate::workflow_state::{MergeStrategy, WorkflowState};
pub use crate::barrier_node::{BarrierDefaultAction, BarrierNode};
pub use crate::parallel_node::{
ParallelErrorStrategy, ParallelNode, ParallelNodeBuilder, ParallelNodeBuilderWithMerge,
};
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum NextStep {
Goto(String),
GoToNext,
End,
}
#[derive(Debug)]
pub struct NodeOutput {
pub deltas: Vec<crate::delta::StateDelta>,
pub next: NextStep,
pub metadata: Option<crate::node_context::NodeMetadata>,
}
impl NodeOutput {
pub fn new(next: NextStep) -> Self {
Self {
deltas: Vec::new(),
next,
metadata: None,
}
}
pub fn with_delta(mut self, delta: crate::delta::StateDelta) -> Self {
self.deltas.push(delta);
self
}
pub fn with_deltas(mut self, deltas: Vec<crate::delta::StateDelta>) -> Self {
self.deltas.extend(deltas);
self
}
pub fn with_metadata(mut self, metadata: crate::node_context::NodeMetadata) -> Self {
self.metadata = Some(metadata);
self
}
pub fn with_token_cost(mut self, cost: f64) -> Self {
self.metadata
.get_or_insert_with(Default::default)
.token_cost = cost;
self
}
pub fn with_side_effects(mut self) -> Self {
self.metadata
.get_or_insert_with(Default::default)
.has_side_effects = true;
self
}
}
pub use crate::node_context::NodeMetadata;
#[derive(Debug)]
pub enum StreamNodeResult {
Continue {
deltas: Vec<crate::delta::StateDelta>,
next: NextStep,
span_id: SpanId,
observed: Option<ObservedError>,
metadata: Option<NodeMetadata>,
},
Pause {
deltas: Vec<crate::delta::StateDelta>,
barrier_id: BarrierId,
node_name: String,
span_id: SpanId,
timeout: Option<std::time::Duration>,
default_action: BarrierDefaultAction,
},
Fallback {
deltas: Vec<crate::delta::StateDelta>,
reason: String,
node_name: String,
},
}
#[async_trait]
pub trait FlowNode<S: WorkflowState = State>: Send + Sync {
async fn execute(&self, ctx: &mut NodeContext<'_, S>) -> Result<(), GraphError>;
}
pub enum NodeKind<S: WorkflowState = State, M: MergeStrategy<S> = StateMerge> {
Task(TaskNode<S>),
Condition(ConditionNode<S>),
Barrier(BarrierNode<S>),
Parallel(ParallelNode<S, M>),
External(Arc<dyn FlowNode<S>>),
}
impl<S: WorkflowState, M: MergeStrategy<S>> Clone for NodeKind<S, M> {
fn clone(&self) -> Self {
match self {
Self::Task(n) => Self::Task(n.clone()),
Self::Condition(n) => Self::Condition(n.clone()),
Self::Barrier(n) => Self::Barrier(n.clone()),
Self::Parallel(n) => Self::Parallel(n.clone()),
Self::External(n) => Self::External(n.clone()),
}
}
}
pub type TaskFn<S> = Arc<dyn Fn(&mut NodeContext<'_, S>) -> Result<(), GraphError> + Send + Sync>;
#[derive(Clone)]
pub struct TaskNode<S: WorkflowState = State> {
pub name: String,
pub func: TaskFn<S>,
}
impl<S: WorkflowState> TaskNode<S> {
pub fn new(
name: impl Into<String>,
func: impl Fn(&mut NodeContext<'_, S>) -> Result<(), GraphError> + Send + Sync + 'static,
) -> Self {
Self {
name: name.into(),
func: Arc::new(func),
}
}
}
#[async_trait]
impl<S: WorkflowState> FlowNode<S> for TaskNode<S> {
async fn execute(&self, ctx: &mut NodeContext<'_, S>) -> Result<(), GraphError> {
(self.func)(ctx)
}
}
pub type BranchCondition<S> = Arc<dyn Fn(&S) -> bool + Send + Sync>;
#[derive(Clone)]
pub struct ConditionNode<S: WorkflowState = State> {
pub name: String,
pub branches: Vec<(String, BranchCondition<S>)>,
}
impl<S: WorkflowState> ConditionNode<S> {
pub fn builder(name: impl Into<String>) -> ConditionNodeBuilder<S> {
ConditionNodeBuilder {
name: name.into(),
branches: Vec::new(),
}
}
}
pub struct ConditionNodeBuilder<S: WorkflowState = State> {
name: String,
branches: Vec<(String, BranchCondition<S>)>,
}
impl<S: WorkflowState> ConditionNodeBuilder<S> {
pub fn branch(
mut self,
target: impl Into<String>,
condition: impl Fn(&S) -> bool + Send + Sync + 'static,
) -> Self {
self.branches.push((target.into(), Arc::new(condition)));
self
}
pub fn build(self) -> ConditionNode<S> {
ConditionNode {
name: self.name,
branches: self.branches,
}
}
}
#[async_trait]
impl<S: WorkflowState> FlowNode<S> for ConditionNode<S> {
async fn execute(&self, ctx: &mut NodeContext<'_, S>) -> Result<(), GraphError> {
let state = ctx.state();
for (target, condition) in &self.branches {
if condition(state) {
ctx.goto(target);
return Ok(());
}
}
Ok(())
}
}
#[async_trait]
impl<S: WorkflowState, M: MergeStrategy<S>> FlowNode<S> for NodeKind<S, M> {
async fn execute(&self, ctx: &mut NodeContext<'_, S>) -> Result<(), GraphError> {
match self {
Self::Task(n) => n.execute(ctx).await,
Self::Condition(n) => n.execute(ctx).await,
Self::Barrier(n) => n.execute(ctx).await,
Self::Parallel(n) => n.execute(ctx).await,
Self::External(n) => n.execute(ctx).await,
}
}
}
pub type GraphNode<S> = dyn FlowNode<S>;