#![forbid(unsafe_code)]
#![allow(dead_code)]
#![allow(clippy::cast_possible_truncation)]
#![allow(clippy::cast_precision_loss)]
#![allow(clippy::cast_sign_loss)]
#![allow(clippy::similar_names)]
use super::block::BlockSize;
use super::cdef::{CdefParams, CdefStrength};
use super::loop_filter::LoopFilterParams;
const MAX_LOOP_FILTER_LEVEL: u8 = 63;
const MAX_CDEF_STRENGTH: u8 = 15;
const FILTER_LEVELS_TO_TEST: usize = 5;
const CDEF_STRENGTHS_TO_TEST: usize = 4;
#[derive(Clone, Debug)]
pub struct LoopFilterOptimizer {
params: LoopFilterParams,
lambda: f32,
rd_optimization: bool,
}
impl LoopFilterOptimizer {
#[must_use]
pub fn new(lambda: f32) -> Self {
Self {
params: LoopFilterParams::default(),
lambda,
rd_optimization: true,
}
}
pub fn optimize_filter_level(
&mut self,
src: &[u8],
recon: &[u8],
width: usize,
height: usize,
qp: u8,
) -> u8 {
if !self.rd_optimization {
return self.filter_level_from_qp(qp);
}
let base_level = self.filter_level_from_qp(qp);
let mut best_level = base_level;
let mut best_cost = f32::MAX;
for delta in -(FILTER_LEVELS_TO_TEST as i32 / 2)..=(FILTER_LEVELS_TO_TEST as i32 / 2) {
let level = (i32::from(base_level) + delta * 4)
.clamp(0, i32::from(MAX_LOOP_FILTER_LEVEL)) as u8;
let cost = self.evaluate_filter_level(src, recon, width, height, level);
if cost < best_cost {
best_cost = cost;
best_level = level;
}
}
self.params.level = [best_level, best_level, best_level, best_level];
best_level
}
fn filter_level_from_qp(&self, qp: u8) -> u8 {
((i32::from(qp) * 3) / 2).clamp(0, i32::from(MAX_LOOP_FILTER_LEVEL)) as u8
}
fn evaluate_filter_level(
&self,
src: &[u8],
recon: &[u8],
width: usize,
height: usize,
level: u8,
) -> f32 {
let distortion = self.compute_distortion(src, recon, width, height);
let rate = f32::from(level) * 0.1;
distortion + self.lambda * rate
}
fn compute_distortion(&self, src: &[u8], recon: &[u8], width: usize, height: usize) -> f32 {
let mut sse = 0u64;
let total = (width * height).min(src.len()).min(recon.len());
for i in 0..total {
let diff = i32::from(src[i]) - i32::from(recon[i]);
sse += (diff * diff) as u64;
}
sse as f32
}
#[must_use]
pub const fn params(&self) -> &LoopFilterParams {
&self.params
}
pub fn set_lambda(&mut self, lambda: f32) {
self.lambda = lambda;
}
pub fn set_rd_optimization(&mut self, enabled: bool) {
self.rd_optimization = enabled;
}
}
impl Default for LoopFilterOptimizer {
fn default() -> Self {
Self::new(1.0)
}
}
#[derive(Clone, Debug)]
pub struct CdefOptimizer {
params: CdefParams,
lambda: f32,
}
impl CdefOptimizer {
#[must_use]
pub fn new(lambda: f32) -> Self {
Self {
params: CdefParams::default(),
lambda,
}
}
pub fn optimize_strength(
&mut self,
src: &[u8],
recon: &[u8],
width: usize,
height: usize,
_block_size: BlockSize,
) -> CdefStrength {
let mut best_strength = CdefStrength::default();
let mut best_cost = f32::MAX;
for primary in 0..CDEF_STRENGTHS_TO_TEST {
for secondary in 0..CDEF_STRENGTHS_TO_TEST {
let strength = CdefStrength {
primary: primary as u8,
secondary: secondary as u8,
};
let cost = self.evaluate_cdef_strength(src, recon, width, height, &strength);
if cost < best_cost {
best_cost = cost;
best_strength = strength;
}
}
}
best_strength
}
fn evaluate_cdef_strength(
&self,
src: &[u8],
recon: &[u8],
width: usize,
height: usize,
strength: &CdefStrength,
) -> f32 {
let mut sse = 0u64;
let total = (width * height).min(src.len()).min(recon.len());
for i in 0..total {
let diff = i32::from(src[i]) - i32::from(recon[i]);
sse += (diff * diff) as u64;
}
let distortion = sse as f32;
let rate = f32::from(strength.primary + strength.secondary) * 0.5;
distortion + self.lambda * rate
}
#[must_use]
pub const fn params(&self) -> &CdefParams {
&self.params
}
}
impl Default for CdefOptimizer {
fn default() -> Self {
Self::new(1.0)
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum RestorationType {
None = 0,
Wiener = 1,
Sgrproj = 2,
}
#[derive(Clone, Debug)]
pub struct RestorationOptimizer {
restoration_type: RestorationType,
lambda: f32,
}
impl RestorationOptimizer {
#[must_use]
pub fn new(lambda: f32) -> Self {
Self {
restoration_type: RestorationType::None,
lambda,
}
}
pub fn optimize_restoration(
&mut self,
src: &[u8],
recon: &[u8],
width: usize,
height: usize,
) -> RestorationType {
let mut best_type = RestorationType::None;
let mut best_cost = f32::MAX;
for rtype in [
RestorationType::None,
RestorationType::Wiener,
RestorationType::Sgrproj,
] {
let cost = self.evaluate_restoration(src, recon, width, height, rtype);
if cost < best_cost {
best_cost = cost;
best_type = rtype;
}
}
self.restoration_type = best_type;
best_type
}
fn evaluate_restoration(
&self,
src: &[u8],
recon: &[u8],
width: usize,
height: usize,
rtype: RestorationType,
) -> f32 {
let base_distortion = self.compute_distortion(src, recon, width, height);
let rate = match rtype {
RestorationType::None => 0.0,
RestorationType::Wiener => 100.0,
RestorationType::Sgrproj => 80.0,
};
let distortion_reduction = match rtype {
RestorationType::None => 0.0,
RestorationType::Wiener => base_distortion * 0.05,
RestorationType::Sgrproj => base_distortion * 0.03,
};
(base_distortion - distortion_reduction) + self.lambda * rate
}
fn compute_distortion(&self, src: &[u8], recon: &[u8], width: usize, height: usize) -> f32 {
let mut sse = 0u64;
let total = (width * height).min(src.len()).min(recon.len());
for i in 0..total {
let diff = i32::from(src[i]) - i32::from(recon[i]);
sse += (diff * diff) as u64;
}
sse as f32
}
#[must_use]
pub const fn restoration_type(&self) -> RestorationType {
self.restoration_type
}
}
impl Default for RestorationOptimizer {
fn default() -> Self {
Self::new(1.0)
}
}
#[derive(Clone, Debug, Default)]
pub struct FilmGrainParams {
pub enabled: bool,
pub grain_seed: u16,
pub luma_points: Vec<(u8, u8)>,
pub chroma_points: Vec<(u8, u8)>,
}
impl FilmGrainParams {
#[must_use]
pub const fn new() -> Self {
Self {
enabled: false,
grain_seed: 0,
luma_points: Vec::new(),
chroma_points: Vec::new(),
}
}
pub fn enable(&mut self, seed: u16) {
self.enabled = true;
self.grain_seed = seed;
}
pub fn disable(&mut self) {
self.enabled = false;
}
}
#[derive(Clone, Debug)]
pub struct LoopOptimizer {
loop_filter: LoopFilterOptimizer,
cdef: CdefOptimizer,
restoration: RestorationOptimizer,
film_grain: FilmGrainParams,
}
impl LoopOptimizer {
#[must_use]
pub fn new(lambda: f32) -> Self {
Self {
loop_filter: LoopFilterOptimizer::new(lambda),
cdef: CdefOptimizer::new(lambda),
restoration: RestorationOptimizer::new(lambda),
film_grain: FilmGrainParams::new(),
}
}
pub fn optimize_frame(
&mut self,
src: &[u8],
recon: &[u8],
width: usize,
height: usize,
qp: u8,
) {
self.loop_filter
.optimize_filter_level(src, recon, width, height, qp);
let cdef_width = width.min(64);
let cdef_height = height.min(64);
self.cdef
.optimize_strength(src, recon, cdef_width, cdef_height, BlockSize::Block64x64);
self.restoration
.optimize_restoration(src, recon, width, height);
}
#[must_use]
pub const fn loop_filter_params(&self) -> &LoopFilterParams {
self.loop_filter.params()
}
#[must_use]
pub const fn cdef_params(&self) -> &CdefParams {
self.cdef.params()
}
#[must_use]
pub const fn restoration_type(&self) -> RestorationType {
self.restoration.restoration_type()
}
#[must_use]
pub const fn film_grain_params(&self) -> &FilmGrainParams {
&self.film_grain
}
pub fn set_lambda(&mut self, lambda: f32) {
self.loop_filter.set_lambda(lambda);
self.cdef.lambda = lambda;
self.restoration.lambda = lambda;
}
}
impl Default for LoopOptimizer {
fn default() -> Self {
Self::new(1.0)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_loop_filter_optimizer_creation() {
let opt = LoopFilterOptimizer::new(1.0);
assert_eq!(opt.lambda, 1.0);
assert!(opt.rd_optimization);
}
#[test]
fn test_filter_level_from_qp() {
let opt = LoopFilterOptimizer::new(1.0);
let level_low = opt.filter_level_from_qp(10);
let level_high = opt.filter_level_from_qp(50);
assert!(level_low < level_high);
assert!(level_low <= MAX_LOOP_FILTER_LEVEL);
assert!(level_high <= MAX_LOOP_FILTER_LEVEL);
}
#[test]
fn test_optimize_filter_level_fast() {
let mut opt = LoopFilterOptimizer::new(1.0);
opt.set_rd_optimization(false);
let src = vec![128u8; 64 * 64];
let recon = vec![128u8; 64 * 64];
let level = opt.optimize_filter_level(&src, &recon, 64, 64, 28);
assert!(level <= MAX_LOOP_FILTER_LEVEL);
}
#[test]
fn test_compute_distortion() {
let opt = LoopFilterOptimizer::new(1.0);
let src = vec![100u8; 64];
let recon = vec![100u8; 64];
let distortion = opt.compute_distortion(&src, &recon, 8, 8);
assert_eq!(distortion, 0.0);
let recon2 = vec![110u8; 64];
let distortion2 = opt.compute_distortion(&src, &recon2, 8, 8);
assert!(distortion2 > 0.0);
}
#[test]
fn test_cdef_optimizer() {
let opt = CdefOptimizer::new(1.0);
assert_eq!(opt.lambda, 1.0);
}
#[test]
fn test_cdef_optimize_strength() {
let mut opt = CdefOptimizer::new(1.0);
let src = vec![128u8; 32 * 32];
let recon = vec![130u8; 32 * 32];
let strength = opt.optimize_strength(&src, &recon, 32, 32, BlockSize::Block32x32);
assert!(strength.primary <= MAX_CDEF_STRENGTH);
assert!(strength.secondary <= MAX_CDEF_STRENGTH);
}
#[test]
fn test_restoration_optimizer() {
let opt = RestorationOptimizer::new(1.0);
assert_eq!(opt.restoration_type, RestorationType::None);
}
#[test]
fn test_restoration_optimize() {
let mut opt = RestorationOptimizer::new(1.0);
let src = vec![128u8; 64 * 64];
let recon = vec![130u8; 64 * 64];
let rtype = opt.optimize_restoration(&src, &recon, 64, 64);
assert!(matches!(
rtype,
RestorationType::None | RestorationType::Wiener | RestorationType::Sgrproj
));
}
#[test]
fn test_film_grain_params() {
let mut params = FilmGrainParams::new();
assert!(!params.enabled);
params.enable(1234);
assert!(params.enabled);
assert_eq!(params.grain_seed, 1234);
params.disable();
assert!(!params.enabled);
}
#[test]
fn test_combined_optimizer() {
let opt = LoopOptimizer::new(1.5);
assert_eq!(opt.loop_filter.lambda, 1.5);
assert_eq!(opt.cdef.lambda, 1.5);
}
#[test]
fn test_combined_optimize_frame() {
let mut opt = LoopOptimizer::new(1.0);
let src = vec![128u8; 128 * 128];
let recon = vec![128u8; 128 * 128];
opt.optimize_frame(&src, &recon, 128, 128, 28);
let lf_params = opt.loop_filter_params();
assert!(lf_params.level[0] <= MAX_LOOP_FILTER_LEVEL);
}
#[test]
fn test_set_lambda() {
let mut opt = LoopOptimizer::new(1.0);
opt.set_lambda(2.5);
assert_eq!(opt.loop_filter.lambda, 2.5);
assert_eq!(opt.cdef.lambda, 2.5);
assert_eq!(opt.restoration.lambda, 2.5);
}
#[test]
fn test_restoration_types() {
assert_eq!(RestorationType::None as u8, 0);
assert_eq!(RestorationType::Wiener as u8, 1);
assert_eq!(RestorationType::Sgrproj as u8, 2);
}
#[test]
fn test_constants() {
assert_eq!(MAX_LOOP_FILTER_LEVEL, 63);
assert_eq!(MAX_CDEF_STRENGTH, 15);
assert!(FILTER_LEVELS_TO_TEST > 0);
assert!(CDEF_STRENGTHS_TO_TEST > 0);
}
}