use cubecl::prelude::*;
use cubecl::terminate;
use super::helpers::{read_clamped_line, read_line};
#[cube(launch_unchecked)]
pub fn nlm_bilateral<N: Size>(
input: &Array<Vector<f32, N>>,
output: &mut Array<Vector<f32, N>>,
frame: u32,
inv_two_sigma_s_sq: f32,
inv_two_sigma_r_sq: f32,
#[comptime] width: u32,
#[comptime] height: u32,
#[comptime] channels: u32,
#[comptime] radius: u32,
#[comptime] block_x: u32,
#[comptime] block_y: u32,
) {
let tile_width = comptime!(block_x + 2 * radius);
let tile_elems = comptime!((block_x + 2 * radius) * (block_y + 2 * radius));
let mut smem = SharedMemory::<Vector<f32, N>>::new(tile_elems as usize);
let local_x = UNIT_POS_X;
let local_y = UNIT_POS_Y;
let global_x = CUBE_POS_X * block_x + local_x;
let global_y = CUBE_POS_Y * block_y + local_y;
let tile_start_x = CUBE_POS_X as i32 * block_x as i32 - radius as i32;
let tile_start_y = CUBE_POS_Y as i32 * block_y as i32 - radius as i32;
let tile_end_x = tile_start_x + comptime!((block_x + 2 * radius) as i32);
let tile_end_y = tile_start_y + comptime!((block_y + 2 * radius) as i32);
let interior =
tile_start_x >= 0 && tile_end_x <= width as i32 && tile_start_y >= 0 && tile_end_y <= height as i32;
let threads = block_x * block_y;
let thread_id = local_y * block_x + local_x;
let mut idx = thread_id;
if interior {
while idx < tile_elems {
let tile_x = idx % tile_width;
let tile_y = idx / tile_width;
let src_x = (tile_start_x + tile_x as i32) as u32;
let src_y = (tile_start_y + tile_y as i32) as u32;
smem[idx as usize] = read_line(input, src_x, src_y, frame, width, height);
idx += threads;
}
} else {
while idx < tile_elems {
let tile_x = idx % tile_width;
let tile_y = idx / tile_width;
let src_x = tile_start_x + tile_x as i32;
let src_y = tile_start_y + tile_y as i32;
smem[idx as usize] = read_clamped_line(input, src_x, src_y, frame, width, height);
idx += threads;
}
}
sync_cube();
if global_x >= width || global_y >= height {
terminate!();
}
let center_tile_x = local_x + radius;
let center_tile_y = local_y + radius;
let center = smem[(center_tile_y * tile_width + center_tile_x) as usize];
let patch_size = 2 * radius + 1;
let mut weight_sum = 0.0f32;
let mut acc = Vector::<f32, N>::empty().fill(0.0f32);
for offset_y in 0..patch_size {
for offset_x in 0..patch_size {
let dy = offset_y as i32 - radius as i32;
let dx = offset_x as i32 - radius as i32;
let smem_idx = ((center_tile_y + offset_y - radius) * tile_width + center_tile_x + offset_x
- radius) as usize;
let neighbor = smem[smem_idx];
let diff = neighbor - center;
let mut range_sq = 0.0f32;
#[unroll]
for c in 0..channels {
range_sq += diff[c as usize] * diff[c as usize];
}
let spatial = (dx * dx + dy * dy) as f32 * inv_two_sigma_s_sq;
let w = f32::exp(-(spatial + range_sq * inv_two_sigma_r_sq));
let line_w = Vector::<f32, N>::empty().fill(w);
acc += neighbor * line_w;
weight_sum += w;
}
}
let inv = 1.0f32 / weight_sum;
let line_inv = Vector::<f32, N>::empty().fill(inv);
output[((frame * height + global_y) * width + global_x) as usize] = acc * line_inv;
}