adder_codec_rs/utils/
cv.rs

1use crate::transcoder::source::video::SourceError;
2#[cfg(feature = "open-cv")]
3use adder_codec_core::PixelAddress;
4use adder_codec_core::{Coord, PlaneSize};
5use const_for::const_for;
6use ndarray::{Array3, ArrayView, Axis, Ix2};
7#[cfg(feature = "open-cv")]
8use opencv::prelude::KeyPointTraitConst;
9use serde::{Deserialize, Serialize};
10#[cfg(feature = "open-cv")]
11use std::collections::HashSet;
12
13use std::error::Error;
14use video_rs_adder_dep::Frame;
15
16// TODO: Explore optimal threshold values
17/// The threshold for feature detection
18pub const INTENSITY_THRESHOLD: i16 = 30;
19
20/// Indices for the asynchronous FAST 9_16 algorithm
21#[rustfmt::skip]
22const CIRCLE3: [[isize; 2]; 16] = [
23    [0, 3], [1, 3], [2, 2], [3, 1],
24    [3, 0], [3, -1], [2, -2], [1, -3],
25    [0, -3], [-1, -3], [-2, -2], [-3, -1],
26    [-3, 0], [-3, 1], [-2, 2], [-1, 3]
27];
28
29const STREAK_SIZE: usize = 9;
30
31const fn threshold_table() -> [u8; 512] {
32    let mut table = [0; 512];
33    const_for!(i in -255..256 => {
34        table[(i + 255) as usize] = if i < -INTENSITY_THRESHOLD {
35            1
36        } else if i > INTENSITY_THRESHOLD {
37            2
38        } else {
39            0
40        };
41    });
42
43    table
44}
45
46const THRESHOLD_TABLE: [u8; 512] = threshold_table();
47
48/// Check if the given event is a feature
49///
50/// This implementation is a direct port/adaptation of the OpenCV reference implementation at
51/// https://github.com/opencv/opencv_attic/blob/master/opencv/modules/features2d/src/fast.cpp
52pub fn is_feature(
53    coord: Coord,
54    plane: PlaneSize,
55    img: &Array3<u8>,
56) -> Result<bool, Box<dyn Error>> {
57    if coord.is_border(plane.w_usize(), plane.h_usize(), 3) || coord.c_usize() != 0 {
58        return Ok(false);
59    }
60    unsafe {
61        let candidate: i16 = *img.uget((coord.y_usize(), coord.x_usize(), 0)) as i16;
62
63        let offset = -candidate as isize + 255;
64        let tab = THRESHOLD_TABLE.as_ptr().offset(offset);
65        debug_assert!(
66            (-candidate < INTENSITY_THRESHOLD && *tab == 1)
67                || (-candidate > INTENSITY_THRESHOLD && *tab == 2)
68                || (-candidate >= -INTENSITY_THRESHOLD
69                    && -candidate <= INTENSITY_THRESHOLD
70                    && *tab == 0)
71        );
72        // const uchar* tab = &threshold_tab[0] - v + 2
73        let c = plane.c_usize() as isize;
74        let width = plane.w() as isize * c;
75        // Get a raw pointer to the intensities
76        let ptr = img.as_ptr();
77
78        let y = coord.y as isize;
79        let x = coord.x as isize;
80        debug_assert_eq!(candidate, *ptr.offset(y * width + x * c) as i16);
81
82        let mut d = *tab
83            .offset(*ptr.offset((y + CIRCLE3[0][1]) * width + (x + CIRCLE3[0][0]) * c) as isize)
84            | *tab.offset(
85                *ptr.offset((y + CIRCLE3[8][1]) * width + (x + CIRCLE3[8][0]) * c) as isize,
86            );
87
88        // If both check pixels are within the intensity threshold range, it's not a feature
89        if d == 0 {
90            return Ok(false);
91        }
92
93        // Check other pixels that are on opposite sides of the circle
94        d &= *tab
95            .offset(*ptr.offset((y + CIRCLE3[2][1]) * width + (x + CIRCLE3[2][0]) * c) as isize)
96            | *tab.offset(
97                *ptr.offset((y + CIRCLE3[10][1]) * width + (x + CIRCLE3[10][0]) * c) as isize,
98            );
99        d &= *tab
100            .offset(*ptr.offset((y + CIRCLE3[4][1]) * width + (x + CIRCLE3[4][0]) * c) as isize)
101            | *tab.offset(
102                *ptr.offset((y + CIRCLE3[12][1]) * width + (x + CIRCLE3[12][0]) * c) as isize,
103            );
104        d &= *tab
105            .offset(*ptr.offset((y + CIRCLE3[6][1]) * width + (x + CIRCLE3[6][0]) * c) as isize)
106            | *tab.offset(
107                *ptr.offset((y + CIRCLE3[14][1]) * width + (x + CIRCLE3[14][0]) * c) as isize,
108            );
109
110        // Not a feature
111        if d == 0 {
112            return Ok(false);
113        }
114
115        // Check other pixels that are on opposite sides of the circle
116        d &= *tab
117            .offset(*ptr.offset((y + CIRCLE3[1][1]) * width + (x + CIRCLE3[1][0]) * c) as isize)
118            | *tab.offset(
119                *ptr.offset((y + CIRCLE3[9][1]) * width + (x + CIRCLE3[9][0]) * c) as isize,
120            );
121        d &= *tab
122            .offset(*ptr.offset((y + CIRCLE3[3][1]) * width + (x + CIRCLE3[3][0]) * c) as isize)
123            | *tab.offset(
124                *ptr.offset((y + CIRCLE3[11][1]) * width + (x + CIRCLE3[11][0]) * c) as isize,
125            );
126        d &= *tab
127            .offset(*ptr.offset((y + CIRCLE3[5][1]) * width + (x + CIRCLE3[5][0]) * c) as isize)
128            | *tab.offset(
129                *ptr.offset((y + CIRCLE3[13][1]) * width + (x + CIRCLE3[13][0]) * c) as isize,
130            );
131        d &= *tab
132            .offset(*ptr.offset((y + CIRCLE3[7][1]) * width + (x + CIRCLE3[7][0]) * c) as isize)
133            | *tab.offset(
134                *ptr.offset((y + CIRCLE3[15][1]) * width + (x + CIRCLE3[15][0]) * c) as isize,
135            );
136
137        if d & 1 > 0 {
138            // It's a dark streak
139            let vt = candidate - INTENSITY_THRESHOLD;
140            let mut count = 0;
141
142            for k in 0..16 {
143                let x = *ptr.offset((y + CIRCLE3[k][1]) * width + (x + CIRCLE3[k][0]) * c) as i16;
144                if x < vt {
145                    count += 1;
146                    if count == STREAK_SIZE {
147                        return Ok(true);
148                    }
149                } else {
150                    count = 0;
151                }
152            }
153            for k in 16..25 {
154                let x = *ptr.offset((y + CIRCLE3[k - 16][1]) * width + (x + CIRCLE3[k - 16][0]) * c)
155                    as i16;
156                if x < vt {
157                    count += 1;
158                    if count == STREAK_SIZE {
159                        return Ok(true);
160                    }
161                } else {
162                    count = 0;
163
164                    // Then we don't need to check the rest of the circle; can't be a streak long enough
165                    if k == 17 {
166                        return Ok(false);
167                    }
168                }
169            }
170        }
171
172        if d & 2 > 0 {
173            // It's a bright streak
174            let vt = candidate + INTENSITY_THRESHOLD;
175            let mut count = 0;
176            for k in 0..16 {
177                let x = *ptr.offset((y + CIRCLE3[k][1]) * width + (x + CIRCLE3[k][0]) * c) as i16;
178                if x > vt {
179                    count += 1;
180                    if count == STREAK_SIZE {
181                        return Ok(true);
182                    }
183                } else {
184                    count = 0;
185                }
186            }
187            for k in 16..25 {
188                let x = *ptr.offset((y + CIRCLE3[k - 16][1]) * width + (x + CIRCLE3[k - 16][0]) * c)
189                    as i16;
190                if x > vt {
191                    count += 1;
192                    if count == STREAK_SIZE {
193                        return Ok(true);
194                    }
195                } else {
196                    count = 0;
197
198                    // Then we don't need to check the rest of the circle; can't be a streak long enough
199                    if k == 17 {
200                        return Ok(false);
201                    }
202                }
203            }
204        }
205    }
206
207    Ok(false)
208}
209
210/// If the input is a color image and we want a gray image, convert it to grayscale
211pub fn handle_color(mut input: Frame, color: bool) -> Result<Frame, SourceError> {
212    if !color {
213        // Map the three color channels to a single grayscale channel
214        input
215            .exact_chunks_mut((1, 1, 3))
216            .into_iter()
217            .for_each(|mut v| unsafe {
218                *v.uget_mut((0, 0, 0)) = (*v.uget((0, 0, 0)) as f64 * 0.114
219                    + *v.uget((0, 0, 1)) as f64 * 0.587
220                    + *v.uget((0, 0, 2)) as f64 * 0.299)
221                    as u8;
222            });
223
224        // Remove the color channels
225        input.collapse_axis(Axis(2), 0);
226    }
227    Ok(input)
228}
229
230#[cfg(feature = "open-cv")]
231pub fn feature_precision_recall_accuracy(
232    gt: &opencv::core::Vector<opencv::core::KeyPoint>,
233    prediction: &HashSet<Coord>,
234    plane: PlaneSize,
235) -> (f64, f64, f64) {
236    let (mut tp, mut fp, mut tn, mut fnn) = (0, 0, 0, 0);
237
238    // Channel of first pred event:
239    let channel = match prediction.iter().next() {
240        None => None,
241        Some(coord) => coord.c,
242    };
243
244    // convert the keypoints vec to a hashset for convenience
245    let mut gt_hash = HashSet::<Coord>::new();
246    for keypoint in gt {
247        gt_hash.insert(Coord::new(
248            keypoint.pt().x as PixelAddress,
249            keypoint.pt().y as PixelAddress,
250            channel,
251        ));
252    }
253
254    for y in 0..plane.h() {
255        for x in 0..plane.w() {
256            let coord = Coord::new(x, y, None);
257            if prediction.contains(&coord) {
258                if gt_hash.contains(&coord) {
259                    tp += 1;
260                } else {
261                    fp += 1;
262                }
263            } else if gt_hash.contains(&coord) {
264                fnn += 1;
265            } else {
266                tn += 1;
267            }
268        }
269    }
270
271    let precision = (tp as f64) / ((tp + fp) as f64);
272    let recall = (tp as f64) / ((tp + fnn) as f64);
273    let accuracy = ((tp + tn) as f64) / ((tp + tn + fp + fnn) as f64);
274    (precision, recall, accuracy)
275}
276
277/// Container for quality metric results
278#[derive(Debug, Serialize, Deserialize, Copy, Clone)]
279pub struct QualityMetrics {
280    /// Peak signal-to-noise ratio
281    pub psnr: Option<f64>,
282
283    /// Mean squared error
284    pub mse: Option<f64>,
285
286    /// Structural similarity index measure
287    pub ssim: Option<f64>,
288}
289
290impl Default for QualityMetrics {
291    fn default() -> Self {
292        Self {
293            psnr: Some(0.0),
294            mse: Some(0.0),
295            ssim: None,
296        }
297    }
298}
299
300/// Pass in the options for which metrics you want to evaluate by making them Some() in the `results`
301/// that you pass in
302pub fn calculate_quality_metrics(
303    original: &Array3<u8>,
304    reconstructed: &Array3<u8>,
305    mut results: QualityMetrics,
306) -> Result<QualityMetrics, Box<dyn Error>> {
307    if original.shape() != reconstructed.shape() {
308        return Err("Shapes of original and reconstructed images must match".into());
309    }
310
311    let mut mse = calculate_mse(original, reconstructed)?;
312    if mse == 0.0 {
313        // Make sure that PSNR isn't undefined
314        mse = 0.0000001;
315    }
316    if results.mse.is_some() {
317        results.mse = Some(mse);
318    }
319    if results.psnr.is_some() {
320        results.psnr = Some(calculate_psnr(mse)?);
321    }
322    if results.ssim.is_some() {
323        results.ssim = Some(calculate_ssim(original, reconstructed)?);
324    }
325    Ok(results)
326}
327
328/// Calculate the mean squared error
329fn calculate_mse(original: &Array3<u8>, reconstructed: &Array3<u8>) -> Result<f64, Box<dyn Error>> {
330    if original.shape() != reconstructed.shape() {
331        return Err("Shapes of original and reconstructed images must match".into());
332    }
333
334    let mut error_sum = 0.0;
335    original
336        .iter()
337        .zip(reconstructed.iter())
338        .for_each(|(a, b)| {
339            error_sum += (*a as f64 - *b as f64).powi(2);
340        });
341    Ok(error_sum / (original.len() as f64))
342}
343
344/// Calculate the peak signal-to-noise ratio from the given MSE
345fn calculate_psnr(mse: f64) -> Result<f64, Box<dyn Error>> {
346    Ok(20.0 * (255.0_f64).log10() - 10.0 * mse.log10())
347}
348
349// Below is adapted from https://github.com/ChrisRega/image-compare/blob/main/src/ssim.rs
350const DEFAULT_WINDOW_SIZE: usize = 8;
351const K1: f64 = 0.01;
352const K2: f64 = 0.03;
353const L: u8 = u8::MAX;
354const C1: f64 = (K1 * L as f64) * (K1 * L as f64);
355const C2: f64 = (K2 * L as f64) * (K2 * L as f64);
356
357/// Calculate the SSIM score
358fn calculate_ssim(
359    original: &Array3<u8>,
360    reconstructed: &Array3<u8>,
361) -> Result<f64, Box<dyn Error>> {
362    let mut scores = vec![];
363    for channel in 0..original.shape()[2] {
364        let channel_view_original = original.index_axis(Axis(2), channel);
365        let channel_view_reconstructed = reconstructed.index_axis(Axis(2), channel);
366        let windows_original =
367            channel_view_original.windows((DEFAULT_WINDOW_SIZE, DEFAULT_WINDOW_SIZE));
368        let windows_reconstructed =
369            channel_view_reconstructed.windows((DEFAULT_WINDOW_SIZE, DEFAULT_WINDOW_SIZE));
370        let results = windows_original
371            .into_iter()
372            .zip(windows_reconstructed.into_iter())
373            .map(|(w1, w2)| ssim_for_window(w1, w2))
374            .collect::<Vec<_>>();
375        let score = results
376            .iter()
377            .map(|r| r * (DEFAULT_WINDOW_SIZE * DEFAULT_WINDOW_SIZE) as f64)
378            .sum::<f64>()
379            / results
380                .iter()
381                .map(|_r| (DEFAULT_WINDOW_SIZE * DEFAULT_WINDOW_SIZE) as f64)
382                .sum::<f64>();
383        scores.push(score)
384    }
385
386    let score = (scores.iter().sum::<f64>() / scores.len() as f64) * 100.0;
387
388    debug_assert!(score >= 0.0);
389    debug_assert!(score <= 100.0);
390
391    Ok(score)
392}
393
394/// Calculate the SSIM for the given windows
395fn ssim_for_window(source_window: ArrayView<u8, Ix2>, recon_window: ArrayView<u8, Ix2>) -> f64 {
396    let mean_x = mean(&source_window);
397    let mean_y = mean(&recon_window);
398    let variance_x = covariance(&source_window, mean_x, &source_window, mean_x);
399    let variance_y = covariance(&recon_window, mean_y, &recon_window, mean_y);
400    let covariance = covariance(&source_window, mean_x, &recon_window, mean_y);
401    let counter = (2. * mean_x * mean_y + C1) * (2. * covariance + C2);
402    let denominator = (mean_x.powi(2) + mean_y.powi(2) + C1) * (variance_x + variance_y + C2);
403    counter / denominator
404}
405
406/// Calculate the covariance of the given windows for SSIM
407fn covariance(
408    window_x: &ArrayView<u8, Ix2>,
409    mean_x: f64,
410    window_y: &ArrayView<u8, Ix2>,
411    mean_y: f64,
412) -> f64 {
413    window_x
414        .iter()
415        .zip(window_y.iter())
416        .map(|(x, y)| (*x as f64 - mean_x) * (*y as f64 - mean_y))
417        .sum::<f64>()
418}
419
420/// Calculate the mean of the given window for SSIM
421fn mean(window: &ArrayView<u8, Ix2>) -> f64 {
422    let sum = window.iter().map(|pixel| *pixel as f64).sum::<f64>();
423
424    sum / window.len() as f64
425}
426
427/// Clamp the value to the range [0, 255].
428pub fn clamp_u8(frame_val: &mut f64, last_val_ln: &mut f64) {
429    if *frame_val <= 0.0 {
430        *frame_val = 0.0;
431        *last_val_ln = 0.0_f64.ln_1p();
432    } else if *frame_val > 255.0 {
433        *frame_val = 255.0;
434        *last_val_ln = 1.0_f64.ln_1p();
435    }
436}
437
438/// Clamp the value to the range [0, 255]. If the value is outside of this range, set it to a mid-
439/// point gray value.
440pub fn mid_clamp_u8(frame_val: &mut f64, last_val_ln: &mut f64) {
441    if *frame_val < 0.0 || *frame_val > 255.0 {
442        *frame_val = 128.0;
443        *last_val_ln = (128.0_f64 / 255.0_f64).ln_1p();
444    }
445}