1use crate::error::{AmateRSError, ErrorContext, Result};
23use parking_lot::RwLock;
24use std::sync::Arc;
25
26#[cfg(feature = "compute")]
27use crate::compute::operations::{
28 EncryptedBool, EncryptedU8, EncryptedU16, EncryptedU32, EncryptedU64,
29};
30
31#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
33pub enum GpuBackend {
34 Cuda,
36 Metal,
38 Cpu,
40}
41
42impl std::fmt::Display for GpuBackend {
43 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44 match self {
45 GpuBackend::Cuda => write!(f, "CUDA"),
46 GpuBackend::Metal => write!(f, "Metal"),
47 GpuBackend::Cpu => write!(f, "CPU"),
48 }
49 }
50}
51
52#[derive(Debug, Clone)]
54pub struct GpuConfig {
55 pub preferred_backend: Option<GpuBackend>,
57 pub device_id: usize,
59 pub enable_batch: bool,
61 pub batch_size: usize,
63 pub memory_pool_size: usize,
65}
66
67impl Default for GpuConfig {
68 fn default() -> Self {
69 Self {
70 preferred_backend: None,
71 device_id: 0,
72 enable_batch: true,
73 batch_size: 64,
74 memory_pool_size: 0,
75 }
76 }
77}
78
79#[derive(Debug, Clone)]
81pub struct GpuDeviceInfo {
82 pub backend: GpuBackend,
84 pub name: String,
86 pub compute_capability: String,
88 pub total_memory: u64,
90 pub available_memory: u64,
92 pub compute_units: u32,
94}
95
96#[derive(Clone)]
101pub struct GpuExecutor {
102 backend: GpuBackend,
103 config: GpuConfig,
104 device_info: Option<GpuDeviceInfo>,
105 #[cfg(all(feature = "cuda", feature = "compute"))]
106 cuda_context: Option<Arc<RwLock<cuda::CudaContext>>>,
107 #[cfg(all(feature = "metal", feature = "compute"))]
108 metal_context: Option<Arc<RwLock<metal::MetalContext>>>,
109}
110
111impl GpuExecutor {
112 pub fn new() -> Result<Self> {
116 Self::with_config(GpuConfig::default())
117 }
118
119 pub fn with_config(config: GpuConfig) -> Result<Self> {
121 let backend = if let Some(preferred) = config.preferred_backend {
122 preferred
124 } else {
125 Self::detect_backend()?
127 };
128
129 let device_info = Self::get_device_info(backend, config.device_id)?;
130
131 let mut executor = Self {
132 backend,
133 config,
134 device_info: Some(device_info),
135 #[cfg(all(feature = "cuda", feature = "compute"))]
136 cuda_context: None,
137 #[cfg(all(feature = "metal", feature = "compute"))]
138 metal_context: None,
139 };
140
141 executor.initialize_backend()?;
143
144 Ok(executor)
145 }
146
147 fn detect_backend() -> Result<GpuBackend> {
149 #[cfg(feature = "cuda")]
150 {
151 if Self::is_cuda_available() {
152 return Ok(GpuBackend::Cuda);
153 }
154 }
155
156 #[cfg(feature = "metal")]
157 {
158 if Self::is_metal_available() {
159 return Ok(GpuBackend::Metal);
160 }
161 }
162
163 Ok(GpuBackend::Cpu)
165 }
166
167 #[cfg(feature = "cuda")]
169 fn is_cuda_available() -> bool {
170 #[cfg(feature = "compute")]
172 {
173 cuda::detect_cuda_devices().is_ok()
174 }
175 #[cfg(not(feature = "compute"))]
176 {
177 false
178 }
179 }
180
181 #[cfg(not(feature = "cuda"))]
182 fn is_cuda_available() -> bool {
183 false
184 }
185
186 #[cfg(feature = "metal")]
188 fn is_metal_available() -> bool {
189 #[cfg(target_os = "macos")]
191 {
192 #[cfg(feature = "compute")]
193 {
194 metal::detect_metal_devices().is_ok()
195 }
196 #[cfg(not(feature = "compute"))]
197 {
198 false
199 }
200 }
201 #[cfg(not(target_os = "macos"))]
202 {
203 false
204 }
205 }
206
207 #[cfg(not(feature = "metal"))]
208 fn is_metal_available() -> bool {
209 false
210 }
211
212 fn get_device_info(backend: GpuBackend, device_id: usize) -> Result<GpuDeviceInfo> {
214 match backend {
215 #[cfg(feature = "cuda")]
216 GpuBackend::Cuda => {
217 #[cfg(feature = "compute")]
218 {
219 cuda::get_device_info(device_id)
220 }
221 #[cfg(not(feature = "compute"))]
222 {
223 Err(AmateRSError::FeatureNotEnabled(ErrorContext::new(
224 "CUDA backend requires 'compute' feature".to_string(),
225 )))
226 }
227 }
228
229 #[cfg(feature = "metal")]
230 GpuBackend::Metal => {
231 #[cfg(feature = "compute")]
232 {
233 metal::get_device_info(device_id)
234 }
235 #[cfg(not(feature = "compute"))]
236 {
237 Err(AmateRSError::FeatureNotEnabled(ErrorContext::new(
238 "Metal backend requires 'compute' feature".to_string(),
239 )))
240 }
241 }
242
243 GpuBackend::Cpu => {
244 let cpus = std::thread::available_parallelism()
245 .map(|n| n.get())
246 .unwrap_or(1);
247 Ok(GpuDeviceInfo {
248 backend: GpuBackend::Cpu,
249 name: "CPU".to_string(),
250 compute_capability: format!("{} cores", cpus),
251 total_memory: 0,
252 available_memory: 0,
253 compute_units: cpus as u32,
254 })
255 }
256
257 #[allow(unreachable_patterns)]
258 _ => Err(AmateRSError::Configuration(ErrorContext::new(format!(
259 "Backend {} is not available (feature not enabled)",
260 backend
261 )))),
262 }
263 }
264
265 fn initialize_backend(&mut self) -> Result<()> {
267 match self.backend {
268 #[cfg(all(feature = "cuda", feature = "compute"))]
269 GpuBackend::Cuda => {
270 let context =
271 cuda::CudaContext::new(self.config.device_id, self.config.memory_pool_size)?;
272 self.cuda_context = Some(Arc::new(RwLock::new(context)));
273 Ok(())
274 }
275
276 #[cfg(all(feature = "metal", feature = "compute"))]
277 GpuBackend::Metal => {
278 let context =
279 metal::MetalContext::new(self.config.device_id, self.config.memory_pool_size)?;
280 self.metal_context = Some(Arc::new(RwLock::new(context)));
281 Ok(())
282 }
283
284 GpuBackend::Cpu => {
285 Ok(())
287 }
288
289 #[allow(unreachable_patterns)]
290 _ => Err(AmateRSError::Configuration(ErrorContext::new(format!(
291 "Cannot initialize backend {} (feature not enabled)",
292 self.backend
293 )))),
294 }
295 }
296
297 pub fn backend(&self) -> GpuBackend {
299 self.backend
300 }
301
302 pub fn device_info(&self) -> Option<&GpuDeviceInfo> {
304 self.device_info.as_ref()
305 }
306
307 pub fn config(&self) -> &GpuConfig {
309 &self.config
310 }
311
312 pub fn is_gpu_enabled(&self) -> bool {
314 !matches!(self.backend, GpuBackend::Cpu)
315 }
316
317 #[cfg(feature = "compute")]
322 pub fn execute_operation<F, R>(&self, operation: F) -> Result<R>
323 where
324 F: FnOnce() -> Result<R> + Send,
325 R: Send,
326 {
327 match self.backend {
328 #[cfg(feature = "cuda")]
329 GpuBackend::Cuda => {
330 #[cfg(feature = "compute")]
331 {
332 if let Some(context) = &self.cuda_context {
333 let ctx = context.read();
334 ctx.execute_operation(operation)
335 } else {
336 Err(AmateRSError::GpuError(ErrorContext::new(
337 "CUDA context not initialized".to_string(),
338 )))
339 }
340 }
341 #[cfg(not(feature = "compute"))]
342 {
343 Err(AmateRSError::FeatureNotEnabled(ErrorContext::new(
344 "CUDA backend requires 'compute' feature".to_string(),
345 )))
346 }
347 }
348
349 #[cfg(feature = "metal")]
350 GpuBackend::Metal => {
351 #[cfg(feature = "compute")]
352 {
353 if let Some(context) = &self.metal_context {
354 let ctx = context.read();
355 ctx.execute_operation(operation)
356 } else {
357 Err(AmateRSError::GpuError(ErrorContext::new(
358 "Metal context not initialized".to_string(),
359 )))
360 }
361 }
362 #[cfg(not(feature = "compute"))]
363 {
364 Err(AmateRSError::FeatureNotEnabled(ErrorContext::new(
365 "Metal backend requires 'compute' feature".to_string(),
366 )))
367 }
368 }
369
370 GpuBackend::Cpu => {
371 operation()
373 }
374
375 #[allow(unreachable_patterns)]
376 _ => Err(AmateRSError::Configuration(ErrorContext::new(format!(
377 "Backend {} is not available",
378 self.backend
379 )))),
380 }
381 }
382
383 #[cfg(not(feature = "compute"))]
385 pub fn execute_operation<F, R>(&self, _operation: F) -> Result<R>
386 where
387 F: FnOnce() -> Result<R> + Send,
388 R: Send,
389 {
390 Err(AmateRSError::FeatureNotEnabled(ErrorContext::new(
391 "FHE compute feature is not enabled".to_string(),
392 )))
393 }
394
395 #[cfg(feature = "compute")]
397 pub fn execute_batch<F, R>(&self, operations: Vec<F>) -> Result<Vec<R>>
398 where
399 F: FnOnce() -> Result<R> + Send,
400 R: Send,
401 {
402 if !self.config.enable_batch || operations.is_empty() {
403 return operations
404 .into_iter()
405 .map(|op| self.execute_operation(op))
406 .collect();
407 }
408
409 match self.backend {
410 #[cfg(feature = "cuda")]
411 GpuBackend::Cuda => {
412 #[cfg(feature = "compute")]
413 {
414 if let Some(context) = &self.cuda_context {
415 let ctx = context.read();
416 ctx.execute_batch(operations, self.config.batch_size)
417 } else {
418 Err(AmateRSError::GpuError(ErrorContext::new(
419 "CUDA context not initialized".to_string(),
420 )))
421 }
422 }
423 #[cfg(not(feature = "compute"))]
424 {
425 Err(AmateRSError::FeatureNotEnabled(ErrorContext::new(
426 "CUDA backend requires 'compute' feature".to_string(),
427 )))
428 }
429 }
430
431 #[cfg(feature = "metal")]
432 GpuBackend::Metal => {
433 #[cfg(feature = "compute")]
434 {
435 if let Some(context) = &self.metal_context {
436 let ctx = context.read();
437 ctx.execute_batch(operations, self.config.batch_size)
438 } else {
439 Err(AmateRSError::GpuError(ErrorContext::new(
440 "Metal context not initialized".to_string(),
441 )))
442 }
443 }
444 #[cfg(not(feature = "compute"))]
445 {
446 Err(AmateRSError::FeatureNotEnabled(ErrorContext::new(
447 "Metal backend requires 'compute' feature".to_string(),
448 )))
449 }
450 }
451
452 GpuBackend::Cpu => {
453 #[cfg(feature = "parallel")]
455 {
456 use rayon::prelude::*;
457 operations.into_par_iter().map(|op| op()).collect()
458 }
459 #[cfg(not(feature = "parallel"))]
460 {
461 operations.into_iter().map(|op| op()).collect()
462 }
463 }
464
465 #[allow(unreachable_patterns)]
466 _ => Err(AmateRSError::Configuration(ErrorContext::new(format!(
467 "Backend {} is not available",
468 self.backend
469 )))),
470 }
471 }
472
473 #[cfg(not(feature = "compute"))]
475 pub fn execute_batch<F, R>(&self, _operations: Vec<F>) -> Result<Vec<R>>
476 where
477 F: FnOnce() -> Result<R> + Send,
478 R: Send,
479 {
480 Err(AmateRSError::FeatureNotEnabled(ErrorContext::new(
481 "FHE compute feature is not enabled".to_string(),
482 )))
483 }
484}
485
486impl Default for GpuExecutor {
487 fn default() -> Self {
488 Self::new().expect("Failed to create default GPU executor")
489 }
490}
491
492impl std::fmt::Debug for GpuExecutor {
493 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
494 f.debug_struct("GpuExecutor")
495 .field("backend", &self.backend)
496 .field("config", &self.config)
497 .field("device_info", &self.device_info)
498 .finish()
499 }
500}
501
502mod detection {
508 use super::{GpuBackend, GpuDeviceInfo};
509 use std::process::Command;
510
511 #[cfg(target_os = "macos")]
513 pub fn detect_macos_gpu() -> Option<GpuDeviceInfo> {
514 let output = Command::new("system_profiler")
515 .arg("SPDisplaysDataType")
516 .output()
517 .ok()?;
518 if !output.status.success() {
519 return None;
520 }
521 let text = String::from_utf8_lossy(&output.stdout);
522 let mut info = parse_system_profiler(&text)?;
523
524 if info.name.starts_with("Apple") {
526 if let Some(mem) = get_macos_system_memory() {
527 info.total_memory = mem;
528 info.available_memory = mem * 9 / 10;
530 }
531 }
532
533 Some(info)
534 }
535
536 #[cfg(target_os = "macos")]
538 fn get_macos_system_memory() -> Option<u64> {
539 let output = Command::new("sysctl")
540 .arg("-n")
541 .arg("hw.memsize")
542 .output()
543 .ok()?;
544 if !output.status.success() {
545 return None;
546 }
547 let text = String::from_utf8_lossy(&output.stdout);
548 text.trim().parse::<u64>().ok()
549 }
550
551 #[cfg(target_os = "linux")]
553 pub fn detect_nvidia_gpu(device_id: usize) -> Option<GpuDeviceInfo> {
554 let output = Command::new("nvidia-smi")
555 .args([
556 "--query-gpu=name,memory.total,memory.free",
557 "--format=csv,noheader,nounits",
558 ])
559 .output()
560 .ok()?;
561 if !output.status.success() {
562 return None;
563 }
564 let text = String::from_utf8_lossy(&output.stdout);
565 let devices = parse_nvidia_smi(&text);
566 devices.into_iter().nth(device_id)
567 }
568
569 #[cfg(target_os = "linux")]
571 pub fn detect_nvidia_device_count() -> Option<usize> {
572 let output = Command::new("nvidia-smi")
573 .args(["--query-gpu=name", "--format=csv,noheader"])
574 .output()
575 .ok()?;
576 if !output.status.success() {
577 return None;
578 }
579 let text = String::from_utf8_lossy(&output.stdout);
580 let count = text.lines().filter(|l| !l.trim().is_empty()).count();
581 if count > 0 { Some(count) } else { None }
582 }
583
584 #[cfg(target_os = "linux")]
586 pub fn detect_nvidia_sysfs() -> Vec<GpuDeviceInfo> {
587 let mut devices = Vec::new();
588 let drm_path = std::path::Path::new("/sys/class/drm");
589 if !drm_path.exists() {
590 return devices;
591 }
592
593 let entries = match std::fs::read_dir(drm_path) {
594 Ok(e) => e,
595 Err(_) => return devices,
596 };
597
598 for entry in entries.flatten() {
599 let name = entry.file_name();
600 let name_str = name.to_string_lossy();
601 if !name_str.starts_with("card") || name_str.contains('-') {
602 continue;
603 }
604
605 let vendor_path = entry.path().join("device/vendor");
606 if let Ok(vendor) = std::fs::read_to_string(&vendor_path) {
607 let vendor_trimmed = vendor.trim();
608 if vendor_trimmed == "0x10de" {
610 let device_name = read_nvidia_proc_name(devices.len())
611 .unwrap_or_else(|| format!("NVIDIA GPU (card {})", name_str));
612
613 devices.push(GpuDeviceInfo {
614 backend: GpuBackend::Cuda,
615 name: device_name,
616 compute_capability: "unknown".to_string(),
617 total_memory: 0,
618 available_memory: 0,
619 compute_units: 0,
620 });
621 }
622 }
623 }
624
625 devices
626 }
627
628 #[cfg(target_os = "linux")]
630 fn read_nvidia_proc_name(index: usize) -> Option<String> {
631 let nvidia_path = std::path::Path::new("/proc/driver/nvidia/gpus");
632 if !nvidia_path.exists() {
633 return None;
634 }
635
636 let entries: Vec<_> = std::fs::read_dir(nvidia_path).ok()?.flatten().collect();
637
638 let entry = entries.get(index)?;
639 let info_path = entry.path().join("information");
640 let content = std::fs::read_to_string(info_path).ok()?;
641
642 for line in content.lines() {
643 if let Some(stripped) = line.strip_prefix("Model:") {
644 return Some(stripped.trim().to_string());
645 }
646 }
647
648 None
649 }
650
651 pub fn parse_system_profiler(text: &str) -> Option<GpuDeviceInfo> {
656 let mut chipset_model: Option<String> = None;
657 let mut total_cores: Option<u32> = None;
658 let mut vram_bytes: Option<u64> = None;
659 let mut is_apple_silicon = false;
660
661 for line in text.lines() {
662 let trimmed = line.trim();
663
664 if let Some(value) = trimmed.strip_prefix("Chipset Model:") {
666 chipset_model = Some(value.trim().to_string());
667 if value.trim().starts_with("Apple") {
668 is_apple_silicon = true;
669 }
670 }
671
672 if let Some(value) = trimmed.strip_prefix("Total Number of Cores:") {
674 total_cores = value.trim().parse::<u32>().ok();
675 }
676
677 if trimmed.starts_with("VRAM") {
679 if let Some(colon_pos) = trimmed.find(':') {
681 let value_part = trimmed[colon_pos + 1..].trim();
682 vram_bytes = parse_memory_string(value_part);
683 }
684 }
685 }
686
687 let name = chipset_model?;
688
689 let compute_capability = if is_apple_silicon {
690 "Metal 3".to_string()
691 } else if name.contains("Intel") {
692 "Metal 2".to_string()
693 } else {
694 "Metal".to_string()
695 };
696
697 let compute_units = total_cores.unwrap_or(0);
698
699 let total_memory = vram_bytes.unwrap_or(0);
702 let available_memory = if total_memory > 0 {
703 total_memory * 9 / 10
704 } else {
705 0
706 };
707
708 Some(GpuDeviceInfo {
709 backend: GpuBackend::Metal,
710 name,
711 compute_capability,
712 total_memory,
713 available_memory,
714 compute_units,
715 })
716 }
717
718 fn parse_memory_string(s: &str) -> Option<u64> {
720 let s = s.trim();
721 let parts: Vec<&str> = s.split_whitespace().collect();
722 if parts.len() < 2 {
723 return None;
724 }
725
726 let value = parts[0].replace(',', "").parse::<u64>().ok()?;
727 let unit = parts[1].to_uppercase();
728
729 match unit.as_str() {
730 "GB" => Some(value * 1_073_741_824),
731 "MB" => Some(value * 1_048_576),
732 "KB" => Some(value * 1024),
733 "TB" => Some(value * 1_099_511_627_776),
734 _ => None,
735 }
736 }
737
738 pub fn parse_nvidia_smi(text: &str) -> Vec<GpuDeviceInfo> {
747 let mut devices = Vec::new();
748
749 for line in text.lines() {
750 let trimmed = line.trim();
751 if trimmed.is_empty() {
752 continue;
753 }
754
755 let parts: Vec<&str> = trimmed.splitn(3, ',').collect();
756 if parts.len() < 3 {
757 continue;
758 }
759
760 let name = parts[0].trim().to_string();
761 let total_mb = match parts[1].trim().parse::<u64>() {
762 Ok(v) => v,
763 Err(_) => continue,
764 };
765 let free_mb = match parts[2].trim().parse::<u64>() {
766 Ok(v) => v,
767 Err(_) => continue,
768 };
769
770 let total_memory = total_mb * 1_048_576;
772 let available_memory = free_mb * 1_048_576;
773
774 devices.push(GpuDeviceInfo {
775 backend: GpuBackend::Cuda,
776 name,
777 compute_capability: "unknown".to_string(),
778 total_memory,
779 available_memory,
780 compute_units: 0,
781 });
782 }
783
784 devices
785 }
786}
787
788#[cfg(all(feature = "cuda", feature = "compute"))]
790mod cuda {
791 use super::*;
792
793 pub struct CudaContext {
795 device_id: usize,
796 memory_pool_size: usize,
797 }
798
799 impl CudaContext {
800 pub fn new(device_id: usize, memory_pool_size: usize) -> Result<Self> {
801 Ok(Self {
804 device_id,
805 memory_pool_size,
806 })
807 }
808
809 pub fn execute_operation<F, R>(&self, operation: F) -> Result<R>
810 where
811 F: FnOnce() -> Result<R> + Send,
812 R: Send,
813 {
814 operation()
817 }
818
819 pub fn execute_batch<F, R>(&self, operations: Vec<F>, batch_size: usize) -> Result<Vec<R>>
820 where
821 F: FnOnce() -> Result<R> + Send,
822 R: Send,
823 {
824 let mut results = Vec::with_capacity(operations.len());
826 let mut iter = operations.into_iter().peekable();
827
828 while iter.peek().is_some() {
829 let batch: Vec<F> = iter.by_ref().take(batch_size).collect();
830
831 #[cfg(feature = "parallel")]
832 {
833 use rayon::prelude::*;
834 let chunk_results: Result<Vec<_>> =
835 batch.into_par_iter().map(|op| op()).collect();
836 results.extend(chunk_results?);
837 }
838 #[cfg(not(feature = "parallel"))]
839 {
840 for op in batch {
841 results.push(op()?);
842 }
843 }
844 }
845
846 Ok(results)
847 }
848 }
849
850 pub fn detect_cuda_devices() -> Result<Vec<usize>> {
854 #[cfg(target_os = "linux")]
855 {
856 if let Some(count) = detection::detect_nvidia_device_count() {
858 return Ok((0..count).collect());
859 }
860
861 let sysfs_devices = detection::detect_nvidia_sysfs();
863 if !sysfs_devices.is_empty() {
864 return Ok((0..sysfs_devices.len()).collect());
865 }
866 }
867
868 Ok(vec![0])
870 }
871
872 pub fn get_device_info(device_id: usize) -> Result<GpuDeviceInfo> {
877 #[cfg(target_os = "linux")]
879 {
880 if let Some(info) = detection::detect_nvidia_gpu(device_id) {
882 return Ok(info);
883 }
884
885 let sysfs_devices = detection::detect_nvidia_sysfs();
887 if let Some(info) = sysfs_devices.into_iter().nth(device_id) {
888 return Ok(info);
889 }
890 }
891
892 Ok(cuda_placeholder(device_id))
894 }
895
896 fn cuda_placeholder(device_id: usize) -> GpuDeviceInfo {
898 GpuDeviceInfo {
899 backend: GpuBackend::Cuda,
900 name: format!("CUDA Device {}", device_id),
901 compute_capability: "unknown".to_string(),
902 total_memory: 0,
903 available_memory: 0,
904 compute_units: 0,
905 }
906 }
907}
908
909#[cfg(all(feature = "metal", feature = "compute", target_os = "macos"))]
911mod metal {
912 use super::*;
913
914 pub struct MetalContext {
916 device_id: usize,
917 memory_pool_size: usize,
918 }
919
920 impl MetalContext {
921 pub fn new(device_id: usize, memory_pool_size: usize) -> Result<Self> {
922 Ok(Self {
925 device_id,
926 memory_pool_size,
927 })
928 }
929
930 pub fn execute_operation<F, R>(&self, operation: F) -> Result<R>
931 where
932 F: FnOnce() -> Result<R> + Send,
933 R: Send,
934 {
935 operation()
938 }
939
940 pub fn execute_batch<F, R>(&self, operations: Vec<F>, batch_size: usize) -> Result<Vec<R>>
941 where
942 F: FnOnce() -> Result<R> + Send,
943 R: Send,
944 {
945 let mut results = Vec::with_capacity(operations.len());
947 let mut iter = operations.into_iter().peekable();
948
949 while iter.peek().is_some() {
950 let batch: Vec<F> = iter.by_ref().take(batch_size).collect();
951
952 #[cfg(feature = "parallel")]
953 {
954 use rayon::prelude::*;
955 let chunk_results: Result<Vec<_>> =
956 batch.into_par_iter().map(|op| op()).collect();
957 results.extend(chunk_results?);
958 }
959 #[cfg(not(feature = "parallel"))]
960 {
961 for op in batch {
962 results.push(op()?);
963 }
964 }
965 }
966
967 Ok(results)
968 }
969 }
970
971 pub fn detect_metal_devices() -> Result<Vec<usize>> {
975 if let Some(_info) = detection::detect_macos_gpu() {
976 Ok(vec![0])
977 } else {
978 Ok(vec![0])
980 }
981 }
982
983 pub fn get_device_info(device_id: usize) -> Result<GpuDeviceInfo> {
988 if device_id == 0 {
989 if let Some(info) = detection::detect_macos_gpu() {
990 return Ok(info);
991 }
992 }
993
994 Ok(metal_placeholder(device_id))
996 }
997
998 fn metal_placeholder(device_id: usize) -> GpuDeviceInfo {
1000 GpuDeviceInfo {
1001 backend: GpuBackend::Metal,
1002 name: format!("Apple Metal Device {}", device_id),
1003 compute_capability: "Metal".to_string(),
1004 total_memory: 0,
1005 available_memory: 0,
1006 compute_units: 0,
1007 }
1008 }
1009}
1010
1011#[cfg(all(feature = "metal", feature = "compute", not(target_os = "macos")))]
1013mod metal {
1014 use super::*;
1015
1016 pub struct MetalContext;
1017
1018 impl MetalContext {
1019 pub fn new(_device_id: usize, _memory_pool_size: usize) -> Result<Self> {
1020 Err(AmateRSError::GpuError(ErrorContext::new(
1021 "Metal is only available on macOS".to_string(),
1022 )))
1023 }
1024
1025 pub fn execute_operation<F, R>(&self, _operation: F) -> Result<R>
1026 where
1027 F: FnOnce() -> Result<R> + Send,
1028 R: Send,
1029 {
1030 Err(AmateRSError::GpuError(ErrorContext::new(
1031 "Metal is only available on macOS".to_string(),
1032 )))
1033 }
1034
1035 pub fn execute_batch<F, R>(&self, _operations: Vec<F>, _batch_size: usize) -> Result<Vec<R>>
1036 where
1037 F: FnOnce() -> Result<R> + Send,
1038 R: Send,
1039 {
1040 Err(AmateRSError::GpuError(ErrorContext::new(
1041 "Metal is only available on macOS".to_string(),
1042 )))
1043 }
1044 }
1045
1046 pub fn detect_metal_devices() -> Result<Vec<usize>> {
1047 Err(AmateRSError::GpuError(ErrorContext::new(
1048 "Metal is only available on macOS".to_string(),
1049 )))
1050 }
1051
1052 pub fn get_device_info(_device_id: usize) -> Result<GpuDeviceInfo> {
1053 Err(AmateRSError::GpuError(ErrorContext::new(
1054 "Metal is only available on macOS".to_string(),
1055 )))
1056 }
1057}
1058
1059#[cfg(all(test, feature = "compute"))]
1060mod tests {
1061 use super::*;
1062
1063 #[test]
1064 fn test_gpu_backend_display() {
1065 assert_eq!(format!("{}", GpuBackend::Cuda), "CUDA");
1066 assert_eq!(format!("{}", GpuBackend::Metal), "Metal");
1067 assert_eq!(format!("{}", GpuBackend::Cpu), "CPU");
1068 }
1069
1070 #[test]
1071 fn test_gpu_config_default() {
1072 let config = GpuConfig::default();
1073 assert_eq!(config.preferred_backend, None);
1074 assert_eq!(config.device_id, 0);
1075 assert!(config.enable_batch);
1076 assert_eq!(config.batch_size, 64);
1077 assert_eq!(config.memory_pool_size, 0);
1078 }
1079
1080 #[test]
1081 fn test_gpu_executor_creation() -> Result<()> {
1082 let executor = GpuExecutor::new()?;
1083 assert!(matches!(
1084 executor.backend(),
1085 GpuBackend::Cuda | GpuBackend::Metal | GpuBackend::Cpu
1086 ));
1087 Ok(())
1088 }
1089
1090 #[test]
1091 fn test_gpu_executor_with_cpu_fallback() -> Result<()> {
1092 let config = GpuConfig {
1093 preferred_backend: Some(GpuBackend::Cpu),
1094 ..Default::default()
1095 };
1096 let executor = GpuExecutor::with_config(config)?;
1097 assert_eq!(executor.backend(), GpuBackend::Cpu);
1098 assert!(!executor.is_gpu_enabled());
1099 Ok(())
1100 }
1101
1102 #[test]
1103 fn test_device_info() -> Result<()> {
1104 let executor = GpuExecutor::new()?;
1105 let info = executor.device_info();
1106 assert!(info.is_some());
1107
1108 if let Some(info) = info {
1109 assert!(!info.name.is_empty());
1110 assert!(!info.compute_capability.is_empty());
1111 }
1112
1113 Ok(())
1114 }
1115
1116 #[test]
1117 fn test_execute_operation_cpu() -> Result<()> {
1118 let config = GpuConfig {
1119 preferred_backend: Some(GpuBackend::Cpu),
1120 ..Default::default()
1121 };
1122 let executor = GpuExecutor::with_config(config)?;
1123
1124 let result = executor.execute_operation(|| Ok(42))?;
1125 assert_eq!(result, 42);
1126
1127 Ok(())
1128 }
1129
1130 #[cfg(feature = "parallel")]
1131 #[test]
1132 fn test_execute_batch_cpu() -> Result<()> {
1133 let config = GpuConfig {
1134 preferred_backend: Some(GpuBackend::Cpu),
1135 enable_batch: true,
1136 batch_size: 4,
1137 ..Default::default()
1138 };
1139 let executor = GpuExecutor::with_config(config)?;
1140
1141 let operations: Vec<_> = (0..10).map(|i| move || Ok(i * 2)).collect();
1142
1143 let results = executor.execute_batch(operations)?;
1144 assert_eq!(results.len(), 10);
1145 assert_eq!(results, vec![0, 2, 4, 6, 8, 10, 12, 14, 16, 18]);
1146
1147 Ok(())
1148 }
1149
1150 #[test]
1153 fn test_parse_system_profiler_m1() {
1154 let output = "\
1155Graphics/Displays:
1156
1157 Apple M1:
1158
1159 Chipset Model: Apple M1
1160 Type: GPU
1161 Bus: Built-In
1162 Total Number of Cores: 8
1163 Vendor: Apple (0x106b)
1164 Metal Support: Metal 3
1165";
1166 let info = detection::parse_system_profiler(output);
1167 assert!(info.is_some());
1168 let info = info.expect("should parse");
1169 assert_eq!(info.name, "Apple M1");
1170 assert_eq!(info.compute_units, 8);
1171 assert_eq!(info.backend, GpuBackend::Metal);
1172 assert_eq!(info.compute_capability, "Metal 3");
1173 }
1174
1175 #[test]
1176 fn test_parse_system_profiler_m2_pro() {
1177 let output = "\
1178Graphics/Displays:
1179
1180 Apple M2 Pro:
1181
1182 Chipset Model: Apple M2 Pro
1183 Type: GPU
1184 Bus: Built-In
1185 Total Number of Cores: 19
1186 Vendor: Apple (0x106b)
1187 Metal Support: Metal 3
1188";
1189 let info = detection::parse_system_profiler(output);
1190 assert!(info.is_some());
1191 let info = info.expect("should parse");
1192 assert_eq!(info.name, "Apple M2 Pro");
1193 assert_eq!(info.compute_units, 19);
1194 assert_eq!(info.compute_capability, "Metal 3");
1195 }
1196
1197 #[test]
1198 fn test_parse_system_profiler_m3_max() {
1199 let output = "\
1200Graphics/Displays:
1201
1202 Apple M3 Max:
1203
1204 Chipset Model: Apple M3 Max
1205 Type: GPU
1206 Bus: Built-In
1207 Total Number of Cores: 40
1208 Vendor: Apple (0x106b)
1209 Metal Support: Metal 3
1210";
1211 let info = detection::parse_system_profiler(output);
1212 assert!(info.is_some());
1213 let info = info.expect("should parse");
1214 assert_eq!(info.name, "Apple M3 Max");
1215 assert_eq!(info.compute_units, 40);
1216 assert_eq!(info.compute_capability, "Metal 3");
1217 }
1218
1219 #[test]
1220 fn test_parse_system_profiler_intel_gpu() {
1221 let output = "\
1222Graphics/Displays:
1223
1224 Intel Iris Plus Graphics 655:
1225
1226 Chipset Model: Intel Iris Plus Graphics 655
1227 Type: GPU
1228 Bus: Built-In
1229 VRAM (Dynamic, Max): 1536 MB
1230 Vendor: Intel (0x8086)
1231 Device ID: 0x3ea5
1232 Metal Support: Metal 2
1233";
1234 let info = detection::parse_system_profiler(output);
1235 assert!(info.is_some());
1236 let info = info.expect("should parse");
1237 assert_eq!(info.name, "Intel Iris Plus Graphics 655");
1238 assert_eq!(info.compute_capability, "Metal 2");
1239 assert_eq!(info.total_memory, 1_610_612_736);
1241 assert!(info.available_memory > 0);
1242 assert_eq!(info.compute_units, 0); }
1244
1245 #[test]
1246 fn test_parse_system_profiler_empty() {
1247 let output = "";
1248 let info = detection::parse_system_profiler(output);
1249 assert!(info.is_none());
1250 }
1251
1252 #[test]
1253 fn test_parse_system_profiler_no_gpu_section() {
1254 let output = "\
1255Graphics/Displays:
1256
1257 No GPU found.
1258";
1259 let info = detection::parse_system_profiler(output);
1260 assert!(info.is_none());
1261 }
1262
1263 #[test]
1264 fn test_parse_nvidia_smi_single() {
1265 let output = "NVIDIA GeForce RTX 4090, 24564, 23456\n";
1266 let devices = detection::parse_nvidia_smi(output);
1267 assert_eq!(devices.len(), 1);
1268 assert_eq!(devices[0].name, "NVIDIA GeForce RTX 4090");
1269 assert_eq!(devices[0].total_memory, 24564 * 1_048_576);
1270 assert_eq!(devices[0].available_memory, 23456 * 1_048_576);
1271 assert_eq!(devices[0].backend, GpuBackend::Cuda);
1272 }
1273
1274 #[test]
1275 fn test_parse_nvidia_smi_multi() {
1276 let output = "\
1277NVIDIA GeForce RTX 4090, 24564, 23456
1278NVIDIA A100-SXM4-80GB, 81920, 79000
1279";
1280 let devices = detection::parse_nvidia_smi(output);
1281 assert_eq!(devices.len(), 2);
1282 assert_eq!(devices[0].name, "NVIDIA GeForce RTX 4090");
1283 assert_eq!(devices[0].total_memory, 24564 * 1_048_576);
1284 assert_eq!(devices[1].name, "NVIDIA A100-SXM4-80GB");
1285 assert_eq!(devices[1].total_memory, 81920 * 1_048_576);
1286 assert_eq!(devices[1].available_memory, 79000 * 1_048_576);
1287 }
1288
1289 #[test]
1290 fn test_parse_nvidia_smi_empty() {
1291 let output = "";
1292 let devices = detection::parse_nvidia_smi(output);
1293 assert!(devices.is_empty());
1294 }
1295
1296 #[test]
1297 fn test_parse_nvidia_smi_malformed() {
1298 let output = "\
1299this is not valid csv data
1300also garbage
1301,,
1302just-one-field
1303name, not_a_number, 123
1304name, 123, not_a_number
1305";
1306 let devices = detection::parse_nvidia_smi(output);
1307 assert!(devices.is_empty());
1308 }
1309
1310 #[test]
1311 fn test_gpu_device_info_fields() {
1312 let info = GpuDeviceInfo {
1313 backend: GpuBackend::Cuda,
1314 name: "Test GPU".to_string(),
1315 compute_capability: "8.9".to_string(),
1316 total_memory: 16_000_000_000,
1317 available_memory: 15_000_000_000,
1318 compute_units: 128,
1319 };
1320 assert_eq!(info.backend, GpuBackend::Cuda);
1321 assert_eq!(info.name, "Test GPU");
1322 assert_eq!(info.compute_capability, "8.9");
1323 assert_eq!(info.total_memory, 16_000_000_000);
1324 assert_eq!(info.available_memory, 15_000_000_000);
1325 assert_eq!(info.compute_units, 128);
1326 }
1327
1328 #[test]
1329 fn test_fallback_to_placeholder_cuda() {
1330 let devices = detection::parse_nvidia_smi("");
1333 assert!(devices.is_empty());
1334
1335 let devices = detection::parse_nvidia_smi("garbage data here");
1337 assert!(devices.is_empty());
1338 }
1339
1340 #[test]
1341 fn test_fallback_to_placeholder_metal() {
1342 let info = detection::parse_system_profiler("");
1345 assert!(info.is_none());
1346
1347 let info = detection::parse_system_profiler("Graphics/Displays:\n No data\n");
1349 assert!(info.is_none());
1350 }
1351
1352 #[test]
1353 fn test_detect_on_current_platform() {
1354 #[cfg(target_os = "macos")]
1356 {
1357 let info = detection::detect_macos_gpu();
1358 assert!(info.is_some(), "should detect GPU on macOS");
1360 let info = info.expect("GPU detected");
1361 assert!(!info.name.is_empty());
1362 assert_eq!(info.backend, GpuBackend::Metal);
1363 if info.name.starts_with("Apple") {
1365 assert!(info.total_memory > 0, "Apple Silicon should report memory");
1366 assert!(info.compute_units > 0, "Apple Silicon should report cores");
1367 }
1368 }
1369
1370 #[cfg(target_os = "linux")]
1371 {
1372 let _nvidia = detection::detect_nvidia_gpu(0);
1375 let _count = detection::detect_nvidia_device_count();
1376 let _sysfs = detection::detect_nvidia_sysfs();
1377 }
1378 }
1379
1380 #[test]
1381 fn test_parse_nvidia_smi_whitespace_handling() {
1382 let output = " NVIDIA RTX 3080 , 10240 , 9500 \n";
1383 let devices = detection::parse_nvidia_smi(output);
1384 assert_eq!(devices.len(), 1);
1385 assert_eq!(devices[0].name, "NVIDIA RTX 3080");
1386 assert_eq!(devices[0].total_memory, 10240 * 1_048_576);
1387 assert_eq!(devices[0].available_memory, 9500 * 1_048_576);
1388 }
1389}