kowalski_data_agent/
agent.rs1use crate::config::DataAgentConfig;
2use async_trait::async_trait;
3use kowalski_agent_template::TemplateAgent;
4use kowalski_agent_template::templates::general::GeneralTemplate;
5use kowalski_core::agent::Agent;
6use kowalski_core::config::Config;
7use kowalski_core::conversation::Conversation;
8use kowalski_core::error::KowalskiError;
9use kowalski_core::role::Role;
10use kowalski_core::tools::{Tool, ToolInput};
11use kowalski_tools::data::CsvTool;
12use serde::{Deserialize, Serialize};
13
14pub struct DataAgent {
17 agent: TemplateAgent,
18 config: DataAgentConfig,
19}
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct CsvAnalysisResult {
23 pub headers: Vec<String>,
24 pub total_rows: usize,
25 pub total_columns: usize,
26 pub summary: serde_json::Value,
27}
28
29impl DataAgent {
30 pub async fn new(_config: Config) -> Result<Self, KowalskiError> {
32 let data_config = DataAgentConfig::default();
34 let csv_tool = CsvTool::new(data_config.max_rows, data_config.max_columns);
35
36 let tools: Vec<Box<dyn Tool + Send + Sync>> = vec![Box::new(csv_tool)];
37 let builder = GeneralTemplate::create_agent(
38 tools,
39 Some("You are a data analysis assistant specialized in processing and analyzing structured data. You have access to the csv_tool. Use it to answer questions about data analysis.".to_string()),
40 Some(0.7),
41 )
42 .await
43 .map_err(|e| KowalskiError::Configuration(e.to_string()))?;
44 let agent = builder.build().await?;
45
46 Ok(Self {
47 agent,
48 config: data_config,
49 })
50 }
51
52 pub async fn process_csv(&self, csv_content: &str) -> Result<CsvAnalysisResult, KowalskiError> {
54 let mut tools = self.agent.tool_chain.write().await;
55 let tool = tools.iter_mut().find(|t| t.name() == "csv_tool");
56 let tool = match tool {
57 Some(t) => t,
58 None => {
59 return Err(KowalskiError::ToolExecution(
60 "csv_tool not found".to_string(),
61 ));
62 }
63 };
64 let input = ToolInput::new(
65 "process_csv".to_string(),
66 csv_content.to_string(),
67 serde_json::json!({
68 "max_rows": self.config.max_rows,
69 "max_columns": self.config.max_columns
70 }),
71 );
72 let output = tool.execute(input).await?;
73
74 let result = output.result;
75 Ok(CsvAnalysisResult {
76 headers: result["headers"]
77 .as_array()
78 .unwrap_or(&Vec::new())
79 .iter()
80 .filter_map(|v| v.as_str().map(|s| s.to_string()))
81 .collect(),
82 total_rows: result["total_rows"].as_u64().unwrap_or_default() as usize,
83 total_columns: result["total_columns"].as_u64().unwrap_or_default() as usize,
84 summary: result["summary"].clone(),
85 })
86 }
87
88 pub async fn analyze_data(
90 &self,
91 csv_content: &str,
92 ) -> Result<serde_json::Value, KowalskiError> {
93 let mut tools = self.agent.tool_chain.write().await;
94 let tool = tools.iter_mut().find(|t| t.name() == "csv_tool");
95 let tool = match tool {
96 Some(t) => t,
97 None => {
98 return Err(KowalskiError::ToolExecution(
99 "csv_tool not found".to_string(),
100 ));
101 }
102 };
103 let input = ToolInput::new(
104 "analyze_csv".to_string(),
105 csv_content.to_string(),
106 serde_json::json!({}),
107 );
108 let output = tool.execute(input).await?;
109 Ok(output.result)
110 }
111}
112
113#[async_trait]
114impl Agent for DataAgent {
115 async fn new(config: Config) -> Result<Self, KowalskiError> {
116 DataAgent::new(config).await
117 }
118
119 fn start_conversation(&mut self, model: &str) -> String {
120 self.agent.base_mut().start_conversation(model)
121 }
122
123 fn get_conversation(&self, id: &str) -> Option<&Conversation> {
124 self.agent.base().get_conversation(id)
125 }
126
127 fn list_conversations(&self) -> Vec<&Conversation> {
128 self.agent.base().list_conversations()
129 }
130
131 fn delete_conversation(&mut self, id: &str) -> bool {
132 self.agent.base_mut().delete_conversation(id)
133 }
134
135 async fn chat_with_history(
136 &mut self,
137 conversation_id: &str,
138 content: &str,
139 role: Option<Role>,
140 ) -> Result<reqwest::Response, KowalskiError> {
141 self.agent
142 .base_mut()
143 .chat_with_history(conversation_id, content, role)
144 .await
145 }
146
147 async fn process_stream_response(
148 &mut self,
149 conversation_id: &str,
150 chunk: &[u8],
151 ) -> Result<Option<kowalski_core::conversation::Message>, KowalskiError> {
152 self.agent
153 .base_mut()
154 .process_stream_response(conversation_id, chunk)
155 .await
156 }
157
158 async fn add_message(&mut self, conversation_id: &str, role: &str, content: &str) {
159 self.agent
160 .base_mut()
161 .add_message(conversation_id, role, content)
162 .await;
163 }
164
165 fn name(&self) -> &str {
166 self.agent.base().name()
167 }
168
169 fn description(&self) -> &str {
170 self.agent.base().description()
171 }
172}
173
174#[cfg(test)]
175mod tests {
176 use super::*;
177 use kowalski_core::config::Config;
178
179 #[tokio::test]
180 async fn test_data_agent_creation() {
181 let config = Config::default();
182 let agent = DataAgent::new(config).await;
183 assert!(agent.is_ok());
184 }
185
186 #[tokio::test]
187 async fn test_data_agent_conversation() {
188 let config = Config::default();
189 let mut agent = DataAgent::new(config).await.unwrap();
190 let conv_id = agent.start_conversation("test-model");
191 assert!(!conv_id.is_empty());
192 }
193}