use crate::blur::Blur;
use crate::input::ToLinearRgb;
use crate::{
LinearRgb, Msssim, MsssimScale, NUM_SCALES, SimdImpl, Ssimulacra2Error, downscale_by_2,
edge_diff_map, image_multiply, linear_rgb_to_xyb_simd, make_positive_xyb, ssim_map,
xyb_to_planar, xyb_to_planar_into,
};
pub struct CompareContext {
width: usize,
height: usize,
blur: Blur,
mul: [Vec<f32>; 3],
mu2: [Vec<f32>; 3],
sigma2_sq: [Vec<f32>; 3],
sigma12: [Vec<f32>; 3],
img2_planar: [Vec<f32>; 3],
}
impl CompareContext {
fn new(width: usize, height: usize) -> Self {
let alloc_plane = || vec![0.0f32; width * height];
let alloc_3planes = || [alloc_plane(), alloc_plane(), alloc_plane()];
Self {
width,
height,
blur: Blur::new(width, height),
mul: alloc_3planes(),
mu2: alloc_3planes(),
sigma2_sq: alloc_3planes(),
sigma12: alloc_3planes(),
img2_planar: alloc_3planes(),
}
}
fn reset_to_full(&mut self) {
let size = self.width * self.height;
for buf in [
&mut self.mul,
&mut self.mu2,
&mut self.sigma2_sq,
&mut self.sigma12,
&mut self.img2_planar,
] {
for c in buf.iter_mut() {
c.resize(size, 0.0);
}
}
self.blur.shrink_to(self.width, self.height);
}
fn shrink_to(&mut self, width: usize, height: usize) {
let size = width * height;
for buf in [
&mut self.mul,
&mut self.mu2,
&mut self.sigma2_sq,
&mut self.sigma12,
&mut self.img2_planar,
] {
for c in buf.iter_mut() {
c.truncate(size);
}
}
self.blur.shrink_to(width, height);
}
}
#[derive(Clone, Debug)]
struct ScaleData {
img1_planar: [Vec<f32>; 3],
mu1: [Vec<f32>; 3],
sigma1_sq: [Vec<f32>; 3],
}
#[derive(Clone, Debug)]
pub struct Ssimulacra2Reference {
scales: Vec<ScaleData>,
original_width: usize,
original_height: usize,
}
#[doc(hidden)]
pub struct ScalePlanesView<'a> {
pub img1_planar: &'a [Vec<f32>; 3],
pub mu1: &'a [Vec<f32>; 3],
pub sigma1_sq: &'a [Vec<f32>; 3],
pub width: usize,
pub height: usize,
}
impl Ssimulacra2Reference {
#[doc(hidden)]
#[must_use]
pub fn scale_planes(&self, scale: usize) -> Option<ScalePlanesView<'_>> {
let data = self.scales.get(scale)?;
let mut w = self.original_width;
let mut h = self.original_height;
for _ in 0..scale {
w = w.div_ceil(2);
h = h.div_ceil(2);
}
Some(ScalePlanesView {
img1_planar: &data.img1_planar,
mu1: &data.mu1,
sigma1_sq: &data.sigma1_sq,
width: w,
height: h,
})
}
pub fn new<T: ToLinearRgb>(source: T) -> Result<Self, Ssimulacra2Error> {
let mut img1: LinearRgb = source.into_linear_rgb().into();
if img1.width().get() < 8 || img1.height().get() < 8 {
return Err(Ssimulacra2Error::InvalidImageSize);
}
let pixels = img1
.width()
.get()
.checked_mul(img1.height().get())
.ok_or(Ssimulacra2Error::ImageTooLarge { actual: usize::MAX })?;
if pixels > crate::MAX_IMAGE_PIXELS {
return Err(Ssimulacra2Error::ImageTooLarge { actual: pixels });
}
let original_width = img1.width().get();
let original_height = img1.height().get();
let mut width = original_width;
let mut height = original_height;
let mut mul = [
vec![0.0f32; width * height],
vec![0.0f32; width * height],
vec![0.0f32; width * height],
];
let mut blur = Blur::new(width, height);
let mut scales = Vec::with_capacity(NUM_SCALES);
for scale in 0..NUM_SCALES {
if width < 8 || height < 8 {
break;
}
if scale > 0 {
img1 = downscale_by_2(&img1);
width = img1.width().get();
height = img1.height().get();
}
for c in &mut mul {
c.truncate(width * height);
}
blur.shrink_to(width, height);
let mut img1_xyb = linear_rgb_to_xyb_simd(img1.clone());
make_positive_xyb(&mut img1_xyb);
let img1_planar = xyb_to_planar(&img1_xyb);
let mu1 = blur.blur(&img1_planar);
image_multiply(&img1_planar, &img1_planar, &mut mul, SimdImpl::default());
let sigma1_sq = blur.blur(&mul);
scales.push(ScaleData {
img1_planar,
mu1,
sigma1_sq,
});
}
Ok(Self {
scales,
original_width,
original_height,
})
}
#[must_use]
pub fn compare_context(&self) -> CompareContext {
CompareContext::new(self.original_width, self.original_height)
}
pub fn compare<T: ToLinearRgb>(&self, distorted: T) -> Result<f64, Ssimulacra2Error> {
let mut ctx = self.compare_context();
self.compare_with(&mut ctx, distorted)
}
pub fn compare_with<T: ToLinearRgb>(
&self,
ctx: &mut CompareContext,
distorted: T,
) -> Result<f64, Ssimulacra2Error> {
let mut img2: LinearRgb = distorted.into_linear_rgb().into();
if img2.width().get() != self.original_width || img2.height().get() != self.original_height
{
return Err(Ssimulacra2Error::NonMatchingImageDimensions);
}
if ctx.width != self.original_width || ctx.height != self.original_height {
return Err(Ssimulacra2Error::NonMatchingImageDimensions);
}
let mut width = img2.width().get();
let mut height = img2.height().get();
ctx.reset_to_full();
let scales_n = self.scales.len();
let mut msssim = Msssim::default();
for (scale_idx, scale_data) in self.scales.iter().enumerate() {
if width < 8 || height < 8 {
break;
}
if scale_idx > 0 {
img2 = downscale_by_2(&img2);
width = img2.width().get();
height = img2.height().get();
}
ctx.shrink_to(width, height);
let mut img2_xyb = linear_rgb_to_xyb_simd(img2.clone());
make_positive_xyb(&mut img2_xyb);
xyb_to_planar_into(&img2_xyb, &mut ctx.img2_planar);
ctx.blur.blur_into(&ctx.img2_planar, &mut ctx.mu2);
image_multiply(
&ctx.img2_planar,
&ctx.img2_planar,
&mut ctx.mul,
SimdImpl::default(),
);
ctx.blur.blur_into(&ctx.mul, &mut ctx.sigma2_sq);
image_multiply(
&scale_data.img1_planar,
&ctx.img2_planar,
&mut ctx.mul,
SimdImpl::default(),
);
ctx.blur.blur_into(&ctx.mul, &mut ctx.sigma12);
let avg_ssim = ssim_map(
scales_n,
scale_idx,
width,
height,
&scale_data.mu1,
&ctx.mu2,
&scale_data.sigma1_sq,
&ctx.sigma2_sq,
&ctx.sigma12,
SimdImpl::default(),
);
let avg_edgediff = edge_diff_map(
scales_n,
scale_idx,
width,
height,
&scale_data.img1_planar,
&scale_data.mu1,
&ctx.img2_planar,
&ctx.mu2,
SimdImpl::default(),
);
msssim.scales.push(MsssimScale {
avg_ssim,
avg_edgediff,
});
}
Ok(msssim.score())
}
#[must_use]
pub fn width(&self) -> usize {
self.original_width
}
#[must_use]
pub fn height(&self) -> usize {
self.original_height
}
#[must_use]
pub fn num_scales(&self) -> usize {
self.scales.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::compute_ssimulacra2;
use std::num::NonZeroUsize;
use yuvxyb::{ColorPrimaries, Rgb, TransferCharacteristic};
#[test]
fn test_precompute_matches_full_compute() {
let width = 64usize;
let height = 64usize;
let nz_width = NonZeroUsize::new(width).unwrap();
let nz_height = NonZeroUsize::new(height).unwrap();
let source_data: Vec<[f32; 3]> = (0..width * height)
.map(|i| {
let x = (i % width) as f32 / width as f32;
let y = (i / width) as f32 / height as f32;
[x, y, 0.5]
})
.collect();
let distorted_data: Vec<[f32; 3]> = source_data
.iter()
.map(|&[r, g, b]| [r * 0.9, g * 0.95, b * 1.05])
.collect();
let source = Rgb::new(
source_data.clone(),
nz_width,
nz_height,
TransferCharacteristic::SRGB,
ColorPrimaries::BT709,
)
.unwrap();
let distorted = Rgb::new(
distorted_data,
nz_width,
nz_height,
TransferCharacteristic::SRGB,
ColorPrimaries::BT709,
)
.unwrap();
let source_clone = Rgb::new(
source_data,
nz_width,
nz_height,
TransferCharacteristic::SRGB,
ColorPrimaries::BT709,
)
.unwrap();
let full_score = compute_ssimulacra2(source_clone, distorted.clone()).unwrap();
let precomputed = Ssimulacra2Reference::new(source).unwrap();
let precomputed_score = precomputed.compare(distorted).unwrap();
assert!(
(full_score - precomputed_score).abs() < 1e-6,
"Scores don't match: full={}, precomputed={}",
full_score,
precomputed_score
);
}
#[test]
fn test_precompute_dimension_mismatch() {
let source_data: Vec<[f32; 3]> = vec![[0.5, 0.5, 0.5]; 64 * 64];
let distorted_data: Vec<[f32; 3]> = vec![[0.4, 0.4, 0.4]; 32 * 32];
let source = Rgb::new(
source_data,
NonZeroUsize::new(64).unwrap(),
NonZeroUsize::new(64).unwrap(),
TransferCharacteristic::SRGB,
ColorPrimaries::BT709,
)
.unwrap();
let distorted = Rgb::new(
distorted_data,
NonZeroUsize::new(32).unwrap(),
NonZeroUsize::new(32).unwrap(),
TransferCharacteristic::SRGB,
ColorPrimaries::BT709,
)
.unwrap();
let precomputed = Ssimulacra2Reference::new(source).unwrap();
let result = precomputed.compare(distorted);
assert!(matches!(
result,
Err(Ssimulacra2Error::NonMatchingImageDimensions)
));
}
#[test]
fn test_compare_with_matches_compare() {
let width = 64usize;
let height = 64usize;
let nz_width = NonZeroUsize::new(width).unwrap();
let nz_height = NonZeroUsize::new(height).unwrap();
let source_data: Vec<[f32; 3]> = (0..width * height)
.map(|i| {
let x = (i % width) as f32 / width as f32;
let y = (i / width) as f32 / height as f32;
[x, y, 0.5]
})
.collect();
let distorted_data: Vec<[f32; 3]> = source_data
.iter()
.map(|&[r, g, b]| [r * 0.92, g * 0.97, b * 1.03])
.collect();
let source = Rgb::new(
source_data,
nz_width,
nz_height,
TransferCharacteristic::SRGB,
ColorPrimaries::BT709,
)
.unwrap();
let distorted = Rgb::new(
distorted_data,
nz_width,
nz_height,
TransferCharacteristic::SRGB,
ColorPrimaries::BT709,
)
.unwrap();
let precomputed = Ssimulacra2Reference::new(source).unwrap();
let score_compare = precomputed.compare(distorted.clone()).unwrap();
let mut ctx = precomputed.compare_context();
let score_compare_with = precomputed
.compare_with(&mut ctx, distorted.clone())
.unwrap();
let score_compare_with_repeat = precomputed.compare_with(&mut ctx, distorted).unwrap();
assert!(
(score_compare - score_compare_with).abs() < 1e-9,
"compare={} vs compare_with={}",
score_compare,
score_compare_with
);
assert!(
(score_compare_with - score_compare_with_repeat).abs() < 1e-12,
"compare_with should be deterministic across reuse"
);
}
#[test]
fn test_compare_context_dimension_mismatch() {
let source_a: Vec<[f32; 3]> = vec![[0.5, 0.5, 0.5]; 64 * 64];
let source_b: Vec<[f32; 3]> = vec![[0.5, 0.5, 0.5]; 32 * 32];
let ref_a = Ssimulacra2Reference::new(
Rgb::new(
source_a,
NonZeroUsize::new(64).unwrap(),
NonZeroUsize::new(64).unwrap(),
TransferCharacteristic::SRGB,
ColorPrimaries::BT709,
)
.unwrap(),
)
.unwrap();
let distorted_b = Rgb::new(
source_b,
NonZeroUsize::new(32).unwrap(),
NonZeroUsize::new(32).unwrap(),
TransferCharacteristic::SRGB,
ColorPrimaries::BT709,
)
.unwrap();
let mut ctx = ref_a.compare_context();
assert!(matches!(
ref_a.compare_with(&mut ctx, distorted_b),
Err(Ssimulacra2Error::NonMatchingImageDimensions)
));
}
#[test]
fn test_precompute_metadata() {
let data: Vec<[f32; 3]> = vec![[0.5, 0.5, 0.5]; 128 * 96];
let source = Rgb::new(
data,
NonZeroUsize::new(128).unwrap(),
NonZeroUsize::new(96).unwrap(),
TransferCharacteristic::SRGB,
ColorPrimaries::BT709,
)
.unwrap();
let precomputed = Ssimulacra2Reference::new(source).unwrap();
assert_eq!(precomputed.width(), 128);
assert_eq!(precomputed.height(), 96);
assert!(precomputed.num_scales() > 0);
assert!(precomputed.num_scales() <= NUM_SCALES);
}
}