use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use entelix_core::{Error, ExecutionContext, Result, ThreadKey};
use entelix_runnable::Runnable;
use entelix_runnable::stream::{BoxStream, DebugEvent, RunnableEvent, StreamChunk, StreamMode};
use crate::checkpoint::{Checkpoint, CheckpointId, Checkpointer};
use crate::command::Command;
use crate::finalizing_stream::FinalizingStream;
use crate::state_graph::END;
pub type EdgeSelector<S> = Arc<dyn Fn(&S) -> String + Send + Sync>;
pub struct ConditionalEdge<S>
where
S: Clone + Send + Sync + 'static,
{
pub selector: EdgeSelector<S>,
pub mapping: HashMap<String, String>,
}
pub type SendSelector<S> = Arc<dyn Fn(&S) -> Vec<(String, S)> + Send + Sync>;
pub type SendMerger<S> = Arc<dyn Fn(S, S) -> S + Send + Sync>;
pub struct SendEdge<S>
where
S: Clone + Send + Sync + 'static,
{
targets: Vec<String>,
targets_set: HashSet<String>,
pub selector: SendSelector<S>,
pub merger: SendMerger<S>,
pub join: String,
}
impl<S> SendEdge<S>
where
S: Clone + Send + Sync + 'static,
{
pub fn new(
targets: impl IntoIterator<Item = String>,
selector: SendSelector<S>,
merger: SendMerger<S>,
join: String,
) -> Self {
let mut ordered: Vec<String> = Vec::new();
let mut set: HashSet<String> = HashSet::new();
for t in targets {
if set.insert(t.clone()) {
ordered.push(t);
}
}
Self {
targets: ordered,
targets_set: set,
selector,
merger,
join,
}
}
pub fn targets(&self) -> &[String] {
&self.targets
}
pub fn has_target(&self, name: &str) -> bool {
self.targets_set.contains(name)
}
}
pub struct CompiledGraph<S>
where
S: Clone + Send + Sync + 'static,
{
nodes: HashMap<String, Arc<dyn Runnable<S, S>>>,
edges: HashMap<String, String>,
conditional_edges: HashMap<String, ConditionalEdge<S>>,
send_edges: HashMap<String, SendEdge<S>>,
entry_point: String,
finish_points: HashSet<String>,
recursion_limit: usize,
checkpointer: Option<Arc<dyn Checkpointer<S>>>,
checkpoint_granularity: crate::state_graph::CheckpointGranularity,
interrupt_before: HashSet<String>,
interrupt_after: HashSet<String>,
}
impl<S> std::fmt::Debug for CompiledGraph<S>
where
S: Clone + Send + Sync + 'static,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CompiledGraph")
.field("nodes", &sorted_keys(&self.nodes))
.field("edges", &sorted_pairs(&self.edges))
.field("conditional_edges", &sorted_keys(&self.conditional_edges))
.field("send_edges", &sorted_keys(&self.send_edges))
.field("entry_point", &self.entry_point)
.field("finish_points", &sorted_set(&self.finish_points))
.field("recursion_limit", &self.recursion_limit)
.field("has_checkpointer", &self.checkpointer.is_some())
.field("checkpoint_granularity", &self.checkpoint_granularity)
.field("interrupt_before", &sorted_set(&self.interrupt_before))
.field("interrupt_after", &sorted_set(&self.interrupt_after))
.finish()
}
}
fn sorted_keys<V>(m: &HashMap<String, V>) -> Vec<&String> {
let mut out: Vec<&String> = m.keys().collect();
out.sort();
out
}
fn sorted_pairs(m: &HashMap<String, String>) -> Vec<(&String, &String)> {
let mut out: Vec<(&String, &String)> = m.iter().collect();
out.sort_by_key(|(k, _)| k.as_str());
out
}
fn sorted_set(s: &HashSet<String>) -> Vec<&String> {
let mut out: Vec<&String> = s.iter().collect();
out.sort();
out
}
impl<S> CompiledGraph<S>
where
S: Clone + Send + Sync + 'static,
{
#[allow(clippy::too_many_arguments)]
pub(crate) fn new(
nodes: HashMap<String, Arc<dyn Runnable<S, S>>>,
edges: HashMap<String, String>,
conditional_edges: HashMap<String, ConditionalEdge<S>>,
send_edges: HashMap<String, SendEdge<S>>,
entry_point: String,
finish_points: HashSet<String>,
recursion_limit: usize,
checkpointer: Option<Arc<dyn Checkpointer<S>>>,
checkpoint_granularity: crate::state_graph::CheckpointGranularity,
interrupt_before: HashSet<String>,
interrupt_after: HashSet<String>,
) -> Self {
Self {
nodes,
edges,
conditional_edges,
send_edges,
entry_point,
finish_points,
recursion_limit,
checkpointer,
checkpoint_granularity,
interrupt_before,
interrupt_after,
}
}
pub const fn checkpoint_granularity(&self) -> crate::state_graph::CheckpointGranularity {
self.checkpoint_granularity
}
pub fn entry_point(&self) -> &str {
&self.entry_point
}
pub const fn recursion_limit(&self) -> usize {
self.recursion_limit
}
pub fn finish_point_count(&self) -> usize {
self.finish_points.len()
}
pub fn has_checkpointer(&self) -> bool {
self.checkpointer.is_some()
}
pub async fn resume(&self, ctx: &ExecutionContext) -> Result<S> {
self.resume_with(Command::Resume, ctx).await
}
pub async fn resume_with(&self, command: Command<S>, ctx: &ExecutionContext) -> Result<S> {
let checkpointer = self
.checkpointer
.as_ref()
.ok_or_else(|| Error::config("CompiledGraph::resume requires a Checkpointer"))?;
let key = ThreadKey::from_ctx(ctx)?;
let latest = checkpointer.get_latest(&key).await?.ok_or_else(|| {
Error::invalid_request(format!(
"CompiledGraph::resume: no checkpoint exists for tenant '{}' thread '{}'",
key.tenant_id(),
key.thread_id()
))
})?;
self.dispatch_from_checkpoint(latest, command, ctx).await
}
pub async fn resume_from(
&self,
checkpoint_id: &CheckpointId,
command: Command<S>,
ctx: &ExecutionContext,
) -> Result<S> {
let checkpointer = self
.checkpointer
.as_ref()
.ok_or_else(|| Error::config("CompiledGraph::resume_from requires a Checkpointer"))?;
let key = ThreadKey::from_ctx(ctx)?;
let cp = checkpointer
.get_by_id(&key, checkpoint_id)
.await?
.ok_or_else(|| {
Error::invalid_request(format!(
"CompiledGraph::resume_from: checkpoint not found in tenant '{}' thread '{}'",
key.tenant_id(),
key.thread_id()
))
})?;
self.dispatch_from_checkpoint(cp, command, ctx).await
}
async fn dispatch_from_checkpoint(
&self,
checkpoint: Checkpoint<S>,
command: Command<S>,
ctx: &ExecutionContext,
) -> Result<S> {
if let Some(handle) = ctx.audit_sink() {
handle
.as_sink()
.record_resumed(&checkpoint.id.to_hyphenated_string());
}
let mut scoped_ctx: Option<ExecutionContext> = None;
let (state, next_node) = match command {
Command::Resume => (checkpoint.state, checkpoint.next_node),
Command::Update(s) => (s, checkpoint.next_node),
Command::GoTo(node) => (checkpoint.state, Some(node)),
Command::ApproveTool {
tool_use_id,
decision,
} => {
if matches!(decision, entelix_core::ApprovalDecision::AwaitExternal) {
return Err(Error::invalid_request(
"Command::ApproveTool: AwaitExternal is not a valid resume \
decision — pausing again on resume defeats the purpose. \
Supply Approve or Reject{reason}.",
));
}
let mut pending = ctx
.extension::<entelix_core::PendingApprovalDecisions>()
.map(|h| (*h).clone())
.unwrap_or_default();
pending.insert(tool_use_id, decision);
scoped_ctx = Some(ctx.clone().add_extension(pending));
(checkpoint.state, checkpoint.next_node)
}
};
let effective_ctx = scoped_ctx.as_ref().unwrap_or(ctx);
match next_node {
None => Ok(state),
Some(next) => {
self.execute_loop_inner(
state,
next,
checkpoint.step.saturating_add(1),
effective_ctx,
true,
)
.await
}
}
}
async fn execute_loop(
&self,
state: S,
current: String,
step_offset: usize,
ctx: &ExecutionContext,
) -> Result<S> {
self.execute_loop_inner(state, current, step_offset, ctx, false)
.await
}
#[allow(clippy::too_many_lines)]
async fn execute_loop_inner(
&self,
mut state: S,
mut current: String,
step_offset: usize,
ctx: &ExecutionContext,
mut skip_interrupt_before_for_current: bool,
) -> Result<S> {
let effective_recursion_limit = ctx
.extension::<entelix_core::RunOverrides>()
.and_then(|o| o.max_iterations())
.map_or(self.recursion_limit, |n| n.min(self.recursion_limit));
let mut steps_in_call: usize = 0;
loop {
if ctx.is_cancelled() {
return Err(Error::Cancelled);
}
if steps_in_call >= effective_recursion_limit {
return Err(Error::invalid_request(format!(
"StateGraph: recursion limit ({effective_recursion_limit}) exceeded — possible infinite cycle"
)));
}
steps_in_call = steps_in_call.saturating_add(1);
let total_step = step_offset.saturating_add(steps_in_call);
let node = self.nodes.get(¤t).ok_or_else(|| {
Error::invalid_request(format!(
"StateGraph: control reached unknown node '{current}'"
))
})?;
let pre_state = if self.checkpointer.is_some() && ctx.thread_id().is_some() {
Some(state.clone())
} else {
None
};
if self.interrupt_before.contains(¤t) && !skip_interrupt_before_for_current {
if let (Some(cp), Some(thread_id), Some(pre)) =
(&self.checkpointer, ctx.thread_id(), pre_state.clone())
{
let key = ThreadKey::new(ctx.tenant_id().clone(), thread_id);
cp.put(Checkpoint::new(
&key,
total_step,
pre,
Some(current.clone()),
))
.await?;
}
return Err(Error::Interrupted {
kind: entelix_core::InterruptionKind::ScheduledPause {
phase: entelix_core::InterruptionPhase::Before,
node: current.clone(),
},
payload: serde_json::Value::Null,
});
}
skip_interrupt_before_for_current = false;
match node.invoke(state, ctx).await {
Ok(new_state) => state = new_state,
Err(Error::Interrupted { kind, payload }) => {
if let (Some(cp), Some(thread_id), Some(pre)) =
(&self.checkpointer, ctx.thread_id(), pre_state)
{
let key = ThreadKey::new(ctx.tenant_id().clone(), thread_id);
cp.put(Checkpoint::new(
&key,
total_step,
pre,
Some(current.clone()),
))
.await?;
}
return Err(Error::Interrupted { kind, payload });
}
Err(other) => return Err(other),
}
if self.interrupt_after.contains(¤t) && !self.send_edges.contains_key(¤t) {
let next_node = self.resolve_next_node(¤t, &state)?;
if let (Some(cp), Some(thread_id)) = (&self.checkpointer, ctx.thread_id()) {
let key = ThreadKey::new(ctx.tenant_id().clone(), thread_id);
cp.put(Checkpoint::new(
&key,
total_step,
state.clone(),
next_node.clone(),
))
.await?;
}
return Err(Error::Interrupted {
kind: entelix_core::InterruptionKind::ScheduledPause {
phase: entelix_core::InterruptionPhase::After,
node: current.clone(),
},
payload: serde_json::Value::Null,
});
}
if let Some(send) = self.send_edges.get(¤t) {
state = self.execute_send_edge(send, state, ctx).await?;
if send.join == END {
self.emit_depth_histogram(steps_in_call, ctx);
return Ok(state);
}
current = send.join.clone();
continue;
}
let next_node = self.resolve_next_node(¤t, &state)?;
let granularity_writes = matches!(
self.checkpoint_granularity,
crate::state_graph::CheckpointGranularity::PerNode
);
if granularity_writes
&& let (Some(cp), Some(thread_id)) = (&self.checkpointer, ctx.thread_id())
{
let key = ThreadKey::new(ctx.tenant_id().clone(), thread_id);
cp.put(Checkpoint::new(
&key,
total_step,
state.clone(),
next_node.clone(),
))
.await?;
}
match next_node {
None => {
self.emit_depth_histogram(steps_in_call, ctx);
return Ok(state);
}
Some(next) => current = next,
}
}
}
async fn execute_send_edge(
&self,
send: &SendEdge<S>,
state: S,
ctx: &ExecutionContext,
) -> Result<S> {
let branches = (send.selector)(&state);
if branches.is_empty() {
return Ok(state);
}
for (target, _) in &branches {
if !send.has_target(target) {
return Err(Error::invalid_request(format!(
"StateGraph: send edge dispatched to '{target}' which is not in the \
declared target set {:?}",
send.targets()
)));
}
if !self.nodes.contains_key(target) {
return Err(Error::invalid_request(format!(
"StateGraph: send edge dispatched to unknown node '{target}'"
)));
}
}
let scope_ctx = ctx.child();
let futures = branches
.into_iter()
.map(|(target, branch_state)| {
let node = self.nodes.get(&target).map(Arc::clone).ok_or_else(|| {
Error::invalid_request(format!(
"StateGraph: send edge dispatched to unknown node '{target}'"
))
})?;
let scope_ctx = scope_ctx.clone();
Ok::<_, Error>(async move { node.invoke(branch_state, &scope_ctx).await })
})
.collect::<Result<Vec<_>>>()?;
let branch_states = futures::future::try_join_all(futures).await?;
let mut folded = state;
for branch in branch_states {
folded = (send.merger)(folded, branch);
}
Ok(folded)
}
fn resolve_next_node(&self, current: &str, state: &S) -> Result<Option<String>> {
if self.finish_points.contains(current) {
return Ok(None);
}
if let Some(cond) = self.conditional_edges.get(current) {
let key = (cond.selector)(state);
let target = cond.mapping.get(&key).ok_or_else(|| {
Error::invalid_request(format!(
"StateGraph: conditional edge from '{current}' returned key '{key}' \
which is not present in the mapping"
))
})?;
return Ok(if target == END {
None
} else {
Some(target.clone())
});
}
let target = self.edges.get(current).ok_or_else(|| {
Error::invalid_request(format!(
"StateGraph: node '{current}' has no outgoing edge and is not terminal"
))
})?;
Ok(Some(target.clone()))
}
fn emit_depth_histogram(&self, depth: usize, ctx: &ExecutionContext) {
tracing::event!(
target: "entelix_graph::compiled",
tracing::Level::DEBUG,
entelix.graph.depth = depth,
entelix.graph.recursion_limit = self.recursion_limit,
entelix.tenant_id = ctx.tenant_id().as_str(),
entelix.thread_id = ctx.thread_id(),
entelix.run_id = ctx.run_id(),
"entelix.graph.run_complete"
);
}
}
#[async_trait::async_trait]
impl<S> Runnable<S, S> for CompiledGraph<S>
where
S: Clone + Send + Sync + 'static,
{
async fn invoke(&self, input: S, ctx: &ExecutionContext) -> Result<S> {
self.execute_loop(input, self.entry_point.clone(), 0, ctx)
.await
}
async fn stream(
&self,
input: S,
mode: StreamMode,
ctx: &ExecutionContext,
) -> Result<BoxStream<'_, Result<StreamChunk<S>>>> {
Ok(Box::pin(self.build_stream(input, mode, ctx.clone())))
}
}
const GRAPH_STREAM_NAME: &str = "CompiledGraph";
fn finished<S>(ok: bool) -> StreamChunk<S> {
StreamChunk::Event(RunnableEvent::Finished {
name: GRAPH_STREAM_NAME.to_owned(),
ok,
})
}
impl<S> CompiledGraph<S>
where
S: Clone + Send + Sync + 'static,
{
#[allow(
clippy::too_many_lines,
clippy::single_match_else,
clippy::manual_let_else,
tail_expr_drop_order
)]
fn build_stream(
&self,
input: S,
mode: StreamMode,
ctx: ExecutionContext,
) -> impl futures::Stream<Item = Result<StreamChunk<S>>> + Send + '_ {
let entry = self.entry_point.clone();
let finalize_tenant = ctx.tenant_id().clone();
let finalize_thread = ctx.thread_id().map(str::to_owned);
let finalize_mode = mode;
let effective_recursion_limit = ctx
.extension::<entelix_core::RunOverrides>()
.and_then(|o| o.max_iterations())
.map_or(self.recursion_limit, |n| n.min(self.recursion_limit));
let inner = async_stream::stream! {
let mut state = input;
let mut current = entry;
let mut steps_in_call: usize = 0;
if matches!(mode, StreamMode::Events) {
yield Ok(StreamChunk::Event(RunnableEvent::Started {
name: GRAPH_STREAM_NAME.to_owned(),
}));
}
loop {
if ctx.is_cancelled() {
if matches!(mode, StreamMode::Events) {
yield Ok(finished::<S>(false));
}
yield Err(Error::Cancelled);
return;
}
if steps_in_call >= effective_recursion_limit {
if matches!(mode, StreamMode::Events) {
yield Ok(finished::<S>(false));
}
yield Err(Error::invalid_request(format!(
"StateGraph: recursion limit ({effective_recursion_limit}) exceeded — possible infinite cycle"
)));
return;
}
steps_in_call = steps_in_call.saturating_add(1);
if matches!(mode, StreamMode::Debug) {
yield Ok(StreamChunk::Debug(DebugEvent::NodeStart {
node: current.clone(),
step: steps_in_call,
}));
}
let Some(node) = self.nodes.get(¤t) else {
yield Err(Error::invalid_request(format!(
"StateGraph: control reached unknown node '{current}'"
)));
return;
};
match node.invoke(state, &ctx).await {
Ok(s) => state = s,
Err(e) => {
if matches!(mode, StreamMode::Events) {
yield Ok(finished::<S>(false));
}
yield Err(e);
return;
}
}
match mode {
StreamMode::Values => {
yield Ok(StreamChunk::Value(state.clone()));
}
StreamMode::Updates => {
yield Ok(StreamChunk::Update {
node: current.clone(),
value: state.clone(),
});
}
StreamMode::Debug => {
yield Ok(StreamChunk::Debug(DebugEvent::NodeEnd {
node: current.clone(),
step: steps_in_call,
}));
}
_ => {}
}
if let Some(send) = self.send_edges.get(¤t) {
match self.execute_send_edge(send, state.clone(), &ctx).await {
Ok(merged) => state = merged,
Err(e) => {
if matches!(mode, StreamMode::Events) {
yield Ok(finished::<S>(false));
}
yield Err(e);
return;
}
}
if send.join == END {
self.emit_depth_histogram(steps_in_call, &ctx);
match mode {
StreamMode::Debug => {
yield Ok(StreamChunk::Debug(DebugEvent::Final));
}
StreamMode::Events => {
yield Ok(finished::<S>(true));
}
StreamMode::Messages => {
yield Ok(StreamChunk::Value(state));
}
_ => {}
}
return;
}
current = send.join.clone();
continue;
}
let next_node = match self.resolve_next_node(¤t, &state) {
Ok(n) => n,
Err(e) => {
if matches!(mode, StreamMode::Events) {
yield Ok(finished::<S>(false));
}
yield Err(e);
return;
}
};
if let Some(next) = next_node {
current = next;
} else {
self.emit_depth_histogram(steps_in_call, &ctx);
match mode {
StreamMode::Debug => {
yield Ok(StreamChunk::Debug(DebugEvent::Final));
}
StreamMode::Events => {
yield Ok(finished::<S>(true));
}
StreamMode::Messages => {
yield Ok(StreamChunk::Value(state));
}
_ => {}
}
return;
}
}
};
FinalizingStream::new(inner, move || {
tracing::debug!(
target: "entelix_graph::stream",
tenant_id = %finalize_tenant,
thread_id = ?finalize_thread,
mode = ?finalize_mode,
"graph stream dropped before completion"
);
})
}
}