1use crate::analyze::FilterCurve;
2use crate::filters::KnownFilter;
3
4#[derive(Debug, Clone)]
6pub struct FilterScore {
7 pub filter: KnownFilter,
8 pub correlation: f64,
9 pub rms_error: f64,
10 pub max_error: f64,
11 pub detected_support: f64,
12 pub expected_support: f64,
13}
14
15impl std::fmt::Display for FilterScore {
16 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
17 write!(
18 f,
19 "{}: r={:.4} rms={:.4} max={:.4} support={:.1}/{:.1}",
20 self.filter,
21 self.correlation,
22 self.rms_error,
23 self.max_error,
24 self.detected_support,
25 self.expected_support
26 )
27 }
28}
29
30fn bin_scatter(points: &[(f64, f64)], bin_width: f64) -> Vec<(f64, f64)> {
32 if points.is_empty() {
33 return Vec::new();
34 }
35
36 let min_x = points.iter().map(|p| p.0).fold(f64::INFINITY, f64::min);
37 let max_x = points.iter().map(|p| p.0).fold(f64::NEG_INFINITY, f64::max);
38
39 let n_bins = ((max_x - min_x) / bin_width).ceil() as usize + 1;
40 let mut sums = vec![0.0_f64; n_bins];
41 let mut counts = vec![0_u32; n_bins];
42
43 for &(x, y) in points {
44 let bin = ((x - min_x) / bin_width).floor() as usize;
45 let bin = bin.min(n_bins - 1);
46 sums[bin] += y;
47 counts[bin] += 1;
48 }
49
50 sums.iter()
51 .zip(counts.iter())
52 .enumerate()
53 .filter(|&(_, (_, c))| *c > 0)
54 .map(|(i, (&s, &c))| {
55 let center = min_x + (i as f64 + 0.5) * bin_width;
56 (center, s / c as f64)
57 })
58 .collect()
59}
60
61fn pearson(a: &[f64], b: &[f64]) -> f64 {
63 let n = a.len();
64 if n < 2 {
65 return 0.0;
66 }
67
68 let mean_a: f64 = a.iter().sum::<f64>() / n as f64;
69 let mean_b: f64 = b.iter().sum::<f64>() / n as f64;
70
71 let mut cov = 0.0;
72 let mut var_a = 0.0;
73 let mut var_b = 0.0;
74
75 for i in 0..n {
76 let da = a[i] - mean_a;
77 let db = b[i] - mean_b;
78 cov += da * db;
79 var_a += da * da;
80 var_b += db * db;
81 }
82
83 let denom = (var_a * var_b).sqrt();
84 if denom < 1e-15 {
85 return 0.0;
86 }
87 cov / denom
88}
89
90fn detect_support(points: &[(f64, f64)], threshold: f64) -> f64 {
92 points
93 .iter()
94 .filter(|(_, w)| w.abs() > threshold)
95 .map(|(x, _)| x.abs())
96 .fold(0.0_f64, f64::max)
97}
98
99pub fn score_against(curve: &FilterCurve, filter: KnownFilter) -> FilterScore {
101 let comparison_points = if curve.is_scatter {
103 bin_scatter(&curve.points, 0.02)
104 } else {
105 curve.points.clone()
106 };
107
108 if comparison_points.is_empty() {
109 return FilterScore {
110 filter,
111 correlation: 0.0,
112 rms_error: f64::INFINITY,
113 max_error: f64::INFINITY,
114 detected_support: 0.0,
115 expected_support: filter.support(),
116 };
117 }
118
119 let actual: Vec<f64> = comparison_points.iter().map(|p| p.1).collect();
121 let reference: Vec<f64> = comparison_points
122 .iter()
123 .map(|p| filter.evaluate(p.0))
124 .collect();
125
126 let correlation = pearson(&actual, &reference);
127
128 let n = actual.len() as f64;
130 let rms_error = (actual
131 .iter()
132 .zip(reference.iter())
133 .map(|(a, r)| (a - r).powi(2))
134 .sum::<f64>()
135 / n)
136 .sqrt();
137
138 let max_error = actual
140 .iter()
141 .zip(reference.iter())
142 .map(|(a, r)| (a - r).abs())
143 .fold(0.0_f64, f64::max);
144
145 let detected_support = detect_support(&comparison_points, 0.005);
146
147 FilterScore {
148 filter,
149 correlation,
150 rms_error,
151 max_error,
152 detected_support,
153 expected_support: filter.support(),
154 }
155}
156
157pub fn score_against_all(curve: &FilterCurve) -> Vec<FilterScore> {
159 let mut scores: Vec<FilterScore> = KnownFilter::all_named()
160 .iter()
161 .map(|&f| score_against(curve, f))
162 .collect();
163
164 scores.sort_by(|a, b| {
165 b.correlation
166 .partial_cmp(&a.correlation)
167 .unwrap_or(std::cmp::Ordering::Equal)
168 });
169
170 scores
171}
172
173pub fn ssim(a: &[u8], b: &[u8], width: usize, height: usize) -> f64 {
176 const K1: f64 = 0.01;
177 const K2: f64 = 0.03;
178 const L: f64 = 255.0;
179 let c1 = (K1 * L) * (K1 * L);
180 let c2 = (K2 * L) * (K2 * L);
181 const BLOCK: usize = 8;
182
183 assert_eq!(a.len(), width * height);
184 assert_eq!(b.len(), width * height);
185
186 if width < BLOCK || height < BLOCK {
187 let mean_a: f64 = a.iter().map(|&v| v as f64).sum::<f64>() / a.len() as f64;
189 let mean_b: f64 = b.iter().map(|&v| v as f64).sum::<f64>() / b.len() as f64;
190 let var_a: f64 =
191 a.iter().map(|&v| (v as f64 - mean_a).powi(2)).sum::<f64>() / a.len() as f64;
192 let var_b: f64 =
193 b.iter().map(|&v| (v as f64 - mean_b).powi(2)).sum::<f64>() / b.len() as f64;
194 let cov: f64 = a
195 .iter()
196 .zip(b.iter())
197 .map(|(&va, &vb)| (va as f64 - mean_a) * (vb as f64 - mean_b))
198 .sum::<f64>()
199 / a.len() as f64;
200
201 let num = (2.0 * mean_a * mean_b + c1) * (2.0 * cov + c2);
202 let den = (mean_a.powi(2) + mean_b.powi(2) + c1) * (var_a + var_b + c2);
203 return num / den;
204 }
205
206 let blocks_x = width / BLOCK;
207 let blocks_y = height / BLOCK;
208 let mut total_ssim = 0.0;
209 let mut count = 0;
210
211 for by in 0..blocks_y {
212 for bx in 0..blocks_x {
213 let mut sum_a = 0.0_f64;
214 let mut sum_b = 0.0_f64;
215 let mut sum_aa = 0.0_f64;
216 let mut sum_bb = 0.0_f64;
217 let mut sum_ab = 0.0_f64;
218 let n = (BLOCK * BLOCK) as f64;
219
220 for dy in 0..BLOCK {
221 for dx in 0..BLOCK {
222 let y = by * BLOCK + dy;
223 let x = bx * BLOCK + dx;
224 let va = a[y * width + x] as f64;
225 let vb = b[y * width + x] as f64;
226 sum_a += va;
227 sum_b += vb;
228 sum_aa += va * va;
229 sum_bb += vb * vb;
230 sum_ab += va * vb;
231 }
232 }
233
234 let mean_a = sum_a / n;
235 let mean_b = sum_b / n;
236 let var_a = sum_aa / n - mean_a * mean_a;
237 let var_b = sum_bb / n - mean_b * mean_b;
238 let cov = sum_ab / n - mean_a * mean_b;
239
240 let num = (2.0 * mean_a * mean_b + c1) * (2.0 * cov + c2);
241 let den = (mean_a.powi(2) + mean_b.powi(2) + c1) * (var_a + var_b + c2);
242 total_ssim += num / den;
243 count += 1;
244 }
245 }
246
247 if count == 0 {
248 return 1.0;
249 }
250 total_ssim / count as f64
251}
252
253#[cfg(test)]
254mod tests {
255 use super::*;
256
257 #[test]
258 fn pearson_perfect_correlation() {
259 let a = vec![1.0, 2.0, 3.0, 4.0, 5.0];
260 let b = vec![2.0, 4.0, 6.0, 8.0, 10.0];
261 assert!((pearson(&a, &b) - 1.0).abs() < 1e-10);
262 }
263
264 #[test]
265 fn pearson_negative_correlation() {
266 let a = vec![1.0, 2.0, 3.0, 4.0, 5.0];
267 let b = vec![10.0, 8.0, 6.0, 4.0, 2.0];
268 assert!((pearson(&a, &b) + 1.0).abs() < 1e-10);
269 }
270
271 #[test]
272 fn ssim_identical_images() {
273 let img = vec![128u8; 64 * 64];
274 let s = ssim(&img, &img, 64, 64);
275 assert!((s - 1.0).abs() < 1e-6, "ssim of identical images: {s}");
276 }
277
278 #[test]
279 fn ssim_different_images() {
280 let a = vec![0u8; 64 * 64];
281 let b = vec![255u8; 64 * 64];
282 let s = ssim(&a, &b, 64, 64);
283 assert!(s < 0.1, "ssim of opposite images should be low: {s}");
284 }
285
286 #[test]
287 fn score_triangle_against_triangle() {
288 let points: Vec<(f64, f64)> = (-100..=100)
290 .map(|i| {
291 let x = i as f64 / 100.0;
292 (x, KnownFilter::Triangle.evaluate(x))
293 })
294 .collect();
295
296 let curve = FilterCurve {
297 points,
298 area: 1.0,
299 scale_factor: 37.0,
300 is_scatter: false,
301 };
302
303 let score = score_against(&curve, KnownFilter::Triangle);
304 assert!(
305 score.correlation > 0.999,
306 "correlation: {}",
307 score.correlation
308 );
309 assert!(score.rms_error < 0.001, "rms: {}", score.rms_error);
310 }
311
312 #[test]
313 fn score_all_sorts_correctly() {
314 let points: Vec<(f64, f64)> = (-300..=300)
315 .map(|i| {
316 let x = i as f64 / 100.0;
317 (x, KnownFilter::Lanczos3.evaluate(x))
318 })
319 .collect();
320
321 let curve = FilterCurve {
322 points,
323 area: 1.0,
324 scale_factor: 37.0,
325 is_scatter: false,
326 };
327
328 let scores = score_against_all(&curve);
329 assert_eq!(scores[0].filter, KnownFilter::Lanczos3);
330 assert!(scores[0].correlation > 0.999);
331 }
332}