optirs_gpu/memory/vendors/
mod.rs1pub mod cuda_backend;
7pub mod metal_backend;
8pub mod oneapi_backend;
9pub mod rocm_backend;
10
11use std::ffi::c_void;
12use std::time::Duration;
13
14pub use cuda_backend::{
15 CudaConfig, CudaError, CudaMemoryBackend, CudaMemoryType, ThreadSafeCudaBackend,
16};
17pub use metal_backend::{
18 MetalConfig, MetalError, MetalMemoryBackend, MetalMemoryType, ThreadSafeMetalBackend,
19};
20pub use oneapi_backend::{
21 OneApiConfig, OneApiError, OneApiMemoryBackend, OneApiMemoryType, ThreadSafeOneApiBackend,
22};
23pub use rocm_backend::{
24 RocmConfig, RocmError, RocmMemoryBackend, RocmMemoryType, ThreadSafeRocmBackend,
25};
26
27#[derive(Debug, Clone, PartialEq)]
29pub enum GpuVendor {
30 Nvidia,
31 Amd,
32 Intel,
33 Apple,
34 Unknown,
35}
36
37pub trait GpuMemoryBackend {
39 type Error: std::error::Error + Send + Sync + 'static;
40 type MemoryType: Clone + PartialEq;
41 type Stats: Clone;
42
43 fn allocate(
45 &mut self,
46 size: usize,
47 memory_type: Self::MemoryType,
48 ) -> Result<*mut c_void, Self::Error>;
49
50 fn free(&mut self, ptr: *mut c_void, memory_type: Self::MemoryType) -> Result<(), Self::Error>;
52
53 fn get_stats(&self) -> Self::Stats;
55
56 fn synchronize(&mut self) -> Result<(), Self::Error>;
58
59 fn get_vendor(&self) -> GpuVendor;
61
62 fn get_device_name(&self) -> &str;
64
65 fn get_total_memory(&self) -> usize;
67}
68
69pub struct GpuBackendFactory;
71
72impl GpuBackendFactory {
73 pub fn detect_available_vendors() -> Vec<GpuVendor> {
75 let mut vendors = Vec::new();
76
77 #[cfg(target_os = "linux")]
79 {
80 vendors.push(GpuVendor::Nvidia);
81 vendors.push(GpuVendor::Amd);
82 vendors.push(GpuVendor::Intel);
83 }
84
85 #[cfg(target_os = "windows")]
86 {
87 vendors.push(GpuVendor::Nvidia);
88 vendors.push(GpuVendor::Amd);
89 vendors.push(GpuVendor::Intel);
90 }
91
92 #[cfg(target_os = "macos")]
93 {
94 vendors.push(GpuVendor::Apple);
95 vendors.push(GpuVendor::Intel); }
97
98 vendors
99 }
100
101 pub fn get_preferred_vendor() -> GpuVendor {
103 #[cfg(target_os = "macos")]
104 {
105 GpuVendor::Apple
106 }
107
108 #[cfg(any(target_os = "linux", target_os = "windows"))]
109 {
110 let vendors = Self::detect_available_vendors();
112 if vendors.contains(&GpuVendor::Nvidia) {
113 return GpuVendor::Nvidia;
114 } else if vendors.contains(&GpuVendor::Amd) {
115 return GpuVendor::Amd;
116 } else if vendors.contains(&GpuVendor::Intel) {
117 return GpuVendor::Intel;
118 }
119 GpuVendor::Unknown
120 }
121
122 #[cfg(not(any(target_os = "macos", target_os = "linux", target_os = "windows")))]
123 {
124 GpuVendor::Unknown
125 }
126 }
127
128 pub fn create_default_config(vendor: GpuVendor) -> VendorConfig {
130 match vendor {
131 GpuVendor::Nvidia => VendorConfig::Cuda(CudaConfig::default()),
132 GpuVendor::Amd => VendorConfig::Rocm(RocmConfig::default()),
133 GpuVendor::Intel => VendorConfig::OneApi(OneApiConfig::default()),
134 GpuVendor::Apple => VendorConfig::Metal(MetalConfig::default()),
135 GpuVendor::Unknown => VendorConfig::Cuda(CudaConfig::default()), }
137 }
138}
139
140#[derive(Debug, Clone)]
142pub enum VendorConfig {
143 Cuda(CudaConfig),
144 Rocm(RocmConfig),
145 OneApi(OneApiConfig),
146 Metal(MetalConfig),
147}
148
149pub enum UnifiedGpuBackend {
151 Cuda(CudaMemoryBackend),
152 Rocm(RocmMemoryBackend),
153 OneApi(OneApiMemoryBackend),
154 Metal(MetalMemoryBackend),
155}
156
157impl UnifiedGpuBackend {
158 pub fn new(config: VendorConfig) -> Result<Self, UnifiedGpuError> {
160 match config {
161 VendorConfig::Cuda(config) => {
162 let backend = CudaMemoryBackend::new(config)?;
163 Ok(UnifiedGpuBackend::Cuda(backend))
164 }
165 VendorConfig::Rocm(config) => {
166 let backend = RocmMemoryBackend::new(config)?;
167 Ok(UnifiedGpuBackend::Rocm(backend))
168 }
169 VendorConfig::OneApi(config) => {
170 let backend = OneApiMemoryBackend::new(config)?;
171 Ok(UnifiedGpuBackend::OneApi(backend))
172 }
173 VendorConfig::Metal(config) => {
174 let backend = MetalMemoryBackend::new(config)?;
175 Ok(UnifiedGpuBackend::Metal(backend))
176 }
177 }
178 }
179
180 pub fn auto_create() -> Result<Self, UnifiedGpuError> {
182 let vendor = GpuBackendFactory::get_preferred_vendor();
183 let config = GpuBackendFactory::create_default_config(vendor);
184 Self::new(config)
185 }
186
187 pub fn get_vendor(&self) -> GpuVendor {
189 match self {
190 UnifiedGpuBackend::Cuda(_) => GpuVendor::Nvidia,
191 UnifiedGpuBackend::Rocm(_) => GpuVendor::Amd,
192 UnifiedGpuBackend::OneApi(_) => GpuVendor::Intel,
193 UnifiedGpuBackend::Metal(_) => GpuVendor::Apple,
194 }
195 }
196
197 pub fn allocate(&mut self, size: usize) -> Result<*mut c_void, UnifiedGpuError> {
199 match self {
200 UnifiedGpuBackend::Cuda(backend) => backend
201 .allocate(size, CudaMemoryType::Device)
202 .map_err(UnifiedGpuError::Cuda),
203 UnifiedGpuBackend::Rocm(backend) => backend
204 .allocate(size, RocmMemoryType::Device)
205 .map_err(UnifiedGpuError::Rocm),
206 UnifiedGpuBackend::OneApi(backend) => backend
207 .allocate(size, OneApiMemoryType::Device)
208 .map_err(UnifiedGpuError::OneApi),
209 UnifiedGpuBackend::Metal(backend) => backend
210 .allocate(size, MetalMemoryType::Private)
211 .map_err(UnifiedGpuError::Metal),
212 }
213 }
214
215 pub fn free(&mut self, ptr: *mut c_void) -> Result<(), UnifiedGpuError> {
217 match self {
218 UnifiedGpuBackend::Cuda(backend) => backend
219 .free(ptr, CudaMemoryType::Device)
220 .map_err(UnifiedGpuError::Cuda),
221 UnifiedGpuBackend::Rocm(backend) => backend
222 .free(ptr, RocmMemoryType::Device)
223 .map_err(UnifiedGpuError::Rocm),
224 UnifiedGpuBackend::OneApi(backend) => backend
225 .free(ptr, OneApiMemoryType::Device)
226 .map_err(UnifiedGpuError::OneApi),
227 UnifiedGpuBackend::Metal(backend) => backend
228 .free(ptr, MetalMemoryType::Private)
229 .map_err(UnifiedGpuError::Metal),
230 }
231 }
232
233 pub fn get_total_memory(&self) -> usize {
236 match self {
239 UnifiedGpuBackend::Cuda(_) => 8 * 1024 * 1024 * 1024, UnifiedGpuBackend::Rocm(_) => 8 * 1024 * 1024 * 1024, UnifiedGpuBackend::OneApi(_) => 8 * 1024 * 1024 * 1024, UnifiedGpuBackend::Metal(_) => 8 * 1024 * 1024 * 1024, }
244 }
245
246 pub fn get_memory_stats(&self) -> UnifiedMemoryStats {
247 match self {
248 UnifiedGpuBackend::Cuda(backend) => {
249 let stats = backend.get_stats();
250 UnifiedMemoryStats {
251 total_allocations: stats.total_allocations,
252 bytes_allocated: stats.bytes_allocated,
253 peak_memory_usage: stats.peak_memory_usage,
254 average_allocation_time: stats.average_allocation_time,
255 }
256 }
257 UnifiedGpuBackend::Rocm(backend) => {
258 let stats = backend.get_stats();
259 UnifiedMemoryStats {
260 total_allocations: stats.total_allocations,
261 bytes_allocated: stats.bytes_allocated,
262 peak_memory_usage: stats.peak_memory_usage,
263 average_allocation_time: stats.average_allocation_time,
264 }
265 }
266 UnifiedGpuBackend::OneApi(backend) => {
267 let stats = backend.get_stats();
268 UnifiedMemoryStats {
269 total_allocations: stats.total_allocations,
270 bytes_allocated: stats.bytes_allocated,
271 peak_memory_usage: stats.peak_memory_usage,
272 average_allocation_time: stats.average_allocation_time,
273 }
274 }
275 UnifiedGpuBackend::Metal(backend) => {
276 let stats = backend.get_stats();
277 UnifiedMemoryStats {
278 total_allocations: stats.total_allocations,
279 bytes_allocated: stats.bytes_allocated,
280 peak_memory_usage: stats.peak_memory_usage,
281 average_allocation_time: stats.average_allocation_time,
282 }
283 }
284 }
285 }
286}
287
288#[derive(Debug, Clone, Default)]
290pub struct UnifiedMemoryStats {
291 pub total_allocations: u64,
292 pub bytes_allocated: u64,
293 pub peak_memory_usage: usize,
294 pub average_allocation_time: Duration,
295}
296
297#[derive(Debug)]
299pub enum UnifiedGpuError {
300 Cuda(CudaError),
301 Rocm(RocmError),
302 OneApi(OneApiError),
303 Metal(MetalError),
304 VendorNotSupported(String),
305 InitializationFailed(String),
306}
307
308impl std::fmt::Display for UnifiedGpuError {
309 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
310 match self {
311 UnifiedGpuError::Cuda(err) => write!(f, "CUDA Error: {}", err),
312 UnifiedGpuError::Rocm(err) => write!(f, "ROCm Error: {}", err),
313 UnifiedGpuError::OneApi(err) => write!(f, "OneAPI Error: {}", err),
314 UnifiedGpuError::Metal(err) => write!(f, "Metal Error: {}", err),
315 UnifiedGpuError::VendorNotSupported(msg) => write!(f, "Vendor not supported: {}", msg),
316 UnifiedGpuError::InitializationFailed(msg) => {
317 write!(f, "Initialization failed: {}", msg)
318 }
319 }
320 }
321}
322
323impl std::error::Error for UnifiedGpuError {}
324
325impl From<CudaError> for UnifiedGpuError {
326 fn from(err: CudaError) -> Self {
327 UnifiedGpuError::Cuda(err)
328 }
329}
330
331impl From<RocmError> for UnifiedGpuError {
332 fn from(err: RocmError) -> Self {
333 UnifiedGpuError::Rocm(err)
334 }
335}
336
337impl From<OneApiError> for UnifiedGpuError {
338 fn from(err: OneApiError) -> Self {
339 UnifiedGpuError::OneApi(err)
340 }
341}
342
343impl From<MetalError> for UnifiedGpuError {
344 fn from(err: MetalError) -> Self {
345 UnifiedGpuError::Metal(err)
346 }
347}
348
349#[cfg(test)]
350mod tests {
351 use super::*;
352
353 #[test]
354 fn test_vendor_detection() {
355 let vendors = GpuBackendFactory::detect_available_vendors();
356 assert!(!vendors.is_empty());
357 }
358
359 #[test]
360 fn test_preferred_vendor() {
361 let vendor = GpuBackendFactory::get_preferred_vendor();
362 assert_ne!(vendor, GpuVendor::Unknown);
363 }
364
365 #[test]
366 fn test_unified_backend_creation() {
367 let vendor = GpuBackendFactory::get_preferred_vendor();
368 let config = GpuBackendFactory::create_default_config(vendor);
369 let backend = UnifiedGpuBackend::new(config);
370 assert!(backend.is_ok());
371 }
372
373 #[test]
374 fn test_auto_create() {
375 let backend = UnifiedGpuBackend::auto_create();
376 assert!(backend.is_ok());
377 }
378}