trueno/backends/gpu/batch/execute/
operations.rs1use super::super::{GpuCommandBatch, GpuOp};
8use super::dispatch::PipelineCache;
9
10impl GpuCommandBatch {
11 pub(crate) fn encode_operation(
15 &self,
16 op: &GpuOp,
17 encoder: &mut wgpu::CommandEncoder,
18 cache: &mut PipelineCache,
19 ) -> Result<(), String> {
20 use crate::backends::gpu::shaders;
21
22 match op {
23 GpuOp::Relu { input, output } => {
24 let input_info = self.buffers.get(input).ok_or("Invalid input buffer ID")?;
25 let output_info = self.buffers.get(output).ok_or("Invalid output buffer ID")?;
26
27 let input_buffer =
28 input_info.gpu_buffer.as_ref().ok_or("Input buffer not created")?;
29 let output_buffer =
30 output_info.gpu_buffer.as_ref().ok_or("Output buffer not created")?;
31
32 self.encode_unary_op::<()>(
33 encoder,
34 cache,
35 shaders::RELU_SHADER,
36 "ReLU",
37 input_buffer,
38 output_buffer,
39 input_info.size,
40 None,
41 )?;
42 }
43
44 GpuOp::Scale { input, output, scalar } => {
45 let input_info = self.buffers.get(input).ok_or("Invalid input buffer ID")?;
46 let output_info = self.buffers.get(output).ok_or("Invalid output buffer ID")?;
47
48 let input_buffer =
49 input_info.gpu_buffer.as_ref().ok_or("Input buffer not created")?;
50 let output_buffer =
51 output_info.gpu_buffer.as_ref().ok_or("Output buffer not created")?;
52
53 #[repr(C)]
54 #[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
55 struct ScaleParams {
56 scalar: f32,
57 _padding: [f32; 3],
58 }
59
60 let params = ScaleParams { scalar: *scalar, _padding: [0.0; 3] };
61
62 self.encode_unary_op(
63 encoder,
64 cache,
65 shaders::SCALE_SHADER,
66 "Scale",
67 input_buffer,
68 output_buffer,
69 input_info.size,
70 Some(¶ms),
71 )?;
72 }
73
74 GpuOp::Add { a, b, output } => {
75 self.encode_binary_op_for(
76 encoder,
77 cache,
78 shaders::VEC_ADD_SHADER,
79 "Add",
80 a,
81 b,
82 output,
83 )?;
84 }
85
86 GpuOp::Mul { a, b, output } => {
87 self.encode_binary_op_for(
88 encoder,
89 cache,
90 shaders::VEC_MUL_SHADER,
91 "Mul",
92 a,
93 b,
94 output,
95 )?;
96 }
97
98 GpuOp::Dot { a, b, output } => {
99 self.encode_binary_op_for(
100 encoder,
101 cache,
102 shaders::DOT_PRODUCT_SHADER,
103 "Dot",
104 a,
105 b,
106 output,
107 )?;
108 }
109
110 GpuOp::Sigmoid { input, output } => {
111 self.encode_unary_op_for(
112 encoder,
113 cache,
114 shaders::SIGMOID_SHADER,
115 "Sigmoid",
116 input,
117 output,
118 )?;
119 }
120
121 GpuOp::Tanh { input, output } => {
122 self.encode_unary_op_for(
123 encoder,
124 cache,
125 shaders::TANH_SHADER,
126 "Tanh",
127 input,
128 output,
129 )?;
130 }
131
132 GpuOp::Swish { input, output } => {
133 self.encode_unary_op_for(
134 encoder,
135 cache,
136 shaders::SWISH_SHADER,
137 "Swish",
138 input,
139 output,
140 )?;
141 }
142
143 GpuOp::Gelu { input, output } => {
144 self.encode_unary_op_for(
145 encoder,
146 cache,
147 shaders::GELU_SHADER,
148 "GELU",
149 input,
150 output,
151 )?;
152 }
153
154 GpuOp::Sub { a, b, output } => {
155 self.encode_binary_op_for(
156 encoder,
157 cache,
158 shaders::VEC_SUB_SHADER,
159 "Sub",
160 a,
161 b,
162 output,
163 )?;
164 }
165
166 GpuOp::Matmul { a, b, output, m, k, n } => {
167 self.encode_matmul_op(
168 encoder,
169 cache,
170 shaders::MATMUL_SHADER,
171 "Matmul",
172 a,
173 b,
174 output,
175 *m,
176 *k,
177 *n,
178 )?;
179 }
180 }
181
182 Ok(())
183 }
184
185 fn encode_unary_op_for(
187 &self,
188 encoder: &mut wgpu::CommandEncoder,
189 cache: &mut PipelineCache,
190 shader_source: &str,
191 label: &str,
192 input: &super::super::BufferId,
193 output: &super::super::BufferId,
194 ) -> Result<(), String> {
195 let input_info = self.buffers.get(input).ok_or("Invalid input buffer ID")?;
196 let output_info = self.buffers.get(output).ok_or("Invalid output buffer ID")?;
197
198 let input_buffer = input_info.gpu_buffer.as_ref().ok_or("Input buffer not created")?;
199 let output_buffer = output_info.gpu_buffer.as_ref().ok_or("Output buffer not created")?;
200
201 self.encode_unary_op::<()>(
202 encoder,
203 cache,
204 shader_source,
205 label,
206 input_buffer,
207 output_buffer,
208 input_info.size,
209 None,
210 )
211 }
212
213 fn encode_binary_op_for(
215 &self,
216 encoder: &mut wgpu::CommandEncoder,
217 cache: &mut PipelineCache,
218 shader_source: &str,
219 label: &str,
220 a: &super::super::BufferId,
221 b: &super::super::BufferId,
222 output: &super::super::BufferId,
223 ) -> Result<(), String> {
224 let a_info = self.buffers.get(a).ok_or("Invalid buffer A ID")?;
225 let b_info = self.buffers.get(b).ok_or("Invalid buffer B ID")?;
226 let output_info = self.buffers.get(output).ok_or("Invalid output buffer ID")?;
227
228 let a_buffer = a_info.gpu_buffer.as_ref().ok_or("Buffer A not created")?;
229 let b_buffer = b_info.gpu_buffer.as_ref().ok_or("Buffer B not created")?;
230 let output_buffer = output_info.gpu_buffer.as_ref().ok_or("Output buffer not created")?;
231
232 self.encode_binary_op(
233 encoder,
234 cache,
235 shader_source,
236 label,
237 a_buffer,
238 b_buffer,
239 output_buffer,
240 a_info.size,
241 )
242 }
243}