pub const MATRIX_OPS_SHADER: &str = include_str!("advanced/matrix_ops.wgsl");
pub const FFT_SHADER: &str = include_str!("advanced/fft.wgsl");
pub const HISTOGRAM_EQ_SHADER: &str = include_str!("advanced/histogram_eq.wgsl");
pub const MORPHOLOGY_SHADER: &str = include_str!("advanced/morphology.wgsl");
pub const EDGE_DETECTION_SHADER: &str = include_str!("advanced/edge_detection.wgsl");
pub const TEXTURE_ANALYSIS_SHADER: &str = include_str!("advanced/texture_analysis.wgsl");
pub struct KernelRegistry {
shaders: std::collections::HashMap<String, String>,
}
impl KernelRegistry {
pub fn new() -> Self {
let mut shaders = std::collections::HashMap::new();
shaders.insert("matrix_ops".to_string(), MATRIX_OPS_SHADER.to_string());
shaders.insert("fft".to_string(), FFT_SHADER.to_string());
shaders.insert("histogram_eq".to_string(), HISTOGRAM_EQ_SHADER.to_string());
shaders.insert("morphology".to_string(), MORPHOLOGY_SHADER.to_string());
shaders.insert(
"edge_detection".to_string(),
EDGE_DETECTION_SHADER.to_string(),
);
shaders.insert(
"texture_analysis".to_string(),
TEXTURE_ANALYSIS_SHADER.to_string(),
);
Self { shaders }
}
pub fn get_shader(&self, name: &str) -> Option<&str> {
self.shaders.get(name).map(|s| s.as_str())
}
pub fn register_shader(&mut self, name: String, source: String) {
self.shaders.insert(name, source);
}
pub fn list_shaders(&self) -> Vec<&str> {
self.shaders.keys().map(|k| k.as_str()).collect()
}
pub fn has_shader(&self, name: &str) -> bool {
self.shaders.contains_key(name)
}
pub fn remove_shader(&mut self, name: &str) -> bool {
self.shaders.remove(name).is_some()
}
pub fn shader_count(&self) -> usize {
self.shaders.len()
}
}
impl Default for KernelRegistry {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct KernelParams {
pub workgroup_size: (u32, u32, u32),
pub dispatch_size: (u32, u32, u32),
pub entry_point: String,
}
impl Default for KernelParams {
fn default() -> Self {
Self {
workgroup_size: (8, 8, 1),
dispatch_size: (1, 1, 1),
entry_point: "main".to_string(),
}
}
}
impl KernelParams {
pub fn new(workgroup_size: (u32, u32, u32), dispatch_size: (u32, u32, u32)) -> Self {
Self {
workgroup_size,
dispatch_size,
entry_point: "main".to_string(),
}
}
pub fn with_workgroup_size(mut self, x: u32, y: u32, z: u32) -> Self {
self.workgroup_size = (x, y, z);
self
}
pub fn with_dispatch_size(mut self, x: u32, y: u32, z: u32) -> Self {
self.dispatch_size = (x, y, z);
self
}
pub fn with_entry_point(mut self, entry_point: impl Into<String>) -> Self {
self.entry_point = entry_point.into();
self
}
pub fn total_threads(&self) -> u64 {
let (wg_x, wg_y, wg_z) = self.workgroup_size;
let (d_x, d_y, d_z) = self.dispatch_size;
(wg_x as u64 * d_x as u64) * (wg_y as u64 * d_y as u64) * (wg_z as u64 * d_z as u64)
}
pub fn calculate_dispatch_size(
data_width: u32,
data_height: u32,
workgroup_size: (u32, u32, u32),
) -> (u32, u32, u32) {
let (wg_x, wg_y, _wg_z) = workgroup_size;
let dispatch_x = data_width.div_ceil(wg_x);
let dispatch_y = data_height.div_ceil(wg_y);
let dispatch_z = 1;
(dispatch_x, dispatch_y, dispatch_z)
}
}
pub struct MatrixMultiplyKernel;
impl MatrixMultiplyKernel {
pub fn shader() -> &'static str {
MATRIX_OPS_SHADER
}
pub fn params(m: u32, n: u32, _k: u32, tiled: bool) -> KernelParams {
if tiled {
let workgroup_size = (16, 16, 1);
let dispatch_x = n.div_ceil(16);
let dispatch_y = m.div_ceil(16);
KernelParams {
workgroup_size,
dispatch_size: (dispatch_x, dispatch_y, 1),
entry_point: "matrix_multiply_tiled".to_string(),
}
} else {
let workgroup_size = (8, 8, 1);
let dispatch_x = n.div_ceil(8);
let dispatch_y = m.div_ceil(8);
KernelParams {
workgroup_size,
dispatch_size: (dispatch_x, dispatch_y, 1),
entry_point: "matrix_multiply_naive".to_string(),
}
}
}
}
pub struct FftKernel;
impl FftKernel {
pub fn shader() -> &'static str {
FFT_SHADER
}
pub fn params(n: u32) -> KernelParams {
let workgroup_size = (256, 1, 1);
let dispatch_size = (n.div_ceil(256), 1, 1);
KernelParams {
workgroup_size,
dispatch_size,
entry_point: "fft_cooley_tukey".to_string(),
}
}
pub fn num_stages(n: u32) -> u32 {
(n as f32).log2() as u32
}
}
pub struct HistogramEqKernel;
impl HistogramEqKernel {
pub fn shader() -> &'static str {
HISTOGRAM_EQ_SHADER
}
pub fn compute_histogram_params(width: u32, height: u32) -> KernelParams {
let workgroup_size = (16, 16, 1);
let dispatch_x = width.div_ceil(16);
let dispatch_y = height.div_ceil(16);
KernelParams {
workgroup_size,
dispatch_size: (dispatch_x, dispatch_y, 1),
entry_point: "compute_histogram".to_string(),
}
}
pub fn equalize_params(width: u32, height: u32) -> KernelParams {
let workgroup_size = (16, 16, 1);
let dispatch_x = width.div_ceil(16);
let dispatch_y = height.div_ceil(16);
KernelParams {
workgroup_size,
dispatch_size: (dispatch_x, dispatch_y, 1),
entry_point: "histogram_equalize".to_string(),
}
}
}
pub struct EdgeDetectionKernel;
impl EdgeDetectionKernel {
pub fn shader() -> &'static str {
EDGE_DETECTION_SHADER
}
pub fn sobel_params(width: u32, height: u32) -> KernelParams {
let workgroup_size = (16, 16, 1);
let dispatch_x = width.div_ceil(16);
let dispatch_y = height.div_ceil(16);
KernelParams {
workgroup_size,
dispatch_size: (dispatch_x, dispatch_y, 1),
entry_point: "sobel".to_string(),
}
}
pub fn canny_gradient_params(width: u32, height: u32) -> KernelParams {
let workgroup_size = (16, 16, 1);
let dispatch_x = width.div_ceil(16);
let dispatch_y = height.div_ceil(16);
KernelParams {
workgroup_size,
dispatch_size: (dispatch_x, dispatch_y, 1),
entry_point: "canny_gradient".to_string(),
}
}
}
pub struct MorphologyKernel;
impl MorphologyKernel {
pub fn shader() -> &'static str {
MORPHOLOGY_SHADER
}
pub fn dilate_params(width: u32, height: u32) -> KernelParams {
let workgroup_size = (16, 16, 1);
let dispatch_x = width.div_ceil(16);
let dispatch_y = height.div_ceil(16);
KernelParams {
workgroup_size,
dispatch_size: (dispatch_x, dispatch_y, 1),
entry_point: "dilate".to_string(),
}
}
pub fn erode_params(width: u32, height: u32) -> KernelParams {
let workgroup_size = (16, 16, 1);
let dispatch_x = width.div_ceil(16);
let dispatch_y = height.div_ceil(16);
KernelParams {
workgroup_size,
dispatch_size: (dispatch_x, dispatch_y, 1),
entry_point: "erode".to_string(),
}
}
}
pub struct TextureAnalysisKernel;
impl TextureAnalysisKernel {
pub fn shader() -> &'static str {
TEXTURE_ANALYSIS_SHADER
}
pub fn glcm_params(width: u32, height: u32) -> KernelParams {
let workgroup_size = (16, 16, 1);
let dispatch_x = width.div_ceil(16);
let dispatch_y = height.div_ceil(16);
KernelParams {
workgroup_size,
dispatch_size: (dispatch_x, dispatch_y, 1),
entry_point: "compute_glcm".to_string(),
}
}
pub fn lbp_params(width: u32, height: u32) -> KernelParams {
let workgroup_size = (16, 16, 1);
let dispatch_x = width.div_ceil(16);
let dispatch_y = height.div_ceil(16);
KernelParams {
workgroup_size,
dispatch_size: (dispatch_x, dispatch_y, 1),
entry_point: "local_binary_pattern".to_string(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_kernel_registry() {
let registry = KernelRegistry::new();
assert_eq!(registry.shader_count(), 6);
assert!(registry.has_shader("matrix_ops"));
assert!(registry.has_shader("fft"));
assert!(registry.has_shader("histogram_eq"));
}
#[test]
fn test_kernel_registry_custom() {
let mut registry = KernelRegistry::new();
let initial_count = registry.shader_count();
registry.register_shader("custom".to_string(), "custom shader code".to_string());
assert_eq!(registry.shader_count(), initial_count + 1);
assert!(registry.has_shader("custom"));
assert!(registry.remove_shader("custom"));
assert_eq!(registry.shader_count(), initial_count);
}
#[test]
fn test_kernel_params() {
let params = KernelParams::default();
assert_eq!(params.workgroup_size, (8, 8, 1));
assert_eq!(params.entry_point, "main");
}
#[test]
fn test_kernel_params_total_threads() {
let params = KernelParams::new((8, 8, 1), (10, 10, 1));
assert_eq!(params.total_threads(), 8 * 8 * 10 * 10);
}
#[test]
fn test_calculate_dispatch_size() {
let (dx, dy, dz) = KernelParams::calculate_dispatch_size(1920, 1080, (16, 16, 1));
assert_eq!(dx, 1920_u32.div_ceil(16));
assert_eq!(dy, 1080_u32.div_ceil(16));
assert_eq!(dz, 1);
}
#[test]
fn test_matrix_multiply_kernel() {
let params = MatrixMultiplyKernel::params(1024, 1024, 1024, true);
assert_eq!(params.entry_point, "matrix_multiply_tiled");
assert_eq!(params.workgroup_size, (16, 16, 1));
}
#[test]
fn test_fft_kernel() {
let params = FftKernel::params(1024);
assert_eq!(params.entry_point, "fft_cooley_tukey");
let stages = FftKernel::num_stages(1024);
assert_eq!(stages, 10); }
#[test]
fn test_all_shaders_available() {
assert!(!MATRIX_OPS_SHADER.is_empty());
assert!(!FFT_SHADER.is_empty());
assert!(!HISTOGRAM_EQ_SHADER.is_empty());
assert!(!MORPHOLOGY_SHADER.is_empty());
assert!(!EDGE_DETECTION_SHADER.is_empty());
assert!(!TEXTURE_ANALYSIS_SHADER.is_empty());
}
}