use std::sync::Arc;
use async_trait::async_trait;
use crate::delta::StateDelta;
use crate::error::{GraphError, ObservedError};
use crate::event::{BarrierId, GraphEvent};
use crate::ids::SpanId;
use crate::state::State;
pub use crate::barrier_node::{BarrierDefaultAction, BarrierNode};
pub use crate::parallel_node::{ParallelErrorStrategy, ParallelNode, ParallelNodeBuilder};
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum NextStep {
Goto(String),
GoToNext,
End,
}
#[derive(Debug)]
pub struct NodeOutput {
pub deltas: Vec<StateDelta>,
pub next: NextStep,
pub metadata: Option<NodeMetadata>,
}
#[derive(Debug, Clone, Default)]
pub struct NodeMetadata {
pub token_cost: f64,
pub has_side_effects: bool,
}
impl NodeOutput {
pub fn new(next: NextStep) -> Self {
Self {
deltas: Vec::new(),
next,
metadata: None,
}
}
pub fn with_delta(mut self, delta: StateDelta) -> Self {
self.deltas.push(delta);
self
}
pub fn with_deltas(mut self, deltas: Vec<StateDelta>) -> Self {
self.deltas.extend(deltas);
self
}
pub fn with_metadata(mut self, metadata: NodeMetadata) -> Self {
self.metadata = Some(metadata);
self
}
pub fn with_token_cost(mut self, cost: f64) -> Self {
self.metadata
.get_or_insert_with(Default::default)
.token_cost = cost;
self
}
pub fn with_side_effects(mut self) -> Self {
self.metadata
.get_or_insert_with(Default::default)
.has_side_effects = true;
self
}
}
#[derive(Debug)]
pub enum StreamNodeResult {
Continue {
deltas: Vec<StateDelta>,
next: NextStep,
span_id: SpanId,
observed: Option<ObservedError>,
metadata: Option<NodeMetadata>,
},
Pause {
deltas: Vec<StateDelta>,
barrier_id: BarrierId,
node_name: String,
span_id: SpanId,
timeout: Option<std::time::Duration>,
default_action: BarrierDefaultAction,
},
Fallback {
deltas: Vec<StateDelta>,
reason: String,
node_name: String,
},
}
#[async_trait]
pub trait FlowNode: Send + Sync {
async fn execute(&self, state: &State) -> Result<NodeOutput, GraphError>;
async fn execute_stream(
&self,
state: &State,
_sink: &tokio::sync::mpsc::Sender<GraphEvent>,
span_id: SpanId,
) -> Result<StreamNodeResult, GraphError> {
let output = self.execute(state).await?;
Ok(StreamNodeResult::Continue {
deltas: output.deltas,
next: output.next,
span_id,
observed: None,
metadata: output.metadata,
})
}
fn metadata_hint(&self) -> NodeMetadata {
NodeMetadata::default()
}
}
#[derive(Clone)]
pub enum NodeKind {
Task(TaskNode),
Condition(ConditionNode),
Barrier(BarrierNode),
Parallel(ParallelNode),
External(std::sync::Arc<dyn FlowNode>),
}
pub type TaskFn = Arc<dyn Fn(&State) -> Result<Vec<StateDelta>, GraphError> + Send + Sync>;
pub type BranchCondition = Arc<dyn Fn(&State) -> bool + Send + Sync>;
#[derive(Clone)]
pub struct TaskNode {
pub name: String,
pub func: TaskFn,
}
impl TaskNode {
pub fn new(
name: impl Into<String>,
func: impl Fn(&State) -> Result<Vec<StateDelta>, GraphError> + Send + Sync + 'static,
) -> Self {
Self {
name: name.into(),
func: Arc::new(func),
}
}
}
#[async_trait]
impl FlowNode for TaskNode {
async fn execute(&self, state: &State) -> Result<NodeOutput, GraphError> {
let deltas = (self.func)(state)?;
Ok(NodeOutput {
deltas,
next: NextStep::GoToNext,
metadata: None,
})
}
fn metadata_hint(&self) -> NodeMetadata {
NodeMetadata {
token_cost: 0.0,
has_side_effects: false,
}
}
}
#[derive(Clone)]
pub struct ConditionNode {
pub name: String,
pub branches: Vec<(String, BranchCondition)>,
}
impl ConditionNode {
pub fn builder(name: impl Into<String>) -> ConditionNodeBuilder {
ConditionNodeBuilder {
name: name.into(),
branches: Vec::new(),
}
}
}
pub struct ConditionNodeBuilder {
name: String,
branches: Vec<(String, BranchCondition)>,
}
impl ConditionNodeBuilder {
pub fn branch(
mut self,
target: impl Into<String>,
condition: impl Fn(&State) -> bool + Send + Sync + 'static,
) -> Self {
self.branches.push((target.into(), Arc::new(condition)));
self
}
pub fn build(self) -> ConditionNode {
ConditionNode {
name: self.name,
branches: self.branches,
}
}
}
#[async_trait]
impl FlowNode for ConditionNode {
async fn execute(&self, state: &State) -> Result<NodeOutput, GraphError> {
for (target, condition) in &self.branches {
if condition(state) {
return Ok(NodeOutput::new(NextStep::Goto(target.clone())));
}
}
Ok(NodeOutput::new(NextStep::GoToNext))
}
fn metadata_hint(&self) -> NodeMetadata {
NodeMetadata {
token_cost: 0.0,
has_side_effects: false,
}
}
}
#[async_trait]
impl FlowNode for NodeKind {
async fn execute(&self, state: &State) -> Result<NodeOutput, GraphError> {
match self {
Self::Task(n) => n.execute(state).await,
Self::Condition(n) => n.execute(state).await,
Self::Barrier(n) => n.execute(state).await,
Self::Parallel(n) => n.execute_sequential(state).await,
Self::External(n) => n.execute(state).await,
}
}
async fn execute_stream(
&self,
state: &State,
sink: &tokio::sync::mpsc::Sender<GraphEvent>,
span_id: SpanId,
) -> Result<StreamNodeResult, GraphError> {
match self {
Self::Task(n) => n.execute_stream(state, sink, span_id).await,
Self::Condition(n) => n.execute_stream(state, sink, span_id).await,
Self::Barrier(n) => n.execute_stream(state, sink, span_id).await,
Self::Parallel(_) => {
let output = self.execute(state).await?;
Ok(StreamNodeResult::Continue {
deltas: output.deltas,
next: output.next,
span_id,
observed: None,
metadata: output.metadata,
})
}
Self::External(n) => n.execute_stream(state, sink, span_id).await,
}
}
}
pub type GraphNode = dyn FlowNode;