1use wgpu::util::DeviceExt;
10
11use super::GpuContext;
12use super::curve::GpuCurve;
13
14impl<C: GpuCurve> GpuContext<C> {
15 pub fn execute_to_montgomery(
16 &self,
17 buffer: &wgpu::Buffer,
18 num_elements: u32,
19 ) {
20 let bind_group =
21 self.device.create_bind_group(&wgpu::BindGroupDescriptor {
22 label: Some("To Montgomery Bind Group"),
23 layout: &self.montgomery_bind_group_layout,
24 entries: &[wgpu::BindGroupEntry {
25 binding: 0,
26 resource: buffer.as_entire_binding(),
27 }],
28 });
29 let mut encoder = self.device.create_command_encoder(
30 &wgpu::CommandEncoderDescriptor { label: None },
31 );
32 {
33 let mut cpass =
34 encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
35 label: None,
36 timestamp_writes: None,
37 });
38 cpass.set_pipeline(&self.to_montgomery_pipeline);
39 cpass.set_bind_group(0, &bind_group, &[]);
40 cpass.dispatch_workgroups(
41 num_elements.div_ceil(C::SCALAR_WORKGROUP_SIZE),
42 1,
43 1,
44 );
45 }
46 self.queue.submit(Some(encoder.finish()));
47 }
48
49 pub fn execute_from_montgomery(
50 &self,
51 buffer: &wgpu::Buffer,
52 num_elements: u32,
53 ) {
54 let bind_group =
55 self.device.create_bind_group(&wgpu::BindGroupDescriptor {
56 label: Some("From Montgomery Bind Group"),
57 layout: &self.montgomery_bind_group_layout,
58 entries: &[wgpu::BindGroupEntry {
59 binding: 0,
60 resource: buffer.as_entire_binding(),
61 }],
62 });
63 let mut encoder = self.device.create_command_encoder(
64 &wgpu::CommandEncoderDescriptor { label: None },
65 );
66 {
67 let mut cpass =
68 encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
69 label: None,
70 timestamp_writes: None,
71 });
72 cpass.set_pipeline(&self.from_montgomery_pipeline);
73 cpass.set_bind_group(0, &bind_group, &[]);
74 cpass.dispatch_workgroups(
75 num_elements.div_ceil(C::SCALAR_WORKGROUP_SIZE),
76 1,
77 1,
78 );
79 }
80 self.queue.submit(Some(encoder.finish()));
81 }
82
83 pub fn execute_ntt(
84 &self,
85 data_buffer: &wgpu::Buffer,
86 twiddles_buffer: &wgpu::Buffer,
87 num_elements: u32,
88 ) {
89 if num_elements > C::NTT_TILE_SIZE {
90 self.execute_ntt_global(data_buffer, twiddles_buffer, num_elements);
91 return;
92 }
93
94 let bind_group =
95 self.device.create_bind_group(&wgpu::BindGroupDescriptor {
96 label: Some("NTT Bind Group"),
97 layout: &self.ntt_bind_group_layout,
98 entries: &[
99 wgpu::BindGroupEntry {
100 binding: 0,
101 resource: data_buffer.as_entire_binding(),
102 },
103 wgpu::BindGroupEntry {
104 binding: 1,
105 resource: twiddles_buffer.as_entire_binding(),
106 },
107 ],
108 });
109 let mut encoder = self.device.create_command_encoder(
110 &wgpu::CommandEncoderDescriptor {
111 label: Some("NTT Encoder"),
112 },
113 );
114 {
115 let mut cpass =
116 encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
117 label: Some("NTT Pass"),
118 timestamp_writes: None,
119 });
120 cpass.set_pipeline(&self.ntt_pipeline);
121 cpass.set_bind_group(0, &bind_group, &[]);
122 cpass.dispatch_workgroups(
123 num_elements.div_ceil(C::NTT_TILE_SIZE),
124 1,
125 1,
126 );
127 }
128 self.queue.submit(Some(encoder.finish()));
129 }
130
131 pub fn execute_ntt_global(
142 &self,
143 data_buffer: &wgpu::Buffer,
144 twiddles_buffer: &wgpu::Buffer,
145 num_elements: u32,
146 ) {
147 let mut log_n = 0u32;
148 let mut m = num_elements;
149 while m > 1 {
150 log_n += 1;
151 m >>= 1;
152 }
153
154 let mut encoder = self.device.create_command_encoder(
155 &wgpu::CommandEncoderDescriptor {
156 label: Some("NTT Global Encoder"),
157 },
158 );
159
160 let mut stage_params = [0u32; 4];
161 stage_params[0] = num_elements;
162 stage_params[2] = log_n;
163 let params_buf =
164 self.device
165 .create_buffer_init(&wgpu::util::BufferInitDescriptor {
166 label: Some("NTT Params Buffer"),
167 contents: bytemuck::cast_slice(&stage_params),
168 usage: wgpu::BufferUsages::UNIFORM
169 | wgpu::BufferUsages::COPY_DST,
170 });
171
172 let make_bind_group = |params_buf: &wgpu::Buffer| {
173 self.device.create_bind_group(&wgpu::BindGroupDescriptor {
174 label: Some("NTT Global Bind Group"),
175 layout: &self.ntt_params_bind_group_layout,
176 entries: &[
177 wgpu::BindGroupEntry {
178 binding: 0,
179 resource: data_buffer.as_entire_binding(),
180 },
181 wgpu::BindGroupEntry {
182 binding: 1,
183 resource: twiddles_buffer.as_entire_binding(),
184 },
185 wgpu::BindGroupEntry {
186 binding: 2,
187 resource: params_buf.as_entire_binding(),
188 },
189 ],
190 })
191 };
192
193 {
195 let bg = make_bind_group(¶ms_buf);
196 let mut pass =
197 encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
198 label: Some("NTT BitReverse Pass"),
199 timestamp_writes: None,
200 });
201 pass.set_pipeline(&self.ntt_bitreverse_pipeline);
202 pass.set_bind_group(0, &bg, &[]);
203 pass.dispatch_workgroups(
204 num_elements.div_ceil(C::SCALAR_WORKGROUP_SIZE),
205 1,
206 1,
207 );
208 }
209
210 let mut half_len = 1u32;
212 let mut param_updates: Vec<wgpu::Buffer> = Vec::new();
213
214 if (log_n & 1) == 1 {
215 stage_params[1] = half_len;
216 let update_buf = self.device.create_buffer_init(
217 &wgpu::util::BufferInitDescriptor {
218 label: Some("NTT Params Update"),
219 contents: bytemuck::cast_slice(&stage_params),
220 usage: wgpu::BufferUsages::COPY_SRC,
221 },
222 );
223 encoder.copy_buffer_to_buffer(&update_buf, 0, ¶ms_buf, 0, 16);
224 param_updates.push(update_buf);
225
226 let bg = make_bind_group(¶ms_buf);
227 let mut pass =
228 encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
229 label: Some("NTT Global Stage R2 Pass"),
230 timestamp_writes: None,
231 });
232 pass.set_pipeline(&self.ntt_global_stage_pipeline);
233 pass.set_bind_group(0, &bg, &[]);
234 pass.dispatch_workgroups(
235 (num_elements / 2).div_ceil(C::SCALAR_WORKGROUP_SIZE),
236 1,
237 1,
238 );
239
240 half_len = 2;
241 }
242
243 while half_len < num_elements {
244 stage_params[1] = half_len;
245 let update_buf = self.device.create_buffer_init(
246 &wgpu::util::BufferInitDescriptor {
247 label: Some("NTT Params Update"),
248 contents: bytemuck::cast_slice(&stage_params),
249 usage: wgpu::BufferUsages::COPY_SRC,
250 },
251 );
252 encoder.copy_buffer_to_buffer(&update_buf, 0, ¶ms_buf, 0, 16);
253 param_updates.push(update_buf);
254
255 let bg = make_bind_group(¶ms_buf);
256 let mut pass =
257 encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
258 label: Some("NTT Global Stage R4 Pass"),
259 timestamp_writes: None,
260 });
261 pass.set_pipeline(&self.ntt_global_stage_radix4_pipeline);
262 pass.set_bind_group(0, &bg, &[]);
263 pass.dispatch_workgroups(
264 (num_elements / 4).div_ceil(C::SCALAR_WORKGROUP_SIZE),
265 1,
266 1,
267 );
268
269 half_len <<= 2;
270 }
271
272 self.queue.submit(Some(encoder.finish()));
273 }
274
275 pub fn execute_coset_shift(
276 &self,
277 data_buffer: &wgpu::Buffer,
278 shifts_buffer: &wgpu::Buffer,
279 num_elements: u32,
280 ) {
281 let bind_group =
282 self.device.create_bind_group(&wgpu::BindGroupDescriptor {
283 label: Some("Coset Shift Bind Group"),
284 layout: &self.coset_shift_bind_group_layout,
285 entries: &[
286 wgpu::BindGroupEntry {
287 binding: 0,
288 resource: data_buffer.as_entire_binding(),
289 },
290 wgpu::BindGroupEntry {
291 binding: 1,
292 resource: shifts_buffer.as_entire_binding(),
293 },
294 ],
295 });
296 let mut encoder = self.device.create_command_encoder(
297 &wgpu::CommandEncoderDescriptor { label: None },
298 );
299 {
300 let mut cpass =
301 encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
302 label: None,
303 timestamp_writes: None,
304 });
305 cpass.set_pipeline(&self.coset_shift_pipeline);
306 cpass.set_bind_group(0, &bind_group, &[]);
307 cpass.dispatch_workgroups(
308 num_elements.div_ceil(C::SCALAR_WORKGROUP_SIZE),
309 1,
310 1,
311 );
312 }
313 self.queue.submit(Some(encoder.finish()));
314 }
315
316 pub fn execute_pointwise_poly(
317 &self,
318 a_buf: &wgpu::Buffer,
319 b_buf: &wgpu::Buffer,
320 c_buf: &wgpu::Buffer,
321 h_buf: &wgpu::Buffer,
322 z_invs_buf: &wgpu::Buffer,
323 num_elements: u32,
324 ) {
325 let bind_group =
326 self.device.create_bind_group(&wgpu::BindGroupDescriptor {
327 label: Some("Pointwise Poly Bind Group"),
328 layout: &self.pointwise_poly_bind_group_layout,
329 entries: &[
330 wgpu::BindGroupEntry {
331 binding: 0,
332 resource: a_buf.as_entire_binding(),
333 },
334 wgpu::BindGroupEntry {
335 binding: 1,
336 resource: b_buf.as_entire_binding(),
337 },
338 wgpu::BindGroupEntry {
339 binding: 2,
340 resource: c_buf.as_entire_binding(),
341 },
342 wgpu::BindGroupEntry {
343 binding: 3,
344 resource: h_buf.as_entire_binding(),
345 },
346 wgpu::BindGroupEntry {
347 binding: 4,
348 resource: z_invs_buf.as_entire_binding(),
349 },
350 ],
351 });
352 let mut encoder = self.device.create_command_encoder(
353 &wgpu::CommandEncoderDescriptor { label: None },
354 );
355 {
356 let mut cpass =
357 encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
358 label: None,
359 timestamp_writes: None,
360 });
361 cpass.set_pipeline(&self.pointwise_poly_pipeline);
362 cpass.set_bind_group(0, &bind_group, &[]);
363 cpass.dispatch_workgroups(
364 num_elements.div_ceil(C::SCALAR_WORKGROUP_SIZE),
365 1,
366 1,
367 );
368 }
369 self.queue.submit(Some(encoder.finish()));
370 }
371}