1use handlebars::Handlebars;
2use once_cell::sync::Lazy;
3use serde_json::json;
4
5pub static DECOMPOSE_SCALARS_SHADER: Lazy<String> =
7 Lazy::new(|| include_str!("wgsl/cuzk/decompose_scalars.template.wgsl").to_string());
8pub static EXTRACT_WORD_FROM_BYTES_LE_FUNCS: Lazy<String> =
10 Lazy::new(|| include_str!("wgsl/cuzk/extract_word_from_bytes_le.template.wgsl").to_string());
11pub static MONTGOMERY_PRODUCT_FUNCS: Lazy<String> =
13 Lazy::new(|| include_str!("wgsl/montgomery/mont_pro_product.template.wgsl").to_string());
14pub static BARRETT_FUNCS: Lazy<String> =
16 Lazy::new(|| include_str!("wgsl/field/barrett.template.wgsl").to_string());
17pub static EC_FUNCS: Lazy<String> =
19 Lazy::new(|| include_str!("wgsl/curve/ec.template.wgsl").to_string());
20pub static FIELD_FUNCS: Lazy<String> =
22 Lazy::new(|| include_str!("wgsl/field/field.template.wgsl").to_string());
23pub static BIGINT_FUNCS: Lazy<String> =
25 Lazy::new(|| include_str!("wgsl/bigint/bigint.template.wgsl").to_string());
26pub static STRUCTS: Lazy<String> =
28 Lazy::new(|| include_str!("wgsl/struct/structs.template.wgsl").to_string());
29
30pub static TRANSPOSE_SHADER: Lazy<String> =
32 Lazy::new(|| include_str!("wgsl/cuzk/transpose.template.wgsl").to_string());
33pub static SMVP_SHADER: Lazy<String> =
35 Lazy::new(|| include_str!("wgsl/cuzk/smvp.template.wgsl").to_string());
36pub static BPR_SHADER: Lazy<String> =
38 Lazy::new(|| include_str!("wgsl/cuzk/bpr.template.wgsl").to_string());
39pub static TEST_FIELD_SHADER: Lazy<String> =
41 Lazy::new(|| include_str!("wgsl/test/test_field.wgsl").to_string());
42pub static TEST_POINT_SHADER: Lazy<String> =
44 Lazy::new(|| include_str!("wgsl/test/test_point.wgsl").to_string());
45
46use crate::cuzk::utils::{calc_bitwidth, gen_mu_limbs, gen_one_limbs, gen_p_limbs, gen_rinv_limbs};
47
48use super::{
49 msm::{P, PARAMS},
50 utils::{gen_p_limbs_plus_one, gen_r_limbs, gen_zero_limbs},
51};
52
53pub struct ShaderManager {
55 word_size: usize,
56 chunk_size: usize,
57 input_size: usize,
58 num_words: usize,
59 index_shift: usize,
60 p_limbs: String,
61 p_limbs_plus_one: String,
62 zero_limbs: String,
63 one_limbs: String,
64 r_limbs: String,
65 slack: usize,
66 w_mask: usize,
67 n0: u32,
68 mu_limbs: String,
69 rinv_limbs: String,
70}
71
72impl ShaderManager {
73 pub fn new(word_size: usize, chunk_size: usize, input_size: usize) -> Self {
75 let p_bit_length = calc_bitwidth(&P);
76 let num_words = PARAMS.num_words;
77 let r = PARAMS.r.clone();
78 let rinv = PARAMS.rinv.clone();
79 println!("P: {P:?}");
80 println!("P limbs: {}", gen_p_limbs(&P, num_words, word_size));
81 println!("W_MASK: {:?}", (1 << word_size) - 1);
82 println!("R limbs: {}", gen_r_limbs(&r, num_words, word_size));
83 Self {
84 word_size,
85 chunk_size,
86 input_size,
87 num_words,
88 index_shift: 1 << (chunk_size - 1),
89 p_limbs: gen_p_limbs(&P, num_words, word_size),
90 p_limbs_plus_one: gen_p_limbs_plus_one(&P, num_words, word_size),
91 zero_limbs: gen_zero_limbs(num_words),
92 one_limbs: gen_one_limbs(num_words),
93 slack: num_words * word_size - p_bit_length,
94 w_mask: (1 << word_size) - 1,
95 n0: PARAMS.n0,
96 r_limbs: gen_r_limbs(&r, num_words, word_size),
97 mu_limbs: gen_mu_limbs(&P, num_words, word_size),
98 rinv_limbs: gen_rinv_limbs(&rinv, num_words, word_size),
99 }
100 }
101
102 pub fn gen_transpose_shader(&self, workgroup_size: usize) -> String {
104 let mut handlebars = Handlebars::new();
105 handlebars
106 .register_template_string("transpose", TRANSPOSE_SHADER.as_str())
107 .unwrap();
108 let data = json!({
109 "workgroup_size": workgroup_size,
110 });
111 handlebars.render("transpose", &data).unwrap()
112 }
113
114 pub fn gen_smvp_shader(&self, workgroup_size: usize, num_csr_cols: usize) -> String {
116 println!("num_csr_cols: {num_csr_cols:?}");
117 println!("workgroup_size: {workgroup_size:?}");
118 let mut handlebars = Handlebars::new();
119 handlebars
120 .register_template_string("smvp", SMVP_SHADER.as_str())
121 .unwrap();
122
123 handlebars
124 .register_template_string("structs", STRUCTS.as_str())
125 .unwrap();
126 handlebars
127 .register_template_string("bigint_funcs", BIGINT_FUNCS.as_str())
128 .unwrap();
129 handlebars
130 .register_template_string("ec_funcs", EC_FUNCS.as_str())
131 .unwrap();
132 handlebars
133 .register_template_string("field_funcs", FIELD_FUNCS.as_str())
134 .unwrap();
135 handlebars
136 .register_template_string(
137 "montgomery_product_funcs",
138 MONTGOMERY_PRODUCT_FUNCS.as_str(),
139 )
140 .unwrap();
141 handlebars
142 .register_template_string("barrett_funcs", BARRETT_FUNCS.as_str())
143 .unwrap();
144
145 let data = json!({
146 "word_size": self.word_size,
147 "num_words": self.num_words,
148 "num_columns": num_csr_cols,
149 "workgroup_size": workgroup_size,
150 "n0": self.n0,
151 "p_limbs": self.p_limbs,
152 "p_limbs_plus_one": self.p_limbs_plus_one,
153 "zero_limbs": self.zero_limbs,
154 "one_limbs": self.one_limbs,
155 "r_limbs": self.r_limbs,
156 "w_mask": self.w_mask,
157 "index_shift": self.index_shift,
158 "half_num_columns": num_csr_cols / 2,
159 "num_words_mul_two": self.num_words * 2,
160 "num_words_plus_one": self.num_words + 1,
161 "mu_limbs": self.mu_limbs,
162 "slack": self.slack,
163 "rinv_limbs": self.rinv_limbs,
164 "input_size": self.input_size,
165 });
166 handlebars.render("smvp", &data).unwrap()
167 }
168
169 pub fn gen_bpr_shader(&self, workgroup_size: usize) -> String {
171 let mut handlebars = Handlebars::new();
172 handlebars
173 .register_template_string("bpr", BPR_SHADER.as_str())
174 .unwrap();
175
176 handlebars
177 .register_template_string("structs", STRUCTS.as_str())
178 .unwrap();
179 handlebars
180 .register_template_string("bigint_funcs", BIGINT_FUNCS.as_str())
181 .unwrap();
182 handlebars
183 .register_template_string("ec_funcs", EC_FUNCS.as_str())
184 .unwrap();
185 handlebars
186 .register_template_string("field_funcs", FIELD_FUNCS.as_str())
187 .unwrap();
188 handlebars
189 .register_template_string(
190 "montgomery_product_funcs",
191 MONTGOMERY_PRODUCT_FUNCS.as_str(),
192 )
193 .unwrap();
194 handlebars
195 .register_template_string("barrett_funcs", BARRETT_FUNCS.as_str())
196 .unwrap();
197 let data = json!({
198 "workgroup_size": workgroup_size,
199 "word_size": self.word_size,
200 "num_words": self.num_words,
201 "n0": self.n0,
202 "p_limbs": self.p_limbs,
203 "p_limbs_plus_one": self.p_limbs_plus_one,
204 "zero_limbs": self.zero_limbs,
205 "one_limbs": self.one_limbs,
206 "r_limbs": self.r_limbs,
207 "w_mask": self.w_mask,
208 "index_shift": self.index_shift,
209 "num_words_mul_two": self.num_words * 2,
210 "num_words_plus_one": self.num_words + 1,
211 "mu_limbs": self.mu_limbs,
212 "slack": self.slack,
213 "rinv_limbs": self.rinv_limbs,
214 "input_size": self.input_size,
215 });
216 handlebars.render("bpr", &data).unwrap()
217 }
218
219 pub fn gen_decomp_scalars_shader(
221 &self,
222 workgroup_size: usize,
223 num_y_workgroups: usize,
224 num_subtasks: usize,
225 num_columns: usize,
226 ) -> String {
227 println!("num_columns: {num_columns:?}");
228 println!("num_y_workgroups: {num_y_workgroups:?}");
229 println!("num_subtasks: {num_subtasks:?}");
230 println!("workgroup_size: {workgroup_size:?}");
231 let mut handlebars = Handlebars::new();
232 handlebars
233 .register_template_string("decomp_scalars", DECOMPOSE_SCALARS_SHADER.as_str())
234 .unwrap();
235
236 handlebars
237 .register_template_string("structs", STRUCTS.as_str())
238 .unwrap();
239 handlebars
240 .register_template_string("bigint_funcs", BIGINT_FUNCS.as_str())
241 .unwrap();
242 handlebars
243 .register_template_string("field_funcs", FIELD_FUNCS.as_str())
244 .unwrap();
245 handlebars
246 .register_template_string(
247 "montgomery_product_funcs",
248 MONTGOMERY_PRODUCT_FUNCS.as_str(),
249 )
250 .unwrap();
251 handlebars
252 .register_template_string(
253 "extract_word_from_bytes_le_funcs",
254 EXTRACT_WORD_FROM_BYTES_LE_FUNCS.as_str(),
255 )
256 .unwrap();
257 handlebars
258 .register_template_string("barrett_funcs", BARRETT_FUNCS.as_str())
259 .unwrap();
260 let data = json!({
261 "workgroup_size": workgroup_size,
262 "word_size": self.word_size,
263 "chunk_size": self.chunk_size,
264 "num_words": self.num_words,
265 "num_y_workgroups": num_y_workgroups,
266 "num_subtasks": num_subtasks,
267 "num_columns": num_columns,
268 "n0": self.n0,
269 "p_limbs": self.p_limbs,
270 "p_limbs_plus_one": self.p_limbs_plus_one,
271 "zero_limbs": self.zero_limbs,
272 "one_limbs": self.one_limbs,
273 "slack": self.slack,
274 "w_mask": self.w_mask,
275 "index_shift": self.index_shift,
276 "num_words_mul_two": self.num_words * 2,
277 "num_words_plus_one": self.num_words + 1,
278 "r_limbs": self.r_limbs,
279 "mu_limbs": self.mu_limbs,
280 "slack": self.slack,
281 "rinv_limbs": self.rinv_limbs,
282 });
283 handlebars.render("decomp_scalars", &data).unwrap()
284 }
285
286 pub fn gen_test_field_shader(&self) -> String {
288 let mut handlebars = Handlebars::new();
289 handlebars
290 .register_template_string("test_field", TEST_FIELD_SHADER.as_str())
291 .unwrap();
292
293 handlebars
294 .register_template_string("structs", STRUCTS.as_str())
295 .unwrap();
296 handlebars
297 .register_template_string("bigint_funcs", BIGINT_FUNCS.as_str())
298 .unwrap();
299 handlebars
300 .register_template_string("field_funcs", FIELD_FUNCS.as_str())
301 .unwrap();
302 handlebars
303 .register_template_string(
304 "montgomery_product_funcs",
305 MONTGOMERY_PRODUCT_FUNCS.as_str(),
306 )
307 .unwrap();
308 handlebars
309 .register_template_string("barrett_funcs", BARRETT_FUNCS.as_str())
310 .unwrap();
311
312 let data = json!({
313 "word_size": self.word_size,
314 "num_words": self.num_words,
315 "p_limbs": self.p_limbs,
316 "p_limbs_plus_one": self.p_limbs_plus_one,
317 "zero_limbs": self.zero_limbs,
318 "one_limbs": self.one_limbs,
319 "r_limbs": self.r_limbs,
320 "w_mask": self.w_mask,
321 "num_words_mul_two": self.num_words * 2,
322 "num_words_plus_one": self.num_words + 1,
323 "n0": self.n0,
324 "mu_limbs": self.mu_limbs,
325 "slack": self.slack,
326 "rinv_limbs": self.rinv_limbs,
327 });
328 handlebars.render("test_field", &data).unwrap()
329 }
330
331 pub fn gen_test_point_shader(&self) -> String {
333 let mut handlebars = Handlebars::new();
334 handlebars
335 .register_template_string("test_point", TEST_POINT_SHADER.as_str())
336 .unwrap();
337
338 handlebars
339 .register_template_string("structs", STRUCTS.as_str())
340 .unwrap();
341 handlebars
342 .register_template_string("bigint_funcs", BIGINT_FUNCS.as_str())
343 .unwrap();
344 handlebars
345 .register_template_string("field_funcs", FIELD_FUNCS.as_str())
346 .unwrap();
347 handlebars
348 .register_template_string(
349 "montgomery_product_funcs",
350 MONTGOMERY_PRODUCT_FUNCS.as_str(),
351 )
352 .unwrap();
353 handlebars
354 .register_template_string("ec_funcs", EC_FUNCS.as_str())
355 .unwrap();
356 handlebars
357 .register_template_string("barrett_funcs", BARRETT_FUNCS.as_str())
358 .unwrap();
359 let data = json!({
360 "word_size": self.word_size,
361 "num_words": self.num_words,
362 "p_limbs": self.p_limbs,
363 "p_limbs_plus_one": self.p_limbs_plus_one,
364 "zero_limbs": self.zero_limbs,
365 "one_limbs": self.one_limbs,
366 "r_limbs": self.r_limbs,
367 "w_mask": self.w_mask,
368 "num_words_mul_two": self.num_words * 2,
369 "num_words_plus_one": self.num_words + 1,
370 "n0": self.n0,
371 "mu_limbs": self.mu_limbs,
372 "slack": self.slack,
373 "rinv_limbs": self.rinv_limbs,
374 });
375 handlebars.render("test_point", &data).unwrap()
376 }
377}