1use std::collections::HashMap;
10use std::time::Instant;
11
12use tokio::sync::mpsc;
13
14use crate::barrier_node::BarrierDefaultAction;
15use crate::error::{GraphError, ObservedError, TerminalError};
16use crate::event::{
17 BarrierDecision, BarrierDecisionMessage, BarrierId, GraphEvent, GraphExecution, GraphHandle,
18 SpanId,
19};
20use crate::graph::{EdgeExceededStrategy, EdgePolicy, Graph};
21use crate::node::{GraphNode, NextStep, NodeKind, StreamNodeResult};
22use crate::state::{ExecutionEntry, GraphResult, State};
23
24struct DecisionRegistry {
30 pending: HashMap<BarrierId, BarrierDecision>,
31 wildcards: HashMap<String, BarrierDecision>,
32 occurrence_counter: HashMap<String, u32>,
34}
35
36impl DecisionRegistry {
37 fn new() -> Self {
38 Self {
39 pending: HashMap::new(),
40 wildcards: HashMap::new(),
41 occurrence_counter: HashMap::new(),
42 }
43 }
44
45 fn next_id(&mut self, node_id: &str) -> BarrierId {
47 let occ = self.occurrence_counter.entry(node_id.to_string()).or_insert(0);
48 *occ += 1;
49 BarrierId::new(node_id, *occ)
50 }
51
52 fn insert_exact(&mut self, barrier_id: BarrierId, decision: BarrierDecision) {
54 self.pending.insert(barrier_id, decision);
55 }
56
57 fn insert_wildcard(&mut self, node_id: String, decision: BarrierDecision) {
59 self.wildcards.insert(node_id, decision);
60 }
61
62 fn take(&mut self, target_id: &BarrierId) -> Option<BarrierDecision> {
65 if let Some(decision) = self.pending.remove(target_id) {
67 return Some(decision);
68 }
69 self.wildcards.get(&target_id.node_id).cloned()
71 }
72
73 fn process_message(
75 &mut self,
76 msg: BarrierDecisionMessage,
77 target_id: &BarrierId,
78 ) -> Option<BarrierDecision> {
79 match msg {
80 BarrierDecisionMessage::Exact {
81 barrier_id,
82 decision,
83 } => {
84 if barrier_id == *target_id {
85 Some(decision)
86 } else {
87 self.insert_exact(barrier_id, decision);
88 None
89 }
90 }
91 BarrierDecisionMessage::Wildcard { node_id, decision } => {
92 if node_id == target_id.node_id {
93 Some(decision)
94 } else {
95 self.insert_wildcard(node_id, decision);
96 None
97 }
98 }
99 }
100 }
101}
102
103#[derive(Debug)]
107enum EdgeTransitionResult {
108 Ok,
110 PolicyExceededStrict { edge: String, limit: usize },
112 PolicyExceededSoftFallback,
114 Dropped,
116}
117
118#[derive(Default)]
121struct EdgeVisits(HashMap<(String, String), usize>);
122
123impl EdgeVisits {
124 fn record(
125 &mut self,
126 from: &str,
127 to: &str,
128 policy: Option<&crate::graph::EdgePolicy>,
129 ) -> EdgeTransitionResult {
130 let key = (from.to_string(), to.to_string());
131 let count = self.0.entry(key).or_insert(0);
132 *count += 1;
133
134 if let Some(EdgePolicy::MaxVisits { limit, on_exceeded }) = policy {
135 if *count > *limit {
136 return match on_exceeded {
137 EdgeExceededStrategy::Strict => EdgeTransitionResult::PolicyExceededStrict {
138 edge: format!("{from}→{to}"),
139 limit: *limit,
140 },
141 EdgeExceededStrategy::SoftFallback => {
142 EdgeTransitionResult::PolicyExceededSoftFallback
143 }
144 EdgeExceededStrategy::Drop => EdgeTransitionResult::Dropped,
145 };
146 }
147 }
148 EdgeTransitionResult::Ok
149 }
150}
151
152#[derive(Clone, Debug)]
156pub struct GraphExecutor {
157 pub max_steps: usize,
160}
161
162impl Default for GraphExecutor {
163 fn default() -> Self {
164 Self { max_steps: 50 }
165 }
166}
167
168impl GraphExecutor {
169 pub fn new(max_steps: usize) -> Self {
170 Self { max_steps }
171 }
172
173 pub async fn execute(
184 &self,
185 graph: std::sync::Arc<Graph>,
186 initial_state: State,
187 ) -> Result<GraphResult, GraphError> {
188 for (name, node) in &graph.nodes {
190 if matches!(node, NodeKind::Barrier(_)) {
191 return Err(GraphError::Terminal(TerminalError::InvalidGraph(format!(
192 "BarrierNode '{}' requires stream mode. Use GraphExecutor::execute_stream() for human-in-the-loop.",
193 name
194 ))));
195 }
196 }
197
198 let GraphExecution { mut stream, handle } =
199 self.execute_stream(graph, initial_state);
200
201 drop(handle);
204
205 let mut result = None;
206
207 while let Some(event) = stream.recv().await {
208 match event {
209 GraphEvent::GraphComplete { result: r } => {
210 result = Some(Ok(r));
211 }
212 GraphEvent::GraphError { error, .. } => {
213 result = Some(Err(error));
214 }
215 _ => {}
216 }
217 }
218
219 result.unwrap_or_else(|| {
220 Err(GraphError::Terminal(
221 TerminalError::InvalidGraph("stream ended without completion".into()),
222 ))
223 })
224 }
225
226 pub fn execute_stream(
232 &self,
233 graph: std::sync::Arc<Graph>,
234 initial_state: State,
235 ) -> GraphExecution {
236 let executor = self.clone();
237 let (event_tx, event_rx) = mpsc::channel(32);
238 let (decision_tx, mut decision_rx) = mpsc::channel(16);
239 let (cancel_tx, mut cancel_rx) = mpsc::channel(1);
240
241 let handle = GraphHandle::new(decision_tx, cancel_tx);
242
243 tokio::spawn(async move {
244 let start_time = Instant::now();
245 let mut state = initial_state;
246 let mut execution_log = Vec::new();
247 let mut edge_visits = EdgeVisits::default();
248 let mut decision_registry = DecisionRegistry::new();
249
250 let mut current = graph.start_node().to_string();
251 let mut step: usize = 0;
252
253 let send = |event: GraphEvent| async {
254 if event_tx.send(event).await.is_err() {
255 tracing::warn!("graph event consumer dropped");
256 }
257 };
258
259 let mut completed = false;
261
262 loop {
263 if cancel_rx.try_recv().is_ok() {
267 let _ = send(
268 GraphEvent::GraphError {
269 error: GraphError::Terminal(TerminalError::BarrierCancelled {
270 node: "execution cancelled by handle".into(),
271 }),
272 state: state.clone(),
273 },
274 )
275 .await;
276 break;
277 }
278
279 step += 1;
280
281 if step > executor.max_steps {
283 let _ = send(
284 GraphEvent::GraphError {
285 error: GraphError::Terminal(TerminalError::StepsExceeded {
286 limit: executor.max_steps,
287 }),
288 state: state.clone(),
289 },
290 )
291 .await;
292 break;
293 }
294
295 let node = match graph.nodes.get(¤t) {
296 Some(n) => n,
297 None => {
298 let _ = send(
299 GraphEvent::GraphError {
300 error: GraphError::Terminal(TerminalError::NodeNotFound(
301 current.clone(),
302 )),
303 state: state.clone(),
304 },
305 )
306 .await;
307 break;
308 }
309 };
310
311 let node_name = current.clone();
312 let span_id = SpanId::new();
313
314 let _ = send(
315 GraphEvent::NodeStart {
316 node_name: node_name.clone(),
317 span_id,
318 step,
319 },
320 )
321 .await;
322
323 let node_start = Instant::now();
324 let result = node.execute_stream(&mut state, &event_tx, span_id).await;
325 let node_end = Instant::now();
326 let duration = node_end.duration_since(node_start);
327
328 match result {
329 Ok(StreamNodeResult::Done { next, span_id: _ }) => {
330 execution_log.push(ExecutionEntry {
331 node_name: node_name.clone(),
332 start_time: node_start,
333 end_time: node_end,
334 success: true,
335 });
336
337 let _ = send(
338 GraphEvent::NodeEnd {
339 node_name: node_name.clone(),
340 span_id,
341 success: true,
342 duration,
343 },
344 )
345 .await;
346
347 if current == graph.end_node() {
349 completed = true;
350 break;
351 }
352
353 match executor.resolve_next(
354 &graph,
355 ¤t,
356 &mut state,
357 &mut edge_visits,
358 next,
359 ) {
360 Ok(target) => current = target,
361 Err(e) => {
362 let _ = send(GraphEvent::GraphError { error: e, state: state.clone() }).await;
363 break;
364 }
365 }
366 }
367
368 Ok(StreamNodeResult::Observed {
369 error,
370 next,
371 span_id: _,
372 }) => {
373 execution_log.push(ExecutionEntry {
374 node_name: node_name.clone(),
375 start_time: node_start,
376 end_time: node_end,
377 success: true,
378 });
379
380 let _ = send(
381 GraphEvent::NodeEnd {
382 node_name: node_name.clone(),
383 span_id,
384 success: true,
385 duration,
386 },
387 )
388 .await;
389
390 let _ = send(GraphEvent::ObservedError {
391 error,
392 node_name: node_name.clone(),
393 })
394 .await;
395
396 if current == graph.end_node() {
398 completed = true;
399 break;
400 }
401
402 match executor.resolve_next(
403 &graph,
404 ¤t,
405 &mut state,
406 &mut edge_visits,
407 next,
408 ) {
409 Ok(target) => current = target,
410 Err(e) => {
411 let _ = send(GraphEvent::GraphError { error: e, state: state.clone() }).await;
412 break;
413 }
414 }
415 }
416
417 Ok(StreamNodeResult::BarrierPaused {
418 barrier_id: _, node_name: barrier_name,
420 span_id: _,
421 timeout,
422 default_action,
423 }) => {
424 let barrier_id = decision_registry.next_id(&barrier_name);
426
427 let _ = send(
429 GraphEvent::BarrierWaiting {
430 barrier_id: barrier_id.clone(),
431 node_name: barrier_name.clone(),
432 span_id,
433 },
434 )
435 .await;
436
437 let decision = executor
439 .wait_barrier_decision(
440 &mut decision_rx,
441 &mut decision_registry,
442 &barrier_id,
443 timeout,
444 &default_action,
445 &mut cancel_rx,
446 )
447 .await;
448
449 if cancel_rx.try_recv().is_ok() {
451 let _ = send(
452 GraphEvent::GraphError {
453 error: GraphError::Terminal(TerminalError::BarrierCancelled {
454 node: barrier_name.clone(),
455 }),
456 state: state.clone(),
457 },
458 )
459 .await;
460 break;
461 }
462
463 let _ = send(
465 GraphEvent::BarrierResolved {
466 barrier_id: barrier_id.clone(),
467 decision: decision.clone(),
468 },
469 )
470 .await;
471
472 let next = match node {
474 NodeKind::Barrier(b) => match b.apply_decision(decision, &mut state) {
475 Ok(ns) => ns,
476 Err(e) => {
477 let _ =
478 send(GraphEvent::GraphError { error: e, state: state.clone() })
479 .await;
480 break;
481 }
482 },
483 _ => unreachable!("expected BarrierNode for BarrierPaused"),
484 };
485
486 execution_log.push(ExecutionEntry {
487 node_name: barrier_name.clone(),
488 start_time: node_start,
489 end_time: Instant::now(),
490 success: true,
491 });
492
493 let _ = send(
494 GraphEvent::NodeEnd {
495 node_name: barrier_name.clone(),
496 span_id,
497 success: true,
498 duration: Instant::now().duration_since(node_start),
499 },
500 )
501 .await;
502
503 if current == graph.end_node() {
505 completed = true;
506 break;
507 }
508
509 match executor.resolve_next(
510 &graph,
511 ¤t,
512 &mut state,
513 &mut edge_visits,
514 next,
515 ) {
516 Ok(target) => current = target,
517 Err(e) => {
518 let _ = send(GraphEvent::GraphError { error: e, state: state.clone() }).await;
519 break;
520 }
521 }
522 }
523
524 Err(e) => {
525 execution_log.push(ExecutionEntry {
526 node_name: node_name.clone(),
527 start_time: node_start,
528 end_time: node_end,
529 success: false,
530 });
531
532 let _ = send(
533 GraphEvent::NodeEnd {
534 node_name: node_name.clone(),
535 span_id,
536 success: false,
537 duration,
538 },
539 )
540 .await;
541
542 match &e {
544 GraphError::Terminal(_) => {
545 let _ = send(GraphEvent::GraphError { error: e, state: state.clone() }).await;
546 break;
547 }
548 GraphError::Recoverable(recoverable) => {
549 tracing::warn!(
551 node = %node_name,
552 error = %recoverable,
553 "Recoverable error captured. Attempting fallback route..."
554 );
555
556 if let Some(fallback_target) = graph.find_fallback_edge(¤t) {
557 let _ = send(
559 GraphEvent::ObservedError {
560 error: ObservedError::Degraded {
561 node: node_name.clone(),
562 message: format!(
563 "fallback to '{}' due to: {}",
564 fallback_target, recoverable
565 ),
566 },
567 node_name: node_name.clone(),
568 },
569 )
570 .await;
571
572 current = fallback_target;
574 } else {
575 let _ = send(
577 GraphEvent::GraphError {
578 error: GraphError::Terminal(
579 TerminalError::NodeExecutionFailed {
580 node: node_name.clone(),
581 source: format!(
582 "Recoverable error with no fallback edge: {}",
583 recoverable
584 )
585 .into(),
586 },
587 ),
588 state: state.clone(),
589 },
590 )
591 .await;
592 break;
593 }
594 }
595 GraphError::Observed(observed) => {
596 let _ = send(
598 GraphEvent::ObservedError {
599 error: observed.clone(),
600 node_name: node_name.clone(),
601 },
602 )
603 .await;
604 if current == graph.end_node() {
609 completed = true;
610 break;
611 }
612 match executor.resolve_next(
613 &graph,
614 ¤t,
615 &mut state,
616 &mut edge_visits,
617 NextStep::GoToNext,
618 ) {
619 Ok(target) => current = target,
620 Err(e) => {
621 let _ = send(GraphEvent::GraphError { error: e, state: state.clone() }).await;
622 break;
623 }
624 }
625 }
626 }
627 }
628 }
629 }
630
631 if completed {
633 let _ = send(
634 GraphEvent::GraphComplete {
635 result: GraphResult {
636 state,
637 execution_log,
638 duration: start_time.elapsed(),
639 },
640 },
641 )
642 .await;
643 }
644 });
645
646 GraphExecution { stream: event_rx, handle }
647 }
648
649 async fn wait_barrier_decision(
651 &self,
652 decision_rx: &mut mpsc::Receiver<BarrierDecisionMessage>,
653 registry: &mut DecisionRegistry,
654 target_id: &BarrierId,
655 timeout: Option<std::time::Duration>,
656 default_action: &BarrierDefaultAction,
657 cancel_rx: &mut mpsc::Receiver<()>,
658 ) -> BarrierDecision {
659 if let Some(decision) = registry.take(target_id) {
661 return decision;
662 }
663
664 while let Ok(msg) = decision_rx.try_recv() {
666 if let Some(decision) = registry.process_message(msg, target_id) {
667 return decision;
668 }
669 }
670
671 if cancel_rx.try_recv().is_ok() {
673 return Self::default_decision(default_action);
674 }
675
676 if let Some(timeout) = timeout {
678 let start = std::time::Instant::now();
679 loop {
680 match tokio::time::timeout(
681 std::time::Duration::from_millis(50),
682 decision_rx.recv(),
683 )
684 .await
685 {
686 Ok(Some(msg)) => {
687 if let Some(decision) = registry.process_message(msg, target_id) {
688 return decision;
689 }
690 }
691 Ok(None) => return Self::default_decision(default_action),
692 Err(_) => {}
693 }
694 if cancel_rx.try_recv().is_ok() {
696 return Self::default_decision(default_action);
697 }
698 if start.elapsed() >= timeout {
699 return Self::default_decision(default_action);
700 }
701 }
702 } else {
703 loop {
704 if let Some(msg) = decision_rx.recv().await {
705 if let Some(decision) = registry.process_message(msg, target_id) {
706 return decision;
707 }
708 } else {
709 return Self::default_decision(default_action);
710 }
711 if cancel_rx.try_recv().is_ok() {
713 return Self::default_decision(default_action);
714 }
715 }
716 }
717 }
718
719 fn default_decision(action: &BarrierDefaultAction) -> BarrierDecision {
720 match action {
721 BarrierDefaultAction::Approve => BarrierDecision::Approve,
722 BarrierDefaultAction::Reject => BarrierDecision::Reject {
723 reason: "timeout — no decision received".into(),
724 },
725 BarrierDefaultAction::Skip => BarrierDecision::Approve,
726 }
727 }
728
729 fn resolve_next(
735 &self,
736 graph: &Graph,
737 current: &str,
738 state: &mut State,
739 edge_visits: &mut EdgeVisits,
740 next: NextStep,
741 ) -> Result<String, GraphError> {
742 match next {
743 NextStep::Goto(target) => {
744 match Self::transition(graph, current, &target, edge_visits)? {
745 EdgeTransitionResult::Ok => Ok(target),
746 EdgeTransitionResult::PolicyExceededStrict { edge, limit } => {
747 Err(GraphError::Terminal(TerminalError::EdgePolicyExceeded { edge, limit }))
748 }
749 EdgeTransitionResult::PolicyExceededSoftFallback { .. } => {
750 if let Some(fallback_target) = graph.find_fallback_edge(current) {
751 Ok(fallback_target)
752 } else {
753 Err(GraphError::Terminal(TerminalError::NodeExecutionFailed {
754 node: current.to_string(),
755 source:
756 "SoftFallback triggered but no fallback edge defined".into(),
757 }))
758 }
759 }
760 EdgeTransitionResult::Dropped => {
761 Err(GraphError::Terminal(TerminalError::InvalidGraph(
762 "edge transition dropped for Goto".into(),
763 )))
764 }
765 }
766 }
767 NextStep::GoToNext => {
768 let (target, policy) = Self::find_next_node(graph, current, state)?;
769 let result = edge_visits.record(current, &target, policy);
770 match result {
771 EdgeTransitionResult::Ok => Ok(target),
772 EdgeTransitionResult::PolicyExceededStrict { edge, limit } => {
773 Err(GraphError::Terminal(TerminalError::EdgePolicyExceeded { edge, limit }))
774 }
775 EdgeTransitionResult::PolicyExceededSoftFallback { .. } => {
776 if let Some(fallback_target) = graph.find_fallback_edge(current) {
777 Ok(fallback_target)
778 } else {
779 Err(GraphError::Terminal(TerminalError::NodeExecutionFailed {
780 node: current.to_string(),
781 source:
782 "SoftFallback triggered but no fallback edge defined".into(),
783 }))
784 }
785 }
786 EdgeTransitionResult::Dropped => {
787 Self::find_fallback_or_any(graph, current, state)
788 }
789 }
790 }
791 NextStep::End => {
792 Err(GraphError::Terminal(TerminalError::InvalidGraph(
793 "unexpected End next step".into(),
794 )))
795 }
796 }
797 }
798
799 fn find_fallback_or_any(
801 graph: &Graph,
802 current: &str,
803 state: &State,
804 ) -> Result<String, GraphError> {
805 let edges = graph.edges_from(current);
806
807 for edge in &edges {
809 if edge.fallback && (edge.condition.is_none() || edge.condition.as_ref().is_some_and(|c| c(state))) {
810 return Ok(edge.to.clone());
811 }
812 }
813
814 for edge in &edges {
816 if edge.fallback && edge.condition.is_none() {
817 return Ok(edge.to.clone());
818 }
819 }
820
821 for edge in &edges {
823 if !edge.fallback && (edge.condition.is_none() || edge.condition.as_ref().is_some_and(|c| c(state))) {
824 return Ok(edge.to.clone());
825 }
826 }
827
828 Err(GraphError::Terminal(TerminalError::Unrouted {
829 node: current.to_string(),
830 attempted_conditions: Vec::new(),
831 }))
832 }
833
834 fn transition(
840 graph: &Graph,
841 current: &str,
842 target: &str,
843 edge_visits: &mut EdgeVisits,
844 ) -> Result<EdgeTransitionResult, GraphError> {
845 let edge = graph.find_edge(current, target).ok_or_else(|| {
846 GraphError::Terminal(TerminalError::MissingEdge {
847 from: current.to_string(),
848 to: target.to_string(),
849 })
850 })?;
851
852 let result = edge_visits.record(current, target, edge.policy.as_ref());
853 Ok(result)
854 }
855
856 fn find_next_node<'a>(
865 graph: &'a Graph,
866 current: &str,
867 state: &State,
868 ) -> Result<(String, Option<&'a EdgePolicy>), GraphError> {
869 let edges = graph.edges_from(current);
870
871 if edges.is_empty() {
872 return Err(GraphError::Terminal(TerminalError::InvalidGraph(format!(
873 "node '{}' has no outgoing edges and is not the end node",
874 current
875 ))));
876 }
877
878 for edge in &edges {
880 if !edge.fallback
881 && edge.condition.as_ref().is_some_and(|c| c(state))
882 {
883 return Ok((edge.to.clone(), edge.policy.as_ref()));
884 }
885 }
886
887 for edge in &edges {
889 if !edge.fallback && edge.condition.is_none() {
890 return Ok((edge.to.clone(), edge.policy.as_ref()));
891 }
892 }
893
894 for edge in &edges {
896 if edge.fallback
897 && (edge.condition.is_none() || edge.condition.as_ref().is_some_and(|c| c(state)))
898 {
899 return Ok((edge.to.clone(), edge.policy.as_ref()));
900 }
901 }
902
903 for edge in &edges {
905 if edge.fallback && edge.condition.is_none() {
906 return Ok((edge.to.clone(), edge.policy.as_ref()));
907 }
908 }
909
910 let attempted: Vec<crate::error::ConditionEval> = edges
912 .iter()
913 .map(|e| crate::error::ConditionEval {
914 edge: format!("{}→{}", e.from, e.to),
915 condition: e.condition.as_ref().map(|_| "condition".to_string()),
916 matched: e.condition.as_ref().map_or(false, |c| c(state)),
917 })
918 .collect();
919
920 Err(GraphError::Terminal(TerminalError::Unrouted {
921 node: current.to_string(),
922 attempted_conditions: attempted,
923 }))
924 }
925}
926
927