1use std::path::{Path, PathBuf};
9
10use serde::{Deserialize, Serialize};
11use tokio::process::Command;
12use tracing::warn;
13use viser_ffmpeg::{ProbeCache, ffmpeg_path};
14
15pub mod noref;
16pub mod pool;
17pub use noref::{NoRefOpts, NoRefResult, measure_noref};
18pub use pool::{PoolStrategy, PooledStats};
19
20#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
22#[serde(rename_all = "lowercase")]
23pub enum Metric {
24 #[default]
26 Vmaf,
27 Psnr,
29 Ssim,
31 Ssimulacra2,
33 Butteraugli,
35 MsSsim,
37 Vif,
39 Cambi,
41 Xpsnr,
43}
44
45#[derive(Debug, Clone, Default, Serialize, Deserialize)]
49#[serde(default)]
50pub struct Result {
51 pub vmaf: f64,
53 pub psnr: f64,
55 pub psnr_u: f64,
57 pub psnr_v: f64,
59 pub psnr_avg: f64,
61 pub ssim: f64,
63 pub ssimulacra2: f64,
65 pub butteraugli: f64,
67 pub ms_ssim: f64,
69 pub vif: f64,
71 pub cambi: f64,
73 pub xpsnr: f64,
75 pub pooled: Pooled,
77 #[serde(skip_serializing_if = "Vec::is_empty")]
79 pub frames: Vec<FrameResult>,
80}
81
82#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
84#[serde(default)]
85pub struct Pooled {
86 pub vmaf: PooledStats,
88 pub psnr: PooledStats,
90 pub ssim: PooledStats,
92 pub ssimulacra2: PooledStats,
94 pub butteraugli: PooledStats,
96 pub ms_ssim: PooledStats,
98 pub vif: PooledStats,
100 pub cambi: PooledStats,
102 pub xpsnr: PooledStats,
104}
105
106#[derive(Debug, Clone, Default, Serialize, Deserialize)]
108pub struct FrameResult {
109 pub frame_num: i32,
111 pub vmaf: f64,
113 pub psnr: f64,
115 #[serde(default)]
117 pub psnr_u: f64,
118 #[serde(default)]
120 pub psnr_v: f64,
121 pub ssim: f64,
123 pub ssimulacra2: f64,
125 pub butteraugli: f64,
127 #[serde(default)]
129 pub ms_ssim: f64,
130 #[serde(default)]
132 pub vif: f64,
133 #[serde(default)]
135 pub cambi: f64,
136 #[serde(default)]
138 pub xpsnr: f64,
139}
140
141#[derive(Debug, Clone)]
143pub struct MeasureOpts {
144 pub metrics: Vec<Metric>,
146 pub subsample: i32,
148 pub model: String,
150 pub per_frame: bool,
152 pub frame_samples: usize,
157 pub probe_cache: Option<ProbeCache>,
159}
160
161impl Default for MeasureOpts {
162 fn default() -> Self {
163 Self {
164 metrics: vec![
165 Metric::Vmaf,
166 Metric::Psnr,
167 Metric::Ssim,
168 Metric::Ssimulacra2,
169 Metric::Butteraugli,
170 ],
171 subsample: 0,
172 model: "vmaf_v0.6.1".into(),
173 per_frame: false,
174 frame_samples: 0,
175 probe_cache: None,
176 }
177 }
178}
179
180pub async fn measure(
182 reference: &str,
183 distorted: &str,
184 opts: MeasureOpts,
185) -> anyhow::Result<Result> {
186 let model_name = if opts.model.is_empty() { "vmaf_v0.6.1" } else { &opts.model };
187 let metrics = if opts.metrics.is_empty() {
188 vec![Metric::Vmaf, Metric::Psnr, Metric::Ssim]
189 } else {
190 opts.metrics.clone()
191 };
192
193 if metrics.iter().all(|m| matches!(m, Metric::Psnr | Metric::Ssim)) {
198 return measure_native(reference, distorted, &metrics, &opts).await;
199 }
200
201 let tmp = tempfile::Builder::new().prefix("viser-vmaf-").suffix(".json").tempfile()?;
202 let log_path = tmp.path().to_string_lossy().to_string();
203
204 let mut vmaf_opts = format!("log_fmt=json:log_path={log_path}:model=version={model_name}");
206
207 let mut features: Vec<&str> = Vec::new();
211 for m in &metrics {
212 match m {
213 Metric::Psnr => features.push("name=psnr"),
214 Metric::Ssim => features.push("name=float_ssim"),
215 Metric::MsSsim => features.push("name=float_ms_ssim"),
216 Metric::Cambi => features.push("name=cambi"),
217 Metric::Vmaf | Metric::Vif => {}
219 Metric::Xpsnr | Metric::Ssimulacra2 | Metric::Butteraugli => {}
221 }
222 }
223 if !features.is_empty() {
224 vmaf_opts.push_str(&format!(":feature={}", features.join("|")));
225 }
226
227 if opts.subsample > 0 {
228 vmaf_opts.push_str(&format!(":n_subsample={}", opts.subsample));
229 }
230
231 let ref_info = if let Some(ref cache) = opts.probe_cache {
233 cache.probe(reference).await?
234 } else {
235 viser_ffmpeg::probe(reference).await?
236 };
237
238 let ref_video =
239 ref_info.video_stream().ok_or_else(|| anyhow::anyhow!("no video stream in reference"))?;
240
241 if ref_video.bits_per_raw_sample > 8 {
242 warn!(
243 bits_per_sample = ref_video.bits_per_raw_sample,
244 reference = reference,
245 "10-bit content detected; VMAF scores calibrated for 8-bit may differ"
246 );
247 }
248
249 let filtergraph = format!(
250 "[0:v]scale={}:{}:flags=bicubic[dist];[dist][1:v]libvmaf={}",
251 ref_video.width, ref_video.height, vmaf_opts
252 );
253
254 let args = ["-i", distorted, "-i", reference, "-lavfi", &filtergraph, "-f", "null", "-"];
255
256 let output = Command::new(ffmpeg_path())
257 .args(args)
258 .stderr(std::process::Stdio::piped())
259 .output()
260 .await?;
261
262 if !output.status.success() {
263 let stderr = String::from_utf8_lossy(&output.stderr);
264 anyhow::bail!("ffmpeg quality measurement failed: {stderr}");
265 }
266
267 let data = std::fs::read(&log_path)?;
268 let mut result = parse_vmaf_log(&data, opts.per_frame)?;
269
270 if metrics.contains(&Metric::Ssimulacra2) {
272 let scores = measure_ssimulacra2(reference, distorted, &opts).await?;
273 result.ssimulacra2 = pool::PoolStrategy::Mean.apply(&scores);
274 result.pooled.ssimulacra2 = PooledStats::from_values(&scores);
275 }
276
277 if metrics.contains(&Metric::Butteraugli) {
279 let scores = measure_butteraugli(reference, distorted, &opts).await?;
280 result.butteraugli = pool::PoolStrategy::Mean.apply(&scores);
281 result.pooled.butteraugli = PooledStats::from_values(&scores);
282 }
283
284 if metrics.contains(&Metric::Xpsnr) {
286 let scores = measure_xpsnr(reference, distorted, &opts).await?;
287 result.xpsnr = pool::PoolStrategy::Mean.apply(&scores);
288 result.pooled.xpsnr = PooledStats::from_values(&scores);
289 if opts.per_frame && scores.len() == result.frames.len() {
290 for (fr, s) in result.frames.iter_mut().zip(scores) {
291 fr.xpsnr = s;
292 }
293 }
294 }
295
296 Ok(result)
297}
298
299async fn measure_native(
307 reference: &str,
308 distorted: &str,
309 metrics: &[Metric],
310 opts: &MeasureOpts,
311) -> anyhow::Result<Result> {
312 let ref_info = if let Some(ref cache) = opts.probe_cache {
313 cache.probe(reference).await?
314 } else {
315 viser_ffmpeg::probe(reference).await?
316 };
317 let ref_video =
318 ref_info.video_stream().ok_or_else(|| anyhow::anyhow!("no video stream in reference"))?;
319
320 let sel = if opts.subsample > 1 {
323 format!("select=not(mod(n\\,{}))", opts.subsample)
324 } else {
325 "null".to_string()
326 };
327
328 let mut result = Result::default();
329 for m in metrics {
330 let filter_name = match m {
331 Metric::Psnr => "psnr",
332 Metric::Ssim => "ssim",
333 _ => continue,
334 };
335
336 let filtergraph = format!(
337 "[0:v]scale={}:{}:flags=bicubic,{sel}[dist];[1:v]{sel}[ref];[dist][ref]{filter_name}",
338 ref_video.width, ref_video.height
339 );
340 let args = ["-i", distorted, "-i", reference, "-lavfi", &filtergraph, "-f", "null", "-"];
341
342 let output = Command::new(ffmpeg_path())
343 .args(args)
344 .stderr(std::process::Stdio::piped())
345 .output()
346 .await?;
347 if !output.status.success() {
348 let stderr = String::from_utf8_lossy(&output.stderr);
349 anyhow::bail!("ffmpeg {filter_name} measurement failed: {stderr}");
350 }
351
352 let stderr = String::from_utf8_lossy(&output.stderr);
353 match m {
354 Metric::Psnr => {
355 let line = stderr
356 .lines()
357 .rev()
358 .find(|l| l.contains("PSNR") && l.contains("average:"))
359 .ok_or_else(|| anyhow::anyhow!("could not parse PSNR from ffmpeg output"))?;
360 result.psnr = parse_metric_kv(line, "y:").unwrap_or(0.0);
361 result.psnr_u = parse_metric_kv(line, "u:").unwrap_or(0.0);
362 result.psnr_v = parse_metric_kv(line, "v:").unwrap_or(0.0);
363 result.psnr_avg = parse_metric_kv(line, "average:").unwrap_or(result.psnr);
364 }
365 Metric::Ssim => {
366 let line = stderr
367 .lines()
368 .rev()
369 .find(|l| l.contains("SSIM") && l.contains("All:"))
370 .ok_or_else(|| anyhow::anyhow!("could not parse SSIM from ffmpeg output"))?;
371 result.ssim = parse_metric_kv(line, "All:").unwrap_or(0.0);
372 }
373 _ => {}
374 }
375 }
376
377 Ok(result)
378}
379
380fn parse_metric_kv(line: &str, key: &str) -> Option<f64> {
385 let start = line.find(key)? + key.len();
386 let rest = &line[start..];
387 let end = rest
388 .find(|c: char| !matches!(c, '0'..='9' | '.' | '-' | '+' | 'e' | 'E'))
389 .unwrap_or(rest.len());
390 rest[..end].parse().ok()
391}
392
393#[derive(Deserialize)]
395struct VmafLog {
396 frames: Vec<VmafFrame>,
397 #[serde(default)]
398 pooled_metrics: std::collections::HashMap<String, PooledMetric>,
399}
400
401#[derive(Deserialize)]
402struct VmafFrame {
403 #[serde(rename = "frameNum")]
404 frame_num: i32,
405 metrics: std::collections::HashMap<String, f64>,
406}
407
408#[derive(Deserialize)]
409struct PooledMetric {
410 mean: f64,
411}
412
413fn parse_vmaf_log(data: &[u8], per_frame: bool) -> anyhow::Result<Result> {
414 let log: VmafLog = serde_json::from_slice(data)?;
415
416 let mut result = Result::default();
417
418 result.vmaf = pooled_mean(&log, &["vmaf"]);
420 result.psnr = pooled_mean(&log, &["psnr_y", "psnr"]);
421 result.psnr_u = pooled_mean(&log, &["psnr_cb", "psnr_u"]);
422 result.psnr_v = pooled_mean(&log, &["psnr_cr", "psnr_v"]);
423 result.psnr_avg = if result.psnr_u > 0.0 && result.psnr_v > 0.0 {
424 (6.0 * result.psnr + result.psnr_u + result.psnr_v) / 8.0
428 } else {
429 result.psnr
430 };
431 result.ssim = pooled_mean(&log, &["float_ssim", "ssim"]);
432
433 let mut vmaf_series = Vec::with_capacity(log.frames.len());
435 let mut psnr_series = Vec::with_capacity(log.frames.len());
436 let mut ssim_series = Vec::with_capacity(log.frames.len());
437 let mut ms_ssim_series = Vec::with_capacity(log.frames.len());
438 let mut vif_series = Vec::with_capacity(log.frames.len());
439 let mut cambi_series = Vec::with_capacity(log.frames.len());
440 for f in &log.frames {
441 if let Some(v) = f.metrics.get("vmaf") {
442 vmaf_series.push(*v);
443 }
444 if let Some(v) = frame_metric(&f.metrics, &["psnr_y", "psnr"]) {
445 psnr_series.push(v);
446 }
447 if let Some(v) = frame_metric(&f.metrics, &["float_ssim", "ssim"]) {
448 ssim_series.push(v);
449 }
450 if let Some(v) = frame_metric(&f.metrics, &["float_ms_ssim", "ms_ssim"]) {
451 ms_ssim_series.push(v);
452 }
453 if let Some(v) = vif_mean(&f.metrics) {
454 vif_series.push(v);
455 }
456 if let Some(v) = f.metrics.get("cambi") {
457 cambi_series.push(*v);
458 }
459 }
460 result.pooled.vmaf = PooledStats::from_values(&vmaf_series);
461 result.pooled.psnr = PooledStats::from_values(&psnr_series);
462 result.pooled.ssim = PooledStats::from_values(&ssim_series);
463 result.pooled.ms_ssim = PooledStats::from_values(&ms_ssim_series);
464 result.pooled.vif = PooledStats::from_values(&vif_series);
465 result.pooled.cambi = PooledStats::from_values(&cambi_series);
466 result.ms_ssim = result.pooled.ms_ssim.mean;
467 result.vif = result.pooled.vif.mean;
468 result.cambi = result.pooled.cambi.mean;
469
470 if result.vmaf == 0.0 {
472 result.vmaf = result.pooled.vmaf.mean;
473 }
474 if result.psnr == 0.0 {
475 result.psnr = result.pooled.psnr.mean;
476 if result.psnr_avg == 0.0 {
477 result.psnr_avg = result.psnr;
478 }
479 }
480 if result.ssim == 0.0 {
481 result.ssim = result.pooled.ssim.mean;
482 }
483
484 if per_frame {
485 for f in &log.frames {
486 result.frames.push(FrameResult {
487 frame_num: f.frame_num,
488 vmaf: f.metrics.get("vmaf").copied().unwrap_or(0.0),
489 psnr: frame_metric(&f.metrics, &["psnr_y", "psnr"]).unwrap_or(0.0),
490 psnr_u: frame_metric(&f.metrics, &["psnr_cb", "psnr_u"]).unwrap_or(0.0),
491 psnr_v: frame_metric(&f.metrics, &["psnr_cr", "psnr_v"]).unwrap_or(0.0),
492 ssim: frame_metric(&f.metrics, &["float_ssim", "ssim"]).unwrap_or(0.0),
493 ssimulacra2: f.metrics.get("ssimulacra2").copied().unwrap_or(0.0),
494 butteraugli: f.metrics.get("butteraugli").copied().unwrap_or(0.0),
495 ms_ssim: frame_metric(&f.metrics, &["float_ms_ssim", "ms_ssim"]).unwrap_or(0.0),
496 vif: vif_mean(&f.metrics).unwrap_or(0.0),
497 cambi: f.metrics.get("cambi").copied().unwrap_or(0.0),
498 xpsnr: 0.0,
499 });
500 }
501 }
502
503 Ok(result)
504}
505
506fn pooled_mean(log: &VmafLog, keys: &[&str]) -> f64 {
508 for k in keys {
509 if let Some(m) = log.pooled_metrics.get(*k) {
510 return m.mean;
511 }
512 }
513 0.0
514}
515
516fn frame_metric(metrics: &std::collections::HashMap<String, f64>, keys: &[&str]) -> Option<f64> {
518 for k in keys {
519 if let Some(v) = metrics.get(*k) {
520 return Some(*v);
521 }
522 }
523 None
524}
525
526fn vif_mean(metrics: &std::collections::HashMap<String, f64>) -> Option<f64> {
529 let mut sum = 0.0;
530 let mut n = 0;
531 for s in 0..4 {
532 if let Some(v) = frame_metric(
533 metrics,
534 &[
535 &format!("integer_vif_scale{s}"),
536 &format!("float_vif_scale{s}"),
537 &format!("vif_scale{s}"),
538 ],
539 ) {
540 sum += v;
541 n += 1;
542 }
543 }
544 if n > 0 { Some(sum / n as f64) } else { None }
545}
546
547fn sample_indices(nb_frames: i32, samples: usize) -> Vec<i32> {
552 if samples <= 1 || nb_frames <= 1 {
553 return vec![0];
554 }
555 let count = samples.min(nb_frames as usize);
556 if count <= 1 {
557 return vec![0];
558 }
559 (0..count)
560 .map(|i| ((i as f64) * (nb_frames as f64 - 1.0) / (count as f64 - 1.0)).round() as i32)
561 .collect()
562}
563
564async fn reference_dims(reference: &str, opts: &MeasureOpts) -> anyhow::Result<(i32, i32, i32)> {
566 let ref_info = if let Some(ref cache) = opts.probe_cache {
567 cache.probe(reference).await?
568 } else {
569 viser_ffmpeg::probe(reference).await?
570 };
571 let ref_video =
572 ref_info.video_stream().ok_or_else(|| anyhow::anyhow!("no video stream in reference"))?;
573 Ok((ref_video.width, ref_video.height, ref_video.nb_frames))
574}
575
576async fn extract_frames_png(
583 input: &str,
584 selection: Option<&[i32]>,
585 width: i32,
586 height: i32,
587 dir: &Path,
588) -> anyhow::Result<Vec<PathBuf>> {
589 let scale = format!("scale={width}:{height}:flags=bicubic");
590 let vf = match selection {
591 None => scale,
592 Some(indices) => {
593 let sel = indices.iter().map(|i| format!("eq(n\\,{i})")).collect::<Vec<_>>().join("+");
594 format!("select='{sel}',{scale}")
595 }
596 };
597 let pattern = dir.join("%06d.png");
598 let output = Command::new(ffmpeg_path())
599 .args(["-i", input, "-vf", &vf, "-fps_mode", "passthrough", "-c:v", "png"])
600 .arg(&pattern)
601 .stderr(std::process::Stdio::piped())
602 .output()
603 .await?;
604
605 if !output.status.success() {
606 let stderr = String::from_utf8_lossy(&output.stderr);
607 anyhow::bail!("failed to extract frames from {input}: {stderr}");
608 }
609
610 let mut paths: Vec<PathBuf> = std::fs::read_dir(dir)?
611 .filter_map(|e| e.ok().map(|e| e.path()))
612 .filter(|p| p.extension().is_some_and(|x| x == "png"))
613 .collect();
614 paths.sort();
615 Ok(paths)
616}
617
618struct FramePairs {
622 _ref_dir: tempfile::TempDir,
623 _dist_dir: tempfile::TempDir,
624 pairs: Vec<(PathBuf, PathBuf)>,
625}
626
627async fn extract_frame_pairs(
628 reference: &str,
629 distorted: &str,
630 opts: &MeasureOpts,
631) -> anyhow::Result<FramePairs> {
632 let (width, height, nb_frames) = reference_dims(reference, opts).await?;
633 let (_, _, dist_nb_frames) = reference_dims(distorted, opts).await?;
634 if dist_nb_frames != nb_frames {
635 warn!(
636 reference_frames = nb_frames,
637 distorted_frames = dist_nb_frames,
638 "reference and distorted frame counts differ; sampled perceptual metrics may be misaligned"
639 );
640 }
641
642 let selection: Option<Vec<i32>> = if opts.frame_samples == 0 {
643 None
644 } else {
645 Some(sample_indices(nb_frames, opts.frame_samples))
646 };
647 let sel = selection.as_deref();
648
649 let ref_dir = tempfile::Builder::new().prefix("viser-q-ref-").tempdir()?;
650 let dist_dir = tempfile::Builder::new().prefix("viser-q-dist-").tempdir()?;
651 let ref_paths = extract_frames_png(reference, sel, width, height, ref_dir.path()).await?;
652 let dist_paths = extract_frames_png(distorted, sel, width, height, dist_dir.path()).await?;
653
654 let n = ref_paths.len().min(dist_paths.len());
655 let pairs =
656 ref_paths.into_iter().take(n).zip(dist_paths.into_iter().take(n)).collect::<Vec<_>>();
657 Ok(FramePairs { _ref_dir: ref_dir, _dist_dir: dist_dir, pairs })
658}
659
660async fn measure_ssimulacra2(
662 reference: &str,
663 distorted: &str,
664 opts: &MeasureOpts,
665) -> anyhow::Result<Vec<f64>> {
666 let frames = extract_frame_pairs(reference, distorted, opts).await?;
667 let mut scores = Vec::with_capacity(frames.pairs.len());
668 for (ref_png, dist_png) in &frames.pairs {
669 let s2_output = Command::new("ssimulacra2")
670 .arg(ref_png)
671 .arg(dist_png)
672 .stdout(std::process::Stdio::piped())
673 .stderr(std::process::Stdio::null())
674 .output()
675 .await?;
676
677 if !s2_output.status.success() {
678 anyhow::bail!("ssimulacra2 failed: {}", String::from_utf8_lossy(&s2_output.stderr));
679 }
680
681 let stdout_str = String::from_utf8_lossy(&s2_output.stdout);
682 let score: f64 = stdout_str
683 .trim()
684 .parse()
685 .map_err(|_| anyhow::anyhow!("ssimulacra2: could not parse score: {stdout_str}"))?;
686 scores.push(score);
687 }
688
689 Ok(scores)
690}
691
692async fn measure_butteraugli(
697 reference: &str,
698 distorted: &str,
699 opts: &MeasureOpts,
700) -> anyhow::Result<Vec<f64>> {
701 let frames = extract_frame_pairs(reference, distorted, opts).await?;
702 let mut scores = Vec::with_capacity(frames.pairs.len());
703 for (i, (ref_png, dist_png)) in frames.pairs.iter().enumerate() {
704 let ba_output = Command::new("butteraugli")
705 .arg(ref_png)
706 .arg(dist_png)
707 .stdout(std::process::Stdio::piped())
708 .stderr(std::process::Stdio::null())
709 .output()
710 .await;
711
712 let mut score = 0.0;
713 let mut parsed = false;
714 if let Ok(out) = ba_output
715 && out.status.success()
716 {
717 let stdout_str = String::from_utf8_lossy(&out.stdout);
718 if let Ok(s) = stdout_str.trim().parse::<f64>() {
719 score = s;
720 parsed = true;
721 } else if let Some(last_line) = stdout_str.lines().last() {
722 if let Ok(s) = last_line.trim().parse::<f64>() {
724 score = s;
725 parsed = true;
726 }
727 }
728 }
729 if !parsed {
730 warn!(frame = i, "butteraugli not available or failed; recording 0.0");
731 }
732 scores.push(score);
733 }
734
735 Ok(scores)
736}
737
738fn parse_xpsnr_component(line: &str, tag: &str) -> Option<f64> {
741 let idx = line.find(tag)?;
742 let token = line[idx + tag.len()..].split_whitespace().next()?;
743 match token {
744 "inf" | "-inf" => Some(100.0),
745 t => t.parse::<f64>().ok().map(|x| if x.is_finite() { x } else { 100.0 }),
746 }
747}
748
749async fn measure_xpsnr(
752 reference: &str,
753 distorted: &str,
754 opts: &MeasureOpts,
755) -> anyhow::Result<Vec<f64>> {
756 let (width, height, _nb) = reference_dims(reference, opts).await?;
757 let stats = tempfile::Builder::new().prefix("viser-xpsnr-").suffix(".log").tempfile()?;
758 let stats_path = stats.path().to_string_lossy().to_string();
759
760 let filtergraph = format!(
762 "[0:v]scale={width}:{height}:flags=bicubic[dist];[dist][1:v]xpsnr=stats_file={stats_path}"
763 );
764 let output = Command::new(ffmpeg_path())
765 .args(["-i", distorted, "-i", reference, "-lavfi", &filtergraph, "-f", "null", "-"])
766 .stderr(std::process::Stdio::piped())
767 .output()
768 .await?;
769
770 if !output.status.success() {
771 let stderr = String::from_utf8_lossy(&output.stderr);
772 anyhow::bail!("xpsnr measurement failed: {stderr}");
773 }
774
775 let log = std::fs::read_to_string(stats.path())?;
776 let mut scores = Vec::new();
777 for line in log.lines() {
778 if let Some(y) = parse_xpsnr_component(line, "y:") {
780 let u = parse_xpsnr_component(line, "u:").unwrap_or(y);
781 let v = parse_xpsnr_component(line, "v:").unwrap_or(y);
782 scores.push((6.0 * y + u + v) / 8.0);
783 }
784 }
785 Ok(scores)
786}
787
788#[cfg(test)]
789mod tests {
790 use super::*;
791
792 #[test]
793 fn test_metric_serde_roundtrip() {
794 for m in
795 &[Metric::Vmaf, Metric::Psnr, Metric::Ssim, Metric::Ssimulacra2, Metric::Butteraugli]
796 {
797 let json = serde_json::to_string(m).unwrap();
798 let back: Metric = serde_json::from_str(&json).unwrap();
799 assert_eq!(*m, back);
800 }
801 }
802
803 #[test]
804 fn test_metric_serde_names() {
805 assert_eq!(serde_json::to_string(&Metric::Vmaf).unwrap(), "\"vmaf\"");
806 assert_eq!(serde_json::to_string(&Metric::Psnr).unwrap(), "\"psnr\"");
807 assert_eq!(serde_json::to_string(&Metric::Ssim).unwrap(), "\"ssim\"");
808 assert_eq!(serde_json::to_string(&Metric::Ssimulacra2).unwrap(), "\"ssimulacra2\"");
809 assert_eq!(serde_json::to_string(&Metric::Butteraugli).unwrap(), "\"butteraugli\"");
810 }
811
812 #[test]
813 fn test_metric_eq() {
814 assert_eq!(Metric::Vmaf, Metric::Vmaf);
815 assert_ne!(Metric::Vmaf, Metric::Psnr);
816 assert_eq!(Metric::Ssimulacra2, Metric::Ssimulacra2);
817 assert_ne!(Metric::Ssimulacra2, Metric::Butteraugli);
818 }
819
820 #[test]
821 fn test_result_default() {
822 let r = Result::default();
823 assert!((r.vmaf - 0.0).abs() < 1e-9);
824 assert!((r.psnr - 0.0).abs() < 1e-9);
825 assert!((r.ssim - 0.0).abs() < 1e-9);
826 assert!((r.ssimulacra2 - 0.0).abs() < 1e-9);
827 assert!((r.butteraugli - 0.0).abs() < 1e-9);
828 assert!(r.frames.is_empty());
829 }
830
831 #[test]
832 fn test_parse_vmaf_log_basic() {
833 let json = br#"{
834 "frames": [
835 {"frameNum": 0, "metrics": {"vmaf": 85.0, "psnr_y": 38.5, "float_ssim": 0.95}}
836 ],
837 "pooled_metrics": {
838 "vmaf": {"mean": 86.5},
839 "psnr_y": {"mean": 39.2},
840 "float_ssim": {"mean": 0.96}
841 }
842 }"#;
843 let result = parse_vmaf_log(json, false).unwrap();
844 assert!((result.vmaf - 86.5).abs() < 1e-9);
845 assert!((result.psnr - 39.2).abs() < 1e-9);
846 assert!((result.ssim - 0.96).abs() < 1e-9);
847 assert!(result.frames.is_empty());
848 }
849
850 #[test]
851 fn test_parse_vmaf_log_per_frame() {
852 let json = br#"{
853 "frames": [
854 {"frameNum": 0, "metrics": {"vmaf": 80.0, "psnr_y": 37.0, "float_ssim": 0.93}},
855 {"frameNum": 1, "metrics": {"vmaf": 90.0, "psnr_y": 40.0, "float_ssim": 0.97}}
856 ],
857 "pooled_metrics": {
858 "vmaf": {"mean": 85.0},
859 "psnr_y": {"mean": 38.5},
860 "float_ssim": {"mean": 0.95}
861 }
862 }"#;
863 let result = parse_vmaf_log(json, true).unwrap();
864 assert_eq!(result.frames.len(), 2);
865 assert_eq!(result.frames[0].frame_num, 0);
866 assert!((result.frames[0].vmaf - 80.0).abs() < 1e-9);
867 assert_eq!(result.frames[1].frame_num, 1);
868 assert!((result.frames[1].vmaf - 90.0).abs() < 1e-9);
869 }
870
871 #[test]
872 fn test_parse_vmaf_log_fallback_psnr() {
873 let json = br#"{
874 "frames": [],
875 "pooled_metrics": {
876 "vmaf": {"mean": 85.0},
877 "psnr": {"mean": 39.0},
878 "ssim": {"mean": 0.94}
879 }
880 }"#;
881 let result = parse_vmaf_log(json, false).unwrap();
882 assert!((result.psnr - 39.0).abs() < 1e-9);
883 }
884
885 #[test]
886 fn test_parse_vmaf_log_missing_metrics() {
887 let json = br#"{
888 "frames": [],
889 "pooled_metrics": {}
890 }"#;
891 let result = parse_vmaf_log(json, false).unwrap();
892 assert!((result.vmaf - 0.0).abs() < 1e-9);
893 assert!((result.psnr - 0.0).abs() < 1e-9);
894 assert!((result.ssim - 0.0).abs() < 1e-9);
895 }
896
897 #[test]
898 fn test_parse_vmaf_log_invalid_json() {
899 assert!(parse_vmaf_log(b"not json", false).is_err());
900 }
901
902 #[test]
903 fn test_result_serde_roundtrip() {
904 let r = Result {
905 vmaf: 85.0,
906 psnr: 38.5,
907 ssim: 0.95,
908 ssimulacra2: 70.0,
909 butteraugli: 0.5,
910 ..Default::default()
911 };
912 let json = serde_json::to_string(&r).unwrap();
913 let back: Result = serde_json::from_str(&json).unwrap();
914 assert!((back.vmaf - 85.0).abs() < 1e-9);
915 assert!((back.ssimulacra2 - 70.0).abs() < 1e-9);
916 assert!((back.butteraugli - 0.5).abs() < 1e-9);
917 }
918
919 #[test]
920 fn test_parse_vmaf_log_per_component_psnr() {
921 let json = br#"{
922 "frames": [],
923 "pooled_metrics": {
924 "vmaf": {"mean": 85.0},
925 "psnr_y": {"mean": 40.0},
926 "psnr_cb": {"mean": 44.0},
927 "psnr_cr": {"mean": 46.0},
928 "float_ssim": {"mean": 0.95}
929 }
930 }"#;
931 let result = parse_vmaf_log(json, false).unwrap();
932 assert!((result.psnr - 40.0).abs() < 1e-9, "luma");
933 assert!((result.psnr_u - 44.0).abs() < 1e-9, "Cb");
934 assert!((result.psnr_v - 46.0).abs() < 1e-9, "Cr");
935 assert!((result.psnr_avg - 41.25).abs() < 1e-9, "weighted avg");
937 }
938
939 #[test]
940 fn test_parse_vmaf_log_psnr_avg_falls_back_to_luma() {
941 let json = br#"{
942 "frames": [],
943 "pooled_metrics": {"psnr_y": {"mean": 39.0}}
944 }"#;
945 let result = parse_vmaf_log(json, false).unwrap();
946 assert!((result.psnr_avg - 39.0).abs() < 1e-9);
947 }
948
949 #[test]
950 fn test_parse_vmaf_log_pooled_distribution() {
951 let json = br#"{
952 "frames": [
953 {"frameNum": 0, "metrics": {"vmaf": 80.0, "psnr_y": 37.0, "float_ssim": 0.93}},
954 {"frameNum": 1, "metrics": {"vmaf": 90.0, "psnr_y": 41.0, "float_ssim": 0.97}}
955 ],
956 "pooled_metrics": {"vmaf": {"mean": 85.0}}
957 }"#;
958 let result = parse_vmaf_log(json, false).unwrap();
959 assert_eq!(result.pooled.vmaf.count, 2);
960 assert!((result.pooled.vmaf.min - 80.0).abs() < 1e-9);
961 assert!((result.pooled.vmaf.max - 90.0).abs() < 1e-9);
962 assert!((result.pooled.vmaf.mean - 85.0).abs() < 1e-9);
963 assert!((result.pooled.psnr.min - 37.0).abs() < 1e-9);
965 assert!((result.psnr - 39.0).abs() < 1e-9, "psnr falls back to frame mean");
966 }
967
968 #[test]
969 fn test_sample_indices() {
970 assert_eq!(sample_indices(100, 0), vec![0]);
971 assert_eq!(sample_indices(100, 1), vec![0]);
972 assert_eq!(sample_indices(0, 5), vec![0]);
973 assert_eq!(sample_indices(1, 5), vec![0]);
974 assert_eq!(sample_indices(101, 3), vec![0, 50, 100]);
975 assert_eq!(sample_indices(2, 10), vec![0, 1]);
977 }
978
979 #[test]
980 fn test_result_serde_omits_zero_frames() {
981 let r = Result::default();
982 let json = serde_json::to_string(&r).unwrap();
983 assert!(!json.contains("frames"));
984 }
985
986 #[test]
987 fn test_measure_opts_default() {
988 let opts = MeasureOpts::default();
989 assert_eq!(opts.metrics.len(), 5);
990 assert_eq!(opts.subsample, 0);
991 assert_eq!(opts.model, "vmaf_v0.6.1");
992 assert!(!opts.per_frame);
993 assert_eq!(opts.frame_samples, 0);
994 assert!(opts.probe_cache.is_none());
995 }
996
997 #[test]
998 fn test_vif_mean() {
999 let mut m = std::collections::HashMap::new();
1000 m.insert("integer_vif_scale0".to_string(), 0.2);
1001 m.insert("integer_vif_scale1".to_string(), 0.4);
1002 m.insert("integer_vif_scale2".to_string(), 0.6);
1003 m.insert("integer_vif_scale3".to_string(), 0.8);
1004 assert!((vif_mean(&m).unwrap() - 0.5).abs() < 1e-9);
1005
1006 let mut m2 = std::collections::HashMap::new();
1008 m2.insert("vif_scale0".to_string(), 1.0);
1009 m2.insert("float_vif_scale1".to_string(), 0.0);
1010 assert!((vif_mean(&m2).unwrap() - 0.5).abs() < 1e-9);
1011
1012 assert!(vif_mean(&std::collections::HashMap::new()).is_none());
1013 }
1014
1015 #[test]
1016 fn test_parse_xpsnr_component() {
1017 let line = "n: 1 XPSNR y: 46.9714 XPSNR u: 45.1188 XPSNR v: 45.0873";
1018 assert!((parse_xpsnr_component(line, "y:").unwrap() - 46.9714).abs() < 1e-9);
1019 assert!((parse_xpsnr_component(line, "u:").unwrap() - 45.1188).abs() < 1e-9);
1020 assert!((parse_xpsnr_component(line, "v:").unwrap() - 45.0873).abs() < 1e-9);
1021 assert_eq!(parse_xpsnr_component("XPSNR y: inf", "y:"), Some(100.0));
1023 assert_eq!(parse_xpsnr_component("nothing here", "y:"), None);
1024 }
1025
1026 #[test]
1027 fn test_parse_vmaf_log_extended_metrics() {
1028 let json = br#"{
1029 "frames": [
1030 {"frameNum": 0, "metrics": {"vmaf": 80.0, "float_ms_ssim": 0.90, "cambi": 2.0,
1031 "integer_vif_scale0": 0.2, "integer_vif_scale1": 0.4,
1032 "integer_vif_scale2": 0.6, "integer_vif_scale3": 0.8}},
1033 {"frameNum": 1, "metrics": {"vmaf": 90.0, "float_ms_ssim": 1.00, "cambi": 0.0,
1034 "integer_vif_scale0": 0.4, "integer_vif_scale1": 0.6,
1035 "integer_vif_scale2": 0.8, "integer_vif_scale3": 1.0}}
1036 ],
1037 "pooled_metrics": {"vmaf": {"mean": 85.0}}
1038 }"#;
1039 let result = parse_vmaf_log(json, true).unwrap();
1040 assert!((result.ms_ssim - 0.95).abs() < 1e-9);
1042 assert!((result.cambi - 1.0).abs() < 1e-9);
1044 assert!((result.vif - 0.6).abs() < 1e-9);
1046 assert!((result.frames[0].ms_ssim - 0.90).abs() < 1e-9);
1048 assert!((result.frames[0].vif - 0.5).abs() < 1e-9);
1049 assert!((result.frames[1].cambi - 0.0).abs() < 1e-9);
1050 }
1051
1052 #[test]
1054 fn test_parse_vmaf_log_ssim_no_float_prefix() {
1055 let json = br#"{
1056 "frames": [{"frameNum": 0, "metrics": {"ssim": 0.92}}],
1057 "pooled_metrics": {"ssim": {"mean": 0.92}}
1058 }"#;
1059 let result = parse_vmaf_log(json, false).unwrap();
1060 assert!((result.ssim - 0.92).abs() < 1e-9);
1061 }
1062
1063 #[test]
1064 fn test_parse_vmaf_log_ms_ssim_fallback_name() {
1065 let json = br#"{
1066 "frames": [{"frameNum": 0, "metrics": {"ms_ssim": 0.88}}],
1067 "pooled_metrics": {}
1068 }"#;
1069 let result = parse_vmaf_log(json, false).unwrap();
1070 assert!((result.ms_ssim - 0.88).abs() < 1e-9);
1071 }
1072
1073 #[test]
1074 fn test_parse_vmaf_log_psnr_cb_cr_fallback_names() {
1075 let json = br#"{
1076 "frames": [],
1077 "pooled_metrics": {
1078 "psnr_y": {"mean": 40.0},
1079 "psnr_cb": {"mean": 44.0},
1080 "psnr_cr": {"mean": 46.0}
1081 }
1082 }"#;
1083 let result = parse_vmaf_log(json, false).unwrap();
1084 assert!((result.psnr_u - 44.0).abs() < 1e-9, "Cb via psnr_cb");
1085 assert!((result.psnr_v - 46.0).abs() < 1e-9, "Cr via psnr_cr");
1086 }
1087
1088 #[test]
1089 fn test_parse_vmaf_log_psnr_u_v_fallback_names() {
1090 let json = br#"{
1091 "frames": [],
1092 "pooled_metrics": {
1093 "psnr_y": {"mean": 40.0},
1094 "psnr_u": {"mean": 43.0},
1095 "psnr_v": {"mean": 45.0}
1096 }
1097 }"#;
1098 let result = parse_vmaf_log(json, false).unwrap();
1099 assert!((result.psnr_u - 43.0).abs() < 1e-9, "Cb via psnr_u");
1100 assert!((result.psnr_v - 45.0).abs() < 1e-9, "Cr via psnr_v");
1101 }
1102
1103 #[test]
1104 fn test_parse_vmaf_log_pooled_missing_fallback_to_frame_mean() {
1105 let json = br#"{
1106 "frames": [
1107 {"frameNum": 0, "metrics": {"vmaf": 80.0, "psnr_y": 36.0, "float_ssim": 0.90}},
1108 {"frameNum": 1, "metrics": {"vmaf": 90.0, "psnr_y": 42.0, "float_ssim": 0.96}}
1109 ],
1110 "pooled_metrics": {}
1111 }"#;
1112 let result = parse_vmaf_log(json, false).unwrap();
1113 assert!((result.vmaf - 85.0).abs() < 1e-9);
1114 assert!((result.psnr - 39.0).abs() < 1e-9);
1115 assert!((result.ssim - 0.93).abs() < 1e-9);
1116 }
1117
1118 #[test]
1119 fn test_parse_vmaf_log_empty_frames_and_pooled() {
1120 let json = br#"{
1121 "frames": [],
1122 "pooled_metrics": {}
1123 }"#;
1124 let result = parse_vmaf_log(json, false).unwrap();
1125 assert!((result.vmaf - 0.0).abs() < 1e-9);
1126 assert!((result.psnr - 0.0).abs() < 1e-9);
1127 assert!((result.ssim - 0.0).abs() < 1e-9);
1128 assert!((result.ms_ssim - 0.0).abs() < 1e-9);
1129 assert!((result.vif - 0.0).abs() < 1e-9);
1130 assert!((result.cambi - 0.0).abs() < 1e-9);
1131 }
1132
1133 #[test]
1134 fn test_parse_vmaf_log_single_frame_with_pooled() {
1135 let json = br#"{
1136 "frames": [{"frameNum": 0, "metrics": {"vmaf": 95.0}}],
1137 "pooled_metrics": {"vmaf": {"mean": 95.0}}
1138 }"#;
1139 let result = parse_vmaf_log(json, false).unwrap();
1140 assert!((result.vmaf - 95.0).abs() < 1e-9);
1141 assert_eq!(result.pooled.vmaf.count, 1);
1142 }
1143
1144 #[test]
1145 fn test_parse_vmaf_log_vif_mixed_naming() {
1146 let json = br#"{
1147 "frames": [{"frameNum": 0, "metrics": {
1148 "integer_vif_scale0": 0.5,
1149 "float_vif_scale0": 0.4,
1150 "vif_scale1": 0.6,
1151 "integer_vif_scale1": 0.6
1152 }}],
1153 "pooled_metrics": {}
1154 }"#;
1155 let result = parse_vmaf_log(json, false).unwrap();
1156 assert!((result.vif - 0.55).abs() < 1e-9, "mean of 2 scales with naming variants");
1159 }
1160
1161 #[test]
1162 fn test_parse_vmaf_log_xpsnr_per_frame_propagation() {
1163 let json = br#"{
1164 "frames": [
1165 {"frameNum": 0, "metrics": {"vmaf": 85.0}}
1166 ],
1167 "pooled_metrics": {"vmaf": {"mean": 85.0}}
1168 }"#;
1169 let mut result = parse_vmaf_log(json, true).unwrap();
1170 result.xpsnr = 0.0;
1171 result.frames[0].xpsnr = 45.5;
1172 assert!((result.frames[0].xpsnr - 45.5).abs() < 1e-9);
1173 }
1174
1175 #[test]
1176 fn test_parse_vmaf_log_pooled_distribution_single_frame() {
1177 let json = br#"{
1178 "frames": [{"frameNum": 0, "metrics": {"vmaf": 88.0}}],
1179 "pooled_metrics": {"vmaf": {"mean": 88.0}}
1180 }"#;
1181 let result = parse_vmaf_log(json, false).unwrap();
1182 assert_eq!(result.pooled.vmaf.count, 1);
1183 assert!((result.pooled.vmaf.min - 88.0).abs() < 1e-9);
1184 assert!((result.pooled.vmaf.max - 88.0).abs() < 1e-9);
1185 assert!((result.pooled.vmaf.mean - 88.0).abs() < 1e-9);
1186 }
1187
1188 #[test]
1189 fn test_parse_vmaf_log_per_frame_with_missing_metrics() {
1190 let json = br#"{
1191 "frames": [
1192 {"frameNum": 0, "metrics": {"vmaf": 85.0}},
1193 {"frameNum": 1, "metrics": {}}
1194 ],
1195 "pooled_metrics": {"vmaf": {"mean": 85.0}}
1196 }"#;
1197 let result = parse_vmaf_log(json, true).unwrap();
1198 assert_eq!(result.frames.len(), 2);
1199 assert!((result.frames[0].vmaf - 85.0).abs() < 1e-9);
1200 assert!((result.frames[1].vmaf - 0.0).abs() < 1e-9);
1201 }
1202
1203 #[test]
1204 fn test_parse_xpsnr_component_negative_inf() {
1205 assert_eq!(parse_xpsnr_component("XPSNR y: -inf", "y:"), Some(100.0));
1206 }
1207
1208 #[test]
1209 fn test_parse_xpsnr_component_nan() {
1210 assert_eq!(parse_xpsnr_component("XPSNR y: NaN", "y:"), Some(100.0));
1211 }
1212
1213 #[test]
1214 fn test_parse_xpsnr_component_regular() {
1215 assert!((parse_xpsnr_component("XPSNR u: 44.5678", "u:").unwrap() - 44.5678).abs() < 1e-4);
1216 }
1217
1218 #[test]
1219 fn test_parse_xpsnr_component_bad_format() {
1220 assert_eq!(parse_xpsnr_component("n: 1 XPSNR", "y:"), None);
1221 }
1222
1223 #[test]
1224 fn test_sample_indices_uneven() {
1225 assert_eq!(sample_indices(5, 3), vec![0, 2, 4]);
1226 }
1227
1228 #[test]
1229 fn test_sample_indices_more_samples_than_frames() {
1230 assert_eq!(sample_indices(2, 10), vec![0, 1]);
1231 }
1232
1233 #[test]
1234 fn test_sample_indices_single_frame_input() {
1235 assert_eq!(sample_indices(1, 5), vec![0]);
1236 }
1237
1238 #[test]
1239 fn test_sample_indices_large_values() {
1240 let indices = sample_indices(1000, 5);
1241 assert_eq!(indices.len(), 5);
1242 assert_eq!(indices[0], 0);
1243 assert_eq!(indices[4], 999);
1244 }
1245
1246 #[test]
1248 fn test_pooled_mean_first_match_wins() {
1249 let mut map = std::collections::HashMap::new();
1250 map.insert("psnr_y".to_string(), PooledMetric { mean: 40.0 });
1251 map.insert("psnr".to_string(), PooledMetric { mean: 39.0 });
1252 assert_eq!(
1253 pooled_mean(&VmafLog { frames: vec![], pooled_metrics: map }, &["psnr_y", "psnr"]),
1254 40.0
1255 );
1256 }
1257
1258 #[test]
1259 fn test_pooled_mean_fallback() {
1260 let mut map = std::collections::HashMap::new();
1261 map.insert("psnr".to_string(), PooledMetric { mean: 39.0 });
1262 assert_eq!(
1263 pooled_mean(&VmafLog { frames: vec![], pooled_metrics: map }, &["psnr_y", "psnr"]),
1264 39.0
1265 );
1266 }
1267
1268 #[test]
1269 fn test_pooled_mean_missing_all() {
1270 assert_eq!(
1271 pooled_mean(
1272 &VmafLog { frames: vec![], pooled_metrics: std::collections::HashMap::new() },
1273 &["psnr_y", "psnr"]
1274 ),
1275 0.0
1276 );
1277 }
1278
1279 #[test]
1280 fn test_frame_metric_first_match() {
1281 let mut map = std::collections::HashMap::new();
1282 map.insert("psnr_y".to_string(), 40.0);
1283 map.insert("psnr".to_string(), 39.0);
1284 assert_eq!(frame_metric(&map, &["psnr_y", "psnr"]), Some(40.0));
1285 }
1286
1287 #[test]
1288 fn test_frame_metric_fallback() {
1289 let mut map = std::collections::HashMap::new();
1290 map.insert("psnr".to_string(), 39.0);
1291 assert_eq!(frame_metric(&map, &["psnr_y", "psnr"]), Some(39.0));
1292 }
1293
1294 #[test]
1295 fn test_frame_metric_missing() {
1296 assert_eq!(frame_metric(&std::collections::HashMap::new(), &["psnr_y"]), None);
1297 }
1298}