mofa_kernel/workflow/graph.rs
1//! State Graph Traits
2//!
3//! Defines the core graph interfaces for building and executing workflows.
4//! Inspired by LangGraph's StateGraph API.
5
6use async_trait::async_trait;
7use futures::Stream;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::pin::Pin;
11
12use crate::agent::error::AgentResult;
13
14use super::{Command, GraphConfig, GraphState, Reducer, RuntimeContext};
15
16/// Special node ID for the graph entry point
17pub const START: &str = "__START__";
18
19/// Special node ID for the graph exit point
20pub const END: &str = "__END__";
21
22/// Node function trait
23///
24/// Implement this trait to define custom node behavior.
25/// Nodes receive the current state and runtime context,
26/// and return a Command that can update state and control flow.
27///
28/// # Example
29///
30/// ```rust,ignore
31/// use mofa_kernel::workflow::{NodeFunc, Command, RuntimeContext};
32///
33/// struct ProcessNode;
34///
35/// #[async_trait]
36/// impl NodeFunc<MyState> for ProcessNode {
37/// async fn call(&self, state: &mut MyState, ctx: &RuntimeContext) -> AgentResult<Command> {
38/// // Process the state
39/// let input = state.messages.last().cloned().unwrap_or_default();
40///
41/// // Return command with state update and control flow
42/// Ok(Command::new()
43/// .update("result", json!(format!("Processed: {}", input)))
44/// .goto("next_node"))
45/// }
46///
47/// fn name(&self) -> &str {
48/// "process"
49/// }
50/// }
51/// ```
52#[async_trait]
53pub trait NodeFunc<S: GraphState>: Send + Sync {
54 /// Execute the node
55 ///
56 /// # Arguments
57 /// * `state` - Mutable reference to the current state
58 /// * `ctx` - Runtime context with execution metadata
59 ///
60 /// # Returns
61 /// A Command containing state updates and control flow directive
62 async fn call(&self, state: &mut S, ctx: &RuntimeContext) -> AgentResult<Command>;
63
64 /// Returns the node name/identifier
65 fn name(&self) -> &str;
66
67 /// Optional description of what this node does
68 fn description(&self) -> Option<&str> {
69 None
70 }
71}
72
73/// Edge target definition
74#[derive(Debug, Clone, Serialize, Deserialize)]
75pub enum EdgeTarget {
76 /// Single target node
77 Single(String),
78 /// Conditional edges with route names to node IDs
79 Conditional(HashMap<String, String>),
80 /// Multiple parallel targets
81 Parallel(Vec<String>),
82}
83
84impl EdgeTarget {
85 /// Create a single target edge
86 pub fn single(target: impl Into<String>) -> Self {
87 Self::Single(target.into())
88 }
89
90 /// Create conditional edges
91 pub fn conditional(routes: HashMap<String, String>) -> Self {
92 Self::Conditional(routes)
93 }
94
95 /// Create parallel edges
96 pub fn parallel(targets: Vec<String>) -> Self {
97 Self::Parallel(targets)
98 }
99
100 /// Check if this is a conditional edge
101 pub fn is_conditional(&self) -> bool {
102 matches!(self, Self::Conditional(_))
103 }
104
105 /// Get all target node IDs
106 pub fn targets(&self) -> Vec<&str> {
107 match self {
108 Self::Single(t) => vec![t],
109 Self::Conditional(routes) => routes.values().map(|s| s.as_str()).collect(),
110 Self::Parallel(targets) => targets.iter().map(|s| s.as_str()).collect(),
111 }
112 }
113}
114
115/// State graph builder trait
116///
117/// Defines the interface for building stateful workflow graphs.
118/// Implementations should provide a fluent API for constructing graphs.
119///
120/// # Example
121///
122/// ```rust,ignore
123/// use mofa_kernel::workflow::{StateGraph, START, END};
124///
125/// let graph = StateGraphImpl::<MyState>::new("my_workflow")
126/// // Add reducers for state keys
127/// .add_reducer("messages", Box::new(AppendReducer))
128/// .add_reducer("result", Box::new(OverwriteReducer))
129/// // Add nodes
130/// .add_node("process", Box::new(ProcessNode))
131/// .add_node("validate", Box::new(ValidateNode))
132/// // Add edges
133/// .add_edge(START, "process")
134/// .add_edge("process", "validate")
135/// .add_edge("validate", END)
136/// // Compile
137/// .compile()?;
138/// ```
139#[async_trait]
140pub trait StateGraph: Send + Sync {
141 /// The state type for this graph
142 type State: GraphState;
143
144 /// The compiled graph type produced by this builder
145 type Compiled: CompiledGraph<Self::State>;
146
147 /// Create a new graph with the given ID
148 fn new(id: impl Into<String>) -> Self;
149
150 /// Add a node to the graph
151 ///
152 /// # Arguments
153 /// * `id` - Unique node identifier
154 /// * `node` - Node function implementation
155 fn add_node(&mut self, id: impl Into<String>, node: Box<dyn NodeFunc<Self::State>>) -> &mut Self;
156
157 /// Add an edge between two nodes
158 ///
159 /// # Arguments
160 /// * `from` - Source node ID (use START for entry edge)
161 /// * `to` - Target node ID (use END for exit edge)
162 fn add_edge(&mut self, from: impl Into<String>, to: impl Into<String>) -> &mut Self;
163
164 /// Add conditional edges from a node
165 ///
166 /// # Arguments
167 /// * `from` - Source node ID
168 /// * `conditions` - Map of condition names to target node IDs
169 ///
170 /// # Example
171 /// ```rust,ignore
172 /// graph.add_conditional_edges("classify", HashMap::from([
173 /// ("type_a".to_string(), "handle_a".to_string()),
174 /// ("type_b".to_string(), "handle_b".to_string()),
175 /// ]));
176 /// ```
177 fn add_conditional_edges(
178 &mut self,
179 from: impl Into<String>,
180 conditions: HashMap<String, String>,
181 ) -> &mut Self;
182
183 /// Add parallel edges from a node
184 ///
185 /// # Arguments
186 /// * `from` - Source node ID
187 /// * `targets` - List of target node IDs to execute in parallel
188 fn add_parallel_edges(&mut self, from: impl Into<String>, targets: Vec<String>) -> &mut Self;
189
190 /// Set the entry point (equivalent to add_edge(START, node))
191 fn set_entry_point(&mut self, node: impl Into<String>) -> &mut Self;
192
193 /// Set a finish point (equivalent to add_edge(node, END))
194 fn set_finish_point(&mut self, node: impl Into<String>) -> &mut Self;
195
196 /// Add a reducer for a state key
197 ///
198 /// # Arguments
199 /// * `key` - State key name
200 /// * `reducer` - Reducer implementation
201 fn add_reducer(&mut self, key: impl Into<String>, reducer: Box<dyn Reducer>) -> &mut Self;
202
203 /// Set the graph configuration
204 fn with_config(&mut self, config: GraphConfig) -> &mut Self;
205
206 /// Get the graph ID
207 fn id(&self) -> &str;
208
209 /// Compile the graph into an executable form
210 ///
211 /// This validates the graph structure and prepares it for execution.
212 fn compile(self) -> AgentResult<Self::Compiled>;
213}
214
215/// Compiled graph trait for execution
216///
217/// A compiled graph can be invoked with an initial state and
218/// returns the final state after execution.
219#[async_trait]
220pub trait CompiledGraph<S: GraphState>: Send + Sync {
221 /// Get the graph ID
222 fn id(&self) -> &str;
223
224 /// Execute the graph synchronously
225 ///
226 /// # Arguments
227 /// * `input` - Initial state
228 /// * `config` - Optional runtime configuration (uses defaults if None)
229 ///
230 /// # Returns
231 /// The final state after graph execution completes
232 async fn invoke(&self, input: S, config: Option<RuntimeContext>) -> AgentResult<S>;
233
234 /// Execute the graph with streaming output
235 ///
236 /// Returns a stream of (node_id, state) pairs as each node completes.
237 async fn stream(
238 &self,
239 input: S,
240 config: Option<RuntimeContext>,
241 ) -> AgentResult<Pin<Box<dyn Stream<Item = AgentResult<StreamEvent<S>>> + Send>>>;
242
243 /// Execute a single step of the graph
244 ///
245 /// Useful for debugging or interactive execution.
246 async fn step(&self, input: S, config: Option<RuntimeContext>) -> AgentResult<StepResult<S>>;
247
248 /// Validate that a state is valid for this graph
249 fn validate_state(&self, state: &S) -> AgentResult<()>;
250
251 /// Get the graph's state schema
252 fn state_schema(&self) -> HashMap<String, String>;
253}
254
255/// Stream event from graph execution
256#[derive(Debug, Clone)]
257pub enum StreamEvent<S: GraphState> {
258 /// A node started executing
259 NodeStart {
260 node_id: String,
261 state: S,
262 },
263 /// A node finished executing
264 NodeEnd {
265 node_id: String,
266 state: S,
267 command: Command,
268 },
269 /// Graph execution completed
270 End {
271 final_state: S,
272 },
273 /// Error occurred
274 Error {
275 node_id: Option<String>,
276 error: String,
277 },
278}
279
280/// Result of a single step execution
281#[derive(Debug, Clone)]
282pub struct StepResult<S: GraphState> {
283 /// Current state after the step
284 pub state: S,
285 /// Which node was executed
286 pub node_id: String,
287 /// Command returned by the node
288 pub command: Command,
289 /// Whether execution is complete
290 pub is_complete: bool,
291 /// Next node to execute (if any)
292 pub next_node: Option<String>,
293}
294
295#[cfg(test)]
296mod tests {
297 use super::*;
298
299 #[test]
300 fn test_edge_target_single() {
301 let target = EdgeTarget::single("node_a");
302 assert!(!target.is_conditional());
303 assert_eq!(target.targets(), vec!["node_a"]);
304 }
305
306 #[test]
307 fn test_edge_target_conditional() {
308 let mut routes = HashMap::new();
309 routes.insert("condition_a".to_string(), "node_a".to_string());
310 routes.insert("condition_b".to_string(), "node_b".to_string());
311
312 let target = EdgeTarget::conditional(routes);
313 assert!(target.is_conditional());
314
315 let targets = target.targets();
316 assert_eq!(targets.len(), 2);
317 assert!(targets.contains(&"node_a"));
318 assert!(targets.contains(&"node_b"));
319 }
320
321 #[test]
322 fn test_edge_target_parallel() {
323 let target = EdgeTarget::parallel(vec!["a".to_string(), "b".to_string(), "c".to_string()]);
324 assert!(!target.is_conditional());
325 assert_eq!(target.targets(), vec!["a", "b", "c"]);
326 }
327
328 #[test]
329 fn test_constants() {
330 assert_eq!(START, "__START__");
331 assert_eq!(END, "__END__");
332 }
333}