Skip to main content

oxigdal_gpu_advanced/kernels/
mod.rs

1//! Advanced GPU compute kernels and WGSL shaders.
2//!
3//! This module provides access to optimized WGSL shaders for various
4//! geospatial and image processing operations.
5
6/// WGSL shader for matrix operations (GEMM)
7pub const MATRIX_OPS_SHADER: &str = include_str!("advanced/matrix_ops.wgsl");
8
9/// WGSL shader for Fast Fourier Transform
10pub const FFT_SHADER: &str = include_str!("advanced/fft.wgsl");
11
12/// WGSL shader for histogram equalization
13pub const HISTOGRAM_EQ_SHADER: &str = include_str!("advanced/histogram_eq.wgsl");
14
15/// WGSL shader for morphological operations
16pub const MORPHOLOGY_SHADER: &str = include_str!("advanced/morphology.wgsl");
17
18/// WGSL shader for edge detection
19pub const EDGE_DETECTION_SHADER: &str = include_str!("advanced/edge_detection.wgsl");
20
21/// WGSL shader for texture analysis
22pub const TEXTURE_ANALYSIS_SHADER: &str = include_str!("advanced/texture_analysis.wgsl");
23
24/// Kernel registry for managing and accessing shaders
25pub struct KernelRegistry {
26    shaders: std::collections::HashMap<String, String>,
27}
28
29impl KernelRegistry {
30    /// Create a new kernel registry with built-in shaders
31    pub fn new() -> Self {
32        let mut shaders = std::collections::HashMap::new();
33
34        shaders.insert("matrix_ops".to_string(), MATRIX_OPS_SHADER.to_string());
35        shaders.insert("fft".to_string(), FFT_SHADER.to_string());
36        shaders.insert("histogram_eq".to_string(), HISTOGRAM_EQ_SHADER.to_string());
37        shaders.insert("morphology".to_string(), MORPHOLOGY_SHADER.to_string());
38        shaders.insert(
39            "edge_detection".to_string(),
40            EDGE_DETECTION_SHADER.to_string(),
41        );
42        shaders.insert(
43            "texture_analysis".to_string(),
44            TEXTURE_ANALYSIS_SHADER.to_string(),
45        );
46
47        Self { shaders }
48    }
49
50    /// Get a shader by name
51    pub fn get_shader(&self, name: &str) -> Option<&str> {
52        self.shaders.get(name).map(|s| s.as_str())
53    }
54
55    /// Register a custom shader
56    pub fn register_shader(&mut self, name: String, source: String) {
57        self.shaders.insert(name, source);
58    }
59
60    /// List all available shader names
61    pub fn list_shaders(&self) -> Vec<&str> {
62        self.shaders.keys().map(|k| k.as_str()).collect()
63    }
64
65    /// Check if a shader exists
66    pub fn has_shader(&self, name: &str) -> bool {
67        self.shaders.contains_key(name)
68    }
69
70    /// Remove a shader
71    pub fn remove_shader(&mut self, name: &str) -> bool {
72        self.shaders.remove(name).is_some()
73    }
74
75    /// Get the number of registered shaders
76    pub fn shader_count(&self) -> usize {
77        self.shaders.len()
78    }
79}
80
81impl Default for KernelRegistry {
82    fn default() -> Self {
83        Self::new()
84    }
85}
86
87/// Kernel execution parameters
88#[derive(Debug, Clone)]
89pub struct KernelParams {
90    /// Workgroup size (x, y, z)
91    pub workgroup_size: (u32, u32, u32),
92    /// Number of workgroups to dispatch (x, y, z)
93    pub dispatch_size: (u32, u32, u32),
94    /// Entry point function name
95    pub entry_point: String,
96}
97
98impl Default for KernelParams {
99    fn default() -> Self {
100        Self {
101            workgroup_size: (8, 8, 1),
102            dispatch_size: (1, 1, 1),
103            entry_point: "main".to_string(),
104        }
105    }
106}
107
108impl KernelParams {
109    /// Create new kernel parameters
110    pub fn new(workgroup_size: (u32, u32, u32), dispatch_size: (u32, u32, u32)) -> Self {
111        Self {
112            workgroup_size,
113            dispatch_size,
114            entry_point: "main".to_string(),
115        }
116    }
117
118    /// Set workgroup size
119    pub fn with_workgroup_size(mut self, x: u32, y: u32, z: u32) -> Self {
120        self.workgroup_size = (x, y, z);
121        self
122    }
123
124    /// Set dispatch size
125    pub fn with_dispatch_size(mut self, x: u32, y: u32, z: u32) -> Self {
126        self.dispatch_size = (x, y, z);
127        self
128    }
129
130    /// Set entry point
131    pub fn with_entry_point(mut self, entry_point: impl Into<String>) -> Self {
132        self.entry_point = entry_point.into();
133        self
134    }
135
136    /// Calculate total number of threads
137    pub fn total_threads(&self) -> u64 {
138        let (wg_x, wg_y, wg_z) = self.workgroup_size;
139        let (d_x, d_y, d_z) = self.dispatch_size;
140
141        (wg_x as u64 * d_x as u64) * (wg_y as u64 * d_y as u64) * (wg_z as u64 * d_z as u64)
142    }
143
144    /// Calculate optimal dispatch size for a given data size
145    pub fn calculate_dispatch_size(
146        data_width: u32,
147        data_height: u32,
148        workgroup_size: (u32, u32, u32),
149    ) -> (u32, u32, u32) {
150        let (wg_x, wg_y, _wg_z) = workgroup_size;
151
152        let dispatch_x = data_width.div_ceil(wg_x);
153        let dispatch_y = data_height.div_ceil(wg_y);
154        let dispatch_z = 1;
155
156        (dispatch_x, dispatch_y, dispatch_z)
157    }
158}
159
160/// Matrix multiplication kernel helper
161pub struct MatrixMultiplyKernel;
162
163impl MatrixMultiplyKernel {
164    /// Get shader source
165    pub fn shader() -> &'static str {
166        MATRIX_OPS_SHADER
167    }
168
169    /// Create parameters for matrix multiplication
170    pub fn params(m: u32, n: u32, _k: u32, tiled: bool) -> KernelParams {
171        if tiled {
172            let workgroup_size = (16, 16, 1);
173            let dispatch_x = n.div_ceil(16);
174            let dispatch_y = m.div_ceil(16);
175
176            KernelParams {
177                workgroup_size,
178                dispatch_size: (dispatch_x, dispatch_y, 1),
179                entry_point: "matrix_multiply_tiled".to_string(),
180            }
181        } else {
182            let workgroup_size = (8, 8, 1);
183            let dispatch_x = n.div_ceil(8);
184            let dispatch_y = m.div_ceil(8);
185
186            KernelParams {
187                workgroup_size,
188                dispatch_size: (dispatch_x, dispatch_y, 1),
189                entry_point: "matrix_multiply_naive".to_string(),
190            }
191        }
192    }
193}
194
195/// FFT kernel helper
196pub struct FftKernel;
197
198impl FftKernel {
199    /// Get shader source
200    pub fn shader() -> &'static str {
201        FFT_SHADER
202    }
203
204    /// Create parameters for FFT
205    pub fn params(n: u32) -> KernelParams {
206        let workgroup_size = (256, 1, 1);
207        let dispatch_size = (n.div_ceil(256), 1, 1);
208
209        KernelParams {
210            workgroup_size,
211            dispatch_size,
212            entry_point: "fft_cooley_tukey".to_string(),
213        }
214    }
215
216    /// Calculate number of FFT stages needed
217    pub fn num_stages(n: u32) -> u32 {
218        (n as f32).log2() as u32
219    }
220}
221
222/// Histogram equalization kernel helper
223pub struct HistogramEqKernel;
224
225impl HistogramEqKernel {
226    /// Get shader source
227    pub fn shader() -> &'static str {
228        HISTOGRAM_EQ_SHADER
229    }
230
231    /// Create parameters for histogram computation
232    pub fn compute_histogram_params(width: u32, height: u32) -> KernelParams {
233        let workgroup_size = (16, 16, 1);
234        let dispatch_x = width.div_ceil(16);
235        let dispatch_y = height.div_ceil(16);
236
237        KernelParams {
238            workgroup_size,
239            dispatch_size: (dispatch_x, dispatch_y, 1),
240            entry_point: "compute_histogram".to_string(),
241        }
242    }
243
244    /// Create parameters for equalization
245    pub fn equalize_params(width: u32, height: u32) -> KernelParams {
246        let workgroup_size = (16, 16, 1);
247        let dispatch_x = width.div_ceil(16);
248        let dispatch_y = height.div_ceil(16);
249
250        KernelParams {
251            workgroup_size,
252            dispatch_size: (dispatch_x, dispatch_y, 1),
253            entry_point: "histogram_equalize".to_string(),
254        }
255    }
256}
257
258/// Edge detection kernel helper
259pub struct EdgeDetectionKernel;
260
261impl EdgeDetectionKernel {
262    /// Get shader source
263    pub fn shader() -> &'static str {
264        EDGE_DETECTION_SHADER
265    }
266
267    /// Create parameters for Sobel edge detection
268    pub fn sobel_params(width: u32, height: u32) -> KernelParams {
269        let workgroup_size = (16, 16, 1);
270        let dispatch_x = width.div_ceil(16);
271        let dispatch_y = height.div_ceil(16);
272
273        KernelParams {
274            workgroup_size,
275            dispatch_size: (dispatch_x, dispatch_y, 1),
276            entry_point: "sobel".to_string(),
277        }
278    }
279
280    /// Create parameters for Canny edge detection (gradient step)
281    pub fn canny_gradient_params(width: u32, height: u32) -> KernelParams {
282        let workgroup_size = (16, 16, 1);
283        let dispatch_x = width.div_ceil(16);
284        let dispatch_y = height.div_ceil(16);
285
286        KernelParams {
287            workgroup_size,
288            dispatch_size: (dispatch_x, dispatch_y, 1),
289            entry_point: "canny_gradient".to_string(),
290        }
291    }
292}
293
294/// Morphology kernel helper
295pub struct MorphologyKernel;
296
297impl MorphologyKernel {
298    /// Get shader source
299    pub fn shader() -> &'static str {
300        MORPHOLOGY_SHADER
301    }
302
303    /// Create parameters for dilation
304    pub fn dilate_params(width: u32, height: u32) -> KernelParams {
305        let workgroup_size = (16, 16, 1);
306        let dispatch_x = width.div_ceil(16);
307        let dispatch_y = height.div_ceil(16);
308
309        KernelParams {
310            workgroup_size,
311            dispatch_size: (dispatch_x, dispatch_y, 1),
312            entry_point: "dilate".to_string(),
313        }
314    }
315
316    /// Create parameters for erosion
317    pub fn erode_params(width: u32, height: u32) -> KernelParams {
318        let workgroup_size = (16, 16, 1);
319        let dispatch_x = width.div_ceil(16);
320        let dispatch_y = height.div_ceil(16);
321
322        KernelParams {
323            workgroup_size,
324            dispatch_size: (dispatch_x, dispatch_y, 1),
325            entry_point: "erode".to_string(),
326        }
327    }
328}
329
330/// Texture analysis kernel helper
331pub struct TextureAnalysisKernel;
332
333impl TextureAnalysisKernel {
334    /// Get shader source
335    pub fn shader() -> &'static str {
336        TEXTURE_ANALYSIS_SHADER
337    }
338
339    /// Create parameters for GLCM computation
340    pub fn glcm_params(width: u32, height: u32) -> KernelParams {
341        let workgroup_size = (16, 16, 1);
342        let dispatch_x = width.div_ceil(16);
343        let dispatch_y = height.div_ceil(16);
344
345        KernelParams {
346            workgroup_size,
347            dispatch_size: (dispatch_x, dispatch_y, 1),
348            entry_point: "compute_glcm".to_string(),
349        }
350    }
351
352    /// Create parameters for LBP (Local Binary Pattern)
353    pub fn lbp_params(width: u32, height: u32) -> KernelParams {
354        let workgroup_size = (16, 16, 1);
355        let dispatch_x = width.div_ceil(16);
356        let dispatch_y = height.div_ceil(16);
357
358        KernelParams {
359            workgroup_size,
360            dispatch_size: (dispatch_x, dispatch_y, 1),
361            entry_point: "local_binary_pattern".to_string(),
362        }
363    }
364}
365
366#[cfg(test)]
367mod tests {
368    use super::*;
369
370    #[test]
371    fn test_kernel_registry() {
372        let registry = KernelRegistry::new();
373        assert_eq!(registry.shader_count(), 6);
374        assert!(registry.has_shader("matrix_ops"));
375        assert!(registry.has_shader("fft"));
376        assert!(registry.has_shader("histogram_eq"));
377    }
378
379    #[test]
380    fn test_kernel_registry_custom() {
381        let mut registry = KernelRegistry::new();
382        let initial_count = registry.shader_count();
383
384        registry.register_shader("custom".to_string(), "custom shader code".to_string());
385        assert_eq!(registry.shader_count(), initial_count + 1);
386        assert!(registry.has_shader("custom"));
387
388        assert!(registry.remove_shader("custom"));
389        assert_eq!(registry.shader_count(), initial_count);
390    }
391
392    #[test]
393    fn test_kernel_params() {
394        let params = KernelParams::default();
395        assert_eq!(params.workgroup_size, (8, 8, 1));
396        assert_eq!(params.entry_point, "main");
397    }
398
399    #[test]
400    fn test_kernel_params_total_threads() {
401        let params = KernelParams::new((8, 8, 1), (10, 10, 1));
402        assert_eq!(params.total_threads(), 8 * 8 * 10 * 10);
403    }
404
405    #[test]
406    fn test_calculate_dispatch_size() {
407        let (dx, dy, dz) = KernelParams::calculate_dispatch_size(1920, 1080, (16, 16, 1));
408        assert_eq!(dx, 1920_u32.div_ceil(16));
409        assert_eq!(dy, 1080_u32.div_ceil(16));
410        assert_eq!(dz, 1);
411    }
412
413    #[test]
414    fn test_matrix_multiply_kernel() {
415        let params = MatrixMultiplyKernel::params(1024, 1024, 1024, true);
416        assert_eq!(params.entry_point, "matrix_multiply_tiled");
417        assert_eq!(params.workgroup_size, (16, 16, 1));
418    }
419
420    #[test]
421    fn test_fft_kernel() {
422        let params = FftKernel::params(1024);
423        assert_eq!(params.entry_point, "fft_cooley_tukey");
424
425        let stages = FftKernel::num_stages(1024);
426        assert_eq!(stages, 10); // log2(1024) = 10
427    }
428
429    #[test]
430    fn test_all_shaders_available() {
431        assert!(!MATRIX_OPS_SHADER.is_empty());
432        assert!(!FFT_SHADER.is_empty());
433        assert!(!HISTOGRAM_EQ_SHADER.is_empty());
434        assert!(!MORPHOLOGY_SHADER.is_empty());
435        assert!(!EDGE_DETECTION_SHADER.is_empty());
436        assert!(!TEXTURE_ANALYSIS_SHADER.is_empty());
437    }
438}