Skip to main content

oxicuda_webgpu/
memory.rs

1//! WebGPU buffer manager — allocates, copies, and frees `wgpu::Buffer` objects
2//! through an opaque `u64` handle interface that mirrors the CUDA device-pointer
3//! model used by the rest of OxiCUDA.
4
5use std::{
6    collections::HashMap,
7    sync::{
8        Arc, Mutex,
9        atomic::{AtomicU64, Ordering},
10    },
11};
12
13use wgpu;
14
15use crate::{
16    device::WebGpuDevice,
17    error::{WebGpuError, WebGpuResult},
18};
19
20// ─── Buffer bookkeeping ──────────────────────────────────────────────────────
21
22/// Internal record for a single allocated `wgpu::Buffer`.
23pub struct WebGpuBufferInfo {
24    /// The GPU-resident buffer.
25    pub buffer: wgpu::Buffer,
26    /// Byte size of the buffer.
27    pub size: u64,
28}
29
30// ─── Memory manager ──────────────────────────────────────────────────────────
31
32/// Manages a pool of device-resident `wgpu::Buffer` objects, returning opaque
33/// `u64` handles to callers.
34///
35/// All public methods are `&self` to allow shared references from the backend.
36pub struct WebGpuMemoryManager {
37    device: Arc<WebGpuDevice>,
38    buffers: Mutex<HashMap<u64, WebGpuBufferInfo>>,
39    next_handle: AtomicU64,
40}
41
42impl WebGpuMemoryManager {
43    /// Create a new memory manager backed by `device`.
44    pub fn new(device: Arc<WebGpuDevice>) -> Self {
45        Self {
46            device,
47            buffers: Mutex::new(HashMap::new()),
48            next_handle: AtomicU64::new(1),
49        }
50    }
51
52    /// Allocate a new device buffer of `bytes` bytes.
53    ///
54    /// Returns an opaque handle that identifies the buffer.
55    pub fn alloc(&self, bytes: usize) -> WebGpuResult<u64> {
56        let size = bytes as u64;
57        let buffer = self.device.device.create_buffer(&wgpu::BufferDescriptor {
58            label: Some("oxicuda-webgpu-buffer"),
59            size,
60            usage: wgpu::BufferUsages::STORAGE
61                | wgpu::BufferUsages::COPY_SRC
62                | wgpu::BufferUsages::COPY_DST,
63            mapped_at_creation: false,
64        });
65
66        let handle = self.next_handle.fetch_add(1, Ordering::Relaxed);
67
68        self.buffers
69            .lock()
70            .map_err(|_| WebGpuError::BufferMapping("mutex poisoned".into()))?
71            .insert(handle, WebGpuBufferInfo { buffer, size });
72
73        Ok(handle)
74    }
75
76    /// Release the buffer associated with `handle`.
77    ///
78    /// The handle is silently ignored if it is unknown (already freed).
79    pub fn free(&self, handle: u64) -> WebGpuResult<()> {
80        self.buffers
81            .lock()
82            .map_err(|_| WebGpuError::BufferMapping("mutex poisoned".into()))?
83            .remove(&handle);
84        Ok(())
85    }
86
87    /// Upload `src` (host bytes) into the device buffer identified by `handle`.
88    pub fn copy_to_device(&self, handle: u64, src: &[u8]) -> WebGpuResult<()> {
89        let buffers = self
90            .buffers
91            .lock()
92            .map_err(|_| WebGpuError::BufferMapping("mutex poisoned".into()))?;
93
94        let buf_info = buffers
95            .get(&handle)
96            .ok_or_else(|| WebGpuError::InvalidArgument(format!("unknown handle {handle}")))?;
97
98        self.device.queue.write_buffer(&buf_info.buffer, 0, src);
99        Ok(())
100    }
101
102    /// Lock the internal buffer map and return a guard for direct access.
103    ///
104    /// Used by the backend to look up multiple buffers within a single lock scope
105    /// (e.g. when building wgpu bind groups for compute passes).
106    pub(crate) fn lock_buffers(
107        &self,
108    ) -> WebGpuResult<std::sync::MutexGuard<'_, HashMap<u64, WebGpuBufferInfo>>> {
109        self.buffers
110            .lock()
111            .map_err(|_| WebGpuError::BufferMapping("mutex poisoned".into()))
112    }
113
114    /// Download the device buffer identified by `handle` into `dst` (host bytes).
115    ///
116    /// Uses a temporary `MAP_READ` staging buffer and blocks until the GPU
117    /// work completes.
118    pub fn copy_from_device(&self, dst: &mut [u8], handle: u64) -> WebGpuResult<()> {
119        // Phase 1: acquire the lock, build a staging buffer + command encoder,
120        // and submit the copy.  The lock is dropped at the end of this block so
121        // that `device.poll()` (Phase 2) does not hold the mutex.
122        let staging = {
123            let buffers = self
124                .buffers
125                .lock()
126                .map_err(|_| WebGpuError::BufferMapping("mutex poisoned".into()))?;
127
128            let buf_info = buffers
129                .get(&handle)
130                .ok_or_else(|| WebGpuError::InvalidArgument(format!("unknown handle {handle}")))?;
131
132            let staging = self.device.device.create_buffer(&wgpu::BufferDescriptor {
133                label: Some("oxicuda-webgpu-staging"),
134                size: buf_info.size,
135                usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
136                mapped_at_creation: false,
137            });
138
139            let mut encoder =
140                self.device
141                    .device
142                    .create_command_encoder(&wgpu::CommandEncoderDescriptor {
143                        label: Some("oxicuda-webgpu-readback"),
144                    });
145
146            encoder.copy_buffer_to_buffer(&buf_info.buffer, 0, &staging, 0, buf_info.size);
147            self.device.queue.submit(std::iter::once(encoder.finish()));
148
149            staging
150            // Mutex guard dropped here — lock released before poll.
151        };
152
153        // Phase 2: map the staging buffer and read the data back to the host.
154        let slice = staging.slice(..);
155        let (tx, rx) = std::sync::mpsc::channel();
156        slice.map_async(wgpu::MapMode::Read, move |result| {
157            // Ignore send errors — the receiver may have been dropped.
158            let _ = tx.send(result);
159        });
160
161        // Block the calling thread until all submitted GPU work (including the
162        // copy) is complete.
163        let _ = self.device.device.poll(wgpu::PollType::wait_indefinitely());
164
165        rx.recv()
166            .map_err(|_| WebGpuError::BufferMapping("channel closed before map completed".into()))?
167            .map_err(|e| WebGpuError::BufferMapping(format!("{e:?}")))?;
168
169        let data = slice.get_mapped_range();
170        let copy_len = dst.len().min(data.len());
171        dst[..copy_len].copy_from_slice(&data[..copy_len]);
172        drop(data);
173        staging.unmap();
174
175        Ok(())
176    }
177}
178
179impl std::fmt::Debug for WebGpuMemoryManager {
180    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
181        let count = self.buffers.lock().map(|b| b.len()).unwrap_or(0);
182        write!(f, "WebGpuMemoryManager(buffers={})", count)
183    }
184}
185
186#[cfg(test)]
187mod tests {
188    use super::*;
189    use crate::device::WebGpuDevice;
190
191    fn try_get_device() -> Option<Arc<WebGpuDevice>> {
192        WebGpuDevice::new().ok().map(Arc::new)
193    }
194
195    #[test]
196    fn alloc_and_free_requires_device() {
197        let Some(dev) = try_get_device() else {
198            // No GPU — skip.
199            return;
200        };
201        let mm = WebGpuMemoryManager::new(dev);
202        let h = mm.alloc(256).expect("alloc 256 bytes");
203        assert!(h > 0);
204        mm.free(h).expect("free");
205        // Double-free is silently ignored.
206        mm.free(h).expect("double-free is a no-op");
207    }
208
209    #[test]
210    fn copy_roundtrip_requires_device() {
211        let Some(dev) = try_get_device() else {
212            return;
213        };
214        let mm = WebGpuMemoryManager::new(dev);
215
216        let src: Vec<u8> = (0u8..64).collect();
217        let h = mm.alloc(src.len()).expect("alloc");
218        mm.copy_to_device(h, &src).expect("copy_to_device");
219
220        let mut dst = vec![0u8; src.len()];
221        mm.copy_from_device(&mut dst, h).expect("copy_from_device");
222
223        assert_eq!(src, dst);
224        mm.free(h).expect("free");
225    }
226
227    #[test]
228    fn unknown_handle_returns_error() {
229        let Some(dev) = try_get_device() else {
230            return;
231        };
232        let mm = WebGpuMemoryManager::new(dev);
233        let err = mm.copy_to_device(9999, b"hello").unwrap_err();
234        assert!(matches!(err, WebGpuError::InvalidArgument(_)));
235    }
236}