1use anyhow::{ensure, Result};
2use num_rational::Rational64;
3use v_frame::{frame::Frame, pixel::Pixel};
4
5use self::solver::{FlatBlockFinder, NoiseModel};
6use crate::{util::frame_into_u8, GrainTableSegment};
7
8mod solver;
9
10const BLOCK_SIZE: usize = 32;
11const BLOCK_SIZE_SQUARED: usize = BLOCK_SIZE * BLOCK_SIZE;
12
13pub struct DiffGenerator {
14    fps: Rational64,
15    source_bit_depth: usize,
16    denoised_bit_depth: usize,
17    frame_count: usize,
18    prev_timestamp: u64,
19    flat_block_finder: FlatBlockFinder,
20    noise_model: NoiseModel,
21    grain_table: Vec<GrainTableSegment>,
22}
23
24impl DiffGenerator {
25    #[must_use]
26    #[inline]
27    pub fn new(fps: Rational64, source_bit_depth: usize, denoised_bit_depth: usize) -> Self {
28        Self {
29            frame_count: 0,
30            fps,
31            flat_block_finder: FlatBlockFinder::new(),
32            noise_model: NoiseModel::new(),
33            grain_table: Vec::new(),
34            prev_timestamp: 0,
35            source_bit_depth,
36            denoised_bit_depth,
37        }
38    }
39
40    #[inline]
47    pub fn diff_frame<T: Pixel, U: Pixel>(
48        &mut self,
49        source: &Frame<T>,
50        denoised: &Frame<U>,
51    ) -> Result<()> {
52        self.diff_frame_internal(
53            &frame_into_u8(source, self.source_bit_depth),
54            &frame_into_u8(denoised, self.denoised_bit_depth),
55        )
56    }
57
58    #[must_use]
61    #[inline]
62    pub fn finish(mut self) -> Vec<GrainTableSegment> {
63        log::debug!("Updating final parameters");
64        self.grain_table.push(
65            self.noise_model
66                .get_grain_parameters(self.prev_timestamp, i64::MAX as u64),
67        );
68
69        self.grain_table
70    }
71
72    fn diff_frame_internal(&mut self, source: &Frame<u8>, denoised: &Frame<u8>) -> Result<()> {
73        verify_dimensions_match(source, denoised)?;
74
75        let (flat_blocks, num_flat_blocks) = self.flat_block_finder.run(&source.planes[0]);
76        log::debug!("Num flat blocks: {num_flat_blocks}");
77
78        log::debug!("Updating noise model");
79        let status = self.noise_model.update(source, denoised, &flat_blocks);
80
81        if status == NoiseStatus::DifferentType {
82            let cur_timestamp = self.frame_count as u64 * 10_000_000u64 * *self.fps.denom() as u64
83                / *self.fps.numer() as u64;
84            log::debug!(
85                "Updating parameters for times {} to {}",
86                self.prev_timestamp,
87                cur_timestamp
88            );
89            self.grain_table.push(
90                self.noise_model
91                    .get_grain_parameters(self.prev_timestamp, cur_timestamp),
92            );
93            self.noise_model.save_latest();
94            self.prev_timestamp = cur_timestamp;
95        }
96        log::debug!("Noise model updated for frame {}", self.frame_count);
97        self.frame_count += 1;
98
99        Ok(())
100    }
101}
102
103#[derive(Debug)]
104enum NoiseStatus {
105    Ok,
106    DifferentType,
107    #[allow(dead_code)]
108    Error(anyhow::Error),
109}
110
111impl PartialEq for NoiseStatus {
112    fn eq(&self, other: &Self) -> bool {
113        match (self, other) {
114            (&Self::Error(_), &Self::Error(_)) => true,
115            _ => core::mem::discriminant(self) == core::mem::discriminant(other),
116        }
117    }
118}
119
120fn verify_dimensions_match(source: &Frame<u8>, denoised: &Frame<u8>) -> Result<()> {
121    let res_1 = (source.planes[0].cfg.width, source.planes[0].cfg.height);
122    let res_2 = (denoised.planes[0].cfg.width, denoised.planes[0].cfg.height);
123    ensure!(
124        res_1 == res_2,
125        "Luma resolutions were not equal, {}x{} != {}x{}",
126        res_1.0,
127        res_1.1,
128        res_2.0,
129        res_2.1
130    );
131
132    let res_1 = (source.planes[1].cfg.width, source.planes[1].cfg.height);
133    let res_2 = (denoised.planes[1].cfg.width, denoised.planes[1].cfg.height);
134    ensure!(
135        res_1 == res_2,
136        "Chroma resolutions were not equal, {}x{} != {}x{}",
137        res_1.0,
138        res_1.1,
139        res_2.0,
140        res_2.1
141    );
142
143    Ok(())
144}