Skip to main content

mofa_kernel/workflow/
graph.rs

1//! State Graph Traits
2//!
3//! Defines the core graph interfaces for building and executing workflows.
4//! Inspired by LangGraph's StateGraph API.
5
6use async_trait::async_trait;
7use futures::Stream;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::pin::Pin;
11
12use crate::agent::error::AgentResult;
13
14use super::{Command, GraphConfig, GraphState, Reducer, RuntimeContext};
15
16/// Special node ID for the graph entry point
17pub const START: &str = "__START__";
18
19/// Special node ID for the graph exit point
20pub const END: &str = "__END__";
21
22/// Node function trait
23///
24/// Implement this trait to define custom node behavior.
25/// Nodes receive the current state and runtime context,
26/// and return a Command that can update state and control flow.
27///
28/// # Example
29///
30/// ```rust,ignore
31/// use mofa_kernel::workflow::{NodeFunc, Command, RuntimeContext};
32///
33/// struct ProcessNode;
34///
35/// #[async_trait]
36/// impl NodeFunc<MyState> for ProcessNode {
37///     async fn call(&self, state: &mut MyState, ctx: &RuntimeContext) -> AgentResult<Command> {
38///         // Process the state
39///         let input = state.messages.last().cloned().unwrap_or_default();
40///
41///         // Return command with state update and control flow
42///         Ok(Command::new()
43///             .update("result", json!(format!("Processed: {}", input)))
44///             .goto("next_node"))
45///     }
46///
47///     fn name(&self) -> &str {
48///         "process"
49///     }
50/// }
51/// ```
52#[async_trait]
53pub trait NodeFunc<S: GraphState>: Send + Sync {
54    /// Execute the node
55    ///
56    /// # Arguments
57    /// * `state` - Mutable reference to the current state
58    /// * `ctx` - Runtime context with execution metadata
59    ///
60    /// # Returns
61    /// A Command containing state updates and control flow directive
62    async fn call(&self, state: &mut S, ctx: &RuntimeContext) -> AgentResult<Command>;
63
64    /// Returns the node name/identifier
65    fn name(&self) -> &str;
66
67    /// Optional description of what this node does
68    fn description(&self) -> Option<&str> {
69        None
70    }
71}
72
73/// Edge target definition
74#[derive(Debug, Clone, Serialize, Deserialize)]
75pub enum EdgeTarget {
76    /// Single target node
77    Single(String),
78    /// Conditional edges with route names to node IDs
79    Conditional(HashMap<String, String>),
80    /// Multiple parallel targets
81    Parallel(Vec<String>),
82}
83
84impl EdgeTarget {
85    /// Create a single target edge
86    pub fn single(target: impl Into<String>) -> Self {
87        Self::Single(target.into())
88    }
89
90    /// Create conditional edges
91    pub fn conditional(routes: HashMap<String, String>) -> Self {
92        Self::Conditional(routes)
93    }
94
95    /// Create parallel edges
96    pub fn parallel(targets: Vec<String>) -> Self {
97        Self::Parallel(targets)
98    }
99
100    /// Check if this is a conditional edge
101    pub fn is_conditional(&self) -> bool {
102        matches!(self, Self::Conditional(_))
103    }
104
105    /// Get all target node IDs
106    pub fn targets(&self) -> Vec<&str> {
107        match self {
108            Self::Single(t) => vec![t],
109            Self::Conditional(routes) => routes.values().map(|s| s.as_str()).collect(),
110            Self::Parallel(targets) => targets.iter().map(|s| s.as_str()).collect(),
111        }
112    }
113}
114
115/// State graph builder trait
116///
117/// Defines the interface for building stateful workflow graphs.
118/// Implementations should provide a fluent API for constructing graphs.
119///
120/// # Example
121///
122/// ```rust,ignore
123/// use mofa_kernel::workflow::{StateGraph, START, END};
124///
125/// let graph = StateGraphImpl::<MyState>::new("my_workflow")
126///     // Add reducers for state keys
127///     .add_reducer("messages", Box::new(AppendReducer))
128///     .add_reducer("result", Box::new(OverwriteReducer))
129///     // Add nodes
130///     .add_node("process", Box::new(ProcessNode))
131///     .add_node("validate", Box::new(ValidateNode))
132///     // Add edges
133///     .add_edge(START, "process")
134///     .add_edge("process", "validate")
135///     .add_edge("validate", END)
136///     // Compile
137///     .compile()?;
138/// ```
139#[async_trait]
140pub trait StateGraph: Send + Sync {
141    /// The state type for this graph
142    type State: GraphState;
143
144    /// The compiled graph type produced by this builder
145    type Compiled: CompiledGraph<Self::State>;
146
147    /// Create a new graph with the given ID
148    fn new(id: impl Into<String>) -> Self;
149
150    /// Add a node to the graph
151    ///
152    /// # Arguments
153    /// * `id` - Unique node identifier
154    /// * `node` - Node function implementation
155    fn add_node(
156        &mut self,
157        id: impl Into<String>,
158        node: Box<dyn NodeFunc<Self::State>>,
159    ) -> &mut Self;
160
161    /// Add an edge between two nodes
162    ///
163    /// # Arguments
164    /// * `from` - Source node ID (use START for entry edge)
165    /// * `to` - Target node ID (use END for exit edge)
166    fn add_edge(&mut self, from: impl Into<String>, to: impl Into<String>) -> &mut Self;
167
168    /// Add conditional edges from a node
169    ///
170    /// # Arguments
171    /// * `from` - Source node ID
172    /// * `conditions` - Map of condition names to target node IDs
173    ///
174    /// # Example
175    /// ```rust,ignore
176    /// graph.add_conditional_edges("classify", HashMap::from([
177    ///     ("type_a".to_string(), "handle_a".to_string()),
178    ///     ("type_b".to_string(), "handle_b".to_string()),
179    /// ]));
180    /// ```
181    fn add_conditional_edges(
182        &mut self,
183        from: impl Into<String>,
184        conditions: HashMap<String, String>,
185    ) -> &mut Self;
186
187    /// Add parallel edges from a node
188    ///
189    /// # Arguments
190    /// * `from` - Source node ID
191    /// * `targets` - List of target node IDs to execute in parallel
192    fn add_parallel_edges(&mut self, from: impl Into<String>, targets: Vec<String>) -> &mut Self;
193
194    /// Set the entry point (equivalent to add_edge(START, node))
195    fn set_entry_point(&mut self, node: impl Into<String>) -> &mut Self;
196
197    /// Set a finish point (equivalent to add_edge(node, END))
198    fn set_finish_point(&mut self, node: impl Into<String>) -> &mut Self;
199
200    /// Add a reducer for a state key
201    ///
202    /// # Arguments
203    /// * `key` - State key name
204    /// * `reducer` - Reducer implementation
205    fn add_reducer(&mut self, key: impl Into<String>, reducer: Box<dyn Reducer>) -> &mut Self;
206
207    /// Set the graph configuration
208    fn with_config(&mut self, config: GraphConfig) -> &mut Self;
209
210    /// Get the graph ID
211    fn id(&self) -> &str;
212
213    /// Compile the graph into an executable form
214    ///
215    /// This validates the graph structure and prepares it for execution.
216    fn compile(self) -> AgentResult<Self::Compiled>;
217}
218
219/// Compiled graph trait for execution
220///
221/// A compiled graph can be invoked with an initial state and
222/// returns the final state after execution.
223#[async_trait]
224pub trait CompiledGraph<S: GraphState>: Send + Sync {
225    /// Get the graph ID
226    fn id(&self) -> &str;
227
228    /// Execute the graph synchronously
229    ///
230    /// # Arguments
231    /// * `input` - Initial state
232    /// * `config` - Optional runtime configuration (uses defaults if None)
233    ///
234    /// # Returns
235    /// The final state after graph execution completes
236    async fn invoke(&self, input: S, config: Option<RuntimeContext>) -> AgentResult<S>;
237
238    /// Execute the graph with streaming output
239    ///
240    /// Returns a stream of (node_id, state) pairs as each node completes.
241    async fn stream(
242        &self,
243        input: S,
244        config: Option<RuntimeContext>,
245    ) -> AgentResult<Pin<Box<dyn Stream<Item = AgentResult<StreamEvent<S>>> + Send>>>;
246
247    /// Execute a single step of the graph
248    ///
249    /// Useful for debugging or interactive execution.
250    async fn step(&self, input: S, config: Option<RuntimeContext>) -> AgentResult<StepResult<S>>;
251
252    /// Validate that a state is valid for this graph
253    fn validate_state(&self, state: &S) -> AgentResult<()>;
254
255    /// Get the graph's state schema
256    fn state_schema(&self) -> HashMap<String, String>;
257}
258
259/// Stream event from graph execution
260#[derive(Debug, Clone)]
261pub enum StreamEvent<S: GraphState> {
262    /// A node started executing
263    NodeStart { node_id: String, state: S },
264    /// A node finished executing
265    NodeEnd {
266        node_id: String,
267        state: S,
268        command: Command,
269    },
270    /// Graph execution completed
271    End { final_state: S },
272    /// Error occurred
273    Error {
274        node_id: Option<String>,
275        error: String,
276    },
277}
278
279/// Result of a single step execution
280#[derive(Debug, Clone)]
281pub struct StepResult<S: GraphState> {
282    /// Current state after the step
283    pub state: S,
284    /// Which node was executed
285    pub node_id: String,
286    /// Command returned by the node
287    pub command: Command,
288    /// Whether execution is complete
289    pub is_complete: bool,
290    /// Next node to execute (if any)
291    pub next_node: Option<String>,
292}
293
294#[cfg(test)]
295mod tests {
296    use super::*;
297
298    #[test]
299    fn test_edge_target_single() {
300        let target = EdgeTarget::single("node_a");
301        assert!(!target.is_conditional());
302        assert_eq!(target.targets(), vec!["node_a"]);
303    }
304
305    #[test]
306    fn test_edge_target_conditional() {
307        let mut routes = HashMap::new();
308        routes.insert("condition_a".to_string(), "node_a".to_string());
309        routes.insert("condition_b".to_string(), "node_b".to_string());
310
311        let target = EdgeTarget::conditional(routes);
312        assert!(target.is_conditional());
313
314        let targets = target.targets();
315        assert_eq!(targets.len(), 2);
316        assert!(targets.contains(&"node_a"));
317        assert!(targets.contains(&"node_b"));
318    }
319
320    #[test]
321    fn test_edge_target_parallel() {
322        let target = EdgeTarget::parallel(vec!["a".to_string(), "b".to_string(), "c".to_string()]);
323        assert!(!target.is_conditional());
324        assert_eq!(target.targets(), vec!["a", "b", "c"]);
325    }
326
327    #[test]
328    fn test_constants() {
329        assert_eq!(START, "__START__");
330        assert_eq!(END, "__END__");
331    }
332}