use cubecl::prelude::*;
use cubecl::terminate;
use super::helpers::{accumulate_pair, clamp_coord};
#[cube(launch_unchecked)]
pub fn nlm_accumulate<N: Size>(
input: &Array<Vector<f32, N>>,
accum: &mut Array<Vector<f32, N>>,
weight_sum: &mut Array<f32>,
weights_fwd: &Array<f32>,
weights_bwd: &Array<f32>,
max_weight: &mut Array<f32>,
frame_fwd: u32,
frame_bwd: u32,
q_x: i32,
q_y: i32,
#[comptime] width: u32,
#[comptime] height: u32,
) {
let x = ABSOLUTE_POS_X;
let y = ABSOLUTE_POS_Y;
if x >= width || y >= height {
terminate!();
}
let pixel_idx = (y * width + x) as usize;
let weight_fwd = weights_fwd[pixel_idx];
let clamped_bwd_x = clamp_coord(x as i32 - q_x, width);
let clamped_bwd_y = clamp_coord(y as i32 - q_y, height);
let weight_bwd = weights_bwd[(clamped_bwd_y * width + clamped_bwd_x) as usize];
accumulate_pair(
input, accum, weight_sum, max_weight, x, y, q_x, q_y, frame_fwd, frame_bwd, weight_fwd, weight_bwd,
width, height,
);
}
#[cube(launch_unchecked)]
pub fn nlm_finish<N: Size>(
input: &Array<Vector<f32, N>>,
output: &mut Array<Vector<f32, N>>,
accum: &Array<Vector<f32, N>>,
weight_sum: &Array<f32>,
max_weight: &Array<f32>,
center_frame: u32,
wref: f32,
#[comptime] width: u32,
#[comptime] height: u32,
#[comptime] channels: u32,
) {
let x = ABSOLUTE_POS_X;
let y = ABSOLUTE_POS_Y;
if x >= width || y >= height {
terminate!();
}
let pixel_idx = (y * width + x) as usize;
let frame_idx = ((center_frame * height + y) * width + x) as usize;
let m = wref * max_weight[pixel_idx];
let denominator = m + weight_sum[pixel_idx];
let original = input[frame_idx];
let accumulated = accum[pixel_idx];
let mut out = Vector::<f32, N>::empty();
if denominator > 1e-30f32 {
let inv_denominator = 1.0f32 / denominator;
#[unroll]
for c in 0..channels {
out[c as usize] = (original[c as usize] * m + accumulated[c as usize]) * inv_denominator;
}
} else {
#[unroll]
for c in 0..channels {
out[c as usize] = original[c as usize];
}
}
output[pixel_idx] = out;
}