1use crate::core::NodeId;
6use crate::state::GraphState;
7use crate::RGraphResult;
8use async_trait::async_trait;
9
10#[cfg(feature = "serde")]
11use serde::{Deserialize, Serialize};
12
13#[derive(Debug, Clone)]
15#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
16pub enum RoutingDecision {
17 Route(NodeId),
19 Stop,
21 Continue,
23}
24
25#[async_trait]
27pub trait RoutingCondition: Send + Sync {
28 async fn evaluate(&self, state: &GraphState) -> RGraphResult<RoutingDecision>;
30}
31
32pub struct StateCondition {
34 key: String,
35 expected_value: serde_json::Value,
36 target_node: NodeId,
37}
38
39impl StateCondition {
40 pub fn new(
41 key: impl Into<String>,
42 expected_value: serde_json::Value,
43 target_node: impl Into<NodeId>,
44 ) -> Self {
45 Self {
46 key: key.into(),
47 expected_value,
48 target_node: target_node.into(),
49 }
50 }
51}
52
53#[async_trait]
54impl RoutingCondition for StateCondition {
55 async fn evaluate(&self, state: &GraphState) -> RGraphResult<RoutingDecision> {
56 match state.get(&self.key) {
57 Ok(value) => {
58 let state_json: serde_json::Value = value.into();
59 if state_json == self.expected_value {
60 Ok(RoutingDecision::Route(self.target_node.clone()))
61 } else {
62 Ok(RoutingDecision::Continue)
63 }
64 }
65 Err(_) => Ok(RoutingDecision::Continue),
66 }
67 }
68}
69
70pub struct ConditionalEdge {
72 condition: Box<dyn RoutingCondition>,
73 source: NodeId,
74}
75
76impl ConditionalEdge {
77 pub fn new(source: impl Into<NodeId>, condition: Box<dyn RoutingCondition>) -> Self {
78 Self {
79 condition,
80 source: source.into(),
81 }
82 }
83
84 pub async fn evaluate(&self, state: &GraphState) -> RGraphResult<RoutingDecision> {
85 self.condition.evaluate(state).await
86 }
87}
88
89pub struct Router {
91 conditions: Vec<ConditionalEdge>,
92}
93
94impl Router {
95 pub fn new() -> Self {
96 Self {
97 conditions: Vec::new(),
98 }
99 }
100
101 pub fn add_condition(&mut self, condition: ConditionalEdge) {
102 self.conditions.push(condition);
103 }
104
105 pub async fn route(
106 &self,
107 current_node: &NodeId,
108 state: &GraphState,
109 ) -> RGraphResult<RoutingDecision> {
110 for condition in &self.conditions {
111 if &condition.source == current_node {
112 let decision = condition.evaluate(state).await?;
113 match decision {
114 RoutingDecision::Continue => continue,
115 _ => return Ok(decision),
116 }
117 }
118 }
119
120 Ok(RoutingDecision::Continue)
121 }
122}
123
124impl Default for Router {
125 fn default() -> Self {
126 Self::new()
127 }
128}