1use crate::transcoder::source::video::SourceError;
2#[cfg(feature = "open-cv")]
3use {
4 adder_codec_core::PixelAddress,
5 opencv::prelude::KeyPointTraitConst,
6 std::collections::HashSet,
7};
8use adder_codec_core::{Coord, PlaneSize};
9use const_for::const_for;
10use ndarray::{Array3, ArrayView, Axis, Ix2};
11
12
13use serde::{Deserialize, Serialize};
14
15
16
17use std::error::Error;
18use video_rs_adder_dep::Frame;
19
20pub const INTENSITY_THRESHOLD: i16 = 30;
23
24#[rustfmt::skip]
26const CIRCLE3: [[isize; 2]; 16] = [
27 [0, 3], [1, 3], [2, 2], [3, 1],
28 [3, 0], [3, -1], [2, -2], [1, -3],
29 [0, -3], [-1, -3], [-2, -2], [-3, -1],
30 [-3, 0], [-3, 1], [-2, 2], [-1, 3]
31];
32
33const STREAK_SIZE: usize = 9;
34
35const fn threshold_table() -> [u8; 512] {
36 let mut table = [0; 512];
37 const_for!(i in -255..256 => {
38 table[(i + 255) as usize] = if i < -INTENSITY_THRESHOLD {
39 1
40 } else if i > INTENSITY_THRESHOLD {
41 2
42 } else {
43 0
44 };
45 });
46
47 table
48}
49
50const THRESHOLD_TABLE: [u8; 512] = threshold_table();
51
52pub fn is_feature(
57 coord: Coord,
58 plane: PlaneSize,
59 img: &Array3<u8>,
60) -> Result<bool, Box<dyn Error>> {
61 if coord.is_border(plane.w_usize(), plane.h_usize(), 3) || coord.c_usize() != 0 {
62 return Ok(false);
63 }
64 unsafe {
65 let candidate: i16 = *img.uget((coord.y_usize(), coord.x_usize(), 0)) as i16;
66
67 let offset = -candidate as isize + 255;
68 let tab = THRESHOLD_TABLE.as_ptr().offset(offset);
69 debug_assert!(
70 (-candidate < INTENSITY_THRESHOLD && *tab == 1)
71 || (-candidate > INTENSITY_THRESHOLD && *tab == 2)
72 || (-candidate >= -INTENSITY_THRESHOLD
73 && -candidate <= INTENSITY_THRESHOLD
74 && *tab == 0)
75 );
76 let c = plane.c_usize() as isize;
78 let width = plane.w() as isize * c;
79 let ptr = img.as_ptr();
81
82 let y = coord.y as isize;
83 let x = coord.x as isize;
84 debug_assert_eq!(candidate, *ptr.offset(y * width + x * c) as i16);
85
86 let mut d = *tab
87 .offset(*ptr.offset((y + CIRCLE3[0][1]) * width + (x + CIRCLE3[0][0]) * c) as isize)
88 | *tab.offset(
89 *ptr.offset((y + CIRCLE3[8][1]) * width + (x + CIRCLE3[8][0]) * c) as isize,
90 );
91
92 if d == 0 {
94 return Ok(false);
95 }
96
97 d &= *tab
99 .offset(*ptr.offset((y + CIRCLE3[2][1]) * width + (x + CIRCLE3[2][0]) * c) as isize)
100 | *tab.offset(
101 *ptr.offset((y + CIRCLE3[10][1]) * width + (x + CIRCLE3[10][0]) * c) as isize,
102 );
103 d &= *tab
104 .offset(*ptr.offset((y + CIRCLE3[4][1]) * width + (x + CIRCLE3[4][0]) * c) as isize)
105 | *tab.offset(
106 *ptr.offset((y + CIRCLE3[12][1]) * width + (x + CIRCLE3[12][0]) * c) as isize,
107 );
108 d &= *tab
109 .offset(*ptr.offset((y + CIRCLE3[6][1]) * width + (x + CIRCLE3[6][0]) * c) as isize)
110 | *tab.offset(
111 *ptr.offset((y + CIRCLE3[14][1]) * width + (x + CIRCLE3[14][0]) * c) as isize,
112 );
113
114 if d == 0 {
116 return Ok(false);
117 }
118
119 d &= *tab
121 .offset(*ptr.offset((y + CIRCLE3[1][1]) * width + (x + CIRCLE3[1][0]) * c) as isize)
122 | *tab.offset(
123 *ptr.offset((y + CIRCLE3[9][1]) * width + (x + CIRCLE3[9][0]) * c) as isize,
124 );
125 d &= *tab
126 .offset(*ptr.offset((y + CIRCLE3[3][1]) * width + (x + CIRCLE3[3][0]) * c) as isize)
127 | *tab.offset(
128 *ptr.offset((y + CIRCLE3[11][1]) * width + (x + CIRCLE3[11][0]) * c) as isize,
129 );
130 d &= *tab
131 .offset(*ptr.offset((y + CIRCLE3[5][1]) * width + (x + CIRCLE3[5][0]) * c) as isize)
132 | *tab.offset(
133 *ptr.offset((y + CIRCLE3[13][1]) * width + (x + CIRCLE3[13][0]) * c) as isize,
134 );
135 d &= *tab
136 .offset(*ptr.offset((y + CIRCLE3[7][1]) * width + (x + CIRCLE3[7][0]) * c) as isize)
137 | *tab.offset(
138 *ptr.offset((y + CIRCLE3[15][1]) * width + (x + CIRCLE3[15][0]) * c) as isize,
139 );
140
141 if d & 1 > 0 {
142 let vt = candidate - INTENSITY_THRESHOLD;
144 let mut count = 0;
145
146 for [a, b] in CIRCLE3 {
147 let x = *ptr.offset((y + b) * width + (x + a) * c) as i16;
148 if x < vt {
149 count += 1;
150 if count == STREAK_SIZE {
151 return Ok(true);
152 }
153 } else {
154 count = 0;
155 }
156 }
157 for k in 16..25 {
158 let x = *ptr.offset((y + CIRCLE3[k - 16][1]) * width + (x + CIRCLE3[k - 16][0]) * c)
159 as i16;
160 if x < vt {
161 count += 1;
162 if count == STREAK_SIZE {
163 return Ok(true);
164 }
165 } else {
166 count = 0;
167
168 if k == 17 {
170 return Ok(false);
171 }
172 }
173 }
174 }
175
176 if d & 2 > 0 {
177 let vt = candidate + INTENSITY_THRESHOLD;
179 let mut count = 0;
180 for [a, b] in CIRCLE3 {
181 let x = *ptr.offset((y + b) * width + (x + a) * c) as i16;
182 if x > vt {
183 count += 1;
184 if count == STREAK_SIZE {
185 return Ok(true);
186 }
187 } else {
188 count = 0;
189 }
190 }
191 for k in 16..25 {
192 let x = *ptr.offset((y + CIRCLE3[k - 16][1]) * width + (x + CIRCLE3[k - 16][0]) * c)
193 as i16;
194 if x > vt {
195 count += 1;
196 if count == STREAK_SIZE {
197 return Ok(true);
198 }
199 } else {
200 count = 0;
201
202 if k == 17 {
204 return Ok(false);
205 }
206 }
207 }
208 }
209 }
210
211 Ok(false)
212}
213
214pub fn handle_color(mut input: Frame, color: bool) -> Result<Frame, SourceError> {
216 if !color {
217 input
219 .exact_chunks_mut((1, 1, 3))
220 .into_iter()
221 .for_each(|mut v| unsafe {
222 *v.uget_mut((0, 0, 0)) = (*v.uget((0, 0, 0)) as f64 * 0.114
223 + *v.uget((0, 0, 1)) as f64 * 0.587
224 + *v.uget((0, 0, 2)) as f64 * 0.299)
225 as u8;
226 });
227
228 input.collapse_axis(Axis(2), 0);
230 }
231 Ok(input)
232}
233
234#[cfg(feature = "open-cv")]
235pub fn feature_precision_recall_accuracy(
236 gt: &opencv::core::Vector<opencv::core::KeyPoint>,
237 prediction: &HashSet<Coord>,
238 plane: PlaneSize,
239) -> (f64, f64, f64) {
240 let (mut tp, mut fp, mut tn, mut fnn) = (0, 0, 0, 0);
241
242 let channel = match prediction.iter().next() {
244 None => None,
245 Some(coord) => coord.c,
246 };
247
248 let mut gt_hash = HashSet::<Coord>::new();
250 for keypoint in gt {
251 gt_hash.insert(Coord::new(
252 keypoint.pt().x as PixelAddress,
253 keypoint.pt().y as PixelAddress,
254 channel,
255 ));
256 }
257
258 for y in 0..plane.h() {
259 for x in 0..plane.w() {
260 let coord = Coord::new(x, y, None);
261 if prediction.contains(&coord) {
262 if gt_hash.contains(&coord) {
263 tp += 1;
264 } else {
265 fp += 1;
266 }
267 } else if gt_hash.contains(&coord) {
268 fnn += 1;
269 } else {
270 tn += 1;
271 }
272 }
273 }
274
275 let precision = (tp as f64) / ((tp + fp) as f64);
276 let recall = (tp as f64) / ((tp + fnn) as f64);
277 let accuracy = ((tp + tn) as f64) / ((tp + tn + fp + fnn) as f64);
278 (precision, recall, accuracy)
279}
280
281#[derive(Debug, Serialize, Deserialize, Copy, Clone)]
283pub struct QualityMetrics {
284 pub psnr: Option<f64>,
286
287 pub mse: Option<f64>,
289
290 pub ssim: Option<f64>,
292}
293
294impl Default for QualityMetrics {
295 fn default() -> Self {
296 Self {
297 psnr: Some(0.0),
298 mse: Some(0.0),
299 ssim: None,
300 }
301 }
302}
303
304pub fn calculate_quality_metrics(
307 original: &Array3<u8>,
308 reconstructed: &Array3<u8>,
309 mut results: QualityMetrics,
310) -> Result<QualityMetrics, Box<dyn Error>> {
311 if original.shape() != reconstructed.shape() {
312 return Err("Shapes of original and reconstructed images must match".into());
313 }
314
315 let mut mse = calculate_mse(original, reconstructed)?;
316 if mse == 0.0 {
317 mse = 0.0000001;
319 }
320 if results.mse.is_some() {
321 results.mse = Some(mse);
322 }
323 if results.psnr.is_some() {
324 results.psnr = Some(calculate_psnr(mse)?);
325 }
326 if results.ssim.is_some() {
327 results.ssim = Some(calculate_ssim(original, reconstructed)?);
328 }
329 Ok(results)
330}
331
332fn calculate_mse(original: &Array3<u8>, reconstructed: &Array3<u8>) -> Result<f64, Box<dyn Error>> {
334 if original.shape() != reconstructed.shape() {
335 return Err("Shapes of original and reconstructed images must match".into());
336 }
337
338 let mut error_sum = 0.0;
339 original
340 .iter()
341 .zip(reconstructed.iter())
342 .for_each(|(a, b)| {
343 error_sum += (*a as f64 - *b as f64).powi(2);
344 });
345 Ok(error_sum / (original.len() as f64))
346}
347
348fn calculate_psnr(mse: f64) -> Result<f64, Box<dyn Error>> {
350 Ok(20.0 * (255.0_f64).log10() - 10.0 * mse.log10())
351}
352
353const DEFAULT_WINDOW_SIZE: usize = 8;
355const K1: f64 = 0.01;
356const K2: f64 = 0.03;
357const L: u8 = u8::MAX;
358const C1: f64 = (K1 * L as f64) * (K1 * L as f64);
359const C2: f64 = (K2 * L as f64) * (K2 * L as f64);
360
361fn calculate_ssim(
363 original: &Array3<u8>,
364 reconstructed: &Array3<u8>,
365) -> Result<f64, Box<dyn Error>> {
366 let mut scores = vec![];
367 for channel in 0..original.shape()[2] {
368 let channel_view_original = original.index_axis(Axis(2), channel);
369 let channel_view_reconstructed = reconstructed.index_axis(Axis(2), channel);
370 let windows_original =
371 channel_view_original.windows((DEFAULT_WINDOW_SIZE, DEFAULT_WINDOW_SIZE));
372 let windows_reconstructed =
373 channel_view_reconstructed.windows((DEFAULT_WINDOW_SIZE, DEFAULT_WINDOW_SIZE));
374 let results = windows_original
375 .into_iter()
376 .zip(windows_reconstructed.into_iter())
377 .map(|(w1, w2)| ssim_for_window(w1, w2))
378 .collect::<Vec<_>>();
379 let score = results
380 .iter()
381 .map(|r| r * (DEFAULT_WINDOW_SIZE * DEFAULT_WINDOW_SIZE) as f64)
382 .sum::<f64>()
383 / results
384 .iter()
385 .map(|_r| (DEFAULT_WINDOW_SIZE * DEFAULT_WINDOW_SIZE) as f64)
386 .sum::<f64>();
387 scores.push(score)
388 }
389
390 let score = (scores.iter().sum::<f64>() / scores.len() as f64) * 100.0;
391
392 debug_assert!(score >= 0.0);
393 debug_assert!(score <= 100.0);
394
395 Ok(score)
396}
397
398fn ssim_for_window(source_window: ArrayView<u8, Ix2>, recon_window: ArrayView<u8, Ix2>) -> f64 {
400 let mean_x = mean(&source_window);
401 let mean_y = mean(&recon_window);
402 let variance_x = covariance(&source_window, mean_x, &source_window, mean_x);
403 let variance_y = covariance(&recon_window, mean_y, &recon_window, mean_y);
404 let covariance = covariance(&source_window, mean_x, &recon_window, mean_y);
405 let counter = (2. * mean_x * mean_y + C1) * (2. * covariance + C2);
406 let denominator = (mean_x.powi(2) + mean_y.powi(2) + C1) * (variance_x + variance_y + C2);
407 counter / denominator
408}
409
410fn covariance(
412 window_x: &ArrayView<u8, Ix2>,
413 mean_x: f64,
414 window_y: &ArrayView<u8, Ix2>,
415 mean_y: f64,
416) -> f64 {
417 window_x
418 .iter()
419 .zip(window_y.iter())
420 .map(|(x, y)| (*x as f64 - mean_x) * (*y as f64 - mean_y))
421 .sum::<f64>()
422}
423
424fn mean(window: &ArrayView<u8, Ix2>) -> f64 {
426 let sum = window.iter().map(|pixel| *pixel as f64).sum::<f64>();
427
428 sum / window.len() as f64
429}
430
431pub fn clamp_u8(frame_val: &mut f64, last_val_ln: &mut f64) {
433 if *frame_val <= 0.0 {
434 *frame_val = 0.0;
435 *last_val_ln = 0.0_f64.ln_1p();
436 } else if *frame_val > 255.0 {
437 *frame_val = 255.0;
438 *last_val_ln = 1.0_f64.ln_1p();
439 }
440}
441
442pub fn mid_clamp_u8(frame_val: &mut f64, last_val_ln: &mut f64) {
445 if *frame_val < 0.0 || *frame_val > 255.0 {
446 *frame_val = 128.0;
447 *last_val_ln = (128.0_f64 / 255.0_f64).ln_1p();
448 }
449}