use std::collections::HashMap;
use std::sync::Arc;
use super::{error::GraphError, state::State};
pub const START: &str = "__start__";
pub const END: &str = "__end__";
#[derive(Clone)]
pub enum EdgeType<S: State> {
Regular { to: String },
Conditional {
condition: Arc<
dyn Fn(
&S,
) -> std::pin::Pin<
Box<dyn std::future::Future<Output = Result<String, GraphError>> + Send>,
> + Send
+ Sync,
>,
mapping: HashMap<String, String>, },
}
impl<S: State> std::fmt::Debug for EdgeType<S> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
EdgeType::Regular { to } => f.debug_struct("Regular").field("to", to).finish(),
EdgeType::Conditional { mapping, .. } => f
.debug_struct("Conditional")
.field("condition", &"<fn>")
.field("mapping", mapping)
.finish(),
}
}
}
#[derive(Clone, Debug)]
pub struct Edge<S: State> {
pub from: String,
pub edge_type: EdgeType<S>,
}
impl<S: State> Edge<S> {
pub fn new(from: impl Into<String>, to: impl Into<String>) -> Self {
Self {
from: from.into(),
edge_type: EdgeType::Regular { to: to.into() },
}
}
pub fn conditional<F, Fut>(
from: impl Into<String>,
condition: F,
mapping: HashMap<String, String>,
) -> Self
where
F: Fn(&S) -> Fut + Send + Sync + 'static,
Fut: std::future::Future<Output = Result<String, GraphError>> + Send + 'static,
{
Self {
from: from.into(),
edge_type: EdgeType::Conditional {
condition: Arc::new(move |state| Box::pin(condition(state))),
mapping,
},
}
}
pub async fn get_target(&self, state: &S) -> Result<String, GraphError> {
match &self.edge_type {
EdgeType::Regular { to } => Ok(to.clone()),
EdgeType::Conditional { condition, mapping } => {
let condition_result = (condition)(state).await?;
mapping.get(&condition_result).cloned().ok_or_else(|| {
GraphError::ConditionError(format!(
"Condition returned '{}' which is not in mapping",
condition_result
))
})
}
}
}
pub fn is_regular(&self) -> bool {
matches!(self.edge_type, EdgeType::Regular { .. })
}
pub fn is_conditional(&self) -> bool {
matches!(self.edge_type, EdgeType::Conditional { .. })
}
}
pub fn edge<S: State>(from: impl Into<String>, to: impl Into<String>) -> Edge<S> {
Edge::new(from, to)
}
pub fn conditional_edge<S: State, F, Fut>(
from: impl Into<String>,
condition: F,
mapping: HashMap<String, String>,
) -> Edge<S>
where
F: Fn(&S) -> Fut + Send + Sync + 'static,
Fut: std::future::Future<Output = Result<String, GraphError>> + Send + 'static,
{
Edge::conditional(from, condition, mapping)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::graph::state::MessagesState;
#[tokio::test]
async fn test_regular_edge() {
let edge = Edge::new("node1", "node2");
let state = MessagesState::new();
let target = edge.get_target(&state).await.unwrap();
assert_eq!(target, "node2");
assert!(edge.is_regular());
}
#[tokio::test]
async fn test_conditional_edge() {
let mut mapping = HashMap::new();
mapping.insert("yes".to_string(), "node_yes".to_string());
mapping.insert("no".to_string(), "node_no".to_string());
let edge = Edge::conditional(
"node1",
|_state: &MessagesState| async move { Ok("yes".to_string()) },
mapping,
);
let state = MessagesState::new();
let target = edge.get_target(&state).await.unwrap();
assert_eq!(target, "node_yes");
assert!(edge.is_conditional());
}
}