fast_ssim2/
precompute.rs

1//! Precomputed reference data for fast repeated SSIMULACRA2 comparisons.
2//!
3//! When comparing multiple distorted images against the same reference image,
4//! you can precompute the reference data once and reuse it for ~2x speedup.
5//!
6//! # Example
7//!
8//! ```
9//! use fast_ssim2::Ssimulacra2Reference;
10//! use yuvxyb::{Rgb, TransferCharacteristic, ColorPrimaries};
11//!
12//! // Load reference image
13//! let reference_rgb = vec![[1.0f32, 1.0, 1.0]; 512 * 512];
14//! let reference = Rgb::new(
15//!     reference_rgb,
16//!     512,
17//!     512,
18//!     TransferCharacteristic::SRGB,
19//!     ColorPrimaries::BT709,
20//! ).unwrap();
21//!
22//! // Precompute reference data once
23//! let precomputed = Ssimulacra2Reference::new(reference).unwrap();
24//!
25//! // Compare against a distorted image
26//! let distorted_rgb = vec![[0.9f32, 0.95, 1.05]; 512 * 512];
27//! let distorted = Rgb::new(
28//!     distorted_rgb,
29//!     512,
30//!     512,
31//!     TransferCharacteristic::SRGB,
32//!     ColorPrimaries::BT709,
33//! ).unwrap();
34//! let score = precomputed.compare(distorted).unwrap();
35//! println!("SSIMULACRA2 score: {}", score);
36//! ```
37
38use crate::blur::Blur;
39use crate::input::ToLinearRgb;
40use crate::{
41    downscale_by_2, edge_diff_map, image_multiply, linear_rgb_to_xyb_simd, make_positive_xyb,
42    ssim_map, xyb_to_planar, LinearRgb, Msssim, MsssimScale, SimdImpl, Ssimulacra2Error,
43    NUM_SCALES,
44};
45
46/// Precomputed reference data for a single scale.
47#[derive(Clone, Debug)]
48struct ScaleData {
49    /// Planar XYB representation of reference image
50    img1_planar: [Vec<f32>; 3],
51    /// blur(img1) - mean of reference
52    mu1: [Vec<f32>; 3],
53    /// blur(img1 * img1) - variance component of reference
54    sigma1_sq: [Vec<f32>; 3],
55}
56
57/// Precomputed SSIMULACRA2 reference data for fast repeated comparisons.
58///
59/// This struct stores precomputed data for the reference image at all scales,
60/// allowing you to quickly compare multiple distorted images against the same
61/// reference without recomputing the reference-side data each time.
62///
63/// For simulated annealing or other optimization where you compare many variations
64/// against the same source, this provides approximately 2x speedup.
65#[derive(Clone, Debug)]
66pub struct Ssimulacra2Reference {
67    scales: Vec<ScaleData>,
68    original_width: usize,
69    original_height: usize,
70}
71
72impl Ssimulacra2Reference {
73    /// Precompute reference data for the given source image.
74    ///
75    /// Supports:
76    /// - `imgref` types (with the `imgref` feature): `ImgRef<[u8; 3]>`, `ImgRef<[f32; 3]>`, etc.
77    /// - `yuvxyb` types: `Rgb`, `LinearRgb`
78    /// - Custom types implementing [`ToLinearRgb`]
79    ///
80    /// # Errors
81    /// - If the image is smaller than 8x8 pixels
82    pub fn new<T: ToLinearRgb>(source: T) -> Result<Self, Ssimulacra2Error> {
83        let mut img1: LinearRgb = source.to_linear_rgb().into();
84        if img1.width() < 8 || img1.height() < 8 {
85            return Err(Ssimulacra2Error::InvalidImageSize);
86        }
87
88        let original_width = img1.width();
89        let original_height = img1.height();
90        let mut width = original_width;
91        let mut height = original_height;
92
93        let mut mul = [
94            vec![0.0f32; width * height],
95            vec![0.0f32; width * height],
96            vec![0.0f32; width * height],
97        ];
98        let mut blur = Blur::new(width, height);
99        let mut scales = Vec::with_capacity(NUM_SCALES);
100
101        for scale in 0..NUM_SCALES {
102            if width < 8 || height < 8 {
103                break;
104            }
105
106            if scale > 0 {
107                img1 = downscale_by_2(&img1);
108                width = img1.width();
109                height = img1.height();
110            }
111
112            for c in &mut mul {
113                c.truncate(width * height);
114            }
115            blur.shrink_to(width, height);
116
117            let mut img1_xyb = linear_rgb_to_xyb_simd(img1.clone());
118            make_positive_xyb(&mut img1_xyb);
119
120            let img1_planar = xyb_to_planar(&img1_xyb);
121
122            // Precompute mu1 = blur(img1)
123            let mu1 = blur.blur(&img1_planar);
124
125            // Precompute sigma1_sq = blur(img1 * img1)
126            image_multiply(&img1_planar, &img1_planar, &mut mul, SimdImpl::default());
127            let sigma1_sq = blur.blur(&mul);
128
129            scales.push(ScaleData {
130                img1_planar,
131                mu1,
132                sigma1_sq,
133            });
134        }
135
136        Ok(Self {
137            scales,
138            original_width,
139            original_height,
140        })
141    }
142
143    /// Compare a distorted image against the precomputed reference.
144    ///
145    /// This is approximately 2x faster than calling `compute_ssimulacra2`
146    /// because it only needs to process the distorted image and compute cross-terms.
147    ///
148    /// # Errors
149    /// - If the distorted image dimensions don't match the reference
150    pub fn compare<T: ToLinearRgb>(&self, distorted: T) -> Result<f64, Ssimulacra2Error> {
151        let mut img2: LinearRgb = distorted.to_linear_rgb().into();
152        if img2.width() != self.original_width || img2.height() != self.original_height {
153            return Err(Ssimulacra2Error::NonMatchingImageDimensions);
154        }
155
156        let mut width = img2.width();
157        let mut height = img2.height();
158
159        let mut mul = [
160            vec![0.0f32; width * height],
161            vec![0.0f32; width * height],
162            vec![0.0f32; width * height],
163        ];
164        let mut blur = Blur::new(width, height);
165        let mut msssim = Msssim::default();
166
167        for (scale_idx, scale_data) in self.scales.iter().enumerate() {
168            if width < 8 || height < 8 {
169                break;
170            }
171
172            if scale_idx > 0 {
173                img2 = downscale_by_2(&img2);
174                width = img2.width();
175                height = img2.height();
176            }
177
178            for c in &mut mul {
179                c.truncate(width * height);
180            }
181            blur.shrink_to(width, height);
182
183            let mut img2_xyb = linear_rgb_to_xyb_simd(img2.clone());
184            make_positive_xyb(&mut img2_xyb);
185
186            let img2_planar = xyb_to_planar(&img2_xyb);
187
188            // Compute mu2 = blur(img2)
189            let mu2 = blur.blur(&img2_planar);
190
191            // Compute sigma2_sq = blur(img2 * img2)
192            image_multiply(&img2_planar, &img2_planar, &mut mul, SimdImpl::default());
193            let sigma2_sq = blur.blur(&mul);
194
195            // Compute sigma12 = blur(img1 * img2) - cross-term
196            image_multiply(
197                &scale_data.img1_planar,
198                &img2_planar,
199                &mut mul,
200                SimdImpl::default(),
201            );
202            let sigma12 = blur.blur(&mul);
203
204            // Use precomputed mu1 and sigma1_sq from reference
205            let avg_ssim = ssim_map(
206                width,
207                height,
208                &scale_data.mu1,
209                &mu2,
210                &scale_data.sigma1_sq,
211                &sigma2_sq,
212                &sigma12,
213                SimdImpl::default(),
214            );
215
216            let avg_edgediff = edge_diff_map(
217                width,
218                height,
219                &scale_data.img1_planar,
220                &scale_data.mu1,
221                &img2_planar,
222                &mu2,
223                SimdImpl::default(),
224            );
225
226            msssim.scales.push(MsssimScale {
227                avg_ssim,
228                avg_edgediff,
229            });
230        }
231
232        Ok(msssim.score())
233    }
234
235    /// Get the width of the original reference image.
236    #[must_use]
237    pub fn width(&self) -> usize {
238        self.original_width
239    }
240
241    /// Get the height of the original reference image.
242    #[must_use]
243    pub fn height(&self) -> usize {
244        self.original_height
245    }
246
247    /// Get the number of scales that were precomputed.
248    #[must_use]
249    pub fn num_scales(&self) -> usize {
250        self.scales.len()
251    }
252}
253
254#[cfg(test)]
255mod tests {
256    use super::*;
257    use crate::compute_frame_ssimulacra2;
258    use yuvxyb::{ColorPrimaries, Rgb, TransferCharacteristic};
259
260    #[test]
261    fn test_precompute_matches_full_compute() {
262        // Create a simple test image
263        let width = 64;
264        let height = 64;
265        let source_data: Vec<[f32; 3]> = (0..width * height)
266            .map(|i| {
267                let x = (i % width) as f32 / width as f32;
268                let y = (i / width) as f32 / height as f32;
269                [x, y, 0.5]
270            })
271            .collect();
272
273        let distorted_data: Vec<[f32; 3]> = source_data
274            .iter()
275            .map(|&[r, g, b]| [r * 0.9, g * 0.95, b * 1.05])
276            .collect();
277
278        let source = Rgb::new(
279            source_data.clone(),
280            width,
281            height,
282            TransferCharacteristic::SRGB,
283            ColorPrimaries::BT709,
284        )
285        .unwrap();
286
287        let distorted = Rgb::new(
288            distorted_data,
289            width,
290            height,
291            TransferCharacteristic::SRGB,
292            ColorPrimaries::BT709,
293        )
294        .unwrap();
295
296        // Compute using full method
297        let source_clone = Rgb::new(
298            source_data,
299            width,
300            height,
301            TransferCharacteristic::SRGB,
302            ColorPrimaries::BT709,
303        )
304        .unwrap();
305        let full_score = compute_frame_ssimulacra2(source_clone, distorted.clone()).unwrap();
306
307        // Compute using precomputed reference
308        let precomputed = Ssimulacra2Reference::new(source).unwrap();
309        let precomputed_score = precomputed.compare(distorted).unwrap();
310
311        // Scores should match exactly (both use same SIMD XYB path)
312        assert!(
313            (full_score - precomputed_score).abs() < 1e-6,
314            "Scores don't match: full={}, precomputed={}",
315            full_score,
316            precomputed_score
317        );
318    }
319
320    #[test]
321    fn test_precompute_dimension_mismatch() {
322        let source_data: Vec<[f32; 3]> = vec![[0.5, 0.5, 0.5]; 64 * 64];
323        let distorted_data: Vec<[f32; 3]> = vec![[0.4, 0.4, 0.4]; 32 * 32]; // Wrong size
324
325        let source = Rgb::new(
326            source_data,
327            64,
328            64,
329            TransferCharacteristic::SRGB,
330            ColorPrimaries::BT709,
331        )
332        .unwrap();
333
334        let distorted = Rgb::new(
335            distorted_data,
336            32,
337            32,
338            TransferCharacteristic::SRGB,
339            ColorPrimaries::BT709,
340        )
341        .unwrap();
342
343        let precomputed = Ssimulacra2Reference::new(source).unwrap();
344        let result = precomputed.compare(distorted);
345
346        assert!(matches!(
347            result,
348            Err(Ssimulacra2Error::NonMatchingImageDimensions)
349        ));
350    }
351
352    #[test]
353    fn test_precompute_metadata() {
354        let data: Vec<[f32; 3]> = vec![[0.5, 0.5, 0.5]; 128 * 96];
355        let source = Rgb::new(
356            data,
357            128,
358            96,
359            TransferCharacteristic::SRGB,
360            ColorPrimaries::BT709,
361        )
362        .unwrap();
363
364        let precomputed = Ssimulacra2Reference::new(source).unwrap();
365
366        assert_eq!(precomputed.width(), 128);
367        assert_eq!(precomputed.height(), 96);
368        assert!(precomputed.num_scales() > 0);
369        assert!(precomputed.num_scales() <= NUM_SCALES);
370    }
371}