enact-core 0.0.2

Core agent runtime for Enact - Graph-Native AI agents
Documentation
//! GraphLoader - loads graph definitions from YAML
//!
//! Supports provider injection for LLM nodes through `GraphLoaderContext`.

use super::node::LlmNode;
use super::schema::{GraphDefinition, NodeDefinition};
use super::{EdgeTarget, NodeState, StateGraph};
use crate::providers::ModelProvider;
use crate::routing::{resolve_model_precedence, DEFAULT_MODEL_ROUTER_ID};
use anyhow::{anyhow, Context, Result};
use std::sync::Arc;

/// Context for loading graphs with runtime dependencies
///
/// Provides optional access to model providers for creating functional LLM nodes.
/// When no provider is set, LLM nodes will be created as placeholders that
/// log their invocation but don't make actual LLM calls.
#[derive(Default, Clone)]
pub struct GraphLoaderContext {
    /// Model provider for LLM nodes
    pub provider: Option<Arc<dyn ModelProvider>>,
    /// Agent-level default model pin used in precedence resolution.
    pub default_model: Option<String>,
}

impl GraphLoaderContext {
    /// Create a new empty context
    pub fn new() -> Self {
        Self::default()
    }

    /// Create a context with a model provider
    pub fn with_provider(provider: Arc<dyn ModelProvider>) -> Self {
        Self {
            provider: Some(provider),
            default_model: None,
        }
    }

    /// Create a context with provider and agent-level default model.
    pub fn with_provider_and_model(
        provider: Arc<dyn ModelProvider>,
        default_model: impl Into<String>,
    ) -> Self {
        Self {
            provider: Some(provider),
            default_model: Some(default_model.into()),
        }
    }
}

pub struct GraphLoader;

impl GraphLoader {
    /// Load a graph from YAML string
    ///
    /// LLM nodes will be created as placeholders that log but don't execute.
    /// Use `load_from_str_with_context` to provide a model provider for
    /// functional LLM nodes.
    pub fn load_from_str(yaml: &str) -> Result<StateGraph> {
        Self::load_from_str_with_context(yaml, &GraphLoaderContext::default())
    }

