Skip to main content

mofa_foundation/workflow/
state_graph.rs

1//! StateGraph Implementation
2//!
3//! This module provides a LangGraph-inspired StateGraph implementation
4//! for building and executing stateful workflow graphs.
5
6use async_trait::async_trait;
7use futures::Stream;
8use mofa_kernel::agent::error::{AgentError, AgentResult};
9use mofa_kernel::workflow::{
10    CompiledGraph, Command, ControlFlow, EdgeTarget, GraphConfig, GraphState,
11    NodeFunc, Reducer, RuntimeContext, StateUpdate, StreamEvent, StepResult, END, START,
12};
13use serde_json::Value;
14use std::collections::{HashMap, HashSet};
15use std::pin::Pin;
16use std::sync::Arc;
17use tracing::{debug, info, warn};
18
19/// Type alias for node ID
20pub type NodeId = String;
21
22/// StateGraph implementation - LangGraph-inspired API
23///
24/// # Example
25///
26/// ```rust,ignore
27/// use mofa_foundation::workflow::{StateGraphImpl, AppendReducer, OverwriteReducer};
28/// use mofa_kernel::workflow::{StateGraph, START, END};
29///
30/// let graph = StateGraphImpl::<MyState>::new("my_workflow")
31///     .add_reducer("messages", Box::new(AppendReducer))
32///     .add_node("process", Box::new(ProcessNode))
33///     .add_edge(START, "process")
34///     .add_edge("process", END)
35///     .compile()?;
36/// ```
37pub struct StateGraphImpl<S: GraphState> {
38    /// Graph ID
39    id: String,
40    /// Node functions
41    nodes: HashMap<NodeId, Box<dyn NodeFunc<S>>>,
42    /// Edges: source -> target(s)
43    edges: HashMap<NodeId, EdgeTarget>,
44    /// Reducers for state keys
45    reducers: HashMap<String, Box<dyn Reducer>>,
46    /// Entry point (first node after START)
47    entry_point: Option<NodeId>,
48    /// Finish points (nodes that connect to END)
49    finish_points: Vec<NodeId>,
50    /// Graph configuration
51    config: GraphConfig,
52}
53
54impl<S: GraphState> StateGraphImpl<S> {
55    /// Create a new StateGraph builder with the given ID
56    pub fn build(id: impl Into<String>) -> Self {
57        Self {
58            id: id.into(),
59            nodes: HashMap::new(),
60            edges: HashMap::new(),
61            reducers: HashMap::new(),
62            entry_point: None,
63            finish_points: Vec::new(),
64            config: GraphConfig::default(),
65        }
66    }
67
68    /// Get the number of nodes
69    pub fn node_count(&self) -> usize {
70        self.nodes.len()
71    }
72
73    /// Get the number of edges
74    pub fn edge_count(&self) -> usize {
75        self.edges.len()
76    }
77
78    /// Get all node IDs
79    pub fn node_ids(&self) -> Vec<&str> {
80        self.nodes.keys().map(|s| s.as_str()).collect()
81    }
82
83    /// Validate the graph structure
84    pub fn validate(&self) -> AgentResult<()> {
85        let mut errors = Vec::new();
86
87        // Check entry point
88        if self.entry_point.is_none() {
89            errors.push("No entry point set. Use set_entry_point() or add_edge(START, node).".to_string());
90        }
91
92        // Check that all nodes are reachable
93        if let Some(entry) = &self.entry_point {
94            let reachable = self.find_reachable_nodes(entry);
95            for node_id in self.nodes.keys() {
96                if !reachable.contains(node_id) && node_id != entry {
97                    errors.push(format!("Node '{}' is not reachable from entry point", node_id));
98                }
99            }
100        }
101
102        // Check that edges reference valid nodes
103        for (from, target) in &self.edges {
104            if from != START && !self.nodes.contains_key(from) {
105                errors.push(format!("Edge source '{}' does not exist", from));
106            }
107            let targets = target.targets();
108            for target_id in targets {
109                if target_id != END && !self.nodes.contains_key(target_id) {
110                    errors.push(format!("Edge target '{}' does not exist", target_id));
111                }
112            }
113        }
114
115        if errors.is_empty() {
116            Ok(())
117        } else {
118            Err(AgentError::ValidationFailed(errors.join("; ")))
119        }
120    }
121
122    /// Find all nodes reachable from a starting node
123    fn find_reachable_nodes(&self, start: &str) -> HashSet<String> {
124        let mut reachable = HashSet::new();
125        let mut stack = vec![start.to_string()];
126
127        while let Some(node_id) = stack.pop() {
128            if reachable.insert(node_id.clone()) {
129                if let Some(edge_target) = self.edges.get(&node_id) {
130                    let targets = edge_target.targets();
131                    for target in targets {
132                        if target != END && !reachable.contains(target) {
133                            stack.push(target.to_string());
134                        }
135                    }
136                }
137            }
138        }
139
140        reachable
141    }
142}
143
144#[async_trait]
145impl<S: GraphState + 'static> mofa_kernel::workflow::StateGraph for StateGraphImpl<S> {
146    type State = S;
147    type Compiled = CompiledGraphImpl<S>;
148
149    fn new(id: impl Into<String>) -> Self {
150        Self::build(id)
151    }
152
153    fn add_node(&mut self, id: impl Into<String>, node: Box<dyn NodeFunc<S>>) -> &mut Self {
154        let node_id = id.into();
155        debug!("Adding node '{}' to graph '{}'", node_id, self.id);
156        self.nodes.insert(node_id, node);
157        self
158    }
159
160    fn add_edge(&mut self, from: impl Into<String>, to: impl Into<String>) -> &mut Self {
161        let from_id = from.into();
162        let to_id = to.into();
163
164        debug!("Adding edge: {} -> {}", from_id, to_id);
165
166        // Handle START edge (entry point)
167        if from_id == START {
168            self.entry_point = Some(to_id.clone());
169            return self;
170        }
171
172        // Handle END edge (finish point)
173        if to_id == END {
174            if !self.finish_points.contains(&from_id) {
175                self.finish_points.push(from_id.clone());
176            }
177            return self;
178        }
179
180        // Regular edge
181        match self.edges.get_mut(&from_id) {
182            Some(EdgeTarget::Parallel(targets)) => {
183                targets.push(to_id);
184            }
185            Some(EdgeTarget::Single(existing)) => {
186                let existing = existing.clone();
187                self.edges.insert(from_id, EdgeTarget::parallel(vec![existing, to_id]));
188            }
189            Some(EdgeTarget::Conditional(_)) => {
190                warn!("Overwriting conditional edges with single edge for '{}'", from_id);
191                self.edges.insert(from_id, EdgeTarget::single(to_id));
192            }
193            None => {
194                self.edges.insert(from_id, EdgeTarget::single(to_id));
195            }
196        }
197
198        self
199    }
200
201    fn add_conditional_edges(
202        &mut self,
203        from: impl Into<String>,
204        conditions: HashMap<String, String>,
205    ) -> &mut Self {
206        let from_id = from.into();
207        debug!("Adding conditional edges from '{}': {:?}", from_id, conditions);
208        self.edges.insert(from_id, EdgeTarget::conditional(conditions));
209        self
210    }
211
212    fn add_parallel_edges(&mut self, from: impl Into<String>, targets: Vec<String>) -> &mut Self {
213        let from_id = from.into();
214        debug!("Adding parallel edges from '{}': {:?}", from_id, targets);
215        self.edges.insert(from_id, EdgeTarget::parallel(targets));
216        self
217    }
218
219    fn set_entry_point(&mut self, node: impl Into<String>) -> &mut Self {
220        let node_id = node.into();
221        debug!("Setting entry point to '{}'", node_id);
222        self.entry_point = Some(node_id);
223        self
224    }
225
226    fn set_finish_point(&mut self, node: impl Into<String>) -> &mut Self {
227        let node_id = node.into();
228        debug!("Setting finish point at '{}'", node_id);
229        if !self.finish_points.contains(&node_id) {
230            self.finish_points.push(node_id);
231        }
232        self
233    }
234
235    fn add_reducer(&mut self, key: impl Into<String>, reducer: Box<dyn Reducer>) -> &mut Self {
236        let key_str = key.into();
237        debug!("Adding reducer for key '{}' of type {:?}", key_str, reducer.reducer_type());
238        self.reducers.insert(key_str, reducer);
239        self
240    }
241
242    fn with_config(&mut self, config: GraphConfig) -> &mut Self {
243        self.config = config;
244        self
245    }
246
247    fn id(&self) -> &str {
248        &self.id
249    }
250
251    fn compile(self) -> AgentResult<CompiledGraphImpl<S>> {
252        info!("Compiling graph '{}'", self.id);
253
254        // Validate
255        self.validate()?;
256
257        // Create compiled graph
258        Ok(CompiledGraphImpl {
259            id: self.id,
260            nodes: Arc::new(self.nodes),
261            edges: Arc::new(self.edges),
262            reducers: Arc::new(self.reducers),
263            entry_point: self.entry_point.expect("Entry point should be validated"),
264            config: self.config,
265        })
266    }
267}
268
269/// Compiled graph ready for execution
270pub struct CompiledGraphImpl<S: GraphState> {
271    /// Graph ID
272    id: String,
273    /// Node functions
274    nodes: Arc<HashMap<NodeId, Box<dyn NodeFunc<S>>>>,
275    /// Edges
276    edges: Arc<HashMap<NodeId, EdgeTarget>>,
277    /// Reducers
278    reducers: Arc<HashMap<String, Box<dyn Reducer>>>,
279    /// Entry point
280    entry_point: NodeId,
281    /// Configuration
282    config: GraphConfig,
283}
284
285impl<S: GraphState> CompiledGraphImpl<S> {
286    /// Get the next node(s) based on the current node and command
287    fn get_next_nodes(&self, current_node: &str, command: &Command) -> Vec<String> {
288        match &command.control {
289            ControlFlow::Goto(target) => {
290                vec![target.clone()]
291            }
292            ControlFlow::Return => {
293                vec![] // End execution
294            }
295            ControlFlow::Send(sends) => {
296                // MapReduce: create branches for each send target
297                sends.iter().map(|s| s.target.clone()).collect()
298            }
299            ControlFlow::Continue => {
300                // Follow graph edges
301                match self.edges.get(current_node) {
302                    Some(EdgeTarget::Single(target)) => vec![target.clone()],
303                    Some(EdgeTarget::Parallel(targets)) => targets.clone(),
304                    Some(EdgeTarget::Conditional(routes)) => {
305                        // Find matching route based on state updates
306                        for update in &command.updates {
307                            if let Some(target) = routes.get(&update.key) {
308                                return vec![target.clone()];
309                            }
310                        }
311                        // Default to first route if no match
312                        routes.values().next()
313                            .map(|t: &String| vec![t.clone()])
314                            .unwrap_or_default()
315                    }
316                    None => vec![],
317                }
318            }
319        }
320    }
321
322    /// Apply state updates using reducers
323    async fn apply_updates(&self, state: &mut S, updates: &[StateUpdate]) -> AgentResult<()> {
324        for update in updates {
325            let current = state.get_value(&update.key);
326
327            // Get or create reducer
328            let new_value = if let Some(reducer) = self.reducers.get(&update.key) {
329                reducer.reduce(current.as_ref(), &update.value).await?
330            } else {
331                // Default: overwrite
332                update.value.clone()
333            };
334
335            state.apply_update(&update.key, new_value).await?;
336        }
337        Ok(())
338    }
339}
340
341#[async_trait]
342impl<S: GraphState + 'static> CompiledGraph<S> for CompiledGraphImpl<S> {
343    fn id(&self) -> &str {
344        &self.id
345    }
346
347    async fn invoke(&self, input: S, config: Option<RuntimeContext>) -> AgentResult<S> {
348        let ctx = config.unwrap_or_else(|| {
349            RuntimeContext::with_config(&self.id, self.config.clone())
350        });
351
352        info!("Starting graph execution '{}' with execution_id={}", self.id, ctx.execution_id);
353
354        let mut state = input;
355        let mut current_nodes = vec![self.entry_point.clone()];
356
357        while !current_nodes.is_empty() {
358            // Check recursion limit
359            if ctx.is_recursion_limit_reached().await {
360                return Err(AgentError::Internal(
361                    "Recursion limit reached".to_string()
362                ));
363            }
364            ctx.decrement_steps().await;
365
366            // Execute nodes
367            if current_nodes.len() == 1 {
368                // Single node execution
369                let node_id = current_nodes.remove(0);
370                let node = self.nodes.get(&node_id)
371                    .ok_or_else(|| AgentError::NotFound(format!("Node '{}'", node_id)))?;
372
373                ctx.set_current_node(&node_id).await;
374                debug!("Executing node '{}' in graph '{}'", node_id, self.id);
375
376                let command = node.call(&mut state, &ctx).await?;
377
378                // Apply updates
379                self.apply_updates(&mut state, &command.updates).await?;
380
381                // Get next nodes
382                current_nodes = self.get_next_nodes(&node_id, &command);
383
384                debug!("Node '{}' completed, next nodes: {:?}", node_id, current_nodes);
385            } else {
386                // Parallel execution
387                let mut next_nodes = Vec::new();
388                let nodes_to_execute = std::mem::take(&mut current_nodes);
389
390                for node_id in nodes_to_execute {
391                    let node = self.nodes.get(&node_id)
392                        .ok_or_else(|| AgentError::NotFound(format!("Node '{}'", node_id)))?;
393
394                    ctx.set_current_node(&node_id).await;
395                    debug!("Executing node '{}' (parallel)", node_id);
396
397                    let command = node.call(&mut state, &ctx).await?;
398
399                    // Apply updates
400                    self.apply_updates(&mut state, &command.updates).await?;
401
402                    // Collect next nodes
403                    let next = self.get_next_nodes(&node_id, &command);
404                    next_nodes.extend(next);
405                }
406
407                // Deduplicate next nodes
408                let next_set: HashSet<String> = next_nodes.into_iter().collect();
409                current_nodes = next_set.into_iter().collect();
410            }
411        }
412
413        info!("Graph '{}' execution completed", self.id);
414        Ok(state)
415    }
416
417    async fn stream(
418        &self,
419        input: S,
420        config: Option<RuntimeContext>,
421    ) -> AgentResult<Pin<Box<dyn Stream<Item = AgentResult<StreamEvent<S>>> + Send>>> {
422        let ctx = config.unwrap_or_else(|| {
423            RuntimeContext::with_config(&self.id, self.config.clone())
424        });
425
426        let nodes = self.nodes.clone();
427        let reducers = self.reducers.clone();
428        let entry_point = self.entry_point.clone();
429
430        // Create a channel for streaming events
431        let (tx, rx) = tokio::sync::mpsc::channel(100);
432
433        // Spawn execution task
434        tokio::spawn(async move {
435            let mut state = input;
436            let mut current_nodes = vec![entry_point];
437
438            while !current_nodes.is_empty() {
439                // Check recursion limit
440                if ctx.remaining_steps.is_exhausted().await {
441                    let _ = tx.send(Err(AgentError::Internal(
442                        "Recursion limit reached".to_string()
443                    ))).await;
444                    return;
445                }
446                ctx.remaining_steps.decrement().await;
447
448                let nodes_to_execute = std::mem::take(&mut current_nodes);
449
450                for node_id in nodes_to_execute {
451                    let node = match nodes.get(&node_id) {
452                        Some(n) => n,
453                        None => {
454                            let _ = tx.send(Err(AgentError::NotFound(format!("Node '{}'", node_id)))).await;
455                            return;
456                        }
457                    };
458
459                    ctx.set_current_node(&node_id).await;
460
461                    // Send start event
462                    let _ = tx.send(Ok(StreamEvent::NodeStart {
463                        node_id: node_id.clone(),
464                        state: state.clone(),
465                    })).await;
466
467                    // Execute node
468                    let command = match node.call(&mut state, &ctx).await {
469                        Ok(cmd) => cmd,
470                        Err(e) => {
471                            let _ = tx.send(Ok(StreamEvent::Error {
472                                node_id: Some(node_id),
473                                error: e.to_string(),
474                            })).await;
475                            return;
476                        }
477                    };
478
479                    // Apply updates
480                    for update in &command.updates {
481                        let current = state.get_value(&update.key);
482                        let new_value = if let Some(reducer) = reducers.get(&update.key) {
483                            match reducer.reduce(current.as_ref(), &update.value).await {
484                                Ok(v) => v,
485                                Err(e) => {
486                                    let _ = tx.send(Ok(StreamEvent::Error {
487                                        node_id: Some(node_id.clone()),
488                                        error: e.to_string(),
489                                    })).await;
490                                    return;
491                                }
492                            }
493                        } else {
494                            update.value.clone()
495                        };
496                        if let Err(e) = state.apply_update(&update.key, new_value).await {
497                            let _ = tx.send(Ok(StreamEvent::Error {
498                                node_id: Some(node_id.clone()),
499                                error: e.to_string(),
500                            })).await;
501                            return;
502                        }
503                    }
504
505                    // Send end event
506                    let _ = tx.send(Ok(StreamEvent::NodeEnd {
507                        node_id: node_id.clone(),
508                        state: state.clone(),
509                        command: command.clone(),
510                    })).await;
511                }
512
513                // For simplicity, break after first round
514                // TODO: Implement proper edge following
515                break;
516            }
517
518            // Send final event
519            let _ = tx.send(Ok(StreamEvent::End {
520                final_state: state,
521            })).await;
522        });
523
524        // Convert receiver to stream
525        Ok(Box::pin(tokio_stream::wrappers::ReceiverStream::new(rx)))
526    }
527
528    async fn step(&self, input: S, config: Option<RuntimeContext>) -> AgentResult<StepResult<S>> {
529        let ctx = config.unwrap_or_else(|| {
530            RuntimeContext::with_config(&self.id, self.config.clone())
531        });
532
533        let mut state = input;
534
535        // Get current node from context or use entry point
536        let current_node_id = ctx.current_node().await;
537        let node_id = if current_node_id.is_empty() {
538            self.entry_point.clone()
539        } else {
540            current_node_id
541        };
542
543        let node = self.nodes.get(&node_id)
544            .ok_or_else(|| AgentError::NotFound(format!("Node '{}'", node_id)))?;
545
546        ctx.set_current_node(&node_id).await;
547        let command = node.call(&mut state, &ctx).await?;
548
549        // Apply updates
550        self.apply_updates(&mut state, &command.updates).await?;
551
552        // Get next nodes
553        let next_nodes = self.get_next_nodes(&node_id, &command);
554        let is_complete = next_nodes.is_empty();
555        let next_node = next_nodes.into_iter().next();
556
557        Ok(StepResult {
558            state,
559            node_id,
560            command,
561            is_complete,
562            next_node,
563        })
564    }
565
566    fn validate_state(&self, _state: &S) -> AgentResult<()> {
567        // Default implementation - no validation
568        Ok(())
569    }
570
571    fn state_schema(&self) -> HashMap<String, String> {
572        self.reducers.iter()
573            .map(|(k, r)| (k.clone(), r.reducer_type().to_string()))
574            .collect()
575    }
576}
577
578#[cfg(test)]
579mod tests {
580    use super::*;
581    use mofa_kernel::workflow::{JsonState, StateGraph};
582    use serde_json::json;
583
584    // Simple test node
585    struct TestNode {
586        name: String,
587        updates: Vec<StateUpdate>,
588    }
589
590    #[async_trait]
591    impl NodeFunc<JsonState> for TestNode {
592        async fn call(&self, _state: &mut JsonState, _ctx: &RuntimeContext) -> AgentResult<Command> {
593            let mut cmd = Command::new();
594            for update in &self.updates {
595                cmd = cmd.update(update.key.clone(), update.value.clone());
596            }
597            Ok(cmd.continue_())
598        }
599
600        fn name(&self) -> &str {
601            &self.name
602        }
603    }
604
605    #[tokio::test]
606    async fn test_state_graph_build_and_compile() {
607        let mut graph = StateGraphImpl::<JsonState>::new("test_graph");
608
609        graph
610            .add_node("start_node", Box::new(TestNode {
611                name: "start".to_string(),
612                updates: vec![StateUpdate::new("initialized", json!(true))],
613            }))
614            .add_node("end_node", Box::new(TestNode {
615                name: "end".to_string(),
616                updates: vec![StateUpdate::new("completed", json!(true))],
617            }))
618            .add_edge(START, "start_node")
619            .add_edge("start_node", "end_node")
620            .add_edge("end_node", END);
621
622        let compiled = graph.compile();
623        assert!(compiled.is_ok());
624    }
625
626    #[tokio::test]
627    async fn test_state_graph_no_entry_point() {
628        let mut graph = StateGraphImpl::<JsonState>::new("test_graph");
629
630        graph.add_node("node1", Box::new(TestNode {
631            name: "node1".to_string(),
632            updates: vec![],
633        }));
634
635        let result = graph.compile();
636        assert!(result.is_err());
637    }
638
639    #[tokio::test]
640    async fn test_compiled_graph_invoke() {
641        let mut graph = StateGraphImpl::<JsonState>::new("test_graph");
642
643        graph
644            .add_node("process", Box::new(TestNode {
645                name: "process".to_string(),
646                updates: vec![
647                    StateUpdate::new("processed", json!(true)),
648                    StateUpdate::new("count", json!(1)),
649                ],
650            }))
651            .add_edge(START, "process")
652            .add_edge("process", END);
653
654        let compiled = graph.compile().unwrap();
655
656        let initial_state = JsonState::new();
657        let result = compiled.invoke(initial_state, None).await;
658
659        assert!(result.is_ok());
660        let final_state = result.unwrap();
661        assert_eq!(final_state.get_value("processed"), Some(json!(true)));
662        assert_eq!(final_state.get_value("count"), Some(json!(1)));
663    }
664}