use super::state::StateSchema;
use super::errors::GraphError;
use std::collections::HashMap;
use std::marker::PhantomData;
#[derive(Debug, Clone, PartialEq)]
pub enum EdgeTarget {
Fixed(String),
Conditional(String),
}
impl EdgeTarget {
pub fn to(node: impl Into<String>) -> Self {
Self::Fixed(node.into())
}
pub fn conditional(router: impl Into<String>) -> Self {
Self::Conditional(router.into())
}
}
#[derive(Debug, Clone)]
pub enum GraphEdge {
Fixed {
source: String,
target: String,
},
Conditional {
source: String,
router_name: String,
targets: HashMap<String, String>,
default_target: Option<String>,
},
FanOut {
source: String,
targets: Vec<String>,
},
FanIn {
sources: Vec<String>,
target: String,
},
}
impl GraphEdge {
pub fn fixed(source: impl Into<String>, target: impl Into<String>) -> Self {
Self::Fixed {
source: source.into(),
target: target.into(),
}
}
pub fn conditional<R, T>(
source: impl Into<String>,
router_name: impl Into<String>,
targets: HashMap<R, T>,
default: Option<T>,
) -> Self
where
R: Into<String>,
T: Into<String>,
{
Self::Conditional {
source: source.into(),
router_name: router_name.into(),
targets: targets.into_iter().map(|(k, v)| (k.into(), v.into())).collect(),
default_target: default.map(|d| d.into()),
}
}
pub fn fan_out(source: impl Into<String>, targets: Vec<String>) -> Self {
Self::FanOut {
source: source.into(),
targets,
}
}
pub fn fan_in(sources: Vec<String>, target: impl Into<String>) -> Self {
Self::FanIn {
sources,
target: target.into(),
}
}
pub fn source(&self) -> &str {
match self {
Self::Fixed { source, .. } => source,
Self::Conditional { source, .. } => source,
Self::FanOut { source, .. } => source,
Self::FanIn { .. } => "__fanin__", }
}
pub fn fixed_target(&self) -> Option<&str> {
match self {
Self::Fixed { target, .. } => Some(target),
Self::Conditional { .. } => None,
Self::FanOut { .. } => None,
Self::FanIn { target, .. } => Some(target),
}
}
pub fn fan_out_targets(&self) -> Option<&Vec<String>> {
match self {
Self::FanOut { targets, .. } => Some(targets),
_ => None,
}
}
pub fn fan_in_sources(&self) -> Option<&Vec<String>> {
match self {
Self::FanIn { sources, .. } => Some(sources),
_ => None,
}
}
}
#[async_trait::async_trait]
pub trait ConditionalEdge<S: StateSchema>: Send + Sync {
async fn route(&self, state: &S) -> Result<String, GraphError>;
}
pub struct FunctionRouter<S: StateSchema, F> {
func: F,
_marker: PhantomData<S>,
}
impl<S: StateSchema, F> FunctionRouter<S, F>
where
F: Fn(&S) -> String + Send + Sync,
{
pub fn new(func: F) -> Self {
Self { func, _marker: PhantomData }
}
}
#[async_trait::async_trait]
impl<S: StateSchema, F> ConditionalEdge<S> for FunctionRouter<S, F>
where
F: Fn(&S) -> String + Send + Sync,
{
async fn route(&self, state: &S) -> Result<String, GraphError> {
Ok((self.func)(state))
}
}
pub struct AsyncFunctionRouter<S: StateSchema, F> {
func: F,
_marker: PhantomData<S>,
}
impl<S: StateSchema, F, Fut> AsyncFunctionRouter<S, F>
where
F: Fn(&S) -> Fut + Send + Sync,
Fut: std::future::Future<Output = Result<String, GraphError>> + Send,
{
pub fn new(func: F) -> Self {
Self { func, _marker: PhantomData }
}
}
#[async_trait::async_trait]
impl<S: StateSchema, F, Fut> ConditionalEdge<S> for AsyncFunctionRouter<S, F>
where
F: Fn(&S) -> Fut + Send + Sync,
Fut: std::future::Future<Output = Result<String, GraphError>> + Send,
{
async fn route(&self, state: &S) -> Result<String, GraphError> {
(self.func)(state).await
}
}
pub const ROUTE_CONTINUE: &str = "continue";
pub const ROUTE_END: &str = "end";
pub const ROUTE_ERROR: &str = "error";
#[cfg(test)]
mod tests {
use super::*;
use super::super::state::AgentState;
#[test]
fn test_fixed_edge() {
let edge = GraphEdge::fixed("start", "process");
assert_eq!(edge.source(), "start");
assert_eq!(edge.fixed_target(), Some("process"));
}
#[test]
fn test_conditional_edge() {
let targets = HashMap::from([
("continue", "next_node"),
("end", "__end__"),
]);
let edge = GraphEdge::conditional("decision", "router", targets, None);
assert_eq!(edge.source(), "decision");
assert!(edge.fixed_target().is_none());
}
#[tokio::test]
async fn test_function_router() {
let router = FunctionRouter::new(|state: &AgentState| {
if state.output.is_some() {
ROUTE_END.to_string()
} else {
ROUTE_CONTINUE.to_string()
}
});
let state = AgentState::new("test".to_string());
let route = router.route(&state).await.unwrap();
assert_eq!(route, ROUTE_CONTINUE);
}
}