1use std::{
6 collections::HashMap,
7 sync::{
8 Arc, Mutex,
9 atomic::{AtomicU64, Ordering},
10 },
11};
12
13use wgpu;
14
15use crate::{
16 device::WebGpuDevice,
17 error::{WebGpuError, WebGpuResult},
18};
19
20pub struct WebGpuBufferInfo {
24 pub buffer: wgpu::Buffer,
26 pub size: u64,
28}
29
30pub struct WebGpuMemoryManager {
37 device: Arc<WebGpuDevice>,
38 buffers: Mutex<HashMap<u64, WebGpuBufferInfo>>,
39 next_handle: AtomicU64,
40}
41
42impl WebGpuMemoryManager {
43 pub fn new(device: Arc<WebGpuDevice>) -> Self {
45 Self {
46 device,
47 buffers: Mutex::new(HashMap::new()),
48 next_handle: AtomicU64::new(1),
49 }
50 }
51
52 pub fn alloc(&self, bytes: usize) -> WebGpuResult<u64> {
56 let size = bytes as u64;
57 let buffer = self.device.device.create_buffer(&wgpu::BufferDescriptor {
58 label: Some("oxicuda-webgpu-buffer"),
59 size,
60 usage: wgpu::BufferUsages::STORAGE
61 | wgpu::BufferUsages::COPY_SRC
62 | wgpu::BufferUsages::COPY_DST,
63 mapped_at_creation: false,
64 });
65
66 let handle = self.next_handle.fetch_add(1, Ordering::Relaxed);
67
68 self.buffers
69 .lock()
70 .map_err(|_| WebGpuError::BufferMapping("mutex poisoned".into()))?
71 .insert(handle, WebGpuBufferInfo { buffer, size });
72
73 Ok(handle)
74 }
75
76 pub fn free(&self, handle: u64) -> WebGpuResult<()> {
80 self.buffers
81 .lock()
82 .map_err(|_| WebGpuError::BufferMapping("mutex poisoned".into()))?
83 .remove(&handle);
84 Ok(())
85 }
86
87 pub fn copy_to_device(&self, handle: u64, src: &[u8]) -> WebGpuResult<()> {
89 let buffers = self
90 .buffers
91 .lock()
92 .map_err(|_| WebGpuError::BufferMapping("mutex poisoned".into()))?;
93
94 let buf_info = buffers
95 .get(&handle)
96 .ok_or_else(|| WebGpuError::InvalidArgument(format!("unknown handle {handle}")))?;
97
98 self.device.queue.write_buffer(&buf_info.buffer, 0, src);
99 Ok(())
100 }
101
102 pub fn copy_from_device(&self, dst: &mut [u8], handle: u64) -> WebGpuResult<()> {
107 let staging = {
111 let buffers = self
112 .buffers
113 .lock()
114 .map_err(|_| WebGpuError::BufferMapping("mutex poisoned".into()))?;
115
116 let buf_info = buffers
117 .get(&handle)
118 .ok_or_else(|| WebGpuError::InvalidArgument(format!("unknown handle {handle}")))?;
119
120 let staging = self.device.device.create_buffer(&wgpu::BufferDescriptor {
121 label: Some("oxicuda-webgpu-staging"),
122 size: buf_info.size,
123 usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
124 mapped_at_creation: false,
125 });
126
127 let mut encoder =
128 self.device
129 .device
130 .create_command_encoder(&wgpu::CommandEncoderDescriptor {
131 label: Some("oxicuda-webgpu-readback"),
132 });
133
134 encoder.copy_buffer_to_buffer(&buf_info.buffer, 0, &staging, 0, buf_info.size);
135 self.device.queue.submit(std::iter::once(encoder.finish()));
136
137 staging
138 };
140
141 let slice = staging.slice(..);
143 let (tx, rx) = std::sync::mpsc::channel();
144 slice.map_async(wgpu::MapMode::Read, move |result| {
145 let _ = tx.send(result);
147 });
148
149 let _ = self.device.device.poll(wgpu::PollType::wait_indefinitely());
152
153 rx.recv()
154 .map_err(|_| WebGpuError::BufferMapping("channel closed before map completed".into()))?
155 .map_err(|e| WebGpuError::BufferMapping(format!("{e:?}")))?;
156
157 let data = slice.get_mapped_range();
158 let copy_len = dst.len().min(data.len());
159 dst[..copy_len].copy_from_slice(&data[..copy_len]);
160 drop(data);
161 staging.unmap();
162
163 Ok(())
164 }
165}
166
167impl std::fmt::Debug for WebGpuMemoryManager {
168 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
169 let count = self.buffers.lock().map(|b| b.len()).unwrap_or(0);
170 write!(f, "WebGpuMemoryManager(buffers={})", count)
171 }
172}
173
174#[cfg(test)]
175mod tests {
176 use super::*;
177 use crate::device::WebGpuDevice;
178
179 fn try_get_device() -> Option<Arc<WebGpuDevice>> {
180 WebGpuDevice::new().ok().map(Arc::new)
181 }
182
183 #[test]
184 fn alloc_and_free_requires_device() {
185 let Some(dev) = try_get_device() else {
186 return;
188 };
189 let mm = WebGpuMemoryManager::new(dev);
190 let h = mm.alloc(256).expect("alloc 256 bytes");
191 assert!(h > 0);
192 mm.free(h).expect("free");
193 mm.free(h).expect("double-free is a no-op");
195 }
196
197 #[test]
198 fn copy_roundtrip_requires_device() {
199 let Some(dev) = try_get_device() else {
200 return;
201 };
202 let mm = WebGpuMemoryManager::new(dev);
203
204 let src: Vec<u8> = (0u8..64).collect();
205 let h = mm.alloc(src.len()).expect("alloc");
206 mm.copy_to_device(h, &src).expect("copy_to_device");
207
208 let mut dst = vec![0u8; src.len()];
209 mm.copy_from_device(&mut dst, h).expect("copy_from_device");
210
211 assert_eq!(src, dst);
212 mm.free(h).expect("free");
213 }
214
215 #[test]
216 fn unknown_handle_returns_error() {
217 let Some(dev) = try_get_device() else {
218 return;
219 };
220 let mm = WebGpuMemoryManager::new(dev);
221 let err = mm.copy_to_device(9999, b"hello").unwrap_err();
222 assert!(matches!(err, WebGpuError::InvalidArgument(_)));
223 }
224}