Skip to main content

fast_ssim2/
precompute.rs

1//! Precomputed reference data for fast repeated SSIMULACRA2 comparisons.
2//!
3//! When comparing multiple distorted images against the same reference image,
4//! you can precompute the reference data once and reuse it for ~2x speedup.
5//!
6//! # Example
7//!
8//! ```
9//! use fast_ssim2::Ssimulacra2Reference;
10//! use yuvxyb::{Rgb, TransferCharacteristic, ColorPrimaries};
11//!
12//! // Load reference image
13//! use std::num::NonZeroUsize;
14//! let reference_rgb = vec![[1.0f32, 1.0, 1.0]; 512 * 512];
15//! let reference = Rgb::new(
16//!     reference_rgb,
17//!     NonZeroUsize::new(512).unwrap(),
18//!     NonZeroUsize::new(512).unwrap(),
19//!     TransferCharacteristic::SRGB,
20//!     ColorPrimaries::BT709,
21//! ).unwrap();
22//!
23//! // Precompute reference data once
24//! let precomputed = Ssimulacra2Reference::new(reference).unwrap();
25//!
26//! // Compare against a distorted image
27//! let distorted_rgb = vec![[0.9f32, 0.95, 1.05]; 512 * 512];
28//! let distorted = Rgb::new(
29//!     distorted_rgb,
30//!     NonZeroUsize::new(512).unwrap(),
31//!     NonZeroUsize::new(512).unwrap(),
32//!     TransferCharacteristic::SRGB,
33//!     ColorPrimaries::BT709,
34//! ).unwrap();
35//! let score = precomputed.compare(distorted).unwrap();
36//! println!("SSIMULACRA2 score: {}", score);
37//! ```
38
39use crate::blur::Blur;
40use crate::input::ToLinearRgb;
41use crate::{
42    LinearRgb, Msssim, MsssimScale, NUM_SCALES, SimdImpl, Ssimulacra2Error, downscale_by_2,
43    edge_diff_map, image_multiply, linear_rgb_to_xyb_simd, make_positive_xyb, ssim_map,
44    xyb_to_planar, xyb_to_planar_into,
45};
46
47/// Reusable scratch buffers for [`Ssimulacra2Reference::compare_with`].
48///
49/// `Ssimulacra2Reference::compare` allocates roughly 13 image-sized
50/// `Vec<f32>` planes (`mul`, `mu2`, `sigma2_sq`, `sigma12`, `img2_planar`)
51/// plus the [`Blur`] working memory on every call. When you compare many
52/// distorted images against the same reference (encoder rate-distortion
53/// search, simulated annealing, picker training), reuse a `CompareContext`
54/// to amortise those allocations across all calls. Buffers grow only on
55/// the first call and are reused thereafter; later calls do no `Vec` heap
56/// allocation.
57///
58/// Allocated for a specific reference dimension via
59/// [`Ssimulacra2Reference::compare_context`]; passed to
60/// [`Ssimulacra2Reference::compare_with`].
61///
62/// `Send` but not `Sync` — give each worker thread its own context.
63pub struct CompareContext {
64    width: usize,
65    height: usize,
66    blur: Blur,
67    mul: [Vec<f32>; 3],
68    mu2: [Vec<f32>; 3],
69    sigma2_sq: [Vec<f32>; 3],
70    sigma12: [Vec<f32>; 3],
71    img2_planar: [Vec<f32>; 3],
72}
73
74impl CompareContext {
75    fn new(width: usize, height: usize) -> Self {
76        let alloc_plane = || vec![0.0f32; width * height];
77        let alloc_3planes = || [alloc_plane(), alloc_plane(), alloc_plane()];
78        Self {
79            width,
80            height,
81            blur: Blur::new(width, height),
82            mul: alloc_3planes(),
83            mu2: alloc_3planes(),
84            sigma2_sq: alloc_3planes(),
85            sigma12: alloc_3planes(),
86            img2_planar: alloc_3planes(),
87        }
88    }
89
90    /// Restore the working buffers to the original reference dimensions.
91    /// Called at the start of each comparison so previous calls' truncations
92    /// don't leave the buffers under-sized for the next call's scale 0.
93    /// Cheap: the underlying `Vec` capacity is retained from construction,
94    /// so this only updates length (no allocation) plus fills the regrown
95    /// portion with zero.
96    fn reset_to_full(&mut self) {
97        let size = self.width * self.height;
98        for buf in [
99            &mut self.mul,
100            &mut self.mu2,
101            &mut self.sigma2_sq,
102            &mut self.sigma12,
103            &mut self.img2_planar,
104        ] {
105            for c in buf.iter_mut() {
106                c.resize(size, 0.0);
107            }
108        }
109        self.blur.shrink_to(self.width, self.height);
110    }
111
112    /// Truncate the working buffers to fit `width * height` of the current scale.
113    /// `Vec::truncate` does not free memory, so subsequent scales just shrink
114    /// and we never reallocate while iterating the pyramid.
115    fn shrink_to(&mut self, width: usize, height: usize) {
116        let size = width * height;
117        for buf in [
118            &mut self.mul,
119            &mut self.mu2,
120            &mut self.sigma2_sq,
121            &mut self.sigma12,
122            &mut self.img2_planar,
123        ] {
124            for c in buf.iter_mut() {
125                c.truncate(size);
126            }
127        }
128        self.blur.shrink_to(width, height);
129    }
130}
131
132/// Precomputed reference data for a single scale.
133#[derive(Clone, Debug)]
134struct ScaleData {
135    /// Planar XYB representation of reference image
136    img1_planar: [Vec<f32>; 3],
137    /// blur(img1) - mean of reference
138    mu1: [Vec<f32>; 3],
139    /// blur(img1 * img1) - variance component of reference
140    sigma1_sq: [Vec<f32>; 3],
141}
142
143/// Precomputed SSIMULACRA2 reference data for fast repeated comparisons.
144///
145/// This struct stores precomputed data for the reference image at all scales,
146/// allowing you to quickly compare multiple distorted images against the same
147/// reference without recomputing the reference-side data each time.
148///
149/// For simulated annealing or other optimization where you compare many variations
150/// against the same source, this provides approximately 2x speedup.
151#[derive(Clone, Debug)]
152pub struct Ssimulacra2Reference {
153    scales: Vec<ScaleData>,
154    /// Dimensions of the source image as supplied by the caller
155    /// (before any sub-8px reflect-padding).
156    original_width: usize,
157    original_height: usize,
158    /// Working dimensions after sub-8px reflect-padding — equal to the
159    /// original dimensions whenever the source is at least 8x8. The
160    /// per-scale planes are sized from these.
161    padded_width: usize,
162    padded_height: usize,
163}
164
165/// Read-only view of a single scale of the precomputed reference.
166///
167/// Exposes the three planar buffers needed by the strip-aware
168/// comparison path (`compare_strip`) so the walker can use the cached
169/// data directly without re-running the ref-side conversion.
170#[doc(hidden)]
171pub struct ScalePlanesView<'a> {
172    /// Reference XYB-planar image at this scale.
173    pub img1_planar: &'a [Vec<f32>; 3],
174    /// Reference `blur(img1)` at this scale.
175    pub mu1: &'a [Vec<f32>; 3],
176    /// Reference `blur(img1 * img1)` at this scale.
177    pub sigma1_sq: &'a [Vec<f32>; 3],
178    /// Width of the scale-s reference image, in pixels.
179    pub width: usize,
180    /// Height of the scale-s reference image, in pixels.
181    pub height: usize,
182}
183
184impl Ssimulacra2Reference {
185    /// Borrow the precomputed data for scale `scale`, or `None` if
186    /// `scale >= self.num_scales()`.
187    ///
188    /// `#[doc(hidden)]` because the exact representation is an
189    /// implementation detail shared between the precompute and strip
190    /// modules; do not depend on the type signature from outside the
191    /// crate.
192    #[doc(hidden)]
193    #[must_use]
194    pub fn scale_planes(&self, scale: usize) -> Option<ScalePlanesView<'_>> {
195        let data = self.scales.get(scale)?;
196        // Scale-s dimensions follow the same `div_ceil(2)` rule as
197        // `downscale_by_2`. We recompute them here rather than store
198        // per-scale so this view stays zero-cost when not used. The walk
199        // starts from the padded dimensions — that is what the per-scale
200        // planes are sized from (== original dims for sources >= 8x8).
201        let mut w = self.padded_width;
202        let mut h = self.padded_height;
203        for _ in 0..scale {
204            w = w.div_ceil(2);
205            h = h.div_ceil(2);
206        }
207        Some(ScalePlanesView {
208            img1_planar: &data.img1_planar,
209            mu1: &data.mu1,
210            sigma1_sq: &data.sigma1_sq,
211            width: w,
212            height: h,
213        })
214    }
215
216    /// Precompute reference data for the given source image.
217    ///
218    /// Supports:
219    /// - `imgref` types (with the `imgref` feature): `ImgRef<[u8; 3]>`, `ImgRef<[f32; 3]>`, etc.
220    /// - `yuvxyb` types: `Rgb`, `LinearRgb`
221    /// - Custom types implementing [`ToLinearRgb`]
222    ///
223    /// Sub-8px sources are reflect(mirror)-padded up to the 8px pyramid
224    /// floor, matching [`crate::compute_ssimulacra2`]; [`Self::compare`]
225    /// then expects distorted images at the *original* (pre-padding)
226    /// dimensions and pads them the same way.
227    ///
228    /// # Errors
229    /// - If the image (after padding) exceeds [`crate::MAX_IMAGE_PIXELS`] pixels
230    pub fn new<T: ToLinearRgb>(source: T) -> Result<Self, Ssimulacra2Error> {
231        let source_img = source.into_linear_rgb();
232        let original_width = source_img.width();
233        let original_height = source_img.height();
234        // Reflect-pad sub-8px sources up to the pyramid floor, exactly as
235        // the one-shot `compute_ssimulacra2` path does. NO-OP at >= 8px.
236        let mut img1: LinearRgb = crate::reflect_pad_linear(source_img, 8).into();
237        if img1.width().get() < 8 || img1.height().get() < 8 {
238            return Err(Ssimulacra2Error::InvalidImageSize);
239        }
240
241        // Cap pixel count to prevent unbounded working-buffer allocation.
242        let pixels = img1
243            .width()
244            .get()
245            .checked_mul(img1.height().get())
246            .ok_or(Ssimulacra2Error::ImageTooLarge { actual: usize::MAX })?;
247        if pixels > crate::MAX_IMAGE_PIXELS {
248            return Err(Ssimulacra2Error::ImageTooLarge { actual: pixels });
249        }
250
251        let padded_width = img1.width().get();
252        let padded_height = img1.height().get();
253        let mut width = padded_width;
254        let mut height = padded_height;
255
256        let mut mul = [
257            vec![0.0f32; width * height],
258            vec![0.0f32; width * height],
259            vec![0.0f32; width * height],
260        ];
261        let mut blur = Blur::new(width, height);
262        let mut scales = Vec::with_capacity(NUM_SCALES);
263
264        for scale in 0..NUM_SCALES {
265            if width < 8 || height < 8 {
266                break;
267            }
268
269            if scale > 0 {
270                img1 = downscale_by_2(&img1);
271                width = img1.width().get();
272                height = img1.height().get();
273            }
274
275            for c in &mut mul {
276                c.truncate(width * height);
277            }
278            blur.shrink_to(width, height);
279
280            let mut img1_xyb = linear_rgb_to_xyb_simd(img1.clone());
281            make_positive_xyb(&mut img1_xyb);
282
283            let img1_planar = xyb_to_planar(&img1_xyb);
284
285            // Precompute mu1 = blur(img1)
286            let mu1 = blur.blur(&img1_planar);
287
288            // Precompute sigma1_sq = blur(img1 * img1)
289            image_multiply(&img1_planar, &img1_planar, &mut mul, SimdImpl::default());
290            let sigma1_sq = blur.blur(&mul);
291
292            scales.push(ScaleData {
293                img1_planar,
294                mu1,
295                sigma1_sq,
296            });
297        }
298
299        Ok(Self {
300            scales,
301            original_width,
302            original_height,
303            padded_width,
304            padded_height,
305        })
306    }
307
308    /// Allocate a [`CompareContext`] sized for this reference's dimensions.
309    ///
310    /// Pair this with [`Self::compare_with`] to do repeated comparisons
311    /// without allocating fresh working buffers on each call.
312    #[must_use]
313    pub fn compare_context(&self) -> CompareContext {
314        CompareContext::new(self.padded_width, self.padded_height)
315    }
316
317    /// Compare a distorted image against the precomputed reference.
318    ///
319    /// This is approximately 2x faster than calling `compute_ssimulacra2`
320    /// because it only needs to process the distorted image and compute cross-terms.
321    ///
322    /// For batch comparisons (many distorted images vs the same reference),
323    /// prefer [`Self::compare_with`] together with a reusable
324    /// [`CompareContext`] — that path performs zero `Vec` allocations after
325    /// the first call.
326    ///
327    /// # Errors
328    /// - If the distorted image dimensions don't match the reference
329    pub fn compare<T: ToLinearRgb>(&self, distorted: T) -> Result<f64, Ssimulacra2Error> {
330        let mut ctx = self.compare_context();
331        self.compare_with(&mut ctx, distorted)
332    }
333
334    /// Compare a distorted image against the precomputed reference, reusing
335    /// the scratch buffers in `ctx`. Zero `Vec` allocations after the first
336    /// call (`ctx` retains its buffers between invocations).
337    ///
338    /// `ctx` must have been produced by [`Self::compare_context`] on this
339    /// reference. Using a context sized for different dimensions returns
340    /// [`Ssimulacra2Error::NonMatchingImageDimensions`].
341    ///
342    /// # Errors
343    /// - If the distorted image dimensions don't match the reference
344    /// - If `ctx` was sized for a different reference
345    pub fn compare_with<T: ToLinearRgb>(
346        &self,
347        ctx: &mut CompareContext,
348        distorted: T,
349    ) -> Result<f64, Ssimulacra2Error> {
350        let distorted_img = distorted.into_linear_rgb();
351        // Dimensions must match the *original* (pre-padding) reference
352        // dimensions; sub-8px distorted images are then reflect-padded
353        // identically to the reference in `new`.
354        if distorted_img.width() != self.original_width
355            || distorted_img.height() != self.original_height
356        {
357            return Err(Ssimulacra2Error::NonMatchingImageDimensions);
358        }
359        let mut img2: LinearRgb = crate::reflect_pad_linear(distorted_img, 8).into();
360        if ctx.width != self.padded_width || ctx.height != self.padded_height {
361            return Err(Ssimulacra2Error::NonMatchingImageDimensions);
362        }
363
364        let mut width = img2.width().get();
365        let mut height = img2.height().get();
366
367        // Re-expand buffers to full reference size in case a previous call
368        // left them truncated to a small scale. `Vec::resize` reuses
369        // existing capacity, so no heap allocation happens after the first
370        // `compare_context()` call.
371        ctx.reset_to_full();
372
373        // Use the actual number of cached reference scales — the skip-map
374        // must agree with what `score()`'s linear WEIGHT walk will index.
375        let scales_n = self.scales.len();
376        let mut msssim = Msssim::default();
377
378        for (scale_idx, scale_data) in self.scales.iter().enumerate() {
379            if width < 8 || height < 8 {
380                break;
381            }
382
383            if scale_idx > 0 {
384                img2 = downscale_by_2(&img2);
385                width = img2.width().get();
386                height = img2.height().get();
387            }
388
389            ctx.shrink_to(width, height);
390
391            let mut img2_xyb = linear_rgb_to_xyb_simd(img2.clone());
392            make_positive_xyb(&mut img2_xyb);
393
394            // Reuse ctx.img2_planar instead of allocating a fresh [Vec; 3].
395            xyb_to_planar_into(&img2_xyb, &mut ctx.img2_planar);
396
397            // mu2 = blur(img2)
398            ctx.blur.blur_into(&ctx.img2_planar, &mut ctx.mu2);
399
400            // sigma2_sq = blur(img2 * img2)
401            image_multiply(
402                &ctx.img2_planar,
403                &ctx.img2_planar,
404                &mut ctx.mul,
405                SimdImpl::default(),
406            );
407            ctx.blur.blur_into(&ctx.mul, &mut ctx.sigma2_sq);
408
409            // sigma12 = blur(img1 * img2) — cross-term
410            image_multiply(
411                &scale_data.img1_planar,
412                &ctx.img2_planar,
413                &mut ctx.mul,
414                SimdImpl::default(),
415            );
416            ctx.blur.blur_into(&ctx.mul, &mut ctx.sigma12);
417
418            // Use precomputed mu1 and sigma1_sq from reference
419            let avg_ssim = ssim_map(
420                scales_n,
421                scale_idx,
422                width,
423                height,
424                &scale_data.mu1,
425                &ctx.mu2,
426                &scale_data.sigma1_sq,
427                &ctx.sigma2_sq,
428                &ctx.sigma12,
429                SimdImpl::default(),
430            );
431
432            let avg_edgediff = edge_diff_map(
433                scales_n,
434                scale_idx,
435                width,
436                height,
437                &scale_data.img1_planar,
438                &scale_data.mu1,
439                &ctx.img2_planar,
440                &ctx.mu2,
441                SimdImpl::default(),
442            );
443
444            msssim.scales.push(MsssimScale {
445                avg_ssim,
446                avg_edgediff,
447            });
448        }
449
450        Ok(msssim.score())
451    }
452
453    /// Get the width of the original reference image, as supplied by the
454    /// caller (before any sub-8px reflect-padding).
455    #[must_use]
456    pub fn width(&self) -> usize {
457        self.original_width
458    }
459
460    /// Get the height of the original reference image, as supplied by the
461    /// caller (before any sub-8px reflect-padding).
462    #[must_use]
463    pub fn height(&self) -> usize {
464        self.original_height
465    }
466
467    /// Get the number of scales that were precomputed.
468    #[must_use]
469    pub fn num_scales(&self) -> usize {
470        self.scales.len()
471    }
472}
473
474#[cfg(test)]
475mod tests {
476    use super::*;
477    use crate::compute_ssimulacra2;
478    use std::num::NonZeroUsize;
479    use yuvxyb::{ColorPrimaries, Rgb, TransferCharacteristic};
480
481    #[test]
482    fn test_precompute_matches_full_compute() {
483        // Create a simple test image
484        let width = 64usize;
485        let height = 64usize;
486        let nz_width = NonZeroUsize::new(width).unwrap();
487        let nz_height = NonZeroUsize::new(height).unwrap();
488        let source_data: Vec<[f32; 3]> = (0..width * height)
489            .map(|i| {
490                let x = (i % width) as f32 / width as f32;
491                let y = (i / width) as f32 / height as f32;
492                [x, y, 0.5]
493            })
494            .collect();
495
496        let distorted_data: Vec<[f32; 3]> = source_data
497            .iter()
498            .map(|&[r, g, b]| [r * 0.9, g * 0.95, b * 1.05])
499            .collect();
500
501        let source = Rgb::new(
502            source_data.clone(),
503            nz_width,
504            nz_height,
505            TransferCharacteristic::SRGB,
506            ColorPrimaries::BT709,
507        )
508        .unwrap();
509
510        let distorted = Rgb::new(
511            distorted_data,
512            nz_width,
513            nz_height,
514            TransferCharacteristic::SRGB,
515            ColorPrimaries::BT709,
516        )
517        .unwrap();
518
519        // Compute using full method
520        let source_clone = Rgb::new(
521            source_data,
522            nz_width,
523            nz_height,
524            TransferCharacteristic::SRGB,
525            ColorPrimaries::BT709,
526        )
527        .unwrap();
528        let full_score = compute_ssimulacra2(source_clone, distorted.clone()).unwrap();
529
530        // Compute using precomputed reference
531        let precomputed = Ssimulacra2Reference::new(source).unwrap();
532        let precomputed_score = precomputed.compare(distorted).unwrap();
533
534        // Scores should match exactly (both use same SIMD XYB path)
535        assert!(
536            (full_score - precomputed_score).abs() < 1e-6,
537            "Scores don't match: full={}, precomputed={}",
538            full_score,
539            precomputed_score
540        );
541    }
542
543    #[test]
544    fn test_precompute_dimension_mismatch() {
545        let source_data: Vec<[f32; 3]> = vec![[0.5, 0.5, 0.5]; 64 * 64];
546        let distorted_data: Vec<[f32; 3]> = vec![[0.4, 0.4, 0.4]; 32 * 32]; // Wrong size
547
548        let source = Rgb::new(
549            source_data,
550            NonZeroUsize::new(64).unwrap(),
551            NonZeroUsize::new(64).unwrap(),
552            TransferCharacteristic::SRGB,
553            ColorPrimaries::BT709,
554        )
555        .unwrap();
556
557        let distorted = Rgb::new(
558            distorted_data,
559            NonZeroUsize::new(32).unwrap(),
560            NonZeroUsize::new(32).unwrap(),
561            TransferCharacteristic::SRGB,
562            ColorPrimaries::BT709,
563        )
564        .unwrap();
565
566        let precomputed = Ssimulacra2Reference::new(source).unwrap();
567        let result = precomputed.compare(distorted);
568
569        assert!(matches!(
570            result,
571            Err(Ssimulacra2Error::NonMatchingImageDimensions)
572        ));
573    }
574
575    #[test]
576    fn test_compare_with_matches_compare() {
577        // `compare_with(ctx, ..)` must produce the same score as `compare(..)`
578        // — it's just the zero-alloc form of the same computation. We compare
579        // the two paths on a small JPEG-like RGB pair.
580        let width = 64usize;
581        let height = 64usize;
582        let nz_width = NonZeroUsize::new(width).unwrap();
583        let nz_height = NonZeroUsize::new(height).unwrap();
584        let source_data: Vec<[f32; 3]> = (0..width * height)
585            .map(|i| {
586                let x = (i % width) as f32 / width as f32;
587                let y = (i / width) as f32 / height as f32;
588                [x, y, 0.5]
589            })
590            .collect();
591        let distorted_data: Vec<[f32; 3]> = source_data
592            .iter()
593            .map(|&[r, g, b]| [r * 0.92, g * 0.97, b * 1.03])
594            .collect();
595        let source = Rgb::new(
596            source_data,
597            nz_width,
598            nz_height,
599            TransferCharacteristic::SRGB,
600            ColorPrimaries::BT709,
601        )
602        .unwrap();
603        let distorted = Rgb::new(
604            distorted_data,
605            nz_width,
606            nz_height,
607            TransferCharacteristic::SRGB,
608            ColorPrimaries::BT709,
609        )
610        .unwrap();
611
612        let precomputed = Ssimulacra2Reference::new(source).unwrap();
613        let score_compare = precomputed.compare(distorted.clone()).unwrap();
614        let mut ctx = precomputed.compare_context();
615        let score_compare_with = precomputed
616            .compare_with(&mut ctx, distorted.clone())
617            .unwrap();
618        // Calling compare_with a second time exercises buffer reuse — the
619        // result must still match exactly.
620        let score_compare_with_repeat = precomputed.compare_with(&mut ctx, distorted).unwrap();
621
622        // Both paths share the SIMD ops, so the scores should be exactly
623        // equal modulo reduce-order. 1e-9 leaves room for the f64
624        // accumulator order to differ if rustc reorders the loops.
625        assert!(
626            (score_compare - score_compare_with).abs() < 1e-9,
627            "compare={} vs compare_with={}",
628            score_compare,
629            score_compare_with
630        );
631        assert!(
632            (score_compare_with - score_compare_with_repeat).abs() < 1e-12,
633            "compare_with should be deterministic across reuse"
634        );
635    }
636
637    #[test]
638    fn test_compare_context_dimension_mismatch() {
639        // A context allocated for one reference is rejected if used with a
640        // different-dimension distorted image (would be impossible to size
641        // the scratch buffers correctly).
642        let source_a: Vec<[f32; 3]> = vec![[0.5, 0.5, 0.5]; 64 * 64];
643        let source_b: Vec<[f32; 3]> = vec![[0.5, 0.5, 0.5]; 32 * 32];
644        let ref_a = Ssimulacra2Reference::new(
645            Rgb::new(
646                source_a,
647                NonZeroUsize::new(64).unwrap(),
648                NonZeroUsize::new(64).unwrap(),
649                TransferCharacteristic::SRGB,
650                ColorPrimaries::BT709,
651            )
652            .unwrap(),
653        )
654        .unwrap();
655        let distorted_b = Rgb::new(
656            source_b,
657            NonZeroUsize::new(32).unwrap(),
658            NonZeroUsize::new(32).unwrap(),
659            TransferCharacteristic::SRGB,
660            ColorPrimaries::BT709,
661        )
662        .unwrap();
663        let mut ctx = ref_a.compare_context();
664        assert!(matches!(
665            ref_a.compare_with(&mut ctx, distorted_b),
666            Err(Ssimulacra2Error::NonMatchingImageDimensions)
667        ));
668    }
669
670    #[test]
671    fn test_sub_8_reference_pads_and_matches_one_shot() {
672        use crate::LinearRgbImage;
673        // Sub-8px references are reflect-padded like the one-shot path:
674        // identical pairs score ~100, differing pairs score the same as
675        // compute_ssimulacra2 on the same inputs, and width()/height()
676        // report the caller-supplied (pre-padding) dimensions.
677        for (w, h) in [(4usize, 4usize), (1, 1), (3, 7), (7, 3)] {
678            let img = LinearRgbImage::new(vec![[0.5f32, 0.5, 0.5]; w * h], w, h);
679            let reference = Ssimulacra2Reference::new(img.clone())
680                .unwrap_or_else(|e| panic!("{w}x{h} reference must build, got {e:?}"));
681            assert_eq!(reference.width(), w);
682            assert_eq!(reference.height(), h);
683            let score = reference.compare(img).unwrap();
684            assert!(
685                (score - 100.0).abs() < 0.01,
686                "identical {w}x{h} should score ~100, got {score}"
687            );
688        }
689
690        // Differing sub-8 pair: Reference path == one-shot path (both pad
691        // then run the same SIMD pipeline).
692        let a = LinearRgbImage::new(vec![[0.5f32, 0.5, 0.5]; 25], 5, 5);
693        let b = LinearRgbImage::new(vec![[0.9f32, 0.1, 0.2]; 25], 5, 5);
694        let one_shot = compute_ssimulacra2(a.clone(), b.clone()).unwrap();
695        let reference = Ssimulacra2Reference::new(a).unwrap();
696        let via_ref = reference.compare(b).unwrap();
697        assert!(
698            (one_shot - via_ref).abs() < 1e-9,
699            "one-shot {one_shot} vs reference {via_ref}"
700        );
701        assert!(via_ref.is_finite() && via_ref < 100.0);
702    }
703
704    #[test]
705    fn test_sub_8_reference_rejects_mismatched_dims() {
706        use crate::LinearRgbImage;
707        // A 5x5 reference must reject a 4x4 distorted image even though
708        // both would pad to 8x8 — dimension matching happens on the
709        // caller-supplied (pre-padding) dimensions.
710        let reference =
711            Ssimulacra2Reference::new(LinearRgbImage::new(vec![[0.5f32, 0.5, 0.5]; 25], 5, 5))
712                .unwrap();
713        let distorted = LinearRgbImage::new(vec![[0.5f32, 0.5, 0.5]; 16], 4, 4);
714        assert!(matches!(
715            reference.compare(distorted),
716            Err(Ssimulacra2Error::NonMatchingImageDimensions)
717        ));
718    }
719
720    #[test]
721    fn test_precompute_metadata() {
722        let data: Vec<[f32; 3]> = vec![[0.5, 0.5, 0.5]; 128 * 96];
723        let source = Rgb::new(
724            data,
725            NonZeroUsize::new(128).unwrap(),
726            NonZeroUsize::new(96).unwrap(),
727            TransferCharacteristic::SRGB,
728            ColorPrimaries::BT709,
729        )
730        .unwrap();
731
732        let precomputed = Ssimulacra2Reference::new(source).unwrap();
733
734        assert_eq!(precomputed.width(), 128);
735        assert_eq!(precomputed.height(), 96);
736        assert!(precomputed.num_scales() > 0);
737        assert!(precomputed.num_scales() <= NUM_SCALES);
738    }
739}