msm_webgpu/cuzk/
shader_manager.rs

1use handlebars::Handlebars;
2use once_cell::sync::Lazy;
3use serde_json::json;
4
5/// Decompose scalars shader
6pub static DECOMPOSE_SCALARS_SHADER: Lazy<String> =
7    Lazy::new(|| include_str!("wgsl/cuzk/decompose_scalars.template.wgsl").to_string());
8/// Extract word from bytes least significant end shader
9pub static EXTRACT_WORD_FROM_BYTES_LE_FUNCS: Lazy<String> =
10    Lazy::new(|| include_str!("wgsl/cuzk/extract_word_from_bytes_le.template.wgsl").to_string());
11/// Montgomery product shader
12pub static MONTGOMERY_PRODUCT_FUNCS: Lazy<String> =
13    Lazy::new(|| include_str!("wgsl/montgomery/mont_pro_product.template.wgsl").to_string());
14/// Barrett reduction shader
15pub static BARRETT_FUNCS: Lazy<String> =
16    Lazy::new(|| include_str!("wgsl/field/barrett.template.wgsl").to_string());
17/// Curve operations shader
18pub static EC_FUNCS: Lazy<String> =
19    Lazy::new(|| include_str!("wgsl/curve/ec.template.wgsl").to_string());
20/// Field operations shader
21pub static FIELD_FUNCS: Lazy<String> =
22    Lazy::new(|| include_str!("wgsl/field/field.template.wgsl").to_string());
23/// Big integer operations shader
24pub static BIGINT_FUNCS: Lazy<String> =
25    Lazy::new(|| include_str!("wgsl/bigint/bigint.template.wgsl").to_string());
26/// Structs shader
27pub static STRUCTS: Lazy<String> =
28    Lazy::new(|| include_str!("wgsl/struct/structs.template.wgsl").to_string());
29
30/// Transpose shader
31pub static TRANSPOSE_SHADER: Lazy<String> =
32    Lazy::new(|| include_str!("wgsl/cuzk/transpose.template.wgsl").to_string());
33/// Sparse matrix-vector product shader
34pub static SMVP_SHADER: Lazy<String> =
35    Lazy::new(|| include_str!("wgsl/cuzk/smvp.template.wgsl").to_string());
36/// Batch product reduction shader
37pub static BPR_SHADER: Lazy<String> =
38    Lazy::new(|| include_str!("wgsl/cuzk/bpr.template.wgsl").to_string());
39/// Test field shader
40pub static TEST_FIELD_SHADER: Lazy<String> =
41    Lazy::new(|| include_str!("wgsl/test/test_field.wgsl").to_string());
42/// Test point shader
43pub static TEST_POINT_SHADER: Lazy<String> =
44    Lazy::new(|| include_str!("wgsl/test/test_point.wgsl").to_string());
45
46use crate::cuzk::utils::{calc_bitwidth, gen_mu_limbs, gen_one_limbs, gen_p_limbs, gen_rinv_limbs};
47
48use super::{
49    msm::{P, PARAMS},
50    utils::{gen_p_limbs_plus_one, gen_r_limbs, gen_zero_limbs},
51};
52
53/// Shader manager
54pub struct ShaderManager {
55    word_size: usize,
56    chunk_size: usize,
57    input_size: usize,
58    num_words: usize,
59    index_shift: usize,
60    p_limbs: String,
61    p_limbs_plus_one: String,
62    zero_limbs: String,
63    one_limbs: String,
64    r_limbs: String,
65    slack: usize,
66    w_mask: usize,
67    n0: u32,
68    mu_limbs: String,
69    rinv_limbs: String,
70}
71
72impl ShaderManager {
73    /// Create a new shader manager
74    pub fn new(word_size: usize, chunk_size: usize, input_size: usize) -> Self {
75        let p_bit_length = calc_bitwidth(&P);
76        let num_words = PARAMS.num_words;
77        let r = PARAMS.r.clone();
78        let rinv = PARAMS.rinv.clone();
79        println!("P: {P:?}");
80        println!("P limbs: {}", gen_p_limbs(&P, num_words, word_size));
81        println!("W_MASK: {:?}", (1 << word_size) - 1);
82        println!("R limbs: {}", gen_r_limbs(&r, num_words, word_size));
83        Self {
84            word_size,
85            chunk_size,
86            input_size,
87            num_words,
88            index_shift: 1 << (chunk_size - 1),
89            p_limbs: gen_p_limbs(&P, num_words, word_size),
90            p_limbs_plus_one: gen_p_limbs_plus_one(&P, num_words, word_size),
91            zero_limbs: gen_zero_limbs(num_words),
92            one_limbs: gen_one_limbs(num_words),
93            slack: num_words * word_size - p_bit_length,
94            w_mask: (1 << word_size) - 1,
95            n0: PARAMS.n0,
96            r_limbs: gen_r_limbs(&r, num_words, word_size),
97            mu_limbs: gen_mu_limbs(&P, num_words, word_size),
98            rinv_limbs: gen_rinv_limbs(&rinv, num_words, word_size),
99        }
100    }
101
102    /// Generate the transpose shader
103    pub fn gen_transpose_shader(&self, workgroup_size: usize) -> String {
104        let mut handlebars = Handlebars::new();
105        handlebars
106            .register_template_string("transpose", TRANSPOSE_SHADER.as_str())
107            .unwrap();
108        let data = json!({
109            "workgroup_size": workgroup_size,
110        });
111        handlebars.render("transpose", &data).unwrap()
112    }
113
114    /// Generate the sparse matrix-vector product shader
115    pub fn gen_smvp_shader(&self, workgroup_size: usize, num_csr_cols: usize) -> String {
116        println!("num_csr_cols: {num_csr_cols:?}");
117        println!("workgroup_size: {workgroup_size:?}");
118        let mut handlebars = Handlebars::new();
119        handlebars
120            .register_template_string("smvp", SMVP_SHADER.as_str())
121            .unwrap();
122
123        handlebars
124            .register_template_string("structs", STRUCTS.as_str())
125            .unwrap();
126        handlebars
127            .register_template_string("bigint_funcs", BIGINT_FUNCS.as_str())
128            .unwrap();
129        handlebars
130            .register_template_string("ec_funcs", EC_FUNCS.as_str())
131            .unwrap();
132        handlebars
133            .register_template_string("field_funcs", FIELD_FUNCS.as_str())
134            .unwrap();
135        handlebars
136            .register_template_string(
137                "montgomery_product_funcs",
138                MONTGOMERY_PRODUCT_FUNCS.as_str(),
139            )
140            .unwrap();
141        handlebars
142            .register_template_string("barrett_funcs", BARRETT_FUNCS.as_str())
143            .unwrap();
144
145        let data = json!({
146            "word_size": self.word_size,
147            "num_words": self.num_words,
148            "num_columns": num_csr_cols,
149            "workgroup_size": workgroup_size,
150            "n0": self.n0,
151            "p_limbs": self.p_limbs,
152            "p_limbs_plus_one": self.p_limbs_plus_one,
153            "zero_limbs": self.zero_limbs,
154            "one_limbs": self.one_limbs,
155            "r_limbs": self.r_limbs,
156            "w_mask": self.w_mask,
157            "index_shift": self.index_shift,
158            "half_num_columns": num_csr_cols / 2,
159            "num_words_mul_two": self.num_words * 2,
160            "num_words_plus_one": self.num_words + 1,
161            "mu_limbs": self.mu_limbs,
162            "slack": self.slack,
163            "rinv_limbs": self.rinv_limbs,
164            "input_size": self.input_size,
165        });
166        handlebars.render("smvp", &data).unwrap()
167    }
168
169    /// Generate the batch product reduction shader
170    pub fn gen_bpr_shader(&self, workgroup_size: usize) -> String {
171        let mut handlebars = Handlebars::new();
172        handlebars
173            .register_template_string("bpr", BPR_SHADER.as_str())
174            .unwrap();
175
176        handlebars
177            .register_template_string("structs", STRUCTS.as_str())
178            .unwrap();
179        handlebars
180            .register_template_string("bigint_funcs", BIGINT_FUNCS.as_str())
181            .unwrap();
182        handlebars
183            .register_template_string("ec_funcs", EC_FUNCS.as_str())
184            .unwrap();
185        handlebars
186            .register_template_string("field_funcs", FIELD_FUNCS.as_str())
187            .unwrap();
188        handlebars
189            .register_template_string(
190                "montgomery_product_funcs",
191                MONTGOMERY_PRODUCT_FUNCS.as_str(),
192            )
193            .unwrap();
194        handlebars
195            .register_template_string("barrett_funcs", BARRETT_FUNCS.as_str())
196            .unwrap();
197        let data = json!({
198            "workgroup_size": workgroup_size,
199            "word_size": self.word_size,
200            "num_words": self.num_words,
201            "n0": self.n0,
202            "p_limbs": self.p_limbs,
203            "p_limbs_plus_one": self.p_limbs_plus_one,
204            "zero_limbs": self.zero_limbs,
205            "one_limbs": self.one_limbs,
206            "r_limbs": self.r_limbs,
207            "w_mask": self.w_mask,
208            "index_shift": self.index_shift,
209            "num_words_mul_two": self.num_words * 2,
210            "num_words_plus_one": self.num_words + 1,
211            "mu_limbs": self.mu_limbs,
212            "slack": self.slack,
213            "rinv_limbs": self.rinv_limbs,
214            "input_size": self.input_size,
215        });
216        handlebars.render("bpr", &data).unwrap()
217    }
218
219    /// Generate the decompose scalars shader
220    pub fn gen_decomp_scalars_shader(
221        &self,
222        workgroup_size: usize,
223        num_y_workgroups: usize,
224        num_subtasks: usize,
225        num_columns: usize,
226    ) -> String {
227        println!("num_columns: {num_columns:?}");
228        println!("num_y_workgroups: {num_y_workgroups:?}");
229        println!("num_subtasks: {num_subtasks:?}");
230        println!("workgroup_size: {workgroup_size:?}");
231        let mut handlebars = Handlebars::new();
232        handlebars
233            .register_template_string("decomp_scalars", DECOMPOSE_SCALARS_SHADER.as_str())
234            .unwrap();
235
236        handlebars
237            .register_template_string("structs", STRUCTS.as_str())
238            .unwrap();
239        handlebars
240            .register_template_string("bigint_funcs", BIGINT_FUNCS.as_str())
241            .unwrap();
242        handlebars
243            .register_template_string("field_funcs", FIELD_FUNCS.as_str())
244            .unwrap();
245        handlebars
246            .register_template_string(
247                "montgomery_product_funcs",
248                MONTGOMERY_PRODUCT_FUNCS.as_str(),
249            )
250            .unwrap();
251        handlebars
252            .register_template_string(
253                "extract_word_from_bytes_le_funcs",
254                EXTRACT_WORD_FROM_BYTES_LE_FUNCS.as_str(),
255            )
256            .unwrap();
257        handlebars
258            .register_template_string("barrett_funcs", BARRETT_FUNCS.as_str())
259            .unwrap();
260        let data = json!({
261            "workgroup_size": workgroup_size,
262            "word_size": self.word_size,
263            "chunk_size": self.chunk_size,
264            "num_words": self.num_words,
265            "num_y_workgroups": num_y_workgroups,
266            "num_subtasks": num_subtasks,
267            "num_columns": num_columns,
268            "n0": self.n0,
269            "p_limbs": self.p_limbs,
270            "p_limbs_plus_one": self.p_limbs_plus_one,
271            "zero_limbs": self.zero_limbs,
272            "one_limbs": self.one_limbs,
273            "slack": self.slack,
274            "w_mask": self.w_mask,
275            "index_shift": self.index_shift,
276            "num_words_mul_two": self.num_words * 2,
277            "num_words_plus_one": self.num_words + 1,
278            "r_limbs": self.r_limbs,
279            "mu_limbs": self.mu_limbs,
280            "slack": self.slack,
281            "rinv_limbs": self.rinv_limbs,
282        });
283        handlebars.render("decomp_scalars", &data).unwrap()
284    }
285
286    /// Generate the test field shader
287    pub fn gen_test_field_shader(&self) -> String {
288        let mut handlebars = Handlebars::new();
289        handlebars
290            .register_template_string("test_field", TEST_FIELD_SHADER.as_str())
291            .unwrap();
292
293        handlebars
294            .register_template_string("structs", STRUCTS.as_str())
295            .unwrap();
296        handlebars
297            .register_template_string("bigint_funcs", BIGINT_FUNCS.as_str())
298            .unwrap();
299        handlebars
300            .register_template_string("field_funcs", FIELD_FUNCS.as_str())
301            .unwrap();
302        handlebars
303            .register_template_string(
304                "montgomery_product_funcs",
305                MONTGOMERY_PRODUCT_FUNCS.as_str(),
306            )
307            .unwrap();
308        handlebars
309            .register_template_string("barrett_funcs", BARRETT_FUNCS.as_str())
310            .unwrap();
311
312        let data = json!({
313            "word_size": self.word_size,
314            "num_words": self.num_words,
315            "p_limbs": self.p_limbs,
316            "p_limbs_plus_one": self.p_limbs_plus_one,
317            "zero_limbs": self.zero_limbs,
318            "one_limbs": self.one_limbs,
319            "r_limbs": self.r_limbs,
320            "w_mask": self.w_mask,
321            "num_words_mul_two": self.num_words * 2,
322            "num_words_plus_one": self.num_words + 1,
323            "n0": self.n0,
324            "mu_limbs": self.mu_limbs,
325            "slack": self.slack,
326            "rinv_limbs": self.rinv_limbs,
327        });
328        handlebars.render("test_field", &data).unwrap()
329    }
330
331    /// Generate the test point shader
332    pub fn gen_test_point_shader(&self) -> String {
333        let mut handlebars = Handlebars::new();
334        handlebars
335            .register_template_string("test_point", TEST_POINT_SHADER.as_str())
336            .unwrap();
337
338        handlebars
339            .register_template_string("structs", STRUCTS.as_str())
340            .unwrap();
341        handlebars
342            .register_template_string("bigint_funcs", BIGINT_FUNCS.as_str())
343            .unwrap();
344        handlebars
345            .register_template_string("field_funcs", FIELD_FUNCS.as_str())
346            .unwrap();
347        handlebars
348            .register_template_string(
349                "montgomery_product_funcs",
350                MONTGOMERY_PRODUCT_FUNCS.as_str(),
351            )
352            .unwrap();
353        handlebars
354            .register_template_string("ec_funcs", EC_FUNCS.as_str())
355            .unwrap();
356        handlebars
357            .register_template_string("barrett_funcs", BARRETT_FUNCS.as_str())
358            .unwrap();
359        let data = json!({
360            "word_size": self.word_size,
361            "num_words": self.num_words,
362            "p_limbs": self.p_limbs,
363            "p_limbs_plus_one": self.p_limbs_plus_one,
364            "zero_limbs": self.zero_limbs,
365            "one_limbs": self.one_limbs,
366            "r_limbs": self.r_limbs,
367            "w_mask": self.w_mask,
368            "num_words_mul_two": self.num_words * 2,
369            "num_words_plus_one": self.num_words + 1,
370            "n0": self.n0,
371            "mu_limbs": self.mu_limbs,
372            "slack": self.slack,
373            "rinv_limbs": self.rinv_limbs,
374        });
375        handlebars.render("test_point", &data).unwrap()
376    }
377}