Skip to main content

cuda_rust_wasm/memory/
unified_memory.rs

1//! Unified memory management for CUDA-Rust
2//!
3//! Provides unified memory allocation that can be accessed by both host and device
4
5use crate::{Result, memory_error};
6use std::sync::Arc;
7use std::alloc::{alloc, dealloc, Layout};
8use std::ptr::NonNull;
9
10/// Unified memory allocation
11pub struct UnifiedMemory {
12    ptr: NonNull<u8>,
13    size: usize,
14    layout: Layout,
15}
16
17impl UnifiedMemory {
18    /// Allocate unified memory
19    pub fn new(size: usize) -> Result<Self> {
20        if size == 0 {
21            return Err(memory_error!("Cannot allocate zero-sized unified memory"));
22        }
23
24        let layout = Layout::from_size_align(size, 8)
25            .map_err(|e| memory_error!("Invalid layout: {}", e))?;
26
27        let ptr = unsafe { alloc(layout) };
28        
29        let ptr = NonNull::new(ptr)
30            .ok_or_else(|| memory_error!("Failed to allocate unified memory"))?;
31
32        Ok(Self { ptr, size, layout })
33    }
34
35    /// Get a pointer to the memory
36    pub fn as_ptr(&self) -> *const u8 {
37        self.ptr.as_ptr() as *const u8
38    }
39
40    /// Get a mutable pointer to the memory
41    pub fn as_mut_ptr(&mut self) -> *mut u8 {
42        self.ptr.as_ptr()
43    }
44
45    /// Get the size of the allocation
46    pub fn size(&self) -> usize {
47        self.size
48    }
49
50    /// Copy data from host to unified memory
51    pub fn copy_from_slice(&mut self, data: &[u8]) -> Result<()> {
52        if data.len() > self.size {
53            return Err(memory_error!(
54                "Data size {} exceeds buffer size {}",
55                data.len(),
56                self.size
57            ));
58        }
59
60        unsafe {
61            std::ptr::copy_nonoverlapping(data.as_ptr(), self.ptr.as_ptr(), data.len());
62        }
63
64        Ok(())
65    }
66
67    /// Copy data from unified memory to host
68    pub fn copy_to_slice(&self, data: &mut [u8]) -> Result<()> {
69        if data.len() > self.size {
70            return Err(memory_error!(
71                "Destination size {} exceeds buffer size {}",
72                data.len(),
73                self.size
74            ));
75        }
76
77        unsafe {
78            std::ptr::copy_nonoverlapping(self.ptr.as_ptr(), data.as_mut_ptr(), data.len());
79        }
80
81        Ok(())
82    }
83}
84
85impl Drop for UnifiedMemory {
86    fn drop(&mut self) {
87        unsafe {
88            dealloc(self.ptr.as_ptr(), self.layout);
89        }
90    }
91}
92
93// Safety: UnifiedMemory can be safely sent between threads
94unsafe impl Send for UnifiedMemory {}
95unsafe impl Sync for UnifiedMemory {}
96
97/// Shared unified memory handle
98pub type SharedUnifiedMemory = Arc<UnifiedMemory>;
99
100/// Create a new shared unified memory allocation
101pub fn allocate_unified(size: usize) -> Result<SharedUnifiedMemory> {
102    Ok(Arc::new(UnifiedMemory::new(size)?))
103}
104
105/// Backend-aware unified memory that routes allocation through the active backend
106///
107/// When a GPU backend is available, this allocates memory that is accessible
108/// from both host and device via the backend's memory management. Falls back
109/// to host-only allocation when no GPU backend is present.
110pub struct ManagedMemory {
111    /// Underlying unified memory
112    inner: UnifiedMemory,
113    /// Whether this memory is registered with a backend
114    backend_registered: bool,
115}
116
117impl ManagedMemory {
118    /// Allocate managed memory (tries backend, falls back to host)
119    pub fn new(size: usize) -> Result<Self> {
120        let inner = UnifiedMemory::new(size)?;
121        let backend_registered = Self::try_register_with_backend(inner.as_ptr(), size);
122        Ok(Self {
123            inner,
124            backend_registered,
125        })
126    }
127
128    /// Check if memory is registered with a GPU backend
129    pub fn is_backend_registered(&self) -> bool {
130        self.backend_registered
131    }
132
133    /// Get the underlying unified memory
134    pub fn as_unified(&self) -> &UnifiedMemory {
135        &self.inner
136    }
137
138    /// Get a mutable reference to the underlying unified memory
139    pub fn as_unified_mut(&mut self) -> &mut UnifiedMemory {
140        &mut self.inner
141    }
142
143    /// Get size
144    pub fn size(&self) -> usize {
145        self.inner.size()
146    }
147
148    /// Copy from host slice
149    pub fn copy_from_slice(&mut self, data: &[u8]) -> Result<()> {
150        self.inner.copy_from_slice(data)
151    }
152
153    /// Copy to host slice
154    pub fn copy_to_slice(&self, data: &mut [u8]) -> Result<()> {
155        self.inner.copy_to_slice(data)
156    }
157
158    /// Prefetch to the device (hint for the runtime; no-op in CPU mode)
159    pub fn prefetch_to_device(&self) -> Result<()> {
160        // In CPU emulation, this is a no-op since all memory is host-accessible
161        Ok(())
162    }
163
164    /// Prefetch to the host (hint for the runtime; no-op in CPU mode)
165    pub fn prefetch_to_host(&self) -> Result<()> {
166        Ok(())
167    }
168
169    /// Try to register the allocation with the active GPU backend
170    fn try_register_with_backend(_ptr: *const u8, _size: usize) -> bool {
171        // Check if a GPU backend is available
172        let backend = crate::backend::get_backend();
173        let caps = backend.capabilities();
174        caps.supports_unified_memory
175    }
176}
177
178#[cfg(test)]
179mod tests {
180    use super::*;
181
182    #[test]
183    fn test_unified_memory_allocation() {
184        let mem = UnifiedMemory::new(1024).unwrap();
185        assert_eq!(mem.size(), 1024);
186    }
187
188    #[test]
189    fn test_unified_memory_copy() {
190        let mut mem = UnifiedMemory::new(256).unwrap();
191
192        let data = vec![42u8; 256];
193        mem.copy_from_slice(&data).unwrap();
194
195        let mut output = vec![0u8; 256];
196        mem.copy_to_slice(&mut output).unwrap();
197
198        assert_eq!(data, output);
199    }
200
201    #[test]
202    fn test_zero_size_allocation() {
203        let result = UnifiedMemory::new(0);
204        assert!(result.is_err());
205    }
206
207    #[test]
208    fn test_managed_memory() {
209        let mem = ManagedMemory::new(512).unwrap();
210        assert_eq!(mem.size(), 512);
211    }
212
213    #[test]
214    fn test_managed_memory_copy() {
215        let mut mem = ManagedMemory::new(128).unwrap();
216        let data = vec![0xAB_u8; 128];
217        mem.copy_from_slice(&data).unwrap();
218
219        let mut out = vec![0u8; 128];
220        mem.copy_to_slice(&mut out).unwrap();
221        assert_eq!(data, out);
222    }
223
224    #[test]
225    fn test_managed_memory_prefetch() {
226        let mem = ManagedMemory::new(64).unwrap();
227        assert!(mem.prefetch_to_device().is_ok());
228        assert!(mem.prefetch_to_host().is_ok());
229    }
230}