1use std::fmt;
2
3#[derive(Debug, Clone)]
5pub struct DiffData {
6 pub file_path: String,
7 pub old_content: String,
8 pub new_content: String,
9}
10
11#[derive(Debug, Clone)]
13pub struct ToolCall {
14 pub tool_id: String,
15 pub params: serde_json::Map<String, serde_json::Value>,
16}
17
18#[derive(Debug, Clone, Default)]
20pub struct FilterStats {
21 pub raw_chars: usize,
22 pub filtered_chars: usize,
23 pub raw_lines: usize,
24 pub filtered_lines: usize,
25 pub confidence: Option<crate::FilterConfidence>,
26 pub command: Option<String>,
27 pub kept_lines: Vec<usize>,
28}
29
30impl FilterStats {
31 #[must_use]
32 #[allow(clippy::cast_precision_loss)]
33 pub fn savings_pct(&self) -> f64 {
34 if self.raw_chars == 0 {
35 return 0.0;
36 }
37 (1.0 - self.filtered_chars as f64 / self.raw_chars as f64) * 100.0
38 }
39
40 #[must_use]
41 pub fn estimated_tokens_saved(&self) -> usize {
42 self.raw_chars.saturating_sub(self.filtered_chars) / 4
43 }
44
45 #[must_use]
46 pub fn format_inline(&self, tool_name: &str) -> String {
47 let cmd_label = self
48 .command
49 .as_deref()
50 .map(|c| {
51 let trimmed = c.trim();
52 if trimmed.len() > 60 {
53 format!(" `{}…`", &trimmed[..57])
54 } else {
55 format!(" `{trimmed}`")
56 }
57 })
58 .unwrap_or_default();
59 format!(
60 "[{tool_name}]{cmd_label} {} lines \u{2192} {} lines, {:.1}% filtered",
61 self.raw_lines,
62 self.filtered_lines,
63 self.savings_pct()
64 )
65 }
66}
67
68#[derive(Debug, Clone)]
70pub struct ToolOutput {
71 pub tool_name: String,
72 pub summary: String,
73 pub blocks_executed: u32,
74 pub filter_stats: Option<FilterStats>,
75 pub diff: Option<DiffData>,
76 pub streamed: bool,
78}
79
80impl fmt::Display for ToolOutput {
81 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
82 f.write_str(&self.summary)
83 }
84}
85
86pub const MAX_TOOL_OUTPUT_CHARS: usize = 30_000;
87
88#[must_use]
90pub fn truncate_tool_output(output: &str) -> String {
91 if output.len() <= MAX_TOOL_OUTPUT_CHARS {
92 return output.to_string();
93 }
94
95 let half = MAX_TOOL_OUTPUT_CHARS / 2;
96 let head_end = output.floor_char_boundary(half);
97 let tail_start = output.ceil_char_boundary(output.len() - half);
98 let head = &output[..head_end];
99 let tail = &output[tail_start..];
100 let truncated = output.len() - head_end - (output.len() - tail_start);
101
102 format!(
103 "{head}\n\n... [truncated {truncated} chars, showing first and last ~{half} chars] ...\n\n{tail}"
104 )
105}
106
107#[derive(Debug, Clone)]
109pub enum ToolEvent {
110 Started {
111 tool_name: String,
112 command: String,
113 },
114 OutputChunk {
115 tool_name: String,
116 command: String,
117 chunk: String,
118 },
119 Completed {
120 tool_name: String,
121 command: String,
122 output: String,
123 success: bool,
124 filter_stats: Option<FilterStats>,
125 diff: Option<DiffData>,
126 },
127}
128
129pub type ToolEventTx = tokio::sync::mpsc::UnboundedSender<ToolEvent>;
130
131#[derive(Debug, thiserror::Error)]
133pub enum ToolError {
134 #[error("command blocked by policy: {command}")]
135 Blocked { command: String },
136
137 #[error("path not allowed by sandbox: {path}")]
138 SandboxViolation { path: String },
139
140 #[error("command requires confirmation: {command}")]
141 ConfirmationRequired { command: String },
142
143 #[error("command timed out after {timeout_secs}s")]
144 Timeout { timeout_secs: u64 },
145
146 #[error("operation cancelled")]
147 Cancelled,
148
149 #[error("invalid tool parameters: {message}")]
150 InvalidParams { message: String },
151
152 #[error("execution failed: {0}")]
153 Execution(#[from] std::io::Error),
154}
155
156pub fn deserialize_params<T: serde::de::DeserializeOwned>(
162 params: &serde_json::Map<String, serde_json::Value>,
163) -> Result<T, ToolError> {
164 let obj = serde_json::Value::Object(params.clone());
165 serde_json::from_value(obj).map_err(|e| ToolError::InvalidParams {
166 message: e.to_string(),
167 })
168}
169
170pub trait ToolExecutor: Send + Sync {
175 fn execute(
176 &self,
177 response: &str,
178 ) -> impl Future<Output = Result<Option<ToolOutput>, ToolError>> + Send;
179
180 fn execute_confirmed(
183 &self,
184 response: &str,
185 ) -> impl Future<Output = Result<Option<ToolOutput>, ToolError>> + Send {
186 self.execute(response)
187 }
188
189 fn tool_definitions(&self) -> Vec<crate::registry::ToolDef> {
191 vec![]
192 }
193
194 fn execute_tool_call(
196 &self,
197 _call: &ToolCall,
198 ) -> impl Future<Output = Result<Option<ToolOutput>, ToolError>> + Send {
199 std::future::ready(Ok(None))
200 }
201
202 fn set_skill_env(&self, _env: Option<std::collections::HashMap<String, String>>) {}
204}
205
206pub trait ErasedToolExecutor: Send + Sync {
211 fn execute_erased<'a>(
212 &'a self,
213 response: &'a str,
214 ) -> std::pin::Pin<Box<dyn Future<Output = Result<Option<ToolOutput>, ToolError>> + Send + 'a>>;
215
216 fn execute_confirmed_erased<'a>(
217 &'a self,
218 response: &'a str,
219 ) -> std::pin::Pin<Box<dyn Future<Output = Result<Option<ToolOutput>, ToolError>> + Send + 'a>>;
220
221 fn tool_definitions_erased(&self) -> Vec<crate::registry::ToolDef>;
222
223 fn execute_tool_call_erased<'a>(
224 &'a self,
225 call: &'a ToolCall,
226 ) -> std::pin::Pin<Box<dyn Future<Output = Result<Option<ToolOutput>, ToolError>> + Send + 'a>>;
227
228 fn set_skill_env(&self, _env: Option<std::collections::HashMap<String, String>>) {}
230}
231
232impl<T: ToolExecutor> ErasedToolExecutor for T {
233 fn execute_erased<'a>(
234 &'a self,
235 response: &'a str,
236 ) -> std::pin::Pin<Box<dyn Future<Output = Result<Option<ToolOutput>, ToolError>> + Send + 'a>>
237 {
238 Box::pin(self.execute(response))
239 }
240
241 fn execute_confirmed_erased<'a>(
242 &'a self,
243 response: &'a str,
244 ) -> std::pin::Pin<Box<dyn Future<Output = Result<Option<ToolOutput>, ToolError>> + Send + 'a>>
245 {
246 Box::pin(self.execute_confirmed(response))
247 }
248
249 fn tool_definitions_erased(&self) -> Vec<crate::registry::ToolDef> {
250 self.tool_definitions()
251 }
252
253 fn execute_tool_call_erased<'a>(
254 &'a self,
255 call: &'a ToolCall,
256 ) -> std::pin::Pin<Box<dyn Future<Output = Result<Option<ToolOutput>, ToolError>> + Send + 'a>>
257 {
258 Box::pin(self.execute_tool_call(call))
259 }
260
261 fn set_skill_env(&self, env: Option<std::collections::HashMap<String, String>>) {
262 ToolExecutor::set_skill_env(self, env);
263 }
264}
265
266#[must_use]
270pub fn extract_fenced_blocks<'a>(text: &'a str, lang: &str) -> Vec<&'a str> {
271 let marker = format!("```{lang}");
272 let marker_len = marker.len();
273 let mut blocks = Vec::new();
274 let mut rest = text;
275
276 while let Some(start) = rest.find(&marker) {
277 let after = &rest[start + marker_len..];
278 if let Some(end) = after.find("```") {
279 blocks.push(after[..end].trim());
280 rest = &after[end + 3..];
281 } else {
282 break;
283 }
284 }
285
286 blocks
287}
288
289#[cfg(test)]
290mod tests {
291 use super::*;
292
293 #[test]
294 fn tool_output_display() {
295 let output = ToolOutput {
296 tool_name: "bash".to_owned(),
297 summary: "$ echo hello\nhello".to_owned(),
298 blocks_executed: 1,
299 filter_stats: None,
300 diff: None,
301 streamed: false,
302 };
303 assert_eq!(output.to_string(), "$ echo hello\nhello");
304 }
305
306 #[test]
307 fn tool_error_blocked_display() {
308 let err = ToolError::Blocked {
309 command: "rm -rf /".to_owned(),
310 };
311 assert_eq!(err.to_string(), "command blocked by policy: rm -rf /");
312 }
313
314 #[test]
315 fn tool_error_sandbox_violation_display() {
316 let err = ToolError::SandboxViolation {
317 path: "/etc/shadow".to_owned(),
318 };
319 assert_eq!(err.to_string(), "path not allowed by sandbox: /etc/shadow");
320 }
321
322 #[test]
323 fn tool_error_confirmation_required_display() {
324 let err = ToolError::ConfirmationRequired {
325 command: "rm -rf /tmp".to_owned(),
326 };
327 assert_eq!(
328 err.to_string(),
329 "command requires confirmation: rm -rf /tmp"
330 );
331 }
332
333 #[test]
334 fn tool_error_timeout_display() {
335 let err = ToolError::Timeout { timeout_secs: 30 };
336 assert_eq!(err.to_string(), "command timed out after 30s");
337 }
338
339 #[test]
340 fn tool_error_invalid_params_display() {
341 let err = ToolError::InvalidParams {
342 message: "missing field `command`".to_owned(),
343 };
344 assert_eq!(
345 err.to_string(),
346 "invalid tool parameters: missing field `command`"
347 );
348 }
349
350 #[test]
351 fn deserialize_params_valid() {
352 #[derive(Debug, serde::Deserialize, PartialEq)]
353 struct P {
354 name: String,
355 count: u32,
356 }
357 let mut map = serde_json::Map::new();
358 map.insert("name".to_owned(), serde_json::json!("test"));
359 map.insert("count".to_owned(), serde_json::json!(42));
360 let p: P = deserialize_params(&map).unwrap();
361 assert_eq!(
362 p,
363 P {
364 name: "test".to_owned(),
365 count: 42
366 }
367 );
368 }
369
370 #[test]
371 fn deserialize_params_missing_required_field() {
372 #[derive(Debug, serde::Deserialize)]
373 struct P {
374 #[allow(dead_code)]
375 name: String,
376 }
377 let map = serde_json::Map::new();
378 let err = deserialize_params::<P>(&map).unwrap_err();
379 assert!(matches!(err, ToolError::InvalidParams { .. }));
380 }
381
382 #[test]
383 fn deserialize_params_wrong_type() {
384 #[derive(Debug, serde::Deserialize)]
385 struct P {
386 #[allow(dead_code)]
387 count: u32,
388 }
389 let mut map = serde_json::Map::new();
390 map.insert("count".to_owned(), serde_json::json!("not a number"));
391 let err = deserialize_params::<P>(&map).unwrap_err();
392 assert!(matches!(err, ToolError::InvalidParams { .. }));
393 }
394
395 #[test]
396 fn deserialize_params_all_optional_empty() {
397 #[derive(Debug, serde::Deserialize, PartialEq)]
398 struct P {
399 name: Option<String>,
400 }
401 let map = serde_json::Map::new();
402 let p: P = deserialize_params(&map).unwrap();
403 assert_eq!(p, P { name: None });
404 }
405
406 #[test]
407 fn deserialize_params_ignores_extra_fields() {
408 #[derive(Debug, serde::Deserialize, PartialEq)]
409 struct P {
410 name: String,
411 }
412 let mut map = serde_json::Map::new();
413 map.insert("name".to_owned(), serde_json::json!("test"));
414 map.insert("extra".to_owned(), serde_json::json!(true));
415 let p: P = deserialize_params(&map).unwrap();
416 assert_eq!(
417 p,
418 P {
419 name: "test".to_owned()
420 }
421 );
422 }
423
424 #[test]
425 fn tool_error_execution_display() {
426 let io_err = std::io::Error::new(std::io::ErrorKind::NotFound, "bash not found");
427 let err = ToolError::Execution(io_err);
428 assert!(err.to_string().starts_with("execution failed:"));
429 assert!(err.to_string().contains("bash not found"));
430 }
431
432 #[test]
433 fn truncate_tool_output_short_passthrough() {
434 let short = "hello world";
435 assert_eq!(truncate_tool_output(short), short);
436 }
437
438 #[test]
439 fn truncate_tool_output_exact_limit() {
440 let exact = "a".repeat(MAX_TOOL_OUTPUT_CHARS);
441 assert_eq!(truncate_tool_output(&exact), exact);
442 }
443
444 #[test]
445 fn truncate_tool_output_long_split() {
446 let long = "x".repeat(MAX_TOOL_OUTPUT_CHARS + 1000);
447 let result = truncate_tool_output(&long);
448 assert!(result.contains("truncated"));
449 assert!(result.len() < long.len());
450 }
451
452 #[test]
453 fn truncate_tool_output_notice_contains_count() {
454 let long = "y".repeat(MAX_TOOL_OUTPUT_CHARS + 2000);
455 let result = truncate_tool_output(&long);
456 assert!(result.contains("truncated"));
457 assert!(result.contains("chars"));
458 }
459
460 #[derive(Debug)]
461 struct DefaultExecutor;
462 impl ToolExecutor for DefaultExecutor {
463 async fn execute(&self, _response: &str) -> Result<Option<ToolOutput>, ToolError> {
464 Ok(None)
465 }
466 }
467
468 #[tokio::test]
469 async fn execute_tool_call_default_returns_none() {
470 let exec = DefaultExecutor;
471 let call = ToolCall {
472 tool_id: "anything".to_owned(),
473 params: serde_json::Map::new(),
474 };
475 let result = exec.execute_tool_call(&call).await.unwrap();
476 assert!(result.is_none());
477 }
478
479 #[test]
480 fn filter_stats_savings_pct() {
481 let fs = FilterStats {
482 raw_chars: 1000,
483 filtered_chars: 200,
484 ..Default::default()
485 };
486 assert!((fs.savings_pct() - 80.0).abs() < 0.01);
487 }
488
489 #[test]
490 fn filter_stats_savings_pct_zero() {
491 let fs = FilterStats::default();
492 assert!((fs.savings_pct()).abs() < 0.01);
493 }
494
495 #[test]
496 fn filter_stats_estimated_tokens_saved() {
497 let fs = FilterStats {
498 raw_chars: 1000,
499 filtered_chars: 200,
500 ..Default::default()
501 };
502 assert_eq!(fs.estimated_tokens_saved(), 200); }
504
505 #[test]
506 fn filter_stats_format_inline() {
507 let fs = FilterStats {
508 raw_chars: 1000,
509 filtered_chars: 200,
510 raw_lines: 342,
511 filtered_lines: 28,
512 ..Default::default()
513 };
514 let line = fs.format_inline("shell");
515 assert_eq!(line, "[shell] 342 lines \u{2192} 28 lines, 80.0% filtered");
516 }
517
518 #[test]
519 fn filter_stats_format_inline_zero() {
520 let fs = FilterStats::default();
521 let line = fs.format_inline("bash");
522 assert_eq!(line, "[bash] 0 lines \u{2192} 0 lines, 0.0% filtered");
523 }
524}