Skip to main content

amaters_core/compute/
gpu.rs

1//! GPU acceleration module for FHE operations
2//!
3//! This module provides GPU-accelerated FHE operations using CUDA and Metal backends.
4//! It automatically detects available GPU hardware and falls back to CPU when needed.
5//!
6//! # Architecture
7//!
8//! - **CUDA Backend**: NVIDIA GPU acceleration on Linux/Windows (via tfhe-cuda-backend)
9//! - **Metal Backend**: Apple GPU acceleration on macOS (custom implementation)
10//! - **CPU Fallback**: Automatic fallback when GPU is unavailable
11//!
12//! # Example
13//!
14//! ```rust,ignore
15//! use amaters_core::compute::gpu::{GpuExecutor, GpuBackend};
16//!
17//! let executor = GpuExecutor::new()?;
18//! let backend = executor.backend();
19//! println!("Using backend: {:?}", backend);
20//! ```
21
22use crate::error::{AmateRSError, ErrorContext, Result};
23use parking_lot::RwLock;
24use std::sync::Arc;
25
26#[cfg(feature = "compute")]
27use crate::compute::operations::{
28    EncryptedBool, EncryptedU8, EncryptedU16, EncryptedU32, EncryptedU64,
29};
30
31/// GPU backend type
32#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
33pub enum GpuBackend {
34    /// NVIDIA CUDA backend (Linux, Windows)
35    Cuda,
36    /// Apple Metal backend (macOS)
37    Metal,
38    /// CPU fallback (all platforms)
39    Cpu,
40}
41
42impl std::fmt::Display for GpuBackend {
43    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44        match self {
45            GpuBackend::Cuda => write!(f, "CUDA"),
46            GpuBackend::Metal => write!(f, "Metal"),
47            GpuBackend::Cpu => write!(f, "CPU"),
48        }
49    }
50}
51
52/// GPU configuration options
53#[derive(Debug, Clone)]
54pub struct GpuConfig {
55    /// Preferred backend (None = auto-detect)
56    pub preferred_backend: Option<GpuBackend>,
57    /// Device ID for multi-GPU systems
58    pub device_id: usize,
59    /// Enable batch processing
60    pub enable_batch: bool,
61    /// Batch size for operations
62    pub batch_size: usize,
63    /// Memory pool size in bytes (0 = auto)
64    pub memory_pool_size: usize,
65}
66
67impl Default for GpuConfig {
68    fn default() -> Self {
69        Self {
70            preferred_backend: None,
71            device_id: 0,
72            enable_batch: true,
73            batch_size: 64,
74            memory_pool_size: 0,
75        }
76    }
77}
78
79/// GPU device information
80#[derive(Debug, Clone)]
81pub struct GpuDeviceInfo {
82    /// Backend type
83    pub backend: GpuBackend,
84    /// Device name
85    pub name: String,
86    /// Compute capability (CUDA) or Metal version
87    pub compute_capability: String,
88    /// Total memory in bytes
89    pub total_memory: u64,
90    /// Available memory in bytes
91    pub available_memory: u64,
92    /// Number of compute units
93    pub compute_units: u32,
94}
95
96/// GPU executor for FHE operations
97///
98/// Manages GPU resources and executes FHE operations on the selected backend.
99/// Automatically handles memory management, batch processing, and fallback to CPU.
100#[derive(Clone)]
101pub struct GpuExecutor {
102    backend: GpuBackend,
103    config: GpuConfig,
104    device_info: Option<GpuDeviceInfo>,
105    #[cfg(all(feature = "cuda", feature = "compute"))]
106    cuda_context: Option<Arc<RwLock<cuda::CudaContext>>>,
107    #[cfg(all(feature = "metal", feature = "compute"))]
108    metal_context: Option<Arc<RwLock<metal::MetalContext>>>,
109}
110
111impl GpuExecutor {
112    /// Create a new GPU executor with default configuration
113    ///
114    /// Automatically detects available GPU hardware and selects the best backend.
115    pub fn new() -> Result<Self> {
116        Self::with_config(GpuConfig::default())
117    }
118
119    /// Create a new GPU executor with custom configuration
120    pub fn with_config(config: GpuConfig) -> Result<Self> {
121        let backend = if let Some(preferred) = config.preferred_backend {
122            // Use preferred backend if specified
123            preferred
124        } else {
125            // Auto-detect best available backend
126            Self::detect_backend()?
127        };
128
129        let device_info = Self::get_device_info(backend, config.device_id)?;
130
131        let mut executor = Self {
132            backend,
133            config,
134            device_info: Some(device_info),
135            #[cfg(all(feature = "cuda", feature = "compute"))]
136            cuda_context: None,
137            #[cfg(all(feature = "metal", feature = "compute"))]
138            metal_context: None,
139        };
140
141        // Initialize backend-specific context
142        executor.initialize_backend()?;
143
144        Ok(executor)
145    }
146
147    /// Detect the best available GPU backend
148    fn detect_backend() -> Result<GpuBackend> {
149        #[cfg(feature = "cuda")]
150        {
151            if Self::is_cuda_available() {
152                return Ok(GpuBackend::Cuda);
153            }
154        }
155
156        #[cfg(feature = "metal")]
157        {
158            if Self::is_metal_available() {
159                return Ok(GpuBackend::Metal);
160            }
161        }
162
163        // Fallback to CPU
164        Ok(GpuBackend::Cpu)
165    }
166
167    /// Check if CUDA backend is available
168    #[cfg(feature = "cuda")]
169    fn is_cuda_available() -> bool {
170        // Check for CUDA runtime and devices
171        #[cfg(feature = "compute")]
172        {
173            cuda::detect_cuda_devices().is_ok()
174        }
175        #[cfg(not(feature = "compute"))]
176        {
177            false
178        }
179    }
180
181    #[cfg(not(feature = "cuda"))]
182    fn is_cuda_available() -> bool {
183        false
184    }
185
186    /// Check if Metal backend is available
187    #[cfg(feature = "metal")]
188    fn is_metal_available() -> bool {
189        // Metal is only available on macOS
190        #[cfg(target_os = "macos")]
191        {
192            #[cfg(feature = "compute")]
193            {
194                metal::detect_metal_devices().is_ok()
195            }
196            #[cfg(not(feature = "compute"))]
197            {
198                false
199            }
200        }
201        #[cfg(not(target_os = "macos"))]
202        {
203            false
204        }
205    }
206
207    #[cfg(not(feature = "metal"))]
208    fn is_metal_available() -> bool {
209        false
210    }
211
212    /// Get device information for the selected backend
213    fn get_device_info(backend: GpuBackend, device_id: usize) -> Result<GpuDeviceInfo> {
214        match backend {
215            #[cfg(feature = "cuda")]
216            GpuBackend::Cuda => {
217                #[cfg(feature = "compute")]
218                {
219                    cuda::get_device_info(device_id)
220                }
221                #[cfg(not(feature = "compute"))]
222                {
223                    Err(AmateRSError::FeatureNotEnabled(ErrorContext::new(
224                        "CUDA backend requires 'compute' feature".to_string(),
225                    )))
226                }
227            }
228
229            #[cfg(feature = "metal")]
230            GpuBackend::Metal => {
231                #[cfg(feature = "compute")]
232                {
233                    metal::get_device_info(device_id)
234                }
235                #[cfg(not(feature = "compute"))]
236                {
237                    Err(AmateRSError::FeatureNotEnabled(ErrorContext::new(
238                        "Metal backend requires 'compute' feature".to_string(),
239                    )))
240                }
241            }
242
243            GpuBackend::Cpu => {
244                let cpus = std::thread::available_parallelism()
245                    .map(|n| n.get())
246                    .unwrap_or(1);
247                Ok(GpuDeviceInfo {
248                    backend: GpuBackend::Cpu,
249                    name: "CPU".to_string(),
250                    compute_capability: format!("{} cores", cpus),
251                    total_memory: 0,
252                    available_memory: 0,
253                    compute_units: cpus as u32,
254                })
255            }
256
257            #[allow(unreachable_patterns)]
258            _ => Err(AmateRSError::Configuration(ErrorContext::new(format!(
259                "Backend {} is not available (feature not enabled)",
260                backend
261            )))),
262        }
263    }
264
265    /// Initialize backend-specific context
266    fn initialize_backend(&mut self) -> Result<()> {
267        match self.backend {
268            #[cfg(all(feature = "cuda", feature = "compute"))]
269            GpuBackend::Cuda => {
270                let context =
271                    cuda::CudaContext::new(self.config.device_id, self.config.memory_pool_size)?;
272                self.cuda_context = Some(Arc::new(RwLock::new(context)));
273                Ok(())
274            }
275
276            #[cfg(all(feature = "metal", feature = "compute"))]
277            GpuBackend::Metal => {
278                let context =
279                    metal::MetalContext::new(self.config.device_id, self.config.memory_pool_size)?;
280                self.metal_context = Some(Arc::new(RwLock::new(context)));
281                Ok(())
282            }
283
284            GpuBackend::Cpu => {
285                // CPU backend doesn't need initialization
286                Ok(())
287            }
288
289            #[allow(unreachable_patterns)]
290            _ => Err(AmateRSError::Configuration(ErrorContext::new(format!(
291                "Cannot initialize backend {} (feature not enabled)",
292                self.backend
293            )))),
294        }
295    }
296
297    /// Get the current backend
298    pub fn backend(&self) -> GpuBackend {
299        self.backend
300    }
301
302    /// Get device information
303    pub fn device_info(&self) -> Option<&GpuDeviceInfo> {
304        self.device_info.as_ref()
305    }
306
307    /// Get configuration
308    pub fn config(&self) -> &GpuConfig {
309        &self.config
310    }
311
312    /// Check if GPU acceleration is enabled
313    pub fn is_gpu_enabled(&self) -> bool {
314        !matches!(self.backend, GpuBackend::Cpu)
315    }
316
317    /// Execute FHE operation with GPU acceleration
318    ///
319    /// This method automatically routes the operation to the appropriate backend
320    /// and handles memory transfers between CPU and GPU.
321    #[cfg(feature = "compute")]
322    pub fn execute_operation<F, R>(&self, operation: F) -> Result<R>
323    where
324        F: FnOnce() -> Result<R> + Send,
325        R: Send,
326    {
327        match self.backend {
328            #[cfg(feature = "cuda")]
329            GpuBackend::Cuda => {
330                #[cfg(feature = "compute")]
331                {
332                    if let Some(context) = &self.cuda_context {
333                        let ctx = context.read();
334                        ctx.execute_operation(operation)
335                    } else {
336                        Err(AmateRSError::GpuError(ErrorContext::new(
337                            "CUDA context not initialized".to_string(),
338                        )))
339                    }
340                }
341                #[cfg(not(feature = "compute"))]
342                {
343                    Err(AmateRSError::FeatureNotEnabled(ErrorContext::new(
344                        "CUDA backend requires 'compute' feature".to_string(),
345                    )))
346                }
347            }
348
349            #[cfg(feature = "metal")]
350            GpuBackend::Metal => {
351                #[cfg(feature = "compute")]
352                {
353                    if let Some(context) = &self.metal_context {
354                        let ctx = context.read();
355                        ctx.execute_operation(operation)
356                    } else {
357                        Err(AmateRSError::GpuError(ErrorContext::new(
358                            "Metal context not initialized".to_string(),
359                        )))
360                    }
361                }
362                #[cfg(not(feature = "compute"))]
363                {
364                    Err(AmateRSError::FeatureNotEnabled(ErrorContext::new(
365                        "Metal backend requires 'compute' feature".to_string(),
366                    )))
367                }
368            }
369
370            GpuBackend::Cpu => {
371                // Execute on CPU directly
372                operation()
373            }
374
375            #[allow(unreachable_patterns)]
376            _ => Err(AmateRSError::Configuration(ErrorContext::new(format!(
377                "Backend {} is not available",
378                self.backend
379            )))),
380        }
381    }
382
383    /// Stub implementation when compute feature is disabled
384    #[cfg(not(feature = "compute"))]
385    pub fn execute_operation<F, R>(&self, _operation: F) -> Result<R>
386    where
387        F: FnOnce() -> Result<R> + Send,
388        R: Send,
389    {
390        Err(AmateRSError::FeatureNotEnabled(ErrorContext::new(
391            "FHE compute feature is not enabled".to_string(),
392        )))
393    }
394
395    /// Execute batch of FHE operations with GPU acceleration
396    #[cfg(feature = "compute")]
397    pub fn execute_batch<F, R>(&self, operations: Vec<F>) -> Result<Vec<R>>
398    where
399        F: FnOnce() -> Result<R> + Send,
400        R: Send,
401    {
402        if !self.config.enable_batch || operations.is_empty() {
403            return operations
404                .into_iter()
405                .map(|op| self.execute_operation(op))
406                .collect();
407        }
408
409        match self.backend {
410            #[cfg(feature = "cuda")]
411            GpuBackend::Cuda => {
412                #[cfg(feature = "compute")]
413                {
414                    if let Some(context) = &self.cuda_context {
415                        let ctx = context.read();
416                        ctx.execute_batch(operations, self.config.batch_size)
417                    } else {
418                        Err(AmateRSError::GpuError(ErrorContext::new(
419                            "CUDA context not initialized".to_string(),
420                        )))
421                    }
422                }
423                #[cfg(not(feature = "compute"))]
424                {
425                    Err(AmateRSError::FeatureNotEnabled(ErrorContext::new(
426                        "CUDA backend requires 'compute' feature".to_string(),
427                    )))
428                }
429            }
430
431            #[cfg(feature = "metal")]
432            GpuBackend::Metal => {
433                #[cfg(feature = "compute")]
434                {
435                    if let Some(context) = &self.metal_context {
436                        let ctx = context.read();
437                        ctx.execute_batch(operations, self.config.batch_size)
438                    } else {
439                        Err(AmateRSError::GpuError(ErrorContext::new(
440                            "Metal context not initialized".to_string(),
441                        )))
442                    }
443                }
444                #[cfg(not(feature = "compute"))]
445                {
446                    Err(AmateRSError::FeatureNotEnabled(ErrorContext::new(
447                        "Metal backend requires 'compute' feature".to_string(),
448                    )))
449                }
450            }
451
452            GpuBackend::Cpu => {
453                // Execute batch on CPU using rayon if available
454                #[cfg(feature = "parallel")]
455                {
456                    use rayon::prelude::*;
457                    operations.into_par_iter().map(|op| op()).collect()
458                }
459                #[cfg(not(feature = "parallel"))]
460                {
461                    operations.into_iter().map(|op| op()).collect()
462                }
463            }
464
465            #[allow(unreachable_patterns)]
466            _ => Err(AmateRSError::Configuration(ErrorContext::new(format!(
467                "Backend {} is not available",
468                self.backend
469            )))),
470        }
471    }
472
473    /// Stub implementation when compute feature is disabled
474    #[cfg(not(feature = "compute"))]
475    pub fn execute_batch<F, R>(&self, _operations: Vec<F>) -> Result<Vec<R>>
476    where
477        F: FnOnce() -> Result<R> + Send,
478        R: Send,
479    {
480        Err(AmateRSError::FeatureNotEnabled(ErrorContext::new(
481            "FHE compute feature is not enabled".to_string(),
482        )))
483    }
484}
485
486impl Default for GpuExecutor {
487    fn default() -> Self {
488        Self::new().expect("Failed to create default GPU executor")
489    }
490}
491
492impl std::fmt::Debug for GpuExecutor {
493    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
494        f.debug_struct("GpuExecutor")
495            .field("backend", &self.backend)
496            .field("config", &self.config)
497            .field("device_info", &self.device_info)
498            .finish()
499    }
500}
501
502/// GPU device detection module
503///
504/// Provides real hardware detection using platform-specific tools:
505/// - macOS: `system_profiler SPDisplaysDataType` and `sysctl hw.memsize`
506/// - Linux: `nvidia-smi` and sysfs fallbacks
507mod detection {
508    use super::{GpuBackend, GpuDeviceInfo};
509    use std::process::Command;
510
511    /// Detect macOS GPU via system_profiler
512    #[cfg(target_os = "macos")]
513    pub fn detect_macos_gpu() -> Option<GpuDeviceInfo> {
514        let output = Command::new("system_profiler")
515            .arg("SPDisplaysDataType")
516            .output()
517            .ok()?;
518        if !output.status.success() {
519            return None;
520        }
521        let text = String::from_utf8_lossy(&output.stdout);
522        let mut info = parse_system_profiler(&text)?;
523
524        // For Apple Silicon (unified memory), get total system memory via sysctl
525        if info.name.starts_with("Apple") {
526            if let Some(mem) = get_macos_system_memory() {
527                info.total_memory = mem;
528                // Estimate ~90% available (conservative)
529                info.available_memory = mem * 9 / 10;
530            }
531        }
532
533        Some(info)
534    }
535
536    /// Get macOS system memory via `sysctl hw.memsize`
537    #[cfg(target_os = "macos")]
538    fn get_macos_system_memory() -> Option<u64> {
539        let output = Command::new("sysctl")
540            .arg("-n")
541            .arg("hw.memsize")
542            .output()
543            .ok()?;
544        if !output.status.success() {
545            return None;
546        }
547        let text = String::from_utf8_lossy(&output.stdout);
548        text.trim().parse::<u64>().ok()
549    }
550
551    /// Detect NVIDIA GPU on Linux via nvidia-smi
552    #[cfg(target_os = "linux")]
553    pub fn detect_nvidia_gpu(device_id: usize) -> Option<GpuDeviceInfo> {
554        let output = Command::new("nvidia-smi")
555            .args([
556                "--query-gpu=name,memory.total,memory.free",
557                "--format=csv,noheader,nounits",
558            ])
559            .output()
560            .ok()?;
561        if !output.status.success() {
562            return None;
563        }
564        let text = String::from_utf8_lossy(&output.stdout);
565        let devices = parse_nvidia_smi(&text);
566        devices.into_iter().nth(device_id)
567    }
568
569    /// Detect NVIDIA GPU count on Linux via nvidia-smi
570    #[cfg(target_os = "linux")]
571    pub fn detect_nvidia_device_count() -> Option<usize> {
572        let output = Command::new("nvidia-smi")
573            .args(["--query-gpu=name", "--format=csv,noheader"])
574            .output()
575            .ok()?;
576        if !output.status.success() {
577            return None;
578        }
579        let text = String::from_utf8_lossy(&output.stdout);
580        let count = text.lines().filter(|l| !l.trim().is_empty()).count();
581        if count > 0 { Some(count) } else { None }
582    }
583
584    /// Linux sysfs fallback: detect NVIDIA devices via /sys/class/drm
585    #[cfg(target_os = "linux")]
586    pub fn detect_nvidia_sysfs() -> Vec<GpuDeviceInfo> {
587        let mut devices = Vec::new();
588        let drm_path = std::path::Path::new("/sys/class/drm");
589        if !drm_path.exists() {
590            return devices;
591        }
592
593        let entries = match std::fs::read_dir(drm_path) {
594            Ok(e) => e,
595            Err(_) => return devices,
596        };
597
598        for entry in entries.flatten() {
599            let name = entry.file_name();
600            let name_str = name.to_string_lossy();
601            if !name_str.starts_with("card") || name_str.contains('-') {
602                continue;
603            }
604
605            let vendor_path = entry.path().join("device/vendor");
606            if let Ok(vendor) = std::fs::read_to_string(&vendor_path) {
607                let vendor_trimmed = vendor.trim();
608                // 0x10de = NVIDIA
609                if vendor_trimmed == "0x10de" {
610                    let device_name = read_nvidia_proc_name(devices.len())
611                        .unwrap_or_else(|| format!("NVIDIA GPU (card {})", name_str));
612
613                    devices.push(GpuDeviceInfo {
614                        backend: GpuBackend::Cuda,
615                        name: device_name,
616                        compute_capability: "unknown".to_string(),
617                        total_memory: 0,
618                        available_memory: 0,
619                        compute_units: 0,
620                    });
621                }
622            }
623        }
624
625        devices
626    }
627
628    /// Try to read NVIDIA GPU name from /proc/driver/nvidia/gpus/*/information
629    #[cfg(target_os = "linux")]
630    fn read_nvidia_proc_name(index: usize) -> Option<String> {
631        let nvidia_path = std::path::Path::new("/proc/driver/nvidia/gpus");
632        if !nvidia_path.exists() {
633            return None;
634        }
635
636        let entries: Vec<_> = std::fs::read_dir(nvidia_path).ok()?.flatten().collect();
637
638        let entry = entries.get(index)?;
639        let info_path = entry.path().join("information");
640        let content = std::fs::read_to_string(info_path).ok()?;
641
642        for line in content.lines() {
643            if let Some(stripped) = line.strip_prefix("Model:") {
644                return Some(stripped.trim().to_string());
645            }
646        }
647
648        None
649    }
650
651    /// Parse `system_profiler SPDisplaysDataType` output
652    ///
653    /// Extracts GPU name, compute units, and memory information from the
654    /// macOS system_profiler output.
655    pub fn parse_system_profiler(text: &str) -> Option<GpuDeviceInfo> {
656        let mut chipset_model: Option<String> = None;
657        let mut total_cores: Option<u32> = None;
658        let mut vram_bytes: Option<u64> = None;
659        let mut is_apple_silicon = false;
660
661        for line in text.lines() {
662            let trimmed = line.trim();
663
664            // Extract chipset model name
665            if let Some(value) = trimmed.strip_prefix("Chipset Model:") {
666                chipset_model = Some(value.trim().to_string());
667                if value.trim().starts_with("Apple") {
668                    is_apple_silicon = true;
669                }
670            }
671
672            // Extract total number of GPU cores
673            if let Some(value) = trimmed.strip_prefix("Total Number of Cores:") {
674                total_cores = value.trim().parse::<u32>().ok();
675            }
676
677            // Extract VRAM (for discrete GPUs)
678            if trimmed.starts_with("VRAM") {
679                // Formats: "VRAM (Total): 8 GB", "VRAM (Dynamic, Max): 1536 MB"
680                if let Some(colon_pos) = trimmed.find(':') {
681                    let value_part = trimmed[colon_pos + 1..].trim();
682                    vram_bytes = parse_memory_string(value_part);
683                }
684            }
685        }
686
687        let name = chipset_model?;
688
689        let compute_capability = if is_apple_silicon {
690            "Metal 3".to_string()
691        } else if name.contains("Intel") {
692            "Metal 2".to_string()
693        } else {
694            "Metal".to_string()
695        };
696
697        let compute_units = total_cores.unwrap_or(0);
698
699        // For Apple Silicon, memory will be set later from sysctl
700        // For discrete GPUs, use VRAM
701        let total_memory = vram_bytes.unwrap_or(0);
702        let available_memory = if total_memory > 0 {
703            total_memory * 9 / 10
704        } else {
705            0
706        };
707
708        Some(GpuDeviceInfo {
709            backend: GpuBackend::Metal,
710            name,
711            compute_capability,
712            total_memory,
713            available_memory,
714            compute_units,
715        })
716    }
717
718    /// Parse a memory string like "8 GB", "1536 MB", "16384 MB" into bytes
719    fn parse_memory_string(s: &str) -> Option<u64> {
720        let s = s.trim();
721        let parts: Vec<&str> = s.split_whitespace().collect();
722        if parts.len() < 2 {
723            return None;
724        }
725
726        let value = parts[0].replace(',', "").parse::<u64>().ok()?;
727        let unit = parts[1].to_uppercase();
728
729        match unit.as_str() {
730            "GB" => Some(value * 1_073_741_824),
731            "MB" => Some(value * 1_048_576),
732            "KB" => Some(value * 1024),
733            "TB" => Some(value * 1_099_511_627_776),
734            _ => None,
735        }
736    }
737
738    /// Parse nvidia-smi CSV output
739    ///
740    /// Expected input format (from `--format=csv,noheader,nounits`):
741    /// ```text
742    /// NVIDIA GeForce RTX 4090, 24564, 23456
743    /// ```
744    ///
745    /// Each line: name, total_memory_mb, free_memory_mb
746    pub fn parse_nvidia_smi(text: &str) -> Vec<GpuDeviceInfo> {
747        let mut devices = Vec::new();
748
749        for line in text.lines() {
750            let trimmed = line.trim();
751            if trimmed.is_empty() {
752                continue;
753            }
754
755            let parts: Vec<&str> = trimmed.splitn(3, ',').collect();
756            if parts.len() < 3 {
757                continue;
758            }
759
760            let name = parts[0].trim().to_string();
761            let total_mb = match parts[1].trim().parse::<u64>() {
762                Ok(v) => v,
763                Err(_) => continue,
764            };
765            let free_mb = match parts[2].trim().parse::<u64>() {
766                Ok(v) => v,
767                Err(_) => continue,
768            };
769
770            // Convert MB to bytes
771            let total_memory = total_mb * 1_048_576;
772            let available_memory = free_mb * 1_048_576;
773
774            devices.push(GpuDeviceInfo {
775                backend: GpuBackend::Cuda,
776                name,
777                compute_capability: "unknown".to_string(),
778                total_memory,
779                available_memory,
780                compute_units: 0,
781            });
782        }
783
784        devices
785    }
786}
787
788/// CUDA backend implementation
789#[cfg(all(feature = "cuda", feature = "compute"))]
790mod cuda {
791    use super::*;
792
793    /// CUDA context for GPU operations
794    pub struct CudaContext {
795        device_id: usize,
796        memory_pool_size: usize,
797    }
798
799    impl CudaContext {
800        pub fn new(device_id: usize, memory_pool_size: usize) -> Result<Self> {
801            // Initialize CUDA context
802            // Note: tfhe-cuda-backend handles context initialization internally
803            Ok(Self {
804                device_id,
805                memory_pool_size,
806            })
807        }
808
809        pub fn execute_operation<F, R>(&self, operation: F) -> Result<R>
810        where
811            F: FnOnce() -> Result<R> + Send,
812            R: Send,
813        {
814            // CUDA operations are executed directly by tfhe-rs when GPU is enabled
815            // The tfhe-cuda-backend is automatically used when available
816            operation()
817        }
818
819        pub fn execute_batch<F, R>(&self, operations: Vec<F>, batch_size: usize) -> Result<Vec<R>>
820        where
821            F: FnOnce() -> Result<R> + Send,
822            R: Send,
823        {
824            // Process operations in batches, consuming the Vec
825            let mut results = Vec::with_capacity(operations.len());
826            let mut iter = operations.into_iter().peekable();
827
828            while iter.peek().is_some() {
829                let batch: Vec<F> = iter.by_ref().take(batch_size).collect();
830
831                #[cfg(feature = "parallel")]
832                {
833                    use rayon::prelude::*;
834                    let chunk_results: Result<Vec<_>> =
835                        batch.into_par_iter().map(|op| op()).collect();
836                    results.extend(chunk_results?);
837                }
838                #[cfg(not(feature = "parallel"))]
839                {
840                    for op in batch {
841                        results.push(op()?);
842                    }
843                }
844            }
845
846            Ok(results)
847        }
848    }
849
850    /// Detect available CUDA devices
851    ///
852    /// On Linux, tries nvidia-smi first, then falls back to sysfs detection.
853    pub fn detect_cuda_devices() -> Result<Vec<usize>> {
854        #[cfg(target_os = "linux")]
855        {
856            // Try nvidia-smi first
857            if let Some(count) = detection::detect_nvidia_device_count() {
858                return Ok((0..count).collect());
859            }
860
861            // Fallback to sysfs detection
862            let sysfs_devices = detection::detect_nvidia_sysfs();
863            if !sysfs_devices.is_empty() {
864                return Ok((0..sysfs_devices.len()).collect());
865            }
866        }
867
868        // Fallback: assume device 0 is available when cuda feature is enabled
869        Ok(vec![0])
870    }
871
872    /// Get CUDA device information
873    ///
874    /// Attempts real detection via nvidia-smi on Linux, with sysfs fallback.
875    /// Returns a placeholder if all detection methods fail.
876    pub fn get_device_info(device_id: usize) -> Result<GpuDeviceInfo> {
877        // Try real detection on Linux
878        #[cfg(target_os = "linux")]
879        {
880            // Try nvidia-smi first
881            if let Some(info) = detection::detect_nvidia_gpu(device_id) {
882                return Ok(info);
883            }
884
885            // Try sysfs fallback
886            let sysfs_devices = detection::detect_nvidia_sysfs();
887            if let Some(info) = sysfs_devices.into_iter().nth(device_id) {
888                return Ok(info);
889            }
890        }
891
892        // Fallback to placeholder
893        Ok(cuda_placeholder(device_id))
894    }
895
896    /// Generate a placeholder CUDA device info when detection fails
897    fn cuda_placeholder(device_id: usize) -> GpuDeviceInfo {
898        GpuDeviceInfo {
899            backend: GpuBackend::Cuda,
900            name: format!("CUDA Device {}", device_id),
901            compute_capability: "unknown".to_string(),
902            total_memory: 0,
903            available_memory: 0,
904            compute_units: 0,
905        }
906    }
907}
908
909/// Metal backend implementation
910#[cfg(all(feature = "metal", feature = "compute", target_os = "macos"))]
911mod metal {
912    use super::*;
913
914    /// Metal context for GPU operations
915    pub struct MetalContext {
916        device_id: usize,
917        memory_pool_size: usize,
918    }
919
920    impl MetalContext {
921        pub fn new(device_id: usize, memory_pool_size: usize) -> Result<Self> {
922            // Initialize Metal context
923            // This would create Metal device, command queue, etc.
924            Ok(Self {
925                device_id,
926                memory_pool_size,
927            })
928        }
929
930        pub fn execute_operation<F, R>(&self, operation: F) -> Result<R>
931        where
932            F: FnOnce() -> Result<R> + Send,
933            R: Send,
934        {
935            // Metal operations would be executed here
936            // For now, we execute on CPU as Metal backend is not yet implemented
937            operation()
938        }
939
940        pub fn execute_batch<F, R>(&self, operations: Vec<F>, batch_size: usize) -> Result<Vec<R>>
941        where
942            F: FnOnce() -> Result<R> + Send,
943            R: Send,
944        {
945            // Process operations in batches, consuming the Vec
946            let mut results = Vec::with_capacity(operations.len());
947            let mut iter = operations.into_iter().peekable();
948
949            while iter.peek().is_some() {
950                let batch: Vec<F> = iter.by_ref().take(batch_size).collect();
951
952                #[cfg(feature = "parallel")]
953                {
954                    use rayon::prelude::*;
955                    let chunk_results: Result<Vec<_>> =
956                        batch.into_par_iter().map(|op| op()).collect();
957                    results.extend(chunk_results?);
958                }
959                #[cfg(not(feature = "parallel"))]
960                {
961                    for op in batch {
962                        results.push(op()?);
963                    }
964                }
965            }
966
967            Ok(results)
968        }
969    }
970
971    /// Detect available Metal devices
972    ///
973    /// On macOS, runs system_profiler to detect GPU hardware.
974    pub fn detect_metal_devices() -> Result<Vec<usize>> {
975        if let Some(_info) = detection::detect_macos_gpu() {
976            Ok(vec![0])
977        } else {
978            // Fallback: assume device 0 on macOS (Metal is always available)
979            Ok(vec![0])
980        }
981    }
982
983    /// Get Metal device information
984    ///
985    /// Attempts real detection via system_profiler on macOS.
986    /// Returns a placeholder if detection fails.
987    pub fn get_device_info(device_id: usize) -> Result<GpuDeviceInfo> {
988        if device_id == 0 {
989            if let Some(info) = detection::detect_macos_gpu() {
990                return Ok(info);
991            }
992        }
993
994        // Fallback to placeholder
995        Ok(metal_placeholder(device_id))
996    }
997
998    /// Generate a placeholder Metal device info when detection fails
999    fn metal_placeholder(device_id: usize) -> GpuDeviceInfo {
1000        GpuDeviceInfo {
1001            backend: GpuBackend::Metal,
1002            name: format!("Apple Metal Device {}", device_id),
1003            compute_capability: "Metal".to_string(),
1004            total_memory: 0,
1005            available_memory: 0,
1006            compute_units: 0,
1007        }
1008    }
1009}
1010
1011/// Stub Metal module for non-macOS platforms
1012#[cfg(all(feature = "metal", feature = "compute", not(target_os = "macos")))]
1013mod metal {
1014    use super::*;
1015
1016    pub struct MetalContext;
1017
1018    impl MetalContext {
1019        pub fn new(_device_id: usize, _memory_pool_size: usize) -> Result<Self> {
1020            Err(AmateRSError::GpuError(ErrorContext::new(
1021                "Metal is only available on macOS".to_string(),
1022            )))
1023        }
1024
1025        pub fn execute_operation<F, R>(&self, _operation: F) -> Result<R>
1026        where
1027            F: FnOnce() -> Result<R> + Send,
1028            R: Send,
1029        {
1030            Err(AmateRSError::GpuError(ErrorContext::new(
1031                "Metal is only available on macOS".to_string(),
1032            )))
1033        }
1034
1035        pub fn execute_batch<F, R>(&self, _operations: Vec<F>, _batch_size: usize) -> Result<Vec<R>>
1036        where
1037            F: FnOnce() -> Result<R> + Send,
1038            R: Send,
1039        {
1040            Err(AmateRSError::GpuError(ErrorContext::new(
1041                "Metal is only available on macOS".to_string(),
1042            )))
1043        }
1044    }
1045
1046    pub fn detect_metal_devices() -> Result<Vec<usize>> {
1047        Err(AmateRSError::GpuError(ErrorContext::new(
1048            "Metal is only available on macOS".to_string(),
1049        )))
1050    }
1051
1052    pub fn get_device_info(_device_id: usize) -> Result<GpuDeviceInfo> {
1053        Err(AmateRSError::GpuError(ErrorContext::new(
1054            "Metal is only available on macOS".to_string(),
1055        )))
1056    }
1057}
1058
1059#[cfg(all(test, feature = "compute"))]
1060mod tests {
1061    use super::*;
1062
1063    #[test]
1064    fn test_gpu_backend_display() {
1065        assert_eq!(format!("{}", GpuBackend::Cuda), "CUDA");
1066        assert_eq!(format!("{}", GpuBackend::Metal), "Metal");
1067        assert_eq!(format!("{}", GpuBackend::Cpu), "CPU");
1068    }
1069
1070    #[test]
1071    fn test_gpu_config_default() {
1072        let config = GpuConfig::default();
1073        assert_eq!(config.preferred_backend, None);
1074        assert_eq!(config.device_id, 0);
1075        assert!(config.enable_batch);
1076        assert_eq!(config.batch_size, 64);
1077        assert_eq!(config.memory_pool_size, 0);
1078    }
1079
1080    #[test]
1081    fn test_gpu_executor_creation() -> Result<()> {
1082        let executor = GpuExecutor::new()?;
1083        assert!(matches!(
1084            executor.backend(),
1085            GpuBackend::Cuda | GpuBackend::Metal | GpuBackend::Cpu
1086        ));
1087        Ok(())
1088    }
1089
1090    #[test]
1091    fn test_gpu_executor_with_cpu_fallback() -> Result<()> {
1092        let config = GpuConfig {
1093            preferred_backend: Some(GpuBackend::Cpu),
1094            ..Default::default()
1095        };
1096        let executor = GpuExecutor::with_config(config)?;
1097        assert_eq!(executor.backend(), GpuBackend::Cpu);
1098        assert!(!executor.is_gpu_enabled());
1099        Ok(())
1100    }
1101
1102    #[test]
1103    fn test_device_info() -> Result<()> {
1104        let executor = GpuExecutor::new()?;
1105        let info = executor.device_info();
1106        assert!(info.is_some());
1107
1108        if let Some(info) = info {
1109            assert!(!info.name.is_empty());
1110            assert!(!info.compute_capability.is_empty());
1111        }
1112
1113        Ok(())
1114    }
1115
1116    #[test]
1117    fn test_execute_operation_cpu() -> Result<()> {
1118        let config = GpuConfig {
1119            preferred_backend: Some(GpuBackend::Cpu),
1120            ..Default::default()
1121        };
1122        let executor = GpuExecutor::with_config(config)?;
1123
1124        let result = executor.execute_operation(|| Ok(42))?;
1125        assert_eq!(result, 42);
1126
1127        Ok(())
1128    }
1129
1130    #[cfg(feature = "parallel")]
1131    #[test]
1132    fn test_execute_batch_cpu() -> Result<()> {
1133        let config = GpuConfig {
1134            preferred_backend: Some(GpuBackend::Cpu),
1135            enable_batch: true,
1136            batch_size: 4,
1137            ..Default::default()
1138        };
1139        let executor = GpuExecutor::with_config(config)?;
1140
1141        let operations: Vec<_> = (0..10).map(|i| move || Ok(i * 2)).collect();
1142
1143        let results = executor.execute_batch(operations)?;
1144        assert_eq!(results.len(), 10);
1145        assert_eq!(results, vec![0, 2, 4, 6, 8, 10, 12, 14, 16, 18]);
1146
1147        Ok(())
1148    }
1149
1150    // ---- GPU detection parsing tests ----
1151
1152    #[test]
1153    fn test_parse_system_profiler_m1() {
1154        let output = "\
1155Graphics/Displays:
1156
1157    Apple M1:
1158
1159      Chipset Model: Apple M1
1160      Type: GPU
1161      Bus: Built-In
1162      Total Number of Cores: 8
1163      Vendor: Apple (0x106b)
1164      Metal Support: Metal 3
1165";
1166        let info = detection::parse_system_profiler(output);
1167        assert!(info.is_some());
1168        let info = info.expect("should parse");
1169        assert_eq!(info.name, "Apple M1");
1170        assert_eq!(info.compute_units, 8);
1171        assert_eq!(info.backend, GpuBackend::Metal);
1172        assert_eq!(info.compute_capability, "Metal 3");
1173    }
1174
1175    #[test]
1176    fn test_parse_system_profiler_m2_pro() {
1177        let output = "\
1178Graphics/Displays:
1179
1180    Apple M2 Pro:
1181
1182      Chipset Model: Apple M2 Pro
1183      Type: GPU
1184      Bus: Built-In
1185      Total Number of Cores: 19
1186      Vendor: Apple (0x106b)
1187      Metal Support: Metal 3
1188";
1189        let info = detection::parse_system_profiler(output);
1190        assert!(info.is_some());
1191        let info = info.expect("should parse");
1192        assert_eq!(info.name, "Apple M2 Pro");
1193        assert_eq!(info.compute_units, 19);
1194        assert_eq!(info.compute_capability, "Metal 3");
1195    }
1196
1197    #[test]
1198    fn test_parse_system_profiler_m3_max() {
1199        let output = "\
1200Graphics/Displays:
1201
1202    Apple M3 Max:
1203
1204      Chipset Model: Apple M3 Max
1205      Type: GPU
1206      Bus: Built-In
1207      Total Number of Cores: 40
1208      Vendor: Apple (0x106b)
1209      Metal Support: Metal 3
1210";
1211        let info = detection::parse_system_profiler(output);
1212        assert!(info.is_some());
1213        let info = info.expect("should parse");
1214        assert_eq!(info.name, "Apple M3 Max");
1215        assert_eq!(info.compute_units, 40);
1216        assert_eq!(info.compute_capability, "Metal 3");
1217    }
1218
1219    #[test]
1220    fn test_parse_system_profiler_intel_gpu() {
1221        let output = "\
1222Graphics/Displays:
1223
1224    Intel Iris Plus Graphics 655:
1225
1226      Chipset Model: Intel Iris Plus Graphics 655
1227      Type: GPU
1228      Bus: Built-In
1229      VRAM (Dynamic, Max): 1536 MB
1230      Vendor: Intel (0x8086)
1231      Device ID: 0x3ea5
1232      Metal Support: Metal 2
1233";
1234        let info = detection::parse_system_profiler(output);
1235        assert!(info.is_some());
1236        let info = info.expect("should parse");
1237        assert_eq!(info.name, "Intel Iris Plus Graphics 655");
1238        assert_eq!(info.compute_capability, "Metal 2");
1239        // 1536 MB = 1536 * 1048576 = 1610612736
1240        assert_eq!(info.total_memory, 1_610_612_736);
1241        assert!(info.available_memory > 0);
1242        assert_eq!(info.compute_units, 0); // Intel GPUs don't report cores this way
1243    }
1244
1245    #[test]
1246    fn test_parse_system_profiler_empty() {
1247        let output = "";
1248        let info = detection::parse_system_profiler(output);
1249        assert!(info.is_none());
1250    }
1251
1252    #[test]
1253    fn test_parse_system_profiler_no_gpu_section() {
1254        let output = "\
1255Graphics/Displays:
1256
1257    No GPU found.
1258";
1259        let info = detection::parse_system_profiler(output);
1260        assert!(info.is_none());
1261    }
1262
1263    #[test]
1264    fn test_parse_nvidia_smi_single() {
1265        let output = "NVIDIA GeForce RTX 4090, 24564, 23456\n";
1266        let devices = detection::parse_nvidia_smi(output);
1267        assert_eq!(devices.len(), 1);
1268        assert_eq!(devices[0].name, "NVIDIA GeForce RTX 4090");
1269        assert_eq!(devices[0].total_memory, 24564 * 1_048_576);
1270        assert_eq!(devices[0].available_memory, 23456 * 1_048_576);
1271        assert_eq!(devices[0].backend, GpuBackend::Cuda);
1272    }
1273
1274    #[test]
1275    fn test_parse_nvidia_smi_multi() {
1276        let output = "\
1277NVIDIA GeForce RTX 4090, 24564, 23456
1278NVIDIA A100-SXM4-80GB, 81920, 79000
1279";
1280        let devices = detection::parse_nvidia_smi(output);
1281        assert_eq!(devices.len(), 2);
1282        assert_eq!(devices[0].name, "NVIDIA GeForce RTX 4090");
1283        assert_eq!(devices[0].total_memory, 24564 * 1_048_576);
1284        assert_eq!(devices[1].name, "NVIDIA A100-SXM4-80GB");
1285        assert_eq!(devices[1].total_memory, 81920 * 1_048_576);
1286        assert_eq!(devices[1].available_memory, 79000 * 1_048_576);
1287    }
1288
1289    #[test]
1290    fn test_parse_nvidia_smi_empty() {
1291        let output = "";
1292        let devices = detection::parse_nvidia_smi(output);
1293        assert!(devices.is_empty());
1294    }
1295
1296    #[test]
1297    fn test_parse_nvidia_smi_malformed() {
1298        let output = "\
1299this is not valid csv data
1300also garbage
1301,,
1302just-one-field
1303name, not_a_number, 123
1304name, 123, not_a_number
1305";
1306        let devices = detection::parse_nvidia_smi(output);
1307        assert!(devices.is_empty());
1308    }
1309
1310    #[test]
1311    fn test_gpu_device_info_fields() {
1312        let info = GpuDeviceInfo {
1313            backend: GpuBackend::Cuda,
1314            name: "Test GPU".to_string(),
1315            compute_capability: "8.9".to_string(),
1316            total_memory: 16_000_000_000,
1317            available_memory: 15_000_000_000,
1318            compute_units: 128,
1319        };
1320        assert_eq!(info.backend, GpuBackend::Cuda);
1321        assert_eq!(info.name, "Test GPU");
1322        assert_eq!(info.compute_capability, "8.9");
1323        assert_eq!(info.total_memory, 16_000_000_000);
1324        assert_eq!(info.available_memory, 15_000_000_000);
1325        assert_eq!(info.compute_units, 128);
1326    }
1327
1328    #[test]
1329    fn test_fallback_to_placeholder_cuda() {
1330        // Parsing empty nvidia-smi output returns empty vec,
1331        // so the caller would fall back to placeholder
1332        let devices = detection::parse_nvidia_smi("");
1333        assert!(devices.is_empty());
1334
1335        // Malformed data also falls back
1336        let devices = detection::parse_nvidia_smi("garbage data here");
1337        assert!(devices.is_empty());
1338    }
1339
1340    #[test]
1341    fn test_fallback_to_placeholder_metal() {
1342        // Empty system_profiler output returns None,
1343        // so the caller would fall back to placeholder
1344        let info = detection::parse_system_profiler("");
1345        assert!(info.is_none());
1346
1347        // No chipset model also returns None
1348        let info = detection::parse_system_profiler("Graphics/Displays:\n    No data\n");
1349        assert!(info.is_none());
1350    }
1351
1352    #[test]
1353    fn test_detect_on_current_platform() {
1354        // This test runs real detection on the current platform
1355        #[cfg(target_os = "macos")]
1356        {
1357            let info = detection::detect_macos_gpu();
1358            // On macOS, we should always detect a GPU
1359            assert!(info.is_some(), "should detect GPU on macOS");
1360            let info = info.expect("GPU detected");
1361            assert!(!info.name.is_empty());
1362            assert_eq!(info.backend, GpuBackend::Metal);
1363            // Apple Silicon should have unified memory > 0
1364            if info.name.starts_with("Apple") {
1365                assert!(info.total_memory > 0, "Apple Silicon should report memory");
1366                assert!(info.compute_units > 0, "Apple Silicon should report cores");
1367            }
1368        }
1369
1370        #[cfg(target_os = "linux")]
1371        {
1372            // On Linux, detection depends on having NVIDIA hardware
1373            // Just verify the functions don't panic
1374            let _nvidia = detection::detect_nvidia_gpu(0);
1375            let _count = detection::detect_nvidia_device_count();
1376            let _sysfs = detection::detect_nvidia_sysfs();
1377        }
1378    }
1379
1380    #[test]
1381    fn test_parse_nvidia_smi_whitespace_handling() {
1382        let output = "  NVIDIA RTX 3080 ,  10240 ,  9500  \n";
1383        let devices = detection::parse_nvidia_smi(output);
1384        assert_eq!(devices.len(), 1);
1385        assert_eq!(devices[0].name, "NVIDIA RTX 3080");
1386        assert_eq!(devices[0].total_memory, 10240 * 1_048_576);
1387        assert_eq!(devices[0].available_memory, 9500 * 1_048_576);
1388    }
1389}