1use std::{
6 collections::HashMap,
7 sync::{
8 Arc, Mutex,
9 atomic::{AtomicU64, Ordering},
10 },
11};
12
13use wgpu;
14
15use crate::{
16 device::WebGpuDevice,
17 error::{WebGpuError, WebGpuResult},
18};
19
20pub struct WebGpuBufferInfo {
24 pub buffer: wgpu::Buffer,
26 pub size: u64,
28}
29
30pub struct WebGpuMemoryManager {
37 device: Arc<WebGpuDevice>,
38 buffers: Mutex<HashMap<u64, WebGpuBufferInfo>>,
39 next_handle: AtomicU64,
40}
41
42impl WebGpuMemoryManager {
43 pub fn new(device: Arc<WebGpuDevice>) -> Self {
45 Self {
46 device,
47 buffers: Mutex::new(HashMap::new()),
48 next_handle: AtomicU64::new(1),
49 }
50 }
51
52 pub fn alloc(&self, bytes: usize) -> WebGpuResult<u64> {
56 let size = bytes as u64;
57 let buffer = self.device.device.create_buffer(&wgpu::BufferDescriptor {
58 label: Some("oxicuda-webgpu-buffer"),
59 size,
60 usage: wgpu::BufferUsages::STORAGE
61 | wgpu::BufferUsages::COPY_SRC
62 | wgpu::BufferUsages::COPY_DST,
63 mapped_at_creation: false,
64 });
65
66 let handle = self.next_handle.fetch_add(1, Ordering::Relaxed);
67
68 self.buffers
69 .lock()
70 .map_err(|_| WebGpuError::BufferMapping("mutex poisoned".into()))?
71 .insert(handle, WebGpuBufferInfo { buffer, size });
72
73 Ok(handle)
74 }
75
76 pub fn free(&self, handle: u64) -> WebGpuResult<()> {
80 self.buffers
81 .lock()
82 .map_err(|_| WebGpuError::BufferMapping("mutex poisoned".into()))?
83 .remove(&handle);
84 Ok(())
85 }
86
87 pub fn copy_to_device(&self, handle: u64, src: &[u8]) -> WebGpuResult<()> {
89 let buffers = self
90 .buffers
91 .lock()
92 .map_err(|_| WebGpuError::BufferMapping("mutex poisoned".into()))?;
93
94 let buf_info = buffers
95 .get(&handle)
96 .ok_or_else(|| WebGpuError::InvalidArgument(format!("unknown handle {handle}")))?;
97
98 self.device.queue.write_buffer(&buf_info.buffer, 0, src);
99 Ok(())
100 }
101
102 pub(crate) fn lock_buffers(
107 &self,
108 ) -> WebGpuResult<std::sync::MutexGuard<'_, HashMap<u64, WebGpuBufferInfo>>> {
109 self.buffers
110 .lock()
111 .map_err(|_| WebGpuError::BufferMapping("mutex poisoned".into()))
112 }
113
114 pub fn copy_from_device(&self, dst: &mut [u8], handle: u64) -> WebGpuResult<()> {
119 let staging = {
123 let buffers = self
124 .buffers
125 .lock()
126 .map_err(|_| WebGpuError::BufferMapping("mutex poisoned".into()))?;
127
128 let buf_info = buffers
129 .get(&handle)
130 .ok_or_else(|| WebGpuError::InvalidArgument(format!("unknown handle {handle}")))?;
131
132 let staging = self.device.device.create_buffer(&wgpu::BufferDescriptor {
133 label: Some("oxicuda-webgpu-staging"),
134 size: buf_info.size,
135 usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
136 mapped_at_creation: false,
137 });
138
139 let mut encoder =
140 self.device
141 .device
142 .create_command_encoder(&wgpu::CommandEncoderDescriptor {
143 label: Some("oxicuda-webgpu-readback"),
144 });
145
146 encoder.copy_buffer_to_buffer(&buf_info.buffer, 0, &staging, 0, buf_info.size);
147 self.device.queue.submit(std::iter::once(encoder.finish()));
148
149 staging
150 };
152
153 let slice = staging.slice(..);
155 let (tx, rx) = std::sync::mpsc::channel();
156 slice.map_async(wgpu::MapMode::Read, move |result| {
157 let _ = tx.send(result);
159 });
160
161 let _ = self.device.device.poll(wgpu::PollType::wait_indefinitely());
164
165 rx.recv()
166 .map_err(|_| WebGpuError::BufferMapping("channel closed before map completed".into()))?
167 .map_err(|e| WebGpuError::BufferMapping(format!("{e:?}")))?;
168
169 let data = slice.get_mapped_range();
170 let copy_len = dst.len().min(data.len());
171 dst[..copy_len].copy_from_slice(&data[..copy_len]);
172 drop(data);
173 staging.unmap();
174
175 Ok(())
176 }
177}
178
179impl std::fmt::Debug for WebGpuMemoryManager {
180 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
181 let count = self.buffers.lock().map(|b| b.len()).unwrap_or(0);
182 write!(f, "WebGpuMemoryManager(buffers={})", count)
183 }
184}
185
186#[cfg(test)]
187mod tests {
188 use super::*;
189 use crate::device::WebGpuDevice;
190
191 fn try_get_device() -> Option<Arc<WebGpuDevice>> {
192 WebGpuDevice::new().ok().map(Arc::new)
193 }
194
195 #[test]
196 fn alloc_and_free_requires_device() {
197 let Some(dev) = try_get_device() else {
198 return;
200 };
201 let mm = WebGpuMemoryManager::new(dev);
202 let h = mm.alloc(256).expect("alloc 256 bytes");
203 assert!(h > 0);
204 mm.free(h).expect("free");
205 mm.free(h).expect("double-free is a no-op");
207 }
208
209 #[test]
210 fn copy_roundtrip_requires_device() {
211 let Some(dev) = try_get_device() else {
212 return;
213 };
214 let mm = WebGpuMemoryManager::new(dev);
215
216 let src: Vec<u8> = (0u8..64).collect();
217 let h = mm.alloc(src.len()).expect("alloc");
218 mm.copy_to_device(h, &src).expect("copy_to_device");
219
220 let mut dst = vec![0u8; src.len()];
221 mm.copy_from_device(&mut dst, h).expect("copy_from_device");
222
223 assert_eq!(src, dst);
224 mm.free(h).expect("free");
225 }
226
227 #[test]
228 fn unknown_handle_returns_error() {
229 let Some(dev) = try_get_device() else {
230 return;
231 };
232 let mm = WebGpuMemoryManager::new(dev);
233 let err = mm.copy_to_device(9999, b"hello").unwrap_err();
234 assert!(matches!(err, WebGpuError::InvalidArgument(_)));
235 }
236}