adder_codec_rs/utils/
cv.rs

1use crate::transcoder::source::video::SourceError;
2#[cfg(feature = "open-cv")]
3use {
4    adder_codec_core::PixelAddress,
5    opencv::prelude::KeyPointTraitConst,
6    std::collections::HashSet,
7};
8use adder_codec_core::{Coord, PlaneSize};
9use const_for::const_for;
10use ndarray::{Array3, ArrayView, Axis, Ix2};
11
12
13use serde::{Deserialize, Serialize};
14
15
16
17use std::error::Error;
18use video_rs_adder_dep::Frame;
19
20// TODO: Explore optimal threshold values
21/// The threshold for feature detection
22pub const INTENSITY_THRESHOLD: i16 = 30;
23
24/// Indices for the asynchronous FAST 9_16 algorithm
25#[rustfmt::skip]
26const CIRCLE3: [[isize; 2]; 16] = [
27    [0, 3], [1, 3], [2, 2], [3, 1],
28    [3, 0], [3, -1], [2, -2], [1, -3],
29    [0, -3], [-1, -3], [-2, -2], [-3, -1],
30    [-3, 0], [-3, 1], [-2, 2], [-1, 3]
31];
32
33const STREAK_SIZE: usize = 9;
34
35const fn threshold_table() -> [u8; 512] {
36    let mut table = [0; 512];
37    const_for!(i in -255..256 => {
38        table[(i + 255) as usize] = if i < -INTENSITY_THRESHOLD {
39            1
40        } else if i > INTENSITY_THRESHOLD {
41            2
42        } else {
43            0
44        };
45    });
46
47    table
48}
49
50const THRESHOLD_TABLE: [u8; 512] = threshold_table();
51
52/// Check if the given event is a feature
53///
54/// This implementation is a direct port/adaptation of the OpenCV reference implementation at
55/// https://github.com/opencv/opencv_attic/blob/master/opencv/modules/features2d/src/fast.cpp
56pub fn is_feature(
57    coord: Coord,
58    plane: PlaneSize,
59    img: &Array3<u8>,
60) -> Result<bool, Box<dyn Error>> {
61    if coord.is_border(plane.w_usize(), plane.h_usize(), 3) || coord.c_usize() != 0 {
62        return Ok(false);
63    }
64    unsafe {
65        let candidate: i16 = *img.uget((coord.y_usize(), coord.x_usize(), 0)) as i16;
66
67        let offset = -candidate as isize + 255;
68        let tab = THRESHOLD_TABLE.as_ptr().offset(offset);
69        debug_assert!(
70            (-candidate < INTENSITY_THRESHOLD && *tab == 1)
71                || (-candidate > INTENSITY_THRESHOLD && *tab == 2)
72                || (-candidate >= -INTENSITY_THRESHOLD
73                    && -candidate <= INTENSITY_THRESHOLD
74                    && *tab == 0)
75        );
76        // const uchar* tab = &threshold_tab[0] - v + 2
77        let c = plane.c_usize() as isize;
78        let width = plane.w() as isize * c;
79        // Get a raw pointer to the intensities
80        let ptr = img.as_ptr();
81
82        let y = coord.y as isize;
83        let x = coord.x as isize;
84        debug_assert_eq!(candidate, *ptr.offset(y * width + x * c) as i16);
85
86        let mut d = *tab
87            .offset(*ptr.offset((y + CIRCLE3[0][1]) * width + (x + CIRCLE3[0][0]) * c) as isize)
88            | *tab.offset(
89                *ptr.offset((y + CIRCLE3[8][1]) * width + (x + CIRCLE3[8][0]) * c) as isize,
90            );
91
92        // If both check pixels are within the intensity threshold range, it's not a feature
93        if d == 0 {
94            return Ok(false);
95        }
96
97        // Check other pixels that are on opposite sides of the circle
98        d &= *tab
99            .offset(*ptr.offset((y + CIRCLE3[2][1]) * width + (x + CIRCLE3[2][0]) * c) as isize)
100            | *tab.offset(
101                *ptr.offset((y + CIRCLE3[10][1]) * width + (x + CIRCLE3[10][0]) * c) as isize,
102            );
103        d &= *tab
104            .offset(*ptr.offset((y + CIRCLE3[4][1]) * width + (x + CIRCLE3[4][0]) * c) as isize)
105            | *tab.offset(
106                *ptr.offset((y + CIRCLE3[12][1]) * width + (x + CIRCLE3[12][0]) * c) as isize,
107            );
108        d &= *tab
109            .offset(*ptr.offset((y + CIRCLE3[6][1]) * width + (x + CIRCLE3[6][0]) * c) as isize)
110            | *tab.offset(
111                *ptr.offset((y + CIRCLE3[14][1]) * width + (x + CIRCLE3[14][0]) * c) as isize,
112            );
113
114        // Not a feature
115        if d == 0 {
116            return Ok(false);
117        }
118
119        // Check other pixels that are on opposite sides of the circle
120        d &= *tab
121            .offset(*ptr.offset((y + CIRCLE3[1][1]) * width + (x + CIRCLE3[1][0]) * c) as isize)
122            | *tab.offset(
123                *ptr.offset((y + CIRCLE3[9][1]) * width + (x + CIRCLE3[9][0]) * c) as isize,
124            );
125        d &= *tab
126            .offset(*ptr.offset((y + CIRCLE3[3][1]) * width + (x + CIRCLE3[3][0]) * c) as isize)
127            | *tab.offset(
128                *ptr.offset((y + CIRCLE3[11][1]) * width + (x + CIRCLE3[11][0]) * c) as isize,
129            );
130        d &= *tab
131            .offset(*ptr.offset((y + CIRCLE3[5][1]) * width + (x + CIRCLE3[5][0]) * c) as isize)
132            | *tab.offset(
133                *ptr.offset((y + CIRCLE3[13][1]) * width + (x + CIRCLE3[13][0]) * c) as isize,
134            );
135        d &= *tab
136            .offset(*ptr.offset((y + CIRCLE3[7][1]) * width + (x + CIRCLE3[7][0]) * c) as isize)
137            | *tab.offset(
138                *ptr.offset((y + CIRCLE3[15][1]) * width + (x + CIRCLE3[15][0]) * c) as isize,
139            );
140
141        if d & 1 > 0 {
142            // It's a dark streak
143            let vt = candidate - INTENSITY_THRESHOLD;
144            let mut count = 0;
145
146            for [a, b] in CIRCLE3 {
147                let x = *ptr.offset((y + b) * width + (x + a) * c) as i16;
148                if x < vt {
149                    count += 1;
150                    if count == STREAK_SIZE {
151                        return Ok(true);
152                    }
153                } else {
154                    count = 0;
155                }
156            }
157            for k in 16..25 {
158                let x = *ptr.offset((y + CIRCLE3[k - 16][1]) * width + (x + CIRCLE3[k - 16][0]) * c)
159                    as i16;
160                if x < vt {
161                    count += 1;
162                    if count == STREAK_SIZE {
163                        return Ok(true);
164                    }
165                } else {
166                    count = 0;
167
168                    // Then we don't need to check the rest of the circle; can't be a streak long enough
169                    if k == 17 {
170                        return Ok(false);
171                    }
172                }
173            }
174        }
175
176        if d & 2 > 0 {
177            // It's a bright streak
178            let vt = candidate + INTENSITY_THRESHOLD;
179            let mut count = 0;
180            for [a, b] in CIRCLE3 {
181                let x = *ptr.offset((y + b) * width + (x + a) * c) as i16;
182                if x > vt {
183                    count += 1;
184                    if count == STREAK_SIZE {
185                        return Ok(true);
186                    }
187                } else {
188                    count = 0;
189                }
190            }
191            for k in 16..25 {
192                let x = *ptr.offset((y + CIRCLE3[k - 16][1]) * width + (x + CIRCLE3[k - 16][0]) * c)
193                    as i16;
194                if x > vt {
195                    count += 1;
196                    if count == STREAK_SIZE {
197                        return Ok(true);
198                    }
199                } else {
200                    count = 0;
201
202                    // Then we don't need to check the rest of the circle; can't be a streak long enough
203                    if k == 17 {
204                        return Ok(false);
205                    }
206                }
207            }
208        }
209    }
210
211    Ok(false)
212}
213
214/// If the input is a color image and we want a gray image, convert it to grayscale
215pub fn handle_color(mut input: Frame, color: bool) -> Result<Frame, SourceError> {
216    if !color {
217        // Map the three color channels to a single grayscale channel
218        input
219            .exact_chunks_mut((1, 1, 3))
220            .into_iter()
221            .for_each(|mut v| unsafe {
222                *v.uget_mut((0, 0, 0)) = (*v.uget((0, 0, 0)) as f64 * 0.114
223                    + *v.uget((0, 0, 1)) as f64 * 0.587
224                    + *v.uget((0, 0, 2)) as f64 * 0.299)
225                    as u8;
226            });
227
228        // Remove the color channels
229        input.collapse_axis(Axis(2), 0);
230    }
231    Ok(input)
232}
233
234#[cfg(feature = "open-cv")]
235pub fn feature_precision_recall_accuracy(
236    gt: &opencv::core::Vector<opencv::core::KeyPoint>,
237    prediction: &HashSet<Coord>,
238    plane: PlaneSize,
239) -> (f64, f64, f64) {
240    let (mut tp, mut fp, mut tn, mut fnn) = (0, 0, 0, 0);
241
242    // Channel of first pred event:
243    let channel = match prediction.iter().next() {
244        None => None,
245        Some(coord) => coord.c,
246    };
247
248    // convert the keypoints vec to a hashset for convenience
249    let mut gt_hash = HashSet::<Coord>::new();
250    for keypoint in gt {
251        gt_hash.insert(Coord::new(
252            keypoint.pt().x as PixelAddress,
253            keypoint.pt().y as PixelAddress,
254            channel,
255        ));
256    }
257
258    for y in 0..plane.h() {
259        for x in 0..plane.w() {
260            let coord = Coord::new(x, y, None);
261            if prediction.contains(&coord) {
262                if gt_hash.contains(&coord) {
263                    tp += 1;
264                } else {
265                    fp += 1;
266                }
267            } else if gt_hash.contains(&coord) {
268                fnn += 1;
269            } else {
270                tn += 1;
271            }
272        }
273    }
274
275    let precision = (tp as f64) / ((tp + fp) as f64);
276    let recall = (tp as f64) / ((tp + fnn) as f64);
277    let accuracy = ((tp + tn) as f64) / ((tp + tn + fp + fnn) as f64);
278    (precision, recall, accuracy)
279}
280
281/// Container for quality metric results
282#[derive(Debug, Serialize, Deserialize, Copy, Clone)]
283pub struct QualityMetrics {
284    /// Peak signal-to-noise ratio
285    pub psnr: Option<f64>,
286
287    /// Mean squared error
288    pub mse: Option<f64>,
289
290    /// Structural similarity index measure
291    pub ssim: Option<f64>,
292}
293
294impl Default for QualityMetrics {
295    fn default() -> Self {
296        Self {
297            psnr: Some(0.0),
298            mse: Some(0.0),
299            ssim: None,
300        }
301    }
302}
303
304/// Pass in the options for which metrics you want to evaluate by making them Some() in the `results`
305/// that you pass in
306pub fn calculate_quality_metrics(
307    original: &Array3<u8>,
308    reconstructed: &Array3<u8>,
309    mut results: QualityMetrics,
310) -> Result<QualityMetrics, Box<dyn Error>> {
311    if original.shape() != reconstructed.shape() {
312        return Err("Shapes of original and reconstructed images must match".into());
313    }
314
315    let mut mse = calculate_mse(original, reconstructed)?;
316    if mse == 0.0 {
317        // Make sure that PSNR isn't undefined
318        mse = 0.0000001;
319    }
320    if results.mse.is_some() {
321        results.mse = Some(mse);
322    }
323    if results.psnr.is_some() {
324        results.psnr = Some(calculate_psnr(mse)?);
325    }
326    if results.ssim.is_some() {
327        results.ssim = Some(calculate_ssim(original, reconstructed)?);
328    }
329    Ok(results)
330}
331
332/// Calculate the mean squared error
333fn calculate_mse(original: &Array3<u8>, reconstructed: &Array3<u8>) -> Result<f64, Box<dyn Error>> {
334    if original.shape() != reconstructed.shape() {
335        return Err("Shapes of original and reconstructed images must match".into());
336    }
337
338    let mut error_sum = 0.0;
339    original
340        .iter()
341        .zip(reconstructed.iter())
342        .for_each(|(a, b)| {
343            error_sum += (*a as f64 - *b as f64).powi(2);
344        });
345    Ok(error_sum / (original.len() as f64))
346}
347
348/// Calculate the peak signal-to-noise ratio from the given MSE
349fn calculate_psnr(mse: f64) -> Result<f64, Box<dyn Error>> {
350    Ok(20.0 * (255.0_f64).log10() - 10.0 * mse.log10())
351}
352
353// Below is adapted from https://github.com/ChrisRega/image-compare/blob/main/src/ssim.rs
354const DEFAULT_WINDOW_SIZE: usize = 8;
355const K1: f64 = 0.01;
356const K2: f64 = 0.03;
357const L: u8 = u8::MAX;
358const C1: f64 = (K1 * L as f64) * (K1 * L as f64);
359const C2: f64 = (K2 * L as f64) * (K2 * L as f64);
360
361/// Calculate the SSIM score
362fn calculate_ssim(
363    original: &Array3<u8>,
364    reconstructed: &Array3<u8>,
365) -> Result<f64, Box<dyn Error>> {
366    let mut scores = vec![];
367    for channel in 0..original.shape()[2] {
368        let channel_view_original = original.index_axis(Axis(2), channel);
369        let channel_view_reconstructed = reconstructed.index_axis(Axis(2), channel);
370        let windows_original =
371            channel_view_original.windows((DEFAULT_WINDOW_SIZE, DEFAULT_WINDOW_SIZE));
372        let windows_reconstructed =
373            channel_view_reconstructed.windows((DEFAULT_WINDOW_SIZE, DEFAULT_WINDOW_SIZE));
374        let results = windows_original
375            .into_iter()
376            .zip(windows_reconstructed.into_iter())
377            .map(|(w1, w2)| ssim_for_window(w1, w2))
378            .collect::<Vec<_>>();
379        let score = results
380            .iter()
381            .map(|r| r * (DEFAULT_WINDOW_SIZE * DEFAULT_WINDOW_SIZE) as f64)
382            .sum::<f64>()
383            / results
384                .iter()
385                .map(|_r| (DEFAULT_WINDOW_SIZE * DEFAULT_WINDOW_SIZE) as f64)
386                .sum::<f64>();
387        scores.push(score)
388    }
389
390    let score = (scores.iter().sum::<f64>() / scores.len() as f64) * 100.0;
391
392    debug_assert!(score >= 0.0);
393    debug_assert!(score <= 100.0);
394
395    Ok(score)
396}
397
398/// Calculate the SSIM for the given windows
399fn ssim_for_window(source_window: ArrayView<u8, Ix2>, recon_window: ArrayView<u8, Ix2>) -> f64 {
400    let mean_x = mean(&source_window);
401    let mean_y = mean(&recon_window);
402    let variance_x = covariance(&source_window, mean_x, &source_window, mean_x);
403    let variance_y = covariance(&recon_window, mean_y, &recon_window, mean_y);
404    let covariance = covariance(&source_window, mean_x, &recon_window, mean_y);
405    let counter = (2. * mean_x * mean_y + C1) * (2. * covariance + C2);
406    let denominator = (mean_x.powi(2) + mean_y.powi(2) + C1) * (variance_x + variance_y + C2);
407    counter / denominator
408}
409
410/// Calculate the covariance of the given windows for SSIM
411fn covariance(
412    window_x: &ArrayView<u8, Ix2>,
413    mean_x: f64,
414    window_y: &ArrayView<u8, Ix2>,
415    mean_y: f64,
416) -> f64 {
417    window_x
418        .iter()
419        .zip(window_y.iter())
420        .map(|(x, y)| (*x as f64 - mean_x) * (*y as f64 - mean_y))
421        .sum::<f64>()
422}
423
424/// Calculate the mean of the given window for SSIM
425fn mean(window: &ArrayView<u8, Ix2>) -> f64 {
426    let sum = window.iter().map(|pixel| *pixel as f64).sum::<f64>();
427
428    sum / window.len() as f64
429}
430
431/// Clamp the value to the range [0, 255].
432pub fn clamp_u8(frame_val: &mut f64, last_val_ln: &mut f64) {
433    if *frame_val <= 0.0 {
434        *frame_val = 0.0;
435        *last_val_ln = 0.0_f64.ln_1p();
436    } else if *frame_val > 255.0 {
437        *frame_val = 255.0;
438        *last_val_ln = 1.0_f64.ln_1p();
439    }
440}
441
442/// Clamp the value to the range [0, 255]. If the value is outside of this range, set it to a mid-
443/// point gray value.
444pub fn mid_clamp_u8(frame_val: &mut f64, last_val_ln: &mut f64) {
445    if *frame_val < 0.0 || *frame_val > 255.0 {
446        *frame_val = 128.0;
447        *last_val_ln = (128.0_f64 / 255.0_f64).ln_1p();
448    }
449}