enact-core 0.0.2

Core agent runtime for Enact - Graph-Native AI agents
Documentation
//! LLM Node - wraps LlmCallable for use in graphs

use super::{Node, NodeState};
use crate::callable::LlmCallable;
use crate::providers::ModelProvider;
use async_trait::async_trait;
use std::sync::Arc;

/// LLM Node - executes an LLM callable within a graph
///
/// This node wraps an `LlmCallable` and adapts it to the `Node` interface.
/// The input `NodeState` is converted to a prompt, and the LLM response
/// is converted back to a `NodeState` for downstream nodes.
pub struct LlmNode {
    name: String,
    callable: LlmCallable,
}

impl LlmNode {
    /// Create a new LLM node with a provider
    pub fn new(
        name: impl Into<String>,
        system_prompt: impl Into<String>,
        provider: Arc<dyn ModelProvider>,
    ) -> Self {
        let name = name.into();
        let callable = LlmCallable::with_provider(name.clone(), system_prompt, provider);
        Self { name, callable }
    }

    /// Create with explicit model pin.
    pub fn with_model(
        name: impl Into<String>,
        system_prompt: impl Into<String>,
        model: impl Into<String>,
        provider: Arc<dyn ModelProvider>,
    ) -> Self {
        let name = name.into();
        let callable =
            LlmCallable::with_provider(name.clone(), system_prompt, provider).with_model(model);
        Self { name, callable }
    }

    /// Add tools to the underlying callable
    pub fn with_tools(mut self, tools: Vec<crate::tool::DynTool>) -> Self {
        self.callable = self.callable.add_tools(tools);
        self
    }
}

#[async_trait]
impl Node for LlmNode {
    fn name(&self) -> &str {
        &self.name
    }

    async fn execute(&self, state: NodeState) -> anyhow::Result<NodeState> {
        // Convert NodeState to input string for the LLM
        let input = match &state.data {
            serde_json::Value::String(s) => s.clone(),
            serde_json::Value::Null => String::new(),
            other => serde_json::to_string(other)?,
        };

        // Execute the callable
        use crate::callable::Callable;
        let response = self.callable.run(&input).await?;

        // Convert response back to NodeState
        Ok(NodeState::from_string(&response))
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::providers::{ChatChoice, ChatMessage, ChatRequest, ChatResponse};
    use async_trait::async_trait;

    /// Mock provider for testing
    struct MockProvider {
        response: String,
    }

    impl MockProvider {
        fn new(response: impl Into<String>) -> Self {
            Self {
                response: response.into(),
            }
        }
    }

    #[async_trait]
    impl ModelProvider for MockProvider {
        fn name(&self) -> &str {
            "mock"
        }

        async fn chat(&self, _request: ChatRequest) -> anyhow::Result<ChatResponse> {
            Ok(ChatResponse {
                id: "mock-id".to_string(),
                choices: vec![ChatChoice {
                    index: 0,
                    message: ChatMessage::assistant(&self.response),
                    finish_reason: Some("stop".to_string()),
                }],
                usage: None,
            })
        }
    }

    #[tokio::test]
    async fn test_llm_node_execute() {
        let provider = Arc::new(MockProvider::new("Hello, world!"));
        let node = LlmNode::new("test_node", "You are a helpful assistant", provider);

        assert_eq!(node.name(), "test_node");

        let input = NodeState::from_string("Say hello");
        let result = node.execute(input).await.unwrap();

        assert_eq!(result.as_str(), Some("Hello, world!"));
    }

    #[tokio::test]
    async fn test_llm_node_with_json_input() {
        let provider = Arc::new(MockProvider::new("Processed JSON"));
        let node = LlmNode::new("json_node", "Process the input", provider);

        let input = NodeState::from_value(serde_json::json!({"key": "value"}));
        let result = node.execute(input).await.unwrap();

        assert_eq!(result.as_str(), Some("Processed JSON"));
    }

    #[tokio::test]
    async fn test_llm_node_with_empty_input() {
        let provider = Arc::new(MockProvider::new("Default response"));
        let node = LlmNode::new("empty_node", "Handle empty input", provider);

        let input = NodeState::new();
        let result = node.execute(input).await.unwrap();

        assert_eq!(result.as_str(), Some("Default response"));
    }
}