agentfs/
tools.rs

1//! Tool call recording and auditing
2//!
3//! This module provides functionality for recording and tracking tool calls made by AI agents.
4//! It supports both a workflow-based API (start -> success/error) and a single-shot record API.
5
6use crate::error::Result;
7use agentdb::AgentDB;
8use async_trait::async_trait;
9use serde::{Deserialize, Serialize};
10use std::sync::Arc;
11use std::time::{SystemTime, UNIX_EPOCH};
12
13/// Status of a tool call
14#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
15#[serde(rename_all = "lowercase")]
16pub enum ToolCallStatus {
17    Pending,
18    Success,
19    Error,
20}
21
22impl ToString for ToolCallStatus {
23    fn to_string(&self) -> String {
24        match self {
25            ToolCallStatus::Pending => "pending".to_string(),
26            ToolCallStatus::Success => "success".to_string(),
27            ToolCallStatus::Error => "error".to_string(),
28        }
29    }
30}
31
32impl From<&str> for ToolCallStatus {
33    fn from(s: &str) -> Self {
34        match s {
35            "success" => ToolCallStatus::Success,
36            "error" => ToolCallStatus::Error,
37            _ => ToolCallStatus::Pending,
38        }
39    }
40}
41
42/// Tool call record
43#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct ToolCall {
45    pub id: i64,
46    pub name: String,
47    #[serde(skip_serializing_if = "Option::is_none")]
48    pub parameters: Option<serde_json::Value>,
49    #[serde(skip_serializing_if = "Option::is_none")]
50    pub result: Option<serde_json::Value>,
51    #[serde(skip_serializing_if = "Option::is_none")]
52    pub error: Option<String>,
53    pub status: ToolCallStatus,
54    pub started_at: i64,
55    #[serde(skip_serializing_if = "Option::is_none")]
56    pub completed_at: Option<i64>,
57    #[serde(skip_serializing_if = "Option::is_none")]
58    pub duration_ms: Option<i64>,
59}
60
61/// Statistics for a specific tool
62#[derive(Debug, Clone, Serialize, Deserialize)]
63pub struct ToolCallStats {
64    pub name: String,
65    pub total_calls: i64,
66    pub successful: i64,
67    pub failed: i64,
68    pub avg_duration_ms: f64,
69}
70
71/// Tool recorder trait for auditing agent tool calls
72#[async_trait]
73pub trait ToolRecorder: Send + Sync {
74    /// Start a new tool call and mark it as pending
75    /// Returns the ID of the created tool call record
76    async fn start(&self, name: &str, parameters: Option<serde_json::Value>) -> Result<i64>;
77
78    /// Mark a tool call as successful
79    async fn success(&self, id: i64, result: Option<serde_json::Value>) -> Result<()>;
80
81    /// Mark a tool call as failed
82    async fn error(&self, id: i64, error: &str) -> Result<()>;
83
84    /// Get a specific tool call by ID
85    async fn get(&self, id: i64) -> Result<Option<ToolCall>>;
86
87    /// Get statistics for a specific tool
88    async fn stats_for(&self, tool_name: &str) -> Result<Option<ToolCallStats>>;
89
90    /// Record a completed tool call (single-shot method)
91    /// Either result or error should be provided, not both
92    /// Returns the ID of the created tool call record
93    async fn record(
94        &self,
95        name: &str,
96        started_at: i64,
97        completed_at: i64,
98        parameters: Option<serde_json::Value>,
99        result: Option<serde_json::Value>,
100        error: Option<&str>,
101    ) -> Result<i64>;
102
103    /// Get all tool calls (optionally limited)
104    async fn list(&self, limit: Option<usize>) -> Result<Vec<ToolCall>>;
105}
106
107/// Database-backed tool recorder
108pub struct DbToolRecorder {
109    db: Arc<Box<dyn AgentDB>>,
110}
111
112impl DbToolRecorder {
113    /// Create a new database-backed tool recorder
114    pub fn new(db: Arc<Box<dyn AgentDB>>) -> Self {
115        Self { db }
116    }
117
118    /// Get current Unix timestamp in seconds
119    fn now() -> i64 {
120        SystemTime::now()
121            .duration_since(UNIX_EPOCH)
122            .unwrap()
123            .as_secs() as i64
124    }
125
126    /// Parse a tool call from a database row
127    fn parse_tool_call(&self, row: &agentdb::Row) -> Result<ToolCall> {
128        let id = self.extract_i64(row, "id")?;
129        let name = self.extract_string(row, "name")?;
130
131        let parameters_str = self.extract_string_opt(row, "parameters")?;
132        let parameters = parameters_str
133            .filter(|s| !s.is_empty())
134            .and_then(|s| serde_json::from_str(&s).ok());
135
136        let result_str = self.extract_string_opt(row, "result")?;
137        let result = result_str
138            .filter(|s| !s.is_empty())
139            .and_then(|s| serde_json::from_str(&s).ok());
140
141        let error = self.extract_string_opt(row, "error")?
142            .filter(|s| !s.is_empty());
143
144        let status_str = self.extract_string(row, "status")?;
145        let status = ToolCallStatus::from(status_str.as_str());
146
147        let started_at = self.extract_i64(row, "started_at")?;
148        let completed_at = self.extract_i64_opt(row, "completed_at")?;
149        let duration_ms = self.extract_i64_opt(row, "duration_ms")?;
150
151        Ok(ToolCall {
152            id,
153            name,
154            parameters,
155            result,
156            error,
157            status,
158            started_at,
159            completed_at,
160            duration_ms,
161        })
162    }
163
164    /// Extract an i64 from a row
165    fn extract_i64(&self, row: &agentdb::Row, column: &str) -> Result<i64> {
166        row.get(column)
167            .ok_or_else(|| crate::error::AgentFsError::Database(
168                agentdb::AgentDbError::Backend(format!("Missing column: {}", column))
169            ))
170            .and_then(|v| {
171                let s = String::from_utf8_lossy(v.as_bytes());
172                s.parse::<i64>().map_err(|e| {
173                    crate::error::AgentFsError::Database(
174                        agentdb::AgentDbError::Backend(format!("Invalid i64 for {}: {}", column, e))
175                    )
176                })
177            })
178    }
179
180    /// Extract an optional i64 from a row
181    fn extract_i64_opt(&self, row: &agentdb::Row, column: &str) -> Result<Option<i64>> {
182        match row.get(column) {
183            None => Ok(None),
184            Some(v) => {
185                // Empty bytes mean NULL
186                if v.as_bytes().is_empty() {
187                    return Ok(None);
188                }
189                let s = String::from_utf8_lossy(v.as_bytes());
190                if s.is_empty() || s == "NULL" {
191                    Ok(None)
192                } else {
193                    s.parse::<i64>()
194                        .map(Some)
195                        .map_err(|e| crate::error::AgentFsError::Database(
196                            agentdb::AgentDbError::Backend(format!("Invalid i64 for {}: {}", column, e))
197                        ))
198                }
199            }
200        }
201    }
202
203    /// Extract a String from a row
204    fn extract_string(&self, row: &agentdb::Row, column: &str) -> Result<String> {
205        row.get(column)
206            .ok_or_else(|| crate::error::AgentFsError::Database(
207                agentdb::AgentDbError::Backend(format!("Missing column: {}", column))
208            ))
209            .map(|v| String::from_utf8_lossy(v.as_bytes()).to_string())
210    }
211
212    /// Extract an optional String from a row
213    fn extract_string_opt(&self, row: &agentdb::Row, column: &str) -> Result<Option<String>> {
214        Ok(row.get(column).and_then(|v| {
215            // Empty bytes mean NULL
216            if v.as_bytes().is_empty() {
217                None
218            } else {
219                Some(String::from_utf8_lossy(v.as_bytes()).to_string())
220            }
221        }))
222    }
223}
224
225#[async_trait]
226impl ToolRecorder for DbToolRecorder {
227    async fn start(&self, name: &str, parameters: Option<serde_json::Value>) -> Result<i64> {
228        let serialized_params = parameters
229            .map(|p| serde_json::to_string(&p))
230            .transpose()?
231            .unwrap_or_default();
232
233        let started_at = Self::now();
234
235        let query = format!(
236            "INSERT INTO tool_calls (name, parameters, status, started_at) VALUES ('{}', '{}', 'pending', {})",
237            name.replace('\'', "''"),
238            serialized_params.replace('\'', "''"),
239            started_at
240        );
241
242        self.db.query(&query, vec![]).await?;
243
244        // Get the ID of the just-inserted row using rowid
245        // This works across SQLite, PostgreSQL (with oid), and MySQL
246        let result = self.db.query(
247            "SELECT id FROM tool_calls WHERE rowid = last_insert_rowid()",
248            vec![]
249        ).await?;
250
251        if let Some(row) = result.rows.first() {
252            self.extract_i64(row, "id")
253        } else {
254            // Fallback: get MAX(id) which should be the just-inserted row
255            let result = self.db.query("SELECT MAX(id) as id FROM tool_calls", vec![]).await?;
256            if let Some(row) = result.rows.first() {
257                self.extract_i64(row, "id")
258            } else {
259                Err(crate::error::AgentFsError::Database(
260                    agentdb::AgentDbError::Backend("Failed to get tool call ID".to_string())
261                ))
262            }
263        }
264    }
265
266    async fn success(&self, id: i64, result: Option<serde_json::Value>) -> Result<()> {
267        let serialized_result = result
268            .map(|r| serde_json::to_string(&r))
269            .transpose()?
270            .unwrap_or_default();
271
272        let completed_at = Self::now();
273
274        // Get the started_at time to calculate duration
275        let query = format!("SELECT started_at FROM tool_calls WHERE id = {}", id);
276        let res = self.db.query(&query, vec![]).await?;
277
278        let started_at = if let Some(row) = res.rows.first() {
279            self.extract_i64(row, "started_at")?
280        } else {
281            return Err(crate::error::AgentFsError::Database(
282                agentdb::AgentDbError::Backend("Tool call not found".to_string())
283            ));
284        };
285
286        let duration_ms = (completed_at - started_at) * 1000;
287
288        let query = format!(
289            "UPDATE tool_calls SET result = '{}', status = 'success', completed_at = {}, duration_ms = {} WHERE id = {}",
290            serialized_result.replace('\'', "''"),
291            completed_at,
292            duration_ms,
293            id
294        );
295
296        self.db.query(&query, vec![]).await?;
297        Ok(())
298    }
299
300    async fn error(&self, id: i64, error: &str) -> Result<()> {
301        let completed_at = Self::now();
302
303        // Get the started_at time to calculate duration
304        let query = format!("SELECT started_at FROM tool_calls WHERE id = {}", id);
305        let res = self.db.query(&query, vec![]).await?;
306
307        let started_at = if let Some(row) = res.rows.first() {
308            self.extract_i64(row, "started_at")?
309        } else {
310            return Err(crate::error::AgentFsError::Database(
311                agentdb::AgentDbError::Backend("Tool call not found".to_string())
312            ));
313        };
314
315        let duration_ms = (completed_at - started_at) * 1000;
316
317        let query = format!(
318            "UPDATE tool_calls SET error = '{}', status = 'error', completed_at = {}, duration_ms = {} WHERE id = {}",
319            error.replace('\'', "''"),
320            completed_at,
321            duration_ms,
322            id
323        );
324
325        self.db.query(&query, vec![]).await?;
326        Ok(())
327    }
328
329    async fn get(&self, id: i64) -> Result<Option<ToolCall>> {
330        let query = format!(
331            "SELECT id, name, parameters, result, error, status, started_at, completed_at, duration_ms FROM tool_calls WHERE id = {}",
332            id
333        );
334
335        let result = self.db.query(&query, vec![]).await?;
336
337        if let Some(row) = result.rows.first() {
338            Ok(Some(self.parse_tool_call(row)?))
339        } else {
340            Ok(None)
341        }
342    }
343
344    async fn stats_for(&self, tool_name: &str) -> Result<Option<ToolCallStats>> {
345        let query = format!(
346            "SELECT
347                COUNT(*) as total_calls,
348                SUM(CASE WHEN status = 'success' THEN 1 ELSE 0 END) as successful,
349                SUM(CASE WHEN status = 'error' THEN 1 ELSE 0 END) as failed,
350                AVG(CASE WHEN duration_ms IS NOT NULL THEN duration_ms ELSE 0 END) as avg_duration_ms
351            FROM tool_calls
352            WHERE name = '{}'",
353            tool_name.replace('\'', "''")
354        );
355
356        let result = self.db.query(&query, vec![]).await?;
357
358        if let Some(row) = result.rows.first() {
359            let total_calls = self.extract_i64(row, "total_calls")?;
360
361            if total_calls == 0 {
362                return Ok(None);
363            }
364
365            let successful = self.extract_i64(row, "successful")?;
366            let failed = self.extract_i64(row, "failed")?;
367
368            let avg_duration_str = self.extract_string(row, "avg_duration_ms")?;
369            let avg_duration_ms = avg_duration_str.parse::<f64>().unwrap_or(0.0);
370
371            Ok(Some(ToolCallStats {
372                name: tool_name.to_string(),
373                total_calls,
374                successful,
375                failed,
376                avg_duration_ms,
377            }))
378        } else {
379            Ok(None)
380        }
381    }
382
383    async fn record(
384        &self,
385        name: &str,
386        started_at: i64,
387        completed_at: i64,
388        parameters: Option<serde_json::Value>,
389        result: Option<serde_json::Value>,
390        error: Option<&str>,
391    ) -> Result<i64> {
392        let serialized_params = parameters
393            .map(|p| serde_json::to_string(&p))
394            .transpose()?
395            .unwrap_or_default();
396
397        let serialized_result = result
398            .map(|r| serde_json::to_string(&r))
399            .transpose()?
400            .unwrap_or_default();
401
402        let duration_ms = (completed_at - started_at) * 1000;
403        let status = if error.is_some() { "error" } else { "success" };
404
405        let query = format!(
406            "INSERT INTO tool_calls (name, parameters, result, error, status, started_at, completed_at, duration_ms)
407             VALUES ('{}', '{}', '{}', '{}', '{}', {}, {}, {})",
408            name.replace('\'', "''"),
409            serialized_params.replace('\'', "''"),
410            serialized_result.replace('\'', "''"),
411            error.unwrap_or("").replace('\'', "''"),
412            status,
413            started_at,
414            completed_at,
415            duration_ms
416        );
417
418        self.db.query(&query, vec![]).await?;
419
420        // Get the ID of the just-inserted row using rowid
421        let result = self.db.query(
422            "SELECT id FROM tool_calls WHERE rowid = last_insert_rowid()",
423            vec![]
424        ).await?;
425
426        if let Some(row) = result.rows.first() {
427            self.extract_i64(row, "id")
428        } else {
429            // Fallback: get MAX(id) which should be the just-inserted row
430            let result = self.db.query("SELECT MAX(id) as id FROM tool_calls", vec![]).await?;
431            if let Some(row) = result.rows.first() {
432                self.extract_i64(row, "id")
433            } else {
434                Err(crate::error::AgentFsError::Database(
435                    agentdb::AgentDbError::Backend("Failed to get tool call ID".to_string())
436                ))
437            }
438        }
439    }
440
441    async fn list(&self, limit: Option<usize>) -> Result<Vec<ToolCall>> {
442        let limit_clause = limit
443            .map(|l| format!(" LIMIT {}", l))
444            .unwrap_or_default();
445
446        let query = format!(
447            "SELECT id, name, parameters, result, error, status, started_at, completed_at, duration_ms
448             FROM tool_calls
449             ORDER BY started_at DESC{}",
450            limit_clause
451        );
452
453        let result = self.db.query(&query, vec![]).await?;
454
455        let mut tool_calls = Vec::new();
456        for row in &result.rows {
457            tool_calls.push(self.parse_tool_call(row)?);
458        }
459
460        Ok(tool_calls)
461    }
462}