1use async_trait::async_trait;
7use futures::Stream;
8use mofa_kernel::agent::error::{AgentError, AgentResult};
9use mofa_kernel::workflow::{
10 CompiledGraph, Command, ControlFlow, EdgeTarget, GraphConfig, GraphState,
11 NodeFunc, Reducer, RuntimeContext, StateUpdate, StreamEvent, StepResult, END, START,
12};
13use serde_json::Value;
14use std::collections::{HashMap, HashSet};
15use std::pin::Pin;
16use std::sync::Arc;
17use tracing::{debug, info, warn};
18
19pub type NodeId = String;
21
22pub struct StateGraphImpl<S: GraphState> {
38 id: String,
40 nodes: HashMap<NodeId, Box<dyn NodeFunc<S>>>,
42 edges: HashMap<NodeId, EdgeTarget>,
44 reducers: HashMap<String, Box<dyn Reducer>>,
46 entry_point: Option<NodeId>,
48 finish_points: Vec<NodeId>,
50 config: GraphConfig,
52}
53
54impl<S: GraphState> StateGraphImpl<S> {
55 pub fn build(id: impl Into<String>) -> Self {
57 Self {
58 id: id.into(),
59 nodes: HashMap::new(),
60 edges: HashMap::new(),
61 reducers: HashMap::new(),
62 entry_point: None,
63 finish_points: Vec::new(),
64 config: GraphConfig::default(),
65 }
66 }
67
68 pub fn node_count(&self) -> usize {
70 self.nodes.len()
71 }
72
73 pub fn edge_count(&self) -> usize {
75 self.edges.len()
76 }
77
78 pub fn node_ids(&self) -> Vec<&str> {
80 self.nodes.keys().map(|s| s.as_str()).collect()
81 }
82
83 pub fn validate(&self) -> AgentResult<()> {
85 let mut errors = Vec::new();
86
87 if self.entry_point.is_none() {
89 errors.push("No entry point set. Use set_entry_point() or add_edge(START, node).".to_string());
90 }
91
92 if let Some(entry) = &self.entry_point {
94 let reachable = self.find_reachable_nodes(entry);
95 for node_id in self.nodes.keys() {
96 if !reachable.contains(node_id) && node_id != entry {
97 errors.push(format!("Node '{}' is not reachable from entry point", node_id));
98 }
99 }
100 }
101
102 for (from, target) in &self.edges {
104 if from != START && !self.nodes.contains_key(from) {
105 errors.push(format!("Edge source '{}' does not exist", from));
106 }
107 let targets = target.targets();
108 for target_id in targets {
109 if target_id != END && !self.nodes.contains_key(target_id) {
110 errors.push(format!("Edge target '{}' does not exist", target_id));
111 }
112 }
113 }
114
115 if errors.is_empty() {
116 Ok(())
117 } else {
118 Err(AgentError::ValidationFailed(errors.join("; ")))
119 }
120 }
121
122 fn find_reachable_nodes(&self, start: &str) -> HashSet<String> {
124 let mut reachable = HashSet::new();
125 let mut stack = vec![start.to_string()];
126
127 while let Some(node_id) = stack.pop() {
128 if reachable.insert(node_id.clone()) {
129 if let Some(edge_target) = self.edges.get(&node_id) {
130 let targets = edge_target.targets();
131 for target in targets {
132 if target != END && !reachable.contains(target) {
133 stack.push(target.to_string());
134 }
135 }
136 }
137 }
138 }
139
140 reachable
141 }
142}
143
144#[async_trait]
145impl<S: GraphState + 'static> mofa_kernel::workflow::StateGraph for StateGraphImpl<S> {
146 type State = S;
147 type Compiled = CompiledGraphImpl<S>;
148
149 fn new(id: impl Into<String>) -> Self {
150 Self::build(id)
151 }
152
153 fn add_node(&mut self, id: impl Into<String>, node: Box<dyn NodeFunc<S>>) -> &mut Self {
154 let node_id = id.into();
155 debug!("Adding node '{}' to graph '{}'", node_id, self.id);
156 self.nodes.insert(node_id, node);
157 self
158 }
159
160 fn add_edge(&mut self, from: impl Into<String>, to: impl Into<String>) -> &mut Self {
161 let from_id = from.into();
162 let to_id = to.into();
163
164 debug!("Adding edge: {} -> {}", from_id, to_id);
165
166 if from_id == START {
168 self.entry_point = Some(to_id.clone());
169 return self;
170 }
171
172 if to_id == END {
174 if !self.finish_points.contains(&from_id) {
175 self.finish_points.push(from_id.clone());
176 }
177 return self;
178 }
179
180 match self.edges.get_mut(&from_id) {
182 Some(EdgeTarget::Parallel(targets)) => {
183 targets.push(to_id);
184 }
185 Some(EdgeTarget::Single(existing)) => {
186 let existing = existing.clone();
187 self.edges.insert(from_id, EdgeTarget::parallel(vec![existing, to_id]));
188 }
189 Some(EdgeTarget::Conditional(_)) => {
190 warn!("Overwriting conditional edges with single edge for '{}'", from_id);
191 self.edges.insert(from_id, EdgeTarget::single(to_id));
192 }
193 None => {
194 self.edges.insert(from_id, EdgeTarget::single(to_id));
195 }
196 }
197
198 self
199 }
200
201 fn add_conditional_edges(
202 &mut self,
203 from: impl Into<String>,
204 conditions: HashMap<String, String>,
205 ) -> &mut Self {
206 let from_id = from.into();
207 debug!("Adding conditional edges from '{}': {:?}", from_id, conditions);
208 self.edges.insert(from_id, EdgeTarget::conditional(conditions));
209 self
210 }
211
212 fn add_parallel_edges(&mut self, from: impl Into<String>, targets: Vec<String>) -> &mut Self {
213 let from_id = from.into();
214 debug!("Adding parallel edges from '{}': {:?}", from_id, targets);
215 self.edges.insert(from_id, EdgeTarget::parallel(targets));
216 self
217 }
218
219 fn set_entry_point(&mut self, node: impl Into<String>) -> &mut Self {
220 let node_id = node.into();
221 debug!("Setting entry point to '{}'", node_id);
222 self.entry_point = Some(node_id);
223 self
224 }
225
226 fn set_finish_point(&mut self, node: impl Into<String>) -> &mut Self {
227 let node_id = node.into();
228 debug!("Setting finish point at '{}'", node_id);
229 if !self.finish_points.contains(&node_id) {
230 self.finish_points.push(node_id);
231 }
232 self
233 }
234
235 fn add_reducer(&mut self, key: impl Into<String>, reducer: Box<dyn Reducer>) -> &mut Self {
236 let key_str = key.into();
237 debug!("Adding reducer for key '{}' of type {:?}", key_str, reducer.reducer_type());
238 self.reducers.insert(key_str, reducer);
239 self
240 }
241
242 fn with_config(&mut self, config: GraphConfig) -> &mut Self {
243 self.config = config;
244 self
245 }
246
247 fn id(&self) -> &str {
248 &self.id
249 }
250
251 fn compile(self) -> AgentResult<CompiledGraphImpl<S>> {
252 info!("Compiling graph '{}'", self.id);
253
254 self.validate()?;
256
257 Ok(CompiledGraphImpl {
259 id: self.id,
260 nodes: Arc::new(self.nodes),
261 edges: Arc::new(self.edges),
262 reducers: Arc::new(self.reducers),
263 entry_point: self.entry_point.expect("Entry point should be validated"),
264 config: self.config,
265 })
266 }
267}
268
269pub struct CompiledGraphImpl<S: GraphState> {
271 id: String,
273 nodes: Arc<HashMap<NodeId, Box<dyn NodeFunc<S>>>>,
275 edges: Arc<HashMap<NodeId, EdgeTarget>>,
277 reducers: Arc<HashMap<String, Box<dyn Reducer>>>,
279 entry_point: NodeId,
281 config: GraphConfig,
283}
284
285impl<S: GraphState> CompiledGraphImpl<S> {
286 fn get_next_nodes(&self, current_node: &str, command: &Command) -> Vec<String> {
288 match &command.control {
289 ControlFlow::Goto(target) => {
290 vec![target.clone()]
291 }
292 ControlFlow::Return => {
293 vec![] }
295 ControlFlow::Send(sends) => {
296 sends.iter().map(|s| s.target.clone()).collect()
298 }
299 ControlFlow::Continue => {
300 match self.edges.get(current_node) {
302 Some(EdgeTarget::Single(target)) => vec![target.clone()],
303 Some(EdgeTarget::Parallel(targets)) => targets.clone(),
304 Some(EdgeTarget::Conditional(routes)) => {
305 for update in &command.updates {
307 if let Some(target) = routes.get(&update.key) {
308 return vec![target.clone()];
309 }
310 }
311 routes.values().next()
313 .map(|t: &String| vec![t.clone()])
314 .unwrap_or_default()
315 }
316 None => vec![],
317 }
318 }
319 }
320 }
321
322 async fn apply_updates(&self, state: &mut S, updates: &[StateUpdate]) -> AgentResult<()> {
324 for update in updates {
325 let current = state.get_value(&update.key);
326
327 let new_value = if let Some(reducer) = self.reducers.get(&update.key) {
329 reducer.reduce(current.as_ref(), &update.value).await?
330 } else {
331 update.value.clone()
333 };
334
335 state.apply_update(&update.key, new_value).await?;
336 }
337 Ok(())
338 }
339}
340
341#[async_trait]
342impl<S: GraphState + 'static> CompiledGraph<S> for CompiledGraphImpl<S> {
343 fn id(&self) -> &str {
344 &self.id
345 }
346
347 async fn invoke(&self, input: S, config: Option<RuntimeContext>) -> AgentResult<S> {
348 let ctx = config.unwrap_or_else(|| {
349 RuntimeContext::with_config(&self.id, self.config.clone())
350 });
351
352 info!("Starting graph execution '{}' with execution_id={}", self.id, ctx.execution_id);
353
354 let mut state = input;
355 let mut current_nodes = vec![self.entry_point.clone()];
356
357 while !current_nodes.is_empty() {
358 if ctx.is_recursion_limit_reached().await {
360 return Err(AgentError::Internal(
361 "Recursion limit reached".to_string()
362 ));
363 }
364 ctx.decrement_steps().await;
365
366 if current_nodes.len() == 1 {
368 let node_id = current_nodes.remove(0);
370 let node = self.nodes.get(&node_id)
371 .ok_or_else(|| AgentError::NotFound(format!("Node '{}'", node_id)))?;
372
373 ctx.set_current_node(&node_id).await;
374 debug!("Executing node '{}' in graph '{}'", node_id, self.id);
375
376 let command = node.call(&mut state, &ctx).await?;
377
378 self.apply_updates(&mut state, &command.updates).await?;
380
381 current_nodes = self.get_next_nodes(&node_id, &command);
383
384 debug!("Node '{}' completed, next nodes: {:?}", node_id, current_nodes);
385 } else {
386 let mut next_nodes = Vec::new();
388 let nodes_to_execute = std::mem::take(&mut current_nodes);
389
390 for node_id in nodes_to_execute {
391 let node = self.nodes.get(&node_id)
392 .ok_or_else(|| AgentError::NotFound(format!("Node '{}'", node_id)))?;
393
394 ctx.set_current_node(&node_id).await;
395 debug!("Executing node '{}' (parallel)", node_id);
396
397 let command = node.call(&mut state, &ctx).await?;
398
399 self.apply_updates(&mut state, &command.updates).await?;
401
402 let next = self.get_next_nodes(&node_id, &command);
404 next_nodes.extend(next);
405 }
406
407 let next_set: HashSet<String> = next_nodes.into_iter().collect();
409 current_nodes = next_set.into_iter().collect();
410 }
411 }
412
413 info!("Graph '{}' execution completed", self.id);
414 Ok(state)
415 }
416
417 async fn stream(
418 &self,
419 input: S,
420 config: Option<RuntimeContext>,
421 ) -> AgentResult<Pin<Box<dyn Stream<Item = AgentResult<StreamEvent<S>>> + Send>>> {
422 let ctx = config.unwrap_or_else(|| {
423 RuntimeContext::with_config(&self.id, self.config.clone())
424 });
425
426 let nodes = self.nodes.clone();
427 let reducers = self.reducers.clone();
428 let entry_point = self.entry_point.clone();
429
430 let (tx, rx) = tokio::sync::mpsc::channel(100);
432
433 tokio::spawn(async move {
435 let mut state = input;
436 let mut current_nodes = vec![entry_point];
437
438 while !current_nodes.is_empty() {
439 if ctx.remaining_steps.is_exhausted().await {
441 let _ = tx.send(Err(AgentError::Internal(
442 "Recursion limit reached".to_string()
443 ))).await;
444 return;
445 }
446 ctx.remaining_steps.decrement().await;
447
448 let nodes_to_execute = std::mem::take(&mut current_nodes);
449
450 for node_id in nodes_to_execute {
451 let node = match nodes.get(&node_id) {
452 Some(n) => n,
453 None => {
454 let _ = tx.send(Err(AgentError::NotFound(format!("Node '{}'", node_id)))).await;
455 return;
456 }
457 };
458
459 ctx.set_current_node(&node_id).await;
460
461 let _ = tx.send(Ok(StreamEvent::NodeStart {
463 node_id: node_id.clone(),
464 state: state.clone(),
465 })).await;
466
467 let command = match node.call(&mut state, &ctx).await {
469 Ok(cmd) => cmd,
470 Err(e) => {
471 let _ = tx.send(Ok(StreamEvent::Error {
472 node_id: Some(node_id),
473 error: e.to_string(),
474 })).await;
475 return;
476 }
477 };
478
479 for update in &command.updates {
481 let current = state.get_value(&update.key);
482 let new_value = if let Some(reducer) = reducers.get(&update.key) {
483 match reducer.reduce(current.as_ref(), &update.value).await {
484 Ok(v) => v,
485 Err(e) => {
486 let _ = tx.send(Ok(StreamEvent::Error {
487 node_id: Some(node_id.clone()),
488 error: e.to_string(),
489 })).await;
490 return;
491 }
492 }
493 } else {
494 update.value.clone()
495 };
496 if let Err(e) = state.apply_update(&update.key, new_value).await {
497 let _ = tx.send(Ok(StreamEvent::Error {
498 node_id: Some(node_id.clone()),
499 error: e.to_string(),
500 })).await;
501 return;
502 }
503 }
504
505 let _ = tx.send(Ok(StreamEvent::NodeEnd {
507 node_id: node_id.clone(),
508 state: state.clone(),
509 command: command.clone(),
510 })).await;
511 }
512
513 break;
516 }
517
518 let _ = tx.send(Ok(StreamEvent::End {
520 final_state: state,
521 })).await;
522 });
523
524 Ok(Box::pin(tokio_stream::wrappers::ReceiverStream::new(rx)))
526 }
527
528 async fn step(&self, input: S, config: Option<RuntimeContext>) -> AgentResult<StepResult<S>> {
529 let ctx = config.unwrap_or_else(|| {
530 RuntimeContext::with_config(&self.id, self.config.clone())
531 });
532
533 let mut state = input;
534
535 let current_node_id = ctx.current_node().await;
537 let node_id = if current_node_id.is_empty() {
538 self.entry_point.clone()
539 } else {
540 current_node_id
541 };
542
543 let node = self.nodes.get(&node_id)
544 .ok_or_else(|| AgentError::NotFound(format!("Node '{}'", node_id)))?;
545
546 ctx.set_current_node(&node_id).await;
547 let command = node.call(&mut state, &ctx).await?;
548
549 self.apply_updates(&mut state, &command.updates).await?;
551
552 let next_nodes = self.get_next_nodes(&node_id, &command);
554 let is_complete = next_nodes.is_empty();
555 let next_node = next_nodes.into_iter().next();
556
557 Ok(StepResult {
558 state,
559 node_id,
560 command,
561 is_complete,
562 next_node,
563 })
564 }
565
566 fn validate_state(&self, _state: &S) -> AgentResult<()> {
567 Ok(())
569 }
570
571 fn state_schema(&self) -> HashMap<String, String> {
572 self.reducers.iter()
573 .map(|(k, r)| (k.clone(), r.reducer_type().to_string()))
574 .collect()
575 }
576}
577
578#[cfg(test)]
579mod tests {
580 use super::*;
581 use mofa_kernel::workflow::{JsonState, StateGraph};
582 use serde_json::json;
583
584 struct TestNode {
586 name: String,
587 updates: Vec<StateUpdate>,
588 }
589
590 #[async_trait]
591 impl NodeFunc<JsonState> for TestNode {
592 async fn call(&self, _state: &mut JsonState, _ctx: &RuntimeContext) -> AgentResult<Command> {
593 let mut cmd = Command::new();
594 for update in &self.updates {
595 cmd = cmd.update(update.key.clone(), update.value.clone());
596 }
597 Ok(cmd.continue_())
598 }
599
600 fn name(&self) -> &str {
601 &self.name
602 }
603 }
604
605 #[tokio::test]
606 async fn test_state_graph_build_and_compile() {
607 let mut graph = StateGraphImpl::<JsonState>::new("test_graph");
608
609 graph
610 .add_node("start_node", Box::new(TestNode {
611 name: "start".to_string(),
612 updates: vec![StateUpdate::new("initialized", json!(true))],
613 }))
614 .add_node("end_node", Box::new(TestNode {
615 name: "end".to_string(),
616 updates: vec![StateUpdate::new("completed", json!(true))],
617 }))
618 .add_edge(START, "start_node")
619 .add_edge("start_node", "end_node")
620 .add_edge("end_node", END);
621
622 let compiled = graph.compile();
623 assert!(compiled.is_ok());
624 }
625
626 #[tokio::test]
627 async fn test_state_graph_no_entry_point() {
628 let mut graph = StateGraphImpl::<JsonState>::new("test_graph");
629
630 graph.add_node("node1", Box::new(TestNode {
631 name: "node1".to_string(),
632 updates: vec![],
633 }));
634
635 let result = graph.compile();
636 assert!(result.is_err());
637 }
638
639 #[tokio::test]
640 async fn test_compiled_graph_invoke() {
641 let mut graph = StateGraphImpl::<JsonState>::new("test_graph");
642
643 graph
644 .add_node("process", Box::new(TestNode {
645 name: "process".to_string(),
646 updates: vec![
647 StateUpdate::new("processed", json!(true)),
648 StateUpdate::new("count", json!(1)),
649 ],
650 }))
651 .add_edge(START, "process")
652 .add_edge("process", END);
653
654 let compiled = graph.compile().unwrap();
655
656 let initial_state = JsonState::new();
657 let result = compiled.invoke(initial_state, None).await;
658
659 assert!(result.is_ok());
660 let final_state = result.unwrap();
661 assert_eq!(final_state.get_value("processed"), Some(json!(true)));
662 assert_eq!(final_state.get_value("count"), Some(json!(1)));
663 }
664}