1use crate::blur::Blur;
39use crate::input::ToLinearRgb;
40use crate::{
41 downscale_by_2, edge_diff_map, image_multiply, linear_rgb_to_xyb_simd, make_positive_xyb,
42 ssim_map, xyb_to_planar, LinearRgb, Msssim, MsssimScale, SimdImpl, Ssimulacra2Error,
43 NUM_SCALES,
44};
45
46#[derive(Clone, Debug)]
48struct ScaleData {
49 img1_planar: [Vec<f32>; 3],
51 mu1: [Vec<f32>; 3],
53 sigma1_sq: [Vec<f32>; 3],
55}
56
57#[derive(Clone, Debug)]
66pub struct Ssimulacra2Reference {
67 scales: Vec<ScaleData>,
68 original_width: usize,
69 original_height: usize,
70}
71
72impl Ssimulacra2Reference {
73 pub fn new<T: ToLinearRgb>(source: T) -> Result<Self, Ssimulacra2Error> {
83 let mut img1: LinearRgb = source.to_linear_rgb().into();
84 if img1.width() < 8 || img1.height() < 8 {
85 return Err(Ssimulacra2Error::InvalidImageSize);
86 }
87
88 let original_width = img1.width();
89 let original_height = img1.height();
90 let mut width = original_width;
91 let mut height = original_height;
92
93 let mut mul = [
94 vec![0.0f32; width * height],
95 vec![0.0f32; width * height],
96 vec![0.0f32; width * height],
97 ];
98 let mut blur = Blur::new(width, height);
99 let mut scales = Vec::with_capacity(NUM_SCALES);
100
101 for scale in 0..NUM_SCALES {
102 if width < 8 || height < 8 {
103 break;
104 }
105
106 if scale > 0 {
107 img1 = downscale_by_2(&img1);
108 width = img1.width();
109 height = img1.height();
110 }
111
112 for c in &mut mul {
113 c.truncate(width * height);
114 }
115 blur.shrink_to(width, height);
116
117 let mut img1_xyb = linear_rgb_to_xyb_simd(img1.clone());
118 make_positive_xyb(&mut img1_xyb);
119
120 let img1_planar = xyb_to_planar(&img1_xyb);
121
122 let mu1 = blur.blur(&img1_planar);
124
125 image_multiply(&img1_planar, &img1_planar, &mut mul, SimdImpl::default());
127 let sigma1_sq = blur.blur(&mul);
128
129 scales.push(ScaleData {
130 img1_planar,
131 mu1,
132 sigma1_sq,
133 });
134 }
135
136 Ok(Self {
137 scales,
138 original_width,
139 original_height,
140 })
141 }
142
143 pub fn compare<T: ToLinearRgb>(&self, distorted: T) -> Result<f64, Ssimulacra2Error> {
151 let mut img2: LinearRgb = distorted.to_linear_rgb().into();
152 if img2.width() != self.original_width || img2.height() != self.original_height {
153 return Err(Ssimulacra2Error::NonMatchingImageDimensions);
154 }
155
156 let mut width = img2.width();
157 let mut height = img2.height();
158
159 let mut mul = [
160 vec![0.0f32; width * height],
161 vec![0.0f32; width * height],
162 vec![0.0f32; width * height],
163 ];
164 let mut blur = Blur::new(width, height);
165 let mut msssim = Msssim::default();
166
167 for (scale_idx, scale_data) in self.scales.iter().enumerate() {
168 if width < 8 || height < 8 {
169 break;
170 }
171
172 if scale_idx > 0 {
173 img2 = downscale_by_2(&img2);
174 width = img2.width();
175 height = img2.height();
176 }
177
178 for c in &mut mul {
179 c.truncate(width * height);
180 }
181 blur.shrink_to(width, height);
182
183 let mut img2_xyb = linear_rgb_to_xyb_simd(img2.clone());
184 make_positive_xyb(&mut img2_xyb);
185
186 let img2_planar = xyb_to_planar(&img2_xyb);
187
188 let mu2 = blur.blur(&img2_planar);
190
191 image_multiply(&img2_planar, &img2_planar, &mut mul, SimdImpl::default());
193 let sigma2_sq = blur.blur(&mul);
194
195 image_multiply(
197 &scale_data.img1_planar,
198 &img2_planar,
199 &mut mul,
200 SimdImpl::default(),
201 );
202 let sigma12 = blur.blur(&mul);
203
204 let avg_ssim = ssim_map(
206 width,
207 height,
208 &scale_data.mu1,
209 &mu2,
210 &scale_data.sigma1_sq,
211 &sigma2_sq,
212 &sigma12,
213 SimdImpl::default(),
214 );
215
216 let avg_edgediff = edge_diff_map(
217 width,
218 height,
219 &scale_data.img1_planar,
220 &scale_data.mu1,
221 &img2_planar,
222 &mu2,
223 SimdImpl::default(),
224 );
225
226 msssim.scales.push(MsssimScale {
227 avg_ssim,
228 avg_edgediff,
229 });
230 }
231
232 Ok(msssim.score())
233 }
234
235 #[must_use]
237 pub fn width(&self) -> usize {
238 self.original_width
239 }
240
241 #[must_use]
243 pub fn height(&self) -> usize {
244 self.original_height
245 }
246
247 #[must_use]
249 pub fn num_scales(&self) -> usize {
250 self.scales.len()
251 }
252}
253
254#[cfg(test)]
255mod tests {
256 use super::*;
257 use crate::compute_frame_ssimulacra2;
258 use yuvxyb::{ColorPrimaries, Rgb, TransferCharacteristic};
259
260 #[test]
261 fn test_precompute_matches_full_compute() {
262 let width = 64;
264 let height = 64;
265 let source_data: Vec<[f32; 3]> = (0..width * height)
266 .map(|i| {
267 let x = (i % width) as f32 / width as f32;
268 let y = (i / width) as f32 / height as f32;
269 [x, y, 0.5]
270 })
271 .collect();
272
273 let distorted_data: Vec<[f32; 3]> = source_data
274 .iter()
275 .map(|&[r, g, b]| [r * 0.9, g * 0.95, b * 1.05])
276 .collect();
277
278 let source = Rgb::new(
279 source_data.clone(),
280 width,
281 height,
282 TransferCharacteristic::SRGB,
283 ColorPrimaries::BT709,
284 )
285 .unwrap();
286
287 let distorted = Rgb::new(
288 distorted_data,
289 width,
290 height,
291 TransferCharacteristic::SRGB,
292 ColorPrimaries::BT709,
293 )
294 .unwrap();
295
296 let source_clone = Rgb::new(
298 source_data,
299 width,
300 height,
301 TransferCharacteristic::SRGB,
302 ColorPrimaries::BT709,
303 )
304 .unwrap();
305 let full_score = compute_frame_ssimulacra2(source_clone, distorted.clone()).unwrap();
306
307 let precomputed = Ssimulacra2Reference::new(source).unwrap();
309 let precomputed_score = precomputed.compare(distorted).unwrap();
310
311 assert!(
313 (full_score - precomputed_score).abs() < 1e-6,
314 "Scores don't match: full={}, precomputed={}",
315 full_score,
316 precomputed_score
317 );
318 }
319
320 #[test]
321 fn test_precompute_dimension_mismatch() {
322 let source_data: Vec<[f32; 3]> = vec![[0.5, 0.5, 0.5]; 64 * 64];
323 let distorted_data: Vec<[f32; 3]> = vec![[0.4, 0.4, 0.4]; 32 * 32]; let source = Rgb::new(
326 source_data,
327 64,
328 64,
329 TransferCharacteristic::SRGB,
330 ColorPrimaries::BT709,
331 )
332 .unwrap();
333
334 let distorted = Rgb::new(
335 distorted_data,
336 32,
337 32,
338 TransferCharacteristic::SRGB,
339 ColorPrimaries::BT709,
340 )
341 .unwrap();
342
343 let precomputed = Ssimulacra2Reference::new(source).unwrap();
344 let result = precomputed.compare(distorted);
345
346 assert!(matches!(
347 result,
348 Err(Ssimulacra2Error::NonMatchingImageDimensions)
349 ));
350 }
351
352 #[test]
353 fn test_precompute_metadata() {
354 let data: Vec<[f32; 3]> = vec![[0.5, 0.5, 0.5]; 128 * 96];
355 let source = Rgb::new(
356 data,
357 128,
358 96,
359 TransferCharacteristic::SRGB,
360 ColorPrimaries::BT709,
361 )
362 .unwrap();
363
364 let precomputed = Ssimulacra2Reference::new(source).unwrap();
365
366 assert_eq!(precomputed.width(), 128);
367 assert_eq!(precomputed.height(), 96);
368 assert!(precomputed.num_scales() > 0);
369 assert!(precomputed.num_scales() <= NUM_SCALES);
370 }
371}