1pub mod tools;
6
7use crate::nlq::client::NLQClient;
8use crate::persistence::tenant::{AgentConfig, NLQConfig};
9use async_trait::async_trait;
10use serde_json::Value;
11use std::collections::HashMap;
12use std::sync::Arc;
13use thiserror::Error;
14
15#[derive(Error, Debug)]
16pub enum AgentError {
17 #[error("Configuration error: {0}")]
18 ConfigError(String),
19 #[error("Tool error: {0}")]
20 ToolError(String),
21 #[error("LLM error: {0}")]
22 LLMError(String),
23 #[error("Execution error: {0}")]
24 ExecutionError(String),
25}
26
27pub type AgentResult<T> = Result<T, AgentError>;
28
29#[async_trait]
31pub trait Tool: Send + Sync {
32 fn name(&self) -> &str;
33 fn description(&self) -> &str;
34 fn parameters(&self) -> Value;
35 async fn execute(&self, args: Value) -> AgentResult<Value>;
36}
37
38pub struct AgentRuntime {
40 config: AgentConfig,
41 tools: HashMap<String, Arc<dyn Tool>>,
42}
43
44impl AgentRuntime {
45 pub fn new(config: AgentConfig) -> Self {
46 Self {
47 config,
48 tools: HashMap::new(),
49 }
50 }
51
52 pub fn register_tool(&mut self, tool: Arc<dyn Tool>) {
53 self.tools.insert(tool.name().to_string(), tool);
54 }
55
56 fn to_nlq_config(config: &AgentConfig) -> NLQConfig {
58 NLQConfig {
59 enabled: config.enabled,
60 provider: config.provider.clone(),
61 model: config.model.clone(),
62 api_key: config.api_key.clone(),
63 api_base_url: config.api_base_url.clone(),
64 system_prompt: config.system_prompt.clone(),
65 }
66 }
67
68 pub async fn process_trigger(&self, prompt: &str, _context: &str) -> AgentResult<String> {
70 let nlq_config = Self::to_nlq_config(&self.config);
71 let client =
72 NLQClient::new(&nlq_config).map_err(|e| AgentError::ConfigError(e.to_string()))?;
73 let response = client
74 .generate_cypher(prompt)
75 .await
76 .map_err(|e| AgentError::LLMError(e.to_string()))?;
77 Ok(response)
78 }
79}
80
81#[cfg(test)]
82mod tests {
83 use super::*;
84 use crate::persistence::tenant::LLMProvider;
85
86 fn mock_agent_config() -> AgentConfig {
87 AgentConfig {
88 enabled: true,
89 provider: LLMProvider::Mock,
90 model: "mock-model".to_string(),
91 api_key: None,
92 api_base_url: None,
93 system_prompt: None,
94 tools: vec![],
95 policies: std::collections::HashMap::new(),
96 }
97 }
98
99 #[test]
100 fn test_agent_runtime_new() {
101 let config = mock_agent_config();
102 let runtime = AgentRuntime::new(config);
103 assert!(runtime.tools.is_empty());
104 }
105
106 #[test]
107 fn test_register_tool() {
108 let config = mock_agent_config();
109 let mut runtime = AgentRuntime::new(config);
110
111 let tool = Arc::new(tools::WebSearchTool::new("test-key".to_string()));
113 runtime.register_tool(tool);
114 assert_eq!(runtime.tools.len(), 1);
115 assert!(runtime.tools.contains_key("web_search"));
116 }
117
118 #[test]
119 fn test_to_nlq_config() {
120 let config = mock_agent_config();
121 let nlq_config = AgentRuntime::to_nlq_config(&config);
122 assert!(nlq_config.enabled);
123 assert_eq!(nlq_config.provider, LLMProvider::Mock);
124 assert_eq!(nlq_config.model, "mock-model");
125 }
126
127 #[tokio::test]
128 async fn test_process_trigger_mock() {
129 let config = mock_agent_config();
130 let runtime = AgentRuntime::new(config);
131 let result = runtime.process_trigger("Find all persons", "context").await;
132 assert!(result.is_ok());
133 let cypher = result.unwrap();
134 assert!(cypher.contains("MATCH")); }
136}