1pub const MATRIX_OPS_SHADER: &str = include_str!("advanced/matrix_ops.wgsl");
8
9pub const FFT_SHADER: &str = include_str!("advanced/fft.wgsl");
11
12pub const HISTOGRAM_EQ_SHADER: &str = include_str!("advanced/histogram_eq.wgsl");
14
15pub const MORPHOLOGY_SHADER: &str = include_str!("advanced/morphology.wgsl");
17
18pub const EDGE_DETECTION_SHADER: &str = include_str!("advanced/edge_detection.wgsl");
20
21pub const TEXTURE_ANALYSIS_SHADER: &str = include_str!("advanced/texture_analysis.wgsl");
23
24pub struct KernelRegistry {
26 shaders: std::collections::HashMap<String, String>,
27}
28
29impl KernelRegistry {
30 pub fn new() -> Self {
32 let mut shaders = std::collections::HashMap::new();
33
34 shaders.insert("matrix_ops".to_string(), MATRIX_OPS_SHADER.to_string());
35 shaders.insert("fft".to_string(), FFT_SHADER.to_string());
36 shaders.insert("histogram_eq".to_string(), HISTOGRAM_EQ_SHADER.to_string());
37 shaders.insert("morphology".to_string(), MORPHOLOGY_SHADER.to_string());
38 shaders.insert(
39 "edge_detection".to_string(),
40 EDGE_DETECTION_SHADER.to_string(),
41 );
42 shaders.insert(
43 "texture_analysis".to_string(),
44 TEXTURE_ANALYSIS_SHADER.to_string(),
45 );
46
47 Self { shaders }
48 }
49
50 pub fn get_shader(&self, name: &str) -> Option<&str> {
52 self.shaders.get(name).map(|s| s.as_str())
53 }
54
55 pub fn register_shader(&mut self, name: String, source: String) {
57 self.shaders.insert(name, source);
58 }
59
60 pub fn list_shaders(&self) -> Vec<&str> {
62 self.shaders.keys().map(|k| k.as_str()).collect()
63 }
64
65 pub fn has_shader(&self, name: &str) -> bool {
67 self.shaders.contains_key(name)
68 }
69
70 pub fn remove_shader(&mut self, name: &str) -> bool {
72 self.shaders.remove(name).is_some()
73 }
74
75 pub fn shader_count(&self) -> usize {
77 self.shaders.len()
78 }
79}
80
81impl Default for KernelRegistry {
82 fn default() -> Self {
83 Self::new()
84 }
85}
86
87#[derive(Debug, Clone)]
89pub struct KernelParams {
90 pub workgroup_size: (u32, u32, u32),
92 pub dispatch_size: (u32, u32, u32),
94 pub entry_point: String,
96}
97
98impl Default for KernelParams {
99 fn default() -> Self {
100 Self {
101 workgroup_size: (8, 8, 1),
102 dispatch_size: (1, 1, 1),
103 entry_point: "main".to_string(),
104 }
105 }
106}
107
108impl KernelParams {
109 pub fn new(workgroup_size: (u32, u32, u32), dispatch_size: (u32, u32, u32)) -> Self {
111 Self {
112 workgroup_size,
113 dispatch_size,
114 entry_point: "main".to_string(),
115 }
116 }
117
118 pub fn with_workgroup_size(mut self, x: u32, y: u32, z: u32) -> Self {
120 self.workgroup_size = (x, y, z);
121 self
122 }
123
124 pub fn with_dispatch_size(mut self, x: u32, y: u32, z: u32) -> Self {
126 self.dispatch_size = (x, y, z);
127 self
128 }
129
130 pub fn with_entry_point(mut self, entry_point: impl Into<String>) -> Self {
132 self.entry_point = entry_point.into();
133 self
134 }
135
136 pub fn total_threads(&self) -> u64 {
138 let (wg_x, wg_y, wg_z) = self.workgroup_size;
139 let (d_x, d_y, d_z) = self.dispatch_size;
140
141 (wg_x as u64 * d_x as u64) * (wg_y as u64 * d_y as u64) * (wg_z as u64 * d_z as u64)
142 }
143
144 pub fn calculate_dispatch_size(
146 data_width: u32,
147 data_height: u32,
148 workgroup_size: (u32, u32, u32),
149 ) -> (u32, u32, u32) {
150 let (wg_x, wg_y, _wg_z) = workgroup_size;
151
152 let dispatch_x = data_width.div_ceil(wg_x);
153 let dispatch_y = data_height.div_ceil(wg_y);
154 let dispatch_z = 1;
155
156 (dispatch_x, dispatch_y, dispatch_z)
157 }
158}
159
160pub struct MatrixMultiplyKernel;
162
163impl MatrixMultiplyKernel {
164 pub fn shader() -> &'static str {
166 MATRIX_OPS_SHADER
167 }
168
169 pub fn params(m: u32, n: u32, _k: u32, tiled: bool) -> KernelParams {
171 if tiled {
172 let workgroup_size = (16, 16, 1);
173 let dispatch_x = n.div_ceil(16);
174 let dispatch_y = m.div_ceil(16);
175
176 KernelParams {
177 workgroup_size,
178 dispatch_size: (dispatch_x, dispatch_y, 1),
179 entry_point: "matrix_multiply_tiled".to_string(),
180 }
181 } else {
182 let workgroup_size = (8, 8, 1);
183 let dispatch_x = n.div_ceil(8);
184 let dispatch_y = m.div_ceil(8);
185
186 KernelParams {
187 workgroup_size,
188 dispatch_size: (dispatch_x, dispatch_y, 1),
189 entry_point: "matrix_multiply_naive".to_string(),
190 }
191 }
192 }
193}
194
195pub struct FftKernel;
197
198impl FftKernel {
199 pub fn shader() -> &'static str {
201 FFT_SHADER
202 }
203
204 pub fn params(n: u32) -> KernelParams {
206 let workgroup_size = (256, 1, 1);
207 let dispatch_size = (n.div_ceil(256), 1, 1);
208
209 KernelParams {
210 workgroup_size,
211 dispatch_size,
212 entry_point: "fft_cooley_tukey".to_string(),
213 }
214 }
215
216 pub fn num_stages(n: u32) -> u32 {
218 (n as f32).log2() as u32
219 }
220}
221
222pub struct HistogramEqKernel;
224
225impl HistogramEqKernel {
226 pub fn shader() -> &'static str {
228 HISTOGRAM_EQ_SHADER
229 }
230
231 pub fn compute_histogram_params(width: u32, height: u32) -> KernelParams {
233 let workgroup_size = (16, 16, 1);
234 let dispatch_x = width.div_ceil(16);
235 let dispatch_y = height.div_ceil(16);
236
237 KernelParams {
238 workgroup_size,
239 dispatch_size: (dispatch_x, dispatch_y, 1),
240 entry_point: "compute_histogram".to_string(),
241 }
242 }
243
244 pub fn equalize_params(width: u32, height: u32) -> KernelParams {
246 let workgroup_size = (16, 16, 1);
247 let dispatch_x = width.div_ceil(16);
248 let dispatch_y = height.div_ceil(16);
249
250 KernelParams {
251 workgroup_size,
252 dispatch_size: (dispatch_x, dispatch_y, 1),
253 entry_point: "histogram_equalize".to_string(),
254 }
255 }
256}
257
258pub struct EdgeDetectionKernel;
260
261impl EdgeDetectionKernel {
262 pub fn shader() -> &'static str {
264 EDGE_DETECTION_SHADER
265 }
266
267 pub fn sobel_params(width: u32, height: u32) -> KernelParams {
269 let workgroup_size = (16, 16, 1);
270 let dispatch_x = width.div_ceil(16);
271 let dispatch_y = height.div_ceil(16);
272
273 KernelParams {
274 workgroup_size,
275 dispatch_size: (dispatch_x, dispatch_y, 1),
276 entry_point: "sobel".to_string(),
277 }
278 }
279
280 pub fn canny_gradient_params(width: u32, height: u32) -> KernelParams {
282 let workgroup_size = (16, 16, 1);
283 let dispatch_x = width.div_ceil(16);
284 let dispatch_y = height.div_ceil(16);
285
286 KernelParams {
287 workgroup_size,
288 dispatch_size: (dispatch_x, dispatch_y, 1),
289 entry_point: "canny_gradient".to_string(),
290 }
291 }
292}
293
294pub struct MorphologyKernel;
296
297impl MorphologyKernel {
298 pub fn shader() -> &'static str {
300 MORPHOLOGY_SHADER
301 }
302
303 pub fn dilate_params(width: u32, height: u32) -> KernelParams {
305 let workgroup_size = (16, 16, 1);
306 let dispatch_x = width.div_ceil(16);
307 let dispatch_y = height.div_ceil(16);
308
309 KernelParams {
310 workgroup_size,
311 dispatch_size: (dispatch_x, dispatch_y, 1),
312 entry_point: "dilate".to_string(),
313 }
314 }
315
316 pub fn erode_params(width: u32, height: u32) -> KernelParams {
318 let workgroup_size = (16, 16, 1);
319 let dispatch_x = width.div_ceil(16);
320 let dispatch_y = height.div_ceil(16);
321
322 KernelParams {
323 workgroup_size,
324 dispatch_size: (dispatch_x, dispatch_y, 1),
325 entry_point: "erode".to_string(),
326 }
327 }
328}
329
330pub struct TextureAnalysisKernel;
332
333impl TextureAnalysisKernel {
334 pub fn shader() -> &'static str {
336 TEXTURE_ANALYSIS_SHADER
337 }
338
339 pub fn glcm_params(width: u32, height: u32) -> KernelParams {
341 let workgroup_size = (16, 16, 1);
342 let dispatch_x = width.div_ceil(16);
343 let dispatch_y = height.div_ceil(16);
344
345 KernelParams {
346 workgroup_size,
347 dispatch_size: (dispatch_x, dispatch_y, 1),
348 entry_point: "compute_glcm".to_string(),
349 }
350 }
351
352 pub fn lbp_params(width: u32, height: u32) -> KernelParams {
354 let workgroup_size = (16, 16, 1);
355 let dispatch_x = width.div_ceil(16);
356 let dispatch_y = height.div_ceil(16);
357
358 KernelParams {
359 workgroup_size,
360 dispatch_size: (dispatch_x, dispatch_y, 1),
361 entry_point: "local_binary_pattern".to_string(),
362 }
363 }
364}
365
366#[cfg(test)]
367mod tests {
368 use super::*;
369
370 #[test]
371 fn test_kernel_registry() {
372 let registry = KernelRegistry::new();
373 assert_eq!(registry.shader_count(), 6);
374 assert!(registry.has_shader("matrix_ops"));
375 assert!(registry.has_shader("fft"));
376 assert!(registry.has_shader("histogram_eq"));
377 }
378
379 #[test]
380 fn test_kernel_registry_custom() {
381 let mut registry = KernelRegistry::new();
382 let initial_count = registry.shader_count();
383
384 registry.register_shader("custom".to_string(), "custom shader code".to_string());
385 assert_eq!(registry.shader_count(), initial_count + 1);
386 assert!(registry.has_shader("custom"));
387
388 assert!(registry.remove_shader("custom"));
389 assert_eq!(registry.shader_count(), initial_count);
390 }
391
392 #[test]
393 fn test_kernel_params() {
394 let params = KernelParams::default();
395 assert_eq!(params.workgroup_size, (8, 8, 1));
396 assert_eq!(params.entry_point, "main");
397 }
398
399 #[test]
400 fn test_kernel_params_total_threads() {
401 let params = KernelParams::new((8, 8, 1), (10, 10, 1));
402 assert_eq!(params.total_threads(), 8 * 8 * 10 * 10);
403 }
404
405 #[test]
406 fn test_calculate_dispatch_size() {
407 let (dx, dy, dz) = KernelParams::calculate_dispatch_size(1920, 1080, (16, 16, 1));
408 assert_eq!(dx, 1920_u32.div_ceil(16));
409 assert_eq!(dy, 1080_u32.div_ceil(16));
410 assert_eq!(dz, 1);
411 }
412
413 #[test]
414 fn test_matrix_multiply_kernel() {
415 let params = MatrixMultiplyKernel::params(1024, 1024, 1024, true);
416 assert_eq!(params.entry_point, "matrix_multiply_tiled");
417 assert_eq!(params.workgroup_size, (16, 16, 1));
418 }
419
420 #[test]
421 fn test_fft_kernel() {
422 let params = FftKernel::params(1024);
423 assert_eq!(params.entry_point, "fft_cooley_tukey");
424
425 let stages = FftKernel::num_stages(1024);
426 assert_eq!(stages, 10); }
428
429 #[test]
430 fn test_all_shaders_available() {
431 assert!(!MATRIX_OPS_SHADER.is_empty());
432 assert!(!FFT_SHADER.is_empty());
433 assert!(!HISTOGRAM_EQ_SHADER.is_empty());
434 assert!(!MORPHOLOGY_SHADER.is_empty());
435 assert!(!EDGE_DETECTION_SHADER.is_empty());
436 assert!(!TEXTURE_ANALYSIS_SHADER.is_empty());
437 }
438}