cuda_rust_wasm/memory/
device_memory.rs

1//! Device memory allocation and management
2
3use crate::{Result, runtime_error};
4use crate::runtime::{Device, BackendType};
5use std::marker::PhantomData;
6use std::sync::Arc;
7use std::alloc::{alloc, dealloc, Layout};
8
9/// Raw device memory pointer
10pub struct DevicePtr {
11    raw: *mut u8,
12    size: usize,
13    backend: BackendType,
14}
15
16impl DevicePtr {
17    /// Allocate raw device memory
18    pub fn allocate(size: usize, device: &Arc<Device>) -> Result<Self> {
19        if size == 0 {
20            return Err(runtime_error!("Cannot allocate zero-sized buffer"));
21        }
22
23        let backend = device.backend();
24        let raw = match backend {
25            BackendType::Native => {
26                // TODO: Real CUDA allocation
27                // For now, use host memory as placeholder
28                unsafe {
29                    let layout = Layout::from_size_align(size, 8)
30                        .map_err(|e| runtime_error!("Invalid layout: {}", e))?;
31                    alloc(layout)
32                }
33            }
34            BackendType::WebGPU => {
35                // TODO: WebGPU buffer allocation
36                unsafe {
37                    let layout = Layout::from_size_align(size, 8)
38                        .map_err(|e| runtime_error!("Invalid layout: {}", e))?;
39                    alloc(layout)
40                }
41            }
42            BackendType::CPU => {
43                // CPU backend uses regular heap allocation
44                unsafe {
45                    let layout = Layout::from_size_align(size, 8)
46                        .map_err(|e| runtime_error!("Invalid layout: {}", e))?;
47                    alloc(layout)
48                }
49            }
50        };
51
52        if raw.is_null() {
53            return Err(runtime_error!("Failed to allocate {} bytes of device memory", size));
54        }
55
56        Ok(Self { raw, size, backend })
57    }
58
59    /// Get raw pointer
60    pub fn as_ptr(&self) -> *const u8 {
61        self.raw
62    }
63
64    /// Get mutable raw pointer
65    pub fn as_mut_ptr(&mut self) -> *mut u8 {
66        self.raw
67    }
68
69    /// Get allocation size
70    pub fn size(&self) -> usize {
71        self.size
72    }
73}
74
75impl Drop for DevicePtr {
76    fn drop(&mut self) {
77        if !self.raw.is_null() {
78            match self.backend {
79                BackendType::Native => {
80                    // TODO: Real CUDA deallocation
81                    unsafe {
82                        if let Ok(layout) = Layout::from_size_align(self.size, 8) {
83                            dealloc(self.raw, layout);
84                        }
85                    }
86                }
87                BackendType::WebGPU => {
88                    // TODO: WebGPU buffer deallocation
89                    unsafe {
90                        if let Ok(layout) = Layout::from_size_align(self.size, 8) {
91                            dealloc(self.raw, layout);
92                        }
93                    }
94                }
95                BackendType::CPU => {
96                    unsafe {
97                        if let Ok(layout) = Layout::from_size_align(self.size, 8) {
98                            dealloc(self.raw, layout);
99                        }
100                    }
101                }
102            }
103        }
104    }
105}
106
107/// Device memory buffer
108pub struct DeviceBuffer<T> {
109    ptr: DevicePtr,
110    len: usize,
111    device: Arc<Device>,
112    phantom: PhantomData<T>,
113}
114
115impl<T: Copy> DeviceBuffer<T> {
116    /// Allocate a new device buffer
117    pub fn new(len: usize, device: Arc<Device>) -> Result<Self> {
118        if len == 0 {
119            return Err(runtime_error!("Cannot allocate zero-length buffer"));
120        }
121
122        let size = len * std::mem::size_of::<T>();
123        let ptr = DevicePtr::allocate(size, &device)?;
124
125        Ok(Self {
126            ptr,
127            len,
128            device,
129            phantom: PhantomData,
130        })
131    }
132    
133    /// Get buffer length
134    pub fn len(&self) -> usize {
135        self.len
136    }
137    
138    /// Check if buffer is empty
139    pub fn is_empty(&self) -> bool {
140        self.len == 0
141    }
142
143    /// Get the device this buffer is allocated on
144    pub fn device(&self) -> &Arc<Device> {
145        &self.device
146    }
147
148    /// Get raw device pointer
149    /// 
150    /// # Safety
151    /// The caller must ensure that the returned pointer is not used after the `DeviceBuffer` is dropped.
152    /// The caller must also ensure that the memory is not accessed concurrently.
153    pub unsafe fn as_ptr(&self) -> *const T {
154        self.ptr.as_ptr() as *const T
155    }
156
157    /// Get mutable raw device pointer
158    /// 
159    /// # Safety
160    /// The caller must ensure that the returned pointer is not used after the `DeviceBuffer` is dropped.
161    /// The caller must also ensure that the memory is not accessed concurrently.
162    pub unsafe fn as_mut_ptr(&mut self) -> *mut T {
163        self.ptr.as_mut_ptr() as *mut T
164    }
165    
166    /// Copy from host memory
167    pub fn copy_from_host(&mut self, data: &[T]) -> Result<()> {
168        if data.len() != self.len {
169            return Err(runtime_error!(
170                "Host buffer length {} doesn't match device buffer length {}",
171                data.len(),
172                self.len
173            ));
174        }
175
176        let size = self.len * std::mem::size_of::<T>();
177        
178        match self.device.backend() {
179            BackendType::Native => {
180                // TODO: Real CUDA memcpy
181                unsafe {
182                    std::ptr::copy_nonoverlapping(
183                        data.as_ptr() as *const u8,
184                        self.ptr.as_mut_ptr(),
185                        size
186                    );
187                }
188            }
189            BackendType::WebGPU => {
190                // TODO: WebGPU buffer write
191                unsafe {
192                    std::ptr::copy_nonoverlapping(
193                        data.as_ptr() as *const u8,
194                        self.ptr.as_mut_ptr(),
195                        size
196                    );
197                }
198            }
199            BackendType::CPU => {
200                unsafe {
201                    std::ptr::copy_nonoverlapping(
202                        data.as_ptr() as *const u8,
203                        self.ptr.as_mut_ptr(),
204                        size
205                    );
206                }
207            }
208        }
209
210        Ok(())
211    }
212    
213    /// Copy to host memory
214    pub fn copy_to_host(&self, data: &mut [T]) -> Result<()> {
215        if data.len() != self.len {
216            return Err(runtime_error!(
217                "Host buffer length {} doesn't match device buffer length {}",
218                data.len(),
219                self.len
220            ));
221        }
222
223        let size = self.len * std::mem::size_of::<T>();
224        
225        match self.device.backend() {
226            BackendType::Native => {
227                // TODO: Real CUDA memcpy
228                unsafe {
229                    std::ptr::copy_nonoverlapping(
230                        self.ptr.as_ptr(),
231                        data.as_mut_ptr() as *mut u8,
232                        size
233                    );
234                }
235            }
236            BackendType::WebGPU => {
237                // TODO: WebGPU buffer read
238                unsafe {
239                    std::ptr::copy_nonoverlapping(
240                        self.ptr.as_ptr(),
241                        data.as_mut_ptr() as *mut u8,
242                        size
243                    );
244                }
245            }
246            BackendType::CPU => {
247                unsafe {
248                    std::ptr::copy_nonoverlapping(
249                        self.ptr.as_ptr(),
250                        data.as_mut_ptr() as *mut u8,
251                        size
252                    );
253                }
254            }
255        }
256
257        Ok(())
258    }
259
260    /// Fill buffer with a value
261    pub fn fill(&mut self, value: T) -> Result<()> {
262        // For now, copy to host, fill, and copy back
263        // TODO: Optimize with kernel-based fill
264        let host_data = vec![value; self.len];
265        self.copy_from_host(&host_data)
266    }
267}
268
269#[cfg(test)]
270mod tests {
271    use super::*;
272    use crate::runtime::Device;
273
274    #[test]
275    fn test_device_buffer_allocation() {
276        let device = Device::get_default().unwrap();
277        let buffer = DeviceBuffer::<f32>::new(1024, device).unwrap();
278        assert_eq!(buffer.len(), 1024);
279        assert!(!buffer.is_empty());
280    }
281
282    #[test]
283    fn test_host_device_copy() {
284        let device = Device::get_default().unwrap();
285        let mut buffer = DeviceBuffer::<f32>::new(100, device).unwrap();
286        
287        // Create test data
288        let host_data: Vec<f32> = (0..100).map(|i| i as f32).collect();
289        
290        // Copy to device
291        buffer.copy_from_host(&host_data).unwrap();
292        
293        // Copy back
294        let mut result = vec![0.0; 100];
295        buffer.copy_to_host(&mut result).unwrap();
296        
297        // Verify
298        assert_eq!(host_data, result);
299    }
300}