1use crate::transcoder::source::video::SourceError;
2#[cfg(feature = "open-cv")]
3use adder_codec_core::PixelAddress;
4use adder_codec_core::{Coord, PlaneSize};
5use const_for::const_for;
6use ndarray::{Array3, ArrayView, Axis, Ix2};
7#[cfg(feature = "open-cv")]
8use opencv::prelude::KeyPointTraitConst;
9use serde::{Deserialize, Serialize};
10#[cfg(feature = "open-cv")]
11use std::collections::HashSet;
12
13use std::error::Error;
14use video_rs_adder_dep::Frame;
15
16pub const INTENSITY_THRESHOLD: i16 = 30;
19
20#[rustfmt::skip]
22const CIRCLE3: [[isize; 2]; 16] = [
23 [0, 3], [1, 3], [2, 2], [3, 1],
24 [3, 0], [3, -1], [2, -2], [1, -3],
25 [0, -3], [-1, -3], [-2, -2], [-3, -1],
26 [-3, 0], [-3, 1], [-2, 2], [-1, 3]
27];
28
29const STREAK_SIZE: usize = 9;
30
31const fn threshold_table() -> [u8; 512] {
32 let mut table = [0; 512];
33 const_for!(i in -255..256 => {
34 table[(i + 255) as usize] = if i < -INTENSITY_THRESHOLD {
35 1
36 } else if i > INTENSITY_THRESHOLD {
37 2
38 } else {
39 0
40 };
41 });
42
43 table
44}
45
46const THRESHOLD_TABLE: [u8; 512] = threshold_table();
47
48pub fn is_feature(
53 coord: Coord,
54 plane: PlaneSize,
55 img: &Array3<u8>,
56) -> Result<bool, Box<dyn Error>> {
57 if coord.is_border(plane.w_usize(), plane.h_usize(), 3) || coord.c_usize() != 0 {
58 return Ok(false);
59 }
60 unsafe {
61 let candidate: i16 = *img.uget((coord.y_usize(), coord.x_usize(), 0)) as i16;
62
63 let offset = -candidate as isize + 255;
64 let tab = THRESHOLD_TABLE.as_ptr().offset(offset);
65 debug_assert!(
66 (-candidate < INTENSITY_THRESHOLD && *tab == 1)
67 || (-candidate > INTENSITY_THRESHOLD && *tab == 2)
68 || (-candidate >= -INTENSITY_THRESHOLD
69 && -candidate <= INTENSITY_THRESHOLD
70 && *tab == 0)
71 );
72 let c = plane.c_usize() as isize;
74 let width = plane.w() as isize * c;
75 let ptr = img.as_ptr();
77
78 let y = coord.y as isize;
79 let x = coord.x as isize;
80 debug_assert_eq!(candidate, *ptr.offset(y * width + x * c) as i16);
81
82 let mut d = *tab
83 .offset(*ptr.offset((y + CIRCLE3[0][1]) * width + (x + CIRCLE3[0][0]) * c) as isize)
84 | *tab.offset(
85 *ptr.offset((y + CIRCLE3[8][1]) * width + (x + CIRCLE3[8][0]) * c) as isize,
86 );
87
88 if d == 0 {
90 return Ok(false);
91 }
92
93 d &= *tab
95 .offset(*ptr.offset((y + CIRCLE3[2][1]) * width + (x + CIRCLE3[2][0]) * c) as isize)
96 | *tab.offset(
97 *ptr.offset((y + CIRCLE3[10][1]) * width + (x + CIRCLE3[10][0]) * c) as isize,
98 );
99 d &= *tab
100 .offset(*ptr.offset((y + CIRCLE3[4][1]) * width + (x + CIRCLE3[4][0]) * c) as isize)
101 | *tab.offset(
102 *ptr.offset((y + CIRCLE3[12][1]) * width + (x + CIRCLE3[12][0]) * c) as isize,
103 );
104 d &= *tab
105 .offset(*ptr.offset((y + CIRCLE3[6][1]) * width + (x + CIRCLE3[6][0]) * c) as isize)
106 | *tab.offset(
107 *ptr.offset((y + CIRCLE3[14][1]) * width + (x + CIRCLE3[14][0]) * c) as isize,
108 );
109
110 if d == 0 {
112 return Ok(false);
113 }
114
115 d &= *tab
117 .offset(*ptr.offset((y + CIRCLE3[1][1]) * width + (x + CIRCLE3[1][0]) * c) as isize)
118 | *tab.offset(
119 *ptr.offset((y + CIRCLE3[9][1]) * width + (x + CIRCLE3[9][0]) * c) as isize,
120 );
121 d &= *tab
122 .offset(*ptr.offset((y + CIRCLE3[3][1]) * width + (x + CIRCLE3[3][0]) * c) as isize)
123 | *tab.offset(
124 *ptr.offset((y + CIRCLE3[11][1]) * width + (x + CIRCLE3[11][0]) * c) as isize,
125 );
126 d &= *tab
127 .offset(*ptr.offset((y + CIRCLE3[5][1]) * width + (x + CIRCLE3[5][0]) * c) as isize)
128 | *tab.offset(
129 *ptr.offset((y + CIRCLE3[13][1]) * width + (x + CIRCLE3[13][0]) * c) as isize,
130 );
131 d &= *tab
132 .offset(*ptr.offset((y + CIRCLE3[7][1]) * width + (x + CIRCLE3[7][0]) * c) as isize)
133 | *tab.offset(
134 *ptr.offset((y + CIRCLE3[15][1]) * width + (x + CIRCLE3[15][0]) * c) as isize,
135 );
136
137 if d & 1 > 0 {
138 let vt = candidate - INTENSITY_THRESHOLD;
140 let mut count = 0;
141
142 for k in 0..16 {
143 let x = *ptr.offset((y + CIRCLE3[k][1]) * width + (x + CIRCLE3[k][0]) * c) as i16;
144 if x < vt {
145 count += 1;
146 if count == STREAK_SIZE {
147 return Ok(true);
148 }
149 } else {
150 count = 0;
151 }
152 }
153 for k in 16..25 {
154 let x = *ptr.offset((y + CIRCLE3[k - 16][1]) * width + (x + CIRCLE3[k - 16][0]) * c)
155 as i16;
156 if x < vt {
157 count += 1;
158 if count == STREAK_SIZE {
159 return Ok(true);
160 }
161 } else {
162 count = 0;
163
164 if k == 17 {
166 return Ok(false);
167 }
168 }
169 }
170 }
171
172 if d & 2 > 0 {
173 let vt = candidate + INTENSITY_THRESHOLD;
175 let mut count = 0;
176 for k in 0..16 {
177 let x = *ptr.offset((y + CIRCLE3[k][1]) * width + (x + CIRCLE3[k][0]) * c) as i16;
178 if x > vt {
179 count += 1;
180 if count == STREAK_SIZE {
181 return Ok(true);
182 }
183 } else {
184 count = 0;
185 }
186 }
187 for k in 16..25 {
188 let x = *ptr.offset((y + CIRCLE3[k - 16][1]) * width + (x + CIRCLE3[k - 16][0]) * c)
189 as i16;
190 if x > vt {
191 count += 1;
192 if count == STREAK_SIZE {
193 return Ok(true);
194 }
195 } else {
196 count = 0;
197
198 if k == 17 {
200 return Ok(false);
201 }
202 }
203 }
204 }
205 }
206
207 Ok(false)
208}
209
210pub fn handle_color(mut input: Frame, color: bool) -> Result<Frame, SourceError> {
212 if !color {
213 input
215 .exact_chunks_mut((1, 1, 3))
216 .into_iter()
217 .for_each(|mut v| unsafe {
218 *v.uget_mut((0, 0, 0)) = (*v.uget((0, 0, 0)) as f64 * 0.114
219 + *v.uget((0, 0, 1)) as f64 * 0.587
220 + *v.uget((0, 0, 2)) as f64 * 0.299)
221 as u8;
222 });
223
224 input.collapse_axis(Axis(2), 0);
226 }
227 Ok(input)
228}
229
230#[cfg(feature = "open-cv")]
231pub fn feature_precision_recall_accuracy(
232 gt: &opencv::core::Vector<opencv::core::KeyPoint>,
233 prediction: &HashSet<Coord>,
234 plane: PlaneSize,
235) -> (f64, f64, f64) {
236 let (mut tp, mut fp, mut tn, mut fnn) = (0, 0, 0, 0);
237
238 let channel = match prediction.iter().next() {
240 None => None,
241 Some(coord) => coord.c,
242 };
243
244 let mut gt_hash = HashSet::<Coord>::new();
246 for keypoint in gt {
247 gt_hash.insert(Coord::new(
248 keypoint.pt().x as PixelAddress,
249 keypoint.pt().y as PixelAddress,
250 channel,
251 ));
252 }
253
254 for y in 0..plane.h() {
255 for x in 0..plane.w() {
256 let coord = Coord::new(x, y, None);
257 if prediction.contains(&coord) {
258 if gt_hash.contains(&coord) {
259 tp += 1;
260 } else {
261 fp += 1;
262 }
263 } else if gt_hash.contains(&coord) {
264 fnn += 1;
265 } else {
266 tn += 1;
267 }
268 }
269 }
270
271 let precision = (tp as f64) / ((tp + fp) as f64);
272 let recall = (tp as f64) / ((tp + fnn) as f64);
273 let accuracy = ((tp + tn) as f64) / ((tp + tn + fp + fnn) as f64);
274 (precision, recall, accuracy)
275}
276
277#[derive(Debug, Serialize, Deserialize, Copy, Clone)]
279pub struct QualityMetrics {
280 pub psnr: Option<f64>,
282
283 pub mse: Option<f64>,
285
286 pub ssim: Option<f64>,
288}
289
290impl Default for QualityMetrics {
291 fn default() -> Self {
292 Self {
293 psnr: Some(0.0),
294 mse: Some(0.0),
295 ssim: None,
296 }
297 }
298}
299
300pub fn calculate_quality_metrics(
303 original: &Array3<u8>,
304 reconstructed: &Array3<u8>,
305 mut results: QualityMetrics,
306) -> Result<QualityMetrics, Box<dyn Error>> {
307 if original.shape() != reconstructed.shape() {
308 return Err("Shapes of original and reconstructed images must match".into());
309 }
310
311 let mut mse = calculate_mse(original, reconstructed)?;
312 if mse == 0.0 {
313 mse = 0.0000001;
315 }
316 if results.mse.is_some() {
317 results.mse = Some(mse);
318 }
319 if results.psnr.is_some() {
320 results.psnr = Some(calculate_psnr(mse)?);
321 }
322 if results.ssim.is_some() {
323 results.ssim = Some(calculate_ssim(original, reconstructed)?);
324 }
325 Ok(results)
326}
327
328fn calculate_mse(original: &Array3<u8>, reconstructed: &Array3<u8>) -> Result<f64, Box<dyn Error>> {
330 if original.shape() != reconstructed.shape() {
331 return Err("Shapes of original and reconstructed images must match".into());
332 }
333
334 let mut error_sum = 0.0;
335 original
336 .iter()
337 .zip(reconstructed.iter())
338 .for_each(|(a, b)| {
339 error_sum += (*a as f64 - *b as f64).powi(2);
340 });
341 Ok(error_sum / (original.len() as f64))
342}
343
344fn calculate_psnr(mse: f64) -> Result<f64, Box<dyn Error>> {
346 Ok(20.0 * (255.0_f64).log10() - 10.0 * mse.log10())
347}
348
349const DEFAULT_WINDOW_SIZE: usize = 8;
351const K1: f64 = 0.01;
352const K2: f64 = 0.03;
353const L: u8 = u8::MAX;
354const C1: f64 = (K1 * L as f64) * (K1 * L as f64);
355const C2: f64 = (K2 * L as f64) * (K2 * L as f64);
356
357fn calculate_ssim(
359 original: &Array3<u8>,
360 reconstructed: &Array3<u8>,
361) -> Result<f64, Box<dyn Error>> {
362 let mut scores = vec![];
363 for channel in 0..original.shape()[2] {
364 let channel_view_original = original.index_axis(Axis(2), channel);
365 let channel_view_reconstructed = reconstructed.index_axis(Axis(2), channel);
366 let windows_original =
367 channel_view_original.windows((DEFAULT_WINDOW_SIZE, DEFAULT_WINDOW_SIZE));
368 let windows_reconstructed =
369 channel_view_reconstructed.windows((DEFAULT_WINDOW_SIZE, DEFAULT_WINDOW_SIZE));
370 let results = windows_original
371 .into_iter()
372 .zip(windows_reconstructed.into_iter())
373 .map(|(w1, w2)| ssim_for_window(w1, w2))
374 .collect::<Vec<_>>();
375 let score = results
376 .iter()
377 .map(|r| r * (DEFAULT_WINDOW_SIZE * DEFAULT_WINDOW_SIZE) as f64)
378 .sum::<f64>()
379 / results
380 .iter()
381 .map(|_r| (DEFAULT_WINDOW_SIZE * DEFAULT_WINDOW_SIZE) as f64)
382 .sum::<f64>();
383 scores.push(score)
384 }
385
386 let score = (scores.iter().sum::<f64>() / scores.len() as f64) * 100.0;
387
388 debug_assert!(score >= 0.0);
389 debug_assert!(score <= 100.0);
390
391 Ok(score)
392}
393
394fn ssim_for_window(source_window: ArrayView<u8, Ix2>, recon_window: ArrayView<u8, Ix2>) -> f64 {
396 let mean_x = mean(&source_window);
397 let mean_y = mean(&recon_window);
398 let variance_x = covariance(&source_window, mean_x, &source_window, mean_x);
399 let variance_y = covariance(&recon_window, mean_y, &recon_window, mean_y);
400 let covariance = covariance(&source_window, mean_x, &recon_window, mean_y);
401 let counter = (2. * mean_x * mean_y + C1) * (2. * covariance + C2);
402 let denominator = (mean_x.powi(2) + mean_y.powi(2) + C1) * (variance_x + variance_y + C2);
403 counter / denominator
404}
405
406fn covariance(
408 window_x: &ArrayView<u8, Ix2>,
409 mean_x: f64,
410 window_y: &ArrayView<u8, Ix2>,
411 mean_y: f64,
412) -> f64 {
413 window_x
414 .iter()
415 .zip(window_y.iter())
416 .map(|(x, y)| (*x as f64 - mean_x) * (*y as f64 - mean_y))
417 .sum::<f64>()
418}
419
420fn mean(window: &ArrayView<u8, Ix2>) -> f64 {
422 let sum = window.iter().map(|pixel| *pixel as f64).sum::<f64>();
423
424 sum / window.len() as f64
425}
426
427pub fn clamp_u8(frame_val: &mut f64, last_val_ln: &mut f64) {
429 if *frame_val <= 0.0 {
430 *frame_val = 0.0;
431 *last_val_ln = 0.0_f64.ln_1p();
432 } else if *frame_val > 255.0 {
433 *frame_val = 255.0;
434 *last_val_ln = 1.0_f64.ln_1p();
435 }
436}
437
438pub fn mid_clamp_u8(frame_val: &mut f64, last_val_ln: &mut f64) {
441 if *frame_val < 0.0 || *frame_val > 255.0 {
442 *frame_val = 128.0;
443 *last_val_ln = (128.0_f64 / 255.0_f64).ln_1p();
444 }
445}