use std::num::NonZeroUsize;
use yuvxyb::LinearRgb;
use crate::blur::Blur;
use crate::input::ToLinearRgb;
use crate::precompute::Ssimulacra2Reference;
use crate::weights::{EDGE_HAS_WEIGHT, NUM_SCALES, SSIM_HAS_WEIGHT};
use crate::{
LinearRgbImage, Msssim, MsssimScale, Ssimulacra2Config, Ssimulacra2Error, downscale_by_2,
image_multiply, linear_rgb_to_xyb_simd, make_positive_xyb, xyb_to_planar_into,
};
pub const HALO_ROWS_DEFAULT: usize = 96;
pub const MIN_STRIP_HEIGHT: usize = 8;
#[derive(Debug, Clone, Copy)]
pub struct Ssimulacra2StripConfig {
pub halo_rows: usize,
pub inner: Ssimulacra2Config,
}
impl Default for Ssimulacra2StripConfig {
fn default() -> Self {
Self {
halo_rows: HALO_ROWS_DEFAULT,
inner: Ssimulacra2Config::default(),
}
}
}
impl Ssimulacra2StripConfig {
#[must_use]
pub fn with_halo_rows(halo_rows: usize) -> Self {
Self {
halo_rows,
inner: Ssimulacra2Config::default(),
}
}
#[must_use]
pub fn with_inner(mut self, inner: Ssimulacra2Config) -> Self {
self.inner = inner;
self
}
}
pub fn compute_ssimulacra2_strip<S, D>(
source: S,
distorted: D,
strip_height: u32,
) -> Result<f64, Ssimulacra2Error>
where
S: ToLinearRgb,
D: ToLinearRgb,
{
compute_ssimulacra2_strip_with_config(
source,
distorted,
strip_height,
Ssimulacra2StripConfig::default(),
)
}
pub fn compute_ssimulacra2_strip_with_config<S, D>(
source: S,
distorted: D,
strip_height: u32,
config: Ssimulacra2StripConfig,
) -> Result<f64, Ssimulacra2Error>
where
S: ToLinearRgb,
D: ToLinearRgb,
{
let img1: LinearRgbImage = source.into_linear_rgb();
let img2: LinearRgbImage = distorted.into_linear_rgb();
let lin1: LinearRgb = img1.into();
let lin2: LinearRgb = img2.into();
compute_strip_impl(lin1, lin2, strip_height as usize, config)
}
fn validate_strip_dims(
width: usize,
height: usize,
strip_height: usize,
) -> Result<(), Ssimulacra2Error> {
if width < 8 || height < 8 {
return Err(Ssimulacra2Error::InvalidImageSize);
}
if strip_height < MIN_STRIP_HEIGHT {
return Err(Ssimulacra2Error::InvalidImageSize);
}
let pixels = width
.checked_mul(height)
.ok_or(Ssimulacra2Error::ImageTooLarge { actual: usize::MAX })?;
if pixels > crate::MAX_IMAGE_PIXELS {
return Err(Ssimulacra2Error::ImageTooLarge { actual: pixels });
}
Ok(())
}
#[derive(Debug, Default, Clone)]
struct ScaleSums {
ssim_sums: [f64; 3 * 2],
edge_sums: [f64; 3 * 4],
pixels: u64,
initialised: bool,
}
struct StripAccumulator {
per_scale: Vec<ScaleSums>,
target_total_pixels: Vec<u64>,
}
impl StripAccumulator {
fn new(width: usize, height: usize) -> Self {
let mut per_scale = Vec::with_capacity(NUM_SCALES);
let mut target_total_pixels = Vec::with_capacity(NUM_SCALES);
let mut w = width;
let mut h = height;
for scale in 0..NUM_SCALES {
if w < 8 || h < 8 {
break;
}
if scale > 0 {
w = w.div_ceil(2);
h = h.div_ceil(2);
}
per_scale.push(ScaleSums::default());
target_total_pixels.push((w * h) as u64);
}
Self {
per_scale,
target_total_pixels,
}
}
fn add_strip_sums(&mut self, scale: usize, ssim: &[f64; 6], edge: &[f64; 12], pixels: u64) {
if scale >= self.per_scale.len() {
return;
}
let s = &mut self.per_scale[scale];
for (dst, &src) in s.ssim_sums.iter_mut().zip(ssim.iter()) {
*dst += src;
}
for (dst, &src) in s.edge_sums.iter_mut().zip(edge.iter()) {
*dst += src;
}
s.pixels += pixels;
s.initialised = true;
}
fn finalise(self) -> Result<f64, Ssimulacra2Error> {
for (scale, (s, &expected)) in self
.per_scale
.iter()
.zip(self.target_total_pixels.iter())
.enumerate()
{
if !s.initialised {
continue;
}
debug_assert_eq!(
s.pixels, expected,
"strip accumulator scale {} pixel count {} != expected {}",
scale, s.pixels, expected,
);
}
let mut msssim = Msssim::default();
for (scale, s) in self.per_scale.iter().enumerate() {
if !s.initialised {
break;
}
let denom = self.target_total_pixels[scale] as f64;
if denom == 0.0 {
break;
}
let inv = 1.0 / denom;
let mut avg_ssim = [0.0_f64; 6];
for c in 0..3 {
avg_ssim[c * 2] = inv * s.ssim_sums[c * 2];
avg_ssim[c * 2 + 1] = (inv * s.ssim_sums[c * 2 + 1]).sqrt().sqrt();
}
let mut avg_edgediff = [0.0_f64; 12];
for c in 0..3 {
avg_edgediff[c * 4] = inv * s.edge_sums[c * 4];
avg_edgediff[c * 4 + 1] = (inv * s.edge_sums[c * 4 + 1]).sqrt().sqrt();
avg_edgediff[c * 4 + 2] = inv * s.edge_sums[c * 4 + 2];
avg_edgediff[c * 4 + 3] = (inv * s.edge_sums[c * 4 + 3]).sqrt().sqrt();
}
msssim.scales.push(MsssimScale {
avg_ssim,
avg_edgediff,
});
}
Ok(msssim.score())
}
}
fn linear_rgb_strip(src: &LinearRgb, row_start: usize, row_end: usize) -> LinearRgb {
let width = src.width().get();
let height = src.height().get();
debug_assert!(row_start <= row_end);
debug_assert!(row_end <= height);
let strip_rows = row_end - row_start;
let start = row_start * width;
let end = row_end * width;
let data: Vec<[f32; 3]> = src.data()[start..end].to_vec();
LinearRgb::new(
data,
NonZeroUsize::new(width).expect("width must be non-zero"),
NonZeroUsize::new(strip_rows).expect("strip rows non-zero"),
)
.expect("strip dimensions are valid")
}
fn process_strip(
img1_strip: LinearRgb,
img2_strip: LinearRgb,
strip_y0: usize, interior_start: usize, interior_end: usize, acc: &mut StripAccumulator,
config: Ssimulacra2Config,
) {
let impl_type = config.impl_type;
let mut img1 = img1_strip;
let mut img2 = img2_strip;
let mut width = img1.width().get();
let mut height = img1.height().get();
let total_scales = acc.per_scale.len();
let alloc_plane = |w: usize, h: usize| vec![0.0f32; w * h];
let alloc_3planes =
|w: usize, h: usize| [alloc_plane(w, h), alloc_plane(w, h), alloc_plane(w, h)];
let mut mul = alloc_3planes(width, height);
let mut sigma1_sq = alloc_3planes(width, height);
let mut sigma2_sq = alloc_3planes(width, height);
let mut sigma12 = alloc_3planes(width, height);
let mut mu1 = alloc_3planes(width, height);
let mut mu2 = alloc_3planes(width, height);
let mut img1_planar = alloc_3planes(width, height);
let mut img2_planar = alloc_3planes(width, height);
let mut blur = Blur::with_simd_impl(width, height, impl_type);
let mut scale0_interior_start_in_strip = interior_start - strip_y0;
let mut scale0_interior_end_in_strip = interior_end - strip_y0;
let _ = (strip_y0, interior_start, interior_end);
for scale in 0..total_scales {
if width < 8 || height < 8 {
break;
}
if scale > 0 {
img1 = downscale_by_2(&img1);
img2 = downscale_by_2(&img2);
width = img1.width().get();
height = img2.height().get();
scale0_interior_start_in_strip = scale0_interior_start_in_strip.div_ceil(2);
scale0_interior_end_in_strip = scale0_interior_end_in_strip.div_ceil(2);
}
scale0_interior_start_in_strip = scale0_interior_start_in_strip.min(height);
scale0_interior_end_in_strip = scale0_interior_end_in_strip.min(height);
if scale0_interior_start_in_strip >= scale0_interior_end_in_strip {
continue;
}
let size = width * height;
for buf in [
&mut mul,
&mut sigma1_sq,
&mut sigma2_sq,
&mut sigma12,
&mut mu1,
&mut mu2,
&mut img1_planar,
&mut img2_planar,
] {
for c in buf.iter_mut() {
c.resize(size, 0.0);
c.truncate(size);
}
}
blur.shrink_to(width, height);
let mut img1_xyb = linear_rgb_to_xyb_simd(img1.clone());
let mut img2_xyb = linear_rgb_to_xyb_simd(img2.clone());
make_positive_xyb(&mut img1_xyb);
make_positive_xyb(&mut img2_xyb);
xyb_to_planar_into(&img1_xyb, &mut img1_planar);
xyb_to_planar_into(&img2_xyb, &mut img2_planar);
image_multiply(&img1_planar, &img1_planar, &mut mul, impl_type);
blur.blur_into(&mul, &mut sigma1_sq);
image_multiply(&img2_planar, &img2_planar, &mut mul, impl_type);
blur.blur_into(&mul, &mut sigma2_sq);
image_multiply(&img1_planar, &img2_planar, &mut mul, impl_type);
blur.blur_into(&mul, &mut sigma12);
blur.blur_into(&img1_planar, &mut mu1);
blur.blur_into(&img2_planar, &mut mu2);
let ssim_sums = ssim_map_strip(
scale,
width,
scale0_interior_start_in_strip,
scale0_interior_end_in_strip,
&mu1,
&mu2,
&sigma1_sq,
&sigma2_sq,
&sigma12,
total_scales,
);
let edge_sums = edge_diff_map_strip(
scale,
width,
scale0_interior_start_in_strip,
scale0_interior_end_in_strip,
&img1_planar,
&mu1,
&img2_planar,
&mu2,
total_scales,
);
let interior_h = scale0_interior_end_in_strip - scale0_interior_start_in_strip;
let interior_pixels = (interior_h as u64) * (width as u64);
acc.add_strip_sums(scale, &ssim_sums, &edge_sums, interior_pixels);
}
}
#[allow(clippy::too_many_arguments)]
fn ssim_map_strip(
scale_idx: usize,
width: usize,
interior_start: usize,
interior_end: usize,
m1: &[Vec<f32>; 3],
m2: &[Vec<f32>; 3],
s11: &[Vec<f32>; 3],
s22: &[Vec<f32>; 3],
s12: &[Vec<f32>; 3],
total_scales: usize,
) -> [f64; 6] {
const C2: f32 = 0.0009f32;
let mut out = [0.0_f64; 6];
let skip_table = SSIM_HAS_WEIGHT[total_scales.min(NUM_SCALES)];
for c in 0..3 {
if scale_idx < NUM_SCALES && !skip_table[c][scale_idx] {
continue;
}
let mut sum_d = 0.0f64;
let mut sum_d4 = 0.0f64;
let m1c = &m1[c];
let m2c = &m2[c];
let s11c = &s11[c];
let s22c = &s22[c];
let s12c = &s12[c];
for y in interior_start..interior_end {
let row_start = y * width;
let row_end = row_start + width;
let m1_row = &m1c[row_start..row_end];
let m2_row = &m2c[row_start..row_end];
let s11_row = &s11c[row_start..row_end];
let s22_row = &s22c[row_start..row_end];
let s12_row = &s12c[row_start..row_end];
for x in 0..width {
let mu1 = m1_row[x];
let mu2 = m2_row[x];
let mu11 = mu1 * mu1;
let mu22 = mu2 * mu2;
let mu12 = mu1 * mu2;
let mu_diff = mu1 - mu2;
let num_m = mu_diff.mul_add(-mu_diff, 1.0f32);
let num_s = 2.0f32.mul_add(s12_row[x] - mu12, C2);
let denom_s = (s11_row[x] - mu11) + (s22_row[x] - mu22) + C2;
let d = (1.0f32 - (num_m * num_s) / denom_s).max(0.0f32);
let d2 = d * d;
let d4 = d2 * d2;
sum_d += f64::from(d);
sum_d4 += f64::from(d4);
}
}
out[c * 2] = sum_d;
out[c * 2 + 1] = sum_d4;
}
out
}
#[allow(clippy::too_many_arguments)]
fn edge_diff_map_strip(
scale_idx: usize,
width: usize,
interior_start: usize,
interior_end: usize,
img1: &[Vec<f32>; 3],
mu1: &[Vec<f32>; 3],
img2: &[Vec<f32>; 3],
mu2: &[Vec<f32>; 3],
total_scales: usize,
) -> [f64; 12] {
let mut out = [0.0_f64; 12];
let skip_table = EDGE_HAS_WEIGHT[total_scales.min(NUM_SCALES)];
for c in 0..3 {
if scale_idx < NUM_SCALES && !skip_table[c][scale_idx] {
continue;
}
let mut sums = [0.0_f64; 4];
let i1 = &img1[c];
let i2 = &img2[c];
let m1c = &mu1[c];
let m2c = &mu2[c];
for y in interior_start..interior_end {
let row_start = y * width;
let row_end = row_start + width;
let row1 = &i1[row_start..row_end];
let row2 = &i2[row_start..row_end];
let rowm1 = &m1c[row_start..row_end];
let rowm2 = &m2c[row_start..row_end];
for x in 0..width {
let d1: f64 = (1.0 + f64::from((row2[x] - rowm2[x]).abs()))
/ (1.0 + f64::from((row1[x] - rowm1[x]).abs()))
- 1.0;
let artifact = d1.max(0.0);
sums[0] += artifact;
sums[1] += artifact.powi(4);
let detail_lost = (-d1).max(0.0);
sums[2] += detail_lost;
sums[3] += detail_lost.powi(4);
}
}
out[c * 4] = sums[0];
out[c * 4 + 1] = sums[1];
out[c * 4 + 2] = sums[2];
out[c * 4 + 3] = sums[3];
}
out
}
fn compute_strip_impl(
img1: LinearRgb,
img2: LinearRgb,
strip_height: usize,
config: Ssimulacra2StripConfig,
) -> Result<f64, Ssimulacra2Error> {
if img1.width() != img2.width() || img1.height() != img2.height() {
return Err(Ssimulacra2Error::NonMatchingImageDimensions);
}
let width = img1.width().get();
let height = img1.height().get();
validate_strip_dims(width, height, strip_height)?;
let halo = config.halo_rows;
let mut acc = StripAccumulator::new(width, height);
const ALIGNMENT: usize = 32;
let strip_h = strip_height.max(MIN_STRIP_HEIGHT);
let mut y = 0usize;
while y < height {
let mut next_y = (y + strip_h).next_multiple_of(ALIGNMENT);
if next_y >= height || height - next_y < ALIGNMENT {
next_y = height;
}
let interior_start = y;
let interior_end = next_y;
let halo_above = halo.min(interior_start);
let halo_below = halo.min(height - interior_end);
let strip_y0 = interior_start - halo_above;
let strip_y1 = interior_end + halo_below;
let img1_strip = linear_rgb_strip(&img1, strip_y0, strip_y1);
let img2_strip = linear_rgb_strip(&img2, strip_y0, strip_y1);
process_strip(
img1_strip,
img2_strip,
strip_y0,
interior_start,
interior_end,
&mut acc,
config.inner,
);
y = next_y;
}
acc.finalise()
}
impl Ssimulacra2Reference {
pub fn compare_strip<T: ToLinearRgb>(
&self,
distorted: T,
strip_height: u32,
) -> Result<f64, Ssimulacra2Error> {
self.compare_strip_with_config(distorted, strip_height, Ssimulacra2StripConfig::default())
}
pub fn compare_strip_with_config<T: ToLinearRgb>(
&self,
distorted: T,
strip_height: u32,
config: Ssimulacra2StripConfig,
) -> Result<f64, Ssimulacra2Error> {
let img2: LinearRgb = distorted.into_linear_rgb().into();
let width = img2.width().get();
let height = img2.height().get();
if width != self.width() || height != self.height() {
return Err(Ssimulacra2Error::NonMatchingImageDimensions);
}
validate_strip_dims(width, height, strip_height as usize)?;
compare_strip_with_cached_ref(self, img2, strip_height as usize, config)
}
}
fn compare_strip_with_cached_ref(
reference: &Ssimulacra2Reference,
img2_full: LinearRgb,
strip_height: usize,
config: Ssimulacra2StripConfig,
) -> Result<f64, Ssimulacra2Error> {
let width = img2_full.width().get();
let height = img2_full.height().get();
let halo = config.halo_rows;
let mut acc = StripAccumulator::new(width, height);
const ALIGNMENT: usize = 32;
let strip_h = strip_height.max(MIN_STRIP_HEIGHT);
let mut y = 0usize;
while y < height {
let mut next_y = (y + strip_h).next_multiple_of(ALIGNMENT);
if next_y >= height || height - next_y < ALIGNMENT {
next_y = height;
}
let interior_start = y;
let interior_end = next_y;
let halo_above = halo.min(interior_start);
let halo_below = halo.min(height - interior_end);
let strip_y0 = interior_start - halo_above;
let strip_y1 = interior_end + halo_below;
let img2_strip = linear_rgb_strip(&img2_full, strip_y0, strip_y1);
process_dist_strip_with_cached_ref(
reference,
img2_strip,
strip_y0,
interior_start,
interior_end,
&mut acc,
config.inner,
);
y = next_y;
}
acc.finalise()
}
fn process_dist_strip_with_cached_ref(
reference: &Ssimulacra2Reference,
img2_strip: LinearRgb,
strip_y0: usize,
interior_start: usize,
interior_end: usize,
acc: &mut StripAccumulator,
config: Ssimulacra2Config,
) {
let impl_type = config.impl_type;
let mut img2 = img2_strip;
let mut width = img2.width().get();
let mut height = img2.height().get();
let total_scales = acc.per_scale.len();
let alloc_plane = |w: usize, h: usize| vec![0.0f32; w * h];
let alloc_3planes =
|w: usize, h: usize| [alloc_plane(w, h), alloc_plane(w, h), alloc_plane(w, h)];
let mut mul = alloc_3planes(width, height);
let mut sigma1_sq_strip = alloc_3planes(width, height);
let mut sigma2_sq = alloc_3planes(width, height);
let mut sigma12 = alloc_3planes(width, height);
let mut mu1_strip = alloc_3planes(width, height);
let mut mu2 = alloc_3planes(width, height);
let mut img2_planar = alloc_3planes(width, height);
let mut img1_planar_strip = alloc_3planes(width, height);
let mut blur = Blur::with_simd_impl(width, height, impl_type);
let mut interior_start_in_strip = interior_start - strip_y0;
let mut interior_end_in_strip = interior_end - strip_y0;
let mut strip_y0_in_ref = strip_y0;
for scale in 0..total_scales {
if width < 8 || height < 8 {
break;
}
if scale > 0 {
img2 = downscale_by_2(&img2);
width = img2.width().get();
height = img2.height().get();
interior_start_in_strip = interior_start_in_strip.div_ceil(2);
interior_end_in_strip = interior_end_in_strip.div_ceil(2);
strip_y0_in_ref /= 2;
}
interior_start_in_strip = interior_start_in_strip.min(height);
interior_end_in_strip = interior_end_in_strip.min(height);
if interior_start_in_strip >= interior_end_in_strip {
continue;
}
let ref_planes = reference
.scale_planes(scale)
.expect("scale index is bounded by total_scales which equals reference.num_scales()");
let ref_width = ref_planes.width;
let ref_height = ref_planes.height;
debug_assert_eq!(ref_width, width);
let ref_h_for_strip = (ref_height - strip_y0_in_ref).min(height);
let actual_strip_h = height.min(ref_h_for_strip);
debug_assert_eq!(
actual_strip_h, height,
"strip walker scale={scale} ref_h={ref_height} strip_y0_in_ref={strip_y0_in_ref} \
dist_strip_h={height} ref_h_for_strip={ref_h_for_strip} — alignment regression"
);
let size = width * actual_strip_h;
for buf in [
&mut mul,
&mut sigma1_sq_strip,
&mut sigma2_sq,
&mut sigma12,
&mut mu1_strip,
&mut mu2,
&mut img2_planar,
&mut img1_planar_strip,
] {
for c in buf.iter_mut() {
c.resize(size, 0.0);
c.truncate(size);
}
}
blur.shrink_to(width, actual_strip_h);
for (dst_chan, src_chan) in img1_planar_strip
.iter_mut()
.zip(ref_planes.img1_planar.iter())
{
for row in 0..actual_strip_h {
let src_row_start = (strip_y0_in_ref + row) * width;
let dst_row_start = row * width;
dst_chan[dst_row_start..dst_row_start + width]
.copy_from_slice(&src_chan[src_row_start..src_row_start + width]);
}
}
let mut img2_xyb = linear_rgb_to_xyb_simd(img2.clone());
make_positive_xyb(&mut img2_xyb);
xyb_to_planar_into(&img2_xyb, &mut img2_planar);
blur.blur_into(&img1_planar_strip, &mut mu1_strip);
image_multiply(&img1_planar_strip, &img1_planar_strip, &mut mul, impl_type);
blur.blur_into(&mul, &mut sigma1_sq_strip);
image_multiply(&img2_planar, &img2_planar, &mut mul, impl_type);
blur.blur_into(&mul, &mut sigma2_sq);
image_multiply(&img1_planar_strip, &img2_planar, &mut mul, impl_type);
blur.blur_into(&mul, &mut sigma12);
blur.blur_into(&img2_planar, &mut mu2);
let ssim_sums = ssim_map_strip(
scale,
width,
interior_start_in_strip,
interior_end_in_strip,
&mu1_strip,
&mu2,
&sigma1_sq_strip,
&sigma2_sq,
&sigma12,
total_scales,
);
let edge_sums = edge_diff_map_strip(
scale,
width,
interior_start_in_strip,
interior_end_in_strip,
&img1_planar_strip,
&mu1_strip,
&img2_planar,
&mu2,
total_scales,
);
let interior_h = interior_end_in_strip - interior_start_in_strip;
let interior_pixels = (interior_h as u64) * (width as u64);
acc.add_strip_sums(scale, &ssim_sums, &edge_sums, interior_pixels);
}
}