1#[allow(dead_code)]
8use std::collections::HashMap;
9use std::ffi::c_void;
10use std::ptr::NonNull;
11use std::sync::{Arc, Mutex};
12use std::time::{Duration, Instant};
13
14pub struct RocmMemoryBackend {
16 config: RocmConfig,
18 device_properties: RocmDeviceProperties,
20 contexts: HashMap<u32, HipContext>,
22 memory_pools: HashMap<RocmMemoryType, RocmMemoryPool>,
24 stats: RocmStats,
26 stream_manager: HipStreamManager,
28}
29
30#[derive(Debug, Clone)]
32pub struct RocmConfig {
33 pub device_id: u32,
35 pub enable_coarse_memory: bool,
37 pub enable_fine_memory: bool,
39 pub enable_memory_pools: bool,
41 pub enable_async_ops: bool,
43 pub pool_growth_size: usize,
45 pub enable_host_visible: bool,
47 pub enable_device_coherent: bool,
49 pub max_streams: u32,
51 pub enable_profiling: bool,
53}
54
55impl Default for RocmConfig {
56 fn default() -> Self {
57 Self {
58 device_id: 0,
59 enable_coarse_memory: true,
60 enable_fine_memory: true,
61 enable_memory_pools: true,
62 enable_async_ops: true,
63 pool_growth_size: 64 * 1024 * 1024, enable_host_visible: true,
65 enable_device_coherent: false,
66 max_streams: 16,
67 enable_profiling: false,
68 }
69 }
70}
71
72#[derive(Debug, Clone)]
74pub struct RocmDeviceProperties {
75 pub device_id: u32,
76 pub name: String,
77 pub arch: String,
78 pub gcn_arch_name: String,
79 pub total_global_memory: usize,
80 pub local_memory_size: usize,
81 pub max_work_group_size: u32,
82 pub max_work_item_dimensions: u32,
83 pub max_work_item_sizes: [u32; 3],
84 pub compute_units: u32,
85 pub wavefront_size: u32,
86 pub memory_clock_frequency: u32,
87 pub memory_bus_width: u32,
88 pub l2_cache_size: usize,
89 pub max_constant_buffer_size: usize,
90 pub pci_bus_id: u32,
91 pub pci_device_id: u32,
92 pub supports_cooperative_launch: bool,
93 pub supports_dynamic_parallelism: bool,
94}
95
96#[derive(Debug, Clone, PartialEq, Eq, Hash)]
98pub enum RocmMemoryType {
99 Device,
100 Host,
101 HostVisible,
102 DeviceCoherent,
103 CoarseGrained,
104 FineGrained,
105}
106
107pub struct HipContext {
109 pub handle: *mut c_void,
111 pub device_id: u32,
113 pub flags: HipContextFlags,
115 pub created_at: Instant,
117 pub streams: Vec<HipStream>,
119 pub memory_info: HipMemoryInfo,
121}
122
123#[derive(Debug, Clone)]
125pub struct HipContextFlags {
126 pub sched_auto: bool,
127 pub sched_spin: bool,
128 pub sched_yield: bool,
129 pub sched_blocking_sync: bool,
130 pub map_host: bool,
131}
132
133impl Default for HipContextFlags {
134 fn default() -> Self {
135 Self {
136 sched_auto: true,
137 sched_spin: false,
138 sched_yield: false,
139 sched_blocking_sync: false,
140 map_host: false,
141 }
142 }
143}
144
145#[derive(Debug, Clone)]
147pub struct HipMemoryInfo {
148 pub total_memory: usize,
149 pub free_memory: usize,
150 pub used_memory: usize,
151 pub coarse_memory: usize,
152 pub fine_memory: usize,
153}
154
155pub struct HipStream {
157 pub handle: *mut c_void,
159 pub id: u32,
161 pub priority: i32,
163 pub flags: HipStreamFlags,
165 pub created_at: Instant,
167 pub operations: std::collections::VecDeque<HipOperation>,
169}
170
171#[derive(Debug, Clone)]
173pub struct HipStreamFlags {
174 pub default: bool,
175 pub non_blocking: bool,
176 pub per_thread: bool,
177}
178
179impl Default for HipStreamFlags {
180 fn default() -> Self {
181 Self {
182 default: true,
183 non_blocking: false,
184 per_thread: false,
185 }
186 }
187}
188
189#[derive(Debug, Clone)]
191pub struct HipOperation {
192 pub op_type: HipOperationType,
193 pub src_ptr: Option<*mut c_void>,
194 pub dst_ptr: Option<*mut c_void>,
195 pub size: usize,
196 pub timestamp: Instant,
197}
198
199#[derive(Debug, Clone)]
201pub enum HipOperationType {
202 MemcpyHostToDevice,
203 MemcpyDeviceToHost,
204 MemcpyDeviceToDevice,
205 MemcpyAsync,
206 MemsetAsync,
207 KernelLaunch,
208 EventRecord,
209 EventSynchronize,
210}
211
212pub struct RocmMemoryPool {
214 memory_type: RocmMemoryType,
216 handle: *mut c_void,
218 current_size: usize,
220 max_size: usize,
222 used_size: usize,
224 free_blocks: std::collections::VecDeque<RocmMemoryBlock>,
226 allocated_blocks: HashMap<*mut c_void, RocmMemoryBlock>,
228 attributes: RocmMemoryAttributes,
230}
231
232#[derive(Debug, Clone)]
234pub struct RocmMemoryBlock {
235 pub ptr: *mut c_void,
236 pub size: usize,
237 pub memory_type: RocmMemoryType,
238 pub allocated_at: Instant,
239 pub last_access: Option<Instant>,
240 pub ref_count: u32,
241 pub agent_accessible: bool,
242}
243
244#[derive(Debug, Clone)]
246pub struct RocmMemoryAttributes {
247 pub is_coarse_grained: bool,
248 pub is_fine_grained: bool,
249 pub is_host_accessible: bool,
250 pub is_device_accessible: bool,
251 pub is_coherent: bool,
252 pub numa_node: Option<u32>,
253}
254
255impl Default for RocmMemoryAttributes {
256 fn default() -> Self {
257 Self {
258 is_coarse_grained: true,
259 is_fine_grained: false,
260 is_host_accessible: false,
261 is_device_accessible: true,
262 is_coherent: false,
263 numa_node: None,
264 }
265 }
266}
267
268impl RocmMemoryPool {
269 pub fn new(memory_type: RocmMemoryType, max_size: usize) -> Self {
270 let attributes = match memory_type {
271 RocmMemoryType::CoarseGrained => RocmMemoryAttributes {
272 is_coarse_grained: true,
273 is_fine_grained: false,
274 is_host_accessible: false,
275 is_device_accessible: true,
276 is_coherent: false,
277 numa_node: None,
278 },
279 RocmMemoryType::FineGrained => RocmMemoryAttributes {
280 is_coarse_grained: false,
281 is_fine_grained: true,
282 is_host_accessible: true,
283 is_device_accessible: true,
284 is_coherent: true,
285 numa_node: Some(0),
286 },
287 RocmMemoryType::HostVisible => RocmMemoryAttributes {
288 is_coarse_grained: false,
289 is_fine_grained: false,
290 is_host_accessible: true,
291 is_device_accessible: true,
292 is_coherent: false,
293 numa_node: None,
294 },
295 _ => RocmMemoryAttributes::default(),
296 };
297
298 Self {
299 memory_type,
300 handle: std::ptr::null_mut(),
301 current_size: 0,
302 max_size,
303 used_size: 0,
304 free_blocks: std::collections::VecDeque::new(),
305 allocated_blocks: HashMap::new(),
306 attributes,
307 }
308 }
309
310 pub fn allocate(&mut self, size: usize) -> Result<*mut c_void, RocmError> {
312 for i in 0..self.free_blocks.len() {
314 if self.free_blocks[i].size >= size {
315 let mut block = self.free_blocks.remove(i).expect("unwrap failed");
316
317 if block.size > size * 2 {
319 let remaining_block = RocmMemoryBlock {
320 ptr: unsafe { block.ptr.add(size) },
321 size: block.size - size,
322 memory_type: block.memory_type.clone(),
323 allocated_at: block.allocated_at,
324 last_access: None,
325 ref_count: 0,
326 agent_accessible: block.agent_accessible,
327 };
328 self.free_blocks.push_back(remaining_block);
329 block.size = size;
330 }
331
332 block.last_access = Some(Instant::now());
333 block.ref_count = 1;
334
335 let ptr = block.ptr;
336 self.allocated_blocks.insert(ptr, block);
337 self.used_size += size;
338
339 return Ok(ptr);
340 }
341 }
342
343 if self.current_size + size > self.max_size {
345 return Err(RocmError::OutOfMemory(
346 "Pool size limit exceeded".to_string(),
347 ));
348 }
349
350 let ptr = self.hip_malloc(size)?;
351 let block = RocmMemoryBlock {
352 ptr,
353 size,
354 memory_type: self.memory_type.clone(),
355 allocated_at: Instant::now(),
356 last_access: Some(Instant::now()),
357 ref_count: 1,
358 agent_accessible: self.attributes.is_device_accessible,
359 };
360
361 self.allocated_blocks.insert(ptr, block);
362 self.current_size += size;
363 self.used_size += size;
364
365 Ok(ptr)
366 }
367
368 pub fn free(&mut self, ptr: *mut c_void) -> Result<(), RocmError> {
370 if let Some(block) = self.allocated_blocks.remove(&ptr) {
371 self.used_size -= block.size;
372
373 self.free_blocks.push_back(RocmMemoryBlock {
375 ptr: block.ptr,
376 size: block.size,
377 memory_type: block.memory_type,
378 allocated_at: block.allocated_at,
379 last_access: None,
380 ref_count: 0,
381 agent_accessible: block.agent_accessible,
382 });
383
384 self.coalesce_free_blocks();
386
387 Ok(())
388 } else {
389 Err(RocmError::InvalidPointer(
390 "Pointer not found in pool".to_string(),
391 ))
392 }
393 }
394
395 fn coalesce_free_blocks(&mut self) {
396 let mut blocks: Vec<RocmMemoryBlock> = self.free_blocks.drain(..).collect();
398 blocks.sort_by_key(|block| block.ptr as usize);
399
400 let mut coalesced = Vec::new();
401 let mut current_block: Option<RocmMemoryBlock> = None;
402
403 for block in blocks {
404 match current_block.take() {
405 None => current_block = Some(block),
406 Some(mut prev_block) => {
407 let prev_end = prev_block.ptr as usize + prev_block.size;
408 let block_start = block.ptr as usize;
409
410 if prev_end == block_start && prev_block.memory_type == block.memory_type {
411 prev_block.size += block.size;
413 current_block = Some(prev_block);
414 } else {
415 coalesced.push(prev_block);
416 current_block = Some(block);
417 }
418 }
419 }
420 }
421
422 if let Some(block) = current_block {
423 coalesced.push(block);
424 }
425
426 self.free_blocks = coalesced.into();
427 }
428
429 fn hip_malloc(&self, size: usize) -> Result<*mut c_void, RocmError> {
430 match self.memory_type {
432 RocmMemoryType::Device => {
433 Ok(unsafe {
435 std::alloc::alloc(std::alloc::Layout::from_size_align_unchecked(size, 256))
436 as *mut c_void
437 })
438 }
439 RocmMemoryType::Host => {
440 Ok(unsafe {
442 std::alloc::alloc(std::alloc::Layout::from_size_align_unchecked(size, 256))
443 as *mut c_void
444 })
445 }
446 RocmMemoryType::CoarseGrained => {
447 Ok(unsafe {
449 std::alloc::alloc(std::alloc::Layout::from_size_align_unchecked(size, 256))
450 as *mut c_void
451 })
452 }
453 RocmMemoryType::FineGrained => {
454 Ok(unsafe {
456 std::alloc::alloc(std::alloc::Layout::from_size_align_unchecked(size, 256))
457 as *mut c_void
458 })
459 }
460 RocmMemoryType::HostVisible => {
461 Ok(unsafe {
463 std::alloc::alloc(std::alloc::Layout::from_size_align_unchecked(size, 256))
464 as *mut c_void
465 })
466 }
467 _ => Err(RocmError::UnsupportedOperation(
468 "Unsupported memory type for allocation".to_string(),
469 )),
470 }
471 }
472}
473
474pub struct HipStreamManager {
476 streams: Vec<HipStream>,
478 stream_pool: std::collections::VecDeque<HipStream>,
480 next_stream_id: u32,
482 config: HipStreamConfig,
484}
485
486#[derive(Debug, Clone)]
488pub struct HipStreamConfig {
489 pub default_priority: i32,
490 pub enable_priorities: bool,
491 pub max_operations_per_stream: usize,
492}
493
494impl Default for HipStreamConfig {
495 fn default() -> Self {
496 Self {
497 default_priority: 0,
498 enable_priorities: true,
499 max_operations_per_stream: 1000,
500 }
501 }
502}
503
504impl HipStreamManager {
505 pub fn new(config: HipStreamConfig) -> Self {
506 Self {
507 streams: Vec::new(),
508 stream_pool: std::collections::VecDeque::new(),
509 next_stream_id: 0,
510 config,
511 }
512 }
513
514 pub fn create_stream(&mut self, priority: Option<i32>) -> Result<u32, RocmError> {
516 let stream_id = self.next_stream_id;
517 self.next_stream_id += 1;
518
519 let stream = HipStream {
520 handle: std::ptr::null_mut(), id: stream_id,
522 priority: priority.unwrap_or(self.config.default_priority),
523 flags: HipStreamFlags::default(),
524 created_at: Instant::now(),
525 operations: std::collections::VecDeque::new(),
526 };
527
528 self.streams.push(stream);
529 Ok(stream_id)
530 }
531
532 pub fn destroy_stream(&mut self, stream_id: u32) -> Result<(), RocmError> {
534 if let Some(pos) = self.streams.iter().position(|s| s.id == stream_id) {
535 let stream = self.streams.remove(pos);
536 Ok(())
538 } else {
539 Err(RocmError::InvalidStream("Stream not found".to_string()))
540 }
541 }
542
543 pub fn add_operation(
545 &mut self,
546 stream_id: u32,
547 operation: HipOperation,
548 ) -> Result<(), RocmError> {
549 if let Some(stream) = self.streams.iter_mut().find(|s| s.id == stream_id) {
550 if stream.operations.len() >= self.config.max_operations_per_stream {
551 return Err(RocmError::StreamFull(
552 "Stream operation queue is full".to_string(),
553 ));
554 }
555
556 stream.operations.push_back(operation);
557 Ok(())
558 } else {
559 Err(RocmError::InvalidStream("Stream not found".to_string()))
560 }
561 }
562
563 pub fn synchronize_stream(&mut self, stream_id: u32) -> Result<(), RocmError> {
565 let mut operations = Vec::new();
567 if let Some(stream) = self.streams.iter_mut().find(|s| s.id == stream_id) {
568 while let Some(operation) = stream.operations.pop_front() {
569 operations.push(operation);
570 }
571 } else {
572 return Err(RocmError::InvalidStream("Stream not found".to_string()));
573 }
574
575 for operation in operations {
577 self.execute_operation(operation)?;
578 }
579
580 Ok(())
581 }
582
583 fn execute_operation(&self, operation: HipOperation) -> Result<(), RocmError> {
584 match operation.op_type {
586 HipOperationType::MemcpyHostToDevice => {
587 std::thread::sleep(Duration::from_micros(120));
589 }
590 HipOperationType::MemcpyDeviceToHost => {
591 std::thread::sleep(Duration::from_micros(120));
593 }
594 HipOperationType::MemcpyDeviceToDevice => {
595 std::thread::sleep(Duration::from_micros(60));
597 }
598 HipOperationType::MemcpyAsync => {
599 std::thread::sleep(Duration::from_micros(15));
601 }
602 _ => {
603 }
605 }
606 Ok(())
607 }
608}
609
610#[derive(Debug, Clone, Default)]
612pub struct RocmStats {
613 pub total_allocations: u64,
614 pub total_deallocations: u64,
615 pub bytes_allocated: u64,
616 pub bytes_deallocated: u64,
617 pub device_memory_used: usize,
618 pub host_memory_used: usize,
619 pub coarse_grained_used: usize,
620 pub fine_grained_used: usize,
621 pub stream_operations: u64,
622 pub kernel_launches: u64,
623 pub memory_transfers: u64,
624 pub average_allocation_time: Duration,
625 pub peak_memory_usage: usize,
626}
627
628impl RocmMemoryBackend {
629 pub fn new(config: RocmConfig) -> Result<Self, RocmError> {
631 let device_properties = Self::query_device_properties(config.device_id)?;
633
634 let mut memory_pools = HashMap::new();
636 if config.enable_memory_pools {
637 let pool_size = device_properties.total_global_memory / 4; memory_pools.insert(
640 RocmMemoryType::Device,
641 RocmMemoryPool::new(RocmMemoryType::Device, pool_size),
642 );
643 memory_pools.insert(
644 RocmMemoryType::Host,
645 RocmMemoryPool::new(RocmMemoryType::Host, pool_size),
646 );
647
648 if config.enable_coarse_memory {
649 memory_pools.insert(
650 RocmMemoryType::CoarseGrained,
651 RocmMemoryPool::new(RocmMemoryType::CoarseGrained, pool_size),
652 );
653 }
654
655 if config.enable_fine_memory {
656 memory_pools.insert(
657 RocmMemoryType::FineGrained,
658 RocmMemoryPool::new(RocmMemoryType::FineGrained, pool_size / 2),
659 );
660 }
661
662 if config.enable_host_visible {
663 memory_pools.insert(
664 RocmMemoryType::HostVisible,
665 RocmMemoryPool::new(RocmMemoryType::HostVisible, pool_size / 4),
666 );
667 }
668 }
669
670 let stream_manager = HipStreamManager::new(HipStreamConfig::default());
671
672 Ok(Self {
673 config,
674 device_properties,
675 contexts: HashMap::new(),
676 memory_pools,
677 stats: RocmStats::default(),
678 stream_manager,
679 })
680 }
681
682 fn query_device_properties(device_id: u32) -> Result<RocmDeviceProperties, RocmError> {
684 Ok(RocmDeviceProperties {
686 device_id,
687 name: format!("AMD GPU {}", device_id),
688 arch: "gfx906".to_string(), gcn_arch_name: "Vega20".to_string(),
690 total_global_memory: 16 * 1024 * 1024 * 1024, local_memory_size: 64 * 1024, max_work_group_size: 1024,
693 max_work_item_dimensions: 3,
694 max_work_item_sizes: [1024, 1024, 1024],
695 compute_units: 64,
696 wavefront_size: 64,
697 memory_clock_frequency: 1000000, memory_bus_width: 4096,
699 l2_cache_size: 4 * 1024 * 1024, max_constant_buffer_size: 64 * 1024, pci_bus_id: 0x03,
702 pci_device_id: 0x66AF,
703 supports_cooperative_launch: true,
704 supports_dynamic_parallelism: false,
705 })
706 }
707
708 pub fn allocate(
710 &mut self,
711 size: usize,
712 memory_type: RocmMemoryType,
713 ) -> Result<*mut c_void, RocmError> {
714 let start_time = Instant::now();
715
716 let ptr = if self.config.enable_memory_pools {
717 if let Some(pool) = self.memory_pools.get_mut(&memory_type) {
718 pool.allocate(size)?
719 } else {
720 return Err(RocmError::UnsupportedMemoryType(
721 "Memory type not supported".to_string(),
722 ));
723 }
724 } else {
725 self.direct_allocate(size, memory_type.clone())?
727 };
728
729 self.stats.total_allocations += 1;
731 self.stats.bytes_allocated += size as u64;
732
733 match memory_type {
734 RocmMemoryType::Device => self.stats.device_memory_used += size,
735 RocmMemoryType::Host => self.stats.host_memory_used += size,
736 RocmMemoryType::CoarseGrained => self.stats.coarse_grained_used += size,
737 RocmMemoryType::FineGrained => self.stats.fine_grained_used += size,
738 _ => {}
739 }
740
741 let allocation_time = start_time.elapsed();
742 let total_time = self.stats.average_allocation_time.as_nanos() as u64
743 * (self.stats.total_allocations - 1)
744 + allocation_time.as_nanos() as u64;
745 self.stats.average_allocation_time =
746 Duration::from_nanos(total_time / self.stats.total_allocations);
747
748 let current_usage = self.stats.device_memory_used
749 + self.stats.host_memory_used
750 + self.stats.coarse_grained_used
751 + self.stats.fine_grained_used;
752 if current_usage > self.stats.peak_memory_usage {
753 self.stats.peak_memory_usage = current_usage;
754 }
755
756 Ok(ptr)
757 }
758
759 fn direct_allocate(
760 &self,
761 size: usize,
762 memory_type: RocmMemoryType,
763 ) -> Result<*mut c_void, RocmError> {
764 match memory_type {
766 RocmMemoryType::Device => {
767 Ok(unsafe {
769 std::alloc::alloc(std::alloc::Layout::from_size_align_unchecked(size, 256))
770 as *mut c_void
771 })
772 }
773 RocmMemoryType::Host => {
774 Ok(unsafe {
776 std::alloc::alloc(std::alloc::Layout::from_size_align_unchecked(size, 256))
777 as *mut c_void
778 })
779 }
780 RocmMemoryType::CoarseGrained => {
781 Ok(unsafe {
783 std::alloc::alloc(std::alloc::Layout::from_size_align_unchecked(size, 256))
784 as *mut c_void
785 })
786 }
787 RocmMemoryType::FineGrained => {
788 Ok(unsafe {
790 std::alloc::alloc(std::alloc::Layout::from_size_align_unchecked(size, 256))
791 as *mut c_void
792 })
793 }
794 _ => Err(RocmError::UnsupportedMemoryType(
795 "Unsupported memory type".to_string(),
796 )),
797 }
798 }
799
800 pub fn free(&mut self, ptr: *mut c_void, memory_type: RocmMemoryType) -> Result<(), RocmError> {
802 if self.config.enable_memory_pools {
803 if let Some(pool) = self.memory_pools.get_mut(&memory_type) {
804 pool.free(ptr)?;
805 } else {
806 return Err(RocmError::UnsupportedMemoryType(
807 "Memory type not supported".to_string(),
808 ));
809 }
810 } else {
811 unsafe {
813 std::alloc::dealloc(
814 ptr as *mut u8,
815 std::alloc::Layout::from_size_align_unchecked(1, 1),
816 );
817 }
818 }
819
820 self.stats.total_deallocations += 1;
821 Ok(())
822 }
823
824 pub fn memcpy(
826 &mut self,
827 dst: *mut c_void,
828 src: *const c_void,
829 size: usize,
830 kind: RocmMemcpyKind,
831 ) -> Result<(), RocmError> {
832 let operation = HipOperation {
833 op_type: match kind {
834 RocmMemcpyKind::HostToDevice => HipOperationType::MemcpyHostToDevice,
835 RocmMemcpyKind::DeviceToHost => HipOperationType::MemcpyDeviceToHost,
836 RocmMemcpyKind::DeviceToDevice => HipOperationType::MemcpyDeviceToDevice,
837 RocmMemcpyKind::HostToHost => HipOperationType::MemcpyAsync,
838 },
839 src_ptr: Some(src as *mut c_void),
840 dst_ptr: Some(dst),
841 size,
842 timestamp: Instant::now(),
843 };
844
845 self.stream_manager.execute_operation(operation)?;
847 self.stats.memory_transfers += 1;
848
849 Ok(())
850 }
851
852 pub fn memcpy_async(
854 &mut self,
855 dst: *mut c_void,
856 src: *const c_void,
857 size: usize,
858 kind: RocmMemcpyKind,
859 stream_id: u32,
860 ) -> Result<(), RocmError> {
861 let operation = HipOperation {
862 op_type: HipOperationType::MemcpyAsync,
863 src_ptr: Some(src as *mut c_void),
864 dst_ptr: Some(dst),
865 size,
866 timestamp: Instant::now(),
867 };
868
869 self.stream_manager.add_operation(stream_id, operation)?;
870 Ok(())
871 }
872
873 pub fn create_context(&mut self, flags: HipContextFlags) -> Result<u32, RocmError> {
875 let context_id = self.contexts.len() as u32;
876
877 let memory_info = HipMemoryInfo {
878 total_memory: self.device_properties.total_global_memory,
879 free_memory: self.device_properties.total_global_memory - self.stats.device_memory_used,
880 used_memory: self.stats.device_memory_used,
881 coarse_memory: self.stats.coarse_grained_used,
882 fine_memory: self.stats.fine_grained_used,
883 };
884
885 let context = HipContext {
886 handle: std::ptr::null_mut(), device_id: self.config.device_id,
888 flags,
889 created_at: Instant::now(),
890 streams: Vec::new(),
891 memory_info,
892 };
893
894 self.contexts.insert(context_id, context);
895 Ok(context_id)
896 }
897
898 pub fn get_device_properties(&self) -> &RocmDeviceProperties {
900 &self.device_properties
901 }
902
903 pub fn get_stats(&self) -> &RocmStats {
905 &self.stats
906 }
907
908 pub fn device_synchronize(&mut self) -> Result<(), RocmError> {
910 let stream_ids: Vec<u32> = self.stream_manager.streams.iter().map(|s| s.id).collect();
912 for stream_id in stream_ids {
913 self.stream_manager.synchronize_stream(stream_id)?;
914 }
915 Ok(())
916 }
917
918 pub fn create_stream(&mut self, priority: Option<i32>) -> Result<u32, RocmError> {
920 self.stream_manager.create_stream(priority)
921 }
922
923 pub fn destroy_stream(&mut self, stream_id: u32) -> Result<(), RocmError> {
925 self.stream_manager.destroy_stream(stream_id)
926 }
927
928 pub fn query_memory_attributes(
930 &self,
931 ptr: *mut c_void,
932 ) -> Result<RocmMemoryAttributes, RocmError> {
933 Ok(RocmMemoryAttributes::default())
936 }
937}
938
939unsafe impl Send for RocmMemoryBackend {}
946unsafe impl Sync for RocmMemoryBackend {}
947
948#[derive(Debug, Clone)]
950pub enum RocmMemcpyKind {
951 HostToDevice,
952 DeviceToHost,
953 DeviceToDevice,
954 HostToHost,
955}
956
957#[derive(Debug, Clone)]
959pub enum RocmError {
960 DeviceNotFound(String),
961 OutOfMemory(String),
962 InvalidPointer(String),
963 InvalidStream(String),
964 StreamFull(String),
965 UnsupportedOperation(String),
966 UnsupportedMemoryType(String),
967 ContextCreationFailed(String),
968 KernelLaunchFailed(String),
969 SynchronizationFailed(String),
970 InternalError(String),
971}
972
973impl std::fmt::Display for RocmError {
974 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
975 match self {
976 RocmError::DeviceNotFound(msg) => write!(f, "Device not found: {}", msg),
977 RocmError::OutOfMemory(msg) => write!(f, "Out of memory: {}", msg),
978 RocmError::InvalidPointer(msg) => write!(f, "Invalid pointer: {}", msg),
979 RocmError::InvalidStream(msg) => write!(f, "Invalid stream: {}", msg),
980 RocmError::StreamFull(msg) => write!(f, "Stream full: {}", msg),
981 RocmError::UnsupportedOperation(msg) => write!(f, "Unsupported operation: {}", msg),
982 RocmError::UnsupportedMemoryType(msg) => write!(f, "Unsupported memory type: {}", msg),
983 RocmError::ContextCreationFailed(msg) => write!(f, "Context creation failed: {}", msg),
984 RocmError::KernelLaunchFailed(msg) => write!(f, "Kernel launch failed: {}", msg),
985 RocmError::SynchronizationFailed(msg) => write!(f, "Synchronization failed: {}", msg),
986 RocmError::InternalError(msg) => write!(f, "Internal error: {}", msg),
987 }
988 }
989}
990
991impl std::error::Error for RocmError {}
992
993pub struct ThreadSafeRocmBackend {
995 backend: Arc<Mutex<RocmMemoryBackend>>,
996}
997
998impl ThreadSafeRocmBackend {
999 pub fn new(config: RocmConfig) -> Result<Self, RocmError> {
1000 let backend = RocmMemoryBackend::new(config)?;
1001 Ok(Self {
1002 backend: Arc::new(Mutex::new(backend)),
1003 })
1004 }
1005
1006 pub fn allocate(
1007 &self,
1008 size: usize,
1009 memory_type: RocmMemoryType,
1010 ) -> Result<*mut c_void, RocmError> {
1011 let mut backend = self.backend.lock().expect("lock poisoned");
1012 backend.allocate(size, memory_type)
1013 }
1014
1015 pub fn free(&self, ptr: *mut c_void, memory_type: RocmMemoryType) -> Result<(), RocmError> {
1016 let mut backend = self.backend.lock().expect("lock poisoned");
1017 backend.free(ptr, memory_type)
1018 }
1019
1020 pub fn get_stats(&self) -> RocmStats {
1021 let backend = self.backend.lock().expect("lock poisoned");
1022 backend.get_stats().clone()
1023 }
1024}
1025
1026#[cfg(test)]
1027mod tests {
1028 use super::*;
1029
1030 #[test]
1031 fn test_rocm_backend_creation() {
1032 let config = RocmConfig::default();
1033 let backend = RocmMemoryBackend::new(config);
1034 assert!(backend.is_ok());
1035 }
1036
1037 #[test]
1038 fn test_memory_pool() {
1039 let mut pool = RocmMemoryPool::new(RocmMemoryType::CoarseGrained, 1024 * 1024);
1040 let ptr = pool.allocate(1024);
1041 assert!(ptr.is_ok());
1042
1043 let ptr = ptr.expect("unwrap failed");
1044 let result = pool.free(ptr);
1045 assert!(result.is_ok());
1046 }
1047
1048 #[test]
1049 fn test_hip_stream_manager() {
1050 let mut manager = HipStreamManager::new(HipStreamConfig::default());
1051 let stream_id = manager.create_stream(Some(1));
1052 assert!(stream_id.is_ok());
1053
1054 let stream_id = stream_id.expect("unwrap failed");
1055 let result = manager.destroy_stream(stream_id);
1056 assert!(result.is_ok());
1057 }
1058
1059 #[test]
1060 fn test_thread_safe_backend() {
1061 let config = RocmConfig::default();
1062 let backend = ThreadSafeRocmBackend::new(config);
1063 assert!(backend.is_ok());
1064
1065 let backend = backend.expect("unwrap failed");
1066 let stats = backend.get_stats();
1067 assert_eq!(stats.total_allocations, 0);
1068 }
1069}