use cubecl::prelude::*;
use cubecl::terminate;
#[cube(launch_unchecked)]
pub fn nlm_mc_block_match_coarse(
centre: &Array<f32>,
neighbour: &Array<f32>,
mv_field: &mut Array<i32>,
#[comptime] level_width: u32,
#[comptime] level_height: u32,
#[comptime] blksize: u32,
#[comptime] step: u32,
#[comptime] search_radius: u32,
#[comptime] level_scale: u32,
#[comptime] fine_blocks_x: u32,
#[comptime] fine_blocks_y: u32,
) {
let bx = CUBE_POS_X;
let by = CUBE_POS_Y;
let block_origin_x = bx as i32 * step as i32;
let block_origin_y = by as i32 * step as i32;
let local_x = UNIT_POS_X;
let local_y = UNIT_POS_Y;
let threads = CUBE_DIM_X * CUBE_DIM_Y;
let thread_id = local_y * CUBE_DIM_X + local_x;
let window_side = comptime!(2 * search_radius + 1);
let candidates = comptime!(window_side * window_side);
let mut sad_scratch = SharedMemory::<f32>::new(candidates as usize);
let mut init_idx = thread_id;
while init_idx < candidates {
sad_scratch[init_idx as usize] = 0.0f32;
init_idx += threads;
}
sync_cube();
let mut py = local_y;
while py < blksize {
let mut px = local_x;
while px < blksize {
let cx = block_origin_x + px as i32;
let cy = block_origin_y + py as i32;
let cx_c = clamp_i32(cx, level_width as i32);
let cy_c = clamp_i32(cy, level_height as i32);
let centre_val = centre[(cy_c * level_width as i32 + cx_c) as usize];
for dy in 0..window_side {
for dx in 0..window_side {
let mvx = dx as i32 - search_radius as i32;
let mvy = dy as i32 - search_radius as i32;
let nx = clamp_i32(cx + mvx, level_width as i32);
let ny = clamp_i32(cy + mvy, level_height as i32);
let neighbour_val = neighbour[(ny * level_width as i32 + nx) as usize];
let diff = centre_val - neighbour_val;
let abs_diff = if diff < 0.0f32 { -diff } else { diff };
let scratch_idx = (dy * window_side + dx) as usize;
sad_scratch[scratch_idx] += abs_diff;
}
}
px += CUBE_DIM_X;
}
py += CUBE_DIM_Y;
}
sync_cube();
if thread_id != 0 {
terminate!();
}
let mut best_sad = 1.0e30f32;
let mut best_dx = 0i32;
let mut best_dy = 0i32;
for dy in 0..window_side {
for dx in 0..window_side {
let s = sad_scratch[(dy * window_side + dx) as usize];
if s < best_sad {
best_sad = s;
best_dx = dx as i32 - search_radius as i32;
best_dy = dy as i32 - search_radius as i32;
}
}
}
let mvx_fine = best_dx * level_scale as i32;
let mvy_fine = best_dy * level_scale as i32;
let fine_bx_origin = bx * level_scale;
let fine_by_origin = by * level_scale;
for fy in 0..level_scale {
let fby = fine_by_origin + fy;
if fby >= fine_blocks_y {
} else {
for fx in 0..level_scale {
let fbx = fine_bx_origin + fx;
if fbx < fine_blocks_x {
let idx = ((fby * fine_blocks_x + fbx) * 2) as usize;
mv_field[idx] = mvx_fine;
mv_field[idx + 1] = mvy_fine;
}
}
}
}
}
#[cube(launch_unchecked)]
pub fn nlm_mc_block_match_fine(
centre: &Array<f32>,
neighbour: &Array<f32>,
mv_field: &mut Array<i32>,
#[comptime] width: u32,
#[comptime] height: u32,
#[comptime] blksize: u32,
#[comptime] step: u32,
#[comptime] search_radius: u32,
use_seed: u32,
#[comptime] blocks_x: u32,
#[comptime] _blocks_y: u32,
) {
let bx = CUBE_POS_X;
let by = CUBE_POS_Y;
let mv_slot = ((by * blocks_x + bx) * 2) as usize;
#[allow(clippy::useless_conversion)]
let seed_dx = if use_seed == 1u32 {
mv_field[mv_slot]
} else {
0i32.into()
};
#[allow(clippy::useless_conversion)]
let seed_dy = if use_seed == 1u32 {
mv_field[mv_slot + 1]
} else {
0i32.into()
};
let block_origin_x = bx as i32 * step as i32;
let block_origin_y = by as i32 * step as i32;
let local_x = UNIT_POS_X;
let local_y = UNIT_POS_Y;
let threads = CUBE_DIM_X * CUBE_DIM_Y;
let thread_id = local_y * CUBE_DIM_X + local_x;
let window_side = comptime!(2 * search_radius + 1);
let candidates = comptime!(window_side * window_side);
let mut sad_scratch = SharedMemory::<f32>::new(candidates as usize);
let mut init_idx = thread_id;
while init_idx < candidates {
sad_scratch[init_idx as usize] = 0.0f32;
init_idx += threads;
}
sync_cube();
let mut py = local_y;
while py < blksize {
let mut px = local_x;
while px < blksize {
let cx = block_origin_x + px as i32;
let cy = block_origin_y + py as i32;
let cx_c = clamp_i32(cx, width as i32);
let cy_c = clamp_i32(cy, height as i32);
let centre_val = centre[(cy_c * width as i32 + cx_c) as usize];
for dy in 0..window_side {
for dx in 0..window_side {
let mvx = seed_dx + (dx as i32 - search_radius as i32);
let mvy = seed_dy + (dy as i32 - search_radius as i32);
let nx = clamp_i32(cx + mvx, width as i32);
let ny = clamp_i32(cy + mvy, height as i32);
let neighbour_val = neighbour[(ny * width as i32 + nx) as usize];
let diff = centre_val - neighbour_val;
let abs_diff = if diff < 0.0f32 { -diff } else { diff };
let scratch_idx = (dy * window_side + dx) as usize;
sad_scratch[scratch_idx] += abs_diff;
}
}
px += CUBE_DIM_X;
}
py += CUBE_DIM_Y;
}
sync_cube();
if thread_id != 0 {
terminate!();
}
let mut best_sad = 1.0e30f32;
let mut best_dx = seed_dx;
let mut best_dy = seed_dy;
for dy in 0..window_side {
for dx in 0..window_side {
let s = sad_scratch[(dy * window_side + dx) as usize];
if s < best_sad {
best_sad = s;
best_dx = seed_dx + (dx as i32 - search_radius as i32);
best_dy = seed_dy + (dy as i32 - search_radius as i32);
}
}
}
mv_field[mv_slot] = best_dx;
mv_field[mv_slot + 1] = best_dy;
}
#[cube]
fn clamp_i32(value: i32, limit: i32) -> i32 {
let mut result = value;
if value < 0 {
result = 0;
} else if value >= limit {
result = limit - 1;
}
result
}