use async_trait::async_trait;
use futures::Stream;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::pin::Pin;
use crate::agent::error::AgentResult;
use super::{Command, GraphConfig, GraphState, Reducer, RuntimeContext};
pub const START: &str = "__START__";
pub const END: &str = "__END__";
#[async_trait]
pub trait NodeFunc<S: GraphState>: Send + Sync {
async fn call(&self, state: &mut S, ctx: &RuntimeContext) -> AgentResult<Command>;
fn name(&self) -> &str;
fn description(&self) -> Option<&str> {
None
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum EdgeTarget {
Single(String),
Conditional(HashMap<String, String>),
Parallel(Vec<String>),
}
impl EdgeTarget {
pub fn single(target: impl Into<String>) -> Self {
Self::Single(target.into())
}
pub fn conditional(routes: HashMap<String, String>) -> Self {
Self::Conditional(routes)
}
pub fn parallel(targets: Vec<String>) -> Self {
Self::Parallel(targets)
}
pub fn is_conditional(&self) -> bool {
matches!(self, Self::Conditional(_))
}
pub fn targets(&self) -> Vec<&str> {
match self {
Self::Single(t) => vec![t],
Self::Conditional(routes) => routes.values().map(|s| s.as_str()).collect(),
Self::Parallel(targets) => targets.iter().map(|s| s.as_str()).collect(),
}
}
}
#[async_trait]
pub trait StateGraph: Send + Sync {
type State: GraphState;
type Compiled: CompiledGraph<Self::State>;
fn new(id: impl Into<String>) -> Self;
fn add_node(
&mut self,
id: impl Into<String>,
node: Box<dyn NodeFunc<Self::State>>,
) -> &mut Self;
fn add_edge(&mut self, from: impl Into<String>, to: impl Into<String>) -> &mut Self;
fn add_conditional_edges(
&mut self,
from: impl Into<String>,
conditions: HashMap<String, String>,
) -> &mut Self;
fn add_parallel_edges(&mut self, from: impl Into<String>, targets: Vec<String>) -> &mut Self;
fn set_entry_point(&mut self, node: impl Into<String>) -> &mut Self;
fn set_finish_point(&mut self, node: impl Into<String>) -> &mut Self;
fn add_reducer(&mut self, key: impl Into<String>, reducer: Box<dyn Reducer>) -> &mut Self;
fn with_config(&mut self, config: GraphConfig) -> &mut Self;
fn id(&self) -> &str;
fn compile(self) -> AgentResult<Self::Compiled>;
}
#[async_trait]
pub trait CompiledGraph<S: GraphState>: Send + Sync {
fn id(&self) -> &str;
async fn invoke(&self, input: S, config: Option<RuntimeContext>) -> AgentResult<S>;
async fn stream(
&self,
input: S,
config: Option<RuntimeContext>,
) -> AgentResult<Pin<Box<dyn Stream<Item = AgentResult<StreamEvent<S>>> + Send>>>;
async fn step(&self, input: S, config: Option<RuntimeContext>) -> AgentResult<StepResult<S>>;
fn validate_state(&self, state: &S) -> AgentResult<()>;
fn state_schema(&self) -> HashMap<String, String>;
}
#[derive(Debug, Clone)]
pub enum StreamEvent<S: GraphState> {
NodeStart { node_id: String, state: S },
NodeEnd {
node_id: String,
state: S,
command: Command,
},
End { final_state: S },
Error {
node_id: Option<String>,
error: String,
},
}
#[derive(Debug, Clone)]
pub struct StepResult<S: GraphState> {
pub state: S,
pub node_id: String,
pub command: Command,
pub is_complete: bool,
pub next_node: Option<String>,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_edge_target_single() {
let target = EdgeTarget::single("node_a");
assert!(!target.is_conditional());
assert_eq!(target.targets(), vec!["node_a"]);
}
#[test]
fn test_edge_target_conditional() {
let mut routes = HashMap::new();
routes.insert("condition_a".to_string(), "node_a".to_string());
routes.insert("condition_b".to_string(), "node_b".to_string());
let target = EdgeTarget::conditional(routes);
assert!(target.is_conditional());
let targets = target.targets();
assert_eq!(targets.len(), 2);
assert!(targets.contains(&"node_a"));
assert!(targets.contains(&"node_b"));
}
#[test]
fn test_edge_target_parallel() {
let target = EdgeTarget::parallel(vec!["a".to_string(), "b".to_string(), "c".to_string()]);
assert!(!target.is_conditional());
assert_eq!(target.targets(), vec!["a", "b", "c"]);
}
#[test]
fn test_constants() {
assert_eq!(START, "__START__");
assert_eq!(END, "__END__");
}
}