use crate::{
shader::{BindGroupLayoutBuilder, ShaderCompiler, ShaderSource},
GpuDevice, GpuError, Result,
};
use bytemuck::{Pod, Zeroable};
use once_cell::sync::OnceCell;
use wgpu::{BindGroup, BindGroupLayout, ComputePipeline};
use super::utils;
#[repr(C)]
#[derive(Copy, Clone, Pod, Zeroable)]
struct FilterParams {
width: u32,
height: u32,
stride: u32,
kernel_size: u32,
normalize: u32,
filter_type: u32,
padding: u32,
sigma: f32,
}
pub struct FilterOperation;
impl FilterOperation {
#[allow(clippy::too_many_arguments)]
pub fn gaussian_blur(
device: &GpuDevice,
input: &[u8],
output: &mut [u8],
width: u32,
height: u32,
sigma: f32,
) -> Result<()> {
utils::validate_dimensions(width, height)?;
utils::validate_buffer_size(input, width, height, 4)?;
utils::validate_buffer_size(output, width, height, 4)?;
let kernel_size = Self::calculate_kernel_size(sigma);
let pipeline = Self::get_gaussian_pipeline(device)?;
let layout = Self::get_bind_group_layout(device)?;
Self::execute_filter(
device,
pipeline,
layout,
input,
output,
width,
height,
kernel_size,
1, sigma,
)
}
#[allow(clippy::too_many_arguments)]
pub fn sharpen(
device: &GpuDevice,
input: &[u8],
output: &mut [u8],
width: u32,
height: u32,
amount: f32,
) -> Result<()> {
utils::validate_dimensions(width, height)?;
utils::validate_buffer_size(input, width, height, 4)?;
utils::validate_buffer_size(output, width, height, 4)?;
let pipeline = Self::get_sharpen_pipeline(device)?;
let layout = Self::get_bind_group_layout(device)?;
Self::execute_filter(
device, pipeline, layout, input, output, width, height,
5, 2, amount,
)
}
pub fn edge_detect(
device: &GpuDevice,
input: &[u8],
output: &mut [u8],
width: u32,
height: u32,
) -> Result<()> {
utils::validate_dimensions(width, height)?;
utils::validate_buffer_size(input, width, height, 4)?;
utils::validate_buffer_size(output, width, height, 4)?;
let pipeline = Self::get_edge_detect_pipeline(device)?;
let layout = Self::get_bind_group_layout(device)?;
Self::execute_filter(
device, pipeline, layout, input, output, width, height, 3, 3, 0.0,
)
}
#[allow(clippy::too_many_arguments)]
pub fn convolve(
device: &GpuDevice,
input: &[u8],
output: &mut [u8],
width: u32,
height: u32,
kernel: &[f32],
normalize: bool,
) -> Result<()> {
utils::validate_dimensions(width, height)?;
utils::validate_buffer_size(input, width, height, 4)?;
utils::validate_buffer_size(output, width, height, 4)?;
let kernel_size = (kernel.len() as f32).sqrt() as u32;
if kernel_size * kernel_size != kernel.len() as u32 {
return Err(GpuError::Internal("Kernel must be square".to_string()));
}
if kernel_size % 2 == 0 {
return Err(GpuError::Internal("Kernel size must be odd".to_string()));
}
let pipeline = Self::get_convolve_pipeline(device)?;
let layout = Self::get_bind_group_layout_with_kernel(device)?;
Self::execute_convolve(
device,
pipeline,
layout,
input,
output,
width,
height,
kernel,
kernel_size,
normalize,
)
}
#[allow(clippy::too_many_arguments)]
fn execute_filter(
device: &GpuDevice,
pipeline: &ComputePipeline,
layout: &BindGroupLayout,
input: &[u8],
output: &mut [u8],
width: u32,
height: u32,
kernel_size: u32,
filter_type: u32,
sigma: f32,
) -> Result<()> {
let input_buffer = utils::create_storage_buffer(device, input.len() as u64)?;
let output_buffer = utils::create_storage_buffer(device, output.len() as u64)?;
device.queue().write_buffer(input_buffer.buffer(), 0, input);
let params = FilterParams {
width,
height,
stride: width,
kernel_size,
normalize: 1,
filter_type,
padding: 0,
sigma,
};
let params_bytes = bytemuck::bytes_of(¶ms);
let params_buffer = utils::create_uniform_buffer(device, params_bytes)?;
let compiler = ShaderCompiler::new(device);
let bind_group = compiler.create_bind_group(
"Filter Bind Group",
layout,
&[
wgpu::BindGroupEntry {
binding: 0,
resource: input_buffer.buffer().as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: output_buffer.buffer().as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: params_buffer.buffer().as_entire_binding(),
},
],
);
Self::dispatch_compute(device, pipeline, &bind_group, width, height)?;
let readback_buffer = utils::create_readback_buffer(device, output.len() as u64)?;
let mut encoder = device
.device()
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("Filter Copy Encoder"),
});
output_buffer.copy_to(&mut encoder, &readback_buffer, 0, 0, output.len() as u64)?;
device.queue().submit(Some(encoder.finish()));
device.wait();
let result = readback_buffer.read(device, 0, output.len() as u64)?;
output.copy_from_slice(&result);
Ok(())
}
#[allow(clippy::too_many_arguments)]
fn execute_convolve(
device: &GpuDevice,
pipeline: &ComputePipeline,
layout: &BindGroupLayout,
input: &[u8],
output: &mut [u8],
width: u32,
height: u32,
kernel: &[f32],
kernel_size: u32,
normalize: bool,
) -> Result<()> {
let input_buffer = utils::create_storage_buffer(device, input.len() as u64)?;
let output_buffer = utils::create_storage_buffer(device, output.len() as u64)?;
device.queue().write_buffer(input_buffer.buffer(), 0, input);
let kernel_bytes = bytemuck::cast_slice(kernel);
let kernel_buffer = utils::create_storage_buffer(device, kernel_bytes.len() as u64)?;
device
.queue()
.write_buffer(kernel_buffer.buffer(), 0, kernel_bytes);
let params = FilterParams {
width,
height,
stride: width,
kernel_size,
normalize: u32::from(normalize),
filter_type: 0, padding: 0,
sigma: 0.0,
};
let params_bytes = bytemuck::bytes_of(¶ms);
let params_buffer = utils::create_uniform_buffer(device, params_bytes)?;
let compiler = ShaderCompiler::new(device);
let bind_group = compiler.create_bind_group(
"Filter Bind Group",
layout,
&[
wgpu::BindGroupEntry {
binding: 0,
resource: input_buffer.buffer().as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: output_buffer.buffer().as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: params_buffer.buffer().as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: kernel_buffer.buffer().as_entire_binding(),
},
],
);
Self::dispatch_compute(device, pipeline, &bind_group, width, height)?;
let readback_buffer = utils::create_readback_buffer(device, output.len() as u64)?;
let mut encoder = device
.device()
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("Filter Copy Encoder"),
});
output_buffer.copy_to(&mut encoder, &readback_buffer, 0, 0, output.len() as u64)?;
device.queue().submit(Some(encoder.finish()));
device.wait();
let result = readback_buffer.read(device, 0, output.len() as u64)?;
output.copy_from_slice(&result);
Ok(())
}
fn dispatch_compute(
device: &GpuDevice,
pipeline: &ComputePipeline,
bind_group: &BindGroup,
width: u32,
height: u32,
) -> Result<()> {
let mut encoder = device
.device()
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("Filter Compute Encoder"),
});
{
let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("Filter Compute Pass"),
timestamp_writes: None,
});
compute_pass.set_pipeline(pipeline);
compute_pass.set_bind_group(0, bind_group, &[]);
let (dispatch_x, dispatch_y) = utils::calculate_dispatch_size(width, height, (16, 16));
compute_pass.dispatch_workgroups(dispatch_x, dispatch_y, 1);
}
device.queue().submit(Some(encoder.finish()));
Ok(())
}
fn calculate_kernel_size(sigma: f32) -> u32 {
let radius = (3.0 * sigma).ceil() as u32;
2 * radius + 1
}
fn get_bind_group_layout(device: &GpuDevice) -> Result<&'static BindGroupLayout> {
static LAYOUT: OnceCell<BindGroupLayout> = OnceCell::new();
Ok(LAYOUT.get_or_init(|| {
let compiler = ShaderCompiler::new(device);
let entries = BindGroupLayoutBuilder::new()
.add_storage_buffer_read_only(0) .add_storage_buffer(1) .add_uniform_buffer(2) .build();
compiler.create_bind_group_layout("Filter Bind Group Layout", &entries)
}))
}
fn get_bind_group_layout_with_kernel(device: &GpuDevice) -> Result<&'static BindGroupLayout> {
static LAYOUT: OnceCell<BindGroupLayout> = OnceCell::new();
Ok(LAYOUT.get_or_init(|| {
let compiler = ShaderCompiler::new(device);
let entries = BindGroupLayoutBuilder::new()
.add_storage_buffer_read_only(0) .add_storage_buffer(1) .add_uniform_buffer(2) .add_storage_buffer_read_only(3) .build();
compiler.create_bind_group_layout("Filter Bind Group Layout (with kernel)", &entries)
}))
}
fn init_pipeline(
device: &GpuDevice,
name: &str,
entry_point: &str,
layout_fn: fn(&GpuDevice) -> Result<&'static BindGroupLayout>,
) -> std::result::Result<ComputePipeline, String> {
let compiler = ShaderCompiler::new(device);
let shader = compiler
.compile(
"Filter Shader",
ShaderSource::Embedded(crate::shader::embedded::FILTER_SHADER),
)
.map_err(|e| format!("Failed to compile filter shader: {e}"))?;
let layout =
layout_fn(device).map_err(|e| format!("Failed to create bind group layout: {e}"))?;
compiler
.create_pipeline(name, &shader, entry_point, layout)
.map_err(|e| format!("Failed to create pipeline: {e}"))
}
fn get_gaussian_pipeline(device: &GpuDevice) -> Result<&'static ComputePipeline> {
static PIPELINE: OnceCell<std::result::Result<ComputePipeline, String>> = OnceCell::new();
PIPELINE
.get_or_init(|| {
FilterOperation::init_pipeline(
device,
"Gaussian Blur Pipeline",
"convolve_main",
Self::get_bind_group_layout,
)
})
.as_ref()
.map_err(|e| crate::GpuError::PipelineCreation(e.clone()))
}
fn get_sharpen_pipeline(device: &GpuDevice) -> Result<&'static ComputePipeline> {
static PIPELINE: OnceCell<std::result::Result<ComputePipeline, String>> = OnceCell::new();
PIPELINE
.get_or_init(|| {
FilterOperation::init_pipeline(
device,
"Sharpen Pipeline",
"unsharp_mask",
Self::get_bind_group_layout,
)
})
.as_ref()
.map_err(|e| crate::GpuError::PipelineCreation(e.clone()))
}
fn get_edge_detect_pipeline(device: &GpuDevice) -> Result<&'static ComputePipeline> {
static PIPELINE: OnceCell<std::result::Result<ComputePipeline, String>> = OnceCell::new();
PIPELINE
.get_or_init(|| {
FilterOperation::init_pipeline(
device,
"Edge Detect Pipeline",
"edge_detect",
Self::get_bind_group_layout,
)
})
.as_ref()
.map_err(|e| crate::GpuError::PipelineCreation(e.clone()))
}
fn get_convolve_pipeline(device: &GpuDevice) -> Result<&'static ComputePipeline> {
static PIPELINE: OnceCell<std::result::Result<ComputePipeline, String>> = OnceCell::new();
PIPELINE
.get_or_init(|| {
FilterOperation::init_pipeline(
device,
"Convolve Pipeline",
"convolve_main",
Self::get_bind_group_layout_with_kernel,
)
})
.as_ref()
.map_err(|e| crate::GpuError::PipelineCreation(e.clone()))
}
}
#[must_use]
pub fn gaussian_kernel_1d(sigma: f32) -> Vec<f32> {
if sigma <= 0.0 {
return vec![1.0_f32];
}
let radius = (3.0 * sigma).ceil() as usize;
let len = 2 * radius + 1;
let mut kernel = Vec::with_capacity(len);
let two_sigma_sq = 2.0 * sigma * sigma;
let mut sum = 0.0_f32;
for i in 0..len {
let x = i as f32 - radius as f32;
let v = (-x * x / two_sigma_sq).exp();
kernel.push(v);
sum += v;
}
for k in &mut kernel {
*k /= sum;
}
kernel
}
pub fn gaussian_blur_separable(
input: &[u8],
output: &mut [u8],
width: u32,
height: u32,
sigma: f32,
) -> crate::Result<()> {
utils::validate_dimensions(width, height)?;
utils::validate_buffer_size(input, width, height, 4)?;
utils::validate_buffer_size(output, width, height, 4)?;
let w = width as usize;
let h = height as usize;
let kernel = gaussian_kernel_1d(sigma);
let radius = kernel.len() / 2;
let mut h_pass = vec![0.0_f32; w * h * 4];
for row in 0..h {
for col in 0..w {
let mut acc = [0.0_f32; 4];
let mut wsum = 0.0_f32;
for (ki, &kw) in kernel.iter().enumerate() {
let sc = col as isize + ki as isize - radius as isize;
if sc < 0 || sc >= w as isize {
continue;
}
let src = (row * w + sc as usize) * 4;
for c in 0..4 {
acc[c] += kw * input[src + c] as f32;
}
wsum += kw;
}
let dst = (row * w + col) * 4;
let inv = if wsum > 0.0 { 1.0 / wsum } else { 1.0 };
for c in 0..4 {
h_pass[dst + c] = acc[c] * inv;
}
}
}
for row in 0..h {
for col in 0..w {
let mut acc = [0.0_f32; 4];
let mut wsum = 0.0_f32;
for (ki, &kw) in kernel.iter().enumerate() {
let sr = row as isize + ki as isize - radius as isize;
if sr < 0 || sr >= h as isize {
continue;
}
let src = (sr as usize * w + col) * 4;
for c in 0..4 {
acc[c] += kw * h_pass[src + c];
}
wsum += kw;
}
let dst = (row * w + col) * 4;
let inv = if wsum > 0.0 { 1.0 / wsum } else { 1.0 };
for c in 0..4 {
output[dst + c] = (acc[c] * inv).round().clamp(0.0, 255.0) as u8;
}
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_kernel_sums_to_one() {
let k = gaussian_kernel_1d(1.0);
let sum: f32 = k.iter().sum();
assert!((sum - 1.0).abs() < 1e-5, "kernel sum = {sum}");
}
#[test]
fn test_kernel_is_symmetric() {
let k = gaussian_kernel_1d(2.0);
let n = k.len();
for i in 0..n / 2 {
assert!(
(k[i] - k[n - 1 - i]).abs() < 1e-6,
"asymmetric at index {i}: {} vs {}",
k[i],
k[n - 1 - i]
);
}
}
#[test]
fn test_kernel_center_is_largest() {
let k = gaussian_kernel_1d(1.5);
let center = k[k.len() / 2];
for &v in &k {
assert!(center >= v, "center {center} not >= {v}");
}
}
#[test]
fn test_kernel_zero_sigma_returns_identity() {
let k = gaussian_kernel_1d(0.0);
assert_eq!(k.len(), 1);
assert!((k[0] - 1.0).abs() < 1e-6);
}
#[test]
fn test_kernel_negative_sigma_returns_identity() {
let k = gaussian_kernel_1d(-1.0);
assert_eq!(k.len(), 1);
assert!((k[0] - 1.0).abs() < 1e-6);
}
#[test]
fn test_blur_uniform_image_unchanged() {
let w = 8u32;
let h = 8u32;
let input: Vec<u8> = (0..(w * h * 4) as usize)
.map(|i| if i % 4 == 3 { 255 } else { 128 })
.collect();
let mut output = vec![0u8; (w * h * 4) as usize];
gaussian_blur_separable(&input, &mut output, w, h, 1.5).expect("blur should succeed");
for (i, (&inp, &out)) in input.iter().zip(output.iter()).enumerate() {
assert!(
(inp as i32 - out as i32).unsigned_abs() <= 1,
"pixel {i}: input={inp} output={out}"
);
}
}
#[test]
fn test_blur_reduces_contrast() {
let w = 4u32;
let h = 4u32;
let mut input = vec![0u8; (w * h * 4) as usize];
for row in 0..h as usize {
for col in 0..w as usize {
let v = if (row + col) % 2 == 0 { 255u8 } else { 0u8 };
let base = (row * w as usize + col) * 4;
input[base] = v;
input[base + 1] = v;
input[base + 2] = v;
input[base + 3] = 255;
}
}
let mut output = vec![0u8; (w * h * 4) as usize];
gaussian_blur_separable(&input, &mut output, w, h, 1.0).expect("blur should succeed");
let max_rgb = output
.chunks(4)
.flat_map(|px| &px[..3])
.copied()
.max()
.unwrap_or(0);
assert!(
max_rgb < 255,
"max_rgb after blur = {max_rgb}; expected < 255"
);
}
#[test]
fn test_blur_size_mismatch_returns_error() {
let w = 4u32;
let h = 4u32;
let input = vec![0u8; (w * h * 4) as usize];
let mut output = vec![0u8; 10];
let result = gaussian_blur_separable(&input, &mut output, w, h, 1.0);
assert!(result.is_err());
}
#[test]
fn test_blur_single_pixel_passthrough() {
let input = vec![100u8, 150u8, 200u8, 255u8];
let mut output = vec![0u8; 4];
gaussian_blur_separable(&input, &mut output, 1, 1, 1.0).expect("blur should succeed");
assert_eq!(output[0], 100);
assert_eq!(output[1], 150);
assert_eq!(output[2], 200);
assert_eq!(output[3], 255);
}
}