use crate::state::State;
use std::sync::Arc;
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum EdgeTarget {
Node(String),
End,
}
impl EdgeTarget {
pub fn node(id: impl Into<String>) -> Self {
Self::Node(id.into())
}
pub fn is_end(&self) -> bool {
matches!(self, Self::End)
}
pub fn node_id(&self) -> Option<&str> {
match self {
Self::Node(id) => Some(id),
Self::End => None,
}
}
}
pub trait Router<S: State>: Send + Sync {
fn route(&self, state: &S) -> EdgeTarget;
fn possible_targets(&self) -> Vec<EdgeTarget> {
vec![]
}
}
pub struct FnRouter<S, F>
where
S: State,
F: Fn(&S) -> EdgeTarget + Send + Sync,
{
func: F,
possible_targets: Vec<EdgeTarget>,
_phantom: std::marker::PhantomData<S>,
}
impl<S, F> FnRouter<S, F>
where
S: State,
F: Fn(&S) -> EdgeTarget + Send + Sync,
{
pub fn new(func: F) -> Self {
Self {
func,
possible_targets: vec![],
_phantom: std::marker::PhantomData,
}
}
pub fn with_targets(mut self, targets: Vec<EdgeTarget>) -> Self {
self.possible_targets = targets;
self
}
}
impl<S, F> Router<S> for FnRouter<S, F>
where
S: State,
F: Fn(&S) -> EdgeTarget + Send + Sync,
{
fn route(&self, state: &S) -> EdgeTarget {
(self.func)(state)
}
fn possible_targets(&self) -> Vec<EdgeTarget> {
self.possible_targets.clone()
}
}
pub enum Edge<S: State> {
Direct(EdgeTarget),
Conditional(Box<dyn Router<S>>),
}
impl<S: State> Edge<S> {
pub fn to_node(id: impl Into<String>) -> Self {
Self::Direct(EdgeTarget::Node(id.into()))
}
pub fn to_end() -> Self {
Self::Direct(EdgeTarget::End)
}
pub fn conditional_fn<F>(f: F) -> Self
where
F: Fn(&S) -> EdgeTarget + Send + Sync + 'static,
{
Self::Conditional(Box::new(FnRouter::new(f)))
}
pub fn conditional<R: Router<S> + 'static>(router: R) -> Self {
Self::Conditional(Box::new(router))
}
pub fn resolve(&self, state: &S) -> EdgeTarget {
match self {
Self::Direct(target) => target.clone(),
Self::Conditional(router) => router.route(state),
}
}
pub fn possible_targets(&self) -> Vec<EdgeTarget> {
match self {
Self::Direct(target) => vec![target.clone()],
Self::Conditional(router) => router.possible_targets(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde::{Deserialize, Serialize};
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
struct TestState {
go_to_end: bool,
}
impl State for TestState {
fn schema() -> serde_json::Value {
serde_json::json!({"type": "object"})
}
}
#[test]
fn test_edge_target() {
let target = EdgeTarget::node("my_node");
assert_eq!(target.node_id(), Some("my_node"));
assert!(!target.is_end());
let target = EdgeTarget::End;
assert!(target.is_end());
assert_eq!(target.node_id(), None);
}
#[test]
fn test_direct_edge() {
let edge: Edge<TestState> = Edge::to_node("next");
let state = TestState::default();
assert_eq!(edge.resolve(&state), EdgeTarget::node("next"));
}
#[test]
fn test_conditional_edge() {
let edge: Edge<TestState> = Edge::conditional_fn(|s: &TestState| {
if s.go_to_end {
EdgeTarget::End
} else {
EdgeTarget::node("continue")
}
});
let state = TestState { go_to_end: false };
assert_eq!(edge.resolve(&state), EdgeTarget::node("continue"));
let state = TestState { go_to_end: true };
assert_eq!(edge.resolve(&state), EdgeTarget::End);
}
struct TestRouter;
impl Router<TestState> for TestRouter {
fn route(&self, state: &TestState) -> EdgeTarget {
if state.go_to_end {
EdgeTarget::End
} else {
EdgeTarget::node("process")
}
}
fn possible_targets(&self) -> Vec<EdgeTarget> {
vec![EdgeTarget::node("process"), EdgeTarget::End]
}
}
#[test]
fn test_router_trait() {
let edge: Edge<TestState> = Edge::conditional(TestRouter);
let targets = edge.possible_targets();
assert_eq!(targets.len(), 2);
}
}