1use std::collections::{HashMap, HashSet};
8use std::sync::Arc;
9
10use entelix_core::{Error, ExecutionContext, Result, ThreadKey};
11use entelix_runnable::Runnable;
12use entelix_runnable::stream::{BoxStream, DebugEvent, RunnableEvent, StreamChunk, StreamMode};
13
14use crate::checkpoint::{Checkpoint, CheckpointId, Checkpointer};
15use crate::command::Command;
16use crate::finalizing_stream::FinalizingStream;
17use crate::state_graph::END;
18
19pub type EdgeSelector<S> = Arc<dyn Fn(&S) -> String + Send + Sync>;
21
22pub struct ConditionalEdge<S>
28where
29 S: Clone + Send + Sync + 'static,
30{
31 pub selector: EdgeSelector<S>,
33 pub mapping: HashMap<String, String>,
36}
37
38pub type SendSelector<S> = Arc<dyn Fn(&S) -> Vec<(String, S)> + Send + Sync>;
42
43pub type SendMerger<S> = Arc<dyn Fn(S, S) -> S + Send + Sync>;
53
54pub struct SendEdge<S>
70where
71 S: Clone + Send + Sync + 'static,
72{
73 targets: Vec<String>,
77 targets_set: HashSet<String>,
80 pub selector: SendSelector<S>,
83 pub merger: SendMerger<S>,
86 pub join: String,
89}
90
91impl<S> SendEdge<S>
92where
93 S: Clone + Send + Sync + 'static,
94{
95 pub fn new(
100 targets: impl IntoIterator<Item = String>,
101 selector: SendSelector<S>,
102 merger: SendMerger<S>,
103 join: String,
104 ) -> Self {
105 let mut ordered: Vec<String> = Vec::new();
106 let mut set: HashSet<String> = HashSet::new();
107 for t in targets {
108 if set.insert(t.clone()) {
109 ordered.push(t);
110 }
111 }
112 Self {
113 targets: ordered,
114 targets_set: set,
115 selector,
116 merger,
117 join,
118 }
119 }
120
121 pub fn targets(&self) -> &[String] {
125 &self.targets
126 }
127
128 pub fn has_target(&self, name: &str) -> bool {
131 self.targets_set.contains(name)
132 }
133}
134
135pub struct CompiledGraph<S>
139where
140 S: Clone + Send + Sync + 'static,
141{
142 nodes: HashMap<String, Arc<dyn Runnable<S, S>>>,
143 edges: HashMap<String, String>,
144 conditional_edges: HashMap<String, ConditionalEdge<S>>,
145 send_edges: HashMap<String, SendEdge<S>>,
146 entry_point: String,
147 finish_points: HashSet<String>,
148 recursion_limit: usize,
149 checkpointer: Option<Arc<dyn Checkpointer<S>>>,
150 checkpoint_granularity: crate::state_graph::CheckpointGranularity,
151 interrupt_before: HashSet<String>,
152 interrupt_after: HashSet<String>,
153}
154
155impl<S> std::fmt::Debug for CompiledGraph<S>
156where
157 S: Clone + Send + Sync + 'static,
158{
159 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
160 f.debug_struct("CompiledGraph")
166 .field("nodes", &sorted_keys(&self.nodes))
167 .field("edges", &sorted_pairs(&self.edges))
168 .field("conditional_edges", &sorted_keys(&self.conditional_edges))
169 .field("send_edges", &sorted_keys(&self.send_edges))
170 .field("entry_point", &self.entry_point)
171 .field("finish_points", &sorted_set(&self.finish_points))
172 .field("recursion_limit", &self.recursion_limit)
173 .field("has_checkpointer", &self.checkpointer.is_some())
174 .field("checkpoint_granularity", &self.checkpoint_granularity)
175 .field("interrupt_before", &sorted_set(&self.interrupt_before))
176 .field("interrupt_after", &sorted_set(&self.interrupt_after))
177 .finish()
178 }
179}
180
181fn sorted_keys<V>(m: &HashMap<String, V>) -> Vec<&String> {
182 let mut out: Vec<&String> = m.keys().collect();
183 out.sort();
184 out
185}
186
187fn sorted_pairs(m: &HashMap<String, String>) -> Vec<(&String, &String)> {
188 let mut out: Vec<(&String, &String)> = m.iter().collect();
189 out.sort_by_key(|(k, _)| k.as_str());
190 out
191}
192
193fn sorted_set(s: &HashSet<String>) -> Vec<&String> {
194 let mut out: Vec<&String> = s.iter().collect();
195 out.sort();
196 out
197}
198
199impl<S> CompiledGraph<S>
200where
201 S: Clone + Send + Sync + 'static,
202{
203 #[allow(clippy::too_many_arguments)]
204 pub(crate) fn new(
205 nodes: HashMap<String, Arc<dyn Runnable<S, S>>>,
206 edges: HashMap<String, String>,
207 conditional_edges: HashMap<String, ConditionalEdge<S>>,
208 send_edges: HashMap<String, SendEdge<S>>,
209 entry_point: String,
210 finish_points: HashSet<String>,
211 recursion_limit: usize,
212 checkpointer: Option<Arc<dyn Checkpointer<S>>>,
213 checkpoint_granularity: crate::state_graph::CheckpointGranularity,
214 interrupt_before: HashSet<String>,
215 interrupt_after: HashSet<String>,
216 ) -> Self {
217 Self {
218 nodes,
219 edges,
220 conditional_edges,
221 send_edges,
222 entry_point,
223 finish_points,
224 recursion_limit,
225 checkpointer,
226 checkpoint_granularity,
227 interrupt_before,
228 interrupt_after,
229 }
230 }
231
232 pub const fn checkpoint_granularity(&self) -> crate::state_graph::CheckpointGranularity {
235 self.checkpoint_granularity
236 }
237
238 pub fn entry_point(&self) -> &str {
240 &self.entry_point
241 }
242
243 pub const fn recursion_limit(&self) -> usize {
245 self.recursion_limit
246 }
247
248 pub fn finish_point_count(&self) -> usize {
250 self.finish_points.len()
251 }
252
253 pub fn has_checkpointer(&self) -> bool {
255 self.checkpointer.is_some()
256 }
257
258 pub async fn resume(&self, ctx: &ExecutionContext) -> Result<S> {
263 self.resume_with(Command::Resume, ctx).await
264 }
265
266 pub async fn resume_with(&self, command: Command<S>, ctx: &ExecutionContext) -> Result<S> {
278 let checkpointer = self
279 .checkpointer
280 .as_ref()
281 .ok_or_else(|| Error::config("CompiledGraph::resume requires a Checkpointer"))?;
282 let key = ThreadKey::from_ctx(ctx)?;
283 let latest = checkpointer.get_latest(&key).await?.ok_or_else(|| {
284 Error::invalid_request(format!(
285 "CompiledGraph::resume: no checkpoint exists for tenant '{}' thread '{}'",
286 key.tenant_id(),
287 key.thread_id()
288 ))
289 })?;
290 self.dispatch_from_checkpoint(latest, command, ctx).await
291 }
292
293 pub async fn resume_from(
298 &self,
299 checkpoint_id: &CheckpointId,
300 command: Command<S>,
301 ctx: &ExecutionContext,
302 ) -> Result<S> {
303 let checkpointer = self
304 .checkpointer
305 .as_ref()
306 .ok_or_else(|| Error::config("CompiledGraph::resume_from requires a Checkpointer"))?;
307 let key = ThreadKey::from_ctx(ctx)?;
308 let cp = checkpointer
309 .get_by_id(&key, checkpoint_id)
310 .await?
311 .ok_or_else(|| {
312 Error::invalid_request(format!(
313 "CompiledGraph::resume_from: checkpoint not found in tenant '{}' thread '{}'",
314 key.tenant_id(),
315 key.thread_id()
316 ))
317 })?;
318 self.dispatch_from_checkpoint(cp, command, ctx).await
319 }
320
321 async fn dispatch_from_checkpoint(
322 &self,
323 checkpoint: Checkpoint<S>,
324 command: Command<S>,
325 ctx: &ExecutionContext,
326 ) -> Result<S> {
327 if let Some(handle) = ctx.audit_sink() {
331 handle
332 .as_sink()
333 .record_resumed(&checkpoint.id.to_hyphenated_string());
334 }
335 let mut scoped_ctx: Option<ExecutionContext> = None;
340 let (state, next_node) = match command {
341 Command::Resume => (checkpoint.state, checkpoint.next_node),
342 Command::Update(s) => (s, checkpoint.next_node),
343 Command::GoTo(node) => (checkpoint.state, Some(node)),
344 Command::ApproveTool {
345 tool_use_id,
346 decision,
347 } => {
348 if matches!(decision, entelix_core::ApprovalDecision::AwaitExternal) {
349 return Err(Error::invalid_request(
350 "Command::ApproveTool: AwaitExternal is not a valid resume \
351 decision — pausing again on resume defeats the purpose. \
352 Supply Approve or Reject{reason}.",
353 ));
354 }
355 let mut pending = ctx
356 .extension::<entelix_core::PendingApprovalDecisions>()
357 .map(|h| (*h).clone())
358 .unwrap_or_default();
359 pending.insert(tool_use_id, decision);
360 scoped_ctx = Some(ctx.clone().add_extension(pending));
361 (checkpoint.state, checkpoint.next_node)
362 }
363 };
364 let effective_ctx = scoped_ctx.as_ref().unwrap_or(ctx);
365 match next_node {
366 None => Ok(state),
367 Some(next) => {
368 self.execute_loop_inner(
373 state,
374 next,
375 checkpoint.step.saturating_add(1),
376 effective_ctx,
377 true,
378 )
379 .await
380 }
381 }
382 }
383
384 async fn execute_loop(
385 &self,
386 state: S,
387 current: String,
388 step_offset: usize,
389 ctx: &ExecutionContext,
390 ) -> Result<S> {
391 self.execute_loop_inner(state, current, step_offset, ctx, false)
392 .await
393 }
394
395 #[allow(clippy::too_many_lines)]
396 async fn execute_loop_inner(
397 &self,
398 mut state: S,
399 mut current: String,
400 step_offset: usize,
401 ctx: &ExecutionContext,
402 mut skip_interrupt_before_for_current: bool,
403 ) -> Result<S> {
404 let effective_recursion_limit = ctx
411 .extension::<entelix_core::RunOverrides>()
412 .and_then(|o| o.max_iterations())
413 .map_or(self.recursion_limit, |n| n.min(self.recursion_limit));
414 let mut steps_in_call: usize = 0;
415 loop {
416 if ctx.is_cancelled() {
417 return Err(Error::Cancelled);
418 }
419 if steps_in_call >= effective_recursion_limit {
420 return Err(Error::invalid_request(format!(
421 "StateGraph: recursion limit ({effective_recursion_limit}) exceeded — possible infinite cycle"
422 )));
423 }
424 steps_in_call = steps_in_call.saturating_add(1);
425 let total_step = step_offset.saturating_add(steps_in_call);
426
427 let node = self.nodes.get(¤t).ok_or_else(|| {
428 Error::invalid_request(format!(
429 "StateGraph: control reached unknown node '{current}'"
430 ))
431 })?;
432
433 let pre_state = if self.checkpointer.is_some() && ctx.thread_id().is_some() {
436 Some(state.clone())
437 } else {
438 None
439 };
440
441 if self.interrupt_before.contains(¤t) && !skip_interrupt_before_for_current {
448 if let (Some(cp), Some(thread_id), Some(pre)) =
449 (&self.checkpointer, ctx.thread_id(), pre_state.clone())
450 {
451 let key = ThreadKey::new(ctx.tenant_id().clone(), thread_id);
452 cp.put(Checkpoint::new(
453 &key,
454 total_step,
455 pre,
456 Some(current.clone()),
457 ))
458 .await?;
459 }
460 let kind = entelix_core::InterruptionKind::ScheduledPause {
461 phase: entelix_core::InterruptionPhase::Before,
462 node: current.clone(),
463 };
464 let payload = serde_json::Value::Null;
465 if let Some(handle) = ctx.audit_sink() {
466 handle.as_sink().record_interrupted(&kind, &payload);
467 }
468 return Err(Error::Interrupted { kind, payload });
469 }
470 skip_interrupt_before_for_current = false;
472
473 match node.invoke(state, ctx).await {
474 Ok(new_state) => state = new_state,
475 Err(Error::Interrupted { kind, payload }) => {
476 if let (Some(cp), Some(thread_id), Some(pre)) =
490 (&self.checkpointer, ctx.thread_id(), pre_state)
491 {
492 let key = ThreadKey::new(ctx.tenant_id().clone(), thread_id);
493 cp.put(Checkpoint::new(
494 &key,
495 total_step,
496 pre,
497 Some(current.clone()),
498 ))
499 .await?;
500 }
501 return Err(Error::Interrupted { kind, payload });
502 }
503 Err(other) => return Err(other),
504 }
505
506 if self.interrupt_after.contains(¤t) && !self.send_edges.contains_key(¤t) {
512 let next_node = self.resolve_next_node(¤t, &state)?;
513 if let (Some(cp), Some(thread_id)) = (&self.checkpointer, ctx.thread_id()) {
514 let key = ThreadKey::new(ctx.tenant_id().clone(), thread_id);
515 cp.put(Checkpoint::new(
516 &key,
517 total_step,
518 state.clone(),
519 next_node.clone(),
520 ))
521 .await?;
522 }
523 let kind = entelix_core::InterruptionKind::ScheduledPause {
524 phase: entelix_core::InterruptionPhase::After,
525 node: current.clone(),
526 };
527 let payload = serde_json::Value::Null;
528 if let Some(handle) = ctx.audit_sink() {
529 handle.as_sink().record_interrupted(&kind, &payload);
530 }
531 return Err(Error::Interrupted { kind, payload });
532 }
533
534 if let Some(send) = self.send_edges.get(¤t) {
544 state = self.execute_send_edge(send, state, ctx).await?;
545 if send.join == END {
546 self.emit_depth_histogram(steps_in_call, ctx);
547 return Ok(state);
548 }
549 current = send.join.clone();
550 continue;
551 }
552
553 let next_node = self.resolve_next_node(¤t, &state)?;
556
557 let granularity_writes = matches!(
564 self.checkpoint_granularity,
565 crate::state_graph::CheckpointGranularity::PerNode
566 );
567 if granularity_writes
568 && let (Some(cp), Some(thread_id)) = (&self.checkpointer, ctx.thread_id())
569 {
570 let key = ThreadKey::new(ctx.tenant_id().clone(), thread_id);
571 cp.put(Checkpoint::new(
572 &key,
573 total_step,
574 state.clone(),
575 next_node.clone(),
576 ))
577 .await?;
578 }
579
580 match next_node {
581 None => {
582 self.emit_depth_histogram(steps_in_call, ctx);
583 return Ok(state);
584 }
585 Some(next) => current = next,
586 }
587 }
588 }
589
590 async fn execute_send_edge(
597 &self,
598 send: &SendEdge<S>,
599 state: S,
600 ctx: &ExecutionContext,
601 ) -> Result<S> {
602 let branches = (send.selector)(&state);
603 if branches.is_empty() {
604 return Ok(state);
605 }
606 for (target, _) in &branches {
611 if !send.has_target(target) {
612 return Err(Error::invalid_request(format!(
613 "StateGraph: send edge dispatched to '{target}' which is not in the \
614 declared target set {:?}",
615 send.targets()
616 )));
617 }
618 if !self.nodes.contains_key(target) {
622 return Err(Error::invalid_request(format!(
623 "StateGraph: send edge dispatched to unknown node '{target}'"
624 )));
625 }
626 }
627 let scope_ctx = ctx.child();
632 let futures = branches
633 .into_iter()
634 .map(|(target, branch_state)| {
635 let node = self.nodes.get(&target).map(Arc::clone).ok_or_else(|| {
636 Error::invalid_request(format!(
637 "StateGraph: send edge dispatched to unknown node '{target}'"
638 ))
639 })?;
640 let scope_ctx = scope_ctx.clone();
641 Ok::<_, Error>(async move { node.invoke(branch_state, &scope_ctx).await })
642 })
643 .collect::<Result<Vec<_>>>()?;
644 let branch_states = futures::future::try_join_all(futures).await?;
645 let mut folded = state;
653 for branch in branch_states {
654 folded = (send.merger)(folded, branch);
655 }
656 Ok(folded)
657 }
658
659 fn resolve_next_node(&self, current: &str, state: &S) -> Result<Option<String>> {
665 if self.finish_points.contains(current) {
666 return Ok(None);
667 }
668 if let Some(cond) = self.conditional_edges.get(current) {
669 let key = (cond.selector)(state);
670 let target = cond.mapping.get(&key).ok_or_else(|| {
671 Error::invalid_request(format!(
672 "StateGraph: conditional edge from '{current}' returned key '{key}' \
673 which is not present in the mapping"
674 ))
675 })?;
676 return Ok(if target == END {
677 None
678 } else {
679 Some(target.clone())
680 });
681 }
682 let target = self.edges.get(current).ok_or_else(|| {
683 Error::invalid_request(format!(
684 "StateGraph: node '{current}' has no outgoing edge and is not terminal"
685 ))
686 })?;
687 Ok(Some(target.clone()))
688 }
689
690 fn emit_depth_histogram(&self, depth: usize, ctx: &ExecutionContext) {
697 tracing::event!(
698 target: "entelix_graph::compiled",
699 tracing::Level::DEBUG,
700 entelix.graph.depth = depth,
701 entelix.graph.recursion_limit = self.recursion_limit,
702 entelix.tenant_id = ctx.tenant_id().as_str(),
703 entelix.thread_id = ctx.thread_id(),
704 entelix.run_id = ctx.run_id(),
705 "entelix.graph.run_complete"
706 );
707 }
708}
709
710#[async_trait::async_trait]
711impl<S> Runnable<S, S> for CompiledGraph<S>
712where
713 S: Clone + Send + Sync + 'static,
714{
715 async fn invoke(&self, input: S, ctx: &ExecutionContext) -> Result<S> {
716 self.execute_loop(input, self.entry_point.clone(), 0, ctx)
717 .await
718 }
719
720 async fn stream(
721 &self,
722 input: S,
723 mode: StreamMode,
724 ctx: &ExecutionContext,
725 ) -> Result<BoxStream<'_, Result<StreamChunk<S>>>> {
726 Ok(Box::pin(self.build_stream(input, mode, ctx.clone())))
727 }
728}
729
730const GRAPH_STREAM_NAME: &str = "CompiledGraph";
731
732fn finished<S>(ok: bool) -> StreamChunk<S> {
733 StreamChunk::Event(RunnableEvent::Finished {
734 name: GRAPH_STREAM_NAME.to_owned(),
735 ok,
736 })
737}
738
739impl<S> CompiledGraph<S>
740where
741 S: Clone + Send + Sync + 'static,
742{
743 #[allow(
744 clippy::too_many_lines,
745 clippy::single_match_else,
746 clippy::manual_let_else,
747 tail_expr_drop_order
748 )]
749 fn build_stream(
750 &self,
751 input: S,
752 mode: StreamMode,
753 ctx: ExecutionContext,
754 ) -> impl futures::Stream<Item = Result<StreamChunk<S>>> + Send + '_ {
755 let entry = self.entry_point.clone();
756 let finalize_tenant = ctx.tenant_id().clone();
761 let finalize_thread = ctx.thread_id().map(str::to_owned);
762 let finalize_mode = mode;
763 let effective_recursion_limit = ctx
764 .extension::<entelix_core::RunOverrides>()
765 .and_then(|o| o.max_iterations())
766 .map_or(self.recursion_limit, |n| n.min(self.recursion_limit));
767 let inner = async_stream::stream! {
768 let mut state = input;
769 let mut current = entry;
770 let mut steps_in_call: usize = 0;
771
772 if matches!(mode, StreamMode::Events) {
773 yield Ok(StreamChunk::Event(RunnableEvent::Started {
774 name: GRAPH_STREAM_NAME.to_owned(),
775 }));
776 }
777
778 loop {
779 if ctx.is_cancelled() {
780 if matches!(mode, StreamMode::Events) {
781 yield Ok(finished::<S>(false));
782 }
783 yield Err(Error::Cancelled);
784 return;
785 }
786 if steps_in_call >= effective_recursion_limit {
787 if matches!(mode, StreamMode::Events) {
788 yield Ok(finished::<S>(false));
789 }
790 yield Err(Error::invalid_request(format!(
791 "StateGraph: recursion limit ({effective_recursion_limit}) exceeded — possible infinite cycle"
792 )));
793 return;
794 }
795 steps_in_call = steps_in_call.saturating_add(1);
796
797 if matches!(mode, StreamMode::Debug) {
798 yield Ok(StreamChunk::Debug(DebugEvent::NodeStart {
799 node: current.clone(),
800 step: steps_in_call,
801 }));
802 }
803
804 let Some(node) = self.nodes.get(¤t) else {
805 yield Err(Error::invalid_request(format!(
806 "StateGraph: control reached unknown node '{current}'"
807 )));
808 return;
809 };
810
811 match node.invoke(state, &ctx).await {
812 Ok(s) => state = s,
813 Err(e) => {
814 if matches!(mode, StreamMode::Events) {
815 yield Ok(finished::<S>(false));
816 }
817 yield Err(e);
818 return;
819 }
820 }
821
822 match mode {
823 StreamMode::Values => {
824 yield Ok(StreamChunk::Value(state.clone()));
825 }
826 StreamMode::Updates => {
827 yield Ok(StreamChunk::Update {
828 node: current.clone(),
829 value: state.clone(),
830 });
831 }
832 StreamMode::Debug => {
833 yield Ok(StreamChunk::Debug(DebugEvent::NodeEnd {
834 node: current.clone(),
835 step: steps_in_call,
836 }));
837 }
838 _ => {}
839 }
840
841 if let Some(send) = self.send_edges.get(¤t) {
845 match self.execute_send_edge(send, state.clone(), &ctx).await {
846 Ok(merged) => state = merged,
847 Err(e) => {
848 if matches!(mode, StreamMode::Events) {
849 yield Ok(finished::<S>(false));
850 }
851 yield Err(e);
852 return;
853 }
854 }
855 if send.join == END {
856 self.emit_depth_histogram(steps_in_call, &ctx);
857 match mode {
858 StreamMode::Debug => {
859 yield Ok(StreamChunk::Debug(DebugEvent::Final));
860 }
861 StreamMode::Events => {
862 yield Ok(finished::<S>(true));
863 }
864 StreamMode::Messages => {
865 yield Ok(StreamChunk::Value(state));
866 }
867 _ => {}
868 }
869 return;
870 }
871 current = send.join.clone();
872 continue;
873 }
874
875 let next_node = match self.resolve_next_node(¤t, &state) {
876 Ok(n) => n,
877 Err(e) => {
878 if matches!(mode, StreamMode::Events) {
879 yield Ok(finished::<S>(false));
880 }
881 yield Err(e);
882 return;
883 }
884 };
885
886 if let Some(next) = next_node {
887 current = next;
888 } else {
889 self.emit_depth_histogram(steps_in_call, &ctx);
890 match mode {
891 StreamMode::Debug => {
892 yield Ok(StreamChunk::Debug(DebugEvent::Final));
893 }
894 StreamMode::Events => {
895 yield Ok(finished::<S>(true));
896 }
897 StreamMode::Messages => {
898 yield Ok(StreamChunk::Value(state));
899 }
900 _ => {}
901 }
902 return;
903 }
904 }
905 };
906 FinalizingStream::new(inner, move || {
907 tracing::debug!(
908 target: "entelix_graph::stream",
909 tenant_id = %finalize_tenant,
910 thread_id = ?finalize_thread,
911 mode = ?finalize_mode,
912 "graph stream dropped before completion"
913 );
914 })
915 }
916}