Skip to main content

trueno/backends/gpu/batch/execute/
mod.rs

1//! GPU execution engine for batched operations
2//!
3//! Contains the `execute()` and `read()` public entry points, plus sub-modules
4//! for operation dispatch and shader pipeline infrastructure.
5//!
6//! - [`dispatch`]: Pipeline-cached shader dispatch (`encode_unary_op`, `encode_binary_op`, etc.)
7//! - [`operations`]: Per-operation routing (`encode_operation`)
8//!
9//! # KAIZEN-022: Pipeline caching + single encoder
10//!
11//! All operations in a batch share a single command encoder (one GPU submission)
12//! and a pipeline cache (shader compiled once, reused for all operations using
13//! that shader).  For Qwen3-4B FFN: reduces 5 pipeline compilations + 5 submissions
14//! per layer to 3 compilations (first layer only) + 1 submission.
15
16pub(crate) mod dispatch;
17mod operations;
18
19use super::{BufferId, GpuCommandBatch};
20use std::sync::Arc;
21
22impl GpuCommandBatch {
23    /// Execute all queued operations on GPU
24    ///
25    /// Uses a single command encoder for all operations (one GPU submission)
26    /// and caches pipelines per shader source to avoid redundant compilation.
27    /// The pipeline cache is local to this call — see `execute_with_cache()`
28    /// for persistent caching across multiple batch executions.
29    ///
30    /// # Contract (C-BATCH-EXEC-001)
31    ///
32    /// - **Precondition**: Operations queued via `matmul()`, `relu()`, etc.
33    /// - **Postcondition**: All operations executed, results in GPU buffers
34    /// - **Invariant**: Pipeline compiled at most once per unique shader source
35    /// - **Invariant**: Single `queue.submit()` per `execute()` call
36    pub async fn execute(&mut self) -> Result<(), String> {
37        contract_pre_single_encoder_batch!();
38        let mut local_cache = dispatch::PipelineCache::new();
39        let result = self.execute_inner(&mut local_cache);
40        contract_post_single_encoder_batch!(result);
41        result
42    }
43
44    /// Execute with a persistent pipeline cache (KAIZEN-023).
45    ///
46    /// Same as `execute()` but uses a caller-provided pipeline cache that
47    /// persists across multiple batch executions.  Shaders compiled in a
48    /// previous batch are reused without recompilation.
49    ///
50    /// For Qwen3-4B FFN (36 layers × 3 unique shaders per batch):
51    /// - `execute()`: 3 compilations per layer × 36 = 108 total
52    /// - `execute_with_cache()`: 3 compilations (layer 1) + 0 (layers 2-36) = 3 total
53    pub async fn execute_with_cache(
54        &mut self,
55        cache: &mut dispatch::PipelineCache,
56    ) -> Result<(), String> {
57        self.execute_inner(cache)
58    }
59
60    /// Shared implementation for execute() and execute_with_cache().
61    fn execute_inner(
62        &mut self,
63        pipeline_cache: &mut dispatch::PipelineCache,
64    ) -> Result<(), String> {
65        // Step 1: Create GPU buffers for all BufferIds
66        // Skip imported buffers — already GPU-resident (KAIZEN-015)
67        for (buffer_id, buffer_info) in &mut self.buffers {
68            if buffer_info.gpu_buffer.is_some() {
69                continue;
70            }
71
72            let size_bytes = (buffer_info.size * std::mem::size_of::<f32>()) as u64;
73
74            let gpu_buffer = self.device.device.create_buffer(&wgpu::BufferDescriptor {
75                label: Some(&format!("Buffer {:?}", buffer_id)),
76                size: size_bytes,
77                usage: wgpu::BufferUsages::STORAGE
78                    | wgpu::BufferUsages::COPY_SRC
79                    | wgpu::BufferUsages::COPY_DST,
80                mapped_at_creation: false,
81            });
82
83            buffer_info.gpu_buffer = Some(Arc::new(gpu_buffer));
84        }
85
86        // Step 2: Upload initial data to buffers that have it
87        for buffer_info in self.buffers.values() {
88            if let Some(data) = &buffer_info.data {
89                if let Some(gpu_buffer) = &buffer_info.gpu_buffer {
90                    self.device.queue.write_buffer(gpu_buffer, 0, bytemuck::cast_slice(data));
91                }
92            }
93        }
94
95        // Step 3: Encode all operations into a single command encoder
96        // with cached pipelines (KAIZEN-022)
97        let mut encoder =
98            self.device.device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
99                label: Some("Batch Encoder"),
100            });
101
102        for op in &self.operations {
103            self.encode_operation(op, &mut encoder, pipeline_cache)?;
104        }
105
106        // Step 4: Single GPU submission for all operations
107        self.device.queue.submit(Some(encoder.finish()));
108
109        Ok(())
110    }
111
112    /// Read buffer data back from GPU
113    ///
114    /// Must call `execute()` first.
115    pub async fn read(&self, buffer_id: BufferId) -> Result<Vec<f32>, String> {
116        contract_pre_read!();
117        let buffer_info = self.buffers.get(&buffer_id).ok_or("Invalid buffer ID")?;
118
119        let gpu_buffer = buffer_info
120            .gpu_buffer
121            .as_ref()
122            .ok_or("Buffer not executed yet - call execute() first")?;
123
124        let size_bytes = (buffer_info.size * std::mem::size_of::<f32>()) as u64;
125
126        // Create staging buffer for reading
127        let staging_buffer = self.device.device.create_buffer(&wgpu::BufferDescriptor {
128            label: Some("Staging Buffer"),
129            size: size_bytes,
130            usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
131            mapped_at_creation: false,
132        });
133
134        // Copy from GPU buffer to staging buffer
135        let mut encoder =
136            self.device.device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
137                label: Some("Read Encoder"),
138            });
139
140        encoder.copy_buffer_to_buffer(gpu_buffer, 0, &staging_buffer, 0, size_bytes);
141
142        self.device.queue.submit(Some(encoder.finish()));
143
144        // Map the staging buffer for reading
145        let buffer_slice = staging_buffer.slice(..);
146        let (sender, receiver) = futures_intrusive::channel::shared::oneshot_channel();
147
148        buffer_slice.map_async(wgpu::MapMode::Read, move |result| {
149            sender.send(result).ok();
150        });
151
152        // Drive GPU work to completion — wgpu requires explicit polling
153        // for map_async callbacks to fire
154        self.device
155            .device
156            .poll(wgpu::PollType::Wait { submission_index: None, timeout: None })
157            .map_err(|e| format!("GPU poll failed: {:?}", e))?;
158
159        // Wait for mapping to complete
160        receiver
161            .receive()
162            .await
163            .ok_or("Failed to receive mapping result")?
164            .map_err(|e| format!("Buffer mapping failed: {:?}", e))?;
165
166        // Read data from mapped buffer
167        let data = {
168            let mapped_range = buffer_slice.get_mapped_range();
169            let float_data: &[f32] = bytemuck::cast_slice(&mapped_range);
170            float_data.to_vec()
171        };
172
173        staging_buffer.unmap();
174
175        Ok(data)
176    }
177}