kowalski_data_agent/
agent.rs

1use crate::config::DataAgentConfig;
2use async_trait::async_trait;
3use kowalski_agent_template::TemplateAgent;
4use kowalski_agent_template::templates::general::GeneralTemplate;
5use kowalski_core::agent::Agent;
6use kowalski_core::config::Config;
7use kowalski_core::conversation::Conversation;
8use kowalski_core::error::KowalskiError;
9use kowalski_core::role::Role;
10use kowalski_core::tools::{Tool, ToolInput};
11use kowalski_tools::data::CsvTool;
12use serde::{Deserialize, Serialize};
13
14/// DataAgent: A specialized agent for data analysis and processing tasks
15/// This agent is built on top of the TemplateAgent and provides data-specific functionality
16pub struct DataAgent {
17    agent: TemplateAgent,
18    config: DataAgentConfig,
19}
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct CsvAnalysisResult {
23    pub headers: Vec<String>,
24    pub total_rows: usize,
25    pub total_columns: usize,
26    pub summary: serde_json::Value,
27}
28
29impl DataAgent {
30    /// Creates a new DataAgent with the specified configuration
31    pub async fn new(_config: Config) -> Result<Self, KowalskiError> {
32        // TODO: Convert Config to DataAgentConfig if needed
33        let data_config = DataAgentConfig::default();
34        let csv_tool = CsvTool::new(data_config.max_rows, data_config.max_columns);
35
36        let tools: Vec<Box<dyn Tool + Send + Sync>> = vec![Box::new(csv_tool)];
37        let builder = GeneralTemplate::create_agent(
38            tools,
39            Some("You are a data analysis assistant specialized in processing and analyzing structured data. You have access to the csv_tool. Use it to answer questions about data analysis.".to_string()),
40            Some(0.7),
41        )
42        .await
43        .map_err(|e| KowalskiError::Configuration(e.to_string()))?;
44        let agent = builder.build().await?;
45
46        Ok(Self {
47            agent,
48            config: data_config,
49        })
50    }
51
52    /// Processes a CSV file
53    pub async fn process_csv(&self, csv_content: &str) -> Result<CsvAnalysisResult, KowalskiError> {
54        let mut tools = self.agent.tool_chain.write().await;
55        let tool = tools.iter_mut().find(|t| t.name() == "csv_tool");
56        let tool = match tool {
57            Some(t) => t,
58            None => {
59                return Err(KowalskiError::ToolExecution(
60                    "csv_tool not found".to_string(),
61                ));
62            }
63        };
64        let input = ToolInput::new(
65            "process_csv".to_string(),
66            csv_content.to_string(),
67            serde_json::json!({
68                "max_rows": self.config.max_rows,
69                "max_columns": self.config.max_columns
70            }),
71        );
72        let output = tool.execute(input).await?;
73
74        let result = output.result;
75        Ok(CsvAnalysisResult {
76            headers: result["headers"]
77                .as_array()
78                .unwrap_or(&Vec::new())
79                .iter()
80                .filter_map(|v| v.as_str().map(|s| s.to_string()))
81                .collect(),
82            total_rows: result["total_rows"].as_u64().unwrap_or_default() as usize,
83            total_columns: result["total_columns"].as_u64().unwrap_or_default() as usize,
84            summary: result["summary"].clone(),
85        })
86    }
87
88    /// Analyzes data statistics
89    pub async fn analyze_data(
90        &self,
91        csv_content: &str,
92    ) -> Result<serde_json::Value, KowalskiError> {
93        let mut tools = self.agent.tool_chain.write().await;
94        let tool = tools.iter_mut().find(|t| t.name() == "csv_tool");
95        let tool = match tool {
96            Some(t) => t,
97            None => {
98                return Err(KowalskiError::ToolExecution(
99                    "csv_tool not found".to_string(),
100                ));
101            }
102        };
103        let input = ToolInput::new(
104            "analyze_csv".to_string(),
105            csv_content.to_string(),
106            serde_json::json!({}),
107        );
108        let output = tool.execute(input).await?;
109        Ok(output.result)
110    }
111}
112
113#[async_trait]
114impl Agent for DataAgent {
115    async fn new(config: Config) -> Result<Self, KowalskiError> {
116        DataAgent::new(config).await
117    }
118
119    fn start_conversation(&mut self, model: &str) -> String {
120        self.agent.base_mut().start_conversation(model)
121    }
122
123    fn get_conversation(&self, id: &str) -> Option<&Conversation> {
124        self.agent.base().get_conversation(id)
125    }
126
127    fn list_conversations(&self) -> Vec<&Conversation> {
128        self.agent.base().list_conversations()
129    }
130
131    fn delete_conversation(&mut self, id: &str) -> bool {
132        self.agent.base_mut().delete_conversation(id)
133    }
134
135    async fn chat_with_history(
136        &mut self,
137        conversation_id: &str,
138        content: &str,
139        role: Option<Role>,
140    ) -> Result<reqwest::Response, KowalskiError> {
141        self.agent
142            .base_mut()
143            .chat_with_history(conversation_id, content, role)
144            .await
145    }
146
147    async fn process_stream_response(
148        &mut self,
149        conversation_id: &str,
150        chunk: &[u8],
151    ) -> Result<Option<kowalski_core::conversation::Message>, KowalskiError> {
152        self.agent
153            .base_mut()
154            .process_stream_response(conversation_id, chunk)
155            .await
156    }
157
158    async fn add_message(&mut self, conversation_id: &str, role: &str, content: &str) {
159        self.agent
160            .base_mut()
161            .add_message(conversation_id, role, content)
162            .await;
163    }
164
165    fn name(&self) -> &str {
166        self.agent.base().name()
167    }
168
169    fn description(&self) -> &str {
170        self.agent.base().description()
171    }
172}
173
174#[cfg(test)]
175mod tests {
176    use super::*;
177    use kowalski_core::config::Config;
178
179    #[tokio::test]
180    async fn test_data_agent_creation() {
181        let config = Config::default();
182        let agent = DataAgent::new(config).await;
183        assert!(agent.is_ok());
184    }
185
186    #[tokio::test]
187    async fn test_data_agent_conversation() {
188        let config = Config::default();
189        let mut agent = DataAgent::new(config).await.unwrap();
190        let conv_id = agent.start_conversation("test-model");
191        assert!(!conv_id.is_empty());
192    }
193}