cuda_rust_wasm/memory/
unified_memory.rs1use crate::{Result, memory_error};
6use std::sync::Arc;
7use std::alloc::{alloc, dealloc, Layout};
8use std::ptr::NonNull;
9
10pub struct UnifiedMemory {
12 ptr: NonNull<u8>,
13 size: usize,
14 layout: Layout,
15}
16
17impl UnifiedMemory {
18 pub fn new(size: usize) -> Result<Self> {
20 if size == 0 {
21 return Err(memory_error!("Cannot allocate zero-sized unified memory"));
22 }
23
24 let layout = Layout::from_size_align(size, 8)
25 .map_err(|e| memory_error!("Invalid layout: {}", e))?;
26
27 let ptr = unsafe { alloc(layout) };
28
29 let ptr = NonNull::new(ptr)
30 .ok_or_else(|| memory_error!("Failed to allocate unified memory"))?;
31
32 Ok(Self { ptr, size, layout })
33 }
34
35 pub fn as_ptr(&self) -> *const u8 {
37 self.ptr.as_ptr() as *const u8
38 }
39
40 pub fn as_mut_ptr(&mut self) -> *mut u8 {
42 self.ptr.as_ptr()
43 }
44
45 pub fn size(&self) -> usize {
47 self.size
48 }
49
50 pub fn copy_from_slice(&mut self, data: &[u8]) -> Result<()> {
52 if data.len() > self.size {
53 return Err(memory_error!(
54 "Data size {} exceeds buffer size {}",
55 data.len(),
56 self.size
57 ));
58 }
59
60 unsafe {
61 std::ptr::copy_nonoverlapping(data.as_ptr(), self.ptr.as_ptr(), data.len());
62 }
63
64 Ok(())
65 }
66
67 pub fn copy_to_slice(&self, data: &mut [u8]) -> Result<()> {
69 if data.len() > self.size {
70 return Err(memory_error!(
71 "Destination size {} exceeds buffer size {}",
72 data.len(),
73 self.size
74 ));
75 }
76
77 unsafe {
78 std::ptr::copy_nonoverlapping(self.ptr.as_ptr(), data.as_mut_ptr(), data.len());
79 }
80
81 Ok(())
82 }
83}
84
85impl Drop for UnifiedMemory {
86 fn drop(&mut self) {
87 unsafe {
88 dealloc(self.ptr.as_ptr(), self.layout);
89 }
90 }
91}
92
93unsafe impl Send for UnifiedMemory {}
95unsafe impl Sync for UnifiedMemory {}
96
97pub type SharedUnifiedMemory = Arc<UnifiedMemory>;
99
100pub fn allocate_unified(size: usize) -> Result<SharedUnifiedMemory> {
102 Ok(Arc::new(UnifiedMemory::new(size)?))
103}
104
105pub struct ManagedMemory {
111 inner: UnifiedMemory,
113 backend_registered: bool,
115}
116
117impl ManagedMemory {
118 pub fn new(size: usize) -> Result<Self> {
120 let inner = UnifiedMemory::new(size)?;
121 let backend_registered = Self::try_register_with_backend(inner.as_ptr(), size);
122 Ok(Self {
123 inner,
124 backend_registered,
125 })
126 }
127
128 pub fn is_backend_registered(&self) -> bool {
130 self.backend_registered
131 }
132
133 pub fn as_unified(&self) -> &UnifiedMemory {
135 &self.inner
136 }
137
138 pub fn as_unified_mut(&mut self) -> &mut UnifiedMemory {
140 &mut self.inner
141 }
142
143 pub fn size(&self) -> usize {
145 self.inner.size()
146 }
147
148 pub fn copy_from_slice(&mut self, data: &[u8]) -> Result<()> {
150 self.inner.copy_from_slice(data)
151 }
152
153 pub fn copy_to_slice(&self, data: &mut [u8]) -> Result<()> {
155 self.inner.copy_to_slice(data)
156 }
157
158 pub fn prefetch_to_device(&self) -> Result<()> {
160 Ok(())
162 }
163
164 pub fn prefetch_to_host(&self) -> Result<()> {
166 Ok(())
167 }
168
169 fn try_register_with_backend(_ptr: *const u8, _size: usize) -> bool {
171 let backend = crate::backend::get_backend();
173 let caps = backend.capabilities();
174 caps.supports_unified_memory
175 }
176}
177
178#[cfg(test)]
179mod tests {
180 use super::*;
181
182 #[test]
183 fn test_unified_memory_allocation() {
184 let mem = UnifiedMemory::new(1024).unwrap();
185 assert_eq!(mem.size(), 1024);
186 }
187
188 #[test]
189 fn test_unified_memory_copy() {
190 let mut mem = UnifiedMemory::new(256).unwrap();
191
192 let data = vec![42u8; 256];
193 mem.copy_from_slice(&data).unwrap();
194
195 let mut output = vec![0u8; 256];
196 mem.copy_to_slice(&mut output).unwrap();
197
198 assert_eq!(data, output);
199 }
200
201 #[test]
202 fn test_zero_size_allocation() {
203 let result = UnifiedMemory::new(0);
204 assert!(result.is_err());
205 }
206
207 #[test]
208 fn test_managed_memory() {
209 let mem = ManagedMemory::new(512).unwrap();
210 assert_eq!(mem.size(), 512);
211 }
212
213 #[test]
214 fn test_managed_memory_copy() {
215 let mut mem = ManagedMemory::new(128).unwrap();
216 let data = vec![0xAB_u8; 128];
217 mem.copy_from_slice(&data).unwrap();
218
219 let mut out = vec![0u8; 128];
220 mem.copy_to_slice(&mut out).unwrap();
221 assert_eq!(data, out);
222 }
223
224 #[test]
225 fn test_managed_memory_prefetch() {
226 let mem = ManagedMemory::new(64).unwrap();
227 assert!(mem.prefetch_to_device().is_ok());
228 assert!(mem.prefetch_to_host().is_ok());
229 }
230}