optirs_gpu/memory/vendors/
mod.rs

1// Vendor-specific GPU memory backends
2//
3// This module provides vendor-specific GPU memory management implementations
4// for different GPU architectures and platforms.
5
6pub mod cuda_backend;
7pub mod metal_backend;
8pub mod oneapi_backend;
9pub mod rocm_backend;
10
11use std::ffi::c_void;
12use std::time::Duration;
13
14pub use cuda_backend::{
15    CudaConfig, CudaError, CudaMemoryBackend, CudaMemoryType, ThreadSafeCudaBackend,
16};
17pub use metal_backend::{
18    MetalConfig, MetalError, MetalMemoryBackend, MetalMemoryType, ThreadSafeMetalBackend,
19};
20pub use oneapi_backend::{
21    OneApiConfig, OneApiError, OneApiMemoryBackend, OneApiMemoryType, ThreadSafeOneApiBackend,
22};
23pub use rocm_backend::{
24    RocmConfig, RocmError, RocmMemoryBackend, RocmMemoryType, ThreadSafeRocmBackend,
25};
26
27/// Unified GPU vendor types
28#[derive(Debug, Clone, PartialEq)]
29pub enum GpuVendor {
30    Nvidia,
31    Amd,
32    Intel,
33    Apple,
34    Unknown,
35}
36
37/// Unified memory backend trait for all GPU vendors
38pub trait GpuMemoryBackend {
39    type Error: std::error::Error + Send + Sync + 'static;
40    type MemoryType: Clone + PartialEq;
41    type Stats: Clone;
42
43    /// Allocate GPU memory
44    fn allocate(
45        &mut self,
46        size: usize,
47        memory_type: Self::MemoryType,
48    ) -> Result<*mut c_void, Self::Error>;
49
50    /// Free GPU memory
51    fn free(&mut self, ptr: *mut c_void, memory_type: Self::MemoryType) -> Result<(), Self::Error>;
52
53    /// Get memory statistics
54    fn get_stats(&self) -> Self::Stats;
55
56    /// Synchronize all operations
57    fn synchronize(&mut self) -> Result<(), Self::Error>;
58
59    /// Get GPU vendor
60    fn get_vendor(&self) -> GpuVendor;
61
62    /// Get device name
63    fn get_device_name(&self) -> &str;
64
65    /// Get total memory size
66    fn get_total_memory(&self) -> usize;
67}
68
69/// Vendor detection and backend creation
70pub struct GpuBackendFactory;
71
72impl GpuBackendFactory {
73    /// Detect available GPU vendors
74    pub fn detect_available_vendors() -> Vec<GpuVendor> {
75        let mut vendors = Vec::new();
76
77        // Simulate vendor detection
78        #[cfg(target_os = "linux")]
79        {
80            vendors.push(GpuVendor::Nvidia);
81            vendors.push(GpuVendor::Amd);
82            vendors.push(GpuVendor::Intel);
83        }
84
85        #[cfg(target_os = "windows")]
86        {
87            vendors.push(GpuVendor::Nvidia);
88            vendors.push(GpuVendor::Amd);
89            vendors.push(GpuVendor::Intel);
90        }
91
92        #[cfg(target_os = "macos")]
93        {
94            vendors.push(GpuVendor::Apple);
95            vendors.push(GpuVendor::Intel); // Intel Macs
96        }
97
98        vendors
99    }
100
101    /// Get preferred vendor based on platform
102    pub fn get_preferred_vendor() -> GpuVendor {
103        #[cfg(target_os = "macos")]
104        {
105            GpuVendor::Apple
106        }
107
108        #[cfg(any(target_os = "linux", target_os = "windows"))]
109        {
110            // Prefer NVIDIA for CUDA support, then AMD, then Intel
111            let vendors = Self::detect_available_vendors();
112            if vendors.contains(&GpuVendor::Nvidia) {
113                return GpuVendor::Nvidia;
114            } else if vendors.contains(&GpuVendor::Amd) {
115                return GpuVendor::Amd;
116            } else if vendors.contains(&GpuVendor::Intel) {
117                return GpuVendor::Intel;
118            }
119            return GpuVendor::Unknown;
120        }
121
122        #[cfg(not(any(target_os = "macos", target_os = "linux", target_os = "windows")))]
123        {
124            GpuVendor::Unknown
125        }
126    }
127
128    /// Create backend configuration for vendor
129    pub fn create_default_config(vendor: GpuVendor) -> VendorConfig {
130        match vendor {
131            GpuVendor::Nvidia => VendorConfig::Cuda(CudaConfig::default()),
132            GpuVendor::Amd => VendorConfig::Rocm(RocmConfig::default()),
133            GpuVendor::Intel => VendorConfig::OneApi(OneApiConfig::default()),
134            GpuVendor::Apple => VendorConfig::Metal(MetalConfig::default()),
135            GpuVendor::Unknown => VendorConfig::Cuda(CudaConfig::default()), // Fallback
136        }
137    }
138}
139
140/// Unified configuration for all vendors
141#[derive(Debug, Clone)]
142pub enum VendorConfig {
143    Cuda(CudaConfig),
144    Rocm(RocmConfig),
145    OneApi(OneApiConfig),
146    Metal(MetalConfig),
147}
148
149/// Unified backend wrapper
150pub enum UnifiedGpuBackend {
151    Cuda(CudaMemoryBackend),
152    Rocm(RocmMemoryBackend),
153    OneApi(OneApiMemoryBackend),
154    Metal(MetalMemoryBackend),
155}
156
157impl UnifiedGpuBackend {
158    /// Create backend from configuration
159    pub fn new(config: VendorConfig) -> Result<Self, UnifiedGpuError> {
160        match config {
161            VendorConfig::Cuda(config) => {
162                let backend = CudaMemoryBackend::new(config)?;
163                Ok(UnifiedGpuBackend::Cuda(backend))
164            }
165            VendorConfig::Rocm(config) => {
166                let backend = RocmMemoryBackend::new(config)?;
167                Ok(UnifiedGpuBackend::Rocm(backend))
168            }
169            VendorConfig::OneApi(config) => {
170                let backend = OneApiMemoryBackend::new(config)?;
171                Ok(UnifiedGpuBackend::OneApi(backend))
172            }
173            VendorConfig::Metal(config) => {
174                let backend = MetalMemoryBackend::new(config)?;
175                Ok(UnifiedGpuBackend::Metal(backend))
176            }
177        }
178    }
179
180    /// Auto-detect and create best backend
181    pub fn auto_create() -> Result<Self, UnifiedGpuError> {
182        let vendor = GpuBackendFactory::get_preferred_vendor();
183        let config = GpuBackendFactory::create_default_config(vendor);
184        Self::new(config)
185    }
186
187    /// Get vendor type
188    pub fn get_vendor(&self) -> GpuVendor {
189        match self {
190            UnifiedGpuBackend::Cuda(_) => GpuVendor::Nvidia,
191            UnifiedGpuBackend::Rocm(_) => GpuVendor::Amd,
192            UnifiedGpuBackend::OneApi(_) => GpuVendor::Intel,
193            UnifiedGpuBackend::Metal(_) => GpuVendor::Apple,
194        }
195    }
196
197    /// Allocate memory with unified interface
198    pub fn allocate(&mut self, size: usize) -> Result<*mut c_void, UnifiedGpuError> {
199        match self {
200            UnifiedGpuBackend::Cuda(backend) => backend
201                .allocate(size, CudaMemoryType::Device)
202                .map_err(UnifiedGpuError::Cuda),
203            UnifiedGpuBackend::Rocm(backend) => backend
204                .allocate(size, RocmMemoryType::Device)
205                .map_err(UnifiedGpuError::Rocm),
206            UnifiedGpuBackend::OneApi(backend) => backend
207                .allocate(size, OneApiMemoryType::Device)
208                .map_err(UnifiedGpuError::OneApi),
209            UnifiedGpuBackend::Metal(backend) => backend
210                .allocate(size, MetalMemoryType::Private)
211                .map_err(UnifiedGpuError::Metal),
212        }
213    }
214
215    /// Free memory with unified interface
216    pub fn free(&mut self, ptr: *mut c_void) -> Result<(), UnifiedGpuError> {
217        match self {
218            UnifiedGpuBackend::Cuda(backend) => backend
219                .free(ptr, CudaMemoryType::Device)
220                .map_err(UnifiedGpuError::Cuda),
221            UnifiedGpuBackend::Rocm(backend) => backend
222                .free(ptr, RocmMemoryType::Device)
223                .map_err(UnifiedGpuError::Rocm),
224            UnifiedGpuBackend::OneApi(backend) => backend
225                .free(ptr, OneApiMemoryType::Device)
226                .map_err(UnifiedGpuError::OneApi),
227            UnifiedGpuBackend::Metal(backend) => backend
228                .free(ptr, MetalMemoryType::Private)
229                .map_err(UnifiedGpuError::Metal),
230        }
231    }
232
233    /// Get unified memory statistics
234    /// Get total available GPU memory
235    pub fn get_total_memory(&self) -> usize {
236        // Default to 8GB if backend doesn't provide memory info
237        // Individual backends should implement proper memory querying
238        match self {
239            UnifiedGpuBackend::Cuda(_) => 8 * 1024 * 1024 * 1024, // 8GB default for CUDA
240            UnifiedGpuBackend::Rocm(_) => 8 * 1024 * 1024 * 1024, // 8GB default for ROCm
241            UnifiedGpuBackend::OneApi(_) => 8 * 1024 * 1024 * 1024, // 8GB default for OneAPI
242            UnifiedGpuBackend::Metal(_) => 8 * 1024 * 1024 * 1024, // 8GB default for Metal
243        }
244    }
245
246    pub fn get_memory_stats(&self) -> UnifiedMemoryStats {
247        match self {
248            UnifiedGpuBackend::Cuda(backend) => {
249                let stats = backend.get_stats();
250                UnifiedMemoryStats {
251                    total_allocations: stats.total_allocations,
252                    bytes_allocated: stats.bytes_allocated,
253                    peak_memory_usage: stats.peak_memory_usage,
254                    average_allocation_time: stats.average_allocation_time,
255                }
256            }
257            UnifiedGpuBackend::Rocm(backend) => {
258                let stats = backend.get_stats();
259                UnifiedMemoryStats {
260                    total_allocations: stats.total_allocations,
261                    bytes_allocated: stats.bytes_allocated,
262                    peak_memory_usage: stats.peak_memory_usage,
263                    average_allocation_time: stats.average_allocation_time,
264                }
265            }
266            UnifiedGpuBackend::OneApi(backend) => {
267                let stats = backend.get_stats();
268                UnifiedMemoryStats {
269                    total_allocations: stats.total_allocations,
270                    bytes_allocated: stats.bytes_allocated,
271                    peak_memory_usage: stats.peak_memory_usage,
272                    average_allocation_time: stats.average_allocation_time,
273                }
274            }
275            UnifiedGpuBackend::Metal(backend) => {
276                let stats = backend.get_stats();
277                UnifiedMemoryStats {
278                    total_allocations: stats.total_allocations,
279                    bytes_allocated: stats.bytes_allocated,
280                    peak_memory_usage: stats.peak_memory_usage,
281                    average_allocation_time: stats.average_allocation_time,
282                }
283            }
284        }
285    }
286}
287
288/// Unified memory statistics across all vendors
289#[derive(Debug, Clone, Default)]
290pub struct UnifiedMemoryStats {
291    pub total_allocations: u64,
292    pub bytes_allocated: u64,
293    pub peak_memory_usage: usize,
294    pub average_allocation_time: Duration,
295}
296
297/// Unified error type for all GPU backends
298#[derive(Debug)]
299pub enum UnifiedGpuError {
300    Cuda(CudaError),
301    Rocm(RocmError),
302    OneApi(OneApiError),
303    Metal(MetalError),
304    VendorNotSupported(String),
305    InitializationFailed(String),
306}
307
308impl std::fmt::Display for UnifiedGpuError {
309    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
310        match self {
311            UnifiedGpuError::Cuda(err) => write!(f, "CUDA Error: {}", err),
312            UnifiedGpuError::Rocm(err) => write!(f, "ROCm Error: {}", err),
313            UnifiedGpuError::OneApi(err) => write!(f, "OneAPI Error: {}", err),
314            UnifiedGpuError::Metal(err) => write!(f, "Metal Error: {}", err),
315            UnifiedGpuError::VendorNotSupported(msg) => write!(f, "Vendor not supported: {}", msg),
316            UnifiedGpuError::InitializationFailed(msg) => {
317                write!(f, "Initialization failed: {}", msg)
318            }
319        }
320    }
321}
322
323impl std::error::Error for UnifiedGpuError {}
324
325impl From<CudaError> for UnifiedGpuError {
326    fn from(err: CudaError) -> Self {
327        UnifiedGpuError::Cuda(err)
328    }
329}
330
331impl From<RocmError> for UnifiedGpuError {
332    fn from(err: RocmError) -> Self {
333        UnifiedGpuError::Rocm(err)
334    }
335}
336
337impl From<OneApiError> for UnifiedGpuError {
338    fn from(err: OneApiError) -> Self {
339        UnifiedGpuError::OneApi(err)
340    }
341}
342
343impl From<MetalError> for UnifiedGpuError {
344    fn from(err: MetalError) -> Self {
345        UnifiedGpuError::Metal(err)
346    }
347}
348
349#[cfg(test)]
350mod tests {
351    use super::*;
352
353    #[test]
354    fn test_vendor_detection() {
355        let vendors = GpuBackendFactory::detect_available_vendors();
356        assert!(!vendors.is_empty());
357    }
358
359    #[test]
360    fn test_preferred_vendor() {
361        let vendor = GpuBackendFactory::get_preferred_vendor();
362        assert_ne!(vendor, GpuVendor::Unknown);
363    }
364
365    #[test]
366    fn test_unified_backend_creation() {
367        let vendor = GpuBackendFactory::get_preferred_vendor();
368        let config = GpuBackendFactory::create_default_config(vendor);
369        let backend = UnifiedGpuBackend::new(config);
370        assert!(backend.is_ok());
371    }
372
373    #[test]
374    fn test_auto_create() {
375        let backend = UnifiedGpuBackend::auto_create();
376        assert!(backend.is_ok());
377    }
378}