    /// Load a graph from YAML string with runtime context
    ///
    /// If the context contains a model provider, LLM nodes will be created
    /// as functional nodes that execute actual LLM calls.
    pub fn load_from_str_with_context(yaml: &str, ctx: &GraphLoaderContext) -> Result<StateGraph> {
        let def: GraphDefinition =
            serde_yaml::from_str(yaml).context("Failed to parse graph definition YAML")?;

        let mut graph = StateGraph::new();

        // 1. Add all nodes
        for (name, node_def) in &def.nodes {
            match node_def {
                NodeDefinition::Llm {
                    model,
                    system_prompt,
                    ..
                } => {
                    let (resolved_model, selection_source) = resolve_model_precedence(
                        model.as_deref(),
                        def.model.as_deref(),
                        ctx.default_model.as_deref(),
                    );

                    if let Some(provider) = &ctx.provider {
                        // Create functional LLM node with real provider
                        tracing::debug!(
                            node = %name,
                            selected_model = %resolved_model,
                            source = ?selection_source,
                            "Resolved LLM node model using precedence"
                        );

                        let llm_node = LlmNode::with_model(
                            name.clone(),
                            system_prompt.clone(),
                            resolved_model,
                            provider.clone(),
                        );
                        graph = graph.add_node_impl(llm_node);
                    } else {
                        let name_clone = name.clone();
                        let model = if resolved_model.is_empty() {
                            DEFAULT_MODEL_ROUTER_ID.to_string()
                        } else {
                            resolved_model
                        };
                        let prompt = system_prompt.clone();

                        graph = graph.add_node(name, move |_state: NodeState| {
                            let n = name_clone.clone();
                            let m = model.clone();
                            let p = prompt.clone();
                            async move {
                                tracing::error!(
                                    node = %n,
                                    model = %m,
                                    prompt = %p,
                                    "LLM node requires a model provider but none was configured"
                                );
                                Err(anyhow::anyhow!(
                                    "LLM node '{}' requires a model provider. \
                                     Use GraphLoaderContext::with_provider() when loading the graph.",
                                    n
                                ))
                            }
                        });
                    }
                }
                NodeDefinition::Function { action, .. } => {
                    let name_clone = name.clone();
                    let action = action.clone();

                    graph = graph.add_node(name, move |state: NodeState| {
                        let n = name_clone.clone();
                        let a = action.clone();
                        async move {
                            println!("⚙️ [Function Node: {}] Action: {}", n, a);
                            // In a real implementation, this would execute the action command
                            // For now, allow simple "echo" for testing
                            if a.starts_with("echo ") {
                                let output = a.trim_start_matches("echo ").to_string();
                                return Ok(NodeState::from_string(&output));
                            }
                            Ok(state)
                        }
                    });
                }
                NodeDefinition::Condition { expr, .. } => {
                    let name_clone = name.clone();
                    let expr = expr.clone();

                    // Condition node evaluates expression and returns the result key
                    // (which matches an edge key)
                    graph = graph.add_node(name, move |state: NodeState| {
                        let n = name_clone.clone();
                        let e = expr.clone();
                        async move {
                            println!("❓ [Condition Node: {}] Expr: {}", n, e);
                            // Simple mock evaluation
                            // If input contains "error", return "error", else "ok"
                            let input = state.as_str().unwrap_or("");
                            if e.contains("contains('error')") {
                                if input.contains("error") {
                                    return Ok(NodeState::from_string("error"));
                                } else {
                                    return Ok(NodeState::from_string("ok"));
                                }
                            }
                            Ok(NodeState::from_string("default"))
                        }
                    });
                }
                _ => {
                    return Err(anyhow!("Unsupported node type in yaml"));
                }
            }
        }

        // 2. Add edges
        for (name, node_def) in &def.nodes {
            let edges = node_def.edges();

            // Check if this is a conditional node (router)
            // If it has multiple edges with keys other than "_default",
            // valid keys are the outputs of the previous node.

            // For Llm/Function nodes, usually they have a single "_default" edge
            // or specific keys if they return structured data?
            // The schema implies simple string matching on output.

            let is_conditional = matches!(node_def, NodeDefinition::Condition { .. });

            if is_conditional {
                // Conditional edges based on node output
                let edges_clone = edges.clone();
                let router = move |output: &str| -> EdgeTarget {
                    if let Some(target) = edges_clone.get(output) {
                        if target == "END" {
                            EdgeTarget::End
                        } else {
                            EdgeTarget::Node(target.clone())
                        }
                    } else if let Some(default) = edges_clone.get("_default") {
                        if default == "END" {
                            EdgeTarget::End
                        } else {
                            EdgeTarget::Node(default.clone())
                        }
                    } else {
                        EdgeTarget::End
                    }
                };

                graph = graph.add_conditional_edge(name, router);
            } else {
                // Standard edges
                // TODO: Support branching from non-condition nodes?
                // For now, assume "_default" is the main edge
                if let Some(target) = edges.get("_default") {
                    if target == "END" {
                        graph = graph.add_edge_to_end(name);
                    } else {
                        graph = graph.add_edge(name, target);
                    }
                }
            }
        }

        // 3. Set entry point
        // Ideally schema allows defining it, or we use first node?
        // Current StateGraph defaults to first node if not set.
        // We could look for "start" or "input" node?
        // The implementation_plan example didn't specify entry point explicitly.
        // Let's assume the first defined node in YAML (but HashMap is unordered).
        // Use "start" or "input" if present, else random?
        // Better: require `triggers` or look for a node named "start".

        if def.nodes.contains_key("start") {
            graph = graph.set_entry_point("start");
        } else if def.nodes.contains_key("input") {
            graph = graph.set_entry_point("input");
        }

        Ok(graph)
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::providers::{ChatChoice, ChatMessage, ChatRequest, ChatResponse};
    use async_trait::async_trait;

    /// Mock provider for testing LLM node creation
    struct MockProvider {
        response: String,
    }

    impl MockProvider {
        fn new(response: impl Into<String>) -> Self {
            Self {
                response: response.into(),
            }
        }
    }

    #[async_trait]
    impl ModelProvider for MockProvider {
        fn name(&self) -> &str {
            "mock"
        }

        async fn chat(&self, _request: ChatRequest) -> anyhow::Result<ChatResponse> {
            Ok(ChatResponse {
                id: "mock-id".to_string(),
                choices: vec![ChatChoice {
                    index: 0,
                    message: ChatMessage::assistant(&self.response),
                    finish_reason: Some("stop".to_string()),
                }],
                usage: None,
            })
        }
    }

    const SIMPLE_GRAPH_YAML: &str = r#"
name: test-graph
version: "1.0"
nodes:
  start:
    type: llm
    system_prompt: "You are a helpful assistant"
    edges:
      _default: END
"#;

    #[test]
    fn test_load_without_context_creates_placeholder() {
        let graph = GraphLoader::load_from_str(SIMPLE_GRAPH_YAML).unwrap();
        assert!(graph.nodes.contains_key("start"));
    }

    #[test]
    fn test_load_with_context_creates_functional_node() {
        let provider = Arc::new(MockProvider::new("Hello!"));
        let ctx = GraphLoaderContext::with_provider(provider);

        let graph = GraphLoader::load_from_str_with_context(SIMPLE_GRAPH_YAML, &ctx).unwrap();
        assert!(graph.nodes.contains_key("start"));
    }

    #[tokio::test]
    async fn test_functional_llm_node_executes() {
        let provider = Arc::new(MockProvider::new("LLM Response"));
        let ctx = GraphLoaderContext::with_provider(provider);

        let graph = GraphLoader::load_from_str_with_context(SIMPLE_GRAPH_YAML, &ctx).unwrap();
        let compiled = graph.compile().unwrap();

        // Execute the graph
        let result = compiled.run("User input").await.unwrap();

        assert_eq!(result.as_str(), Some("LLM Response"));
    }

    #[test]
    fn test_context_builder() {
        let ctx = GraphLoaderContext::new();
        assert!(ctx.provider.is_none());
        assert!(ctx.default_model.is_none());

        let provider = Arc::new(MockProvider::new("test"));
        let ctx = GraphLoaderContext::with_provider(provider);
        assert!(ctx.provider.is_some());
        assert!(ctx.default_model.is_none());
    }

    #[test]
    fn test_context_builder_with_default_model() {
        let provider = Arc::new(MockProvider::new("test"));
        let ctx = GraphLoaderContext::with_provider_and_model(provider, "agent/default-model");
        assert!(ctx.provider.is_some());
        assert_eq!(ctx.default_model.as_deref(), Some("agent/default-model"));
    }

    const MULTI_NODE_YAML: &str = r#"
name: multi-node-graph
version: "1.0"
nodes:
  start:
    type: llm
    model: gpt-4
    system_prompt: "Process the input"
    edges:
      _default: check
  check:
    type: condition
    expr: "input.contains('error')"
    edges:
      error: handle_error
      ok: END
  handle_error:
    type: function
    action: "echo error handled"
    edges:
      _default: END
"#;

    #[test]
    fn test_multi_node_graph_loading() {
        let provider = Arc::new(MockProvider::new("processed"));
        let ctx = GraphLoaderContext::with_provider(provider);

        let graph = GraphLoader::load_from_str_with_context(MULTI_NODE_YAML, &ctx).unwrap();

        assert!(graph.nodes.contains_key("start"));
        assert!(graph.nodes.contains_key("check"));
        assert!(graph.nodes.contains_key("handle_error"));
    }
}