use jxl_grid::{AlignedGrid, AllocTracker, SharedSubgrid};
use crate::Region;
use crate::util::PaddedGrid;
pub fn upsample(
grid: SharedSubgrid<f32>,
out_region: &mut Region,
image_header: &jxl_image::ImageHeader,
factor: u32,
tracker: Option<&AllocTracker>,
) -> crate::Result<Option<AlignedGrid<f32>>> {
let metadata = &image_header.metadata;
let mut out = None;
let mut grid = grid;
let up8 = factor / 3;
let last_up = factor % 3;
for _ in 0..up8 {
out = Some(upsample_inner::<8, 210>(
grid,
&metadata.up8_weight,
tracker,
)?);
grid = out.as_ref().unwrap().as_subgrid();
}
out = match last_up {
1 => Some(upsample_inner::<2, 15>(
grid,
&metadata.up2_weight,
tracker,
)?),
2 => Some(upsample_inner::<4, 55>(
grid,
&metadata.up4_weight,
tracker,
)?),
_ => out,
};
*out_region = out_region.upsample(factor);
Ok(out)
}
fn upsample_inner<const K: usize, const NW: usize>(
grid: SharedSubgrid<f32>,
weights: &[f32; NW],
tracker: Option<&AllocTracker>,
) -> crate::Result<AlignedGrid<f32>> {
assert!((K == 2 && NW == 15) || (K == 4 && NW == 55) || (K == 8 && NW == 210));
let grid_width = grid.width();
let grid_height = grid.height();
let frame_width = grid_width << K.ilog2();
let frame_height = grid_height << K.ilog2();
const PADDING: usize = 2;
let mut padded = PaddedGrid::with_alloc_tracker(grid_width, grid_height, PADDING, tracker)?;
let padded_width = grid_width + PADDING * 2;
let padded_buf = padded.buf_padded_mut();
for y in 0..grid.height() {
let row = grid.get_row(y);
padded_buf[(y + PADDING) * padded_width + PADDING..][..grid_width].copy_from_slice(row);
}
padded.mirror_edges_padding();
let mut weights_quarter = vec![[0.0f32; 25]; K * K / 4];
let mut weight_idx = 0usize;
let mat_n = K / 2;
for y in 0..5 * mat_n {
let mat_y = y / 5;
let ky = y % 5;
for x in y..5 * mat_n {
let mat_x = x / 5;
let kx = x % 5;
let w = weights[weight_idx];
weight_idx += 1;
weights_quarter[mat_y * mat_n + mat_x][ky * 5 + kx] = w;
weights_quarter[mat_x * mat_n + mat_y][kx * 5 + ky] = w;
}
}
let mut grid = AlignedGrid::with_alloc_tracker(frame_width, frame_height, tracker)?;
let padded_buf = padded.buf_padded();
let grid_buf = grid.buf_mut();
for y in 0..frame_height {
let ref_y = y / K;
let mat_y = (y % K).min(K - y % K - 1);
let flip_v = y % K >= mat_n;
for x in 0..frame_width {
let ref_x = x / K;
let mat_x = (x % K).min(K - x % K - 1);
let flip_h = x % K >= mat_n;
let kernel = &weights_quarter[mat_y * mat_n + mat_x];
let mut sum = 0.0f32;
let mut min = f32::INFINITY;
let mut max = -f32::INFINITY;
for iy in 0..5 {
let ky = if flip_v { 4 - iy } else { iy };
for ix in 0..5 {
let kx = if flip_h { 4 - ix } else { ix };
let sample = padded_buf[(ref_y + iy) * padded_width + (ref_x + ix)];
sum += kernel[ky * 5 + kx] * sample;
min = min.min(sample);
max = max.max(sample);
}
}
grid_buf[y * frame_width + x] = if !min.is_finite() {
f32::NAN
} else {
sum.clamp(min, max)
};
}
}
Ok(grid)
}