Skip to main content

graphmind/agent/
mod.rs

1//! Agentic Enrichment
2//!
3//! Implements agents that can use tools to enrich the graph.
4
5pub mod tools;
6
7use crate::nlq::client::NLQClient;
8use crate::persistence::tenant::{AgentConfig, NLQConfig};
9use async_trait::async_trait;
10use serde_json::Value;
11use std::collections::HashMap;
12use std::sync::Arc;
13use thiserror::Error;
14
15#[derive(Error, Debug)]
16pub enum AgentError {
17    #[error("Configuration error: {0}")]
18    ConfigError(String),
19    #[error("Tool error: {0}")]
20    ToolError(String),
21    #[error("LLM error: {0}")]
22    LLMError(String),
23    #[error("Execution error: {0}")]
24    ExecutionError(String),
25}
26
27pub type AgentResult<T> = Result<T, AgentError>;
28
29/// Trait for agent tools
30#[async_trait]
31pub trait Tool: Send + Sync {
32    fn name(&self) -> &str;
33    fn description(&self) -> &str;
34    fn parameters(&self) -> Value;
35    async fn execute(&self, args: Value) -> AgentResult<Value>;
36}
37
38/// Runtime for executing agents
39pub struct AgentRuntime {
40    config: AgentConfig,
41    tools: HashMap<String, Arc<dyn Tool>>,
42}
43
44impl AgentRuntime {
45    pub fn new(config: AgentConfig) -> Self {
46        Self {
47            config,
48            tools: HashMap::new(),
49        }
50    }
51
52    pub fn register_tool(&mut self, tool: Arc<dyn Tool>) {
53        self.tools.insert(tool.name().to_string(), tool);
54    }
55
56    /// Convert AgentConfig to NLQConfig for reusing the NLQ client
57    fn to_nlq_config(config: &AgentConfig) -> NLQConfig {
58        NLQConfig {
59            enabled: config.enabled,
60            provider: config.provider.clone(),
61            model: config.model.clone(),
62            api_key: config.api_key.clone(),
63            api_base_url: config.api_base_url.clone(),
64            system_prompt: config.system_prompt.clone(),
65        }
66    }
67
68    /// Process a trigger (e.g., "Enrich Company node X")
69    pub async fn process_trigger(&self, prompt: &str, _context: &str) -> AgentResult<String> {
70        let nlq_config = Self::to_nlq_config(&self.config);
71        let client =
72            NLQClient::new(&nlq_config).map_err(|e| AgentError::ConfigError(e.to_string()))?;
73        let response = client
74            .generate_cypher(prompt)
75            .await
76            .map_err(|e| AgentError::LLMError(e.to_string()))?;
77        Ok(response)
78    }
79}
80
81#[cfg(test)]
82mod tests {
83    use super::*;
84    use crate::persistence::tenant::LLMProvider;
85
86    fn mock_agent_config() -> AgentConfig {
87        AgentConfig {
88            enabled: true,
89            provider: LLMProvider::Mock,
90            model: "mock-model".to_string(),
91            api_key: None,
92            api_base_url: None,
93            system_prompt: None,
94            tools: vec![],
95            policies: std::collections::HashMap::new(),
96        }
97    }
98
99    #[test]
100    fn test_agent_runtime_new() {
101        let config = mock_agent_config();
102        let runtime = AgentRuntime::new(config);
103        assert!(runtime.tools.is_empty());
104    }
105
106    #[test]
107    fn test_register_tool() {
108        let config = mock_agent_config();
109        let mut runtime = AgentRuntime::new(config);
110
111        // Create and register WebSearchTool
112        let tool = Arc::new(tools::WebSearchTool::new("test-key".to_string()));
113        runtime.register_tool(tool);
114        assert_eq!(runtime.tools.len(), 1);
115        assert!(runtime.tools.contains_key("web_search"));
116    }
117
118    #[test]
119    fn test_to_nlq_config() {
120        let config = mock_agent_config();
121        let nlq_config = AgentRuntime::to_nlq_config(&config);
122        assert!(nlq_config.enabled);
123        assert_eq!(nlq_config.provider, LLMProvider::Mock);
124        assert_eq!(nlq_config.model, "mock-model");
125    }
126
127    #[tokio::test]
128    async fn test_process_trigger_mock() {
129        let config = mock_agent_config();
130        let runtime = AgentRuntime::new(config);
131        let result = runtime.process_trigger("Find all persons", "context").await;
132        assert!(result.is_ok());
133        let cypher = result.unwrap();
134        assert!(cypher.contains("MATCH")); // Mock returns "MATCH (n) RETURN n LIMIT 10"
135    }
136}