1use std::collections::HashMap;
2
3use crate::{Database, DbResultExt};
4use roboticus_core::Result;
5
6#[derive(Debug, Clone)]
7pub struct ToolCallRecord {
8 pub id: String,
9 pub turn_id: String,
10 pub tool_name: String,
11 pub input: String,
12 pub output: Option<String>,
13 pub skill_id: Option<String>,
14 pub skill_name: Option<String>,
15 pub skill_hash: Option<String>,
16 pub status: String,
17 pub duration_ms: Option<i64>,
18 pub created_at: String,
19}
20
21pub fn record_tool_call(
22 db: &Database,
23 turn_id: &str,
24 tool_name: &str,
25 input: &str,
26 output: Option<&str>,
27 status: &str,
28 duration_ms: Option<i64>,
29) -> Result<String> {
30 record_tool_call_with_skill(
31 db,
32 turn_id,
33 tool_name,
34 input,
35 output,
36 status,
37 duration_ms,
38 None,
39 None,
40 None,
41 )
42}
43
44#[allow(clippy::too_many_arguments)]
45pub fn record_tool_call_with_skill(
46 db: &Database,
47 turn_id: &str,
48 tool_name: &str,
49 input: &str,
50 output: Option<&str>,
51 status: &str,
52 duration_ms: Option<i64>,
53 skill_id: Option<&str>,
54 skill_name: Option<&str>,
55 skill_hash: Option<&str>,
56) -> Result<String> {
57 let conn = db.conn();
58 let id = uuid::Uuid::new_v4().to_string();
59 conn.execute(
60 "INSERT INTO tool_calls (id, turn_id, tool_name, input, output, skill_id, skill_name, \
61 skill_hash, status, duration_ms) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10)",
62 rusqlite::params![
63 id,
64 turn_id,
65 tool_name,
66 input,
67 output,
68 skill_id,
69 skill_name,
70 skill_hash,
71 status,
72 duration_ms
73 ],
74 )
75 .db_err()?;
76 Ok(id)
77}
78
79pub fn get_tool_calls_for_turn(db: &Database, turn_id: &str) -> Result<Vec<ToolCallRecord>> {
80 let conn = db.conn();
81 let mut stmt = conn
82 .prepare(
83 "SELECT id, turn_id, tool_name, input, output, skill_id, skill_name, skill_hash, \
84 status, duration_ms, created_at \
85 FROM tool_calls WHERE turn_id = ?1 ORDER BY created_at ASC",
86 )
87 .db_err()?;
88
89 let rows = stmt
90 .query_map([turn_id], |row| {
91 Ok(ToolCallRecord {
92 id: row.get(0)?,
93 turn_id: row.get(1)?,
94 tool_name: row.get(2)?,
95 input: row.get(3)?,
96 output: row.get(4)?,
97 skill_id: row.get(5)?,
98 skill_name: row.get(6)?,
99 skill_hash: row.get(7)?,
100 status: row.get(8)?,
101 duration_ms: row.get(9)?,
102 created_at: row.get(10)?,
103 })
104 })
105 .db_err()?;
106
107 rows.collect::<std::result::Result<Vec<_>, _>>().db_err()
108}
109
110pub fn get_tool_calls_for_session(
113 db: &Database,
114 session_id: &str,
115) -> Result<HashMap<String, Vec<ToolCallRecord>>> {
116 let conn = db.conn();
117 let mut stmt = conn
118 .prepare(
119 "SELECT tc.id, tc.turn_id, tc.tool_name, tc.input, tc.output, tc.skill_id, \
120 tc.skill_name, tc.skill_hash, tc.status, tc.duration_ms, tc.created_at \
121 FROM tool_calls tc \
122 INNER JOIN turns t ON tc.turn_id = t.id \
123 WHERE t.session_id = ?1 \
124 ORDER BY tc.created_at ASC",
125 )
126 .db_err()?;
127
128 let rows = stmt
129 .query_map([session_id], |row| {
130 Ok(ToolCallRecord {
131 id: row.get(0)?,
132 turn_id: row.get(1)?,
133 tool_name: row.get(2)?,
134 input: row.get(3)?,
135 output: row.get(4)?,
136 skill_id: row.get(5)?,
137 skill_name: row.get(6)?,
138 skill_hash: row.get(7)?,
139 status: row.get(8)?,
140 duration_ms: row.get(9)?,
141 created_at: row.get(10)?,
142 })
143 })
144 .db_err()?;
145
146 let mut map: HashMap<String, Vec<ToolCallRecord>> = HashMap::new();
147 for row in rows {
148 let record = row.db_err()?;
149 map.entry(record.turn_id.clone()).or_default().push(record);
150 }
151 Ok(map)
152}
153
154#[cfg(test)]
155mod tests {
156 use super::*;
157
158 fn test_db() -> Database {
159 let db = Database::new(":memory:").unwrap();
160 let conn = db.conn();
162 conn.execute(
163 "INSERT INTO sessions (id, agent_id) VALUES ('s1', 'agent-1')",
164 [],
165 )
166 .unwrap();
167 conn.execute("INSERT INTO turns (id, session_id) VALUES ('t1', 's1')", [])
168 .unwrap();
169 drop(conn);
170 db
171 }
172
173 #[test]
174 fn record_and_retrieve_tool_call() {
175 let db = test_db();
176 let id = record_tool_call(
177 &db,
178 "t1",
179 "bash",
180 r#"{"cmd":"ls"}"#,
181 Some("file1\nfile2"),
182 "success",
183 Some(42),
184 )
185 .unwrap();
186 assert!(!id.is_empty());
187
188 let calls = get_tool_calls_for_turn(&db, "t1").unwrap();
189 assert_eq!(calls.len(), 1);
190 assert_eq!(calls[0].tool_name, "bash");
191 assert_eq!(calls[0].duration_ms, Some(42));
192 }
193
194 #[test]
195 fn empty_turn_returns_empty_vec() {
196 let db = test_db();
197 let calls = get_tool_calls_for_turn(&db, "t1").unwrap();
198 assert!(calls.is_empty());
199 }
200
201 #[test]
202 fn multiple_calls_ordered_by_time() {
203 let db = test_db();
204 record_tool_call(&db, "t1", "read", "{}", None, "success", Some(10)).unwrap();
205 record_tool_call(&db, "t1", "write", "{}", None, "success", Some(20)).unwrap();
206
207 let calls = get_tool_calls_for_turn(&db, "t1").unwrap();
208 assert_eq!(calls.len(), 2);
209 assert_eq!(calls[0].tool_name, "read");
210 assert_eq!(calls[1].tool_name, "write");
211 }
212
213 #[test]
214 fn record_tool_call_no_output_no_duration() {
215 let db = test_db();
216 let id = record_tool_call(
217 &db,
218 "t1",
219 "search",
220 r#"{"q":"test"}"#,
221 None,
222 "pending",
223 None,
224 )
225 .unwrap();
226 assert!(!id.is_empty());
227 let calls = get_tool_calls_for_turn(&db, "t1").unwrap();
228 assert!(calls[0].output.is_none());
229 assert!(calls[0].duration_ms.is_none());
230 assert_eq!(calls[0].status, "pending");
231 }
232
233 #[test]
234 fn record_tool_call_error_status() {
235 let db = test_db();
236 record_tool_call(
237 &db,
238 "t1",
239 "bash",
240 r#"{"cmd":"rm -rf /"}"#,
241 Some("permission denied"),
242 "error",
243 Some(5),
244 )
245 .unwrap();
246 let calls = get_tool_calls_for_turn(&db, "t1").unwrap();
247 assert_eq!(calls[0].status, "error");
248 assert_eq!(calls[0].output.as_deref(), Some("permission denied"));
249 }
250
251 #[test]
252 fn get_tool_calls_nonexistent_turn() {
253 let db = test_db();
254 let calls = get_tool_calls_for_turn(&db, "nonexistent").unwrap();
255 assert!(calls.is_empty());
256 }
257
258 #[test]
259 fn batch_get_tool_calls_for_session() {
260 let db = test_db();
261 let conn = db.conn();
262 conn.execute("INSERT INTO turns (id, session_id) VALUES ('t2', 's1')", [])
263 .unwrap();
264 drop(conn);
265 record_tool_call(&db, "t1", "read", "{}", None, "success", Some(10)).unwrap();
266 record_tool_call(&db, "t1", "write", "{}", None, "success", Some(20)).unwrap();
267 record_tool_call(&db, "t2", "bash", "{}", None, "error", Some(5)).unwrap();
268
269 let map = get_tool_calls_for_session(&db, "s1").unwrap();
270 assert_eq!(map.len(), 2);
271 assert_eq!(map["t1"].len(), 2);
272 assert_eq!(map["t2"].len(), 1);
273 assert_eq!(map["t2"][0].tool_name, "bash");
274 }
275
276 #[test]
277 fn batch_get_empty_session() {
278 let db = test_db();
279 let map = get_tool_calls_for_session(&db, "s1").unwrap();
280 assert!(map.is_empty());
281 }
282
283 #[test]
284 fn tool_call_fields_populated() {
285 let db = test_db();
286 record_tool_call(
287 &db,
288 "t1",
289 "bash",
290 r#"{"cmd":"echo hi"}"#,
291 Some("hi"),
292 "success",
293 Some(100),
294 )
295 .unwrap();
296 let calls = get_tool_calls_for_turn(&db, "t1").unwrap();
297 assert!(!calls[0].id.is_empty());
298 assert_eq!(calls[0].turn_id, "t1");
299 assert!(!calls[0].created_at.is_empty());
300 }
301
302 #[test]
303 fn record_tool_call_with_skill_attribution() {
304 let db = test_db();
305 record_tool_call_with_skill(
306 &db,
307 "t1",
308 "run_script",
309 r#"{"path":"deploy.sh"}"#,
310 Some("ok"),
311 "success",
312 Some(33),
313 Some("skill-123"),
314 Some("deploy"),
315 Some("hash-abc"),
316 )
317 .unwrap();
318 let calls = get_tool_calls_for_turn(&db, "t1").unwrap();
319 assert_eq!(calls.len(), 1);
320 assert_eq!(calls[0].skill_id.as_deref(), Some("skill-123"));
321 assert_eq!(calls[0].skill_name.as_deref(), Some("deploy"));
322 assert_eq!(calls[0].skill_hash.as_deref(), Some("hash-abc"));
323 }
324}