use super::{MotionCompensationMode, PrefilterMode};
pub(super) const NLM_NORM: f32 = 255.0 * 255.0;
pub(super) const NLM_LEGACY: f32 = 3.0;
pub(super) const SEPARABLE_THRESHOLD: u32 = 8;
pub const MAX_PATCH_RADIUS: u32 = 16;
pub const MAX_SEARCH_RADIUS: u32 = 8;
pub const MAX_TEMPORAL_RADIUS: u32 = 8;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ChannelMode {
Luma,
Chroma,
Yuv,
}
impl ChannelMode {
pub fn count(self) -> u32 {
match self {
ChannelMode::Luma => 1,
ChannelMode::Chroma => 2,
ChannelMode::Yuv => 3,
}
}
pub fn storage_count(self) -> u32 {
match self {
ChannelMode::Luma => 1,
ChannelMode::Chroma => 2,
ChannelMode::Yuv => 4,
}
}
}
#[derive(Debug, Clone)]
pub struct NlmParams {
pub temporal_radius: u32,
pub search_radius: u32,
pub patch_radius: u32,
pub strength: f32,
pub self_weight: f32,
pub channels: ChannelMode,
pub prefilter: PrefilterMode,
pub motion_compensation: MotionCompensationMode,
}
impl Default for NlmParams {
fn default() -> Self {
Self {
temporal_radius: 0,
search_radius: 2,
patch_radius: 4,
strength: 1.2,
self_weight: 1.0,
channels: ChannelMode::Yuv,
prefilter: PrefilterMode::None,
motion_compensation: MotionCompensationMode::None,
}
}
}
impl NlmParams {
pub fn h2_inv_norm(&self) -> f32 {
let s_size = (2 * self.patch_radius + 1) * (2 * self.patch_radius + 1);
NLM_NORM / (NLM_LEGACY * self.strength * self.strength * s_size as f32)
}
pub(super) fn total_frames(&self) -> u32 {
1 + 2 * self.temporal_radius
}
pub fn validate(&self) -> Result<(), anyhow::Error> {
if self.patch_radius > MAX_PATCH_RADIUS {
anyhow::bail!(
"patch_radius={} exceeds the supported maximum ({}); larger patches \
exhaust on-chip SMEM in the fused/windowed kernels",
self.patch_radius,
MAX_PATCH_RADIUS,
);
}
if self.search_radius > MAX_SEARCH_RADIUS {
anyhow::bail!(
"search_radius={} exceeds the supported maximum ({}); the windowed \
kernel allocates `(block + 2·patch_radius + 2·search_radius)²` of SMEM",
self.search_radius,
MAX_SEARCH_RADIUS,
);
}
if self.temporal_radius > MAX_TEMPORAL_RADIUS {
anyhow::bail!(
"temporal_radius={} exceeds the supported maximum ({}); the ring \
buffer grows linearly with the window size",
self.temporal_radius,
MAX_TEMPORAL_RADIUS,
);
}
if !(self.strength.is_finite() && self.strength > 0.0) {
anyhow::bail!(
"strength must be finite and > 0 (got {}); strength = 0 produces an \
infinite Welsch normalisation factor",
self.strength,
);
}
if !self.self_weight.is_finite() || self.self_weight < 0.0 {
anyhow::bail!("self_weight must be finite and >= 0 (got {})", self.self_weight,);
}
self.motion_compensation.validate()?;
Ok(())
}
}