mod util;
use std::ops::{Add, AddAssign};
use anyhow::anyhow;
use arrayvec::ArrayVec;
use v_frame::{chroma::ChromaSubsampling, frame::Frame, plane::Plane};
use self::util::{extract_ar_row, get_block_mean, get_noise_var, linsolve, multiply_mat};
use super::{BLOCK_SIZE, BLOCK_SIZE_SQUARED, NoiseStatus};
use crate::{
DEFAULT_GRAIN_SEED, GrainTableSegment, NUM_UV_COEFFS, NUM_UV_POINTS, NUM_Y_COEFFS,
NUM_Y_POINTS,
diff::solver::util::normalized_cross_correlation,
util::{get_dbg, get_dbg_mut},
};
const LOW_POLY_NUM_PARAMS: usize = 3;
const NOISE_MODEL_LAG: usize = 3;
const BLOCK_NORMALIZATION: f64 = 255.0f64;
#[derive(Debug, Clone)]
pub(super) struct FlatBlockFinder {
a: Box<[f64]>,
a_t_a_inv: [f64; LOW_POLY_NUM_PARAMS * LOW_POLY_NUM_PARAMS],
}
impl FlatBlockFinder {
#[must_use]
pub fn new() -> Self {
let mut eqns = EquationSystem::new(LOW_POLY_NUM_PARAMS);
let mut a_t_a_inv = [0.0f64; LOW_POLY_NUM_PARAMS * LOW_POLY_NUM_PARAMS];
let mut a = vec![0.0f64; LOW_POLY_NUM_PARAMS * BLOCK_SIZE_SQUARED];
let bs_half = (BLOCK_SIZE / 2) as f64;
(0..BLOCK_SIZE).for_each(|y| {
let yd = (y as f64 - bs_half) / bs_half;
(0..BLOCK_SIZE).for_each(|x| {
let xd = (x as f64 - bs_half) / bs_half;
let coords = [yd, xd, 1.0f64];
let row = y * BLOCK_SIZE + x;
*get_dbg_mut(&mut a, LOW_POLY_NUM_PARAMS * row) = yd;
*get_dbg_mut(&mut a, LOW_POLY_NUM_PARAMS * row + 1) = xd;
*get_dbg_mut(&mut a, LOW_POLY_NUM_PARAMS * row + 2) = 1.0f64;
(0..LOW_POLY_NUM_PARAMS).for_each(|i| {
(0..LOW_POLY_NUM_PARAMS).for_each(|j| {
*get_dbg_mut(&mut eqns.a, LOW_POLY_NUM_PARAMS * i + j) +=
*get_dbg(&coords, i) * *get_dbg(&coords, j);
});
});
});
});
(0..LOW_POLY_NUM_PARAMS).for_each(|i| {
eqns.b.fill(0.0f64);
*get_dbg_mut(&mut eqns.b, i) = 1.0f64;
eqns.solve();
(0..LOW_POLY_NUM_PARAMS).for_each(|j| {
*get_dbg_mut(&mut a_t_a_inv, j * LOW_POLY_NUM_PARAMS + i) = *get_dbg(&eqns.x, j);
});
});
FlatBlockFinder {
a: a.into_boxed_slice(),
a_t_a_inv,
}
}
#[must_use]
#[allow(clippy::too_many_lines)]
pub fn run(&self, plane: &Plane<u8>) -> (Vec<u8>, usize) {
const TRACE_THRESHOLD: f64 = 0.15f64 / BLOCK_SIZE_SQUARED as f64;
const RATIO_THRESHOLD: f64 = 1.25f64;
const NORM_THRESHOLD: f64 = 0.08f64 / BLOCK_SIZE_SQUARED as f64;
const VAR_THRESHOLD: f64 = 0.005f64 / BLOCK_SIZE_SQUARED as f64;
const VAR_WEIGHT: f64 = -6682f64;
const RATIO_WEIGHT: f64 = -0.2056f64;
const TRACE_WEIGHT: f64 = 13087f64;
const NORM_WEIGHT: f64 = -12434f64;
const OFFSET: f64 = 2.5694f64;
let num_blocks_w = plane.width().get().div_ceil(BLOCK_SIZE);
let num_blocks_h = plane.height().get().div_ceil(BLOCK_SIZE);
let num_blocks = num_blocks_w * num_blocks_h;
let mut flat_blocks = vec![0u8; num_blocks];
let mut num_flat = 0;
let mut plane_result = [0.0f64; BLOCK_SIZE_SQUARED];
let mut block_result = [0.0f64; BLOCK_SIZE_SQUARED];
let mut scores = vec![IndexAndScore::default(); num_blocks];
for by in 0..num_blocks_h {
for bx in 0..num_blocks_w {
let mut gxx = 0f64;
let mut gxy = 0f64;
let mut gyy = 0f64;
let mut var = 0f64;
let mut mean = 0f64;
self.extract_block(
plane,
bx * BLOCK_SIZE,
by * BLOCK_SIZE,
&mut plane_result,
&mut block_result,
);
for yi in 1..(BLOCK_SIZE - 1) {
for xi in 1..(BLOCK_SIZE - 1) {
unsafe {
let result_ptr = block_result.as_ptr().add(yi * BLOCK_SIZE + xi);
let gx = (*result_ptr.add(1) - *result_ptr.sub(1)) / 2f64;
let gy =
(*result_ptr.add(BLOCK_SIZE) - *result_ptr.sub(BLOCK_SIZE)) / 2f64;
gxx += gx * gx;
gxy += gx * gy;
gyy += gy * gy;
let block_val = *result_ptr;
mean += block_val;
var += block_val * block_val;
}
}
}
let block_size_norm_factor = (BLOCK_SIZE - 2).pow(2) as f64;
mean /= block_size_norm_factor;
gxx /= block_size_norm_factor;
gxy /= block_size_norm_factor;
gyy /= block_size_norm_factor;
var = mean.mul_add(-mean, var / block_size_norm_factor);
let trace = gxx + gyy;
let det = gxx.mul_add(gyy, -gxy.powi(2));
let e_sub = (trace.mul_add(trace, -4f64 * det)).max(0.).sqrt();
let e1 = f64::midpoint(trace, e_sub);
let e2 = (trace - e_sub) / 2.0f64;
let norm = e1;
let ratio = e1 / e2.max(1.0e-6_f64);
let is_flat = trace < TRACE_THRESHOLD
&& ratio < RATIO_THRESHOLD
&& norm < NORM_THRESHOLD
&& var > VAR_THRESHOLD;
let sum_weights = NORM_WEIGHT.mul_add(
norm,
TRACE_WEIGHT.mul_add(
trace,
VAR_WEIGHT.mul_add(var, RATIO_WEIGHT.mul_add(ratio, OFFSET)),
),
);
let sum_weights = sum_weights.clamp(-25.0f64, 100.0f64);
let score = (1.0f64 / (1.0f64 + (-sum_weights).exp())) as f32;
let index = by * num_blocks_w + bx;
*get_dbg_mut(&mut flat_blocks, index) = if is_flat { 255 } else { 0 };
*get_dbg_mut(&mut scores, index) = IndexAndScore {
score: if var > VAR_THRESHOLD { score } else { 0f32 },
index,
};
if is_flat {
num_flat += 1;
}
}
}
scores.sort_unstable_by(|a, b| a.score.partial_cmp(&b.score).expect("Shouldn't be NaN"));
let top_nth_percentile = num_blocks * 90 / 100;
let score_threshold = get_dbg(&scores, top_nth_percentile).score;
for score in &scores {
if score.score >= score_threshold {
let block_ref = get_dbg_mut(&mut flat_blocks, score.index);
if *block_ref == 0 {
num_flat += 1;
}
*block_ref |= 1;
}
}
(flat_blocks, num_flat)
}
fn extract_block(
&self,
plane: &Plane<u8>,
offset_x: usize,
offset_y: usize,
plane_result: &mut [f64; BLOCK_SIZE_SQUARED],
block_result: &mut [f64; BLOCK_SIZE_SQUARED],
) {
let mut plane_coords = [0f64; LOW_POLY_NUM_PARAMS];
let mut a_t_a_inv_b = [0f64; LOW_POLY_NUM_PARAMS];
let plane_origin = get_dbg(plane.data(), plane.data_origin()..);
for yi in 0..BLOCK_SIZE {
let y = (offset_y + yi).clamp(0, plane.height().get() - 1);
for xi in 0..BLOCK_SIZE {
let x = (offset_x + xi).clamp(0, plane.width().get() - 1);
*get_dbg_mut(block_result, yi * BLOCK_SIZE + xi) = f64::from(*get_dbg(
plane_origin,
y * plane.geometry().stride.get() + x,
)) / BLOCK_NORMALIZATION;
}
}
multiply_mat(
block_result,
&self.a,
&mut a_t_a_inv_b,
1,
BLOCK_SIZE_SQUARED,
LOW_POLY_NUM_PARAMS,
);
multiply_mat(
&self.a_t_a_inv,
&a_t_a_inv_b,
&mut plane_coords,
LOW_POLY_NUM_PARAMS,
LOW_POLY_NUM_PARAMS,
1,
);
multiply_mat(
&self.a,
&plane_coords,
plane_result,
BLOCK_SIZE_SQUARED,
LOW_POLY_NUM_PARAMS,
1,
);
for (block_res, plane_res) in block_result.iter_mut().zip(plane_result.iter()) {
*block_res -= *plane_res;
}
}
}
#[derive(Debug, Clone, Copy, Default)]
struct IndexAndScore {
pub index: usize,
pub score: f32,
}
#[derive(Debug, Clone)]
struct EquationSystem {
a: Vec<f64>,
b: Vec<f64>,
x: Vec<f64>,
n: usize,
}
impl EquationSystem {
#[must_use]
pub fn new(n: usize) -> Self {
Self {
a: vec![0.0f64; n * n],
b: vec![0.0f64; n],
x: vec![0.0f64; n],
n,
}
}
pub fn solve(&mut self) -> bool {
let n = self.n;
let mut a = self.a.clone();
let mut b = self.b.clone();
linsolve(n, &mut a, self.n, &mut b, &mut self.x)
}
pub fn set_chroma_coefficient_fallback_solution(&mut self) {
const TOLERANCE: f64 = 1.0e-6f64;
let last = self.n - 1;
self.x.fill(0f64);
if get_dbg(&self.a, last * self.n + last).abs() > TOLERANCE {
*get_dbg_mut(&mut self.x, last) =
*get_dbg(&self.b, last) / *get_dbg(&self.a, last * self.n + last);
}
}
pub fn copy_from(&mut self, other: &Self) {
assert_eq!(self.n, other.n);
self.a.copy_from_slice(&other.a);
self.x.copy_from_slice(&other.x);
self.b.copy_from_slice(&other.b);
}
pub fn clear(&mut self) {
self.a.fill(0f64);
self.b.fill(0f64);
self.x.fill(0f64);
}
}
impl Add<&EquationSystem> for EquationSystem {
type Output = EquationSystem;
fn add(self, addend: &EquationSystem) -> Self::Output {
let mut dest = self.clone();
let n = self.n;
for i in 0..n {
for j in 0..n {
*get_dbg_mut(&mut dest.a, i * n + j) += *get_dbg(&addend.a, i * n + j);
}
*get_dbg_mut(&mut dest.b, i) += *get_dbg(&addend.b, i);
}
dest
}
}
impl AddAssign<&EquationSystem> for EquationSystem {
fn add_assign(&mut self, rhs: &EquationSystem) {
*self = self.clone() + rhs;
}
}
struct NoiseStrengthLut {
points: Vec<[f64; 2]>,
}
impl NoiseStrengthLut {
#[must_use]
pub fn new(num_bins: usize) -> Self {
assert!(num_bins > 0);
Self {
points: vec![[0f64; 2]; num_bins],
}
}
}
#[derive(Debug, Clone)]
pub(super) struct NoiseModel {
combined_state: [NoiseModelState; 3],
latest_state: [NoiseModelState; 3],
n: usize,
coords: Vec<[isize; 2]>,
}
impl NoiseModel {
#[must_use]
pub fn new() -> Self {
let n = Self::num_coeffs();
let combined_state = [
NoiseModelState::new(n),
NoiseModelState::new(n + 1),
NoiseModelState::new(n + 1),
];
let latest_state = [
NoiseModelState::new(n),
NoiseModelState::new(n + 1),
NoiseModelState::new(n + 1),
];
let mut coords = Vec::new();
let neg_lag = -(NOISE_MODEL_LAG as isize);
for y in neg_lag..=0 {
let max_x = if y == 0 {
-1isize
} else {
NOISE_MODEL_LAG as isize
};
for x in neg_lag..=max_x {
coords.push([x, y]);
}
}
assert!(n == coords.len());
Self {
combined_state,
latest_state,
n,
coords,
}
}
pub fn update(
&mut self,
source: &Frame<u8>,
denoised: &Frame<u8>,
flat_blocks: &[u8],
) -> NoiseStatus {
let num_blocks_w = source.y_plane.width().get().div_ceil(BLOCK_SIZE);
let num_blocks_h = source.y_plane.height().get().div_ceil(BLOCK_SIZE);
let mut y_model_different = false;
for i in 0..3 {
let state = get_dbg_mut(&mut self.latest_state, i);
state.eqns.clear();
state.num_observations = 0;
state.strength_solver.clear();
}
let num_blocks = flat_blocks.iter().filter(|b| **b > 0).count();
if num_blocks <= 1 {
return NoiseStatus::Error(anyhow!("Not enough flat blocks to update noise estimate"));
}
let frame_dims = (source.y_plane.width().get(), source.y_plane.height().get());
for channel in 0..(if source.subsampling == ChromaSubsampling::Monochrome {
1
} else {
3
}) {
let source_plane = match channel {
0 => &source.y_plane,
1 => source
.u_plane
.as_ref()
.expect("unreachable due to loop bounds"),
2 => source
.v_plane
.as_ref()
.expect("unreachable due to loop bounds"),
_ => unreachable!(),
};
let denoised_plane = match channel {
0 => &denoised.y_plane,
1 => denoised
.u_plane
.as_ref()
.expect("unreachable due to loop bounds"),
2 => denoised
.v_plane
.as_ref()
.expect("unreachable due to loop bounds"),
_ => unreachable!(),
};
let is_chroma = channel > 0;
let alt_source = (channel > 0).then_some(&source.y_plane);
let alt_denoised = (channel > 0).then_some(&denoised.y_plane);
self.add_block_observations(
channel,
source_plane,
denoised_plane,
alt_source,
alt_denoised,
frame_dims,
flat_blocks,
num_blocks_w,
num_blocks_h,
);
if !get_dbg_mut(&mut self.latest_state, channel).ar_equation_system_solve(is_chroma) {
if is_chroma {
get_dbg_mut(&mut self.latest_state, channel)
.eqns
.set_chroma_coefficient_fallback_solution();
} else {
return NoiseStatus::Error(anyhow!(
"Solving latest noise equation system failed on plane {}",
channel
));
}
}
self.add_noise_std_observations(
channel,
source_plane,
denoised_plane,
alt_source,
frame_dims,
flat_blocks,
num_blocks_w,
num_blocks_h,
);
if !get_dbg_mut(&mut self.latest_state, channel)
.strength_solver
.solve()
{
return NoiseStatus::Error(anyhow!(
"Failed to solve strength solver for latest state"
));
}
let is_different = self.is_different();
let combined_state = get_dbg_mut(&mut self.combined_state, channel);
if channel == 0 && combined_state.strength_solver.num_equations > 0 && is_different {
y_model_different = true;
}
if y_model_different {
continue;
}
combined_state.num_observations +=
get_dbg(&self.latest_state, channel).num_observations;
combined_state.eqns += &get_dbg(&self.latest_state, channel).eqns;
if !combined_state.ar_equation_system_solve(is_chroma) {
if is_chroma {
combined_state
.eqns
.set_chroma_coefficient_fallback_solution();
} else {
return NoiseStatus::Error(anyhow!(
"Solving combined noise equation system failed on plane {}",
channel
));
}
}
combined_state.strength_solver += &get_dbg(&self.latest_state, channel).strength_solver;
if !combined_state.strength_solver.solve() {
return NoiseStatus::Error(anyhow!(
"Failed to solve strength solver for combined state"
));
};
}
if y_model_different {
return NoiseStatus::DifferentType;
}
NoiseStatus::Ok
}
#[allow(clippy::too_many_lines)]
#[must_use]
pub fn get_grain_parameters(&self, start_ts: u64, end_ts: u64) -> GrainTableSegment {
let scaling_points_y = self.combined_state[0]
.strength_solver
.fit_piecewise(NUM_Y_POINTS)
.points;
let scaling_points_cb = self.combined_state[1]
.strength_solver
.fit_piecewise(NUM_UV_POINTS)
.points;
let scaling_points_cr = self.combined_state[2]
.strength_solver
.fit_piecewise(NUM_UV_POINTS)
.points;
let mut max_scaling_value: f64 = 1.0e-4f64;
for p in scaling_points_y
.iter()
.chain(scaling_points_cb.iter())
.chain(scaling_points_cr.iter())
.map(|p| p[1])
{
if p > max_scaling_value {
max_scaling_value = p;
}
}
let max_scaling_value_log2 =
((max_scaling_value.log2() + 1f64).floor() as u8).clamp(2u8, 5u8);
let scale_factor = f64::from(1u32 << (8u8 - max_scaling_value_log2));
let map_scaling_point = |p: [f64; 2]| {
[
(p[0] + 0.5f64) as u8,
(scale_factor.mul_add(p[1], 0.5f64) as i32).clamp(0i32, 255i32) as u8,
]
};
let scaling_points_y: ArrayVec<_, NUM_Y_POINTS> = scaling_points_y
.into_iter()
.map(map_scaling_point)
.collect();
let scaling_points_cb: ArrayVec<_, NUM_UV_POINTS> = scaling_points_cb
.into_iter()
.map(map_scaling_point)
.collect();
let scaling_points_cr: ArrayVec<_, NUM_UV_POINTS> = scaling_points_cr
.into_iter()
.map(map_scaling_point)
.collect();
let n_coeff = self.combined_state[0].eqns.n;
let mut max_coeff = 1.0e-4f64;
let mut min_coeff = 1.0e-4f64;
let mut y_corr = [0f64; 2];
let mut avg_luma_strength = 0f64;
for c in 0..3 {
let eqns = &get_dbg(&self.combined_state, c).eqns;
for i in 0..n_coeff {
let xi = *get_dbg(&eqns.x, i);
if xi > max_coeff {
max_coeff = xi;
}
if xi < min_coeff {
min_coeff = xi;
}
}
let solver = &get_dbg(&self.combined_state, c).strength_solver;
let mut average_strength = 0f64;
let mut total_weight = 0f64;
for i in 0..solver.eqns.n {
let mut w = 0f64;
for j in 0..solver.eqns.n {
w += *get_dbg(&solver.eqns.a, i * solver.eqns.n + j);
}
w = w.sqrt();
average_strength += *get_dbg(&solver.eqns.x, i) * w;
total_weight += w;
}
if total_weight.abs() < f64::EPSILON {
average_strength = 1f64;
} else {
average_strength /= total_weight;
}
if c == 0 {
avg_luma_strength = average_strength;
} else {
let y_corr_cur = get_dbg_mut(&mut y_corr, c - 1);
*y_corr_cur = avg_luma_strength * *get_dbg(&eqns.x, n_coeff) / average_strength;
max_coeff = max_coeff.max(*y_corr_cur);
min_coeff = min_coeff.min(*y_corr_cur);
}
}
let ar_coeff_shift = (7i32
- (1.0f64 + max_coeff.log2().floor()).max((-min_coeff).log2().ceil()) as i32)
.clamp(6i32, 9i32) as u8;
let scale_ar_coeff = f64::from(1u16 << ar_coeff_shift);
let ar_coeffs_y = self.get_ar_coeffs_y(n_coeff, scale_ar_coeff);
let ar_coeffs_cb = self.get_ar_coeffs_uv(1, n_coeff, scale_ar_coeff, y_corr);
let ar_coeffs_cr = self.get_ar_coeffs_uv(2, n_coeff, scale_ar_coeff, y_corr);
GrainTableSegment {
random_seed: if start_ts == 0 { DEFAULT_GRAIN_SEED } else { 0 },
start_time: start_ts,
end_time: end_ts,
ar_coeff_lag: NOISE_MODEL_LAG as u8,
scaling_points_y,
scaling_points_cb,
scaling_points_cr,
scaling_shift: 5 + (8 - max_scaling_value_log2),
ar_coeff_shift,
ar_coeffs_y,
ar_coeffs_cb,
ar_coeffs_cr,
cb_mult: 128,
cb_luma_mult: 192,
cb_offset: 256,
cr_mult: 128,
cr_luma_mult: 192,
cr_offset: 256,
chroma_scaling_from_luma: false,
grain_scale_shift: 0,
overlap_flag: true,
}
}
pub fn save_latest(&mut self) {
for c in 0..3 {
let latest_state = get_dbg(&self.latest_state, c);
let combined_state = get_dbg_mut(&mut self.combined_state, c);
combined_state.eqns.copy_from(&latest_state.eqns);
combined_state
.strength_solver
.eqns
.copy_from(&latest_state.strength_solver.eqns);
combined_state.strength_solver.num_equations =
latest_state.strength_solver.num_equations;
combined_state.num_observations = latest_state.num_observations;
combined_state.ar_gain = latest_state.ar_gain;
}
}
#[must_use]
const fn num_coeffs() -> usize {
let n = 2 * NOISE_MODEL_LAG + 1;
(n * n) / 2
}
#[must_use]
fn get_ar_coeffs_y(&self, n_coeff: usize, scale_ar_coeff: f64) -> ArrayVec<i8, NUM_Y_COEFFS> {
assert!(n_coeff <= NUM_Y_COEFFS);
let mut coeffs = ArrayVec::new();
let eqns = &self.combined_state[0].eqns;
for i in 0..n_coeff {
let xi = *get_dbg(&eqns.x, i);
coeffs.push(((scale_ar_coeff * xi).round() as i32).clamp(-128i32, 127i32) as i8);
}
coeffs
}
#[must_use]
fn get_ar_coeffs_uv(
&self,
channel: usize,
n_coeff: usize,
scale_ar_coeff: f64,
y_corr: [f64; 2],
) -> ArrayVec<i8, NUM_UV_COEFFS> {
assert!(n_coeff <= NUM_Y_COEFFS);
let mut coeffs = ArrayVec::new();
let eqns = &get_dbg(&self.combined_state, channel).eqns;
for i in 0..n_coeff {
let xi = *get_dbg(&eqns.x, i);
coeffs.push(((scale_ar_coeff * xi).round() as i32).clamp(-128i32, 127i32) as i8);
}
coeffs.push(
((scale_ar_coeff * *get_dbg(&y_corr, channel - 1)).round() as i32)
.clamp(-128i32, 127i32) as i8,
);
coeffs
}
#[must_use]
fn is_different(&self) -> bool {
const COEFF_THRESHOLD: f64 = 0.9f64;
const STRENGTH_THRESHOLD: f64 = 0.005f64;
let latest = &self.latest_state[0];
let combined = &self.combined_state[0];
let corr = normalized_cross_correlation(&latest.eqns.x, &combined.eqns.x, combined.eqns.n);
if corr < COEFF_THRESHOLD {
return true;
}
let dx = 1.0f64 / latest.strength_solver.num_bins as f64;
let latest_eqns = &latest.strength_solver.eqns;
let combined_eqns = &combined.strength_solver.eqns;
let mut diff = 0.0f64;
let mut total_weight = 0.0f64;
for j in 0..latest_eqns.n {
let mut weight = 0.0f64;
for i in 0..latest_eqns.n {
weight += *get_dbg(&latest_eqns.a, i * latest_eqns.n + j);
}
weight = weight.sqrt();
diff += weight * (*get_dbg(&latest_eqns.x, j) - *get_dbg(&combined_eqns.x, j)).abs();
total_weight += weight;
}
diff * dx / total_weight > STRENGTH_THRESHOLD
}
#[allow(clippy::too_many_arguments)]
fn add_block_observations(
&mut self,
channel: usize,
source: &Plane<u8>,
denoised: &Plane<u8>,
alt_source: Option<&Plane<u8>>,
alt_denoised: Option<&Plane<u8>>,
frame_dims: (usize, usize),
flat_blocks: &[u8],
num_blocks_w: usize,
num_blocks_h: usize,
) {
let num_coords = self.n;
let state = get_dbg_mut(&mut self.latest_state, channel);
let a = &mut state.eqns.a;
let b = &mut state.eqns.b;
let mut buffer = vec![0f64; num_coords + 1].into_boxed_slice();
let n = state.eqns.n;
let block_w = BLOCK_SIZE / source.geometry().subsampling_x.get() as usize;
let block_h = BLOCK_SIZE / source.geometry().subsampling_y.get() as usize;
let dec = (
source.geometry().subsampling_x.get() as usize >> 1,
source.geometry().subsampling_y.get() as usize >> 1,
);
let stride = source.geometry().stride.get();
let source_origin = get_dbg(source.data(), source.data_origin()..);
let denoised_origin = get_dbg(denoised.data(), denoised.data_origin()..);
let alt_stride = alt_source.map_or(0, |s| s.geometry().stride.get());
let alt_source_origin = alt_source.map(|s| get_dbg(s.data(), s.data_origin()..));
let alt_denoised_origin = alt_denoised.map(|s| get_dbg(s.data(), s.data_origin()..));
for by in 0..num_blocks_h {
let y_o = by * block_h;
for bx in 0..num_blocks_w {
unsafe {
let flat_block_ptr = flat_blocks.as_ptr().add(by * num_blocks_w + bx);
let x_o = bx * block_w;
if *flat_block_ptr == 0 {
continue;
}
let y_start = if by > 0 && *flat_block_ptr.sub(num_blocks_w) > 0 {
0
} else {
NOISE_MODEL_LAG
};
let x_start = if bx > 0 && *flat_block_ptr.sub(1) > 0 {
0
} else {
NOISE_MODEL_LAG
};
let y_end = ((frame_dims.1 >> dec.1) - by * block_h).min(block_h);
let x_end = ((frame_dims.0 >> dec.0) - bx * block_w - NOISE_MODEL_LAG).min(
if bx + 1 < num_blocks_w && *flat_block_ptr.add(1) > 0 {
block_w
} else {
block_w - NOISE_MODEL_LAG
},
);
for y in y_start..y_end {
for x in x_start..x_end {
let val = extract_ar_row(
&self.coords,
num_coords,
source_origin,
denoised_origin,
stride,
dec,
alt_source_origin,
alt_denoised_origin,
alt_stride,
x + x_o,
y + y_o,
&mut buffer,
);
for i in 0..n {
for j in 0..n {
*get_dbg_mut(a, i * n + j) += (*get_dbg(&buffer, i)
* *get_dbg(&buffer, j))
/ BLOCK_NORMALIZATION.powi(2);
}
*get_dbg_mut(b, i) +=
(*get_dbg(&buffer, i) * val) / BLOCK_NORMALIZATION.powi(2);
}
state.num_observations += 1;
}
}
}
}
}
}
#[allow(clippy::too_many_arguments)]
fn add_noise_std_observations(
&mut self,
channel: usize,
source: &Plane<u8>,
denoised: &Plane<u8>,
alt_source: Option<&Plane<u8>>,
frame_dims: (usize, usize),
flat_blocks: &[u8],
num_blocks_w: usize,
num_blocks_h: usize,
) {
let num_coords = self.n;
let luma_gain = self.latest_state[0].ar_gain;
let noise_gain = get_dbg(&self.latest_state, channel).ar_gain;
let block_w = BLOCK_SIZE / source.geometry().subsampling_x.get() as usize;
let block_h = BLOCK_SIZE / source.geometry().subsampling_y.get() as usize;
for by in 0..num_blocks_h {
let y_o = by * block_h;
for bx in 0..num_blocks_w {
let x_o = bx * block_w;
if *get_dbg(flat_blocks, by * num_blocks_w + bx) == 0 {
continue;
}
let num_samples_h = ((frame_dims.1
/ source.geometry().subsampling_y.get() as usize)
- by * block_h)
.min(block_h);
let num_samples_w = ((frame_dims.0
/ source.geometry().subsampling_x.get() as usize)
- bx * block_w)
.min(block_w);
if num_samples_w * num_samples_h > BLOCK_SIZE {
let block_mean = get_block_mean(
alt_source.unwrap_or(source),
frame_dims,
x_o << (source.geometry().subsampling_x.get() >> 1),
y_o << (source.geometry().subsampling_y.get() >> 1),
);
let noise_var = get_noise_var(
source,
denoised,
(
frame_dims.0 >> (source.geometry().subsampling_x.get() >> 1),
frame_dims.1 >> (source.geometry().subsampling_y.get() >> 1),
),
x_o,
y_o,
block_w,
block_h,
);
let luma_strength = if channel > 0 {
luma_gain * self.latest_state[0].strength_solver.get_value(block_mean)
} else {
0f64
};
let corr = if channel > 0 {
let coeffs = &get_dbg(&self.latest_state, channel).eqns.x;
*get_dbg(coeffs, num_coords)
} else {
0f64
};
let uncorr_std = (noise_var / 16f64)
.max((corr * luma_strength).mul_add(-(corr * luma_strength), noise_var))
.sqrt();
let adjusted_strength = uncorr_std / noise_gain;
get_dbg_mut(&mut self.latest_state, channel)
.strength_solver
.add_measurement(block_mean, adjusted_strength);
}
}
}
}
}
#[derive(Debug, Clone)]
struct NoiseModelState {
eqns: EquationSystem,
ar_gain: f64,
num_observations: usize,
strength_solver: StrengthSolver,
}
impl NoiseModelState {
#[must_use]
pub fn new(n: usize) -> Self {
const NUM_BINS: usize = 20;
Self {
eqns: EquationSystem::new(n),
ar_gain: 1.0f64,
num_observations: 0usize,
strength_solver: StrengthSolver::new(NUM_BINS),
}
}
pub fn ar_equation_system_solve(&mut self, is_chroma: bool) -> bool {
let ret = self.eqns.solve();
self.ar_gain = 1.0f64;
if !ret {
return ret;
}
let mut var = 0f64;
let n_adjusted = self.eqns.n - usize::from(is_chroma);
for i in 0..n_adjusted {
var += *get_dbg(&self.eqns.a, i * self.eqns.n + i) / self.num_observations as f64;
}
var /= n_adjusted as f64;
let mut sum_covar = 0f64;
for i in 0..n_adjusted {
let mut bi = *get_dbg(&self.eqns.b, i);
if is_chroma {
bi -= *get_dbg(&self.eqns.a, i * self.eqns.n + n_adjusted)
* *get_dbg(&self.eqns.x, n_adjusted);
}
sum_covar += (bi * *get_dbg(&self.eqns.x, i)) / self.num_observations as f64;
}
let noise_var = (var - sum_covar).max(1e-6f64);
self.ar_gain = 1f64.max((var / noise_var).max(1e-6f64).sqrt());
ret
}
}
#[derive(Debug, Clone)]
struct StrengthSolver {
eqns: EquationSystem,
num_bins: usize,
num_equations: usize,
total: f64,
}
impl StrengthSolver {
#[must_use]
pub fn new(num_bins: usize) -> Self {
Self {
eqns: EquationSystem::new(num_bins),
num_bins,
num_equations: 0usize,
total: 0f64,
}
}
pub fn add_measurement(&mut self, block_mean: f64, noise_std: f64) {
let bin = self.get_bin_index(block_mean);
let bin_i0 = bin.floor() as usize;
let bin_i1 = (self.num_bins - 1).min(bin_i0 + 1);
let a = bin - bin_i0 as f64;
let n = self.num_bins;
let eqns = &mut self.eqns;
*get_dbg_mut(&mut eqns.a, bin_i0 * n + bin_i0) += (1f64 - a).powi(2);
*get_dbg_mut(&mut eqns.a, bin_i1 * n + bin_i0) += a * (1f64 - a);
*get_dbg_mut(&mut eqns.a, bin_i1 * n + bin_i1) += a.powi(2);
*get_dbg_mut(&mut eqns.a, bin_i0 * n + bin_i1) += (1f64 - a) * a;
*get_dbg_mut(&mut eqns.b, bin_i0) += (1f64 - a) * noise_std;
*get_dbg_mut(&mut eqns.b, bin_i1) += a * noise_std;
self.total += noise_std;
self.num_equations += 1;
}
pub fn solve(&mut self) -> bool {
let n = self.num_bins;
let alpha = 2f64 * self.num_equations as f64 / n as f64;
let old_a = self.eqns.a.clone();
for i in 0..n {
let i_lo = i.saturating_sub(1);
let i_hi = (n - 1).min(i + 1);
*get_dbg_mut(&mut self.eqns.a, i * n + i_lo) -= alpha;
*get_dbg_mut(&mut self.eqns.a, i * n + i) += 2f64 * alpha;
*get_dbg_mut(&mut self.eqns.a, i * n + i_hi) -= alpha;
}
let mean = self.total / self.num_equations as f64;
for i in 0..n {
*get_dbg_mut(&mut self.eqns.a, i * n + i) += 1f64 / 8192f64;
*get_dbg_mut(&mut self.eqns.b, i) += mean / 8192f64;
}
let result = self.eqns.solve();
self.eqns.a = old_a;
result
}
#[must_use]
pub fn fit_piecewise(&self, max_output_points: usize) -> NoiseStrengthLut {
const TOLERANCE: f64 = 0.00625f64;
let mut lut = NoiseStrengthLut::new(self.num_bins);
for i in 0..self.num_bins {
get_dbg_mut(&mut lut.points, i)[0] = self.get_center(i);
get_dbg_mut(&mut lut.points, i)[1] = *get_dbg(&self.eqns.x, i);
}
let mut residual = vec![0.0f64; self.num_bins];
self.update_piecewise_linear_residual(&lut, &mut residual, 0, self.num_bins);
while lut.points.len() > 2 {
let mut min_index = 1usize;
for j in 1..(lut.points.len() - 1) {
if *get_dbg(&residual, j) < *get_dbg(&residual, min_index) {
min_index = j;
}
}
let dx =
get_dbg(&lut.points, min_index + 1)[0] - get_dbg(&lut.points, min_index - 1)[0];
let avg_residual = get_dbg(&residual, min_index) / dx;
if lut.points.len() <= max_output_points && avg_residual > TOLERANCE {
break;
}
lut.points.remove(min_index);
self.update_piecewise_linear_residual(
&lut,
&mut residual,
min_index - 1,
min_index + 1,
);
}
lut
}
#[must_use]
pub fn get_value(&self, x: f64) -> f64 {
let bin = self.get_bin_index(x);
let bin_i0 = bin.floor() as usize;
let bin_i1 = (self.num_bins - 1).min(bin_i0 + 1);
let a = bin - bin_i0 as f64;
(1f64 - a).mul_add(
*get_dbg(&self.eqns.x, bin_i0),
a * *get_dbg(&self.eqns.x, bin_i1),
)
}
pub fn clear(&mut self) {
self.eqns.clear();
self.num_equations = 0;
self.total = 0f64;
}
#[must_use]
fn get_bin_index(&self, value: f64) -> f64 {
let max = 255f64;
let val = value.clamp(0f64, max);
(self.num_bins - 1) as f64 * val / max
}
fn update_piecewise_linear_residual(
&self,
lut: &NoiseStrengthLut,
residual: &mut [f64],
start: usize,
end: usize,
) {
let dx = 255f64 / self.num_bins as f64;
#[allow(clippy::needless_range_loop)]
for i in start.max(1)..end.min(lut.points.len() - 1) {
let lower =
0usize.max(self.get_bin_index(get_dbg(&lut.points, i - 1)[0]).floor() as usize);
let upper = (self.num_bins - 1)
.min(self.get_bin_index(get_dbg(&lut.points, i + 1)[0]).ceil() as usize);
let mut r = 0f64;
for j in lower..=upper {
let x = self.get_center(j);
if x < get_dbg(&lut.points, i - 1)[0] || x >= get_dbg(&lut.points, i + 1)[0] {
continue;
}
let y = *get_dbg(&self.eqns.x, j);
let a = (x - get_dbg(&lut.points, i - 1)[0])
/ (get_dbg(&lut.points, i + 1)[0] - get_dbg(&lut.points, i - 1)[0]);
let estimate_y = get_dbg(&lut.points, i - 1)[1]
.mul_add(1f64 - a, get_dbg(&lut.points, i + 1)[1] * a);
r += (y - estimate_y).abs();
}
*get_dbg_mut(residual, i) = r * dx;
}
}
#[must_use]
fn get_center(&self, i: usize) -> f64 {
let range = 255f64;
let n = self.num_bins;
i as f64 / (n - 1) as f64 * range
}
}
impl Add<&StrengthSolver> for StrengthSolver {
type Output = StrengthSolver;
fn add(self, addend: &StrengthSolver) -> Self::Output {
let mut dest = self;
dest.eqns += &addend.eqns;
dest.num_equations += addend.num_equations;
dest.total += addend.total;
dest
}
}
impl AddAssign<&StrengthSolver> for StrengthSolver {
fn add_assign(&mut self, rhs: &StrengthSolver) {
*self = self.clone() + rhs;
}
}