1use std::path::{Path, PathBuf};
5
6use schemars::JsonSchema;
7use serde::Deserialize;
8
9use crate::executor::{ToolCall, ToolError, ToolExecutor, ToolOutput, deserialize_params};
10use crate::registry::{InvocationHint, ToolDef};
11
12#[derive(Debug, Default, Deserialize, JsonSchema, PartialEq, Eq)]
14#[serde(rename_all = "snake_case")]
15pub enum DiagnosticsLevel {
16 #[default]
18 Check,
19 Clippy,
21}
22
23#[derive(Debug, Deserialize, JsonSchema)]
24struct DiagnosticsParams {
25 path: Option<String>,
27 #[serde(default)]
29 level: DiagnosticsLevel,
30}
31
32#[derive(Debug)]
34pub struct DiagnosticsExecutor {
35 allowed_paths: Vec<PathBuf>,
36 max_diagnostics: usize,
38}
39
40impl DiagnosticsExecutor {
41 #[must_use]
42 pub fn new(allowed_paths: Vec<PathBuf>) -> Self {
43 let paths = if allowed_paths.is_empty() {
44 vec![std::env::current_dir().unwrap_or_else(|_| PathBuf::from("."))]
45 } else {
46 allowed_paths
47 };
48 Self {
49 allowed_paths: paths
50 .into_iter()
51 .map(|p| p.canonicalize().unwrap_or(p))
52 .collect(),
53 max_diagnostics: 50,
54 }
55 }
56
57 #[must_use]
58 pub fn with_max_diagnostics(mut self, max: usize) -> Self {
59 self.max_diagnostics = max;
60 self
61 }
62
63 fn validate_path(&self, path: &Path) -> Result<PathBuf, ToolError> {
64 let resolved = if path.is_absolute() {
65 path.to_path_buf()
66 } else {
67 std::env::current_dir()
68 .unwrap_or_else(|_| PathBuf::from("."))
69 .join(path)
70 };
71 let canonical = resolved.canonicalize().map_err(|e| {
72 ToolError::Execution(std::io::Error::new(
73 std::io::ErrorKind::NotFound,
74 format!("path not found: {}: {e}", resolved.display()),
75 ))
76 })?;
77 if !self.allowed_paths.iter().any(|a| canonical.starts_with(a)) {
78 return Err(ToolError::SandboxViolation {
79 path: canonical.display().to_string(),
80 });
81 }
82 Ok(canonical)
83 }
84}
85
86impl ToolExecutor for DiagnosticsExecutor {
87 async fn execute(&self, _response: &str) -> Result<Option<ToolOutput>, ToolError> {
88 Ok(None)
89 }
90
91 async fn execute_tool_call(&self, call: &ToolCall) -> Result<Option<ToolOutput>, ToolError> {
92 if call.tool_id != "diagnostics" {
93 return Ok(None);
94 }
95 let p: DiagnosticsParams = deserialize_params(&call.params)?;
96 let work_dir = if let Some(path) = &p.path {
97 self.validate_path(Path::new(path))?
98 } else {
99 let cwd = std::env::current_dir().unwrap_or_else(|_| PathBuf::from("."));
100 self.validate_path(&cwd)?
101 };
102
103 let subcmd = match p.level {
104 DiagnosticsLevel::Check => "check",
105 DiagnosticsLevel::Clippy => "clippy",
106 };
107
108 let cargo = which_cargo()?;
109
110 let output = tokio::process::Command::new(&cargo)
111 .arg(subcmd)
112 .arg("--message-format=json")
113 .current_dir(&work_dir)
114 .output()
115 .await
116 .map_err(|e| {
117 ToolError::Execution(std::io::Error::new(
118 std::io::ErrorKind::NotFound,
119 format!("failed to run cargo: {e}"),
120 ))
121 })?;
122
123 let stdout = String::from_utf8_lossy(&output.stdout);
124 let diagnostics = parse_cargo_json(&stdout, self.max_diagnostics);
125
126 let summary = if diagnostics.is_empty() {
127 "No diagnostics".to_owned()
128 } else {
129 diagnostics.join("\n")
130 };
131
132 Ok(Some(ToolOutput {
133 tool_name: "diagnostics".to_owned(),
134 summary,
135 blocks_executed: 1,
136 filter_stats: None,
137 diff: None,
138 streamed: false,
139 terminal_id: None,
140 locations: None,
141 raw_response: None,
142 claim_source: Some(crate::executor::ClaimSource::Diagnostics),
143 }))
144 }
145
146 fn tool_definitions(&self) -> Vec<ToolDef> {
147 vec![ToolDef {
148 id: "diagnostics".into(),
149 description: "Run cargo check or cargo clippy on a Rust workspace and return compiler diagnostics.\n\nParameters: path (string, optional) - workspace directory (default: cwd); level (string, optional) - \"check\" or \"clippy\" (default: \"check\")\nReturns: structured diagnostics with file paths, line numbers, severity, and messages; capped at 50 results\nErrors: SandboxViolation if path outside allowed dirs; Execution if cargo is not found\nExample: {\"path\": \".\", \"level\": \"clippy\"}".into(),
150 schema: schemars::schema_for!(DiagnosticsParams),
151 invocation: InvocationHint::ToolCall,
152 }]
153 }
154}
155
156fn which_cargo() -> Result<PathBuf, ToolError> {
163 if let Ok(cargo) = std::env::var("CARGO") {
165 let p = PathBuf::from(&cargo);
166 if p.is_file() {
167 return Ok(p.canonicalize().unwrap_or(p));
168 }
169 }
170 for dir in std::env::var("PATH").unwrap_or_default().split(':') {
172 let candidate = PathBuf::from(dir).join("cargo");
173 if candidate.is_file() {
174 return Ok(candidate.canonicalize().unwrap_or(candidate));
175 }
176 }
177 Err(ToolError::Execution(std::io::Error::new(
178 std::io::ErrorKind::NotFound,
179 "cargo not found in PATH",
180 )))
181}
182
183pub(crate) fn parse_cargo_json(output: &str, max: usize) -> Vec<String> {
188 let mut results = Vec::new();
189 for line in output.lines() {
190 if results.len() >= max {
191 break;
192 }
193 let Ok(val) = serde_json::from_str::<serde_json::Value>(line) else {
194 continue;
195 };
196 if val.get("reason").and_then(|r| r.as_str()) != Some("compiler-message") {
197 continue;
198 }
199 let Some(msg) = val.get("message") else {
200 continue;
201 };
202 let level = msg
203 .get("level")
204 .and_then(|l| l.as_str())
205 .unwrap_or("unknown");
206 let text = msg
207 .get("message")
208 .and_then(|m| m.as_str())
209 .unwrap_or("")
210 .trim();
211 if text.is_empty() {
212 continue;
213 }
214
215 let spans = msg
217 .get("spans")
218 .and_then(serde_json::Value::as_array)
219 .map_or(&[] as &[_], Vec::as_slice);
220
221 let primary = spans.iter().find(|s| {
222 s.get("is_primary")
223 .and_then(serde_json::Value::as_bool)
224 .unwrap_or(false)
225 });
226
227 if let Some(span) = primary {
228 let file = span
229 .get("file_name")
230 .and_then(|f| f.as_str())
231 .unwrap_or("?");
232 let line = span
233 .get("line_start")
234 .and_then(serde_json::Value::as_u64)
235 .unwrap_or(0);
236 let col = span
237 .get("column_start")
238 .and_then(serde_json::Value::as_u64)
239 .unwrap_or(0);
240 results.push(format!("{file}:{line}:{col}: {level}: {text}"));
241 } else {
242 results.push(format!("{level}: {text}"));
243 }
244 }
245 results
246}
247
248#[cfg(test)]
249mod tests {
250 use super::*;
251
252 fn make_params(
253 pairs: &[(&str, serde_json::Value)],
254 ) -> serde_json::Map<String, serde_json::Value> {
255 pairs
256 .iter()
257 .map(|(k, v)| ((*k).to_owned(), v.clone()))
258 .collect()
259 }
260
261 #[test]
264 fn parse_cargo_json_empty_input() {
265 let result = parse_cargo_json("", 50);
266 assert!(result.is_empty());
267 }
268
269 #[test]
270 fn parse_cargo_json_non_compiler_message_ignored() {
271 let line = r#"{"reason":"build-script-executed","package_id":"foo"}"#;
272 let result = parse_cargo_json(line, 50);
273 assert!(result.is_empty());
274 }
275
276 #[test]
277 fn parse_cargo_json_compiler_message_with_span() {
278 let line = r#"{"reason":"compiler-message","message":{"level":"error","message":"cannot find value `foo` in this scope","spans":[{"file_name":"src/main.rs","line_start":10,"column_start":5,"is_primary":true}]}}"#;
279 let result = parse_cargo_json(line, 50);
280 assert_eq!(result.len(), 1);
281 assert!(result[0].contains("src/main.rs"));
282 assert!(result[0].contains("10"));
283 assert!(result[0].contains("error"));
284 assert!(result[0].contains("cannot find value"));
285 }
286
287 #[test]
288 fn parse_cargo_json_warning_with_span() {
289 let line = r#"{"reason":"compiler-message","message":{"level":"warning","message":"unused variable: `x`","spans":[{"file_name":"src/lib.rs","line_start":3,"column_start":9,"is_primary":true}]}}"#;
290 let result = parse_cargo_json(line, 50);
291 assert_eq!(result.len(), 1);
292 assert!(result[0].starts_with("src/lib.rs:3:9: warning:"));
293 }
294
295 #[test]
296 fn parse_cargo_json_no_primary_span_uses_message_only() {
297 let line = r#"{"reason":"compiler-message","message":{"level":"error","message":"aborting due to previous error","spans":[]}}"#;
298 let result = parse_cargo_json(line, 50);
299 assert_eq!(result.len(), 1);
300 assert_eq!(result[0], "error: aborting due to previous error");
301 }
302
303 #[test]
304 fn parse_cargo_json_max_cap_respected() {
305 let single = r#"{"reason":"compiler-message","message":{"level":"warning","message":"unused","spans":[]}}"#;
306 let input: String = (0..20).map(|_| single).collect::<Vec<_>>().join("\n");
307 let result = parse_cargo_json(&input, 5);
308 assert_eq!(result.len(), 5);
309 }
310
311 #[test]
312 fn parse_cargo_json_empty_message_skipped() {
313 let line = r#"{"reason":"compiler-message","message":{"level":"note","message":" ","spans":[]}}"#;
314 let result = parse_cargo_json(line, 50);
315 assert!(result.is_empty());
316 }
317
318 #[test]
319 fn parse_cargo_json_non_primary_span_skipped_for_location() {
320 let line = r#"{"reason":"compiler-message","message":{"level":"warning","message":"some warning","spans":[{"file_name":"src/foo.rs","line_start":1,"column_start":1,"is_primary":false}]}}"#;
321 let result = parse_cargo_json(line, 50);
323 assert_eq!(result.len(), 1);
324 assert_eq!(result[0], "warning: some warning");
325 }
326
327 #[test]
328 fn parse_cargo_json_invalid_json_line_skipped() {
329 let input = "not json\n{\"reason\":\"build-script-executed\"}";
330 let result = parse_cargo_json(input, 50);
331 assert!(result.is_empty());
332 }
333
334 #[tokio::test]
337 async fn diagnostics_sandbox_violation() {
338 let dir = tempfile::tempdir().unwrap();
339 let exec = DiagnosticsExecutor::new(vec![dir.path().to_path_buf()]);
340
341 let call = ToolCall {
342 tool_id: "diagnostics".to_owned(),
343 params: make_params(&[("path", serde_json::json!("/etc"))]),
344 caller_id: None,
345 };
346 let result = exec.execute_tool_call(&call).await;
347 assert!(result.is_err());
348 }
349
350 #[tokio::test]
351 async fn diagnostics_unknown_tool_returns_none() {
352 let exec = DiagnosticsExecutor::new(vec![]);
353 let call = ToolCall {
354 tool_id: "other".to_owned(),
355 params: serde_json::Map::new(),
356 caller_id: None,
357 };
358 let result = exec.execute_tool_call(&call).await.unwrap();
359 assert!(result.is_none());
360 }
361
362 #[test]
363 fn diagnostics_tool_definition() {
364 let exec = DiagnosticsExecutor::new(vec![]);
365 let defs = exec.tool_definitions();
366 assert_eq!(defs.len(), 1);
367 assert_eq!(defs[0].id, "diagnostics");
368 assert_eq!(defs[0].invocation, InvocationHint::ToolCall);
369 }
370
371 #[test]
372 fn diagnostics_level_default_is_check() {
373 assert_eq!(DiagnosticsLevel::default(), DiagnosticsLevel::Check);
374 }
375
376 #[test]
377 fn diagnostics_level_deserialize_check() {
378 let p: DiagnosticsParams = serde_json::from_str(r#"{"level":"check"}"#).unwrap();
379 assert_eq!(p.level, DiagnosticsLevel::Check);
380 }
381
382 #[test]
383 fn diagnostics_level_deserialize_clippy() {
384 let p: DiagnosticsParams = serde_json::from_str(r#"{"level":"clippy"}"#).unwrap();
385 assert_eq!(p.level, DiagnosticsLevel::Clippy);
386 }
387
388 #[test]
389 fn diagnostics_params_path_optional() {
390 let p: DiagnosticsParams = serde_json::from_str(r"{}").unwrap();
391 assert!(p.path.is_none());
392 assert_eq!(p.level, DiagnosticsLevel::Check);
393 }
394
395 #[test]
397 fn diagnostics_clippy_subcmd_string() {
398 let subcmd = match DiagnosticsLevel::Clippy {
399 DiagnosticsLevel::Check => "check",
400 DiagnosticsLevel::Clippy => "clippy",
401 };
402 assert_eq!(subcmd, "clippy");
403 }
404
405 #[test]
406 fn diagnostics_check_subcmd_string() {
407 let subcmd = match DiagnosticsLevel::Check {
408 DiagnosticsLevel::Check => "check",
409 DiagnosticsLevel::Clippy => "clippy",
410 };
411 assert_eq!(subcmd, "check");
412 }
413}