Skip to main content

oxicuda_webgpu/
memory.rs

1//! WebGPU buffer manager — allocates, copies, and frees `wgpu::Buffer` objects
2//! through an opaque `u64` handle interface that mirrors the CUDA device-pointer
3//! model used by the rest of OxiCUDA.
4
5use std::{
6    collections::HashMap,
7    sync::{
8        Arc, Mutex,
9        atomic::{AtomicU64, Ordering},
10    },
11};
12
13use wgpu;
14
15use crate::{
16    device::WebGpuDevice,
17    error::{WebGpuError, WebGpuResult},
18};
19
20// ─── Buffer bookkeeping ──────────────────────────────────────────────────────
21
22/// Internal record for a single allocated `wgpu::Buffer`.
23pub struct WebGpuBufferInfo {
24    /// The GPU-resident buffer.
25    pub buffer: wgpu::Buffer,
26    /// Byte size of the buffer.
27    pub size: u64,
28}
29
30// ─── Memory manager ──────────────────────────────────────────────────────────
31
32/// Manages a pool of device-resident `wgpu::Buffer` objects, returning opaque
33/// `u64` handles to callers.
34///
35/// All public methods are `&self` to allow shared references from the backend.
36pub struct WebGpuMemoryManager {
37    device: Arc<WebGpuDevice>,
38    buffers: Mutex<HashMap<u64, WebGpuBufferInfo>>,
39    next_handle: AtomicU64,
40}
41
42impl WebGpuMemoryManager {
43    /// Create a new memory manager backed by `device`.
44    pub fn new(device: Arc<WebGpuDevice>) -> Self {
45        Self {
46            device,
47            buffers: Mutex::new(HashMap::new()),
48            next_handle: AtomicU64::new(1),
49        }
50    }
51
52    /// Allocate a new device buffer of `bytes` bytes.
53    ///
54    /// Returns an opaque handle that identifies the buffer.
55    pub fn alloc(&self, bytes: usize) -> WebGpuResult<u64> {
56        let size = bytes as u64;
57        let buffer = self.device.device.create_buffer(&wgpu::BufferDescriptor {
58            label: Some("oxicuda-webgpu-buffer"),
59            size,
60            usage: wgpu::BufferUsages::STORAGE
61                | wgpu::BufferUsages::COPY_SRC
62                | wgpu::BufferUsages::COPY_DST,
63            mapped_at_creation: false,
64        });
65
66        let handle = self.next_handle.fetch_add(1, Ordering::Relaxed);
67
68        self.buffers
69            .lock()
70            .map_err(|_| WebGpuError::BufferMapping("mutex poisoned".into()))?
71            .insert(handle, WebGpuBufferInfo { buffer, size });
72
73        Ok(handle)
74    }
75
76    /// Release the buffer associated with `handle`.
77    ///
78    /// The handle is silently ignored if it is unknown (already freed).
79    pub fn free(&self, handle: u64) -> WebGpuResult<()> {
80        self.buffers
81            .lock()
82            .map_err(|_| WebGpuError::BufferMapping("mutex poisoned".into()))?
83            .remove(&handle);
84        Ok(())
85    }
86
87    /// Upload `src` (host bytes) into the device buffer identified by `handle`.
88    pub fn copy_to_device(&self, handle: u64, src: &[u8]) -> WebGpuResult<()> {
89        let buffers = self
90            .buffers
91            .lock()
92            .map_err(|_| WebGpuError::BufferMapping("mutex poisoned".into()))?;
93
94        let buf_info = buffers
95            .get(&handle)
96            .ok_or_else(|| WebGpuError::InvalidArgument(format!("unknown handle {handle}")))?;
97
98        self.device.queue.write_buffer(&buf_info.buffer, 0, src);
99        Ok(())
100    }
101
102    /// Download the device buffer identified by `handle` into `dst` (host bytes).
103    ///
104    /// Uses a temporary `MAP_READ` staging buffer and blocks until the GPU
105    /// work completes.
106    pub fn copy_from_device(&self, dst: &mut [u8], handle: u64) -> WebGpuResult<()> {
107        // Phase 1: acquire the lock, build a staging buffer + command encoder,
108        // and submit the copy.  The lock is dropped at the end of this block so
109        // that `device.poll()` (Phase 2) does not hold the mutex.
110        let staging = {
111            let buffers = self
112                .buffers
113                .lock()
114                .map_err(|_| WebGpuError::BufferMapping("mutex poisoned".into()))?;
115
116            let buf_info = buffers
117                .get(&handle)
118                .ok_or_else(|| WebGpuError::InvalidArgument(format!("unknown handle {handle}")))?;
119
120            let staging = self.device.device.create_buffer(&wgpu::BufferDescriptor {
121                label: Some("oxicuda-webgpu-staging"),
122                size: buf_info.size,
123                usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
124                mapped_at_creation: false,
125            });
126
127            let mut encoder =
128                self.device
129                    .device
130                    .create_command_encoder(&wgpu::CommandEncoderDescriptor {
131                        label: Some("oxicuda-webgpu-readback"),
132                    });
133
134            encoder.copy_buffer_to_buffer(&buf_info.buffer, 0, &staging, 0, buf_info.size);
135            self.device.queue.submit(std::iter::once(encoder.finish()));
136
137            staging
138            // Mutex guard dropped here — lock released before poll.
139        };
140
141        // Phase 2: map the staging buffer and read the data back to the host.
142        let slice = staging.slice(..);
143        let (tx, rx) = std::sync::mpsc::channel();
144        slice.map_async(wgpu::MapMode::Read, move |result| {
145            // Ignore send errors — the receiver may have been dropped.
146            let _ = tx.send(result);
147        });
148
149        // Block the calling thread until all submitted GPU work (including the
150        // copy) is complete.
151        let _ = self.device.device.poll(wgpu::PollType::wait_indefinitely());
152
153        rx.recv()
154            .map_err(|_| WebGpuError::BufferMapping("channel closed before map completed".into()))?
155            .map_err(|e| WebGpuError::BufferMapping(format!("{e:?}")))?;
156
157        let data = slice.get_mapped_range();
158        let copy_len = dst.len().min(data.len());
159        dst[..copy_len].copy_from_slice(&data[..copy_len]);
160        drop(data);
161        staging.unmap();
162
163        Ok(())
164    }
165}
166
167impl std::fmt::Debug for WebGpuMemoryManager {
168    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
169        let count = self.buffers.lock().map(|b| b.len()).unwrap_or(0);
170        write!(f, "WebGpuMemoryManager(buffers={})", count)
171    }
172}
173
174#[cfg(test)]
175mod tests {
176    use super::*;
177    use crate::device::WebGpuDevice;
178
179    fn try_get_device() -> Option<Arc<WebGpuDevice>> {
180        WebGpuDevice::new().ok().map(Arc::new)
181    }
182
183    #[test]
184    fn alloc_and_free_requires_device() {
185        let Some(dev) = try_get_device() else {
186            // No GPU — skip.
187            return;
188        };
189        let mm = WebGpuMemoryManager::new(dev);
190        let h = mm.alloc(256).expect("alloc 256 bytes");
191        assert!(h > 0);
192        mm.free(h).expect("free");
193        // Double-free is silently ignored.
194        mm.free(h).expect("double-free is a no-op");
195    }
196
197    #[test]
198    fn copy_roundtrip_requires_device() {
199        let Some(dev) = try_get_device() else {
200            return;
201        };
202        let mm = WebGpuMemoryManager::new(dev);
203
204        let src: Vec<u8> = (0u8..64).collect();
205        let h = mm.alloc(src.len()).expect("alloc");
206        mm.copy_to_device(h, &src).expect("copy_to_device");
207
208        let mut dst = vec![0u8; src.len()];
209        mm.copy_from_device(&mut dst, h).expect("copy_from_device");
210
211        assert_eq!(src, dst);
212        mm.free(h).expect("free");
213    }
214
215    #[test]
216    fn unknown_handle_returns_error() {
217        let Some(dev) = try_get_device() else {
218            return;
219        };
220        let mm = WebGpuMemoryManager::new(dev);
221        let err = mm.copy_to_device(9999, b"hello").unwrap_err();
222        assert!(matches!(err, WebGpuError::InvalidArgument(_)));
223    }
224}