1use std::collections::{HashMap, HashSet};
8use std::sync::Arc;
9
10use entelix_core::{Error, ExecutionContext, Result, ThreadKey};
11use entelix_runnable::Runnable;
12use entelix_runnable::stream::{BoxStream, DebugEvent, RunnableEvent, StreamChunk, StreamMode};
13
14use crate::checkpoint::{Checkpoint, CheckpointId, Checkpointer};
15use crate::command::Command;
16use crate::finalizing_stream::FinalizingStream;
17use crate::state_graph::END;
18
19pub type EdgeSelector<S> = Arc<dyn Fn(&S) -> String + Send + Sync>;
21
22pub struct ConditionalEdge<S>
28where
29 S: Clone + Send + Sync + 'static,
30{
31 pub selector: EdgeSelector<S>,
33 pub mapping: HashMap<String, String>,
36}
37
38pub type SendSelector<S> = Arc<dyn Fn(&S) -> Vec<(String, S)> + Send + Sync>;
42
43pub type SendMerger<S> = Arc<dyn Fn(S, S) -> S + Send + Sync>;
53
54pub struct SendEdge<S>
70where
71 S: Clone + Send + Sync + 'static,
72{
73 targets: Vec<String>,
77 targets_set: HashSet<String>,
80 pub selector: SendSelector<S>,
83 pub merger: SendMerger<S>,
86 pub join: String,
89}
90
91impl<S> SendEdge<S>
92where
93 S: Clone + Send + Sync + 'static,
94{
95 pub fn new(
100 targets: impl IntoIterator<Item = String>,
101 selector: SendSelector<S>,
102 merger: SendMerger<S>,
103 join: String,
104 ) -> Self {
105 let mut ordered: Vec<String> = Vec::new();
106 let mut set: HashSet<String> = HashSet::new();
107 for t in targets {
108 if set.insert(t.clone()) {
109 ordered.push(t);
110 }
111 }
112 Self {
113 targets: ordered,
114 targets_set: set,
115 selector,
116 merger,
117 join,
118 }
119 }
120
121 pub fn targets(&self) -> &[String] {
125 &self.targets
126 }
127
128 pub fn has_target(&self, name: &str) -> bool {
131 self.targets_set.contains(name)
132 }
133}
134
135pub struct CompiledGraph<S>
139where
140 S: Clone + Send + Sync + 'static,
141{
142 nodes: HashMap<String, Arc<dyn Runnable<S, S>>>,
143 edges: HashMap<String, String>,
144 conditional_edges: HashMap<String, ConditionalEdge<S>>,
145 send_edges: HashMap<String, SendEdge<S>>,
146 entry_point: String,
147 finish_points: HashSet<String>,
148 recursion_limit: usize,
149 checkpointer: Option<Arc<dyn Checkpointer<S>>>,
150 checkpoint_granularity: crate::state_graph::CheckpointGranularity,
151 interrupt_before: HashSet<String>,
152 interrupt_after: HashSet<String>,
153}
154
155impl<S> std::fmt::Debug for CompiledGraph<S>
156where
157 S: Clone + Send + Sync + 'static,
158{
159 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
160 f.debug_struct("CompiledGraph")
166 .field("nodes", &sorted_keys(&self.nodes))
167 .field("edges", &sorted_pairs(&self.edges))
168 .field("conditional_edges", &sorted_keys(&self.conditional_edges))
169 .field("send_edges", &sorted_keys(&self.send_edges))
170 .field("entry_point", &self.entry_point)
171 .field("finish_points", &sorted_set(&self.finish_points))
172 .field("recursion_limit", &self.recursion_limit)
173 .field("has_checkpointer", &self.checkpointer.is_some())
174 .field("checkpoint_granularity", &self.checkpoint_granularity)
175 .field("interrupt_before", &sorted_set(&self.interrupt_before))
176 .field("interrupt_after", &sorted_set(&self.interrupt_after))
177 .finish()
178 }
179}
180
181fn sorted_keys<V>(m: &HashMap<String, V>) -> Vec<&String> {
182 let mut out: Vec<&String> = m.keys().collect();
183 out.sort();
184 out
185}
186
187fn sorted_pairs(m: &HashMap<String, String>) -> Vec<(&String, &String)> {
188 let mut out: Vec<(&String, &String)> = m.iter().collect();
189 out.sort_by_key(|(k, _)| k.as_str());
190 out
191}
192
193fn sorted_set(s: &HashSet<String>) -> Vec<&String> {
194 let mut out: Vec<&String> = s.iter().collect();
195 out.sort();
196 out
197}
198
199impl<S> CompiledGraph<S>
200where
201 S: Clone + Send + Sync + 'static,
202{
203 #[allow(clippy::too_many_arguments)]
204 pub(crate) fn new(
205 nodes: HashMap<String, Arc<dyn Runnable<S, S>>>,
206 edges: HashMap<String, String>,
207 conditional_edges: HashMap<String, ConditionalEdge<S>>,
208 send_edges: HashMap<String, SendEdge<S>>,
209 entry_point: String,
210 finish_points: HashSet<String>,
211 recursion_limit: usize,
212 checkpointer: Option<Arc<dyn Checkpointer<S>>>,
213 checkpoint_granularity: crate::state_graph::CheckpointGranularity,
214 interrupt_before: HashSet<String>,
215 interrupt_after: HashSet<String>,
216 ) -> Self {
217 Self {
218 nodes,
219 edges,
220 conditional_edges,
221 send_edges,
222 entry_point,
223 finish_points,
224 recursion_limit,
225 checkpointer,
226 checkpoint_granularity,
227 interrupt_before,
228 interrupt_after,
229 }
230 }
231
232 pub const fn checkpoint_granularity(&self) -> crate::state_graph::CheckpointGranularity {
235 self.checkpoint_granularity
236 }
237
238 pub fn entry_point(&self) -> &str {
240 &self.entry_point
241 }
242
243 pub const fn recursion_limit(&self) -> usize {
245 self.recursion_limit
246 }
247
248 pub fn finish_point_count(&self) -> usize {
250 self.finish_points.len()
251 }
252
253 pub fn has_checkpointer(&self) -> bool {
255 self.checkpointer.is_some()
256 }
257
258 pub async fn resume(&self, ctx: &ExecutionContext) -> Result<S> {
263 self.resume_with(Command::Resume, ctx).await
264 }
265
266 pub async fn resume_with(&self, command: Command<S>, ctx: &ExecutionContext) -> Result<S> {
278 let checkpointer = self
279 .checkpointer
280 .as_ref()
281 .ok_or_else(|| Error::config("CompiledGraph::resume requires a Checkpointer"))?;
282 let key = ThreadKey::from_ctx(ctx)?;
283 let latest = checkpointer.get_latest(&key).await?.ok_or_else(|| {
284 Error::invalid_request(format!(
285 "CompiledGraph::resume: no checkpoint exists for tenant '{}' thread '{}'",
286 key.tenant_id(),
287 key.thread_id()
288 ))
289 })?;
290 self.dispatch_from_checkpoint(latest, command, ctx).await
291 }
292
293 pub async fn resume_from(
298 &self,
299 checkpoint_id: &CheckpointId,
300 command: Command<S>,
301 ctx: &ExecutionContext,
302 ) -> Result<S> {
303 let checkpointer = self
304 .checkpointer
305 .as_ref()
306 .ok_or_else(|| Error::config("CompiledGraph::resume_from requires a Checkpointer"))?;
307 let key = ThreadKey::from_ctx(ctx)?;
308 let cp = checkpointer
309 .get_by_id(&key, checkpoint_id)
310 .await?
311 .ok_or_else(|| {
312 Error::invalid_request(format!(
313 "CompiledGraph::resume_from: checkpoint not found in tenant '{}' thread '{}'",
314 key.tenant_id(),
315 key.thread_id()
316 ))
317 })?;
318 self.dispatch_from_checkpoint(cp, command, ctx).await
319 }
320
321 async fn dispatch_from_checkpoint(
322 &self,
323 checkpoint: Checkpoint<S>,
324 command: Command<S>,
325 ctx: &ExecutionContext,
326 ) -> Result<S> {
327 if let Some(handle) = ctx.audit_sink() {
331 handle
332 .as_sink()
333 .record_resumed(&checkpoint.id.to_hyphenated_string());
334 }
335 let mut scoped_ctx: Option<ExecutionContext> = None;
340 let (state, next_node) = match command {
341 Command::Resume => (checkpoint.state, checkpoint.next_node),
342 Command::Update(s) => (s, checkpoint.next_node),
343 Command::GoTo(node) => (checkpoint.state, Some(node)),
344 Command::ApproveTool {
345 tool_use_id,
346 decision,
347 } => {
348 if matches!(decision, entelix_core::ApprovalDecision::AwaitExternal) {
349 return Err(Error::invalid_request(
350 "Command::ApproveTool: AwaitExternal is not a valid resume \
351 decision — pausing again on resume defeats the purpose. \
352 Supply Approve or Reject{reason}.",
353 ));
354 }
355 let mut pending = ctx
356 .extension::<entelix_core::PendingApprovalDecisions>()
357 .map(|h| (*h).clone())
358 .unwrap_or_default();
359 pending.insert(tool_use_id, decision);
360 scoped_ctx = Some(ctx.clone().add_extension(pending));
361 (checkpoint.state, checkpoint.next_node)
362 }
363 };
364 let effective_ctx = scoped_ctx.as_ref().unwrap_or(ctx);
365 match next_node {
366 None => Ok(state),
367 Some(next) => {
368 self.execute_loop_inner(
373 state,
374 next,
375 checkpoint.step.saturating_add(1),
376 effective_ctx,
377 true,
378 )
379 .await
380 }
381 }
382 }
383
384 async fn execute_loop(
385 &self,
386 state: S,
387 current: String,
388 step_offset: usize,
389 ctx: &ExecutionContext,
390 ) -> Result<S> {
391 self.execute_loop_inner(state, current, step_offset, ctx, false)
392 .await
393 }
394
395 #[allow(clippy::too_many_lines)]
396 async fn execute_loop_inner(
397 &self,
398 mut state: S,
399 mut current: String,
400 step_offset: usize,
401 ctx: &ExecutionContext,
402 mut skip_interrupt_before_for_current: bool,
403 ) -> Result<S> {
404 let effective_recursion_limit = ctx
411 .extension::<entelix_core::RunOverrides>()
412 .and_then(|o| o.max_iterations())
413 .map_or(self.recursion_limit, |n| n.min(self.recursion_limit));
414 let mut steps_in_call: usize = 0;
415 loop {
416 if ctx.is_cancelled() {
417 return Err(Error::Cancelled);
418 }
419 if steps_in_call >= effective_recursion_limit {
420 return Err(Error::invalid_request(format!(
421 "StateGraph: recursion limit ({effective_recursion_limit}) exceeded — possible infinite cycle"
422 )));
423 }
424 steps_in_call = steps_in_call.saturating_add(1);
425 let total_step = step_offset.saturating_add(steps_in_call);
426
427 let node = self.nodes.get(¤t).ok_or_else(|| {
428 Error::invalid_request(format!(
429 "StateGraph: control reached unknown node '{current}'"
430 ))
431 })?;
432
433 let pre_state = if self.checkpointer.is_some() && ctx.thread_id().is_some() {
436 Some(state.clone())
437 } else {
438 None
439 };
440
441 if self.interrupt_before.contains(¤t) && !skip_interrupt_before_for_current {
448 if let (Some(cp), Some(thread_id), Some(pre)) =
449 (&self.checkpointer, ctx.thread_id(), pre_state.clone())
450 {
451 let key = ThreadKey::new(ctx.tenant_id().clone(), thread_id);
452 cp.put(Checkpoint::new(
453 &key,
454 total_step,
455 pre,
456 Some(current.clone()),
457 ))
458 .await?;
459 }
460 return Err(Error::Interrupted {
461 kind: entelix_core::InterruptionKind::ScheduledPause {
462 phase: entelix_core::InterruptionPhase::Before,
463 node: current.clone(),
464 },
465 payload: serde_json::Value::Null,
466 });
467 }
468 skip_interrupt_before_for_current = false;
470
471 match node.invoke(state, ctx).await {
472 Ok(new_state) => state = new_state,
473 Err(Error::Interrupted { kind, payload }) => {
474 if let (Some(cp), Some(thread_id), Some(pre)) =
478 (&self.checkpointer, ctx.thread_id(), pre_state)
479 {
480 let key = ThreadKey::new(ctx.tenant_id().clone(), thread_id);
481 cp.put(Checkpoint::new(
482 &key,
483 total_step,
484 pre,
485 Some(current.clone()),
486 ))
487 .await?;
488 }
489 return Err(Error::Interrupted { kind, payload });
490 }
491 Err(other) => return Err(other),
492 }
493
494 if self.interrupt_after.contains(¤t) && !self.send_edges.contains_key(¤t) {
500 let next_node = self.resolve_next_node(¤t, &state)?;
501 if let (Some(cp), Some(thread_id)) = (&self.checkpointer, ctx.thread_id()) {
502 let key = ThreadKey::new(ctx.tenant_id().clone(), thread_id);
503 cp.put(Checkpoint::new(
504 &key,
505 total_step,
506 state.clone(),
507 next_node.clone(),
508 ))
509 .await?;
510 }
511 return Err(Error::Interrupted {
512 kind: entelix_core::InterruptionKind::ScheduledPause {
513 phase: entelix_core::InterruptionPhase::After,
514 node: current.clone(),
515 },
516 payload: serde_json::Value::Null,
517 });
518 }
519
520 if let Some(send) = self.send_edges.get(¤t) {
530 state = self.execute_send_edge(send, state, ctx).await?;
531 if send.join == END {
532 self.emit_depth_histogram(steps_in_call, ctx);
533 return Ok(state);
534 }
535 current = send.join.clone();
536 continue;
537 }
538
539 let next_node = self.resolve_next_node(¤t, &state)?;
542
543 let granularity_writes = matches!(
550 self.checkpoint_granularity,
551 crate::state_graph::CheckpointGranularity::PerNode
552 );
553 if granularity_writes
554 && let (Some(cp), Some(thread_id)) = (&self.checkpointer, ctx.thread_id())
555 {
556 let key = ThreadKey::new(ctx.tenant_id().clone(), thread_id);
557 cp.put(Checkpoint::new(
558 &key,
559 total_step,
560 state.clone(),
561 next_node.clone(),
562 ))
563 .await?;
564 }
565
566 match next_node {
567 None => {
568 self.emit_depth_histogram(steps_in_call, ctx);
569 return Ok(state);
570 }
571 Some(next) => current = next,
572 }
573 }
574 }
575
576 async fn execute_send_edge(
583 &self,
584 send: &SendEdge<S>,
585 state: S,
586 ctx: &ExecutionContext,
587 ) -> Result<S> {
588 let branches = (send.selector)(&state);
589 if branches.is_empty() {
590 return Ok(state);
591 }
592 for (target, _) in &branches {
597 if !send.has_target(target) {
598 return Err(Error::invalid_request(format!(
599 "StateGraph: send edge dispatched to '{target}' which is not in the \
600 declared target set {:?}",
601 send.targets()
602 )));
603 }
604 if !self.nodes.contains_key(target) {
608 return Err(Error::invalid_request(format!(
609 "StateGraph: send edge dispatched to unknown node '{target}'"
610 )));
611 }
612 }
613 let scope_ctx = ctx.child();
618 let futures = branches
619 .into_iter()
620 .map(|(target, branch_state)| {
621 let node = self.nodes.get(&target).map(Arc::clone).ok_or_else(|| {
622 Error::invalid_request(format!(
623 "StateGraph: send edge dispatched to unknown node '{target}'"
624 ))
625 })?;
626 let scope_ctx = scope_ctx.clone();
627 Ok::<_, Error>(async move { node.invoke(branch_state, &scope_ctx).await })
628 })
629 .collect::<Result<Vec<_>>>()?;
630 let branch_states = futures::future::try_join_all(futures).await?;
631 let mut folded = state;
639 for branch in branch_states {
640 folded = (send.merger)(folded, branch);
641 }
642 Ok(folded)
643 }
644
645 fn resolve_next_node(&self, current: &str, state: &S) -> Result<Option<String>> {
651 if self.finish_points.contains(current) {
652 return Ok(None);
653 }
654 if let Some(cond) = self.conditional_edges.get(current) {
655 let key = (cond.selector)(state);
656 let target = cond.mapping.get(&key).ok_or_else(|| {
657 Error::invalid_request(format!(
658 "StateGraph: conditional edge from '{current}' returned key '{key}' \
659 which is not present in the mapping"
660 ))
661 })?;
662 return Ok(if target == END {
663 None
664 } else {
665 Some(target.clone())
666 });
667 }
668 let target = self.edges.get(current).ok_or_else(|| {
669 Error::invalid_request(format!(
670 "StateGraph: node '{current}' has no outgoing edge and is not terminal"
671 ))
672 })?;
673 Ok(Some(target.clone()))
674 }
675
676 fn emit_depth_histogram(&self, depth: usize, ctx: &ExecutionContext) {
683 tracing::event!(
684 target: "entelix_graph::compiled",
685 tracing::Level::DEBUG,
686 entelix.graph.depth = depth,
687 entelix.graph.recursion_limit = self.recursion_limit,
688 entelix.tenant_id = ctx.tenant_id().as_str(),
689 entelix.thread_id = ctx.thread_id(),
690 entelix.run_id = ctx.run_id(),
691 "entelix.graph.run_complete"
692 );
693 }
694}
695
696#[async_trait::async_trait]
697impl<S> Runnable<S, S> for CompiledGraph<S>
698where
699 S: Clone + Send + Sync + 'static,
700{
701 async fn invoke(&self, input: S, ctx: &ExecutionContext) -> Result<S> {
702 self.execute_loop(input, self.entry_point.clone(), 0, ctx)
703 .await
704 }
705
706 async fn stream(
707 &self,
708 input: S,
709 mode: StreamMode,
710 ctx: &ExecutionContext,
711 ) -> Result<BoxStream<'_, Result<StreamChunk<S>>>> {
712 Ok(Box::pin(self.build_stream(input, mode, ctx.clone())))
713 }
714}
715
716const GRAPH_STREAM_NAME: &str = "CompiledGraph";
717
718fn finished<S>(ok: bool) -> StreamChunk<S> {
719 StreamChunk::Event(RunnableEvent::Finished {
720 name: GRAPH_STREAM_NAME.to_owned(),
721 ok,
722 })
723}
724
725impl<S> CompiledGraph<S>
726where
727 S: Clone + Send + Sync + 'static,
728{
729 #[allow(
730 clippy::too_many_lines,
731 clippy::single_match_else,
732 clippy::manual_let_else,
733 tail_expr_drop_order
734 )]
735 fn build_stream(
736 &self,
737 input: S,
738 mode: StreamMode,
739 ctx: ExecutionContext,
740 ) -> impl futures::Stream<Item = Result<StreamChunk<S>>> + Send + '_ {
741 let entry = self.entry_point.clone();
742 let finalize_tenant = ctx.tenant_id().clone();
747 let finalize_thread = ctx.thread_id().map(str::to_owned);
748 let finalize_mode = mode;
749 let effective_recursion_limit = ctx
750 .extension::<entelix_core::RunOverrides>()
751 .and_then(|o| o.max_iterations())
752 .map_or(self.recursion_limit, |n| n.min(self.recursion_limit));
753 let inner = async_stream::stream! {
754 let mut state = input;
755 let mut current = entry;
756 let mut steps_in_call: usize = 0;
757
758 if matches!(mode, StreamMode::Events) {
759 yield Ok(StreamChunk::Event(RunnableEvent::Started {
760 name: GRAPH_STREAM_NAME.to_owned(),
761 }));
762 }
763
764 loop {
765 if ctx.is_cancelled() {
766 if matches!(mode, StreamMode::Events) {
767 yield Ok(finished::<S>(false));
768 }
769 yield Err(Error::Cancelled);
770 return;
771 }
772 if steps_in_call >= effective_recursion_limit {
773 if matches!(mode, StreamMode::Events) {
774 yield Ok(finished::<S>(false));
775 }
776 yield Err(Error::invalid_request(format!(
777 "StateGraph: recursion limit ({effective_recursion_limit}) exceeded — possible infinite cycle"
778 )));
779 return;
780 }
781 steps_in_call = steps_in_call.saturating_add(1);
782
783 if matches!(mode, StreamMode::Debug) {
784 yield Ok(StreamChunk::Debug(DebugEvent::NodeStart {
785 node: current.clone(),
786 step: steps_in_call,
787 }));
788 }
789
790 let Some(node) = self.nodes.get(¤t) else {
791 yield Err(Error::invalid_request(format!(
792 "StateGraph: control reached unknown node '{current}'"
793 )));
794 return;
795 };
796
797 match node.invoke(state, &ctx).await {
798 Ok(s) => state = s,
799 Err(e) => {
800 if matches!(mode, StreamMode::Events) {
801 yield Ok(finished::<S>(false));
802 }
803 yield Err(e);
804 return;
805 }
806 }
807
808 match mode {
809 StreamMode::Values => {
810 yield Ok(StreamChunk::Value(state.clone()));
811 }
812 StreamMode::Updates => {
813 yield Ok(StreamChunk::Update {
814 node: current.clone(),
815 value: state.clone(),
816 });
817 }
818 StreamMode::Debug => {
819 yield Ok(StreamChunk::Debug(DebugEvent::NodeEnd {
820 node: current.clone(),
821 step: steps_in_call,
822 }));
823 }
824 _ => {}
825 }
826
827 if let Some(send) = self.send_edges.get(¤t) {
831 match self.execute_send_edge(send, state.clone(), &ctx).await {
832 Ok(merged) => state = merged,
833 Err(e) => {
834 if matches!(mode, StreamMode::Events) {
835 yield Ok(finished::<S>(false));
836 }
837 yield Err(e);
838 return;
839 }
840 }
841 if send.join == END {
842 self.emit_depth_histogram(steps_in_call, &ctx);
843 match mode {
844 StreamMode::Debug => {
845 yield Ok(StreamChunk::Debug(DebugEvent::Final));
846 }
847 StreamMode::Events => {
848 yield Ok(finished::<S>(true));
849 }
850 StreamMode::Messages => {
851 yield Ok(StreamChunk::Value(state));
852 }
853 _ => {}
854 }
855 return;
856 }
857 current = send.join.clone();
858 continue;
859 }
860
861 let next_node = match self.resolve_next_node(¤t, &state) {
862 Ok(n) => n,
863 Err(e) => {
864 if matches!(mode, StreamMode::Events) {
865 yield Ok(finished::<S>(false));
866 }
867 yield Err(e);
868 return;
869 }
870 };
871
872 if let Some(next) = next_node {
873 current = next;
874 } else {
875 self.emit_depth_histogram(steps_in_call, &ctx);
876 match mode {
877 StreamMode::Debug => {
878 yield Ok(StreamChunk::Debug(DebugEvent::Final));
879 }
880 StreamMode::Events => {
881 yield Ok(finished::<S>(true));
882 }
883 StreamMode::Messages => {
884 yield Ok(StreamChunk::Value(state));
885 }
886 _ => {}
887 }
888 return;
889 }
890 }
891 };
892 FinalizingStream::new(inner, move || {
893 tracing::debug!(
894 target: "entelix_graph::stream",
895 tenant_id = %finalize_tenant,
896 thread_id = ?finalize_thread,
897 mode = ?finalize_mode,
898 "graph stream dropped before completion"
899 );
900 })
901 }
902}