Skip to main content

sapient_generate/
device.rs

1//! Device capability detection and backend recommendation.
2//!
3//! Detects CPUs, GPUs, and memory on macOS (Apple Silicon, Intel) and Windows
4//! (NVIDIA/AMD discrete GPUs via DXGI/wmic). Recommends the fastest backend
5//! combination for a given model size and reports estimated throughput.
6
7use std::process::Command;
8
9// ── Public types ──────────────────────────────────────────────────────────────
10
11/// Full picture of the host device's compute capabilities.
12#[derive(Debug, Clone)]
13pub struct DeviceProfile {
14    pub cpu: CpuInfo,
15    pub gpus: Vec<GpuInfo>,
16    pub ram_bytes: u64,
17    /// True when CPU and GPU share the same physical memory (Apple Silicon UMA).
18    pub unified_memory: bool,
19}
20
21#[derive(Debug, Clone)]
22pub struct CpuInfo {
23    pub name: String,
24    pub logical_cores: usize,
25    /// Apple Silicon performance cores (0 on non-Apple or unknown).
26    pub performance_cores: usize,
27    /// Apple Silicon efficiency cores (0 on non-Apple or unknown).
28    pub efficiency_cores: usize,
29    /// Rough memory bandwidth estimate in GB/s (0 = unknown).
30    pub bandwidth_gbps: f64,
31}
32
33#[derive(Debug, Clone)]
34pub struct GpuInfo {
35    pub name: String,
36    /// Dedicated VRAM in bytes. None = shared/unified memory (Apple Silicon).
37    pub vram_bytes: Option<u64>,
38    pub apis: Vec<ComputeApi>,
39}
40
41#[derive(Debug, Clone, PartialEq, Eq)]
42pub enum ComputeApi {
43    Metal,
44    Cuda,
45    Vulkan,
46    DirectX12,
47    OpenCL,
48}
49
50/// The backend(s) to use for a given model.
51#[derive(Debug, Clone)]
52pub enum BackendPlan {
53    /// Run everything on the CPU using NEON/AVX2 kernels.
54    Cpu,
55    /// Run everything on the Metal GPU via MLX (Apple Silicon only).
56    Metal,
57    /// Run the first `gpu_layers` transformer layers on Metal, the rest on CPU.
58    /// Only used when the model doesn't fit entirely in the Metal memory budget.
59    MetalCpuSplit { gpu_layers: usize, total_layers: usize },
60    /// CUDA backend (future — not yet implemented).
61    Cuda,
62}
63
64impl BackendPlan {
65    /// Human-readable label for display.
66    pub fn label(&self) -> String {
67        match self {
68            Self::Cpu => "CPU (NEON/AVX2)".into(),
69            Self::Metal => "Metal GPU (full model)".into(),
70            Self::MetalCpuSplit { gpu_layers, total_layers } => {
71                format!("Metal+CPU hybrid  ({gpu_layers}/{total_layers} layers on GPU)")
72            }
73            Self::Cuda => "CUDA GPU".into(),
74        }
75    }
76
77    /// Rough tok/s estimate for a model with `model_bytes` weights.
78    pub fn estimated_tps(&self, model_bytes: u64, profile: &DeviceProfile) -> f64 {
79        // Very rough bandwidth-based estimate: tok/s ≈ bandwidth / bytes_per_token
80        // Weight bytes per decode step ≈ model_bytes (reading all weights once).
81        let bw_gbps = match self {
82            Self::Cpu => profile.cpu.bandwidth_gbps,
83            Self::Metal => {
84                // Apple Silicon memory bandwidth (much higher via GPU fabric)
85                // M1: 68, M1 Pro: 200, M1 Max: 400, M2: 100, M2 Pro: 200, M3/M4: similar
86                // We use RAM as a proxy: more RAM → higher-tier chip → higher bandwidth
87                let gb = profile.ram_bytes as f64 / 1e9;
88                if gb >= 96.0 {
89                    400.0
90                } else if gb >= 36.0 {
91                    300.0
92                } else if gb >= 24.0 {
93                    200.0
94                } else {
95                    100.0
96                }
97            }
98            Self::MetalCpuSplit { gpu_layers, total_layers } => {
99                let gpu_frac = *gpu_layers as f64 / *total_layers as f64;
100                let metal_bw = 150.0f64.min(profile.ram_bytes as f64 / 1e9 * 8.0);
101                let cpu_bw = profile.cpu.bandwidth_gbps;
102                gpu_frac * metal_bw + (1.0 - gpu_frac) * cpu_bw
103            }
104            Self::Cuda => {
105                profile.gpus.iter()
106                    .find(|g| g.apis.contains(&ComputeApi::Cuda))
107                    .and_then(|g| g.vram_bytes)
108                    .map(|v| (v as f64 / 1e9) * 50.0)  // rough: 50 tok/s per GB VRAM
109                    .unwrap_or(50.0)
110            }
111        };
112
113        if bw_gbps <= 0.0 || model_bytes == 0 {
114            return 0.0;
115        }
116        (bw_gbps * 1e9 / model_bytes as f64).max(0.1)
117    }
118}
119
120// ── Detection ─────────────────────────────────────────────────────────────────
121
122/// Detect all device capabilities on the current host.
123pub fn detect() -> DeviceProfile {
124    let cpu = detect_cpu();
125    let gpus = detect_gpus();
126    let ram_bytes = detect_ram();
127    let unified_memory = is_unified_memory();
128    DeviceProfile { cpu, gpus, ram_bytes, unified_memory }
129}
130
131/// Recommend the best backend plan for a model of the given size.
132/// `total_layers` is the transformer depth (used for the hybrid split).
133pub fn recommend(profile: &DeviceProfile, model_bytes: u64, total_layers: usize) -> BackendPlan {
134    // ── Metal (Apple Silicon) ─────────────────────────────────────────────────
135    let has_metal = profile.gpus.iter().any(|g| g.apis.contains(&ComputeApi::Metal));
136    if has_metal && profile.unified_memory {
137        // Reserve 2 GB for OS, require 1.5× KV-cache headroom.
138        let budget = profile.ram_bytes.saturating_sub(2 * 1024 * 1024 * 1024);
139        let needed = (model_bytes as f64 * 1.5) as u64;
140
141        if needed <= budget {
142            return BackendPlan::Metal;
143        }
144
145        // Partial fit: calculate how many layers we can run on Metal.
146        if total_layers > 0 && model_bytes < profile.ram_bytes {
147            let bytes_per_layer = model_bytes / total_layers as u64;
148            let gpu_layers =
149                ((budget as f64 / (bytes_per_layer as f64 * 1.5)) as usize).min(total_layers);
150            if gpu_layers > total_layers / 4 {
151                return BackendPlan::MetalCpuSplit { gpu_layers, total_layers };
152            }
153        }
154    }
155
156    // ── CUDA (NVIDIA) ─────────────────────────────────────────────────────────
157    let cuda_gpu = profile.gpus.iter().find(|g| g.apis.contains(&ComputeApi::Cuda));
158    if let Some(gpu) = cuda_gpu {
159        if let Some(vram) = gpu.vram_bytes {
160            if model_bytes < vram * 9 / 10 {
161                return BackendPlan::Cuda;
162            }
163        }
164    }
165
166    BackendPlan::Cpu
167}
168
169// ── OS-specific detection helpers ─────────────────────────────────────────────
170
171fn detect_cpu() -> CpuInfo {
172    let name = cpu_name();
173    let logical_cores = logical_cpu_count();
174    let (perf_cores, eff_cores) = apple_core_split();
175    let bandwidth_gbps = estimate_cpu_bandwidth_gbps(logical_cores, &name);
176
177    CpuInfo { name, logical_cores, performance_cores: perf_cores, efficiency_cores: eff_cores, bandwidth_gbps }
178}
179
180fn cpu_name() -> String {
181    #[cfg(target_os = "macos")]
182    {
183        if let Some(s) = sysctl_str("machdep.cpu.brand_string") {
184            return s;
185        }
186        // Apple Silicon CPUs report chip via hw.model, not the cpu brand string
187        if let Some(model) = sysctl_str("hw.model") {
188            return model;
189        }
190    }
191    #[cfg(target_os = "windows")]
192    {
193        if let Some(s) = wmic_query("cpu get name /value") {
194            if let Some(v) = parse_wmic_value(&s, "Name") {
195                return v;
196            }
197        }
198    }
199    #[cfg(target_os = "linux")]
200    {
201        if let Ok(info) = std::fs::read_to_string("/proc/cpuinfo") {
202            for line in info.lines() {
203                if let Some(rest) = line.strip_prefix("model name") {
204                    if let Some(val) = rest.split(':').nth(1) {
205                        return val.trim().to_string();
206                    }
207                }
208            }
209        }
210    }
211    "Unknown CPU".to_string()
212}
213
214fn logical_cpu_count() -> usize {
215    #[cfg(target_os = "macos")]
216    {
217        if let Some(n) = sysctl_u64("hw.logicalcpu") {
218            return n as usize;
219        }
220    }
221    #[cfg(target_os = "windows")]
222    {
223        if let Ok(v) = std::env::var("NUMBER_OF_PROCESSORS") {
224            if let Ok(n) = v.parse::<usize>() {
225                return n;
226            }
227        }
228    }
229    std::thread::available_parallelism()
230        .map(|n| n.get())
231        .unwrap_or(4)
232}
233
234/// Returns (performance_cores, efficiency_cores) for Apple Silicon.
235/// Returns (0, 0) for non-Apple or when detection fails.
236fn apple_core_split() -> (usize, usize) {
237    #[cfg(target_os = "macos")]
238    {
239        let p = sysctl_u64("hw.perflevel0.logicalcpu").unwrap_or(0) as usize;
240        let e = sysctl_u64("hw.perflevel1.logicalcpu").unwrap_or(0) as usize;
241        if p + e > 0 {
242            return (p, e);
243        }
244    }
245    (0, 0)
246}
247
248fn estimate_cpu_bandwidth_gbps(cores: usize, name: &str) -> f64 {
249    let name_lower = name.to_lowercase();
250    // Apple Silicon: use RAM size (already fetched separately) as a proxy.
251    // Tier: base (M1/M2/M3/M4) ≈ 68-100 GB/s, Pro ≈ 200, Max ≈ 400, Ultra ≈ 800
252    if name_lower.contains("apple") || name_lower.contains("m1") || name_lower.contains("m2")
253        || name_lower.contains("m3") || name_lower.contains("m4")
254    {
255        return if name_lower.contains("ultra") {
256            800.0
257        } else if name_lower.contains("max") {
258            400.0
259        } else if name_lower.contains("pro") {
260            200.0
261        } else {
262            100.0
263        };
264    }
265    // x86: rough per-core bandwidth (typically 20-50 GB/s total for desktop)
266    (cores.min(8) as f64 * 5.0).max(20.0)
267}
268
269fn detect_gpus() -> Vec<GpuInfo> {
270    let mut gpus = Vec::new();
271
272    // ── macOS: Metal via system_profiler + sysctl ─────────────────────────────
273    #[cfg(target_os = "macos")]
274    {
275        let metal_available = {
276            #[cfg(all(target_os = "macos", feature = "mlx"))]
277            { true }
278            #[cfg(not(all(target_os = "macos", feature = "mlx")))]
279            {
280                // Check if we're on Apple Silicon (Metal always available there)
281                cfg!(target_arch = "aarch64")
282            }
283        };
284
285        if metal_available {
286            let name = macos_gpu_name();
287            gpus.push(GpuInfo {
288                name,
289                vram_bytes: None, // Apple UMA — shared with CPU RAM
290                apis: vec![ComputeApi::Metal],
291            });
292        }
293
294        // Check for NVIDIA eGPU (rare but possible)
295        if let Some(nvidia) = detect_nvidia_macos() {
296            gpus.push(nvidia);
297        }
298    }
299
300    // ── Windows: DXGI / wmic ──────────────────────────────────────────────────
301    #[cfg(target_os = "windows")]
302    {
303        for gpu in detect_windows_gpus() {
304            gpus.push(gpu);
305        }
306    }
307
308    // ── Linux: detect NVIDIA/AMD via /proc or nvidia-smi ─────────────────────
309    #[cfg(target_os = "linux")]
310    {
311        for gpu in detect_linux_gpus() {
312            gpus.push(gpu);
313        }
314    }
315
316    gpus
317}
318
319fn macos_gpu_name() -> String {
320    // system_profiler SPDisplaysDataType prints GPU info; parse Chipset Model line.
321    let out = Command::new("system_profiler")
322        .args(["SPDisplaysDataType"])
323        .output();
324    if let Ok(out) = out {
325        let text = String::from_utf8_lossy(&out.stdout);
326        for line in text.lines() {
327            let trimmed = line.trim();
328            if let Some(val) = trimmed.strip_prefix("Chipset Model:").or_else(|| trimmed.strip_prefix("GPU:")) {
329                return val.trim().to_string();
330            }
331        }
332    }
333    // Fallback: derive from chip model
334    sysctl_str("hw.model").unwrap_or_else(|| "Apple Silicon GPU".to_string())
335}
336
337#[cfg(target_os = "macos")]
338fn detect_nvidia_macos() -> Option<GpuInfo> {
339    // nvidia-smi presence = eGPU connected
340    let out = Command::new("nvidia-smi")
341        .args(["--query-gpu=name,memory.total", "--format=csv,noheader"])
342        .output()
343        .ok()?;
344    let text = String::from_utf8_lossy(&out.stdout);
345    let line = text.lines().next()?;
346    let parts: Vec<&str> = line.splitn(2, ',').collect();
347    let name = parts.first()?.trim().to_string();
348    let vram = parts.get(1)
349        .and_then(|v| v.trim().strip_suffix(" MiB"))
350        .and_then(|v| v.parse::<u64>().ok())
351        .map(|mb| mb * 1024 * 1024);
352    Some(GpuInfo { name, vram_bytes: vram, apis: vec![ComputeApi::Cuda] })
353}
354
355#[cfg(target_os = "windows")]
356fn detect_windows_gpus() -> Vec<GpuInfo> {
357    let mut gpus = Vec::new();
358
359    // Query all display adapters via wmic
360    let out = Command::new("wmic")
361        .args(["path", "win32_VideoController", "get",
362               "Name,AdapterRAM,DriverVersion", "/format:csv"])
363        .output();
364
365    if let Ok(out) = out {
366        let text = String::from_utf8_lossy(&out.stdout);
367        for line in text.lines().skip(2) { // skip header + blank line
368            let cols: Vec<&str> = line.split(',').collect();
369            if cols.len() < 3 { continue; }
370            let name = cols.get(2).unwrap_or(&"").trim().to_string();
371            if name.is_empty() { continue; }
372
373            let vram = cols.get(1)
374                .and_then(|v| v.trim().parse::<u64>().ok())
375                .filter(|&v| v > 0);
376
377            let name_lower = name.to_lowercase();
378            let mut apis = Vec::new();
379            if name_lower.contains("nvidia") {
380                // Check for CUDA: try nvidia-smi
381                if Command::new("nvidia-smi").output().is_ok() {
382                    apis.push(ComputeApi::Cuda);
383                }
384                apis.push(ComputeApi::Vulkan);
385                apis.push(ComputeApi::DirectX12);
386            } else if name_lower.contains("amd") || name_lower.contains("radeon") {
387                apis.push(ComputeApi::Vulkan);
388                apis.push(ComputeApi::DirectX12);
389            } else if name_lower.contains("intel") {
390                apis.push(ComputeApi::DirectX12);
391                apis.push(ComputeApi::Vulkan);
392            }
393
394            if !apis.is_empty() {
395                gpus.push(GpuInfo { name, vram_bytes: vram, apis });
396            }
397        }
398    }
399    gpus
400}
401
402#[cfg(target_os = "linux")]
403fn detect_linux_gpus() -> Vec<GpuInfo> {
404    let mut gpus = Vec::new();
405
406    // NVIDIA via nvidia-smi
407    let out = Command::new("nvidia-smi")
408        .args(["--query-gpu=name,memory.total", "--format=csv,noheader,nounits"])
409        .output();
410    if let Ok(out) = out {
411        let text = String::from_utf8_lossy(&out.stdout);
412        for line in text.lines() {
413            let parts: Vec<&str> = line.splitn(2, ',').collect();
414            let name = parts.first().map(|s| s.trim().to_string()).unwrap_or_default();
415            if name.is_empty() { continue; }
416            let vram = parts.get(1)
417                .and_then(|v| v.trim().parse::<u64>().ok())
418                .map(|mb| mb * 1024 * 1024);
419            gpus.push(GpuInfo {
420                name,
421                vram_bytes: vram,
422                apis: vec![ComputeApi::Cuda, ComputeApi::Vulkan],
423            });
424        }
425    }
426
427    // AMD/Intel via lspci (fallback)
428    if gpus.is_empty() {
429        let out = Command::new("lspci").output();
430        if let Ok(out) = out {
431            let text = String::from_utf8_lossy(&out.stdout);
432            for line in text.lines() {
433                let low = line.to_lowercase();
434                if low.contains("vga") || low.contains("3d controller") || low.contains("display") {
435                    let name = line.splitn(2, ':').last().unwrap_or(line).trim().to_string();
436                    let mut apis = vec![ComputeApi::Vulkan];
437                    if low.contains("nvidia") { apis.push(ComputeApi::Cuda); }
438                    gpus.push(GpuInfo { name, vram_bytes: None, apis });
439                }
440            }
441        }
442    }
443
444    gpus
445}
446
447fn detect_ram() -> u64 {
448    #[cfg(target_os = "macos")]
449    if let Some(n) = sysctl_u64("hw.memsize") {
450        return n;
451    }
452    #[cfg(target_os = "linux")]
453    if let Ok(info) = std::fs::read_to_string("/proc/meminfo") {
454        for line in info.lines() {
455            if let Some(rest) = line.strip_prefix("MemTotal:") {
456                if let Ok(kb) = rest.trim().trim_end_matches(" kB").trim().parse::<u64>() {
457                    return kb * 1024;
458                }
459            }
460        }
461    }
462    #[cfg(target_os = "windows")]
463    if let Some(s) = wmic_query("OS get TotalVisibleMemorySize /value") {
464        if let Some(kb_str) = parse_wmic_value(&s, "TotalVisibleMemorySize") {
465            if let Ok(kb) = kb_str.parse::<u64>() {
466                return kb * 1024;
467            }
468        }
469    }
470    0
471}
472
473fn is_unified_memory() -> bool {
474    // Apple Silicon always has unified memory.
475    #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
476    return true;
477    #[allow(unreachable_code)]
478    false
479}
480
481// ── Utility helpers ───────────────────────────────────────────────────────────
482
483#[cfg(target_os = "macos")]
484fn sysctl_str(key: &str) -> Option<String> {
485    let out = Command::new("sysctl").args(["-n", key]).output().ok()?;
486    let s = std::str::from_utf8(&out.stdout).ok()?.trim().to_string();
487    if s.is_empty() { None } else { Some(s) }
488}
489
490#[cfg(any(target_os = "macos", target_os = "linux"))]
491fn sysctl_u64(key: &str) -> Option<u64> {
492    #[cfg(target_os = "macos")]
493    {
494        sysctl_str(key)?.parse().ok()
495    }
496    #[cfg(target_os = "linux")]
497    {
498        let out = Command::new("sysctl").args(["-n", key]).output().ok()?;
499        std::str::from_utf8(&out.stdout).ok()?.trim().parse().ok()
500    }
501}
502
503#[cfg(target_os = "windows")]
504fn wmic_query(args: &str) -> Option<String> {
505    let parts: Vec<&str> = args.split_whitespace().collect();
506    let out = Command::new("wmic").args(&parts).output().ok()?;
507    Some(String::from_utf8_lossy(&out.stdout).to_string())
508}
509
510#[cfg(target_os = "windows")]
511fn parse_wmic_value(text: &str, key: &str) -> Option<String> {
512    for line in text.lines() {
513        if let Some(val) = line.strip_prefix(&format!("{key}=")) {
514            let v = val.trim().to_string();
515            if !v.is_empty() { return Some(v); }
516        }
517    }
518    None
519}
520
521// ── Formatting ────────────────────────────────────────────────────────────────
522
523impl DeviceProfile {
524    /// Render a human-readable device report.
525    pub fn report(&self) -> String {
526        let mut out = String::new();
527        let gb = |b: u64| b as f64 / 1e9;
528
529        // CPU
530        out.push_str(&format!(
531            "  CPU   {}\n",
532            self.cpu.name
533        ));
534        if self.cpu.performance_cores > 0 {
535            out.push_str(&format!(
536                "        {} cores  ({} performance + {} efficiency)\n",
537                self.cpu.logical_cores,
538                self.cpu.performance_cores,
539                self.cpu.efficiency_cores
540            ));
541        } else {
542            out.push_str(&format!("        {} logical cores\n", self.cpu.logical_cores));
543        }
544        if self.ram_bytes > 0 {
545            out.push_str(&format!(
546                "        {:.0} GB RAM{}  ·  est. {:.0} GB/s bandwidth\n",
547                gb(self.ram_bytes),
548                if self.unified_memory { " unified" } else { "" },
549                self.cpu.bandwidth_gbps
550            ));
551        }
552
553        // GPUs
554        for gpu in &self.gpus {
555            let apis: Vec<&str> = gpu.apis.iter().map(|a| match a {
556                ComputeApi::Metal    => "Metal",
557                ComputeApi::Cuda     => "CUDA",
558                ComputeApi::Vulkan   => "Vulkan",
559                ComputeApi::DirectX12 => "DX12",
560                ComputeApi::OpenCL   => "OpenCL",
561            }).collect();
562            let vram_str = match gpu.vram_bytes {
563                Some(b) => format!("  {:.0} GB VRAM", gb(b)),
564                None    => "  shared memory".to_string(),
565            };
566            out.push_str(&format!(
567                "\n  GPU   {}\n        {}  ·  [{}]\n",
568                gpu.name,
569                vram_str.trim(),
570                apis.join(", ")
571            ));
572        }
573
574        out
575    }
576
577    /// Render model-size-specific backend recommendations.
578    pub fn recommendations(&self) -> String {
579        let mut out = String::new();
580        out.push_str("\n  Recommended backends:\n");
581
582        let scenarios = [
583            ("0.5B Q8",  0.5e9 * 1.06, 24),
584            ("1.5B Q8",  1.5e9 * 1.06, 28),
585            ("3B Q4",    3.0e9 * 0.5625, 32),
586            ("7B Q4",    7.0e9 * 0.5625, 32),
587            ("14B Q4",  14.0e9 * 0.5625, 40),
588            ("32B Q4",  32.0e9 * 0.5625, 64),
589        ];
590
591        for (label, bytes, layers) in scenarios {
592            let model_bytes = bytes as u64;
593            let plan = recommend(self, model_bytes, layers);
594            let tps = plan.estimated_tps(model_bytes, self);
595            let tps_str = if tps > 0.5 {
596                format!("~{:.0} tok/s", tps)
597            } else {
598                "slow (model too large)".to_string()
599            };
600            out.push_str(&format!(
601                "    {label:<12}  →  {:<36}  {tps_str}\n",
602                plan.label()
603            ));
604        }
605        out
606    }
607}