1#[allow(dead_code)]
8use std::collections::HashMap;
9use std::ffi::c_void;
10use std::ptr::NonNull;
11use std::sync::{Arc, Mutex};
12use std::time::{Duration, Instant};
13
14pub struct CudaMemoryBackend {
16 config: CudaConfig,
18 device_properties: CudaDeviceProperties,
20 contexts: HashMap<u32, CudaContext>,
22 memory_pools: HashMap<CudaMemoryType, CudaMemoryPool>,
24 stats: CudaStats,
26 stream_manager: CudaStreamManager,
28}
29
30#[derive(Debug, Clone)]
32pub struct CudaConfig {
33 pub device_id: u32,
35 pub enable_unified_memory: bool,
37 pub enable_memory_pools: bool,
39 pub enable_async_ops: bool,
41 pub pool_growth_size: usize,
43 pub enable_mapped_memory: bool,
45 pub enable_cuda_graphs: bool,
47 pub enable_cooperative_groups: bool,
49 pub max_streams: u32,
51}
52
53impl Default for CudaConfig {
54 fn default() -> Self {
55 Self {
56 device_id: 0,
57 enable_unified_memory: true,
58 enable_memory_pools: true,
59 enable_async_ops: true,
60 pool_growth_size: 64 * 1024 * 1024, enable_mapped_memory: true,
62 enable_cuda_graphs: false, enable_cooperative_groups: false,
64 max_streams: 16,
65 }
66 }
67}
68
69#[derive(Debug, Clone)]
71pub struct CudaDeviceProperties {
72 pub device_id: u32,
73 pub name: String,
74 pub compute_capability: (u32, u32),
75 pub total_global_memory: usize,
76 pub shared_memory_per_block: usize,
77 pub warp_size: u32,
78 pub max_threads_per_block: u32,
79 pub max_blocks_per_multiprocessor: u32,
80 pub multiprocessor_count: u32,
81 pub memory_clock_rate: u32,
82 pub memory_bus_width: u32,
83 pub l2_cache_size: usize,
84 pub unified_addressing: bool,
85 pub managed_memory: bool,
86 pub concurrent_kernels: bool,
87 pub async_engine_count: u32,
88}
89
90#[derive(Debug, Clone, PartialEq, Eq, Hash)]
92pub enum CudaMemoryType {
93 Device,
94 Host,
95 Unified,
96 Mapped,
97 Array,
98 Texture,
99}
100
101pub struct CudaContext {
103 pub handle: *mut c_void,
105 pub device_id: u32,
107 pub flags: CudaContextFlags,
109 pub created_at: Instant,
111 pub streams: Vec<CudaStream>,
113}
114
115#[derive(Debug, Clone)]
117pub struct CudaContextFlags {
118 pub sched_auto: bool,
119 pub sched_spin: bool,
120 pub sched_yield: bool,
121 pub sched_blocking_sync: bool,
122 pub map_host: bool,
123 pub lmem_resize_to_max: bool,
124}
125
126impl Default for CudaContextFlags {
127 fn default() -> Self {
128 Self {
129 sched_auto: true,
130 sched_spin: false,
131 sched_yield: false,
132 sched_blocking_sync: false,
133 map_host: false,
134 lmem_resize_to_max: false,
135 }
136 }
137}
138
139pub struct CudaStream {
141 pub handle: *mut c_void,
143 pub id: u32,
145 pub priority: i32,
147 pub flags: CudaStreamFlags,
149 pub created_at: Instant,
151 pub operations: std::collections::VecDeque<CudaOperation>,
153}
154
155impl std::fmt::Debug for CudaStream {
156 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
157 f.debug_struct("CudaStream")
158 .field("handle", &format!("{:p}", self.handle))
159 .field("id", &self.id)
160 .field("priority", &self.priority)
161 .field("flags", &self.flags)
162 .field("created_at", &self.created_at)
163 .field("operations", &self.operations)
164 .finish()
165 }
166}
167
168#[derive(Debug, Clone)]
170pub struct CudaStreamFlags {
171 pub default: bool,
172 pub non_blocking: bool,
173 pub per_thread: bool,
174}
175
176impl Default for CudaStreamFlags {
177 fn default() -> Self {
178 Self {
179 default: true,
180 non_blocking: false,
181 per_thread: false,
182 }
183 }
184}
185
186#[derive(Debug, Clone)]
188pub struct CudaOperation {
189 pub op_type: CudaOperationType,
190 pub src_ptr: Option<*mut c_void>,
191 pub dst_ptr: Option<*mut c_void>,
192 pub size: usize,
193 pub timestamp: Instant,
194}
195
196#[derive(Debug, Clone)]
198pub enum CudaOperationType {
199 MemcpyHostToDevice,
200 MemcpyDeviceToHost,
201 MemcpyDeviceToDevice,
202 MemcpyAsync,
203 MemsetAsync,
204 KernelLaunch,
205 EventRecord,
206 EventSynchronize,
207}
208
209pub struct CudaMemoryPool {
211 memory_type: CudaMemoryType,
213 handle: *mut c_void,
215 current_size: usize,
217 max_size: usize,
219 used_size: usize,
221 free_blocks: std::collections::VecDeque<CudaMemoryBlock>,
223 allocated_blocks: HashMap<*mut c_void, CudaMemoryBlock>,
225}
226
227#[derive(Debug, Clone)]
229pub struct CudaMemoryBlock {
230 pub ptr: *mut c_void,
231 pub size: usize,
232 pub memory_type: CudaMemoryType,
233 pub allocated_at: Instant,
234 pub last_access: Option<Instant>,
235 pub ref_count: u32,
236}
237
238impl CudaMemoryPool {
239 pub fn new(memory_type: CudaMemoryType, max_size: usize) -> Self {
240 Self {
241 memory_type,
242 handle: std::ptr::null_mut(),
243 current_size: 0,
244 max_size,
245 used_size: 0,
246 free_blocks: std::collections::VecDeque::new(),
247 allocated_blocks: HashMap::new(),
248 }
249 }
250
251 pub fn allocate(&mut self, size: usize) -> Result<*mut c_void, CudaError> {
253 for i in 0..self.free_blocks.len() {
255 if self.free_blocks[i].size >= size {
256 let mut block = self.free_blocks.remove(i).unwrap();
257
258 if block.size > size * 2 {
260 let remaining_block = CudaMemoryBlock {
261 ptr: unsafe { block.ptr.add(size) },
262 size: block.size - size,
263 memory_type: block.memory_type.clone(),
264 allocated_at: block.allocated_at,
265 last_access: None,
266 ref_count: 0,
267 };
268 self.free_blocks.push_back(remaining_block);
269 block.size = size;
270 }
271
272 block.last_access = Some(Instant::now());
273 block.ref_count = 1;
274
275 let ptr = block.ptr;
276 self.allocated_blocks.insert(ptr, block);
277 self.used_size += size;
278
279 return Ok(ptr);
280 }
281 }
282
283 if self.current_size + size > self.max_size {
285 return Err(CudaError::OutOfMemory(
286 "Pool size limit exceeded".to_string(),
287 ));
288 }
289
290 let ptr = self.cuda_malloc(size)?;
291 let block = CudaMemoryBlock {
292 ptr,
293 size,
294 memory_type: self.memory_type.clone(),
295 allocated_at: Instant::now(),
296 last_access: Some(Instant::now()),
297 ref_count: 1,
298 };
299
300 self.allocated_blocks.insert(ptr, block);
301 self.current_size += size;
302 self.used_size += size;
303
304 Ok(ptr)
305 }
306
307 pub fn free(&mut self, ptr: *mut c_void) -> Result<(), CudaError> {
309 if let Some(block) = self.allocated_blocks.remove(&ptr) {
310 self.used_size -= block.size;
311
312 self.free_blocks.push_back(CudaMemoryBlock {
314 ptr: block.ptr,
315 size: block.size,
316 memory_type: block.memory_type,
317 allocated_at: block.allocated_at,
318 last_access: None,
319 ref_count: 0,
320 });
321
322 self.coalesce_free_blocks();
324
325 Ok(())
326 } else {
327 Err(CudaError::InvalidPointer(
328 "Pointer not found in pool".to_string(),
329 ))
330 }
331 }
332
333 fn coalesce_free_blocks(&mut self) {
334 let mut blocks: Vec<CudaMemoryBlock> = self.free_blocks.drain(..).collect();
336 blocks.sort_by_key(|block| block.ptr as usize);
337
338 let mut coalesced = Vec::new();
339 let mut current_block: Option<CudaMemoryBlock> = None;
340
341 for block in blocks {
342 match current_block.take() {
343 None => current_block = Some(block),
344 Some(mut prev_block) => {
345 let prev_end = prev_block.ptr as usize + prev_block.size;
346 let block_start = block.ptr as usize;
347
348 if prev_end == block_start {
349 prev_block.size += block.size;
351 current_block = Some(prev_block);
352 } else {
353 coalesced.push(prev_block);
354 current_block = Some(block);
355 }
356 }
357 }
358 }
359
360 if let Some(block) = current_block {
361 coalesced.push(block);
362 }
363
364 self.free_blocks = coalesced.into();
365 }
366
367 fn cuda_malloc(&self, size: usize) -> Result<*mut c_void, CudaError> {
368 match self.memory_type {
370 CudaMemoryType::Device => {
371 Ok(unsafe {
373 std::alloc::alloc(std::alloc::Layout::from_size_align_unchecked(size, 256))
374 as *mut c_void
375 })
376 }
377 CudaMemoryType::Host => {
378 Ok(unsafe {
380 std::alloc::alloc(std::alloc::Layout::from_size_align_unchecked(size, 256))
381 as *mut c_void
382 })
383 }
384 CudaMemoryType::Unified => {
385 Ok(unsafe {
387 std::alloc::alloc(std::alloc::Layout::from_size_align_unchecked(size, 256))
388 as *mut c_void
389 })
390 }
391 CudaMemoryType::Mapped => {
392 Ok(unsafe {
394 std::alloc::alloc(std::alloc::Layout::from_size_align_unchecked(size, 256))
395 as *mut c_void
396 })
397 }
398 _ => Err(CudaError::UnsupportedOperation(
399 "Unsupported memory type for allocation".to_string(),
400 )),
401 }
402 }
403}
404
405pub struct CudaStreamManager {
407 streams: Vec<CudaStream>,
409 stream_pool: std::collections::VecDeque<CudaStream>,
411 next_stream_id: u32,
413 config: CudaStreamConfig,
415}
416
417#[derive(Debug, Clone)]
419pub struct CudaStreamConfig {
420 pub default_priority: i32,
421 pub enable_priorities: bool,
422 pub max_operations_per_stream: usize,
423}
424
425impl Default for CudaStreamConfig {
426 fn default() -> Self {
427 Self {
428 default_priority: 0,
429 enable_priorities: true,
430 max_operations_per_stream: 1000,
431 }
432 }
433}
434
435impl CudaStreamManager {
436 pub fn new(config: CudaStreamConfig) -> Self {
437 Self {
438 streams: Vec::new(),
439 stream_pool: std::collections::VecDeque::new(),
440 next_stream_id: 0,
441 config,
442 }
443 }
444
445 pub fn create_stream(&mut self, priority: Option<i32>) -> Result<u32, CudaError> {
447 let stream_id = self.next_stream_id;
448 self.next_stream_id += 1;
449
450 let stream = CudaStream {
451 handle: std::ptr::null_mut(), id: stream_id,
453 priority: priority.unwrap_or(self.config.default_priority),
454 flags: CudaStreamFlags::default(),
455 created_at: Instant::now(),
456 operations: std::collections::VecDeque::new(),
457 };
458
459 self.streams.push(stream);
460 Ok(stream_id)
461 }
462
463 pub fn destroy_stream(&mut self, stream_id: u32) -> Result<(), CudaError> {
465 if let Some(pos) = self.streams.iter().position(|s| s.id == stream_id) {
466 let stream = self.streams.remove(pos);
467 Ok(())
469 } else {
470 Err(CudaError::InvalidStream("Stream not found".to_string()))
471 }
472 }
473
474 pub fn add_operation(
476 &mut self,
477 stream_id: u32,
478 operation: CudaOperation,
479 ) -> Result<(), CudaError> {
480 if let Some(stream) = self.streams.iter_mut().find(|s| s.id == stream_id) {
481 if stream.operations.len() >= self.config.max_operations_per_stream {
482 return Err(CudaError::StreamFull(
483 "Stream operation queue is full".to_string(),
484 ));
485 }
486
487 stream.operations.push_back(operation);
488 Ok(())
489 } else {
490 Err(CudaError::InvalidStream("Stream not found".to_string()))
491 }
492 }
493
494 pub fn synchronize_stream(&mut self, stream_id: u32) -> Result<(), CudaError> {
496 let mut operations = Vec::new();
498 if let Some(stream) = self.streams.iter_mut().find(|s| s.id == stream_id) {
499 while let Some(operation) = stream.operations.pop_front() {
500 operations.push(operation);
501 }
502 } else {
503 return Err(CudaError::InvalidStream("Stream not found".to_string()));
504 }
505
506 for operation in operations {
508 self.execute_operation(operation)?;
509 }
510
511 Ok(())
512 }
513
514 fn execute_operation(&self, operation: CudaOperation) -> Result<(), CudaError> {
515 match operation.op_type {
517 CudaOperationType::MemcpyHostToDevice => {
518 std::thread::sleep(Duration::from_micros(100));
520 }
521 CudaOperationType::MemcpyDeviceToHost => {
522 std::thread::sleep(Duration::from_micros(100));
524 }
525 CudaOperationType::MemcpyDeviceToDevice => {
526 std::thread::sleep(Duration::from_micros(50));
528 }
529 CudaOperationType::MemcpyAsync => {
530 std::thread::sleep(Duration::from_micros(10));
532 }
533 _ => {
534 }
536 }
537 Ok(())
538 }
539}
540
541#[derive(Debug, Clone, Default)]
543pub struct CudaStats {
544 pub total_allocations: u64,
545 pub total_deallocations: u64,
546 pub bytes_allocated: u64,
547 pub bytes_deallocated: u64,
548 pub device_memory_used: usize,
549 pub host_memory_used: usize,
550 pub unified_memory_used: usize,
551 pub stream_operations: u64,
552 pub kernel_launches: u64,
553 pub memory_transfers: u64,
554 pub average_allocation_time: Duration,
555 pub peak_memory_usage: usize,
556}
557
558impl CudaMemoryBackend {
559 pub fn new(config: CudaConfig) -> Result<Self, CudaError> {
561 let device_properties = Self::query_device_properties(config.device_id)?;
563
564 let mut memory_pools = HashMap::new();
566 if config.enable_memory_pools {
567 let pool_size = device_properties.total_global_memory / 4; memory_pools.insert(
569 CudaMemoryType::Device,
570 CudaMemoryPool::new(CudaMemoryType::Device, pool_size),
571 );
572 memory_pools.insert(
573 CudaMemoryType::Host,
574 CudaMemoryPool::new(CudaMemoryType::Host, pool_size),
575 );
576
577 if config.enable_unified_memory && device_properties.managed_memory {
578 memory_pools.insert(
579 CudaMemoryType::Unified,
580 CudaMemoryPool::new(CudaMemoryType::Unified, pool_size),
581 );
582 }
583 }
584
585 let stream_manager = CudaStreamManager::new(CudaStreamConfig::default());
586
587 Ok(Self {
588 config,
589 device_properties,
590 contexts: HashMap::new(),
591 memory_pools,
592 stats: CudaStats::default(),
593 stream_manager,
594 })
595 }
596
597 fn query_device_properties(device_id: u32) -> Result<CudaDeviceProperties, CudaError> {
599 Ok(CudaDeviceProperties {
601 device_id,
602 name: format!("CUDA Device {}", device_id),
603 compute_capability: (7, 5), total_global_memory: 8 * 1024 * 1024 * 1024, shared_memory_per_block: 48 * 1024, warp_size: 32,
607 max_threads_per_block: 1024,
608 max_blocks_per_multiprocessor: 16,
609 multiprocessor_count: 68,
610 memory_clock_rate: 7001000, memory_bus_width: 256,
612 l2_cache_size: 4 * 1024 * 1024, unified_addressing: true,
614 managed_memory: true,
615 concurrent_kernels: true,
616 async_engine_count: 2,
617 })
618 }
619
620 pub fn allocate(
622 &mut self,
623 size: usize,
624 memory_type: CudaMemoryType,
625 ) -> Result<*mut c_void, CudaError> {
626 let start_time = Instant::now();
627
628 let ptr = if self.config.enable_memory_pools {
629 if let Some(pool) = self.memory_pools.get_mut(&memory_type) {
630 pool.allocate(size)?
631 } else {
632 return Err(CudaError::UnsupportedMemoryType(
633 "Memory type not supported".to_string(),
634 ));
635 }
636 } else {
637 self.direct_allocate(size, memory_type.clone())?
639 };
640
641 self.stats.total_allocations += 1;
643 self.stats.bytes_allocated += size as u64;
644
645 match memory_type {
646 CudaMemoryType::Device => self.stats.device_memory_used += size,
647 CudaMemoryType::Host => self.stats.host_memory_used += size,
648 CudaMemoryType::Unified => self.stats.unified_memory_used += size,
649 _ => {}
650 }
651
652 let allocation_time = start_time.elapsed();
653 let total_time = self.stats.average_allocation_time.as_nanos() as u64
654 * (self.stats.total_allocations - 1)
655 + allocation_time.as_nanos() as u64;
656 self.stats.average_allocation_time =
657 Duration::from_nanos(total_time / self.stats.total_allocations);
658
659 let current_usage = self.stats.device_memory_used
660 + self.stats.host_memory_used
661 + self.stats.unified_memory_used;
662 if current_usage > self.stats.peak_memory_usage {
663 self.stats.peak_memory_usage = current_usage;
664 }
665
666 Ok(ptr)
667 }
668
669 fn direct_allocate(
670 &self,
671 size: usize,
672 memory_type: CudaMemoryType,
673 ) -> Result<*mut c_void, CudaError> {
674 match memory_type {
676 CudaMemoryType::Device => {
677 Ok(unsafe {
679 std::alloc::alloc(std::alloc::Layout::from_size_align_unchecked(size, 256))
680 as *mut c_void
681 })
682 }
683 CudaMemoryType::Host => {
684 Ok(unsafe {
686 std::alloc::alloc(std::alloc::Layout::from_size_align_unchecked(size, 256))
687 as *mut c_void
688 })
689 }
690 CudaMemoryType::Unified => {
691 if !self.device_properties.managed_memory {
693 return Err(CudaError::UnsupportedOperation(
694 "Unified memory not supported".to_string(),
695 ));
696 }
697 Ok(unsafe {
698 std::alloc::alloc(std::alloc::Layout::from_size_align_unchecked(size, 256))
699 as *mut c_void
700 })
701 }
702 _ => Err(CudaError::UnsupportedMemoryType(
703 "Unsupported memory type".to_string(),
704 )),
705 }
706 }
707
708 pub fn free(&mut self, ptr: *mut c_void, memory_type: CudaMemoryType) -> Result<(), CudaError> {
710 if self.config.enable_memory_pools {
711 if let Some(pool) = self.memory_pools.get_mut(&memory_type) {
712 pool.free(ptr)?;
713 } else {
714 return Err(CudaError::UnsupportedMemoryType(
715 "Memory type not supported".to_string(),
716 ));
717 }
718 } else {
719 unsafe {
721 std::alloc::dealloc(
722 ptr as *mut u8,
723 std::alloc::Layout::from_size_align_unchecked(1, 1),
724 );
725 }
726 }
727
728 self.stats.total_deallocations += 1;
729 Ok(())
730 }
731
732 pub fn memcpy(
734 &mut self,
735 dst: *mut c_void,
736 src: *const c_void,
737 size: usize,
738 kind: CudaMemcpyKind,
739 ) -> Result<(), CudaError> {
740 let operation = CudaOperation {
741 op_type: match kind {
742 CudaMemcpyKind::HostToDevice => CudaOperationType::MemcpyHostToDevice,
743 CudaMemcpyKind::DeviceToHost => CudaOperationType::MemcpyDeviceToHost,
744 CudaMemcpyKind::DeviceToDevice => CudaOperationType::MemcpyDeviceToDevice,
745 CudaMemcpyKind::HostToHost => CudaOperationType::MemcpyAsync,
746 },
747 src_ptr: Some(src as *mut c_void),
748 dst_ptr: Some(dst),
749 size,
750 timestamp: Instant::now(),
751 };
752
753 self.stream_manager.execute_operation(operation)?;
755 self.stats.memory_transfers += 1;
756
757 Ok(())
758 }
759
760 pub fn memcpy_async(
762 &mut self,
763 dst: *mut c_void,
764 src: *const c_void,
765 size: usize,
766 kind: CudaMemcpyKind,
767 stream_id: u32,
768 ) -> Result<(), CudaError> {
769 let operation = CudaOperation {
770 op_type: CudaOperationType::MemcpyAsync,
771 src_ptr: Some(src as *mut c_void),
772 dst_ptr: Some(dst),
773 size,
774 timestamp: Instant::now(),
775 };
776
777 self.stream_manager.add_operation(stream_id, operation)?;
778 Ok(())
779 }
780
781 pub fn create_context(&mut self, flags: CudaContextFlags) -> Result<u32, CudaError> {
783 let context_id = self.contexts.len() as u32;
784
785 let context = CudaContext {
786 handle: std::ptr::null_mut(), device_id: self.config.device_id,
788 flags,
789 created_at: Instant::now(),
790 streams: Vec::new(),
791 };
792
793 self.contexts.insert(context_id, context);
794 Ok(context_id)
795 }
796
797 pub fn get_device_properties(&self) -> &CudaDeviceProperties {
799 &self.device_properties
800 }
801
802 pub fn get_stats(&self) -> &CudaStats {
804 &self.stats
805 }
806
807 pub fn device_synchronize(&mut self) -> Result<(), CudaError> {
809 let stream_ids: Vec<u32> = self.stream_manager.streams.iter().map(|s| s.id).collect();
811 for stream_id in stream_ids {
812 self.stream_manager.synchronize_stream(stream_id)?;
813 }
814 Ok(())
815 }
816
817 pub fn create_stream(&mut self, priority: Option<i32>) -> Result<u32, CudaError> {
819 self.stream_manager.create_stream(priority)
820 }
821
822 pub fn destroy_stream(&mut self, stream_id: u32) -> Result<(), CudaError> {
824 self.stream_manager.destroy_stream(stream_id)
825 }
826}
827
828unsafe impl Send for CudaMemoryBackend {}
835unsafe impl Sync for CudaMemoryBackend {}
836
837#[derive(Debug, Clone)]
839pub enum CudaMemcpyKind {
840 HostToDevice,
841 DeviceToHost,
842 DeviceToDevice,
843 HostToHost,
844}
845
846#[derive(Debug, Clone)]
848pub enum CudaError {
849 DeviceNotFound(String),
850 OutOfMemory(String),
851 InvalidPointer(String),
852 InvalidStream(String),
853 StreamFull(String),
854 UnsupportedOperation(String),
855 UnsupportedMemoryType(String),
856 ContextCreationFailed(String),
857 KernelLaunchFailed(String),
858 SynchronizationFailed(String),
859 InternalError(String),
860}
861
862impl std::fmt::Display for CudaError {
863 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
864 match self {
865 CudaError::DeviceNotFound(msg) => write!(f, "Device not found: {}", msg),
866 CudaError::OutOfMemory(msg) => write!(f, "Out of memory: {}", msg),
867 CudaError::InvalidPointer(msg) => write!(f, "Invalid pointer: {}", msg),
868 CudaError::InvalidStream(msg) => write!(f, "Invalid stream: {}", msg),
869 CudaError::StreamFull(msg) => write!(f, "Stream full: {}", msg),
870 CudaError::UnsupportedOperation(msg) => write!(f, "Unsupported operation: {}", msg),
871 CudaError::UnsupportedMemoryType(msg) => write!(f, "Unsupported memory type: {}", msg),
872 CudaError::ContextCreationFailed(msg) => write!(f, "Context creation failed: {}", msg),
873 CudaError::KernelLaunchFailed(msg) => write!(f, "Kernel launch failed: {}", msg),
874 CudaError::SynchronizationFailed(msg) => write!(f, "Synchronization failed: {}", msg),
875 CudaError::InternalError(msg) => write!(f, "Internal error: {}", msg),
876 }
877 }
878}
879
880impl std::error::Error for CudaError {}
881
882pub struct ThreadSafeCudaBackend {
884 backend: Arc<Mutex<CudaMemoryBackend>>,
885}
886
887impl ThreadSafeCudaBackend {
888 pub fn new(config: CudaConfig) -> Result<Self, CudaError> {
889 let backend = CudaMemoryBackend::new(config)?;
890 Ok(Self {
891 backend: Arc::new(Mutex::new(backend)),
892 })
893 }
894
895 pub fn allocate(
896 &self,
897 size: usize,
898 memory_type: CudaMemoryType,
899 ) -> Result<*mut c_void, CudaError> {
900 let mut backend = self.backend.lock().unwrap();
901 backend.allocate(size, memory_type)
902 }
903
904 pub fn free(&self, ptr: *mut c_void, memory_type: CudaMemoryType) -> Result<(), CudaError> {
905 let mut backend = self.backend.lock().unwrap();
906 backend.free(ptr, memory_type)
907 }
908
909 pub fn get_stats(&self) -> CudaStats {
910 let backend = self.backend.lock().unwrap();
911 backend.get_stats().clone()
912 }
913}
914
915#[cfg(test)]
916mod tests {
917 use super::*;
918
919 #[test]
920 fn test_cuda_backend_creation() {
921 let config = CudaConfig::default();
922 let backend = CudaMemoryBackend::new(config);
923 assert!(backend.is_ok());
924 }
925
926 #[test]
927 fn test_memory_pool() {
928 let mut pool = CudaMemoryPool::new(CudaMemoryType::Device, 1024 * 1024);
929 let ptr = pool.allocate(1024);
930 assert!(ptr.is_ok());
931
932 let ptr = ptr.unwrap();
933 let result = pool.free(ptr);
934 assert!(result.is_ok());
935 }
936
937 #[test]
938 fn test_stream_manager() {
939 let mut manager = CudaStreamManager::new(CudaStreamConfig::default());
940 let stream_id = manager.create_stream(Some(1));
941 assert!(stream_id.is_ok());
942
943 let stream_id = stream_id.unwrap();
944 let result = manager.destroy_stream(stream_id);
945 assert!(result.is_ok());
946 }
947
948 #[test]
949 fn test_thread_safe_backend() {
950 let config = CudaConfig::default();
951 let backend = ThreadSafeCudaBackend::new(config);
952 assert!(backend.is_ok());
953
954 let backend = backend.unwrap();
955 let stats = backend.get_stats();
956 assert_eq!(stats.total_allocations, 0);
957 }
958}