cuda_rust_wasm/memory/
device_memory.rs1use crate::{Result, runtime_error};
4use crate::runtime::{Device, BackendType};
5use std::marker::PhantomData;
6use std::sync::Arc;
7use std::alloc::{alloc, dealloc, Layout};
8
9pub struct DevicePtr {
11 raw: *mut u8,
12 size: usize,
13 backend: BackendType,
14}
15
16impl DevicePtr {
17 pub fn allocate(size: usize, device: &Arc<Device>) -> Result<Self> {
19 if size == 0 {
20 return Err(runtime_error!("Cannot allocate zero-sized buffer"));
21 }
22
23 let backend = device.backend();
24 let raw = match backend {
25 BackendType::Native => {
26 unsafe {
29 let layout = Layout::from_size_align(size, 8)
30 .map_err(|e| runtime_error!("Invalid layout: {}", e))?;
31 alloc(layout)
32 }
33 }
34 BackendType::WebGPU => {
35 unsafe {
37 let layout = Layout::from_size_align(size, 8)
38 .map_err(|e| runtime_error!("Invalid layout: {}", e))?;
39 alloc(layout)
40 }
41 }
42 BackendType::CPU => {
43 unsafe {
45 let layout = Layout::from_size_align(size, 8)
46 .map_err(|e| runtime_error!("Invalid layout: {}", e))?;
47 alloc(layout)
48 }
49 }
50 };
51
52 if raw.is_null() {
53 return Err(runtime_error!("Failed to allocate {} bytes of device memory", size));
54 }
55
56 Ok(Self { raw, size, backend })
57 }
58
59 pub fn as_ptr(&self) -> *const u8 {
61 self.raw
62 }
63
64 pub fn as_mut_ptr(&mut self) -> *mut u8 {
66 self.raw
67 }
68
69 pub fn size(&self) -> usize {
71 self.size
72 }
73}
74
75impl Drop for DevicePtr {
76 fn drop(&mut self) {
77 if !self.raw.is_null() {
78 match self.backend {
79 BackendType::Native => {
80 unsafe {
82 if let Ok(layout) = Layout::from_size_align(self.size, 8) {
83 dealloc(self.raw, layout);
84 }
85 }
86 }
87 BackendType::WebGPU => {
88 unsafe {
90 if let Ok(layout) = Layout::from_size_align(self.size, 8) {
91 dealloc(self.raw, layout);
92 }
93 }
94 }
95 BackendType::CPU => {
96 unsafe {
97 if let Ok(layout) = Layout::from_size_align(self.size, 8) {
98 dealloc(self.raw, layout);
99 }
100 }
101 }
102 }
103 }
104 }
105}
106
107pub struct DeviceBuffer<T> {
109 ptr: DevicePtr,
110 len: usize,
111 device: Arc<Device>,
112 phantom: PhantomData<T>,
113}
114
115impl<T: Copy> DeviceBuffer<T> {
116 pub fn new(len: usize, device: Arc<Device>) -> Result<Self> {
118 if len == 0 {
119 return Err(runtime_error!("Cannot allocate zero-length buffer"));
120 }
121
122 let size = len * std::mem::size_of::<T>();
123 let ptr = DevicePtr::allocate(size, &device)?;
124
125 Ok(Self {
126 ptr,
127 len,
128 device,
129 phantom: PhantomData,
130 })
131 }
132
133 pub fn len(&self) -> usize {
135 self.len
136 }
137
138 pub fn is_empty(&self) -> bool {
140 self.len == 0
141 }
142
143 pub fn device(&self) -> &Arc<Device> {
145 &self.device
146 }
147
148 pub unsafe fn as_ptr(&self) -> *const T {
154 self.ptr.as_ptr() as *const T
155 }
156
157 pub unsafe fn as_mut_ptr(&mut self) -> *mut T {
163 self.ptr.as_mut_ptr() as *mut T
164 }
165
166 pub fn copy_from_host(&mut self, data: &[T]) -> Result<()> {
168 if data.len() != self.len {
169 return Err(runtime_error!(
170 "Host buffer length {} doesn't match device buffer length {}",
171 data.len(),
172 self.len
173 ));
174 }
175
176 let size = self.len * std::mem::size_of::<T>();
177
178 match self.device.backend() {
179 BackendType::Native => {
180 unsafe {
182 std::ptr::copy_nonoverlapping(
183 data.as_ptr() as *const u8,
184 self.ptr.as_mut_ptr(),
185 size
186 );
187 }
188 }
189 BackendType::WebGPU => {
190 unsafe {
192 std::ptr::copy_nonoverlapping(
193 data.as_ptr() as *const u8,
194 self.ptr.as_mut_ptr(),
195 size
196 );
197 }
198 }
199 BackendType::CPU => {
200 unsafe {
201 std::ptr::copy_nonoverlapping(
202 data.as_ptr() as *const u8,
203 self.ptr.as_mut_ptr(),
204 size
205 );
206 }
207 }
208 }
209
210 Ok(())
211 }
212
213 pub fn copy_to_host(&self, data: &mut [T]) -> Result<()> {
215 if data.len() != self.len {
216 return Err(runtime_error!(
217 "Host buffer length {} doesn't match device buffer length {}",
218 data.len(),
219 self.len
220 ));
221 }
222
223 let size = self.len * std::mem::size_of::<T>();
224
225 match self.device.backend() {
226 BackendType::Native => {
227 unsafe {
229 std::ptr::copy_nonoverlapping(
230 self.ptr.as_ptr(),
231 data.as_mut_ptr() as *mut u8,
232 size
233 );
234 }
235 }
236 BackendType::WebGPU => {
237 unsafe {
239 std::ptr::copy_nonoverlapping(
240 self.ptr.as_ptr(),
241 data.as_mut_ptr() as *mut u8,
242 size
243 );
244 }
245 }
246 BackendType::CPU => {
247 unsafe {
248 std::ptr::copy_nonoverlapping(
249 self.ptr.as_ptr(),
250 data.as_mut_ptr() as *mut u8,
251 size
252 );
253 }
254 }
255 }
256
257 Ok(())
258 }
259
260 pub fn fill(&mut self, value: T) -> Result<()> {
262 let host_data = vec![value; self.len];
265 self.copy_from_host(&host_data)
266 }
267}
268
269#[cfg(test)]
270mod tests {
271 use super::*;
272 use crate::runtime::Device;
273
274 #[test]
275 fn test_device_buffer_allocation() {
276 let device = Device::get_default().unwrap();
277 let buffer = DeviceBuffer::<f32>::new(1024, device).unwrap();
278 assert_eq!(buffer.len(), 1024);
279 assert!(!buffer.is_empty());
280 }
281
282 #[test]
283 fn test_host_device_copy() {
284 let device = Device::get_default().unwrap();
285 let mut buffer = DeviceBuffer::<f32>::new(100, device).unwrap();
286
287 let host_data: Vec<f32> = (0..100).map(|i| i as f32).collect();
289
290 buffer.copy_from_host(&host_data).unwrap();
292
293 let mut result = vec![0.0; 100];
295 buffer.copy_to_host(&mut result).unwrap();
296
297 assert_eq!(host_data, result);
299 }
300}