Skip to main content

mistralrs_core/
diagnostics.rs

1use std::path::{Path, PathBuf};
2use std::time::Instant;
3
4use hf_hub::{api::sync::ApiBuilder, Cache};
5use serde::{Deserialize, Serialize};
6use sysinfo::{Disks, System};
7
8#[cfg(any(feature = "cuda", feature = "metal"))]
9use crate::MemoryUsage;
10#[cfg(any(feature = "cuda", feature = "metal"))]
11use candle_core::Device;
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct CpuInfo {
15    pub brand: Option<String>,
16    pub logical_cores: usize,
17    pub physical_cores: Option<usize>,
18    pub avx: bool,
19    pub avx2: bool,
20    pub avx512: bool,
21    pub fma: bool,
22}
23
24#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct MemoryInfo {
26    pub total_bytes: u64,
27    pub available_bytes: u64,
28}
29
30#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct DeviceInfo {
32    pub kind: String,
33    pub ordinal: Option<usize>,
34    pub name: Option<String>,
35    pub total_memory_bytes: Option<u64>,
36    pub available_memory_bytes: Option<u64>,
37    /// CUDA compute capability (major, minor) - None for non-CUDA devices
38    #[serde(skip_serializing_if = "Option::is_none")]
39    pub compute_capability: Option<(u32, u32)>,
40    /// Whether this GPU supports Flash Attention v2 (compute capability >= 8.0)
41    #[serde(skip_serializing_if = "Option::is_none")]
42    pub flash_attn_compatible: Option<bool>,
43    /// Whether this GPU supports Flash Attention v3 (compute capability == 9.0, Hopper only)
44    #[serde(skip_serializing_if = "Option::is_none")]
45    pub flash_attn_v3_compatible: Option<bool>,
46}
47
48#[derive(Debug, Clone, Serialize, Deserialize)]
49pub struct BuildInfo {
50    pub cuda: bool,
51    pub metal: bool,
52    pub cudnn: bool,
53    pub flash_attn: bool,
54    pub flash_attn_v3: bool,
55    pub accelerate: bool,
56    pub mkl: bool,
57    pub git_revision: String,
58}
59
60#[derive(Debug, Clone, Serialize, Deserialize)]
61pub struct HfConnectivityInfo {
62    /// Whether HuggingFace is reachable
63    pub reachable: bool,
64    /// Latency in milliseconds (if reachable)
65    #[serde(skip_serializing_if = "Option::is_none")]
66    pub latency_ms: Option<u64>,
67    /// Whether the token is valid for gated models
68    #[serde(skip_serializing_if = "Option::is_none")]
69    pub token_valid_for_gated: Option<bool>,
70    /// Error message if not reachable
71    #[serde(skip_serializing_if = "Option::is_none")]
72    pub error: Option<String>,
73}
74
75#[derive(Debug, Clone, Serialize, Deserialize)]
76pub struct SystemInfo {
77    pub os: Option<String>,
78    pub kernel: Option<String>,
79    pub cpu: CpuInfo,
80    pub memory: MemoryInfo,
81    pub devices: Vec<DeviceInfo>,
82    pub build: BuildInfo,
83    pub hf_cache_path: Option<String>,
84}
85
86#[derive(Debug, Clone, Serialize, Deserialize)]
87#[serde(rename_all = "lowercase")]
88pub enum DoctorStatus {
89    Ok,
90    Warn,
91    Error,
92}
93
94#[derive(Debug, Clone, Serialize, Deserialize)]
95pub struct DoctorCheck {
96    pub name: String,
97    pub status: DoctorStatus,
98    pub message: String,
99    #[serde(skip_serializing_if = "Option::is_none")]
100    pub suggestion: Option<String>,
101}
102
103#[derive(Debug, Clone, Serialize, Deserialize)]
104pub struct DoctorReport {
105    pub system: SystemInfo,
106    pub checks: Vec<DoctorCheck>,
107}
108
109fn build_info() -> BuildInfo {
110    BuildInfo {
111        cuda: cfg!(feature = "cuda"),
112        metal: cfg!(feature = "metal"),
113        cudnn: cfg!(feature = "cudnn"),
114        flash_attn: cfg!(feature = "flash-attn"),
115        flash_attn_v3: cfg!(feature = "flash-attn-v3"),
116        accelerate: cfg!(feature = "accelerate"),
117        mkl: cfg!(feature = "mkl"),
118        git_revision: crate::MISTRALRS_GIT_REVISION.to_string(),
119    }
120}
121
122fn collect_devices(sys: &System) -> Vec<DeviceInfo> {
123    let mut devices = Vec::new();
124
125    // CPU device
126    let cpu_brand = sys.cpus().first().map(|c| c.brand().to_string());
127    devices.push(DeviceInfo {
128        kind: "cpu".to_string(),
129        ordinal: None,
130        name: cpu_brand,
131        total_memory_bytes: Some(sys.total_memory()),
132        available_memory_bytes: Some(sys.available_memory()),
133        compute_capability: None,
134        flash_attn_compatible: None,
135        flash_attn_v3_compatible: None,
136    });
137
138    #[cfg(feature = "cuda")]
139    {
140        let mut ord = 0;
141        loop {
142            match Device::new_cuda(ord) {
143                Ok(dev) => {
144                    let total = MemoryUsage.get_total_memory(&dev).ok().map(|v| v as u64);
145                    let avail = MemoryUsage
146                        .get_memory_available(&dev)
147                        .ok()
148                        .map(|v| v as u64);
149
150                    // Get compute capability
151                    let compute_cap = get_cuda_compute_capability(ord);
152                    let flash_attn_v2_ok = compute_cap.map(|(major, _minor)| {
153                        // Flash Attention v2 requires compute capability >= 8.0 (Ampere+)
154                        major >= 8
155                    });
156                    let flash_attn_v3_ok = compute_cap.map(|(major, minor)| {
157                        // Flash Attention v3 requires compute capability == 9.0 (Hopper only)
158                        major == 9 && minor == 0
159                    });
160
161                    devices.push(DeviceInfo {
162                        kind: "cuda".to_string(),
163                        ordinal: Some(ord),
164                        name: None,
165                        total_memory_bytes: total,
166                        available_memory_bytes: avail,
167                        compute_capability: compute_cap,
168                        flash_attn_compatible: flash_attn_v2_ok,
169                        flash_attn_v3_compatible: flash_attn_v3_ok,
170                    });
171                    ord += 1;
172                }
173                Err(_) => break,
174            }
175        }
176    }
177
178    #[cfg(feature = "metal")]
179    {
180        let total = candle_metal_kernels::metal::Device::all().len();
181        for ord in 0..total {
182            if let Ok(dev) = Device::new_metal(ord) {
183                let total = MemoryUsage.get_total_memory(&dev).ok().map(|v| v as u64);
184                let avail = MemoryUsage
185                    .get_memory_available(&dev)
186                    .ok()
187                    .map(|v| v as u64);
188                devices.push(DeviceInfo {
189                    kind: "metal".to_string(),
190                    ordinal: Some(ord),
191                    name: None,
192                    total_memory_bytes: total,
193                    available_memory_bytes: avail,
194                    compute_capability: None,
195                    flash_attn_compatible: Some(true), // Metal always supports flash attention
196                    flash_attn_v3_compatible: None,    // Flash Attn v3 is CUDA Hopper only
197                });
198            }
199        }
200    }
201
202    devices
203}
204
205/// Get CUDA compute capability for a device ordinal
206#[cfg(feature = "cuda")]
207fn get_cuda_compute_capability(ordinal: usize) -> Option<(u32, u32)> {
208    // Use nvidia-smi to query compute capability
209    let output = std::process::Command::new("nvidia-smi")
210        .args([
211            "--query-gpu=compute_cap",
212            "--format=csv,noheader",
213            &format!("-i={ordinal}"),
214        ])
215        .output()
216        .ok()?;
217
218    if !output.status.success() {
219        return None;
220    }
221
222    let stdout = String::from_utf8(output.stdout).ok()?;
223    let cap = stdout.trim();
224
225    // Parse "8.9" format
226    let parts: Vec<&str> = cap.split('.').collect();
227    if parts.len() == 2 {
228        let major = parts[0].parse().ok()?;
229        let minor = parts[1].parse().ok()?;
230        Some((major, minor))
231    } else {
232        None
233    }
234}
235
236#[cfg(not(feature = "cuda"))]
237#[allow(dead_code)]
238fn get_cuda_compute_capability(_ordinal: usize) -> Option<(u32, u32)> {
239    None
240}
241
242/// Detect CPU extensions (AVX, AVX2, AVX-512, FMA)
243fn detect_cpu_extensions() -> (bool, bool, bool, bool) {
244    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
245    {
246        let avx = std::arch::is_x86_feature_detected!("avx");
247        let avx2 = std::arch::is_x86_feature_detected!("avx2");
248        let avx512 = std::arch::is_x86_feature_detected!("avx512f");
249        let fma = std::arch::is_x86_feature_detected!("fma");
250        (avx, avx2, avx512, fma)
251    }
252    #[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))]
253    {
254        (false, false, false, false)
255    }
256}
257
258pub fn collect_system_info() -> SystemInfo {
259    let mut sys = System::new_all();
260    sys.refresh_all();
261
262    let (avx, avx2, avx512, fma) = detect_cpu_extensions();
263
264    let cpu = CpuInfo {
265        brand: sys.cpus().first().map(|c| c.brand().to_string()),
266        logical_cores: sys.cpus().len(),
267        physical_cores: System::physical_core_count(),
268        avx,
269        avx2,
270        avx512,
271        fma,
272    };
273
274    let memory = MemoryInfo {
275        total_bytes: sys.total_memory(),
276        available_bytes: sys.available_memory(),
277    };
278
279    let hf_cache = Cache::from_env();
280    let hf_cache_path = hf_cache.path().to_string_lossy().to_string();
281
282    SystemInfo {
283        os: System::name(),
284        kernel: System::kernel_version(),
285        cpu,
286        memory,
287        devices: collect_devices(&sys),
288        build: build_info(),
289        hf_cache_path: Some(hf_cache_path),
290    }
291}
292
293/// Check HuggingFace connectivity and token validity by accessing a gated model
294#[allow(clippy::cast_possible_truncation)]
295pub fn check_hf_gated_access() -> HfConnectivityInfo {
296    let start = Instant::now();
297
298    // Try to access a gated model (google/gemma-3-4b-it)
299    let api_result = ApiBuilder::from_env()
300        .with_progress(false)
301        .build()
302        .and_then(|api| api.model("google/gemma-3-4b-it".to_string()).info());
303
304    let latency_ms = start.elapsed().as_millis() as u64;
305
306    match api_result {
307        Ok(_) => HfConnectivityInfo {
308            reachable: true,
309            latency_ms: Some(latency_ms),
310            token_valid_for_gated: Some(true),
311            error: None,
312        },
313        Err(e) => {
314            let error_str = e.to_string();
315            // Check if it's an auth error vs network error
316            let is_auth_error = error_str.contains("401")
317                || error_str.contains("403")
318                || error_str.contains("unauthorized")
319                || error_str.contains("Unauthorized")
320                || error_str.contains("Access denied")
321                || error_str.contains("gated");
322
323            if is_auth_error {
324                // Network works, but token is invalid/missing
325                HfConnectivityInfo {
326                    reachable: true,
327                    latency_ms: Some(latency_ms),
328                    token_valid_for_gated: Some(false),
329                    error: Some("Token invalid or missing for gated models".to_string()),
330                }
331            } else {
332                // Network/other error
333                HfConnectivityInfo {
334                    reachable: false,
335                    latency_ms: None,
336                    token_valid_for_gated: None,
337                    error: Some(error_str),
338                }
339            }
340        }
341    }
342}
343
344fn disk_usage_for(path: &Path) -> Option<(u64, u64)> {
345    let disks = Disks::new_with_refreshed_list();
346    let mut best: Option<(usize, u64, u64)> = None;
347    for disk in disks.list() {
348        let mount = disk.mount_point();
349        if path.starts_with(mount) {
350            let len = mount.as_os_str().len();
351            let avail = disk.available_space();
352            let total = disk.total_space();
353            if best.map(|b| len > b.0).unwrap_or(true) {
354                best = Some((len, avail, total));
355            }
356        }
357    }
358    best.map(|(_, avail, total)| (avail, total))
359}
360
361pub fn run_doctor() -> DoctorReport {
362    let system = collect_system_info();
363    let mut checks = Vec::new();
364
365    // CPU extensions check (ARM-aware)
366    {
367        let is_arm = cfg!(any(target_arch = "aarch64", target_arch = "arm"));
368
369        if is_arm {
370            // ARM CPUs use NEON, not AVX - no warning needed
371            checks.push(DoctorCheck {
372                name: "cpu_extensions".to_string(),
373                status: DoctorStatus::Ok,
374                message: "CPU: ARM architecture (uses NEON)".to_string(),
375                suggestion: None,
376            });
377        } else {
378            // x86/x86_64 - check for AVX extensions
379            let mut extensions = Vec::new();
380            if system.cpu.avx {
381                extensions.push("AVX");
382            }
383            if system.cpu.avx2 {
384                extensions.push("AVX2");
385            }
386            if system.cpu.fma {
387                extensions.push("FMA");
388            }
389            if system.cpu.avx512 {
390                extensions.push("AVX-512");
391            }
392
393            let has_avx2 = system.cpu.avx2;
394            let ext_str = if extensions.is_empty() {
395                "none detected".to_string()
396            } else {
397                extensions.join(", ")
398            };
399
400            checks.push(DoctorCheck {
401                name: "cpu_extensions".to_string(),
402                status: if has_avx2 {
403                    DoctorStatus::Ok
404                } else {
405                    DoctorStatus::Warn
406                },
407                message: format!("CPU extensions: {ext_str}"),
408                suggestion: if !has_avx2 {
409                    Some("AVX2 is recommended for optimal GGML performance on x86.".to_string())
410                } else {
411                    None
412                },
413            });
414        }
415    }
416
417    // Binary vs hardware mismatch check
418    {
419        let has_cuda_device = system.devices.iter().any(|d| d.kind == "cuda");
420        let has_metal_device = system.devices.iter().any(|d| d.kind == "metal");
421
422        if has_cuda_device && !system.build.cuda {
423            checks.push(DoctorCheck {
424                name: "binary_hardware_match".to_string(),
425                status: DoctorStatus::Error,
426                message: "NVIDIA GPU detected but binary compiled without CUDA support."
427                    .to_string(),
428                suggestion: Some("Reinstall with CUDA: cargo install --features cuda".to_string()),
429            });
430        } else if has_metal_device && !system.build.metal {
431            checks.push(DoctorCheck {
432                name: "binary_hardware_match".to_string(),
433                status: DoctorStatus::Error,
434                message: "Apple GPU detected but binary compiled without Metal support."
435                    .to_string(),
436                suggestion: Some(
437                    "Reinstall with Metal: cargo install --features metal".to_string(),
438                ),
439            });
440        } else {
441            checks.push(DoctorCheck {
442                name: "binary_hardware_match".to_string(),
443                status: DoctorStatus::Ok,
444                message: "Binary features match detected hardware.".to_string(),
445                suggestion: None,
446            });
447        }
448    }
449
450    // CUDA compute capability + Flash Attention v2/v3 check
451    #[cfg(feature = "cuda")]
452    {
453        for dev in system.devices.iter().filter(|d| d.kind == "cuda") {
454            if let (Some(ord), Some((major, minor))) = (dev.ordinal, dev.compute_capability) {
455                let fa_v2_ok = dev.flash_attn_compatible.unwrap_or(false);
456                let fa_v3_ok = dev.flash_attn_v3_compatible.unwrap_or(false);
457
458                // Build status strings with emojis
459                let fa_v2_str = if fa_v2_ok { "✅" } else { "❌" };
460                let fa_v3_str = if fa_v3_ok {
461                    "✅"
462                } else {
463                    "❌ (requires Hopper/Compute 9.0)"
464                };
465
466                checks.push(DoctorCheck {
467                    name: format!("cuda_{}_compute", ord),
468                    status: DoctorStatus::Ok,
469                    message: format!(
470                        "GPU {}: compute {}.{} - Flash Attn v2 {}, v3 {}",
471                        ord, major, minor, fa_v2_str, fa_v3_str
472                    ),
473                    suggestion: None,
474                });
475
476                // Warn if hardware supports flash attn v2 but binary doesn't have it
477                if fa_v2_ok && !system.build.flash_attn {
478                    checks.push(DoctorCheck {
479                        name: format!("cuda_{}_flash_attn_v2_missing", ord),
480                        status: DoctorStatus::Warn,
481                        message: format!(
482                            "GPU {} supports Flash Attention v2 but binary compiled without it.",
483                            ord
484                        ),
485                        suggestion: Some(
486                            "Reinstall with: cargo install --features flash-attn".to_string(),
487                        ),
488                    });
489                }
490
491                // Warn if hardware supports flash attn v3 but binary doesn't have it
492                if fa_v3_ok && !system.build.flash_attn_v3 {
493                    checks.push(DoctorCheck {
494                        name: format!("cuda_{}_flash_attn_v3_missing", ord),
495                        status: DoctorStatus::Warn,
496                        message: format!(
497                            "GPU {} supports Flash Attention v3 but binary compiled without it.",
498                            ord
499                        ),
500                        suggestion: Some(
501                            "Reinstall with: cargo install --features flash-attn-v3".to_string(),
502                        ),
503                    });
504                }
505            }
506        }
507    }
508
509    let hf_cache_path = system
510        .hf_cache_path
511        .as_ref()
512        .map(PathBuf::from)
513        .unwrap_or_else(|| Cache::from_env().path().clone());
514
515    if std::fs::create_dir_all(&hf_cache_path).is_err() {
516        checks.push(DoctorCheck {
517            name: "hf_cache_writable".to_string(),
518            status: DoctorStatus::Error,
519            message: format!(
520                "Cannot create or access Hugging Face cache dir at {}",
521                hf_cache_path.display()
522            ),
523            suggestion: Some("Set HF_HOME or fix permissions.".to_string()),
524        });
525    } else {
526        checks.push(DoctorCheck {
527            name: "hf_cache_writable".to_string(),
528            status: DoctorStatus::Ok,
529            message: format!(
530                "Hugging Face cache dir is writable: {}",
531                hf_cache_path.display()
532            ),
533            suggestion: None,
534        });
535    }
536
537    // HuggingFace connectivity + gated model access check
538    {
539        let hf_info = check_hf_gated_access();
540        if hf_info.reachable {
541            if hf_info.token_valid_for_gated == Some(true) {
542                checks.push(DoctorCheck {
543                    name: "hf_connectivity".to_string(),
544                    status: DoctorStatus::Ok,
545                    message: format!(
546                        "Hugging Face: connected ({}ms), token valid for allowed gated models.",
547                        hf_info.latency_ms.unwrap_or(0)
548                    ),
549                    suggestion: None,
550                });
551            } else {
552                checks.push(DoctorCheck {
553                    name: "hf_connectivity".to_string(),
554                    status: DoctorStatus::Warn,
555                    message: format!(
556                        "Hugging Face: connected ({}ms), but token invalid/missing.",
557                        hf_info.latency_ms.unwrap_or(0)
558                    ),
559                    suggestion: Some(
560                        "Run `huggingface-cli login` or set HF_TOKEN to access gated models."
561                            .to_string(),
562                    ),
563                });
564            }
565        } else {
566            checks.push(DoctorCheck {
567                name: "hf_connectivity".to_string(),
568                status: DoctorStatus::Error,
569                message: format!(
570                    "Hugging Face: unreachable - {}",
571                    hf_info.error.unwrap_or_else(|| "unknown error".to_string())
572                ),
573                suggestion: Some(
574                    "Check your internet connection and firewall settings.".to_string(),
575                ),
576            });
577        }
578    }
579
580    if let Some((avail, total)) = disk_usage_for(&hf_cache_path) {
581        let min_free = 10_u64 * 1024 * 1024 * 1024;
582        let status = if avail < min_free {
583            DoctorStatus::Warn
584        } else {
585            DoctorStatus::Ok
586        };
587        checks.push(DoctorCheck {
588            name: "disk_space".to_string(),
589            status,
590            #[allow(clippy::cast_precision_loss)]
591            message: format!(
592                "Disk free: {:.1} GB / {:.1} GB on the volume containing the HF cache at {}.",
593                avail as f64 / 1e9,
594                total as f64 / 1e9,
595                hf_cache_path.display()
596            ),
597            suggestion: if avail < min_free {
598                Some("Free up disk space or move HF cache.".to_string())
599            } else {
600                None
601            },
602        });
603    }
604
605    let has_cuda = system.devices.iter().any(|d| d.kind == "cuda");
606
607    if system.build.cuda && !has_cuda {
608        checks.push(DoctorCheck {
609            name: "cuda_devices".to_string(),
610            status: DoctorStatus::Warn,
611            message: "CUDA support is enabled but no CUDA devices were found.".to_string(),
612            suggestion: Some("Check NVIDIA driver installation.".to_string()),
613        });
614    }
615
616    DoctorReport { system, checks }
617}