Skip to main content

resamplescope/
score.rs

1use crate::analyze::FilterCurve;
2use crate::filters::KnownFilter;
3
4/// Scoring result for one reference filter compared against a reconstructed curve.
5#[derive(Debug, Clone)]
6pub struct FilterScore {
7    pub filter: KnownFilter,
8    pub correlation: f64,
9    pub rms_error: f64,
10    pub max_error: f64,
11    pub detected_support: f64,
12    pub expected_support: f64,
13}
14
15impl std::fmt::Display for FilterScore {
16    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
17        write!(
18            f,
19            "{}: r={:.4} rms={:.4} max={:.4} support={:.1}/{:.1}",
20            self.filter,
21            self.correlation,
22            self.rms_error,
23            self.max_error,
24            self.detected_support,
25            self.expected_support
26        )
27    }
28}
29
30/// Bin scatter data into uniform intervals and average.
31fn bin_scatter(points: &[(f64, f64)], bin_width: f64) -> Vec<(f64, f64)> {
32    if points.is_empty() {
33        return Vec::new();
34    }
35
36    let min_x = points.iter().map(|p| p.0).fold(f64::INFINITY, f64::min);
37    let max_x = points.iter().map(|p| p.0).fold(f64::NEG_INFINITY, f64::max);
38
39    let n_bins = ((max_x - min_x) / bin_width).ceil() as usize + 1;
40    let mut sums = vec![0.0_f64; n_bins];
41    let mut counts = vec![0_u32; n_bins];
42
43    for &(x, y) in points {
44        let bin = ((x - min_x) / bin_width).floor() as usize;
45        let bin = bin.min(n_bins - 1);
46        sums[bin] += y;
47        counts[bin] += 1;
48    }
49
50    sums.iter()
51        .zip(counts.iter())
52        .enumerate()
53        .filter(|&(_, (_, c))| *c > 0)
54        .map(|(i, (&s, &c))| {
55            let center = min_x + (i as f64 + 0.5) * bin_width;
56            (center, s / c as f64)
57        })
58        .collect()
59}
60
61/// Compute Pearson correlation coefficient between two equal-length vectors.
62fn pearson(a: &[f64], b: &[f64]) -> f64 {
63    let n = a.len();
64    if n < 2 {
65        return 0.0;
66    }
67
68    let mean_a: f64 = a.iter().sum::<f64>() / n as f64;
69    let mean_b: f64 = b.iter().sum::<f64>() / n as f64;
70
71    let mut cov = 0.0;
72    let mut var_a = 0.0;
73    let mut var_b = 0.0;
74
75    for i in 0..n {
76        let da = a[i] - mean_a;
77        let db = b[i] - mean_b;
78        cov += da * db;
79        var_a += da * da;
80        var_b += db * db;
81    }
82
83    let denom = (var_a * var_b).sqrt();
84    if denom < 1e-15 {
85        return 0.0;
86    }
87    cov / denom
88}
89
90/// Detect the support radius from a curve (outermost offset where |weight| > threshold).
91fn detect_support(points: &[(f64, f64)], threshold: f64) -> f64 {
92    points
93        .iter()
94        .filter(|(_, w)| w.abs() > threshold)
95        .map(|(x, _)| x.abs())
96        .fold(0.0_f64, f64::max)
97}
98
99/// Score a reconstructed curve against a single reference filter.
100pub fn score_against(curve: &FilterCurve, filter: KnownFilter) -> FilterScore {
101    // For scatter data, bin first; for connected data, use directly.
102    let comparison_points = if curve.is_scatter {
103        bin_scatter(&curve.points, 0.02)
104    } else {
105        curve.points.clone()
106    };
107
108    if comparison_points.is_empty() {
109        return FilterScore {
110            filter,
111            correlation: 0.0,
112            rms_error: f64::INFINITY,
113            max_error: f64::INFINITY,
114            detected_support: 0.0,
115            expected_support: filter.support(),
116        };
117    }
118
119    // Evaluate reference filter at the same offsets.
120    let actual: Vec<f64> = comparison_points.iter().map(|p| p.1).collect();
121    let reference: Vec<f64> = comparison_points
122        .iter()
123        .map(|p| filter.evaluate(p.0))
124        .collect();
125
126    let correlation = pearson(&actual, &reference);
127
128    // RMS error
129    let n = actual.len() as f64;
130    let rms_error = (actual
131        .iter()
132        .zip(reference.iter())
133        .map(|(a, r)| (a - r).powi(2))
134        .sum::<f64>()
135        / n)
136        .sqrt();
137
138    // Max error
139    let max_error = actual
140        .iter()
141        .zip(reference.iter())
142        .map(|(a, r)| (a - r).abs())
143        .fold(0.0_f64, f64::max);
144
145    let detected_support = detect_support(&comparison_points, 0.005);
146
147    FilterScore {
148        filter,
149        correlation,
150        rms_error,
151        max_error,
152        detected_support,
153        expected_support: filter.support(),
154    }
155}
156
157/// Score a curve against all built-in filters, returning results sorted by correlation (best first).
158pub fn score_against_all(curve: &FilterCurve) -> Vec<FilterScore> {
159    let mut scores: Vec<FilterScore> = KnownFilter::all_named()
160        .iter()
161        .map(|&f| score_against(curve, f))
162        .collect();
163
164    scores.sort_by(|a, b| {
165        b.correlation
166            .partial_cmp(&a.correlation)
167            .unwrap_or(std::cmp::Ordering::Equal)
168    });
169
170    scores
171}
172
173/// Compute SSIM between two equal-sized grayscale images.
174/// Uses 8x8 block-based comparison with standard SSIM constants.
175pub fn ssim(a: &[u8], b: &[u8], width: usize, height: usize) -> f64 {
176    const K1: f64 = 0.01;
177    const K2: f64 = 0.03;
178    const L: f64 = 255.0;
179    let c1 = (K1 * L) * (K1 * L);
180    let c2 = (K2 * L) * (K2 * L);
181    const BLOCK: usize = 8;
182
183    assert_eq!(a.len(), width * height);
184    assert_eq!(b.len(), width * height);
185
186    if width < BLOCK || height < BLOCK {
187        // Fall back to global SSIM for very small images.
188        let mean_a: f64 = a.iter().map(|&v| v as f64).sum::<f64>() / a.len() as f64;
189        let mean_b: f64 = b.iter().map(|&v| v as f64).sum::<f64>() / b.len() as f64;
190        let var_a: f64 =
191            a.iter().map(|&v| (v as f64 - mean_a).powi(2)).sum::<f64>() / a.len() as f64;
192        let var_b: f64 =
193            b.iter().map(|&v| (v as f64 - mean_b).powi(2)).sum::<f64>() / b.len() as f64;
194        let cov: f64 = a
195            .iter()
196            .zip(b.iter())
197            .map(|(&va, &vb)| (va as f64 - mean_a) * (vb as f64 - mean_b))
198            .sum::<f64>()
199            / a.len() as f64;
200
201        let num = (2.0 * mean_a * mean_b + c1) * (2.0 * cov + c2);
202        let den = (mean_a.powi(2) + mean_b.powi(2) + c1) * (var_a + var_b + c2);
203        return num / den;
204    }
205
206    let blocks_x = width / BLOCK;
207    let blocks_y = height / BLOCK;
208    let mut total_ssim = 0.0;
209    let mut count = 0;
210
211    for by in 0..blocks_y {
212        for bx in 0..blocks_x {
213            let mut sum_a = 0.0_f64;
214            let mut sum_b = 0.0_f64;
215            let mut sum_aa = 0.0_f64;
216            let mut sum_bb = 0.0_f64;
217            let mut sum_ab = 0.0_f64;
218            let n = (BLOCK * BLOCK) as f64;
219
220            for dy in 0..BLOCK {
221                for dx in 0..BLOCK {
222                    let y = by * BLOCK + dy;
223                    let x = bx * BLOCK + dx;
224                    let va = a[y * width + x] as f64;
225                    let vb = b[y * width + x] as f64;
226                    sum_a += va;
227                    sum_b += vb;
228                    sum_aa += va * va;
229                    sum_bb += vb * vb;
230                    sum_ab += va * vb;
231                }
232            }
233
234            let mean_a = sum_a / n;
235            let mean_b = sum_b / n;
236            let var_a = sum_aa / n - mean_a * mean_a;
237            let var_b = sum_bb / n - mean_b * mean_b;
238            let cov = sum_ab / n - mean_a * mean_b;
239
240            let num = (2.0 * mean_a * mean_b + c1) * (2.0 * cov + c2);
241            let den = (mean_a.powi(2) + mean_b.powi(2) + c1) * (var_a + var_b + c2);
242            total_ssim += num / den;
243            count += 1;
244        }
245    }
246
247    if count == 0 {
248        return 1.0;
249    }
250    total_ssim / count as f64
251}
252
253#[cfg(test)]
254mod tests {
255    use super::*;
256
257    #[test]
258    fn pearson_perfect_correlation() {
259        let a = vec![1.0, 2.0, 3.0, 4.0, 5.0];
260        let b = vec![2.0, 4.0, 6.0, 8.0, 10.0];
261        assert!((pearson(&a, &b) - 1.0).abs() < 1e-10);
262    }
263
264    #[test]
265    fn pearson_negative_correlation() {
266        let a = vec![1.0, 2.0, 3.0, 4.0, 5.0];
267        let b = vec![10.0, 8.0, 6.0, 4.0, 2.0];
268        assert!((pearson(&a, &b) + 1.0).abs() < 1e-10);
269    }
270
271    #[test]
272    fn ssim_identical_images() {
273        let img = vec![128u8; 64 * 64];
274        let s = ssim(&img, &img, 64, 64);
275        assert!((s - 1.0).abs() < 1e-6, "ssim of identical images: {s}");
276    }
277
278    #[test]
279    fn ssim_different_images() {
280        let a = vec![0u8; 64 * 64];
281        let b = vec![255u8; 64 * 64];
282        let s = ssim(&a, &b, 64, 64);
283        assert!(s < 0.1, "ssim of opposite images should be low: {s}");
284    }
285
286    #[test]
287    fn score_triangle_against_triangle() {
288        // Generate a synthetic triangle filter curve
289        let points: Vec<(f64, f64)> = (-100..=100)
290            .map(|i| {
291                let x = i as f64 / 100.0;
292                (x, KnownFilter::Triangle.evaluate(x))
293            })
294            .collect();
295
296        let curve = FilterCurve {
297            points,
298            area: 1.0,
299            scale_factor: 37.0,
300            is_scatter: false,
301        };
302
303        let score = score_against(&curve, KnownFilter::Triangle);
304        assert!(
305            score.correlation > 0.999,
306            "correlation: {}",
307            score.correlation
308        );
309        assert!(score.rms_error < 0.001, "rms: {}", score.rms_error);
310    }
311
312    #[test]
313    fn score_all_sorts_correctly() {
314        let points: Vec<(f64, f64)> = (-300..=300)
315            .map(|i| {
316                let x = i as f64 / 100.0;
317                (x, KnownFilter::Lanczos3.evaluate(x))
318            })
319            .collect();
320
321        let curve = FilterCurve {
322            points,
323            area: 1.0,
324            scale_factor: 37.0,
325            is_scatter: false,
326        };
327
328        let scores = score_against_all(&curve);
329        assert_eq!(scores[0].filter, KnownFilter::Lanczos3);
330        assert!(scores[0].correlation > 0.999);
331    }
332}