1use anyhow::Result;
9use serde::Serialize;
10use std::process::Command;
11use std::time::Instant;
12
13#[derive(Debug, Clone, Serialize)]
15pub struct ToolCheck {
16 pub name: String,
17 pub version: Option<String>,
18 pub status: ToolStatus,
19 pub path: Option<String>,
20}
21
22#[derive(Debug, Clone, PartialEq, Serialize)]
23pub enum ToolStatus {
24 Ok,
25 Missing,
26 VersionMismatch { expected: String, found: String },
27 Error(String),
28}
29
30#[derive(Debug, Serialize)]
32pub struct DoctorReport {
33 pub checks: Vec<ToolCheck>,
34 pub ok_count: usize,
35 pub total_required: usize,
36 pub operational: bool,
37 pub elapsed_ms: f64,
38}
39
40impl std::fmt::Display for ToolStatus {
41 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
42 match self {
43 ToolStatus::Ok => write!(f, "\x1b[32m[OK]\x1b[0m"),
44 ToolStatus::Missing => write!(f, "\x1b[31m[MISSING]\x1b[0m"),
45 ToolStatus::VersionMismatch { expected, found } => {
46 write!(
47 f,
48 "\x1b[33m[VERSION] expected {expected}, found {found}\x1b[0m"
49 )
50 }
51 ToolStatus::Error(msg) => write!(f, "\x1b[31m[ERROR: {msg}]\x1b[0m"),
52 }
53 }
54}
55
56fn check_binary(
58 name: &str,
59 version_args: &[&str],
60 version_parser: fn(&str) -> Option<String>,
61) -> ToolCheck {
62 match which::which(name) {
63 Ok(path) => {
64 let version = if version_args.is_empty() {
65 None
66 } else {
67 Command::new(name)
68 .args(version_args)
69 .output()
70 .ok()
71 .and_then(|out| {
72 let stdout = String::from_utf8_lossy(&out.stdout).to_string();
73 let stderr = String::from_utf8_lossy(&out.stderr).to_string();
74 let combined = format!("{stdout}{stderr}");
75 version_parser(&combined)
76 })
77 };
78 ToolCheck {
79 name: name.to_string(),
80 version,
81 status: ToolStatus::Ok,
82 path: Some(path.display().to_string()),
83 }
84 }
85 Err(_) => ToolCheck {
86 name: name.to_string(),
87 version: None,
88 status: ToolStatus::Missing,
89 path: None,
90 },
91 }
92}
93
94fn parse_ncu_version(output: &str) -> Option<String> {
96 output
97 .lines()
98 .find(|l| l.contains("Nsight Compute") || l.contains("ncu"))
99 .and_then(|l| l.split_whitespace().last().map(String::from))
100}
101
102fn parse_nsys_version(output: &str) -> Option<String> {
104 output
105 .lines()
106 .find(|l| l.contains("version") || l.contains("Nsight Systems"))
107 .and_then(|l| {
108 l.split_whitespace()
109 .find(|w| w.chars().next().is_some_and(|c| c.is_ascii_digit()))
110 .map(String::from)
111 })
112}
113
114#[allow(dead_code)]
116fn parse_nvidia_smi_version(output: &str) -> Option<String> {
117 output
118 .lines()
119 .find(|l| l.contains("Driver Version"))
120 .and_then(|l| {
121 l.split("Driver Version:")
122 .nth(1)
123 .and_then(|s| s.split_whitespace().next())
124 .map(String::from)
125 })
126}
127
128fn parse_perf_version(output: &str) -> Option<String> {
130 output.lines().next().and_then(|l| {
131 l.split_whitespace()
132 .find(|w| w.chars().next().is_some_and(|c| c.is_ascii_digit()))
133 .map(String::from)
134 })
135}
136
137fn parse_generic_version(output: &str) -> Option<String> {
139 output.lines().next().and_then(|l| {
140 l.split_whitespace()
141 .find(|w| w.chars().next().is_some_and(|c| c.is_ascii_digit()))
142 .map(String::from)
143 })
144}
145
146fn detect_gpu() -> ToolCheck {
148 let result = Command::new("nvidia-smi")
149 .args(["--query-gpu=name,compute_cap", "--format=csv,noheader"])
150 .output();
151 match result {
152 Ok(out) if out.status.success() => {
153 let stdout = String::from_utf8_lossy(&out.stdout);
154 let info = stdout.trim().to_string();
155 ToolCheck {
156 name: "GPU".to_string(),
157 version: Some(info),
158 status: ToolStatus::Ok,
159 path: None,
160 }
161 }
162 _ => ToolCheck {
163 name: "GPU".to_string(),
164 version: None,
165 status: ToolStatus::Missing,
166 path: None,
167 },
168 }
169}
170
171fn detect_cpu() -> ToolCheck {
173 #[cfg(target_arch = "x86_64")]
174 {
175 let mut features = Vec::new();
176 if std::arch::is_x86_feature_detected!("avx2") {
177 features.push("AVX2");
178 }
179 if std::arch::is_x86_feature_detected!("fma") {
180 features.push("FMA");
181 }
182 if std::arch::is_x86_feature_detected!("avx512f") {
183 features.push("AVX-512F");
184 }
185 if std::arch::is_x86_feature_detected!("sse4.2") {
186 features.push("SSE4.2");
187 }
188 let cpu_model = std::fs::read_to_string("/proc/cpuinfo").ok().and_then(|s| {
189 s.lines()
190 .find(|l| l.starts_with("model name"))
191 .and_then(|l| l.split(':').nth(1))
192 .map(|s| s.trim().to_string())
193 });
194 let version = match cpu_model {
195 Some(model) => format!("{model} ({features})", features = features.join(", ")),
196 None => features.join(", "),
197 };
198 ToolCheck {
199 name: "CPU".to_string(),
200 version: Some(version),
201 status: ToolStatus::Ok,
202 path: None,
203 }
204 }
205 #[cfg(target_arch = "aarch64")]
206 {
207 ToolCheck {
208 name: "CPU".to_string(),
209 version: Some("aarch64 (NEON)".to_string()),
210 status: ToolStatus::Ok,
211 path: None,
212 }
213 }
214 #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
215 {
216 ToolCheck {
217 name: "CPU".to_string(),
218 version: Some(format!("{}", std::env::consts::ARCH)),
219 status: ToolStatus::Ok,
220 path: None,
221 }
222 }
223}
224
225fn check_perf_paranoid() -> Option<i32> {
227 std::fs::read_to_string("/proc/sys/kernel/perf_event_paranoid")
228 .ok()
229 .and_then(|s| s.trim().parse().ok())
230}
231
232pub fn collect_checks() -> Vec<ToolCheck> {
234 vec![
235 check_binary(
236 "nvidia-smi",
237 &["--query-gpu=driver_version", "--format=csv,noheader"],
238 |s| Some(s.trim().to_string()),
239 ),
240 {
241 let mut check = check_binary("nvcc", &["--version"], |s| {
243 s.lines()
244 .find(|l| l.contains("release"))
245 .and_then(|l| l.split("release ").nth(1))
246 .and_then(|s| s.split(',').next())
247 .map(String::from)
248 });
249 check.name = "CUDA Runtime".to_string();
250 check
251 },
252 check_binary("ncu", &["--version"], parse_ncu_version),
253 check_binary("nsys", &["--version"], parse_nsys_version),
254 {
255 let cupti_paths = [
257 "/usr/local/cuda/lib64/libcupti.so",
258 "/usr/lib/x86_64-linux-gnu/libcupti.so",
259 ];
260 let found = cupti_paths
261 .iter()
262 .find(|p| std::path::Path::new(p).exists());
263 ToolCheck {
264 name: "CUPTI".to_string(),
265 version: found.map(|p| p.to_string()),
266 status: if found.is_some() {
267 ToolStatus::Ok
268 } else {
269 ToolStatus::Missing
270 },
271 path: found.map(|p| p.to_string()),
272 }
273 },
274 {
275 let mut check = check_binary("perf", &["--version"], parse_perf_version);
276 if check.status == ToolStatus::Ok {
277 if let Some(paranoid) = check_perf_paranoid() {
278 check.version = Some(format!(
279 "{} (perf_event_paranoid={})",
280 check.version.as_deref().unwrap_or("?"),
281 paranoid
282 ));
283 }
284 }
285 check
286 },
287 check_binary("valgrind", &["--version"], parse_generic_version),
289 check_binary("renacer", &["--version"], parse_generic_version),
290 check_binary("trueno-explain", &["--version"], parse_generic_version),
291 detect_gpu(),
292 detect_cpu(),
293 ]
294}
295
296pub fn build_report() -> DoctorReport {
298 let start = Instant::now();
299 let checks = collect_checks();
300
301 let optional_tools = ["renacer", "trueno-explain", "CUPTI"];
302 let mut ok_count = 0;
303 let mut total = checks.len();
304
305 for check in &checks {
306 if check.status == ToolStatus::Ok {
307 ok_count += 1;
308 } else if optional_tools.contains(&check.name.as_str()) {
309 total -= 1;
310 }
311 }
312
313 let elapsed = start.elapsed();
314 DoctorReport {
315 checks,
316 ok_count,
317 total_required: total,
318 operational: ok_count >= total,
319 elapsed_ms: elapsed.as_secs_f64() * 1000.0,
320 }
321}
322
323pub fn run_doctor(json: bool) -> Result<()> {
325 let report = build_report();
326
327 if json {
328 println!("{}", serde_json::to_string_pretty(&report)?);
329 return Ok(());
330 }
331
332 println!("\n=== cgp System Check ===\n");
333
334 for check in &report.checks {
335 let version_str = check.version.as_deref().unwrap_or("");
336 let pad_name = format!("{:18}", format!("{}:", check.name));
337 let pad_version = format!("{:30}", version_str);
338 println!(" {pad_name}{pad_version}{}", check.status);
339 }
340
341 if let Some(paranoid) = check_perf_paranoid() {
343 if paranoid > 2 {
344 println!(
345 " \x1b[33m[WARN]\x1b[0m perf_event_paranoid={paranoid} — hardware counters blocked for non-root users."
346 );
347 println!(" Fix: sudo sysctl kernel.perf_event_paranoid=2");
348 println!(" Or run cgp with sudo for perf stat features.\n");
349 }
350 }
351
352 if report.operational {
353 println!(
354 " All {} required components available. cgp is fully operational.",
355 report.ok_count
356 );
357 } else {
358 let missing = report.total_required - report.ok_count;
359 println!(
360 " {}/{} components available. {missing} missing — cgp will operate in degraded mode.",
361 report.ok_count, report.total_required
362 );
363 }
364 println!(" Completed in {:.0}ms", report.elapsed_ms);
365 println!();
366
367 Ok(())
368}
369
370#[cfg(test)]
371mod tests {
372 use super::*;
373
374 #[test]
376 fn test_detect_cpu_features() {
377 let cpu = detect_cpu();
378 assert_eq!(cpu.status, ToolStatus::Ok);
379 assert!(cpu.version.is_some());
380 }
381
382 #[test]
384 fn test_missing_tool_graceful() {
385 let check = check_binary(
386 "nonexistent-tool-xyz",
387 &["--version"],
388 parse_generic_version,
389 );
390 assert_eq!(check.status, ToolStatus::Missing);
391 assert!(check.path.is_none());
392 }
393
394 #[test]
396 fn test_doctor_speed() {
397 let start = Instant::now();
398 let _ = detect_cpu();
400 let _ = detect_gpu();
401 let _ = check_binary("nonexistent", &[], parse_generic_version);
402 let elapsed = start.elapsed();
403 assert!(elapsed.as_secs() < 2, "doctor checks took {:?}", elapsed);
404 }
405}