trueno/backends/gpu/batch/mod.rs
1//! Async GPU command batching for reduced transfer overhead
2//!
3//! This module provides an async API for GPU operations that batches multiple
4//! operations together to minimize CPU↔GPU data transfers.
5//!
6//! # Motivation
7//!
8//! The synchronous GPU API transfers data for each operation:
9//! ```text
10//! vec.relu() // Upload → GPU compute → Download
11//! vec.scale(2.0) // Upload → GPU compute → Download
12//! vec.add(&other) // Upload → GPU compute → Download
13//! Total: 6 transfers (3 up, 3 down)
14//! ```
15//!
16//! The async batch API queues operations and executes them together:
17//! ```text
18//! batch.relu(input)
19//! batch.scale(relu_out, 2.0)
20//! batch.add(scaled, other)
21//! batch.execute() // Upload once → 3 GPU computes → Download once
22//! Total: 2 transfers (1 up, 1 down) // 3x reduction!
23//! ```
24//!
25//! # Example
26//!
27//! ```rust,no_run
28//! use trueno::backends::gpu::{GpuDevice, GpuCommandBatch};
29//!
30//! # async fn example() -> Result<(), String> {
31//! let device = GpuDevice::new()?;
32//! let mut batch = GpuCommandBatch::new(device);
33//!
34//! // Queue operations (no GPU execution yet)
35//! let input = batch.upload(&[1.0, 2.0, -3.0, 4.0]);
36//! let relu_out = batch.relu(input);
37//! let scaled = batch.scale(relu_out, 2.0);
38//! let other = batch.upload(&[0.5, 0.5, 0.5, 0.5]);
39//! let final_out = batch.add(scaled, other);
40//!
41//! // Execute all operations in single batch
42//! batch.execute().await?;
43//!
44//! // Read final result
45//! let result = batch.read(final_out).await?;
46//! assert_eq!(result, vec![2.5, 4.5, 0.5, 8.5]);
47//! # Ok(())
48//! # }
49//! ```
50
51mod execute;
52
53pub use execute::dispatch::PipelineCache;
54
55#[cfg(test)]
56mod tests;
57
58use super::GpuDevice;
59use std::collections::HashMap;
60use std::sync::Arc;
61use wgpu;
62
63/// Unique identifier for a buffer in a batch
64#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
65pub struct BufferId(pub(crate) usize);
66
67/// GPU operation to be executed in a batch
68#[derive(Debug)]
69pub(crate) enum GpuOp {
70 /// ReLU activation: max(0, x)
71 Relu { input: BufferId, output: BufferId },
72
73 /// Scalar multiplication: x * scalar
74 Scale { input: BufferId, output: BufferId, scalar: f32 },
75
76 /// Element-wise addition: a + b
77 Add { a: BufferId, b: BufferId, output: BufferId },
78
79 /// Element-wise multiplication: a * b
80 Mul { a: BufferId, b: BufferId, output: BufferId },
81
82 /// Dot product: sum(a[i] * b[i])
83 Dot {
84 a: BufferId,
85 b: BufferId,
86 output: BufferId, // Single-element buffer for result
87 },
88
89 /// Sigmoid activation: 1 / (1 + exp(-x))
90 Sigmoid { input: BufferId, output: BufferId },
91
92 /// Hyperbolic tangent: tanh(x)
93 Tanh { input: BufferId, output: BufferId },
94
95 /// Swish activation: x * sigmoid(x)
96 Swish { input: BufferId, output: BufferId },
97
98 /// GELU activation: x * Φ(x) where Φ is cumulative distribution function
99 Gelu { input: BufferId, output: BufferId },
100
101 /// Element-wise subtraction: a - b
102 Sub { a: BufferId, b: BufferId, output: BufferId },
103
104 /// Matrix multiplication: C = A × B
105 /// A is M×K, B is K×N, C is M×N (all row-major)
106 Matmul { a: BufferId, b: BufferId, output: BufferId, m: u32, k: u32, n: u32 },
107}
108
109/// Command batch for async GPU execution
110///
111/// Accumulates GPU operations and executes them together to minimize
112/// CPU↔GPU data transfers.
113pub struct GpuCommandBatch {
114 pub(crate) device: Arc<GpuDevice>,
115 pub(crate) operations: Vec<GpuOp>,
116 pub(crate) buffers: HashMap<BufferId, BufferInfo>,
117 pub(crate) next_buffer_id: usize,
118}
119
120/// Information about a buffer in the batch
121#[derive(Debug)]
122pub(crate) struct BufferInfo {
123 /// Size in elements (f32)
124 pub(crate) size: usize,
125
126 /// Initial data to upload (if any)
127 pub(crate) data: Option<Vec<f32>>,
128
129 /// GPU buffer (created during execute(), or pre-existing for imported buffers).
130 /// Wrapped in `Arc` to allow sharing across multiple batch executions (KAIZEN-015).
131 /// When `Some`, execute() skips buffer creation (already GPU-resident).
132 pub(crate) gpu_buffer: Option<Arc<wgpu::Buffer>>,
133}
134
135impl GpuCommandBatch {
136 /// Create a new command batch
137 pub fn new(device: GpuDevice) -> Self {
138 Self {
139 device: Arc::new(device),
140 operations: Vec::new(),
141 buffers: HashMap::new(),
142 next_buffer_id: 0,
143 }
144 }
145
146 /// Allocate a new buffer ID
147 fn alloc_buffer(&mut self, size: usize, data: Option<Vec<f32>>) -> BufferId {
148 let id = BufferId(self.next_buffer_id);
149 self.next_buffer_id += 1;
150
151 self.buffers.insert(id, BufferInfo { size, data, gpu_buffer: None });
152
153 id
154 }
155
156 /// Upload data to GPU (queued for batch execution)
157 ///
158 /// Returns a buffer ID that can be used in subsequent operations.
159 pub fn upload(&mut self, data: &[f32]) -> BufferId {
160 self.alloc_buffer(data.len(), Some(data.to_vec()))
161 }
162
163 /// Allocate an output buffer for an operation
164 fn alloc_output(&mut self, size: usize) -> BufferId {
165 self.alloc_buffer(size, None)
166 }
167
168 /// Queue ReLU operation: max(0, x)
169 ///
170 /// Returns buffer ID for the output.
171 pub fn relu(&mut self, input: BufferId) -> BufferId {
172 let size = self.buffers.get(&input).expect("Invalid buffer ID").size;
173
174 let output = self.alloc_output(size);
175
176 self.operations.push(GpuOp::Relu { input, output });
177
178 output
179 }
180
181 /// Queue scalar multiplication: x * scalar
182 ///
183 /// Returns buffer ID for the output.
184 pub fn scale(&mut self, input: BufferId, scalar: f32) -> BufferId {
185 let size = self.buffers.get(&input).expect("Invalid buffer ID").size;
186
187 let output = self.alloc_output(size);
188
189 self.operations.push(GpuOp::Scale { input, output, scalar });
190
191 output
192 }
193
194 /// Queue element-wise addition: a + b
195 ///
196 /// Returns buffer ID for the output.
197 ///
198 /// # Panics
199 ///
200 /// Panics if buffers have different sizes.
201 pub fn add(&mut self, a: BufferId, b: BufferId) -> BufferId {
202 let size_a = self.buffers.get(&a).expect("Invalid buffer ID").size;
203 let size_b = self.buffers.get(&b).expect("Invalid buffer ID").size;
204
205 assert_eq!(size_a, size_b, "Buffer size mismatch: {} vs {}", size_a, size_b);
206
207 let output = self.alloc_output(size_a);
208
209 self.operations.push(GpuOp::Add { a, b, output });
210
211 output
212 }
213
214 /// Queue element-wise multiplication: a * b
215 ///
216 /// Returns buffer ID for the output.
217 ///
218 /// # Panics
219 ///
220 /// Panics if buffers have different sizes.
221 pub fn mul(&mut self, a: BufferId, b: BufferId) -> BufferId {
222 let size_a = self.buffers.get(&a).expect("Invalid buffer ID").size;
223 let size_b = self.buffers.get(&b).expect("Invalid buffer ID").size;
224
225 assert_eq!(size_a, size_b, "Buffer size mismatch: {} vs {}", size_a, size_b);
226
227 let output = self.alloc_output(size_a);
228
229 self.operations.push(GpuOp::Mul { a, b, output });
230
231 output
232 }
233
234 /// Queue dot product: sum(a[i] * b[i])
235 ///
236 /// Returns buffer ID for a single-element output buffer.
237 ///
238 /// # Panics
239 ///
240 /// Panics if buffers have different sizes.
241 pub fn dot(&mut self, a: BufferId, b: BufferId) -> BufferId {
242 let size_a = self.buffers.get(&a).expect("Invalid buffer ID").size;
243 let size_b = self.buffers.get(&b).expect("Invalid buffer ID").size;
244
245 assert_eq!(size_a, size_b, "Buffer size mismatch: {} vs {}", size_a, size_b);
246
247 let output = self.alloc_output(1); // Dot product returns scalar
248
249 self.operations.push(GpuOp::Dot { a, b, output });
250
251 output
252 }
253
254 /// Queue sigmoid activation: 1 / (1 + exp(-x))
255 ///
256 /// Returns buffer ID for the output.
257 pub fn sigmoid(&mut self, input: BufferId) -> BufferId {
258 let size = self.buffers.get(&input).expect("Invalid buffer ID").size;
259
260 let output = self.alloc_output(size);
261
262 self.operations.push(GpuOp::Sigmoid { input, output });
263
264 output
265 }
266
267 /// Queue hyperbolic tangent: tanh(x)
268 ///
269 /// Returns buffer ID for the output.
270 pub fn tanh(&mut self, input: BufferId) -> BufferId {
271 let size = self.buffers.get(&input).expect("Invalid buffer ID").size;
272
273 let output = self.alloc_output(size);
274
275 self.operations.push(GpuOp::Tanh { input, output });
276
277 output
278 }
279
280 /// Queue Swish activation: x * sigmoid(x)
281 ///
282 /// Returns buffer ID for the output.
283 pub fn swish(&mut self, input: BufferId) -> BufferId {
284 let size = self.buffers.get(&input).expect("Invalid buffer ID").size;
285
286 let output = self.alloc_output(size);
287
288 self.operations.push(GpuOp::Swish { input, output });
289
290 output
291 }
292
293 /// Queue GELU activation: x * Φ(x)
294 ///
295 /// Returns buffer ID for the output.
296 pub fn gelu(&mut self, input: BufferId) -> BufferId {
297 let size = self.buffers.get(&input).expect("Invalid buffer ID").size;
298
299 let output = self.alloc_output(size);
300
301 self.operations.push(GpuOp::Gelu { input, output });
302
303 output
304 }
305
306 /// Queue element-wise subtraction: a - b
307 ///
308 /// Returns buffer ID for the output.
309 ///
310 /// # Panics
311 ///
312 /// Panics if buffers have different sizes.
313 pub fn sub(&mut self, a: BufferId, b: BufferId) -> BufferId {
314 let size_a = self.buffers.get(&a).expect("Invalid buffer ID").size;
315 let size_b = self.buffers.get(&b).expect("Invalid buffer ID").size;
316
317 assert_eq!(size_a, size_b, "Buffer size mismatch: {} vs {}", size_a, size_b);
318
319 let output = self.alloc_output(size_a);
320
321 self.operations.push(GpuOp::Sub { a, b, output });
322
323 output
324 }
325
326 /// Queue matrix multiplication: C = A × B
327 ///
328 /// A is M×K elements, B is K×N elements, output is M×N elements.
329 /// All matrices are row-major flat arrays.
330 ///
331 /// Returns buffer ID for the M×N output.
332 ///
333 /// # Panics
334 ///
335 /// Panics if buffer sizes don't match the declared dimensions.
336 pub fn matmul(&mut self, a: BufferId, b: BufferId, m: u32, k: u32, n: u32) -> BufferId {
337 let size_a = self.buffers.get(&a).expect("Invalid buffer A ID").size;
338 let size_b = self.buffers.get(&b).expect("Invalid buffer B ID").size;
339
340 assert_eq!(
341 size_a,
342 (m * k) as usize,
343 "Buffer A size {} doesn't match M×K = {}",
344 size_a,
345 m * k
346 );
347 assert_eq!(
348 size_b,
349 (k * n) as usize,
350 "Buffer B size {} doesn't match K×N = {}",
351 size_b,
352 k * n
353 );
354
355 let output = self.alloc_output((m * n) as usize);
356
357 self.operations.push(GpuOp::Matmul { a, b, output, m, k, n });
358
359 output
360 }
361
362 /// Import a pre-existing GPU buffer for use in batch operations.
363 ///
364 /// Unlike `upload()` which copies host data to GPU during `execute()`,
365 /// imported buffers are already GPU-resident and skip the upload step.
366 /// The `Arc` wrapper allows the same buffer to be shared across multiple
367 /// batch executions without re-uploading (KAIZEN-015: GPU-resident weights).
368 ///
369 /// # Contract (C-BATCH-IMPORT-001)
370 ///
371 /// - **Precondition**: `buffer` is a valid `wgpu::Buffer` with STORAGE | COPY_SRC usage
372 /// - **Postcondition**: Returned `BufferId` can be used in all batch operations (matmul, etc.)
373 /// - **Invariant**: Imported buffer is NOT destroyed when the batch is dropped —
374 /// the `Arc` keeps it alive as long as the caller retains a clone
375 pub fn import_buffer(&mut self, buffer: Arc<wgpu::Buffer>, size: usize) -> BufferId {
376 let id = BufferId(self.next_buffer_id);
377 self.next_buffer_id += 1;
378 self.buffers.insert(id, BufferInfo { size, data: None, gpu_buffer: Some(buffer) });
379 id
380 }
381
382 /// Get the underlying wgpu device for creating persistent buffers.
383 ///
384 /// Used to create `wgpu::Buffer` instances that outlive individual batch executions.
385 /// Created buffers can be registered via `import_buffer()`.
386 pub fn wgpu_device(&self) -> &wgpu::Device {
387 &self.device.device
388 }
389
390 /// Get the underlying wgpu queue for writing to persistent buffers.
391 pub fn wgpu_queue(&self) -> &wgpu::Queue {
392 &self.device.queue
393 }
394
395 /// Get number of queued operations
396 pub fn num_operations(&self) -> usize {
397 self.operations.len()
398 }
399
400 /// Get number of buffers
401 pub fn num_buffers(&self) -> usize {
402 self.buffers.len()
403 }
404}