trueno/backends/gpu/batch/execute/mod.rs
1//! GPU execution engine for batched operations
2//!
3//! Contains the `execute()` and `read()` public entry points, plus sub-modules
4//! for operation dispatch and shader pipeline infrastructure.
5//!
6//! - [`dispatch`]: Pipeline-cached shader dispatch (`encode_unary_op`, `encode_binary_op`, etc.)
7//! - [`operations`]: Per-operation routing (`encode_operation`)
8//!
9//! # KAIZEN-022: Pipeline caching + single encoder
10//!
11//! All operations in a batch share a single command encoder (one GPU submission)
12//! and a pipeline cache (shader compiled once, reused for all operations using
13//! that shader). For Qwen3-4B FFN: reduces 5 pipeline compilations + 5 submissions
14//! per layer to 3 compilations (first layer only) + 1 submission.
15
16pub(crate) mod dispatch;
17mod operations;
18
19use super::{BufferId, GpuCommandBatch};
20use std::sync::Arc;
21
22impl GpuCommandBatch {
23 /// Execute all queued operations on GPU
24 ///
25 /// Uses a single command encoder for all operations (one GPU submission)
26 /// and caches pipelines per shader source to avoid redundant compilation.
27 /// The pipeline cache is local to this call — see `execute_with_cache()`
28 /// for persistent caching across multiple batch executions.
29 ///
30 /// # Contract (C-BATCH-EXEC-001)
31 ///
32 /// - **Precondition**: Operations queued via `matmul()`, `relu()`, etc.
33 /// - **Postcondition**: All operations executed, results in GPU buffers
34 /// - **Invariant**: Pipeline compiled at most once per unique shader source
35 /// - **Invariant**: Single `queue.submit()` per `execute()` call
36 pub async fn execute(&mut self) -> Result<(), String> {
37 contract_pre_single_encoder_batch!();
38 let mut local_cache = dispatch::PipelineCache::new();
39 let result = self.execute_inner(&mut local_cache);
40 contract_post_single_encoder_batch!(result);
41 result
42 }
43
44 /// Execute with a persistent pipeline cache (KAIZEN-023).
45 ///
46 /// Same as `execute()` but uses a caller-provided pipeline cache that
47 /// persists across multiple batch executions. Shaders compiled in a
48 /// previous batch are reused without recompilation.
49 ///
50 /// For Qwen3-4B FFN (36 layers × 3 unique shaders per batch):
51 /// - `execute()`: 3 compilations per layer × 36 = 108 total
52 /// - `execute_with_cache()`: 3 compilations (layer 1) + 0 (layers 2-36) = 3 total
53 pub async fn execute_with_cache(
54 &mut self,
55 cache: &mut dispatch::PipelineCache,
56 ) -> Result<(), String> {
57 self.execute_inner(cache)
58 }
59
60 /// Shared implementation for execute() and execute_with_cache().
61 fn execute_inner(
62 &mut self,
63 pipeline_cache: &mut dispatch::PipelineCache,
64 ) -> Result<(), String> {
65 // Step 1: Create GPU buffers for all BufferIds
66 // Skip imported buffers — already GPU-resident (KAIZEN-015)
67 for (buffer_id, buffer_info) in &mut self.buffers {
68 if buffer_info.gpu_buffer.is_some() {
69 continue;
70 }
71
72 let size_bytes = (buffer_info.size * std::mem::size_of::<f32>()) as u64;
73
74 let gpu_buffer = self.device.device.create_buffer(&wgpu::BufferDescriptor {
75 label: Some(&format!("Buffer {:?}", buffer_id)),
76 size: size_bytes,
77 usage: wgpu::BufferUsages::STORAGE
78 | wgpu::BufferUsages::COPY_SRC
79 | wgpu::BufferUsages::COPY_DST,
80 mapped_at_creation: false,
81 });
82
83 buffer_info.gpu_buffer = Some(Arc::new(gpu_buffer));
84 }
85
86 // Step 2: Upload initial data to buffers that have it
87 for buffer_info in self.buffers.values() {
88 if let Some(data) = &buffer_info.data {
89 if let Some(gpu_buffer) = &buffer_info.gpu_buffer {
90 self.device.queue.write_buffer(gpu_buffer, 0, bytemuck::cast_slice(data));
91 }
92 }
93 }
94
95 // Step 3: Encode all operations into a single command encoder
96 // with cached pipelines (KAIZEN-022)
97 let mut encoder =
98 self.device.device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
99 label: Some("Batch Encoder"),
100 });
101
102 for op in &self.operations {
103 self.encode_operation(op, &mut encoder, pipeline_cache)?;
104 }
105
106 // Step 4: Single GPU submission for all operations
107 self.device.queue.submit(Some(encoder.finish()));
108
109 Ok(())
110 }
111
112 /// Read buffer data back from GPU
113 ///
114 /// Must call `execute()` first.
115 pub async fn read(&self, buffer_id: BufferId) -> Result<Vec<f32>, String> {
116 contract_pre_read!();
117 let buffer_info = self.buffers.get(&buffer_id).ok_or("Invalid buffer ID")?;
118
119 let gpu_buffer = buffer_info
120 .gpu_buffer
121 .as_ref()
122 .ok_or("Buffer not executed yet - call execute() first")?;
123
124 let size_bytes = (buffer_info.size * std::mem::size_of::<f32>()) as u64;
125
126 // Create staging buffer for reading
127 let staging_buffer = self.device.device.create_buffer(&wgpu::BufferDescriptor {
128 label: Some("Staging Buffer"),
129 size: size_bytes,
130 usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
131 mapped_at_creation: false,
132 });
133
134 // Copy from GPU buffer to staging buffer
135 let mut encoder =
136 self.device.device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
137 label: Some("Read Encoder"),
138 });
139
140 encoder.copy_buffer_to_buffer(gpu_buffer, 0, &staging_buffer, 0, size_bytes);
141
142 self.device.queue.submit(Some(encoder.finish()));
143
144 // Map the staging buffer for reading
145 let buffer_slice = staging_buffer.slice(..);
146 let (sender, receiver) = futures_intrusive::channel::shared::oneshot_channel();
147
148 buffer_slice.map_async(wgpu::MapMode::Read, move |result| {
149 sender.send(result).ok();
150 });
151
152 // Drive GPU work to completion — wgpu requires explicit polling
153 // for map_async callbacks to fire
154 self.device
155 .device
156 .poll(wgpu::PollType::Wait { submission_index: None, timeout: None })
157 .map_err(|e| format!("GPU poll failed: {:?}", e))?;
158
159 // Wait for mapping to complete
160 receiver
161 .receive()
162 .await
163 .ok_or("Failed to receive mapping result")?
164 .map_err(|e| format!("Buffer mapping failed: {:?}", e))?;
165
166 // Read data from mapped buffer
167 let data = {
168 let mapped_range = buffer_slice.get_mapped_range();
169 let float_data: &[f32] = bytemuck::cast_slice(&mapped_range);
170 float_data.to_vec()
171 };
172
173 staging_buffer.unmap();
174
175 Ok(data)
176 }
177}