enact-core 0.0.2

Core agent runtime for Enact - Graph-Native AI agents
Documentation
//! StateGraph - builder for creating graphs
//!
//! Includes cycle detection to prevent infinite loops (DAG invariant).

use super::edge::{ConditionalEdge, Edge, EdgeTarget};
use super::node::{DynNode, FunctionNode, Node, NodeState};
use super::CompiledGraph;
use std::collections::{HashMap, HashSet};
use std::future::Future;
use std::sync::Arc;

/// StateGraph - fluent builder for creating DAGs
pub struct StateGraph {
    pub nodes: HashMap<String, DynNode>,
    pub edges: Vec<Edge>,
    pub conditional_edges: Vec<ConditionalEdge>,
    pub entry_point: Option<String>,
}

impl StateGraph {
    /// Create a new empty graph
    pub fn new() -> Self {
        Self {
            nodes: HashMap::new(),
            edges: Vec::new(),
            conditional_edges: Vec::new(),
            entry_point: None,
        }
    }

    /// Add a node with a function
    pub fn add_node<F, Fut>(mut self, name: impl Into<String>, func: F) -> Self
    where
        F: Fn(NodeState) -> Fut + Send + Sync + 'static,
        Fut: Future<Output = anyhow::Result<NodeState>> + Send + 'static,
    {
        let name = name.into();
        let node = Arc::new(FunctionNode::new(name.clone(), func));
        self.nodes.insert(name, node);
        self
    }

    /// Add a pre-built node
    pub fn add_node_impl(mut self, node: impl Node + 'static) -> Self {
        let name = node.name().to_string();
        self.nodes.insert(name, Arc::new(node));
        self
    }

    /// Add a direct edge between two nodes
    pub fn add_edge(mut self, from: impl Into<String>, to: impl Into<String>) -> Self {
        self.edges.push(Edge::new(from, EdgeTarget::node(to)));
        self
    }

    /// Add an edge to END
    pub fn add_edge_to_end(mut self, from: impl Into<String>) -> Self {
        self.edges.push(Edge::new(from, EdgeTarget::End));
        self
    }

    /// Add a conditional edge with a router function
    pub fn add_conditional_edge<F>(mut self, from: impl Into<String>, router: F) -> Self
    where
        F: Fn(&str) -> EdgeTarget + Send + Sync + 'static,
    {
        self.conditional_edges.push(ConditionalEdge {
            from: from.into(),
            router: Arc::new(router),
        });
        self
    }

    /// Set the entry point node
    pub fn set_entry_point(mut self, name: impl Into<String>) -> Self {
        self.entry_point = Some(name.into());
        self
    }

    /// Compile the graph for execution
    pub fn compile(self) -> anyhow::Result<CompiledGraph> {
        // Validate: must have at least one node
        if self.nodes.is_empty() {
            anyhow::bail!("Graph must have at least one node");
        }

        // Determine entry point
        let entry_point = self.entry_point.clone().or_else(|| {
            // If no entry point set, use first node added
            self.nodes.keys().next().cloned()
        });

        let entry_point = entry_point.ok_or_else(|| anyhow::anyhow!("No entry point defined"))?;

        // Validate: entry point must exist
        if !self.nodes.contains_key(&entry_point) {
            anyhow::bail!("Entry point '{}' does not exist", entry_point);
        }

        // Validate: all edge sources/targets exist
        for edge in &self.edges {
            if !self.nodes.contains_key(&edge.from) {
                anyhow::bail!("Edge source '{}' does not exist", edge.from);
            }
            if let EdgeTarget::Node(ref target) = edge.to {
                if !self.nodes.contains_key(target) {
                    anyhow::bail!("Edge target '{}' does not exist", target);
                }
            }
        }

        // Validate: no cycles (DAG invariant)
        // Build adjacency list from edges
        let adjacency = self.build_adjacency_list();
        if let Some(cycle) = self.detect_cycle(&adjacency, &entry_point) {
            anyhow::bail!(
                "Graph contains a cycle: {} -> ... -> {}. Cycles are not allowed in DAGs.",
                cycle.first().unwrap_or(&"?".to_string()),
                cycle.last().unwrap_or(&"?".to_string())
            );
        }

        Ok(CompiledGraph {
            nodes: self.nodes,
            edges: self.edges,
            conditional_edges: self.conditional_edges,
            entry_point,
        })
    }

    /// Build adjacency list from edges
    fn build_adjacency_list(&self) -> HashMap<String, Vec<String>> {
        let mut adjacency: HashMap<String, Vec<String>> = HashMap::new();

        // Initialize all nodes with empty neighbor lists
        for node_name in self.nodes.keys() {
            adjacency.entry(node_name.clone()).or_default();
        }

        // Add edges
        for edge in &self.edges {
            if let EdgeTarget::Node(ref target) = edge.to {
                adjacency
                    .entry(edge.from.clone())
                    .or_default()
                    .push(target.clone());
            }
        }

        adjacency
    }

    /// Detect cycle using DFS
    /// Returns Some(cycle_path) if a cycle is found, None otherwise
    fn detect_cycle(
        &self,
        adjacency: &HashMap<String, Vec<String>>,
        entry_point: &str,
    ) -> Option<Vec<String>> {
        let mut visited = HashSet::new();
        let mut rec_stack = HashSet::new();
        let mut path = Vec::new();

        if self.dfs_cycle_detect(
            entry_point,
            adjacency,
            &mut visited,
            &mut rec_stack,
            &mut path,
        ) {
            return Some(path);
        }

        // Also check from all nodes in case graph has disconnected components
        for node in self.nodes.keys() {
            if !visited.contains(node) {
                path.clear();
                if self.dfs_cycle_detect(node, adjacency, &mut visited, &mut rec_stack, &mut path) {
                    return Some(path);
                }
            }
        }

        None
    }

    /// DFS helper for cycle detection
    #[allow(clippy::only_used_in_recursion)]
    fn dfs_cycle_detect(
        &self,
        node: &str,
        adjacency: &HashMap<String, Vec<String>>,
        visited: &mut HashSet<String>,
        rec_stack: &mut HashSet<String>,
        path: &mut Vec<String>,
    ) -> bool {
        visited.insert(node.to_string());
        rec_stack.insert(node.to_string());
        path.push(node.to_string());

        if let Some(neighbors) = adjacency.get(node) {
            for neighbor in neighbors {
                if !visited.contains(neighbor) {
                    if self.dfs_cycle_detect(neighbor, adjacency, visited, rec_stack, path) {
                        return true;
                    }
                } else if rec_stack.contains(neighbor) {
                    // Found a cycle
                    path.push(neighbor.clone());
                    return true;
                }
            }
        }

        rec_stack.remove(node);
        path.pop();
        false
    }
}

impl Default for StateGraph {
    fn default() -> Self {
        Self::new()
    }
}