Skip to main content

jxl_render/filter/
gabor.rs

1use jxl_grid::{AlignedGrid, MutableSubgrid};
2use jxl_threadpool::JxlThreadPool;
3
4use crate::{ImageWithRegion, Region};
5
6use super::impls::generic::gabor::gabor_row_edge;
7
8pub fn apply_gabor_like(
9    fb: &mut ImageWithRegion,
10    color_padded_region: Region,
11    fb_scratch: &mut [AlignedGrid<f32>; 3],
12    weights: [[f32; 2]; 3],
13    pool: &jxl_threadpool::JxlThreadPool,
14) {
15    tracing::debug!("Running gaborish");
16    let region = fb.regions_and_shifts()[0].0;
17    assert!(region.contains(color_padded_region));
18    let left = region.left.abs_diff(color_padded_region.left) as usize;
19    let top = region.top.abs_diff(color_padded_region.top) as usize;
20    let right = left + color_padded_region.width as usize;
21    let bottom = top + color_padded_region.height as usize;
22
23    let buffers = fb.as_color_floats_mut();
24    let buffers = buffers.map(|g| g.as_subgrid_mut().subgrid(left..right, top..bottom));
25
26    super::impls::apply_gabor_like(buffers, fb_scratch, weights, pool);
27
28    let left = color_padded_region.left;
29    let top = color_padded_region.top;
30    for (idx, grid) in fb_scratch.iter_mut().enumerate() {
31        let width = grid.width() as u32;
32        let height = grid.height() as u32;
33        let region = Region {
34            width,
35            height,
36            left,
37            top,
38        };
39        fb.swap_channel_f32(idx, grid, region);
40    }
41}
42
43pub(super) struct GaborRow<'buf> {
44    pub input_rows: [&'buf [f32]; 3],
45    pub output_row: &'buf mut [f32],
46    pub weights: [f32; 2],
47}
48
49pub(super) fn run_gabor_rows<'buf>(
50    input: MutableSubgrid<'buf, f32>,
51    output: &'buf mut AlignedGrid<f32>,
52    weights: [f32; 2],
53    pool: &JxlThreadPool,
54    handle_row: for<'a> fn(GaborRow<'a>),
55) {
56    unsafe { run_gabor_rows_unsafe(input, output, weights, pool, handle_row) }
57}
58
59pub(super) unsafe fn run_gabor_rows_unsafe<'buf>(
60    input: MutableSubgrid<'buf, f32>,
61    output: &'buf mut AlignedGrid<f32>,
62    weights: [f32; 2],
63    pool: &JxlThreadPool,
64    handle_row: for<'a> unsafe fn(GaborRow<'a>),
65) {
66    let width = input.width();
67    let height = input.height();
68    let output_buf = output.buf_mut();
69    assert_eq!(output_buf.len(), width * height);
70
71    if height == 1 {
72        let input_buf = input.get_row(0);
73        gabor_row_edge(input_buf, None, output_buf, weights);
74        return;
75    }
76
77    {
78        let input_buf_c = input.get_row(0);
79        let input_buf_a = input.get_row(1);
80        let output_buf = &mut output_buf[..width];
81        gabor_row_edge(input_buf_c, Some(input_buf_a), output_buf, weights);
82    }
83
84    let (inner_rows, bottom_row) = output_buf[width..].split_at_mut((height - 2) * width);
85    let output_rows = inner_rows
86        .chunks_mut(width * 8)
87        .enumerate()
88        .collect::<Vec<_>>();
89
90    pool.for_each_vec(output_rows, |(y8, output_rows)| {
91        let it = output_rows.chunks_exact_mut(width);
92        for (dy, output_row) in it.enumerate() {
93            let y_up = y8 * 8 + dy;
94            let input_rows = [
95                input.get_row(y_up),
96                input.get_row(y_up + 1),
97                input.get_row(y_up + 2),
98            ];
99            let row = GaborRow {
100                input_rows,
101                output_row,
102                weights,
103            };
104            unsafe {
105                handle_row(row);
106            }
107        }
108    });
109
110    {
111        let input_buf_c = input.get_row(height - 1);
112        let input_buf_a = input.get_row(height - 2);
113        let output_buf = bottom_row;
114        gabor_row_edge(input_buf_c, Some(input_buf_a), output_buf, weights);
115    }
116}
117
118#[allow(unused)]
119pub(crate) fn run_gabor_row_generic(row: GaborRow) {
120    super::impls::generic::gabor::run_gabor_row_generic(row)
121}