Skip to main content

trueno/backends/gpu/batch/execute/
operations.rs

1//! Operation dispatch for GPU batch execution
2//!
3//! Contains `encode_operation()` which routes each `GpuOp` variant to the
4//! appropriate cached pipeline via `encode_unary_op`, `encode_binary_op`,
5//! or `encode_matmul_op`.
6
7use super::super::{GpuCommandBatch, GpuOp};
8use super::dispatch::PipelineCache;
9
10impl GpuCommandBatch {
11    /// Encode a single GPU operation into the command encoder.
12    ///
13    /// Uses the pipeline cache to avoid recompiling shaders for repeated operations.
14    pub(crate) fn encode_operation(
15        &self,
16        op: &GpuOp,
17        encoder: &mut wgpu::CommandEncoder,
18        cache: &mut PipelineCache,
19    ) -> Result<(), String> {
20        use crate::backends::gpu::shaders;
21
22        match op {
23            GpuOp::Relu { input, output } => {
24                let input_info = self.buffers.get(input).ok_or("Invalid input buffer ID")?;
25                let output_info = self.buffers.get(output).ok_or("Invalid output buffer ID")?;
26
27                let input_buffer =
28                    input_info.gpu_buffer.as_ref().ok_or("Input buffer not created")?;
29                let output_buffer =
30                    output_info.gpu_buffer.as_ref().ok_or("Output buffer not created")?;
31
32                self.encode_unary_op::<()>(
33                    encoder,
34                    cache,
35                    shaders::RELU_SHADER,
36                    "ReLU",
37                    input_buffer,
38                    output_buffer,
39                    input_info.size,
40                    None,
41                )?;
42            }
43
44            GpuOp::Scale { input, output, scalar } => {
45                let input_info = self.buffers.get(input).ok_or("Invalid input buffer ID")?;
46                let output_info = self.buffers.get(output).ok_or("Invalid output buffer ID")?;
47
48                let input_buffer =
49                    input_info.gpu_buffer.as_ref().ok_or("Input buffer not created")?;
50                let output_buffer =
51                    output_info.gpu_buffer.as_ref().ok_or("Output buffer not created")?;
52
53                #[repr(C)]
54                #[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
55                struct ScaleParams {
56                    scalar: f32,
57                    _padding: [f32; 3],
58                }
59
60                let params = ScaleParams { scalar: *scalar, _padding: [0.0; 3] };
61
62                self.encode_unary_op(
63                    encoder,
64                    cache,
65                    shaders::SCALE_SHADER,
66                    "Scale",
67                    input_buffer,
68                    output_buffer,
69                    input_info.size,
70                    Some(&params),
71                )?;
72            }
73
74            GpuOp::Add { a, b, output } => {
75                self.encode_binary_op_for(
76                    encoder,
77                    cache,
78                    shaders::VEC_ADD_SHADER,
79                    "Add",
80                    a,
81                    b,
82                    output,
83                )?;
84            }
85
86            GpuOp::Mul { a, b, output } => {
87                self.encode_binary_op_for(
88                    encoder,
89                    cache,
90                    shaders::VEC_MUL_SHADER,
91                    "Mul",
92                    a,
93                    b,
94                    output,
95                )?;
96            }
97
98            GpuOp::Dot { a, b, output } => {
99                self.encode_binary_op_for(
100                    encoder,
101                    cache,
102                    shaders::DOT_PRODUCT_SHADER,
103                    "Dot",
104                    a,
105                    b,
106                    output,
107                )?;
108            }
109
110            GpuOp::Sigmoid { input, output } => {
111                self.encode_unary_op_for(
112                    encoder,
113                    cache,
114                    shaders::SIGMOID_SHADER,
115                    "Sigmoid",
116                    input,
117                    output,
118                )?;
119            }
120
121            GpuOp::Tanh { input, output } => {
122                self.encode_unary_op_for(
123                    encoder,
124                    cache,
125                    shaders::TANH_SHADER,
126                    "Tanh",
127                    input,
128                    output,
129                )?;
130            }
131
132            GpuOp::Swish { input, output } => {
133                self.encode_unary_op_for(
134                    encoder,
135                    cache,
136                    shaders::SWISH_SHADER,
137                    "Swish",
138                    input,
139                    output,
140                )?;
141            }
142
143            GpuOp::Gelu { input, output } => {
144                self.encode_unary_op_for(
145                    encoder,
146                    cache,
147                    shaders::GELU_SHADER,
148                    "GELU",
149                    input,
150                    output,
151                )?;
152            }
153
154            GpuOp::Sub { a, b, output } => {
155                self.encode_binary_op_for(
156                    encoder,
157                    cache,
158                    shaders::VEC_SUB_SHADER,
159                    "Sub",
160                    a,
161                    b,
162                    output,
163                )?;
164            }
165
166            GpuOp::Matmul { a, b, output, m, k, n } => {
167                self.encode_matmul_op(
168                    encoder,
169                    cache,
170                    shaders::MATMUL_SHADER,
171                    "Matmul",
172                    a,
173                    b,
174                    output,
175                    *m,
176                    *k,
177                    *n,
178                )?;
179            }
180        }
181
182        Ok(())
183    }
184
185    /// Helper to extract buffers and encode a unary operation (no params)
186    fn encode_unary_op_for(
187        &self,
188        encoder: &mut wgpu::CommandEncoder,
189        cache: &mut PipelineCache,
190        shader_source: &str,
191        label: &str,
192        input: &super::super::BufferId,
193        output: &super::super::BufferId,
194    ) -> Result<(), String> {
195        let input_info = self.buffers.get(input).ok_or("Invalid input buffer ID")?;
196        let output_info = self.buffers.get(output).ok_or("Invalid output buffer ID")?;
197
198        let input_buffer = input_info.gpu_buffer.as_ref().ok_or("Input buffer not created")?;
199        let output_buffer = output_info.gpu_buffer.as_ref().ok_or("Output buffer not created")?;
200
201        self.encode_unary_op::<()>(
202            encoder,
203            cache,
204            shader_source,
205            label,
206            input_buffer,
207            output_buffer,
208            input_info.size,
209            None,
210        )
211    }
212
213    /// Helper to extract buffers and encode a binary operation
214    fn encode_binary_op_for(
215        &self,
216        encoder: &mut wgpu::CommandEncoder,
217        cache: &mut PipelineCache,
218        shader_source: &str,
219        label: &str,
220        a: &super::super::BufferId,
221        b: &super::super::BufferId,
222        output: &super::super::BufferId,
223    ) -> Result<(), String> {
224        let a_info = self.buffers.get(a).ok_or("Invalid buffer A ID")?;
225        let b_info = self.buffers.get(b).ok_or("Invalid buffer B ID")?;
226        let output_info = self.buffers.get(output).ok_or("Invalid output buffer ID")?;
227
228        let a_buffer = a_info.gpu_buffer.as_ref().ok_or("Buffer A not created")?;
229        let b_buffer = b_info.gpu_buffer.as_ref().ok_or("Buffer B not created")?;
230        let output_buffer = output_info.gpu_buffer.as_ref().ok_or("Output buffer not created")?;
231
232        self.encode_binary_op(
233            encoder,
234            cache,
235            shader_source,
236            label,
237            a_buffer,
238            b_buffer,
239            output_buffer,
240            a_info.size,
241        )
242    }
243}