Skip to main content

roboticus_db/
tools.rs

1use std::collections::HashMap;
2
3use crate::{Database, DbResultExt};
4use roboticus_core::Result;
5
6#[derive(Debug, Clone)]
7pub struct ToolCallRecord {
8    pub id: String,
9    pub turn_id: String,
10    pub tool_name: String,
11    pub input: String,
12    pub output: Option<String>,
13    pub skill_id: Option<String>,
14    pub skill_name: Option<String>,
15    pub skill_hash: Option<String>,
16    pub status: String,
17    pub duration_ms: Option<i64>,
18    pub created_at: String,
19}
20
21pub fn record_tool_call(
22    db: &Database,
23    turn_id: &str,
24    tool_name: &str,
25    input: &str,
26    output: Option<&str>,
27    status: &str,
28    duration_ms: Option<i64>,
29) -> Result<String> {
30    record_tool_call_with_skill(
31        db,
32        turn_id,
33        tool_name,
34        input,
35        output,
36        status,
37        duration_ms,
38        None,
39        None,
40        None,
41    )
42}
43
44#[allow(clippy::too_many_arguments)]
45pub fn record_tool_call_with_skill(
46    db: &Database,
47    turn_id: &str,
48    tool_name: &str,
49    input: &str,
50    output: Option<&str>,
51    status: &str,
52    duration_ms: Option<i64>,
53    skill_id: Option<&str>,
54    skill_name: Option<&str>,
55    skill_hash: Option<&str>,
56) -> Result<String> {
57    let conn = db.conn();
58    let id = uuid::Uuid::new_v4().to_string();
59    conn.execute(
60        "INSERT INTO tool_calls (id, turn_id, tool_name, input, output, skill_id, skill_name, \
61         skill_hash, status, duration_ms) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10)",
62        rusqlite::params![
63            id,
64            turn_id,
65            tool_name,
66            input,
67            output,
68            skill_id,
69            skill_name,
70            skill_hash,
71            status,
72            duration_ms
73        ],
74    )
75    .db_err()?;
76    Ok(id)
77}
78
79pub fn get_tool_calls_for_turn(db: &Database, turn_id: &str) -> Result<Vec<ToolCallRecord>> {
80    let conn = db.conn();
81    let mut stmt = conn
82        .prepare(
83            "SELECT id, turn_id, tool_name, input, output, skill_id, skill_name, skill_hash, \
84             status, duration_ms, created_at \
85             FROM tool_calls WHERE turn_id = ?1 ORDER BY created_at ASC",
86        )
87        .db_err()?;
88
89    let rows = stmt
90        .query_map([turn_id], |row| {
91            Ok(ToolCallRecord {
92                id: row.get(0)?,
93                turn_id: row.get(1)?,
94                tool_name: row.get(2)?,
95                input: row.get(3)?,
96                output: row.get(4)?,
97                skill_id: row.get(5)?,
98                skill_name: row.get(6)?,
99                skill_hash: row.get(7)?,
100                status: row.get(8)?,
101                duration_ms: row.get(9)?,
102                created_at: row.get(10)?,
103            })
104        })
105        .db_err()?;
106
107    rows.collect::<std::result::Result<Vec<_>, _>>().db_err()
108}
109
110/// Batch-fetch all tool calls for every turn in a session, grouped by turn_id.
111/// Eliminates the N+1 query pattern when analyzing sessions with many turns.
112pub fn get_tool_calls_for_session(
113    db: &Database,
114    session_id: &str,
115) -> Result<HashMap<String, Vec<ToolCallRecord>>> {
116    let conn = db.conn();
117    let mut stmt = conn
118        .prepare(
119            "SELECT tc.id, tc.turn_id, tc.tool_name, tc.input, tc.output, tc.skill_id, \
120                    tc.skill_name, tc.skill_hash, tc.status, tc.duration_ms, tc.created_at \
121             FROM tool_calls tc \
122             INNER JOIN turns t ON tc.turn_id = t.id \
123             WHERE t.session_id = ?1 \
124             ORDER BY tc.created_at ASC",
125        )
126        .db_err()?;
127
128    let rows = stmt
129        .query_map([session_id], |row| {
130            Ok(ToolCallRecord {
131                id: row.get(0)?,
132                turn_id: row.get(1)?,
133                tool_name: row.get(2)?,
134                input: row.get(3)?,
135                output: row.get(4)?,
136                skill_id: row.get(5)?,
137                skill_name: row.get(6)?,
138                skill_hash: row.get(7)?,
139                status: row.get(8)?,
140                duration_ms: row.get(9)?,
141                created_at: row.get(10)?,
142            })
143        })
144        .db_err()?;
145
146    let mut map: HashMap<String, Vec<ToolCallRecord>> = HashMap::new();
147    for row in rows {
148        let record = row.db_err()?;
149        map.entry(record.turn_id.clone()).or_default().push(record);
150    }
151    Ok(map)
152}
153
154#[cfg(test)]
155mod tests {
156    use super::*;
157
158    fn test_db() -> Database {
159        let db = Database::new(":memory:").unwrap();
160        // tool_calls has FK to turns, which has FK to sessions — seed parent rows
161        let conn = db.conn();
162        conn.execute(
163            "INSERT INTO sessions (id, agent_id) VALUES ('s1', 'agent-1')",
164            [],
165        )
166        .unwrap();
167        conn.execute("INSERT INTO turns (id, session_id) VALUES ('t1', 's1')", [])
168            .unwrap();
169        drop(conn);
170        db
171    }
172
173    #[test]
174    fn record_and_retrieve_tool_call() {
175        let db = test_db();
176        let id = record_tool_call(
177            &db,
178            "t1",
179            "bash",
180            r#"{"cmd":"ls"}"#,
181            Some("file1\nfile2"),
182            "success",
183            Some(42),
184        )
185        .unwrap();
186        assert!(!id.is_empty());
187
188        let calls = get_tool_calls_for_turn(&db, "t1").unwrap();
189        assert_eq!(calls.len(), 1);
190        assert_eq!(calls[0].tool_name, "bash");
191        assert_eq!(calls[0].duration_ms, Some(42));
192    }
193
194    #[test]
195    fn empty_turn_returns_empty_vec() {
196        let db = test_db();
197        let calls = get_tool_calls_for_turn(&db, "t1").unwrap();
198        assert!(calls.is_empty());
199    }
200
201    #[test]
202    fn multiple_calls_ordered_by_time() {
203        let db = test_db();
204        record_tool_call(&db, "t1", "read", "{}", None, "success", Some(10)).unwrap();
205        record_tool_call(&db, "t1", "write", "{}", None, "success", Some(20)).unwrap();
206
207        let calls = get_tool_calls_for_turn(&db, "t1").unwrap();
208        assert_eq!(calls.len(), 2);
209        assert_eq!(calls[0].tool_name, "read");
210        assert_eq!(calls[1].tool_name, "write");
211    }
212
213    #[test]
214    fn record_tool_call_no_output_no_duration() {
215        let db = test_db();
216        let id = record_tool_call(
217            &db,
218            "t1",
219            "search",
220            r#"{"q":"test"}"#,
221            None,
222            "pending",
223            None,
224        )
225        .unwrap();
226        assert!(!id.is_empty());
227        let calls = get_tool_calls_for_turn(&db, "t1").unwrap();
228        assert!(calls[0].output.is_none());
229        assert!(calls[0].duration_ms.is_none());
230        assert_eq!(calls[0].status, "pending");
231    }
232
233    #[test]
234    fn record_tool_call_error_status() {
235        let db = test_db();
236        record_tool_call(
237            &db,
238            "t1",
239            "bash",
240            r#"{"cmd":"rm -rf /"}"#,
241            Some("permission denied"),
242            "error",
243            Some(5),
244        )
245        .unwrap();
246        let calls = get_tool_calls_for_turn(&db, "t1").unwrap();
247        assert_eq!(calls[0].status, "error");
248        assert_eq!(calls[0].output.as_deref(), Some("permission denied"));
249    }
250
251    #[test]
252    fn get_tool_calls_nonexistent_turn() {
253        let db = test_db();
254        let calls = get_tool_calls_for_turn(&db, "nonexistent").unwrap();
255        assert!(calls.is_empty());
256    }
257
258    #[test]
259    fn batch_get_tool_calls_for_session() {
260        let db = test_db();
261        let conn = db.conn();
262        conn.execute("INSERT INTO turns (id, session_id) VALUES ('t2', 's1')", [])
263            .unwrap();
264        drop(conn);
265        record_tool_call(&db, "t1", "read", "{}", None, "success", Some(10)).unwrap();
266        record_tool_call(&db, "t1", "write", "{}", None, "success", Some(20)).unwrap();
267        record_tool_call(&db, "t2", "bash", "{}", None, "error", Some(5)).unwrap();
268
269        let map = get_tool_calls_for_session(&db, "s1").unwrap();
270        assert_eq!(map.len(), 2);
271        assert_eq!(map["t1"].len(), 2);
272        assert_eq!(map["t2"].len(), 1);
273        assert_eq!(map["t2"][0].tool_name, "bash");
274    }
275
276    #[test]
277    fn batch_get_empty_session() {
278        let db = test_db();
279        let map = get_tool_calls_for_session(&db, "s1").unwrap();
280        assert!(map.is_empty());
281    }
282
283    #[test]
284    fn tool_call_fields_populated() {
285        let db = test_db();
286        record_tool_call(
287            &db,
288            "t1",
289            "bash",
290            r#"{"cmd":"echo hi"}"#,
291            Some("hi"),
292            "success",
293            Some(100),
294        )
295        .unwrap();
296        let calls = get_tool_calls_for_turn(&db, "t1").unwrap();
297        assert!(!calls[0].id.is_empty());
298        assert_eq!(calls[0].turn_id, "t1");
299        assert!(!calls[0].created_at.is_empty());
300    }
301
302    #[test]
303    fn record_tool_call_with_skill_attribution() {
304        let db = test_db();
305        record_tool_call_with_skill(
306            &db,
307            "t1",
308            "run_script",
309            r#"{"path":"deploy.sh"}"#,
310            Some("ok"),
311            "success",
312            Some(33),
313            Some("skill-123"),
314            Some("deploy"),
315            Some("hash-abc"),
316        )
317        .unwrap();
318        let calls = get_tool_calls_for_turn(&db, "t1").unwrap();
319        assert_eq!(calls.len(), 1);
320        assert_eq!(calls[0].skill_id.as_deref(), Some("skill-123"));
321        assert_eq!(calls[0].skill_name.as_deref(), Some("deploy"));
322        assert_eq!(calls[0].skill_hash.as_deref(), Some("hash-abc"));
323    }
324}