rrag_graph/nodes/
condition.rs1use crate::core::{ExecutionContext, ExecutionResult, Node, NodeId};
6use crate::state::GraphState;
7use crate::RGraphResult;
8use async_trait::async_trait;
9
10#[cfg(feature = "serde")]
11use serde::{Deserialize, Serialize};
12
13#[derive(Debug, Clone)]
15#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
16pub struct ConditionNodeConfig {
17 pub condition_key: String,
18 pub condition_value: serde_json::Value,
19 pub true_route: String,
20 pub false_route: String,
21}
22
23pub struct ConditionNode {
25 id: NodeId,
26 name: String,
27 config: ConditionNodeConfig,
28}
29
30impl ConditionNode {
31 pub fn new(
32 id: impl Into<NodeId>,
33 name: impl Into<String>,
34 config: ConditionNodeConfig,
35 ) -> Self {
36 Self {
37 id: id.into(),
38 name: name.into(),
39 config,
40 }
41 }
42}
43
44#[async_trait]
45impl Node for ConditionNode {
46 async fn execute(
47 &self,
48 state: &mut GraphState,
49 _context: &ExecutionContext,
50 ) -> RGraphResult<ExecutionResult> {
51 let state_value = state.get(&self.config.condition_key)?;
53 let state_json: serde_json::Value = state_value.into();
54
55 let route = if state_json == self.config.condition_value {
56 &self.config.true_route
57 } else {
58 &self.config.false_route
59 };
60
61 Ok(ExecutionResult::Route(route.clone()))
62 }
63
64 fn id(&self) -> &NodeId {
65 &self.id
66 }
67
68 fn name(&self) -> &str {
69 &self.name
70 }
71
72 fn input_keys(&self) -> Vec<&str> {
73 vec![&self.config.condition_key]
74 }
75
76 fn output_keys(&self) -> Vec<&str> {
77 vec![]
78 }
79}