Skip to main content

mofa_kernel/workflow/
graph.rs

1//! State Graph Traits
2//!
3//! Defines the core graph interfaces for building and executing workflows.
4//! Inspired by LangGraph's StateGraph API.
5
6use async_trait::async_trait;
7use futures::Stream;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::pin::Pin;
11
12use crate::agent::error::AgentResult;
13
14use super::{Command, GraphConfig, GraphState, Reducer, RuntimeContext};
15
16/// Special node ID for the graph entry point
17pub const START: &str = "__START__";
18
19/// Special node ID for the graph exit point
20pub const END: &str = "__END__";
21
22/// Node function trait
23///
24/// Implement this trait to define custom node behavior.
25/// Nodes receive the current state and runtime context,
26/// and return a Command that can update state and control flow.
27///
28/// # Example
29///
30/// ```rust,ignore
31/// use mofa_kernel::workflow::{NodeFunc, Command, RuntimeContext};
32///
33/// struct ProcessNode;
34///
35/// #[async_trait]
36/// impl NodeFunc<MyState> for ProcessNode {
37///     async fn call(&self, state: &mut MyState, ctx: &RuntimeContext) -> AgentResult<Command> {
38///         // Process the state
39///         let input = state.messages.last().cloned().unwrap_or_default();
40///
41///         // Return command with state update and control flow
42///         Ok(Command::new()
43///             .update("result", json!(format!("Processed: {}", input)))
44///             .goto("next_node"))
45///     }
46///
47///     fn name(&self) -> &str {
48///         "process"
49///     }
50/// }
51/// ```
52#[async_trait]
53pub trait NodeFunc<S: GraphState>: Send + Sync {
54    /// Execute the node
55    ///
56    /// # Arguments
57    /// * `state` - Mutable reference to the current state
58    /// * `ctx` - Runtime context with execution metadata
59    ///
60    /// # Returns
61    /// A Command containing state updates and control flow directive
62    async fn call(&self, state: &mut S, ctx: &RuntimeContext) -> AgentResult<Command>;
63
64    /// Returns the node name/identifier
65    fn name(&self) -> &str;
66
67    /// Optional description of what this node does
68    fn description(&self) -> Option<&str> {
69        None
70    }
71}
72
73/// Edge target definition
74#[derive(Debug, Clone, Serialize, Deserialize)]
75pub enum EdgeTarget {
76    /// Single target node
77    Single(String),
78    /// Conditional edges with route names to node IDs
79    Conditional(HashMap<String, String>),
80    /// Multiple parallel targets
81    Parallel(Vec<String>),
82}
83
84impl EdgeTarget {
85    /// Create a single target edge
86    pub fn single(target: impl Into<String>) -> Self {
87        Self::Single(target.into())
88    }
89
90    /// Create conditional edges
91    pub fn conditional(routes: HashMap<String, String>) -> Self {
92        Self::Conditional(routes)
93    }
94
95    /// Create parallel edges
96    pub fn parallel(targets: Vec<String>) -> Self {
97        Self::Parallel(targets)
98    }
99
100    /// Check if this is a conditional edge
101    pub fn is_conditional(&self) -> bool {
102        matches!(self, Self::Conditional(_))
103    }
104
105    /// Get all target node IDs
106    pub fn targets(&self) -> Vec<&str> {
107        match self {
108            Self::Single(t) => vec![t],
109            Self::Conditional(routes) => routes.values().map(|s| s.as_str()).collect(),
110            Self::Parallel(targets) => targets.iter().map(|s| s.as_str()).collect(),
111        }
112    }
113}
114
115/// State graph builder trait
116///
117/// Defines the interface for building stateful workflow graphs.
118/// Implementations should provide a fluent API for constructing graphs.
119///
120/// # Example
121///
122/// ```rust,ignore
123/// use mofa_kernel::workflow::{StateGraph, START, END};
124///
125/// let graph = StateGraphImpl::<MyState>::new("my_workflow")
126///     // Add reducers for state keys
127///     .add_reducer("messages", Box::new(AppendReducer))
128///     .add_reducer("result", Box::new(OverwriteReducer))
129///     // Add nodes
130///     .add_node("process", Box::new(ProcessNode))
131///     .add_node("validate", Box::new(ValidateNode))
132///     // Add edges
133///     .add_edge(START, "process")
134///     .add_edge("process", "validate")
135///     .add_edge("validate", END)
136///     // Compile
137///     .compile()?;
138/// ```
139#[async_trait]
140pub trait StateGraph: Send + Sync {
141    /// The state type for this graph
142    type State: GraphState;
143
144    /// The compiled graph type produced by this builder
145    type Compiled: CompiledGraph<Self::State>;
146
147    /// Create a new graph with the given ID
148    fn new(id: impl Into<String>) -> Self;
149
150    /// Add a node to the graph
151    ///
152    /// # Arguments
153    /// * `id` - Unique node identifier
154    /// * `node` - Node function implementation
155    fn add_node(&mut self, id: impl Into<String>, node: Box<dyn NodeFunc<Self::State>>) -> &mut Self;
156
157    /// Add an edge between two nodes
158    ///
159    /// # Arguments
160    /// * `from` - Source node ID (use START for entry edge)
161    /// * `to` - Target node ID (use END for exit edge)
162    fn add_edge(&mut self, from: impl Into<String>, to: impl Into<String>) -> &mut Self;
163
164    /// Add conditional edges from a node
165    ///
166    /// # Arguments
167    /// * `from` - Source node ID
168    /// * `conditions` - Map of condition names to target node IDs
169    ///
170    /// # Example
171    /// ```rust,ignore
172    /// graph.add_conditional_edges("classify", HashMap::from([
173    ///     ("type_a".to_string(), "handle_a".to_string()),
174    ///     ("type_b".to_string(), "handle_b".to_string()),
175    /// ]));
176    /// ```
177    fn add_conditional_edges(
178        &mut self,
179        from: impl Into<String>,
180        conditions: HashMap<String, String>,
181    ) -> &mut Self;
182
183    /// Add parallel edges from a node
184    ///
185    /// # Arguments
186    /// * `from` - Source node ID
187    /// * `targets` - List of target node IDs to execute in parallel
188    fn add_parallel_edges(&mut self, from: impl Into<String>, targets: Vec<String>) -> &mut Self;
189
190    /// Set the entry point (equivalent to add_edge(START, node))
191    fn set_entry_point(&mut self, node: impl Into<String>) -> &mut Self;
192
193    /// Set a finish point (equivalent to add_edge(node, END))
194    fn set_finish_point(&mut self, node: impl Into<String>) -> &mut Self;
195
196    /// Add a reducer for a state key
197    ///
198    /// # Arguments
199    /// * `key` - State key name
200    /// * `reducer` - Reducer implementation
201    fn add_reducer(&mut self, key: impl Into<String>, reducer: Box<dyn Reducer>) -> &mut Self;
202
203    /// Set the graph configuration
204    fn with_config(&mut self, config: GraphConfig) -> &mut Self;
205
206    /// Get the graph ID
207    fn id(&self) -> &str;
208
209    /// Compile the graph into an executable form
210    ///
211    /// This validates the graph structure and prepares it for execution.
212    fn compile(self) -> AgentResult<Self::Compiled>;
213}
214
215/// Compiled graph trait for execution
216///
217/// A compiled graph can be invoked with an initial state and
218/// returns the final state after execution.
219#[async_trait]
220pub trait CompiledGraph<S: GraphState>: Send + Sync {
221    /// Get the graph ID
222    fn id(&self) -> &str;
223
224    /// Execute the graph synchronously
225    ///
226    /// # Arguments
227    /// * `input` - Initial state
228    /// * `config` - Optional runtime configuration (uses defaults if None)
229    ///
230    /// # Returns
231    /// The final state after graph execution completes
232    async fn invoke(&self, input: S, config: Option<RuntimeContext>) -> AgentResult<S>;
233
234    /// Execute the graph with streaming output
235    ///
236    /// Returns a stream of (node_id, state) pairs as each node completes.
237    async fn stream(
238        &self,
239        input: S,
240        config: Option<RuntimeContext>,
241    ) -> AgentResult<Pin<Box<dyn Stream<Item = AgentResult<StreamEvent<S>>> + Send>>>;
242
243    /// Execute a single step of the graph
244    ///
245    /// Useful for debugging or interactive execution.
246    async fn step(&self, input: S, config: Option<RuntimeContext>) -> AgentResult<StepResult<S>>;
247
248    /// Validate that a state is valid for this graph
249    fn validate_state(&self, state: &S) -> AgentResult<()>;
250
251    /// Get the graph's state schema
252    fn state_schema(&self) -> HashMap<String, String>;
253}
254
255/// Stream event from graph execution
256#[derive(Debug, Clone)]
257pub enum StreamEvent<S: GraphState> {
258    /// A node started executing
259    NodeStart {
260        node_id: String,
261        state: S,
262    },
263    /// A node finished executing
264    NodeEnd {
265        node_id: String,
266        state: S,
267        command: Command,
268    },
269    /// Graph execution completed
270    End {
271        final_state: S,
272    },
273    /// Error occurred
274    Error {
275        node_id: Option<String>,
276        error: String,
277    },
278}
279
280/// Result of a single step execution
281#[derive(Debug, Clone)]
282pub struct StepResult<S: GraphState> {
283    /// Current state after the step
284    pub state: S,
285    /// Which node was executed
286    pub node_id: String,
287    /// Command returned by the node
288    pub command: Command,
289    /// Whether execution is complete
290    pub is_complete: bool,
291    /// Next node to execute (if any)
292    pub next_node: Option<String>,
293}
294
295#[cfg(test)]
296mod tests {
297    use super::*;
298
299    #[test]
300    fn test_edge_target_single() {
301        let target = EdgeTarget::single("node_a");
302        assert!(!target.is_conditional());
303        assert_eq!(target.targets(), vec!["node_a"]);
304    }
305
306    #[test]
307    fn test_edge_target_conditional() {
308        let mut routes = HashMap::new();
309        routes.insert("condition_a".to_string(), "node_a".to_string());
310        routes.insert("condition_b".to_string(), "node_b".to_string());
311
312        let target = EdgeTarget::conditional(routes);
313        assert!(target.is_conditional());
314
315        let targets = target.targets();
316        assert_eq!(targets.len(), 2);
317        assert!(targets.contains(&"node_a"));
318        assert!(targets.contains(&"node_b"));
319    }
320
321    #[test]
322    fn test_edge_target_parallel() {
323        let target = EdgeTarget::parallel(vec!["a".to_string(), "b".to_string(), "c".to_string()]);
324        assert!(!target.is_conditional());
325        assert_eq!(target.targets(), vec!["a", "b", "c"]);
326    }
327
328    #[test]
329    fn test_constants() {
330        assert_eq!(START, "__START__");
331        assert_eq!(END, "__END__");
332    }
333}