Skip to main content

trueno/backends/gpu/batch/
mod.rs

1//! Async GPU command batching for reduced transfer overhead
2//!
3//! This module provides an async API for GPU operations that batches multiple
4//! operations together to minimize CPU↔GPU data transfers.
5//!
6//! # Motivation
7//!
8//! The synchronous GPU API transfers data for each operation:
9//! ```text
10//! vec.relu()      // Upload → GPU compute → Download
11//! vec.scale(2.0)  // Upload → GPU compute → Download
12//! vec.add(&other) // Upload → GPU compute → Download
13//! Total: 6 transfers (3 up, 3 down)
14//! ```
15//!
16//! The async batch API queues operations and executes them together:
17//! ```text
18//! batch.relu(input)
19//! batch.scale(relu_out, 2.0)
20//! batch.add(scaled, other)
21//! batch.execute()  // Upload once → 3 GPU computes → Download once
22//! Total: 2 transfers (1 up, 1 down)  // 3x reduction!
23//! ```
24//!
25//! # Example
26//!
27//! ```rust,no_run
28//! use trueno::backends::gpu::{GpuDevice, GpuCommandBatch};
29//!
30//! # async fn example() -> Result<(), String> {
31//! let device = GpuDevice::new()?;
32//! let mut batch = GpuCommandBatch::new(device);
33//!
34//! // Queue operations (no GPU execution yet)
35//! let input = batch.upload(&[1.0, 2.0, -3.0, 4.0]);
36//! let relu_out = batch.relu(input);
37//! let scaled = batch.scale(relu_out, 2.0);
38//! let other = batch.upload(&[0.5, 0.5, 0.5, 0.5]);
39//! let final_out = batch.add(scaled, other);
40//!
41//! // Execute all operations in single batch
42//! batch.execute().await?;
43//!
44//! // Read final result
45//! let result = batch.read(final_out).await?;
46//! assert_eq!(result, vec![2.5, 4.5, 0.5, 8.5]);
47//! # Ok(())
48//! # }
49//! ```
50
51mod execute;
52
53pub use execute::dispatch::PipelineCache;
54
55#[cfg(test)]
56mod tests;
57
58use super::GpuDevice;
59use std::collections::HashMap;
60use std::sync::Arc;
61use wgpu;
62
63/// Unique identifier for a buffer in a batch
64#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
65pub struct BufferId(pub(crate) usize);
66
67/// GPU operation to be executed in a batch
68#[derive(Debug)]
69pub(crate) enum GpuOp {
70    /// ReLU activation: max(0, x)
71    Relu { input: BufferId, output: BufferId },
72
73    /// Scalar multiplication: x * scalar
74    Scale { input: BufferId, output: BufferId, scalar: f32 },
75
76    /// Element-wise addition: a + b
77    Add { a: BufferId, b: BufferId, output: BufferId },
78
79    /// Element-wise multiplication: a * b
80    Mul { a: BufferId, b: BufferId, output: BufferId },
81
82    /// Dot product: sum(a[i] * b[i])
83    Dot {
84        a: BufferId,
85        b: BufferId,
86        output: BufferId, // Single-element buffer for result
87    },
88
89    /// Sigmoid activation: 1 / (1 + exp(-x))
90    Sigmoid { input: BufferId, output: BufferId },
91
92    /// Hyperbolic tangent: tanh(x)
93    Tanh { input: BufferId, output: BufferId },
94
95    /// Swish activation: x * sigmoid(x)
96    Swish { input: BufferId, output: BufferId },
97
98    /// GELU activation: x * Φ(x) where Φ is cumulative distribution function
99    Gelu { input: BufferId, output: BufferId },
100
101    /// Element-wise subtraction: a - b
102    Sub { a: BufferId, b: BufferId, output: BufferId },
103
104    /// Matrix multiplication: C = A × B
105    /// A is M×K, B is K×N, C is M×N (all row-major)
106    Matmul { a: BufferId, b: BufferId, output: BufferId, m: u32, k: u32, n: u32 },
107}
108
109/// Command batch for async GPU execution
110///
111/// Accumulates GPU operations and executes them together to minimize
112/// CPU↔GPU data transfers.
113pub struct GpuCommandBatch {
114    pub(crate) device: Arc<GpuDevice>,
115    pub(crate) operations: Vec<GpuOp>,
116    pub(crate) buffers: HashMap<BufferId, BufferInfo>,
117    pub(crate) next_buffer_id: usize,
118}
119
120/// Information about a buffer in the batch
121#[derive(Debug)]
122pub(crate) struct BufferInfo {
123    /// Size in elements (f32)
124    pub(crate) size: usize,
125
126    /// Initial data to upload (if any)
127    pub(crate) data: Option<Vec<f32>>,
128
129    /// GPU buffer (created during execute(), or pre-existing for imported buffers).
130    /// Wrapped in `Arc` to allow sharing across multiple batch executions (KAIZEN-015).
131    /// When `Some`, execute() skips buffer creation (already GPU-resident).
132    pub(crate) gpu_buffer: Option<Arc<wgpu::Buffer>>,
133}
134
135impl GpuCommandBatch {
136    /// Create a new command batch
137    pub fn new(device: GpuDevice) -> Self {
138        Self {
139            device: Arc::new(device),
140            operations: Vec::new(),
141            buffers: HashMap::new(),
142            next_buffer_id: 0,
143        }
144    }
145
146    /// Allocate a new buffer ID
147    fn alloc_buffer(&mut self, size: usize, data: Option<Vec<f32>>) -> BufferId {
148        let id = BufferId(self.next_buffer_id);
149        self.next_buffer_id += 1;
150
151        self.buffers.insert(id, BufferInfo { size, data, gpu_buffer: None });
152
153        id
154    }
155
156    /// Upload data to GPU (queued for batch execution)
157    ///
158    /// Returns a buffer ID that can be used in subsequent operations.
159    pub fn upload(&mut self, data: &[f32]) -> BufferId {
160        self.alloc_buffer(data.len(), Some(data.to_vec()))
161    }
162
163    /// Allocate an output buffer for an operation
164    fn alloc_output(&mut self, size: usize) -> BufferId {
165        self.alloc_buffer(size, None)
166    }
167
168    /// Queue ReLU operation: max(0, x)
169    ///
170    /// Returns buffer ID for the output.
171    pub fn relu(&mut self, input: BufferId) -> BufferId {
172        let size = self.buffers.get(&input).expect("Invalid buffer ID").size;
173
174        let output = self.alloc_output(size);
175
176        self.operations.push(GpuOp::Relu { input, output });
177
178        output
179    }
180
181    /// Queue scalar multiplication: x * scalar
182    ///
183    /// Returns buffer ID for the output.
184    pub fn scale(&mut self, input: BufferId, scalar: f32) -> BufferId {
185        let size = self.buffers.get(&input).expect("Invalid buffer ID").size;
186
187        let output = self.alloc_output(size);
188
189        self.operations.push(GpuOp::Scale { input, output, scalar });
190
191        output
192    }
193
194    /// Queue element-wise addition: a + b
195    ///
196    /// Returns buffer ID for the output.
197    ///
198    /// # Panics
199    ///
200    /// Panics if buffers have different sizes.
201    pub fn add(&mut self, a: BufferId, b: BufferId) -> BufferId {
202        let size_a = self.buffers.get(&a).expect("Invalid buffer ID").size;
203        let size_b = self.buffers.get(&b).expect("Invalid buffer ID").size;
204
205        assert_eq!(size_a, size_b, "Buffer size mismatch: {} vs {}", size_a, size_b);
206
207        let output = self.alloc_output(size_a);
208
209        self.operations.push(GpuOp::Add { a, b, output });
210
211        output
212    }
213
214    /// Queue element-wise multiplication: a * b
215    ///
216    /// Returns buffer ID for the output.
217    ///
218    /// # Panics
219    ///
220    /// Panics if buffers have different sizes.
221    pub fn mul(&mut self, a: BufferId, b: BufferId) -> BufferId {
222        let size_a = self.buffers.get(&a).expect("Invalid buffer ID").size;
223        let size_b = self.buffers.get(&b).expect("Invalid buffer ID").size;
224
225        assert_eq!(size_a, size_b, "Buffer size mismatch: {} vs {}", size_a, size_b);
226
227        let output = self.alloc_output(size_a);
228
229        self.operations.push(GpuOp::Mul { a, b, output });
230
231        output
232    }
233
234    /// Queue dot product: sum(a[i] * b[i])
235    ///
236    /// Returns buffer ID for a single-element output buffer.
237    ///
238    /// # Panics
239    ///
240    /// Panics if buffers have different sizes.
241    pub fn dot(&mut self, a: BufferId, b: BufferId) -> BufferId {
242        let size_a = self.buffers.get(&a).expect("Invalid buffer ID").size;
243        let size_b = self.buffers.get(&b).expect("Invalid buffer ID").size;
244
245        assert_eq!(size_a, size_b, "Buffer size mismatch: {} vs {}", size_a, size_b);
246
247        let output = self.alloc_output(1); // Dot product returns scalar
248
249        self.operations.push(GpuOp::Dot { a, b, output });
250
251        output
252    }
253
254    /// Queue sigmoid activation: 1 / (1 + exp(-x))
255    ///
256    /// Returns buffer ID for the output.
257    pub fn sigmoid(&mut self, input: BufferId) -> BufferId {
258        let size = self.buffers.get(&input).expect("Invalid buffer ID").size;
259
260        let output = self.alloc_output(size);
261
262        self.operations.push(GpuOp::Sigmoid { input, output });
263
264        output
265    }
266
267    /// Queue hyperbolic tangent: tanh(x)
268    ///
269    /// Returns buffer ID for the output.
270    pub fn tanh(&mut self, input: BufferId) -> BufferId {
271        let size = self.buffers.get(&input).expect("Invalid buffer ID").size;
272
273        let output = self.alloc_output(size);
274
275        self.operations.push(GpuOp::Tanh { input, output });
276
277        output
278    }
279
280    /// Queue Swish activation: x * sigmoid(x)
281    ///
282    /// Returns buffer ID for the output.
283    pub fn swish(&mut self, input: BufferId) -> BufferId {
284        let size = self.buffers.get(&input).expect("Invalid buffer ID").size;
285
286        let output = self.alloc_output(size);
287
288        self.operations.push(GpuOp::Swish { input, output });
289
290        output
291    }
292
293    /// Queue GELU activation: x * Φ(x)
294    ///
295    /// Returns buffer ID for the output.
296    pub fn gelu(&mut self, input: BufferId) -> BufferId {
297        let size = self.buffers.get(&input).expect("Invalid buffer ID").size;
298
299        let output = self.alloc_output(size);
300
301        self.operations.push(GpuOp::Gelu { input, output });
302
303        output
304    }
305
306    /// Queue element-wise subtraction: a - b
307    ///
308    /// Returns buffer ID for the output.
309    ///
310    /// # Panics
311    ///
312    /// Panics if buffers have different sizes.
313    pub fn sub(&mut self, a: BufferId, b: BufferId) -> BufferId {
314        let size_a = self.buffers.get(&a).expect("Invalid buffer ID").size;
315        let size_b = self.buffers.get(&b).expect("Invalid buffer ID").size;
316
317        assert_eq!(size_a, size_b, "Buffer size mismatch: {} vs {}", size_a, size_b);
318
319        let output = self.alloc_output(size_a);
320
321        self.operations.push(GpuOp::Sub { a, b, output });
322
323        output
324    }
325
326    /// Queue matrix multiplication: C = A × B
327    ///
328    /// A is M×K elements, B is K×N elements, output is M×N elements.
329    /// All matrices are row-major flat arrays.
330    ///
331    /// Returns buffer ID for the M×N output.
332    ///
333    /// # Panics
334    ///
335    /// Panics if buffer sizes don't match the declared dimensions.
336    pub fn matmul(&mut self, a: BufferId, b: BufferId, m: u32, k: u32, n: u32) -> BufferId {
337        let size_a = self.buffers.get(&a).expect("Invalid buffer A ID").size;
338        let size_b = self.buffers.get(&b).expect("Invalid buffer B ID").size;
339
340        assert_eq!(
341            size_a,
342            (m * k) as usize,
343            "Buffer A size {} doesn't match M×K = {}",
344            size_a,
345            m * k
346        );
347        assert_eq!(
348            size_b,
349            (k * n) as usize,
350            "Buffer B size {} doesn't match K×N = {}",
351            size_b,
352            k * n
353        );
354
355        let output = self.alloc_output((m * n) as usize);
356
357        self.operations.push(GpuOp::Matmul { a, b, output, m, k, n });
358
359        output
360    }
361
362    /// Import a pre-existing GPU buffer for use in batch operations.
363    ///
364    /// Unlike `upload()` which copies host data to GPU during `execute()`,
365    /// imported buffers are already GPU-resident and skip the upload step.
366    /// The `Arc` wrapper allows the same buffer to be shared across multiple
367    /// batch executions without re-uploading (KAIZEN-015: GPU-resident weights).
368    ///
369    /// # Contract (C-BATCH-IMPORT-001)
370    ///
371    /// - **Precondition**: `buffer` is a valid `wgpu::Buffer` with STORAGE | COPY_SRC usage
372    /// - **Postcondition**: Returned `BufferId` can be used in all batch operations (matmul, etc.)
373    /// - **Invariant**: Imported buffer is NOT destroyed when the batch is dropped —
374    ///   the `Arc` keeps it alive as long as the caller retains a clone
375    pub fn import_buffer(&mut self, buffer: Arc<wgpu::Buffer>, size: usize) -> BufferId {
376        let id = BufferId(self.next_buffer_id);
377        self.next_buffer_id += 1;
378        self.buffers.insert(id, BufferInfo { size, data: None, gpu_buffer: Some(buffer) });
379        id
380    }
381
382    /// Get the underlying wgpu device for creating persistent buffers.
383    ///
384    /// Used to create `wgpu::Buffer` instances that outlive individual batch executions.
385    /// Created buffers can be registered via `import_buffer()`.
386    pub fn wgpu_device(&self) -> &wgpu::Device {
387        &self.device.device
388    }
389
390    /// Get the underlying wgpu queue for writing to persistent buffers.
391    pub fn wgpu_queue(&self) -> &wgpu::Queue {
392        &self.device.queue
393    }
394
395    /// Get number of queued operations
396    pub fn num_operations(&self) -> usize {
397        self.operations.len()
398    }
399
400    /// Get number of buffers
401    pub fn num_buffers(&self) -> usize {
402        self.buffers.len()
403    }
404}