1use std::path::{Path, PathBuf};
2use std::time::Instant;
3
4use hf_hub::{api::sync::ApiBuilder, Cache};
5use serde::{Deserialize, Serialize};
6use sysinfo::{Disks, System};
7
8#[cfg(any(feature = "cuda", feature = "metal"))]
9use crate::MemoryUsage;
10#[cfg(any(feature = "cuda", feature = "metal"))]
11use candle_core::Device;
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct CpuInfo {
15 pub brand: Option<String>,
16 pub logical_cores: usize,
17 pub physical_cores: Option<usize>,
18 pub avx: bool,
19 pub avx2: bool,
20 pub avx512: bool,
21 pub fma: bool,
22}
23
24#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct MemoryInfo {
26 pub total_bytes: u64,
27 pub available_bytes: u64,
28}
29
30#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct DeviceInfo {
32 pub kind: String,
33 pub ordinal: Option<usize>,
34 pub name: Option<String>,
35 pub total_memory_bytes: Option<u64>,
36 pub available_memory_bytes: Option<u64>,
37 #[serde(skip_serializing_if = "Option::is_none")]
39 pub compute_capability: Option<(u32, u32)>,
40 #[serde(skip_serializing_if = "Option::is_none")]
42 pub flash_attn_compatible: Option<bool>,
43 #[serde(skip_serializing_if = "Option::is_none")]
45 pub flash_attn_v3_compatible: Option<bool>,
46}
47
48#[derive(Debug, Clone, Serialize, Deserialize)]
49pub struct BuildInfo {
50 pub cuda: bool,
51 pub metal: bool,
52 pub cudnn: bool,
53 pub flash_attn: bool,
54 pub flash_attn_v3: bool,
55 pub accelerate: bool,
56 pub mkl: bool,
57 pub git_revision: String,
58}
59
60#[derive(Debug, Clone, Serialize, Deserialize)]
61pub struct HfConnectivityInfo {
62 pub reachable: bool,
64 #[serde(skip_serializing_if = "Option::is_none")]
66 pub latency_ms: Option<u64>,
67 #[serde(skip_serializing_if = "Option::is_none")]
69 pub token_valid_for_gated: Option<bool>,
70 #[serde(skip_serializing_if = "Option::is_none")]
72 pub error: Option<String>,
73}
74
75#[derive(Debug, Clone, Serialize, Deserialize)]
76pub struct SystemInfo {
77 pub os: Option<String>,
78 pub kernel: Option<String>,
79 pub cpu: CpuInfo,
80 pub memory: MemoryInfo,
81 pub devices: Vec<DeviceInfo>,
82 pub build: BuildInfo,
83 pub hf_cache_path: Option<String>,
84}
85
86#[derive(Debug, Clone, Serialize, Deserialize)]
87#[serde(rename_all = "lowercase")]
88pub enum DoctorStatus {
89 Ok,
90 Warn,
91 Error,
92}
93
94#[derive(Debug, Clone, Serialize, Deserialize)]
95pub struct DoctorCheck {
96 pub name: String,
97 pub status: DoctorStatus,
98 pub message: String,
99 #[serde(skip_serializing_if = "Option::is_none")]
100 pub suggestion: Option<String>,
101}
102
103#[derive(Debug, Clone, Serialize, Deserialize)]
104pub struct DoctorReport {
105 pub system: SystemInfo,
106 pub checks: Vec<DoctorCheck>,
107}
108
109fn build_info() -> BuildInfo {
110 BuildInfo {
111 cuda: cfg!(feature = "cuda"),
112 metal: cfg!(feature = "metal"),
113 cudnn: cfg!(feature = "cudnn"),
114 flash_attn: cfg!(feature = "flash-attn"),
115 flash_attn_v3: cfg!(feature = "flash-attn-v3"),
116 accelerate: cfg!(feature = "accelerate"),
117 mkl: cfg!(feature = "mkl"),
118 git_revision: crate::MISTRALRS_GIT_REVISION.to_string(),
119 }
120}
121
122fn collect_devices(sys: &System) -> Vec<DeviceInfo> {
123 let mut devices = Vec::new();
124
125 let cpu_brand = sys.cpus().first().map(|c| c.brand().to_string());
127 devices.push(DeviceInfo {
128 kind: "cpu".to_string(),
129 ordinal: None,
130 name: cpu_brand,
131 total_memory_bytes: Some(sys.total_memory()),
132 available_memory_bytes: Some(sys.available_memory()),
133 compute_capability: None,
134 flash_attn_compatible: None,
135 flash_attn_v3_compatible: None,
136 });
137
138 #[cfg(feature = "cuda")]
139 {
140 let mut ord = 0;
141 loop {
142 match Device::new_cuda(ord) {
143 Ok(dev) => {
144 let total = MemoryUsage.get_total_memory(&dev).ok().map(|v| v as u64);
145 let avail = MemoryUsage
146 .get_memory_available(&dev)
147 .ok()
148 .map(|v| v as u64);
149
150 let compute_cap = get_cuda_compute_capability(ord);
152 let flash_attn_v2_ok = compute_cap.map(|(major, _minor)| {
153 major >= 8
155 });
156 let flash_attn_v3_ok = compute_cap.map(|(major, minor)| {
157 major == 9 && minor == 0
159 });
160
161 devices.push(DeviceInfo {
162 kind: "cuda".to_string(),
163 ordinal: Some(ord),
164 name: None,
165 total_memory_bytes: total,
166 available_memory_bytes: avail,
167 compute_capability: compute_cap,
168 flash_attn_compatible: flash_attn_v2_ok,
169 flash_attn_v3_compatible: flash_attn_v3_ok,
170 });
171 ord += 1;
172 }
173 Err(_) => break,
174 }
175 }
176 }
177
178 #[cfg(feature = "metal")]
179 {
180 let total = candle_metal_kernels::metal::Device::all().len();
181 for ord in 0..total {
182 if let Ok(dev) = Device::new_metal(ord) {
183 let total = MemoryUsage.get_total_memory(&dev).ok().map(|v| v as u64);
184 let avail = MemoryUsage
185 .get_memory_available(&dev)
186 .ok()
187 .map(|v| v as u64);
188 devices.push(DeviceInfo {
189 kind: "metal".to_string(),
190 ordinal: Some(ord),
191 name: None,
192 total_memory_bytes: total,
193 available_memory_bytes: avail,
194 compute_capability: None,
195 flash_attn_compatible: Some(true), flash_attn_v3_compatible: None, });
198 }
199 }
200 }
201
202 devices
203}
204
205#[cfg(feature = "cuda")]
207fn get_cuda_compute_capability(ordinal: usize) -> Option<(u32, u32)> {
208 let output = std::process::Command::new("nvidia-smi")
210 .args([
211 "--query-gpu=compute_cap",
212 "--format=csv,noheader",
213 &format!("-i={ordinal}"),
214 ])
215 .output()
216 .ok()?;
217
218 if !output.status.success() {
219 return None;
220 }
221
222 let stdout = String::from_utf8(output.stdout).ok()?;
223 let cap = stdout.trim();
224
225 let parts: Vec<&str> = cap.split('.').collect();
227 if parts.len() == 2 {
228 let major = parts[0].parse().ok()?;
229 let minor = parts[1].parse().ok()?;
230 Some((major, minor))
231 } else {
232 None
233 }
234}
235
236#[cfg(not(feature = "cuda"))]
237#[allow(dead_code)]
238fn get_cuda_compute_capability(_ordinal: usize) -> Option<(u32, u32)> {
239 None
240}
241
242fn detect_cpu_extensions() -> (bool, bool, bool, bool) {
244 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
245 {
246 let avx = std::arch::is_x86_feature_detected!("avx");
247 let avx2 = std::arch::is_x86_feature_detected!("avx2");
248 let avx512 = std::arch::is_x86_feature_detected!("avx512f");
249 let fma = std::arch::is_x86_feature_detected!("fma");
250 (avx, avx2, avx512, fma)
251 }
252 #[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))]
253 {
254 (false, false, false, false)
255 }
256}
257
258pub fn collect_system_info() -> SystemInfo {
259 let mut sys = System::new_all();
260 sys.refresh_all();
261
262 let (avx, avx2, avx512, fma) = detect_cpu_extensions();
263
264 let cpu = CpuInfo {
265 brand: sys.cpus().first().map(|c| c.brand().to_string()),
266 logical_cores: sys.cpus().len(),
267 physical_cores: System::physical_core_count(),
268 avx,
269 avx2,
270 avx512,
271 fma,
272 };
273
274 let memory = MemoryInfo {
275 total_bytes: sys.total_memory(),
276 available_bytes: sys.available_memory(),
277 };
278
279 let hf_cache = Cache::from_env();
280 let hf_cache_path = hf_cache.path().to_string_lossy().to_string();
281
282 SystemInfo {
283 os: System::name(),
284 kernel: System::kernel_version(),
285 cpu,
286 memory,
287 devices: collect_devices(&sys),
288 build: build_info(),
289 hf_cache_path: Some(hf_cache_path),
290 }
291}
292
293#[allow(clippy::cast_possible_truncation)]
295pub fn check_hf_gated_access() -> HfConnectivityInfo {
296 let start = Instant::now();
297
298 let api_result = ApiBuilder::from_env()
300 .with_progress(false)
301 .build()
302 .and_then(|api| api.model("google/gemma-3-4b-it".to_string()).info());
303
304 let latency_ms = start.elapsed().as_millis() as u64;
305
306 match api_result {
307 Ok(_) => HfConnectivityInfo {
308 reachable: true,
309 latency_ms: Some(latency_ms),
310 token_valid_for_gated: Some(true),
311 error: None,
312 },
313 Err(e) => {
314 let error_str = e.to_string();
315 let is_auth_error = error_str.contains("401")
317 || error_str.contains("403")
318 || error_str.contains("unauthorized")
319 || error_str.contains("Unauthorized")
320 || error_str.contains("Access denied")
321 || error_str.contains("gated");
322
323 if is_auth_error {
324 HfConnectivityInfo {
326 reachable: true,
327 latency_ms: Some(latency_ms),
328 token_valid_for_gated: Some(false),
329 error: Some("Token invalid or missing for gated models".to_string()),
330 }
331 } else {
332 HfConnectivityInfo {
334 reachable: false,
335 latency_ms: None,
336 token_valid_for_gated: None,
337 error: Some(error_str),
338 }
339 }
340 }
341 }
342}
343
344fn disk_usage_for(path: &Path) -> Option<(u64, u64)> {
345 let disks = Disks::new_with_refreshed_list();
346 let mut best: Option<(usize, u64, u64)> = None;
347 for disk in disks.list() {
348 let mount = disk.mount_point();
349 if path.starts_with(mount) {
350 let len = mount.as_os_str().len();
351 let avail = disk.available_space();
352 let total = disk.total_space();
353 if best.map(|b| len > b.0).unwrap_or(true) {
354 best = Some((len, avail, total));
355 }
356 }
357 }
358 best.map(|(_, avail, total)| (avail, total))
359}
360
361pub fn run_doctor() -> DoctorReport {
362 let system = collect_system_info();
363 let mut checks = Vec::new();
364
365 {
367 let is_arm = cfg!(any(target_arch = "aarch64", target_arch = "arm"));
368
369 if is_arm {
370 checks.push(DoctorCheck {
372 name: "cpu_extensions".to_string(),
373 status: DoctorStatus::Ok,
374 message: "CPU: ARM architecture (uses NEON)".to_string(),
375 suggestion: None,
376 });
377 } else {
378 let mut extensions = Vec::new();
380 if system.cpu.avx {
381 extensions.push("AVX");
382 }
383 if system.cpu.avx2 {
384 extensions.push("AVX2");
385 }
386 if system.cpu.fma {
387 extensions.push("FMA");
388 }
389 if system.cpu.avx512 {
390 extensions.push("AVX-512");
391 }
392
393 let has_avx2 = system.cpu.avx2;
394 let ext_str = if extensions.is_empty() {
395 "none detected".to_string()
396 } else {
397 extensions.join(", ")
398 };
399
400 checks.push(DoctorCheck {
401 name: "cpu_extensions".to_string(),
402 status: if has_avx2 {
403 DoctorStatus::Ok
404 } else {
405 DoctorStatus::Warn
406 },
407 message: format!("CPU extensions: {ext_str}"),
408 suggestion: if !has_avx2 {
409 Some("AVX2 is recommended for optimal GGML performance on x86.".to_string())
410 } else {
411 None
412 },
413 });
414 }
415 }
416
417 {
419 let has_cuda_device = system.devices.iter().any(|d| d.kind == "cuda");
420 let has_metal_device = system.devices.iter().any(|d| d.kind == "metal");
421
422 if has_cuda_device && !system.build.cuda {
423 checks.push(DoctorCheck {
424 name: "binary_hardware_match".to_string(),
425 status: DoctorStatus::Error,
426 message: "NVIDIA GPU detected but binary compiled without CUDA support."
427 .to_string(),
428 suggestion: Some("Reinstall with CUDA: cargo install --features cuda".to_string()),
429 });
430 } else if has_metal_device && !system.build.metal {
431 checks.push(DoctorCheck {
432 name: "binary_hardware_match".to_string(),
433 status: DoctorStatus::Error,
434 message: "Apple GPU detected but binary compiled without Metal support."
435 .to_string(),
436 suggestion: Some(
437 "Reinstall with Metal: cargo install --features metal".to_string(),
438 ),
439 });
440 } else {
441 checks.push(DoctorCheck {
442 name: "binary_hardware_match".to_string(),
443 status: DoctorStatus::Ok,
444 message: "Binary features match detected hardware.".to_string(),
445 suggestion: None,
446 });
447 }
448 }
449
450 #[cfg(feature = "cuda")]
452 {
453 for dev in system.devices.iter().filter(|d| d.kind == "cuda") {
454 if let (Some(ord), Some((major, minor))) = (dev.ordinal, dev.compute_capability) {
455 let fa_v2_ok = dev.flash_attn_compatible.unwrap_or(false);
456 let fa_v3_ok = dev.flash_attn_v3_compatible.unwrap_or(false);
457
458 let fa_v2_str = if fa_v2_ok { "✅" } else { "❌" };
460 let fa_v3_str = if fa_v3_ok {
461 "✅"
462 } else {
463 "❌ (requires Hopper/Compute 9.0)"
464 };
465
466 checks.push(DoctorCheck {
467 name: format!("cuda_{}_compute", ord),
468 status: DoctorStatus::Ok,
469 message: format!(
470 "GPU {}: compute {}.{} - Flash Attn v2 {}, v3 {}",
471 ord, major, minor, fa_v2_str, fa_v3_str
472 ),
473 suggestion: None,
474 });
475
476 if fa_v2_ok && !system.build.flash_attn {
478 checks.push(DoctorCheck {
479 name: format!("cuda_{}_flash_attn_v2_missing", ord),
480 status: DoctorStatus::Warn,
481 message: format!(
482 "GPU {} supports Flash Attention v2 but binary compiled without it.",
483 ord
484 ),
485 suggestion: Some(
486 "Reinstall with: cargo install --features flash-attn".to_string(),
487 ),
488 });
489 }
490
491 if fa_v3_ok && !system.build.flash_attn_v3 {
493 checks.push(DoctorCheck {
494 name: format!("cuda_{}_flash_attn_v3_missing", ord),
495 status: DoctorStatus::Warn,
496 message: format!(
497 "GPU {} supports Flash Attention v3 but binary compiled without it.",
498 ord
499 ),
500 suggestion: Some(
501 "Reinstall with: cargo install --features flash-attn-v3".to_string(),
502 ),
503 });
504 }
505 }
506 }
507 }
508
509 let hf_cache_path = system
510 .hf_cache_path
511 .as_ref()
512 .map(PathBuf::from)
513 .unwrap_or_else(|| Cache::from_env().path().clone());
514
515 if std::fs::create_dir_all(&hf_cache_path).is_err() {
516 checks.push(DoctorCheck {
517 name: "hf_cache_writable".to_string(),
518 status: DoctorStatus::Error,
519 message: format!(
520 "Cannot create or access Hugging Face cache dir at {}",
521 hf_cache_path.display()
522 ),
523 suggestion: Some("Set HF_HOME or fix permissions.".to_string()),
524 });
525 } else {
526 checks.push(DoctorCheck {
527 name: "hf_cache_writable".to_string(),
528 status: DoctorStatus::Ok,
529 message: format!(
530 "Hugging Face cache dir is writable: {}",
531 hf_cache_path.display()
532 ),
533 suggestion: None,
534 });
535 }
536
537 {
539 let hf_info = check_hf_gated_access();
540 if hf_info.reachable {
541 if hf_info.token_valid_for_gated == Some(true) {
542 checks.push(DoctorCheck {
543 name: "hf_connectivity".to_string(),
544 status: DoctorStatus::Ok,
545 message: format!(
546 "Hugging Face: connected ({}ms), token valid for allowed gated models.",
547 hf_info.latency_ms.unwrap_or(0)
548 ),
549 suggestion: None,
550 });
551 } else {
552 checks.push(DoctorCheck {
553 name: "hf_connectivity".to_string(),
554 status: DoctorStatus::Warn,
555 message: format!(
556 "Hugging Face: connected ({}ms), but token invalid/missing.",
557 hf_info.latency_ms.unwrap_or(0)
558 ),
559 suggestion: Some(
560 "Run `huggingface-cli login` or set HF_TOKEN to access gated models."
561 .to_string(),
562 ),
563 });
564 }
565 } else {
566 checks.push(DoctorCheck {
567 name: "hf_connectivity".to_string(),
568 status: DoctorStatus::Error,
569 message: format!(
570 "Hugging Face: unreachable - {}",
571 hf_info.error.unwrap_or_else(|| "unknown error".to_string())
572 ),
573 suggestion: Some(
574 "Check your internet connection and firewall settings.".to_string(),
575 ),
576 });
577 }
578 }
579
580 if let Some((avail, total)) = disk_usage_for(&hf_cache_path) {
581 let min_free = 10_u64 * 1024 * 1024 * 1024;
582 let status = if avail < min_free {
583 DoctorStatus::Warn
584 } else {
585 DoctorStatus::Ok
586 };
587 checks.push(DoctorCheck {
588 name: "disk_space".to_string(),
589 status,
590 #[allow(clippy::cast_precision_loss)]
591 message: format!(
592 "Disk free: {:.1} GB / {:.1} GB on the volume containing the HF cache at {}.",
593 avail as f64 / 1e9,
594 total as f64 / 1e9,
595 hf_cache_path.display()
596 ),
597 suggestion: if avail < min_free {
598 Some("Free up disk space or move HF cache.".to_string())
599 } else {
600 None
601 },
602 });
603 }
604
605 let has_cuda = system.devices.iter().any(|d| d.kind == "cuda");
606
607 if system.build.cuda && !has_cuda {
608 checks.push(DoctorCheck {
609 name: "cuda_devices".to_string(),
610 status: DoctorStatus::Warn,
611 message: "CUDA support is enabled but no CUDA devices were found.".to_string(),
612 suggestion: Some("Check NVIDIA driver installation.".to_string()),
613 });
614 }
615
616 DoctorReport { system, checks }
617}