1use crate::error::Result;
7use agentdb::AgentDB;
8use async_trait::async_trait;
9use serde::{Deserialize, Serialize};
10use std::sync::Arc;
11use std::time::{SystemTime, UNIX_EPOCH};
12
13#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
15#[serde(rename_all = "lowercase")]
16pub enum ToolCallStatus {
17 Pending,
18 Success,
19 Error,
20}
21
22impl ToString for ToolCallStatus {
23 fn to_string(&self) -> String {
24 match self {
25 ToolCallStatus::Pending => "pending".to_string(),
26 ToolCallStatus::Success => "success".to_string(),
27 ToolCallStatus::Error => "error".to_string(),
28 }
29 }
30}
31
32impl From<&str> for ToolCallStatus {
33 fn from(s: &str) -> Self {
34 match s {
35 "success" => ToolCallStatus::Success,
36 "error" => ToolCallStatus::Error,
37 _ => ToolCallStatus::Pending,
38 }
39 }
40}
41
42#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct ToolCall {
45 pub id: i64,
46 pub name: String,
47 #[serde(skip_serializing_if = "Option::is_none")]
48 pub parameters: Option<serde_json::Value>,
49 #[serde(skip_serializing_if = "Option::is_none")]
50 pub result: Option<serde_json::Value>,
51 #[serde(skip_serializing_if = "Option::is_none")]
52 pub error: Option<String>,
53 pub status: ToolCallStatus,
54 pub started_at: i64,
55 #[serde(skip_serializing_if = "Option::is_none")]
56 pub completed_at: Option<i64>,
57 #[serde(skip_serializing_if = "Option::is_none")]
58 pub duration_ms: Option<i64>,
59}
60
61#[derive(Debug, Clone, Serialize, Deserialize)]
63pub struct ToolCallStats {
64 pub name: String,
65 pub total_calls: i64,
66 pub successful: i64,
67 pub failed: i64,
68 pub avg_duration_ms: f64,
69}
70
71#[async_trait]
73pub trait ToolRecorder: Send + Sync {
74 async fn start(&self, name: &str, parameters: Option<serde_json::Value>) -> Result<i64>;
77
78 async fn success(&self, id: i64, result: Option<serde_json::Value>) -> Result<()>;
80
81 async fn error(&self, id: i64, error: &str) -> Result<()>;
83
84 async fn get(&self, id: i64) -> Result<Option<ToolCall>>;
86
87 async fn stats_for(&self, tool_name: &str) -> Result<Option<ToolCallStats>>;
89
90 async fn record(
94 &self,
95 name: &str,
96 started_at: i64,
97 completed_at: i64,
98 parameters: Option<serde_json::Value>,
99 result: Option<serde_json::Value>,
100 error: Option<&str>,
101 ) -> Result<i64>;
102
103 async fn list(&self, limit: Option<usize>) -> Result<Vec<ToolCall>>;
105}
106
107pub struct DbToolRecorder {
109 db: Arc<Box<dyn AgentDB>>,
110}
111
112impl DbToolRecorder {
113 pub fn new(db: Arc<Box<dyn AgentDB>>) -> Self {
115 Self { db }
116 }
117
118 fn now() -> i64 {
120 SystemTime::now()
121 .duration_since(UNIX_EPOCH)
122 .unwrap()
123 .as_secs() as i64
124 }
125
126 fn parse_tool_call(&self, row: &agentdb::Row) -> Result<ToolCall> {
128 let id = self.extract_i64(row, "id")?;
129 let name = self.extract_string(row, "name")?;
130
131 let parameters_str = self.extract_string_opt(row, "parameters")?;
132 let parameters = parameters_str
133 .filter(|s| !s.is_empty())
134 .and_then(|s| serde_json::from_str(&s).ok());
135
136 let result_str = self.extract_string_opt(row, "result")?;
137 let result = result_str
138 .filter(|s| !s.is_empty())
139 .and_then(|s| serde_json::from_str(&s).ok());
140
141 let error = self.extract_string_opt(row, "error")?
142 .filter(|s| !s.is_empty());
143
144 let status_str = self.extract_string(row, "status")?;
145 let status = ToolCallStatus::from(status_str.as_str());
146
147 let started_at = self.extract_i64(row, "started_at")?;
148 let completed_at = self.extract_i64_opt(row, "completed_at")?;
149 let duration_ms = self.extract_i64_opt(row, "duration_ms")?;
150
151 Ok(ToolCall {
152 id,
153 name,
154 parameters,
155 result,
156 error,
157 status,
158 started_at,
159 completed_at,
160 duration_ms,
161 })
162 }
163
164 fn extract_i64(&self, row: &agentdb::Row, column: &str) -> Result<i64> {
166 row.get(column)
167 .ok_or_else(|| crate::error::AgentFsError::Database(
168 agentdb::AgentDbError::Backend(format!("Missing column: {}", column))
169 ))
170 .and_then(|v| {
171 let s = String::from_utf8_lossy(v.as_bytes());
172 s.parse::<i64>().map_err(|e| {
173 crate::error::AgentFsError::Database(
174 agentdb::AgentDbError::Backend(format!("Invalid i64 for {}: {}", column, e))
175 )
176 })
177 })
178 }
179
180 fn extract_i64_opt(&self, row: &agentdb::Row, column: &str) -> Result<Option<i64>> {
182 match row.get(column) {
183 None => Ok(None),
184 Some(v) => {
185 if v.as_bytes().is_empty() {
187 return Ok(None);
188 }
189 let s = String::from_utf8_lossy(v.as_bytes());
190 if s.is_empty() || s == "NULL" {
191 Ok(None)
192 } else {
193 s.parse::<i64>()
194 .map(Some)
195 .map_err(|e| crate::error::AgentFsError::Database(
196 agentdb::AgentDbError::Backend(format!("Invalid i64 for {}: {}", column, e))
197 ))
198 }
199 }
200 }
201 }
202
203 fn extract_string(&self, row: &agentdb::Row, column: &str) -> Result<String> {
205 row.get(column)
206 .ok_or_else(|| crate::error::AgentFsError::Database(
207 agentdb::AgentDbError::Backend(format!("Missing column: {}", column))
208 ))
209 .map(|v| String::from_utf8_lossy(v.as_bytes()).to_string())
210 }
211
212 fn extract_string_opt(&self, row: &agentdb::Row, column: &str) -> Result<Option<String>> {
214 Ok(row.get(column).and_then(|v| {
215 if v.as_bytes().is_empty() {
217 None
218 } else {
219 Some(String::from_utf8_lossy(v.as_bytes()).to_string())
220 }
221 }))
222 }
223}
224
225#[async_trait]
226impl ToolRecorder for DbToolRecorder {
227 async fn start(&self, name: &str, parameters: Option<serde_json::Value>) -> Result<i64> {
228 let serialized_params = parameters
229 .map(|p| serde_json::to_string(&p))
230 .transpose()?
231 .unwrap_or_default();
232
233 let started_at = Self::now();
234
235 let query = format!(
236 "INSERT INTO tool_calls (name, parameters, status, started_at) VALUES ('{}', '{}', 'pending', {})",
237 name.replace('\'', "''"),
238 serialized_params.replace('\'', "''"),
239 started_at
240 );
241
242 self.db.query(&query, vec![]).await?;
243
244 let result = self.db.query(
247 "SELECT id FROM tool_calls WHERE rowid = last_insert_rowid()",
248 vec![]
249 ).await?;
250
251 if let Some(row) = result.rows.first() {
252 self.extract_i64(row, "id")
253 } else {
254 let result = self.db.query("SELECT MAX(id) as id FROM tool_calls", vec![]).await?;
256 if let Some(row) = result.rows.first() {
257 self.extract_i64(row, "id")
258 } else {
259 Err(crate::error::AgentFsError::Database(
260 agentdb::AgentDbError::Backend("Failed to get tool call ID".to_string())
261 ))
262 }
263 }
264 }
265
266 async fn success(&self, id: i64, result: Option<serde_json::Value>) -> Result<()> {
267 let serialized_result = result
268 .map(|r| serde_json::to_string(&r))
269 .transpose()?
270 .unwrap_or_default();
271
272 let completed_at = Self::now();
273
274 let query = format!("SELECT started_at FROM tool_calls WHERE id = {}", id);
276 let res = self.db.query(&query, vec![]).await?;
277
278 let started_at = if let Some(row) = res.rows.first() {
279 self.extract_i64(row, "started_at")?
280 } else {
281 return Err(crate::error::AgentFsError::Database(
282 agentdb::AgentDbError::Backend("Tool call not found".to_string())
283 ));
284 };
285
286 let duration_ms = (completed_at - started_at) * 1000;
287
288 let query = format!(
289 "UPDATE tool_calls SET result = '{}', status = 'success', completed_at = {}, duration_ms = {} WHERE id = {}",
290 serialized_result.replace('\'', "''"),
291 completed_at,
292 duration_ms,
293 id
294 );
295
296 self.db.query(&query, vec![]).await?;
297 Ok(())
298 }
299
300 async fn error(&self, id: i64, error: &str) -> Result<()> {
301 let completed_at = Self::now();
302
303 let query = format!("SELECT started_at FROM tool_calls WHERE id = {}", id);
305 let res = self.db.query(&query, vec![]).await?;
306
307 let started_at = if let Some(row) = res.rows.first() {
308 self.extract_i64(row, "started_at")?
309 } else {
310 return Err(crate::error::AgentFsError::Database(
311 agentdb::AgentDbError::Backend("Tool call not found".to_string())
312 ));
313 };
314
315 let duration_ms = (completed_at - started_at) * 1000;
316
317 let query = format!(
318 "UPDATE tool_calls SET error = '{}', status = 'error', completed_at = {}, duration_ms = {} WHERE id = {}",
319 error.replace('\'', "''"),
320 completed_at,
321 duration_ms,
322 id
323 );
324
325 self.db.query(&query, vec![]).await?;
326 Ok(())
327 }
328
329 async fn get(&self, id: i64) -> Result<Option<ToolCall>> {
330 let query = format!(
331 "SELECT id, name, parameters, result, error, status, started_at, completed_at, duration_ms FROM tool_calls WHERE id = {}",
332 id
333 );
334
335 let result = self.db.query(&query, vec![]).await?;
336
337 if let Some(row) = result.rows.first() {
338 Ok(Some(self.parse_tool_call(row)?))
339 } else {
340 Ok(None)
341 }
342 }
343
344 async fn stats_for(&self, tool_name: &str) -> Result<Option<ToolCallStats>> {
345 let query = format!(
346 "SELECT
347 COUNT(*) as total_calls,
348 SUM(CASE WHEN status = 'success' THEN 1 ELSE 0 END) as successful,
349 SUM(CASE WHEN status = 'error' THEN 1 ELSE 0 END) as failed,
350 AVG(CASE WHEN duration_ms IS NOT NULL THEN duration_ms ELSE 0 END) as avg_duration_ms
351 FROM tool_calls
352 WHERE name = '{}'",
353 tool_name.replace('\'', "''")
354 );
355
356 let result = self.db.query(&query, vec![]).await?;
357
358 if let Some(row) = result.rows.first() {
359 let total_calls = self.extract_i64(row, "total_calls")?;
360
361 if total_calls == 0 {
362 return Ok(None);
363 }
364
365 let successful = self.extract_i64(row, "successful")?;
366 let failed = self.extract_i64(row, "failed")?;
367
368 let avg_duration_str = self.extract_string(row, "avg_duration_ms")?;
369 let avg_duration_ms = avg_duration_str.parse::<f64>().unwrap_or(0.0);
370
371 Ok(Some(ToolCallStats {
372 name: tool_name.to_string(),
373 total_calls,
374 successful,
375 failed,
376 avg_duration_ms,
377 }))
378 } else {
379 Ok(None)
380 }
381 }
382
383 async fn record(
384 &self,
385 name: &str,
386 started_at: i64,
387 completed_at: i64,
388 parameters: Option<serde_json::Value>,
389 result: Option<serde_json::Value>,
390 error: Option<&str>,
391 ) -> Result<i64> {
392 let serialized_params = parameters
393 .map(|p| serde_json::to_string(&p))
394 .transpose()?
395 .unwrap_or_default();
396
397 let serialized_result = result
398 .map(|r| serde_json::to_string(&r))
399 .transpose()?
400 .unwrap_or_default();
401
402 let duration_ms = (completed_at - started_at) * 1000;
403 let status = if error.is_some() { "error" } else { "success" };
404
405 let query = format!(
406 "INSERT INTO tool_calls (name, parameters, result, error, status, started_at, completed_at, duration_ms)
407 VALUES ('{}', '{}', '{}', '{}', '{}', {}, {}, {})",
408 name.replace('\'', "''"),
409 serialized_params.replace('\'', "''"),
410 serialized_result.replace('\'', "''"),
411 error.unwrap_or("").replace('\'', "''"),
412 status,
413 started_at,
414 completed_at,
415 duration_ms
416 );
417
418 self.db.query(&query, vec![]).await?;
419
420 let result = self.db.query(
422 "SELECT id FROM tool_calls WHERE rowid = last_insert_rowid()",
423 vec![]
424 ).await?;
425
426 if let Some(row) = result.rows.first() {
427 self.extract_i64(row, "id")
428 } else {
429 let result = self.db.query("SELECT MAX(id) as id FROM tool_calls", vec![]).await?;
431 if let Some(row) = result.rows.first() {
432 self.extract_i64(row, "id")
433 } else {
434 Err(crate::error::AgentFsError::Database(
435 agentdb::AgentDbError::Backend("Failed to get tool call ID".to_string())
436 ))
437 }
438 }
439 }
440
441 async fn list(&self, limit: Option<usize>) -> Result<Vec<ToolCall>> {
442 let limit_clause = limit
443 .map(|l| format!(" LIMIT {}", l))
444 .unwrap_or_default();
445
446 let query = format!(
447 "SELECT id, name, parameters, result, error, status, started_at, completed_at, duration_ms
448 FROM tool_calls
449 ORDER BY started_at DESC{}",
450 limit_clause
451 );
452
453 let result = self.db.query(&query, vec![]).await?;
454
455 let mut tool_calls = Vec::new();
456 for row in &result.rows {
457 tool_calls.push(self.parse_tool_call(row)?);
458 }
459
460 Ok(tool_calls)
461 }
462}