use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use entelix_core::{Error, Result};
use entelix_runnable::Runnable;
use crate::checkpoint::Checkpointer;
use crate::compiled::{
CompiledGraph, ConditionalEdge, EdgeSelector, SendEdge, SendMerger, SendSelector,
};
use crate::contributing_node::ContributingNodeAdapter;
use crate::merge_node::MergeNodeAdapter;
use crate::reducer::StateMerge;
pub const DEFAULT_RECURSION_LIMIT: usize = 25;
pub const END: &str = "__entelix_graph_end__";
#[derive(Clone, Copy, Debug, Default, Eq, PartialEq, Hash)]
#[non_exhaustive]
pub enum CheckpointGranularity {
Off,
#[default]
PerNode,
}
pub struct StateGraph<S>
where
S: Clone + Send + Sync + 'static,
{
nodes: HashMap<String, Arc<dyn Runnable<S, S>>>,
edges: HashMap<String, String>,
conditional_edges: HashMap<String, ConditionalEdge<S>>,
send_edges: HashMap<String, SendEdge<S>>,
entry_point: Option<String>,
finish_points: HashSet<String>,
recursion_limit: usize,
checkpointer: Option<Arc<dyn Checkpointer<S>>>,
checkpoint_granularity: CheckpointGranularity,
interrupt_before: HashSet<String>,
interrupt_after: HashSet<String>,
}
impl<S> std::fmt::Debug for StateGraph<S>
where
S: Clone + Send + Sync + 'static,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mut nodes: Vec<&String> = self.nodes.keys().collect();
nodes.sort();
let mut edges: Vec<(&String, &String)> = self.edges.iter().collect();
edges.sort_by_key(|(k, _)| k.as_str());
let mut conditional: Vec<&String> = self.conditional_edges.keys().collect();
conditional.sort();
let mut send: Vec<&String> = self.send_edges.keys().collect();
send.sort();
let mut finish: Vec<&String> = self.finish_points.iter().collect();
finish.sort();
let mut interrupt_before: Vec<&String> = self.interrupt_before.iter().collect();
interrupt_before.sort();
let mut interrupt_after: Vec<&String> = self.interrupt_after.iter().collect();
interrupt_after.sort();
f.debug_struct("StateGraph")
.field("nodes", &nodes)
.field("edges", &edges)
.field("conditional_edges", &conditional)
.field("send_edges", &send)
.field("entry_point", &self.entry_point)
.field("finish_points", &finish)
.field("recursion_limit", &self.recursion_limit)
.field("has_checkpointer", &self.checkpointer.is_some())
.field("checkpoint_granularity", &self.checkpoint_granularity)
.field("interrupt_before", &interrupt_before)
.field("interrupt_after", &interrupt_after)
.finish()
}
}
impl<S> StateGraph<S>
where
S: Clone + Send + Sync + 'static,
{
pub fn new() -> Self {
Self {
nodes: HashMap::new(),
edges: HashMap::new(),
conditional_edges: HashMap::new(),
send_edges: HashMap::new(),
entry_point: None,
finish_points: HashSet::new(),
recursion_limit: DEFAULT_RECURSION_LIMIT,
checkpointer: None,
checkpoint_granularity: CheckpointGranularity::default(),
interrupt_before: HashSet::new(),
interrupt_after: HashSet::new(),
}
}
#[must_use]
pub fn with_checkpointer(mut self, checkpointer: Arc<dyn Checkpointer<S>>) -> Self {
self.checkpointer = Some(checkpointer);
self
}
#[must_use]
pub const fn with_checkpoint_granularity(mut self, g: CheckpointGranularity) -> Self {
self.checkpoint_granularity = g;
self
}
#[must_use]
pub fn add_node<R>(mut self, name: impl Into<String>, runnable: R) -> Self
where
R: Runnable<S, S> + 'static,
{
self.nodes.insert(name.into(), Arc::new(runnable));
self
}
#[must_use]
pub fn add_node_with<R, U, F>(self, name: impl Into<String>, runnable: R, merger: F) -> Self
where
R: Runnable<S, U> + 'static,
U: Send + Sync + 'static,
F: Fn(S, U) -> Result<S> + Send + Sync + 'static,
{
self.add_node(name, MergeNodeAdapter::new(runnable, merger))
}
#[must_use]
pub fn add_contributing_node<R>(self, name: impl Into<String>, runnable: R) -> Self
where
R: Runnable<S, S::Contribution> + 'static,
S: StateMerge,
{
self.add_node(name, ContributingNodeAdapter::new(runnable))
}
#[must_use]
pub fn add_edge(mut self, from: impl Into<String>, to: impl Into<String>) -> Self {
self.edges.insert(from.into(), to.into());
self
}
#[must_use]
pub fn add_conditional_edges<F, K, V>(
mut self,
from: impl Into<String>,
selector: F,
mapping: impl IntoIterator<Item = (K, V)>,
) -> Self
where
F: Fn(&S) -> String + Send + Sync + 'static,
K: Into<String>,
V: Into<String>,
{
let mapping: HashMap<String, String> = mapping
.into_iter()
.map(|(k, v)| (k.into(), v.into()))
.collect();
let edge = ConditionalEdge {
selector: Arc::new(selector) as EdgeSelector<S>,
mapping,
};
self.conditional_edges.insert(from.into(), edge);
self
}
#[must_use]
pub fn add_send_edges<F, I, T>(
mut self,
from: impl Into<String>,
targets: I,
selector: F,
join: impl Into<String>,
) -> Self
where
F: Fn(&S) -> Vec<(String, S)> + Send + Sync + 'static,
I: IntoIterator<Item = T>,
T: Into<String>,
S: StateMerge,
{
let edge = SendEdge::new(
targets.into_iter().map(Into::into),
Arc::new(selector) as SendSelector<S>,
Arc::new(<S as StateMerge>::merge) as SendMerger<S>,
join.into(),
);
self.send_edges.insert(from.into(), edge);
self
}
#[must_use]
pub fn set_entry_point(mut self, name: impl Into<String>) -> Self {
self.entry_point = Some(name.into());
self
}
#[must_use]
pub fn add_finish_point(mut self, name: impl Into<String>) -> Self {
self.finish_points.insert(name.into());
self
}
#[must_use]
pub const fn with_recursion_limit(mut self, n: usize) -> Self {
self.recursion_limit = n;
self
}
#[must_use]
pub fn interrupt_before<I, T>(mut self, nodes: I) -> Self
where
I: IntoIterator<Item = T>,
T: Into<String>,
{
self.interrupt_before
.extend(nodes.into_iter().map(Into::into));
self
}
#[must_use]
pub fn interrupt_after<I, T>(mut self, nodes: I) -> Self
where
I: IntoIterator<Item = T>,
T: Into<String>,
{
self.interrupt_after
.extend(nodes.into_iter().map(Into::into));
self
}
pub fn node_count(&self) -> usize {
self.nodes.len()
}
pub fn edge_count(&self) -> usize {
self.edges.len()
}
pub fn conditional_edge_count(&self) -> usize {
self.conditional_edges.len()
}
pub fn compile(self) -> Result<CompiledGraph<S>> {
let entry = self
.entry_point
.as_ref()
.ok_or_else(|| Error::config("StateGraph: no entry point set"))?
.clone();
if !self.nodes.contains_key(&entry) {
return Err(Error::config(format!(
"StateGraph: entry point '{entry}' is not a registered node"
)));
}
self.validate_finish_points()?;
self.validate_static_edges()?;
self.validate_conditional_edges()?;
let send_branch_targets = self.validate_send_edges()?;
self.validate_node_termination(&send_branch_targets)?;
self.validate_interrupt_points()?;
Ok(CompiledGraph::new(
self.nodes,
self.edges,
self.conditional_edges,
self.send_edges,
entry,
self.finish_points,
self.recursion_limit,
self.checkpointer,
self.checkpoint_granularity,
self.interrupt_before,
self.interrupt_after,
))
}
fn validate_interrupt_points(&self) -> Result<()> {
for name in &self.interrupt_before {
if !self.nodes.contains_key(name) {
return Err(Error::config(format!(
"StateGraph: interrupt_before names '{name}' which is not a registered node"
)));
}
}
for name in &self.interrupt_after {
if !self.nodes.contains_key(name) {
return Err(Error::config(format!(
"StateGraph: interrupt_after names '{name}' which is not a registered node"
)));
}
}
Ok(())
}
fn validate_finish_points(&self) -> Result<()> {
if self.finish_points.is_empty() {
return Err(Error::config(
"StateGraph: no finish points registered (graph would never terminate)",
));
}
for fp in &self.finish_points {
if !self.nodes.contains_key(fp) {
return Err(Error::config(format!(
"StateGraph: finish point '{fp}' is not a registered node"
)));
}
}
Ok(())
}
fn validate_static_edges(&self) -> Result<()> {
for (from, to) in &self.edges {
if !self.nodes.contains_key(from) {
return Err(Error::config(format!(
"StateGraph: edge source '{from}' is not a registered node"
)));
}
if !self.nodes.contains_key(to) {
return Err(Error::config(format!(
"StateGraph: edge target '{to}' is not a registered node"
)));
}
}
Ok(())
}
fn validate_conditional_edges(&self) -> Result<()> {
for (from, cond) in &self.conditional_edges {
if !self.nodes.contains_key(from) {
return Err(Error::config(format!(
"StateGraph: conditional edge source '{from}' is not a registered node"
)));
}
if self.edges.contains_key(from) {
return Err(Error::config(format!(
"StateGraph: node '{from}' has both a static edge and a conditional edge \
— pick one"
)));
}
for target in cond.mapping.values() {
if target != END && !self.nodes.contains_key(target) {
return Err(Error::config(format!(
"StateGraph: conditional edge from '{from}' maps to '{target}' which is \
neither a registered node nor END"
)));
}
}
}
Ok(())
}
fn validate_send_edges(&self) -> Result<HashSet<String>> {
let mut send_branch_targets: HashSet<String> = HashSet::new();
for (from, send) in &self.send_edges {
if !self.nodes.contains_key(from) {
return Err(Error::config(format!(
"StateGraph: send edge source '{from}' is not a registered node"
)));
}
if self.edges.contains_key(from) || self.conditional_edges.contains_key(from) {
return Err(Error::config(format!(
"StateGraph: node '{from}' has more than one outgoing edge type — \
send edges are mutually exclusive with static and conditional edges"
)));
}
if send.join != END && !self.nodes.contains_key(&send.join) {
return Err(Error::config(format!(
"StateGraph: send edge from '{from}' joins on '{}' which is \
neither a registered node nor END",
send.join
)));
}
for target in send.targets() {
if !self.nodes.contains_key(target) {
return Err(Error::config(format!(
"StateGraph: send edge from '{from}' lists target '{target}' \
which is not a registered node"
)));
}
send_branch_targets.insert(target.clone());
}
}
Ok(send_branch_targets)
}
fn validate_node_termination(&self, send_branch_targets: &HashSet<String>) -> Result<()> {
for name in self.nodes.keys() {
if !self.finish_points.contains(name)
&& !self.edges.contains_key(name)
&& !self.conditional_edges.contains_key(name)
&& !self.send_edges.contains_key(name)
&& !send_branch_targets.contains(name)
{
return Err(Error::config(format!(
"StateGraph: node '{name}' has no outgoing edge and is not a finish point"
)));
}
}
Ok(())
}
}
impl<S> Default for StateGraph<S>
where
S: Clone + Send + Sync + 'static,
{
fn default() -> Self {
Self::new()
}
}