mofa_kernel/workflow/graph.rs
1//! State Graph Traits
2//!
3//! Defines the core graph interfaces for building and executing workflows.
4//! Inspired by LangGraph's StateGraph API.
5
6use async_trait::async_trait;
7use futures::Stream;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::pin::Pin;
11
12use crate::agent::error::AgentResult;
13
14use super::{Command, GraphConfig, GraphState, Reducer, RuntimeContext};
15
16/// Special node ID for the graph entry point
17pub const START: &str = "__START__";
18
19/// Special node ID for the graph exit point
20pub const END: &str = "__END__";
21
22/// Node function trait
23///
24/// Implement this trait to define custom node behavior.
25/// Nodes receive the current state and runtime context,
26/// and return a Command that can update state and control flow.
27///
28/// # Example
29///
30/// ```rust,ignore
31/// use mofa_kernel::workflow::{NodeFunc, Command, RuntimeContext};
32///
33/// struct ProcessNode;
34///
35/// #[async_trait]
36/// impl NodeFunc<MyState> for ProcessNode {
37/// async fn call(&self, state: &mut MyState, ctx: &RuntimeContext) -> AgentResult<Command> {
38/// // Process the state
39/// let input = state.messages.last().cloned().unwrap_or_default();
40///
41/// // Return command with state update and control flow
42/// Ok(Command::new()
43/// .update("result", json!(format!("Processed: {}", input)))
44/// .goto("next_node"))
45/// }
46///
47/// fn name(&self) -> &str {
48/// "process"
49/// }
50/// }
51/// ```
52#[async_trait]
53pub trait NodeFunc<S: GraphState>: Send + Sync {
54 /// Execute the node
55 ///
56 /// # Arguments
57 /// * `state` - Mutable reference to the current state
58 /// * `ctx` - Runtime context with execution metadata
59 ///
60 /// # Returns
61 /// A Command containing state updates and control flow directive
62 async fn call(&self, state: &mut S, ctx: &RuntimeContext) -> AgentResult<Command>;
63
64 /// Returns the node name/identifier
65 fn name(&self) -> &str;
66
67 /// Optional description of what this node does
68 fn description(&self) -> Option<&str> {
69 None
70 }
71}
72
73/// Edge target definition
74#[derive(Debug, Clone, Serialize, Deserialize)]
75pub enum EdgeTarget {
76 /// Single target node
77 Single(String),
78 /// Conditional edges with route names to node IDs
79 Conditional(HashMap<String, String>),
80 /// Multiple parallel targets
81 Parallel(Vec<String>),
82}
83
84impl EdgeTarget {
85 /// Create a single target edge
86 pub fn single(target: impl Into<String>) -> Self {
87 Self::Single(target.into())
88 }
89
90 /// Create conditional edges
91 pub fn conditional(routes: HashMap<String, String>) -> Self {
92 Self::Conditional(routes)
93 }
94
95 /// Create parallel edges
96 pub fn parallel(targets: Vec<String>) -> Self {
97 Self::Parallel(targets)
98 }
99
100 /// Check if this is a conditional edge
101 pub fn is_conditional(&self) -> bool {
102 matches!(self, Self::Conditional(_))
103 }
104
105 /// Get all target node IDs
106 pub fn targets(&self) -> Vec<&str> {
107 match self {
108 Self::Single(t) => vec![t],
109 Self::Conditional(routes) => routes.values().map(|s| s.as_str()).collect(),
110 Self::Parallel(targets) => targets.iter().map(|s| s.as_str()).collect(),
111 }
112 }
113}
114
115/// State graph builder trait
116///
117/// Defines the interface for building stateful workflow graphs.
118/// Implementations should provide a fluent API for constructing graphs.
119///
120/// # Example
121///
122/// ```rust,ignore
123/// use mofa_kernel::workflow::{StateGraph, START, END};
124///
125/// let graph = StateGraphImpl::<MyState>::new("my_workflow")
126/// // Add reducers for state keys
127/// .add_reducer("messages", Box::new(AppendReducer))
128/// .add_reducer("result", Box::new(OverwriteReducer))
129/// // Add nodes
130/// .add_node("process", Box::new(ProcessNode))
131/// .add_node("validate", Box::new(ValidateNode))
132/// // Add edges
133/// .add_edge(START, "process")
134/// .add_edge("process", "validate")
135/// .add_edge("validate", END)
136/// // Compile
137/// .compile()?;
138/// ```
139#[async_trait]
140pub trait StateGraph: Send + Sync {
141 /// The state type for this graph
142 type State: GraphState;
143
144 /// The compiled graph type produced by this builder
145 type Compiled: CompiledGraph<Self::State>;
146
147 /// Create a new graph with the given ID
148 fn new(id: impl Into<String>) -> Self;
149
150 /// Add a node to the graph
151 ///
152 /// # Arguments
153 /// * `id` - Unique node identifier
154 /// * `node` - Node function implementation
155 fn add_node(
156 &mut self,
157 id: impl Into<String>,
158 node: Box<dyn NodeFunc<Self::State>>,
159 ) -> &mut Self;
160
161 /// Add an edge between two nodes
162 ///
163 /// # Arguments
164 /// * `from` - Source node ID (use START for entry edge)
165 /// * `to` - Target node ID (use END for exit edge)
166 fn add_edge(&mut self, from: impl Into<String>, to: impl Into<String>) -> &mut Self;
167
168 /// Add conditional edges from a node
169 ///
170 /// # Arguments
171 /// * `from` - Source node ID
172 /// * `conditions` - Map of condition names to target node IDs
173 ///
174 /// # Example
175 /// ```rust,ignore
176 /// graph.add_conditional_edges("classify", HashMap::from([
177 /// ("type_a".to_string(), "handle_a".to_string()),
178 /// ("type_b".to_string(), "handle_b".to_string()),
179 /// ]));
180 /// ```
181 fn add_conditional_edges(
182 &mut self,
183 from: impl Into<String>,
184 conditions: HashMap<String, String>,
185 ) -> &mut Self;
186
187 /// Add parallel edges from a node
188 ///
189 /// # Arguments
190 /// * `from` - Source node ID
191 /// * `targets` - List of target node IDs to execute in parallel
192 fn add_parallel_edges(&mut self, from: impl Into<String>, targets: Vec<String>) -> &mut Self;
193
194 /// Set the entry point (equivalent to add_edge(START, node))
195 fn set_entry_point(&mut self, node: impl Into<String>) -> &mut Self;
196
197 /// Set a finish point (equivalent to add_edge(node, END))
198 fn set_finish_point(&mut self, node: impl Into<String>) -> &mut Self;
199
200 /// Add a reducer for a state key
201 ///
202 /// # Arguments
203 /// * `key` - State key name
204 /// * `reducer` - Reducer implementation
205 fn add_reducer(&mut self, key: impl Into<String>, reducer: Box<dyn Reducer>) -> &mut Self;
206
207 /// Set the graph configuration
208 fn with_config(&mut self, config: GraphConfig) -> &mut Self;
209
210 /// Get the graph ID
211 fn id(&self) -> &str;
212
213 /// Compile the graph into an executable form
214 ///
215 /// This validates the graph structure and prepares it for execution.
216 fn compile(self) -> AgentResult<Self::Compiled>;
217}
218
219/// Compiled graph trait for execution
220///
221/// A compiled graph can be invoked with an initial state and
222/// returns the final state after execution.
223#[async_trait]
224pub trait CompiledGraph<S: GraphState>: Send + Sync {
225 /// Get the graph ID
226 fn id(&self) -> &str;
227
228 /// Execute the graph synchronously
229 ///
230 /// # Arguments
231 /// * `input` - Initial state
232 /// * `config` - Optional runtime configuration (uses defaults if None)
233 ///
234 /// # Returns
235 /// The final state after graph execution completes
236 async fn invoke(&self, input: S, config: Option<RuntimeContext>) -> AgentResult<S>;
237
238 /// Execute the graph with streaming output
239 ///
240 /// Returns a stream of (node_id, state) pairs as each node completes.
241 async fn stream(
242 &self,
243 input: S,
244 config: Option<RuntimeContext>,
245 ) -> AgentResult<Pin<Box<dyn Stream<Item = AgentResult<StreamEvent<S>>> + Send>>>;
246
247 /// Execute a single step of the graph
248 ///
249 /// Useful for debugging or interactive execution.
250 async fn step(&self, input: S, config: Option<RuntimeContext>) -> AgentResult<StepResult<S>>;
251
252 /// Validate that a state is valid for this graph
253 fn validate_state(&self, state: &S) -> AgentResult<()>;
254
255 /// Get the graph's state schema
256 fn state_schema(&self) -> HashMap<String, String>;
257}
258
259/// Stream event from graph execution
260#[derive(Debug, Clone)]
261pub enum StreamEvent<S: GraphState> {
262 /// A node started executing
263 NodeStart { node_id: String, state: S },
264 /// A node finished executing
265 NodeEnd {
266 node_id: String,
267 state: S,
268 command: Command,
269 },
270 /// Graph execution completed
271 End { final_state: S },
272 /// Error occurred
273 Error {
274 node_id: Option<String>,
275 error: String,
276 },
277}
278
279/// Result of a single step execution
280#[derive(Debug, Clone)]
281pub struct StepResult<S: GraphState> {
282 /// Current state after the step
283 pub state: S,
284 /// Which node was executed
285 pub node_id: String,
286 /// Command returned by the node
287 pub command: Command,
288 /// Whether execution is complete
289 pub is_complete: bool,
290 /// Next node to execute (if any)
291 pub next_node: Option<String>,
292}
293
294#[cfg(test)]
295mod tests {
296 use super::*;
297
298 #[test]
299 fn test_edge_target_single() {
300 let target = EdgeTarget::single("node_a");
301 assert!(!target.is_conditional());
302 assert_eq!(target.targets(), vec!["node_a"]);
303 }
304
305 #[test]
306 fn test_edge_target_conditional() {
307 let mut routes = HashMap::new();
308 routes.insert("condition_a".to_string(), "node_a".to_string());
309 routes.insert("condition_b".to_string(), "node_b".to_string());
310
311 let target = EdgeTarget::conditional(routes);
312 assert!(target.is_conditional());
313
314 let targets = target.targets();
315 assert_eq!(targets.len(), 2);
316 assert!(targets.contains(&"node_a"));
317 assert!(targets.contains(&"node_b"));
318 }
319
320 #[test]
321 fn test_edge_target_parallel() {
322 let target = EdgeTarget::parallel(vec!["a".to_string(), "b".to_string(), "c".to_string()]);
323 assert!(!target.is_conditional());
324 assert_eq!(target.targets(), vec!["a", "b", "c"]);
325 }
326
327 #[test]
328 fn test_constants() {
329 assert_eq!(START, "__START__");
330 assert_eq!(END, "__END__");
331 }
332}