jxl_render/features/
noise.rs

1use std::num::Wrapping;
2
3use jxl_frame::{FrameHeader, data::NoiseParameters};
4use jxl_grid::{AlignedGrid, AllocTracker, SharedSubgrid};
5use jxl_threadpool::JxlThreadPool;
6
7use crate::{ImageWithRegion, Region, Result};
8
9// Padding for 5x5 kernel convolution step
10const PADDING: usize = 2;
11
12pub fn render_noise(
13    header: &FrameHeader,
14    visible_frames_num: usize,
15    invisible_frames_num: usize,
16    base_correlations_xb: Option<(f32, f32)>,
17    grid: &mut ImageWithRegion,
18    params: &NoiseParameters,
19    pool: &JxlThreadPool,
20) -> Result<()> {
21    let (region, shift) = grid.regions_and_shifts()[0];
22    let tracker = grid.alloc_tracker().cloned();
23    let [grid_x, grid_y, grid_b] = grid.as_color_floats_mut();
24
25    let full_frame_region = Region::with_size(header.width, header.height);
26    let actual_region = region
27        .intersection(full_frame_region)
28        .downsample_with_shift(shift);
29
30    let left = actual_region.left as usize;
31    let top = actual_region.top as usize;
32    let width = actual_region.width as usize;
33    let height = actual_region.height as usize;
34    let (corr_x, corr_b) = base_correlations_xb.unwrap_or((0.0, 1.0));
35
36    let noise_buffer = init_noise(
37        visible_frames_num,
38        invisible_frames_num,
39        header,
40        tracker.as_ref(),
41        pool,
42    )?;
43
44    let mut lut = [0f32; 9];
45    lut[..8].copy_from_slice(&params.lut);
46    lut[8] = params.lut[7];
47    for fy in 0..height {
48        let y = fy + top;
49        let row_x = grid_x.get_row_mut(fy);
50        let row_y = grid_y.get_row_mut(fy);
51        let row_b = grid_b.get_row_mut(fy);
52        let row_noise_x = noise_buffer[0].get_row(y);
53        let row_noise_y = noise_buffer[1].get_row(y);
54        let row_noise_b = noise_buffer[2].get_row(y);
55
56        for fx in 0..width {
57            let x = fx + left;
58
59            let grid_x = row_x[fx];
60            let grid_y = row_y[fx];
61            let noise_x = row_noise_x[x];
62            let noise_y = row_noise_y[x];
63            let noise_b = row_noise_b[x];
64
65            let in_x = grid_x + grid_y;
66            let in_y = grid_y - grid_x;
67            let in_scaled_x = f32::max(0.0, in_x * 3.0);
68            let in_scaled_y = f32::max(0.0, in_y * 3.0);
69
70            let in_x_int = (in_scaled_x as usize).min(7);
71            let in_x_frac = in_scaled_x - in_x_int as f32;
72            let in_y_int = (in_scaled_y as usize).min(7);
73            let in_y_frac = in_scaled_y - in_y_int as f32;
74
75            let sx = (lut[in_x_int + 1] - lut[in_x_int]) * in_x_frac + lut[in_x_int];
76            let sy = (lut[in_y_int + 1] - lut[in_y_int]) * in_y_frac + lut[in_y_int];
77            let nx = 0.22 * sx * (0.0078125 * noise_x + 0.9921875 * noise_b);
78            let ny = 0.22 * sy * (0.0078125 * noise_y + 0.9921875 * noise_b);
79            row_x[fx] += corr_x * (nx + ny) + nx - ny;
80            row_y[fx] += nx + ny;
81            row_b[fx] += corr_b * (nx + ny);
82        }
83    }
84
85    Ok(())
86}
87
88fn init_noise(
89    visible_frames: usize,
90    invisible_frames: usize,
91    header: &FrameHeader,
92    tracker: Option<&AllocTracker>,
93    pool: &JxlThreadPool,
94) -> Result<[AlignedGrid<f32>; 3]> {
95    let seed0 = rng_seed0(visible_frames, invisible_frames);
96
97    // We use header.width and header.height because
98    // these are the dimensions after upsampling (the "actual" frame size),
99    // and noise synthesis is done after upsampling.
100    let width = header.width as usize;
101    let height = header.height as usize;
102
103    let group_dim = header.group_dim() as usize;
104    let groups_per_row = width.div_ceil(group_dim);
105    let num_groups = groups_per_row * height.div_ceil(group_dim);
106
107    let mut noise_groups = Vec::with_capacity(num_groups);
108    for group_idx in 0..num_groups {
109        let group_x = group_idx % groups_per_row;
110        let group_y = group_idx / groups_per_row;
111        let x0 = group_x * group_dim;
112        let y0 = group_y * group_dim;
113        let seed1 = rng_seed1(x0, y0);
114
115        let group_width = group_dim.min(width - x0);
116        let group_height = group_dim.min(height - y0);
117        let noise_group = NoiseGroup::new(group_width, group_height, seed0, seed1, tracker)?;
118        noise_groups.push(noise_group);
119    }
120
121    let mut convolved: [AlignedGrid<f32>; 3] = [
122        AlignedGrid::with_alloc_tracker(width, height, tracker)?,
123        AlignedGrid::with_alloc_tracker(width, height, tracker)?,
124        AlignedGrid::with_alloc_tracker(width, height, tracker)?,
125    ];
126
127    // Each channel is convolved by the 5×5 kernel
128    let mut jobs = Vec::with_capacity(num_groups * 3);
129    for (channel_idx, out) in convolved.iter_mut().enumerate() {
130        for (group_idx, out_subgrid) in out
131            .as_subgrid_mut()
132            .into_groups(group_dim, group_dim)
133            .into_iter()
134            .enumerate()
135        {
136            let group_x = group_idx % groups_per_row;
137            let group_y = group_idx / groups_per_row;
138
139            // `adjacent_groups[4] == this`
140            let adjacent_groups: [_; 9] = std::array::from_fn(|idx| {
141                let offset_x = (idx % 3) as isize - 1;
142                let offset_y = (idx / 3) as isize - 1;
143                if let (Some(x), Some(y)) = (
144                    group_x.checked_add_signed(offset_x),
145                    group_y.checked_add_signed(offset_y),
146                ) {
147                    let group_idx = y * groups_per_row + x;
148                    if x < groups_per_row {
149                        noise_groups
150                            .get(group_idx)
151                            .map(|group| group.as_subgrid(channel_idx))
152                    } else {
153                        None
154                    }
155                } else {
156                    None
157                }
158            });
159
160            jobs.push((out_subgrid, adjacent_groups));
161        }
162    }
163
164    let result = std::sync::Mutex::new(Ok(()));
165    pool.for_each_vec(jobs, |job| {
166        let (out_subgrid, adjacent_groups) = job;
167        let r = convolve_fill(out_subgrid, adjacent_groups, tracker);
168        if r.is_err() {
169            *result.lock().unwrap() = r;
170        }
171    });
172    result.into_inner().unwrap()?;
173
174    Ok(convolved)
175}
176
177/// Seed for [`XorShift128Plus`] from the number of ‘visible’ frames decoded so far
178/// and the number of ‘invisible’ frames since the previous visible frame.
179#[inline]
180fn rng_seed0(visible_frames: usize, invisible_frames: usize) -> u64 {
181    ((visible_frames as u64) << 32) + invisible_frames as u64
182}
183
184/// Seed for [`XorShift128Plus`] from the coordinates of the top-left pixel of the
185/// group within the frame.
186#[inline]
187fn rng_seed1(x0: usize, y0: usize) -> u64 {
188    ((x0 as u64) << 32) + y0 as u64
189}
190
191struct NoiseGroup {
192    buf: [Vec<f32>; 3],
193    width: usize,
194    height: usize,
195    stride: usize,
196    _alloc_handle: Option<jxl_grid::AllocHandle>,
197}
198
199impl NoiseGroup {
200    fn new(
201        width: usize,
202        height: usize,
203        seed0: u64,
204        seed1: u64,
205        tracker: Option<&AllocTracker>,
206    ) -> Result<Self> {
207        let width_n2 = width.div_ceil(N * 2);
208        let stride = width_n2 * N * 2;
209        let elems = stride * height * 3;
210        let alloc_handle = tracker
211            .map(|tracker| tracker.alloc::<f32>(elems))
212            .transpose()?;
213
214        let mut rng = XorShift128Plus::new(seed0, seed1);
215
216        let buf: [_; 3] = std::array::from_fn(|_| {
217            let num_iters = width_n2 * height;
218            let mut buf = Vec::with_capacity(num_iters * N * 2);
219            for _ in 0..num_iters {
220                let bits = rng.get_u32_bits();
221                let bits = bits.map(|x| f32::from_bits((x >> 9) | 0x3f800000));
222                buf.extend_from_slice(&bits);
223            }
224            buf
225        });
226
227        Ok(Self {
228            buf,
229            width,
230            height,
231            stride,
232            _alloc_handle: alloc_handle,
233        })
234    }
235
236    #[inline]
237    fn as_subgrid(&self, channel_idx: usize) -> SharedSubgrid<f32> {
238        SharedSubgrid::from_buf(&self.buf[channel_idx], self.width, self.height, self.stride)
239    }
240}
241
242fn convolve_fill(
243    mut out: jxl_grid::MutableSubgrid<'_, f32>,
244    adjacent_groups: [Option<SharedSubgrid<f32>>; 9],
245    tracker: Option<&AllocTracker>,
246) -> Result<()> {
247    let this = adjacent_groups[4].unwrap();
248    let width = out.width();
249    let height = out.height();
250    assert_eq!(this.width(), width);
251    assert_eq!(this.height(), height);
252
253    let mut rows = AlignedGrid::with_alloc_tracker(width + PADDING * 2, 1 + PADDING * 2, tracker)?;
254    if let Some(c) = adjacent_groups[1] {
255        let l = adjacent_groups[0];
256        let r = adjacent_groups[2];
257        for offset_y in -2..0 {
258            let out = rows.get_row_mut(2usize.wrapping_add_signed(offset_y));
259            let c = c.get_row(c.height().wrapping_add_signed(offset_y));
260            let l = l
261                .as_ref()
262                .map(|l| l.get_row(l.height().wrapping_add_signed(offset_y)));
263            let r = r
264                .as_ref()
265                .map(|r| r.get_row(r.height().wrapping_add_signed(offset_y)));
266            fill_padded_row(out, c, l, r);
267        }
268    } else if height >= 2 {
269        let c = this;
270        let l = adjacent_groups[3];
271        let r = adjacent_groups[5];
272        for offset_y in -2..0 {
273            let y = (-(offset_y + 1)) as usize;
274            let out = rows.get_row_mut(2usize.wrapping_add_signed(offset_y));
275            let c = c.get_row(y);
276            let l = l.as_ref().map(|l| l.get_row(y));
277            let r = r.as_ref().map(|r| r.get_row(y));
278            fill_padded_row(out, c, l, r);
279        }
280    } else {
281        let c = this;
282        let l = adjacent_groups[3];
283        let r = adjacent_groups[5];
284
285        let c = c.get_row(0);
286        let l = l.as_ref().map(|l| l.get_row(0));
287        let r = r.as_ref().map(|r| r.get_row(0));
288        for y in 0..2 {
289            let out = rows.get_row_mut(y);
290            fill_padded_row(out, c, l, r);
291        }
292    }
293
294    for y in 0..3 {
295        let out = rows.get_row_mut(2 + y);
296        fill_once(out, y, adjacent_groups);
297    }
298
299    let input_width = rows.width();
300    for y in 0..height {
301        let center_y = (y + 2) % 5;
302
303        let input_buf = rows.buf();
304        let out_buf = out.get_row_mut(y);
305        for (x, out) in out_buf.iter_mut().enumerate() {
306            let mut sum = 0f32;
307            for dy in 0..5 {
308                let input_row = &input_buf[dy * input_width..][..input_width];
309                for dx in 0..5 {
310                    sum += input_row[x + dx] * 0.16;
311                }
312            }
313            *out = sum - input_buf[center_y * input_width + x + 2] * 4.0;
314        }
315
316        if y != height - 1 {
317            let next_y = y + 3;
318            let fill_y = (next_y + 2) % 5;
319            fill_once(rows.get_row_mut(fill_y), next_y, adjacent_groups);
320        }
321    }
322
323    Ok(())
324}
325
326fn fill_once(out: &mut [f32], fill_y: usize, adjacent_groups: [Option<SharedSubgrid<f32>>; 9]) {
327    let this = adjacent_groups[4].unwrap();
328    let height = this.height();
329
330    let (source_y, c, l, r) = if let Some(fill_y) = fill_y.checked_sub(height) {
331        (
332            fill_y,
333            adjacent_groups[7],
334            adjacent_groups[6],
335            adjacent_groups[8],
336        )
337    } else {
338        (
339            fill_y,
340            adjacent_groups[4],
341            adjacent_groups[3],
342            adjacent_groups[5],
343        )
344    };
345
346    let (source_y, c, l, r) = if let Some(c) = c {
347        (source_y, c, l, r)
348    } else if let Some(y) = (height - 1).checked_sub(source_y) {
349        let c = this;
350        let l = adjacent_groups[3];
351        let r = adjacent_groups[5];
352        (y, c, l, r)
353    } else {
354        let dy = source_y - height + 1;
355        if let Some(c) = adjacent_groups[1] {
356            let l = adjacent_groups[0];
357            let r = adjacent_groups[2];
358            (c.height() - dy, c, l, r)
359        } else {
360            let c = this;
361            let l = adjacent_groups[3];
362            let r = adjacent_groups[5];
363            (0, c, l, r)
364        }
365    };
366    let c = c.get_row(source_y);
367    let l = l.as_ref().map(|l| l.get_row(source_y));
368    let r = r.as_ref().map(|r| r.get_row(source_y));
369
370    fill_padded_row(out, c, l, r);
371}
372
373fn fill_padded_row(out: &mut [f32], this: &[f32], left: Option<&[f32]>, right: Option<&[f32]>) {
374    assert_eq!(out.len(), this.len() + PADDING * 2);
375
376    if let Some(left) = left {
377        out[0] = left[left.len() - 2];
378        out[1] = left[left.len() - 1];
379    } else if this.len() >= PADDING {
380        out[0] = this[1];
381        out[1] = this[0];
382    } else {
383        out[0] = this[0];
384        out[1] = this[0];
385    }
386
387    out[2..][..this.len()].copy_from_slice(this);
388
389    if let Some(right) = right {
390        if right.len() >= PADDING {
391            out[out.len() - 2] = right[0];
392            out[out.len() - 1] = right[1];
393        } else {
394            out[out.len() - 2] = right[0];
395            out[out.len() - 1] = right[0];
396        }
397    } else {
398        out[out.len() - 2] = out[out.len() - 3];
399        out[out.len() - 1] = out[out.len() - 4];
400    }
401}
402
403const N: usize = 8;
404
405/// Shift-register pseudo-random number generator
406struct XorShift128Plus {
407    s0: [Wrapping<u64>; N],
408    s1: [Wrapping<u64>; N],
409}
410
411impl XorShift128Plus {
412    /// Initialize a new XorShift128+ PRNG.
413    fn new(seed0: u64, seed1: u64) -> Self {
414        let seed0 = Wrapping(seed0);
415        let seed1 = Wrapping(seed1);
416        let mut s0 = [Wrapping(0u64); N];
417        let mut s1 = [Wrapping(0u64); N];
418        s0[0] = split_mix_64(seed0 + Wrapping(0x9E3779B97F4A7C15));
419        s1[0] = split_mix_64(seed1 + Wrapping(0x9E3779B97F4A7C15));
420        for i in 1..N {
421            s0[i] = split_mix_64(s0[i - 1]);
422            s1[i] = split_mix_64(s1[i - 1]);
423        }
424        Self { s0, s1 }
425    }
426
427    /// Returns N * 2 [`u32`] pseudorandom numbers
428    #[inline]
429    pub fn get_u32_bits(&mut self) -> [u32; N * 2] {
430        let batch = self.fill_batch();
431        if 1u64.to_le() == 1u64 {
432            bytemuck::cast(batch)
433        } else {
434            bytemuck::cast(batch.map(|x| x.rotate_left(32)))
435        }
436    }
437
438    #[inline]
439    fn fill_batch(&mut self) -> [u64; N] {
440        std::array::from_fn(|i| {
441            let mut s1 = self.s0[i];
442            let s0 = self.s1[i];
443            let ret = (s1 + s0).0;
444            self.s0[i] = s0;
445            s1 ^= s1 << 23;
446            self.s1[i] = s1 ^ (s0 ^ (s1 >> 18) ^ (s0 >> 5));
447            ret
448        })
449    }
450}
451
452/// Pseudo-random number generator used to calculate initial state of [`XorShift128Plus`]
453#[inline]
454fn split_mix_64(z: Wrapping<u64>) -> Wrapping<u64> {
455    let z = (z ^ (z >> 30)) * Wrapping(0xBF58476D1CE4E5B9);
456    let z = (z ^ (z >> 27)) * Wrapping(0x94D049BB133111EB);
457    z ^ (z >> 31)
458}