av1an_core/
target_quality.rs

1use std::{
2    borrow::Cow,
3    cmp::{self, Ordering},
4    collections::HashSet,
5    io::Read,
6    path::{Path, PathBuf},
7    process::{Child, Stdio},
8    str::FromStr,
9    thread::{self, available_parallelism},
10};
11
12use anyhow::{anyhow, bail};
13use serde::{Deserialize, Serialize};
14use tracing::{debug, trace};
15
16use crate::{
17    broker::EncoderCrash,
18    chunk::Chunk,
19    ffmpeg::FFPixelFormat,
20    interpol::{
21        akima_interpolate,
22        catmull_rom_interpolate,
23        cubic_polynomial_interpolate,
24        linear_interpolate,
25        natural_cubic_spline,
26        pchip_interpolate,
27        quadratic_interpolate,
28    },
29    metrics::{
30        butteraugli::ButteraugliSubMetric,
31        statistics::MetricStatistics,
32        vmaf::{get_vmaf_model_version, read_vmaf_file, run_vmaf, run_vmaf_weighted},
33        xpsnr::{read_xpsnr_file, run_xpsnr, XPSNRSubMetric},
34    },
35    progress_bar::update_mp_msg,
36    vapoursynth::{measure_butteraugli, measure_ssimulacra2, measure_xpsnr, VapoursynthPlugins},
37    Encoder,
38    ProbingStatistic,
39    ProbingStatisticName,
40    TargetMetric,
41    VmafFeature,
42};
43
44#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
45pub enum InterpolationMethod {
46    Linear,
47    Quadratic,
48    Natural,
49    Pchip,
50    Catmull,
51    Akima,
52    CubicPolynomial,
53}
54
55impl FromStr for InterpolationMethod {
56    type Err = ();
57
58    #[inline]
59    fn from_str(s: &str) -> Result<Self, Self::Err> {
60        match s.to_lowercase().as_str() {
61            "linear" => Ok(Self::Linear),
62            "quadratic" => Ok(Self::Quadratic),
63            "natural" => Ok(Self::Natural),
64            "pchip" => Ok(Self::Pchip),
65            "catmull" => Ok(Self::Catmull),
66            "akima" => Ok(Self::Akima),
67            "cubicpolynomial" | "cubic" => Ok(Self::CubicPolynomial),
68            _ => Err(()),
69        }
70    }
71}
72
73#[derive(Debug, Clone, Serialize, Deserialize)]
74pub struct TargetQuality {
75    pub vmaf_res:              String,
76    pub probe_res:             Option<(u32, u32)>,
77    pub vmaf_scaler:           String,
78    pub vmaf_filter:           Option<String>,
79    pub vmaf_threads:          usize,
80    pub model:                 Option<PathBuf>,
81    pub probing_rate:          usize,
82    pub probes:                u32,
83    pub target:                Option<(f64, f64)>,
84    pub metric:                TargetMetric,
85    pub min_q:                 u32,
86    pub max_q:                 u32,
87    pub interp_method:         Option<(InterpolationMethod, InterpolationMethod)>,
88    pub encoder:               Encoder,
89    pub pix_format:            FFPixelFormat,
90    pub temp:                  String,
91    pub workers:               usize,
92    pub video_params:          Option<Vec<String>>,
93    pub params_copied:         bool,
94    pub vspipe_args:           Vec<String>,
95    pub probing_vmaf_features: Vec<VmafFeature>,
96    pub probing_statistic:     ProbingStatistic,
97}
98
99impl TargetQuality {
100    #[inline]
101    pub fn default(temp_dir: &str, encoder: Encoder) -> Self {
102        Self {
103            vmaf_res: "1920x1080".to_string(),
104            probe_res: Some((1920, 1080)),
105            vmaf_scaler: "bicubic".to_string(),
106            vmaf_filter: None,
107            vmaf_threads: available_parallelism()
108                .expect("Unrecoverable: Failed to get thread count")
109                .get(),
110            model: None,
111            probing_rate: 1,
112            probes: 4,
113            target: None,
114            metric: TargetMetric::VMAF,
115            min_q: encoder.get_default_cq_range().0 as u32,
116            max_q: encoder.get_default_cq_range().1 as u32,
117            interp_method: None,
118            encoder,
119            pix_format: FFPixelFormat::YUV420P10LE,
120            temp: temp_dir.to_owned(),
121            workers: 1,
122            video_params: None,
123            params_copied: false,
124            vspipe_args: vec![],
125            probing_vmaf_features: vec![VmafFeature::Default],
126            probing_statistic: ProbingStatistic {
127                name:  ProbingStatisticName::Automatic,
128                value: None,
129            },
130        }
131    }
132
133    #[inline]
134    pub fn per_shot_target_quality(
135        &self,
136        chunk: &Chunk,
137        worker_id: Option<usize>,
138        plugins: Option<VapoursynthPlugins>,
139    ) -> anyhow::Result<f32> {
140        anyhow::ensure!(self.target.is_some(), "Target must be some");
141        let target = self.target.expect("target is some");
142        // History of probe results as quantizer-score pairs
143        let mut quantizer_score_history: Vec<(f32, f64)> = vec![];
144
145        let update_progress_bar = |next_quantizer: f32| {
146            if let Some(worker_id) = worker_id {
147                update_mp_msg(
148                    worker_id,
149                    format!(
150                        "Targeting {metric} Quality {min}-{max} - Testing {quantizer}",
151                        metric = self.metric,
152                        min = target.0,
153                        max = target.1,
154                        quantizer = next_quantizer
155                    ),
156                );
157            }
158        };
159
160        // Initialize quantizer limits from specified range or encoder defaults
161        let step = match self.encoder {
162            Encoder::x264 | Encoder::x265 => 0.25,
163            Encoder::svt_av1 if crate::encoder::svt_av1_supports_quarter_steps(&self.temp) => 0.25,
164            _ => 1.0,
165        };
166        let mut lower_quantizer_limit = self.min_q as f32;
167        let mut upper_quantizer_limit = self.max_q as f32;
168
169        let skip_reason;
170
171        loop {
172            let next_quantizer = predict_quantizer(
173                lower_quantizer_limit,
174                upper_quantizer_limit,
175                &quantizer_score_history,
176                // Invert for butteraugli
177                match self.metric {
178                    TargetMetric::ButteraugliINF | TargetMetric::Butteraugli3 => {
179                        let (min, max) = target;
180                        (-max, -min)
181                    },
182                    _ => target,
183                },
184                self.interp_method,
185                step,
186            )?;
187
188            if quantizer_score_history
189                .iter()
190                .any(|(quantizer, _)| *quantizer == next_quantizer)
191            {
192                // Predicted quantizer has already been probed
193                skip_reason = SkipProbingReason::None;
194                break;
195            }
196
197            update_progress_bar(next_quantizer);
198
199            let score = {
200                let value = self.probe(chunk, next_quantizer, plugins)?;
201
202                // Butteraugli is an inverse metric, invert score for comparisons
203                match self.metric {
204                    TargetMetric::ButteraugliINF | TargetMetric::Butteraugli3 => -value,
205                    _ => value,
206                }
207            };
208            let score_within_range = within_range(
209                match self.metric {
210                    TargetMetric::ButteraugliINF | TargetMetric::Butteraugli3 => -score,
211                    _ => score,
212                },
213                target,
214            );
215
216            quantizer_score_history.push((next_quantizer, score));
217
218            if score_within_range || quantizer_score_history.len() >= self.probes as usize {
219                skip_reason = if score_within_range {
220                    SkipProbingReason::WithinTolerance
221                } else {
222                    SkipProbingReason::ProbeLimitReached
223                };
224                break;
225            }
226
227            let target_range = match self.metric {
228                TargetMetric::ButteraugliINF | TargetMetric::Butteraugli3 => (-target.1, -target.0),
229                _ => target,
230            };
231
232            if score > target_range.1 {
233                lower_quantizer_limit = ((next_quantizer) + step).min(upper_quantizer_limit);
234            } else if score < target_range.0 {
235                upper_quantizer_limit = ((next_quantizer) - step).max(lower_quantizer_limit);
236            }
237
238            // Ensure quantizer limits are valid
239            if lower_quantizer_limit > upper_quantizer_limit {
240                skip_reason = if score > target_range.1 {
241                    SkipProbingReason::QuantizerTooHigh
242                } else {
243                    SkipProbingReason::QuantizerTooLow
244                };
245                break;
246            }
247        }
248
249        // Calculate final quantizer and score BEFORE logging
250        let final_quantizer_score = quantizer_score_history
251            .iter()
252            .filter(|(_, score)| {
253                within_range(
254                    match self.metric {
255                        TargetMetric::ButteraugliINF | TargetMetric::Butteraugli3 => -score,
256                        _ => *score,
257                    },
258                    target,
259                )
260            })
261            .max_by(|(q1, _), (q2, _)| q1.partial_cmp(q2).unwrap_or(std::cmp::Ordering::Equal))
262            .unwrap_or_else(|| {
263                // No quantizers within tolerance, choose the quantizer closest to target
264                let target_midpoint = f64::midpoint(target.0, target.1);
265                quantizer_score_history
266                    .iter()
267                    .min_by(|(_, score1), (_, score2)| {
268                        let score_1 = match self.metric {
269                            TargetMetric::ButteraugliINF | TargetMetric::Butteraugli3 => -score1,
270                            _ => *score1,
271                        };
272                        let score_2 = match self.metric {
273                            TargetMetric::ButteraugliINF | TargetMetric::Butteraugli3 => -score2,
274                            _ => *score2,
275                        };
276                        let difference1 = (score_1 - target_midpoint).abs();
277                        let difference2 = (score_2 - target_midpoint).abs();
278                        difference1.partial_cmp(&difference2).unwrap_or(Ordering::Equal)
279                    })
280                    .expect("quantizer_score_history is not empty")
281            });
282
283        log_probes(
284            &quantizer_score_history,
285            self.metric,
286            target,
287            chunk.frames() as u32,
288            self.probing_rate as u32,
289            self.video_params.as_ref(),
290            &chunk.name(),
291            final_quantizer_score.0,
292            // Inverse reverse metrics
293            match self.metric {
294                TargetMetric::ButteraugliINF | TargetMetric::Butteraugli3 => {
295                    -final_quantizer_score.1
296                },
297                _ => final_quantizer_score.1,
298            },
299            skip_reason,
300        );
301
302        Ok(final_quantizer_score.0)
303    }
304
305    fn probe(
306        &self,
307        chunk: &Chunk,
308        quantizer: f32,
309        plugins: Option<VapoursynthPlugins>,
310    ) -> anyhow::Result<f64> {
311        let probe_name = self.encode_probe(chunk, quantizer)?;
312        let reference_pipe_cmd =
313            chunk.proxy_cmd.as_ref().map_or(chunk.source_cmd.as_slice(), |proxy_cmd| {
314                proxy_cmd.as_slice()
315            });
316
317        let aggregate_frame_scores = |scores: Vec<f64>| -> anyhow::Result<f64> {
318            let mut statistics = MetricStatistics::new(scores);
319
320            let aggregate = match self.probing_statistic.name {
321                ProbingStatisticName::Automatic => {
322                    if self.metric == TargetMetric::VMAF {
323                        // Preserve legacy VMAF aggregation
324                        return Ok(statistics.percentile(1));
325                    }
326
327                    let sigma_1 = {
328                        let sigma_distance = statistics.standard_deviation();
329                        let statistic = statistics.mean() - sigma_distance;
330                        statistic.clamp(statistics.minimum(), statistics.maximum())
331                    };
332
333                    // Based on quantizer - lower quantizer leads to more accurate scores (lower
334                    // variance) (citation needed)
335                    if self.encoder.get_cq_relative_percentage(quantizer as usize) > 0.25 {
336                        // Liberal: Use mean to determine aggregate
337                        statistics.mean()
338                    } else {
339                        // Less liberal: Use -1 sigma to determine aggregate
340                        sigma_1
341                    }
342                },
343                ProbingStatisticName::Mean => statistics.mean(),
344                ProbingStatisticName::RootMeanSquare => statistics.root_mean_square(),
345                ProbingStatisticName::Median => statistics.median(),
346                ProbingStatisticName::Harmonic => statistics.harmonic_mean(),
347                ProbingStatisticName::Percentile => {
348                    let value = self
349                        .probing_statistic
350                        .value
351                        .ok_or_else(|| anyhow::anyhow!("Percentile statistic requires a value"))?;
352                    statistics.percentile(value as usize)
353                },
354                ProbingStatisticName::StandardDeviation => {
355                    let value = self.probing_statistic.value.ok_or_else(|| {
356                        anyhow::anyhow!("Standard deviation statistic requires a value")
357                    })?;
358                    let sigma_distance = value * statistics.standard_deviation();
359                    let statistic = statistics.mean() + sigma_distance;
360                    statistic.clamp(statistics.minimum(), statistics.maximum())
361                },
362                ProbingStatisticName::Mode => statistics.mode(),
363                ProbingStatisticName::Minimum => statistics.minimum(),
364                ProbingStatisticName::Maximum => statistics.maximum(),
365            };
366
367            Ok(aggregate)
368        };
369
370        match self.metric {
371            TargetMetric::VMAF => {
372                let features: HashSet<_> = self.probing_vmaf_features.iter().copied().collect();
373                let use_weighted = features.contains(&VmafFeature::Weighted);
374                let disable_motion = features.contains(&VmafFeature::Motionless);
375
376                // TODO: Update when nightly changes come to stable (2025-07-15)
377                //   let model = if self.model.is_some() {
378                //     self.model.as_ref()
379                // } else {
380                //     some(&pathbuf::from(format!(
381                //         "{}.json",
382                //         get_vmaf_model_version(&self.probing_vmaf_features)
383                //     )))
384                // };
385
386                let default_model = Some(PathBuf::from(format!(
387                    "{}.json",
388                    get_vmaf_model_version(&self.probing_vmaf_features)
389                )));
390
391                let model = if self.model.is_none() {
392                    default_model.as_ref()
393                } else {
394                    self.model.as_ref()
395                };
396
397                let vmaf_scores = if use_weighted {
398                    run_vmaf_weighted(
399                        &probe_name,
400                        reference_pipe_cmd,
401                        self.vspipe_args.clone(),
402                        model,
403                        self.vmaf_threads,
404                        chunk.frame_rate,
405                        disable_motion,
406                        &self.probing_vmaf_features,
407                    )
408                    .map_err(|e| {
409                        Box::new(EncoderCrash {
410                            exit_status:        std::process::ExitStatus::default(),
411                            source_pipe_stderr: String::new().into(),
412                            ffmpeg_pipe_stderr: None,
413                            stderr:             format!("VMAF calculation failed: {e}").into(),
414                            stdout:             String::new().into(),
415                        })
416                    })?
417                } else {
418                    let fl_path = std::path::Path::new(&chunk.temp)
419                        .join("split")
420                        .join(format!("{index}.json", index = chunk.index));
421
422                    run_vmaf(
423                        &probe_name,
424                        reference_pipe_cmd,
425                        self.vspipe_args.clone(),
426                        &fl_path,
427                        model,
428                        &self.probe_res.map_or_else(
429                            || self.vmaf_res.clone(),
430                            |(width, height)| format!("{width}x{height}"),
431                        ),
432                        &self.vmaf_scaler,
433                        self.probing_rate,
434                        self.vmaf_filter.as_deref(),
435                        self.vmaf_threads,
436                        chunk.frame_rate,
437                        disable_motion,
438                        &self.probing_vmaf_features,
439                    )?;
440
441                    read_vmaf_file(&fl_path)?
442                };
443
444                aggregate_frame_scores(vmaf_scores)
445            },
446            TargetMetric::SSIMULACRA2 => {
447                let scores = if let Some(plugins) = plugins {
448                    measure_ssimulacra2(
449                        chunk.proxy.as_ref().unwrap_or(&chunk.input),
450                        &probe_name,
451                        (chunk.start_frame as u32, chunk.end_frame as u32),
452                        self.probe_res,
453                        self.probing_rate,
454                        plugins,
455                    )?
456                } else {
457                    bail!("SSIMULACRA2 requires Vapoursynth to be installed");
458                };
459
460                aggregate_frame_scores(scores)
461            },
462            TargetMetric::ButteraugliINF | TargetMetric::Butteraugli3 => {
463                let scores = if let Some(plugins) = plugins {
464                    measure_butteraugli(
465                        match self.metric {
466                            TargetMetric::ButteraugliINF => ButteraugliSubMetric::InfiniteNorm,
467                            TargetMetric::Butteraugli3 => ButteraugliSubMetric::ThreeNorm,
468                            _ => unreachable!(),
469                        },
470                        chunk.proxy.as_ref().unwrap_or(&chunk.input),
471                        &probe_name,
472                        (chunk.start_frame as u32, chunk.end_frame as u32),
473                        self.probe_res,
474                        self.probing_rate,
475                        plugins,
476                    )?
477                } else {
478                    bail!("Butteraugli requires Vapoursynth to be installed");
479                };
480
481                aggregate_frame_scores(scores)
482            },
483            TargetMetric::XPSNR | TargetMetric::XPSNRWeighted => {
484                let submetric = if self.metric == TargetMetric::XPSNR {
485                    XPSNRSubMetric::Minimum
486                } else {
487                    XPSNRSubMetric::Weighted
488                };
489                if self.probing_rate > 1 {
490                    let scores = if let Some(plugins) = plugins {
491                        measure_xpsnr(
492                            submetric,
493                            chunk.proxy.as_ref().unwrap_or(&chunk.input),
494                            &probe_name,
495                            (chunk.start_frame as u32, chunk.end_frame as u32),
496                            self.probe_res,
497                            self.probing_rate,
498                            plugins,
499                        )?
500                    } else {
501                        bail!("XPSNR with probing_rate > 1 requires Vapoursynth to be installed");
502                    };
503
504                    aggregate_frame_scores(scores)
505                } else {
506                    let fl_path =
507                        Path::new(&chunk.temp).join("split").join(format!("{}.json", chunk.index));
508
509                    run_xpsnr(
510                        &probe_name,
511                        reference_pipe_cmd,
512                        self.vspipe_args.clone(),
513                        &fl_path,
514                        &self.probe_res.map_or_else(
515                            || self.vmaf_res.clone(),
516                            |(width, height)| format!("{width}x{height}"),
517                        ),
518                        &self.vmaf_scaler,
519                        self.probing_rate,
520                        chunk.frame_rate,
521                    )?;
522
523                    let (aggregate, scores) = read_xpsnr_file(fl_path, submetric)?;
524
525                    match self.probing_statistic.name {
526                        ProbingStatisticName::Automatic => Ok(aggregate),
527                        _ => aggregate_frame_scores(scores),
528                    }
529                }
530            },
531        }
532    }
533
534    fn encode_probe(&self, chunk: &Chunk, q: f32) -> Result<PathBuf, Box<EncoderCrash>> {
535        let vmaf_threads = if self.vmaf_threads == 0 {
536            vmaf_auto_threads(self.workers)
537        } else {
538            self.vmaf_threads
539        };
540
541        let cmd = self.encoder.probe_cmd(
542            self.temp.clone(),
543            chunk.index,
544            q,
545            self.pix_format,
546            self.probing_rate,
547            vmaf_threads,
548            self.video_params.clone(),
549        );
550
551        let source_cmd = chunk.proxy_cmd.clone().unwrap_or_else(|| chunk.source_cmd.clone());
552        let (ff_cmd, output) = cmd.clone();
553
554        thread::scope(move |scope| {
555            let mut source = if let [pipe_cmd, args @ ..] = &*source_cmd {
556                std::process::Command::new(pipe_cmd)
557                    .args(args)
558                    .stderr(std::process::Stdio::piped())
559                    .stdout(std::process::Stdio::piped())
560                    .spawn()
561                    .map_err(|e| EncoderCrash {
562                        exit_status:        std::process::ExitStatus::default(),
563                        source_pipe_stderr: format!("Failed to spawn source: {e}").into(),
564                        ffmpeg_pipe_stderr: None,
565                        stderr:             String::new().into(),
566                        stdout:             String::new().into(),
567                    })?
568            } else {
569                unreachable!()
570            };
571
572            let source_stdout = source.stdout.take().expect("source stdout should exist");
573
574            let (mut source_pipe, mut enc_pipe) = {
575                if let Some(ff_cmd) = ff_cmd.as_deref() {
576                    let (ffmpeg, args) = ff_cmd.split_first().expect("not empty");
577                    let mut source_pipe = std::process::Command::new(ffmpeg)
578                        .args(args)
579                        .stdin(source_stdout)
580                        .stdout(std::process::Stdio::piped())
581                        .stderr(std::process::Stdio::piped())
582                        .spawn()
583                        .map_err(|e| EncoderCrash {
584                            exit_status:        std::process::ExitStatus::default(),
585                            source_pipe_stderr: format!("Failed to spawn ffmpeg: {e}").into(),
586                            ffmpeg_pipe_stderr: None,
587                            stderr:             String::new().into(),
588                            stdout:             String::new().into(),
589                        })?;
590
591                    let source_pipe_stdout =
592                        source_pipe.stdout.take().expect("source_pipe stdout should exist");
593
594                    let enc_pipe = if let [cmd, args @ ..] = &*output {
595                        build_encoder_pipe(cmd, args, source_pipe_stdout)?
596                    } else {
597                        unreachable!()
598                    };
599                    (Some(source_pipe), enc_pipe)
600                } else {
601                    // We unfortunately have to duplicate the code like this
602                    // in order to satisfy the borrow checker for `source_stdout`
603                    let enc_pipe = if let [cmd, args @ ..] = &*output {
604                        build_encoder_pipe(cmd, args, source_stdout)?
605                    } else {
606                        unreachable!()
607                    };
608                    (None, enc_pipe)
609                }
610            };
611
612            // Drop stdout to prevent buffer deadlock
613            drop(enc_pipe.stdout.take());
614
615            let source_stderr = source.stderr.take().expect("source stderr should exist");
616            let stderr_thread1 = scope.spawn(move || {
617                let mut buf = Vec::new();
618                let mut stderr = source_stderr;
619                stderr.read_to_end(&mut buf).ok();
620                buf
621            });
622
623            let source_pipe_stderr = source_pipe
624                .as_mut()
625                .map(|p| p.stderr.take().expect("source_pipe stderr should exist"));
626            let stderr_thread2 = source_pipe_stderr.map(|source_pipe_stderr| {
627                scope.spawn(move || {
628                    let mut buf = Vec::new();
629                    let mut stderr = source_pipe_stderr;
630                    stderr.read_to_end(&mut buf).ok();
631                    buf
632                })
633            });
634
635            let enc_pipe_stderr = enc_pipe.stderr.take().expect("enc_pipe stderr should exist");
636            let stderr_thread3 = scope.spawn(move || {
637                let mut buf = Vec::new();
638                let mut stderr = enc_pipe_stderr;
639                stderr.read_to_end(&mut buf).ok();
640                buf
641            });
642
643            // Wait for encoder & other processes to finish
644            let enc_status = enc_pipe.wait().map_err(|e| EncoderCrash {
645                exit_status:        std::process::ExitStatus::default(),
646                source_pipe_stderr: String::new().into(),
647                ffmpeg_pipe_stderr: None,
648                stderr:             format!("Failed to wait for encoder: {e}").into(),
649                stdout:             String::new().into(),
650            })?;
651
652            if let Some(source_pipe) = source_pipe.as_mut() {
653                let _ = source_pipe.wait();
654            };
655            let _ = source.wait();
656
657            // Collect stderr after process finishes
658            let stderr_handles = (
659                stderr_thread1.join().unwrap_or_default(),
660                stderr_thread2.map(|t| t.join().unwrap_or_default()),
661                stderr_thread3.join().unwrap_or_default(),
662            );
663
664            if !enc_status.success() {
665                return Err(EncoderCrash {
666                    exit_status:        enc_status,
667                    source_pipe_stderr: stderr_handles.0.into(),
668                    ffmpeg_pipe_stderr: stderr_handles.1.map(|h| h.into()),
669                    stderr:             stderr_handles.2.into(),
670                    stdout:             String::new().into(),
671                });
672            }
673
674            Ok(())
675        })?;
676
677        let extension = match self.encoder {
678            crate::encoder::Encoder::x264 => "264",
679            crate::encoder::Encoder::x265 => "hevc",
680            _ => "ivf",
681        };
682
683        let q_str = crate::encoder::format_q(q);
684        let probe_name = format!("v_{index:05}_{q_str}.{extension}", index = chunk.index);
685
686        Ok(std::path::Path::new(&chunk.temp).join("split").join(&probe_name))
687    }
688
689    #[inline]
690    pub fn parse_probing_statistic(stat: &str) -> anyhow::Result<ProbingStatistic> {
691        Ok(match stat.to_lowercase().as_str() {
692            "auto" => ProbingStatistic {
693                name:  ProbingStatisticName::Automatic,
694                value: None,
695            },
696            "mean" => ProbingStatistic {
697                name:  ProbingStatisticName::Mean,
698                value: None,
699            },
700            "harmonic" => ProbingStatistic {
701                name:  ProbingStatisticName::Harmonic,
702                value: None,
703            },
704            "root-mean-square" => ProbingStatistic {
705                name:  ProbingStatisticName::RootMeanSquare,
706                value: None,
707            },
708            "median" => ProbingStatistic {
709                name:  ProbingStatisticName::Median,
710                value: None,
711            },
712            "mode" => ProbingStatistic {
713                name:  ProbingStatisticName::Mode,
714                value: None,
715            },
716            "minimum" => ProbingStatistic {
717                name:  ProbingStatisticName::Minimum,
718                value: None,
719            },
720            "maximum" => ProbingStatistic {
721                name:  ProbingStatisticName::Maximum,
722                value: None,
723            },
724            probe_statistic if probe_statistic.starts_with("percentile") => {
725                if probe_statistic.matches('=').count() != 1
726                    || !probe_statistic.starts_with("percentile=")
727                {
728                    return Err(anyhow!(
729                        "Probing Statistic percentile must have a value between 0.0 and 100.0 set \
730                         using \"=\" (eg. \"--probing-stat percentile=1\")"
731                    ));
732                }
733                let value = probe_statistic
734                    .split("=")
735                    .last()
736                    .and_then(|s| s.parse::<f64>().ok())
737                    .and_then(|v| (0.0..=100.0).contains(&v).then_some(v))
738                    .ok_or_else(|| {
739                        anyhow!(
740                            "Probing Statistic percentile must be set to a value between 0 and 100"
741                        )
742                    })?;
743                ProbingStatistic {
744                    name:  ProbingStatisticName::Percentile,
745                    value: Some(value),
746                }
747            },
748            probe_statistic if probe_statistic.starts_with("standard-deviation") => {
749                if probe_statistic.matches('=').count() != 1
750                    || !probe_statistic.starts_with("standard-deviation=")
751                {
752                    return Err(anyhow!(
753                        "Probing Statistic standard deviation must have a positive or negative \
754                         value set using \"=\" (eg. \"--probing-stat standard-deviation=-0.25\")"
755                    ));
756                }
757                let value = probe_statistic
758                    .split('=')
759                    .next_back()
760                    .and_then(|s| s.parse::<f64>().ok())
761                    .ok_or_else(|| {
762                        anyhow!("Probing Statistic standard deviation must have a value appended")
763                    })?;
764                ProbingStatistic {
765                    name:  ProbingStatisticName::StandardDeviation,
766                    value: Some(value),
767                }
768            },
769            _ => {
770                return Err(anyhow!("Unknown Probing Statistic: {}", stat));
771            },
772        })
773    }
774
775    #[inline]
776    pub fn parse_target_qp_range(s: &str) -> Result<(f64, f64), String> {
777        if let Some((min_str, max_str)) = s.split_once('-') {
778            let min = min_str.parse::<f64>().map_err(|_| "Invalid range format")?;
779            let max = max_str.parse::<f64>().map_err(|_| "Invalid range format")?;
780            if min >= max {
781                return Err("Min must be < max".to_string());
782            }
783            Ok((min, max))
784        } else {
785            let mut val = s.parse::<f64>().map_err(|_| "Invalid number")?;
786            // Convert 0 to 0.001 to avoid degenerate range issues
787            if val == 0.0 {
788                val = 0.001;
789            }
790            let tol = val * 0.01;
791            Ok((val - tol, val + tol))
792        }
793    }
794
795    #[inline]
796    pub fn parse_interp_method(
797        s: &str,
798    ) -> anyhow::Result<(InterpolationMethod, InterpolationMethod)> {
799        let parts: Vec<&str> = s.split('-').collect();
800        if parts.len() != 2 {
801            return Err(anyhow::anyhow!(
802                "Invalid format. Use: --interp-method method4-method5"
803            ));
804        }
805
806        let method4 = parts[0]
807            .parse::<InterpolationMethod>()
808            .map_err(|_| anyhow::anyhow!("Invalid 4th round method: {}", parts[0]))?;
809        let method5 = parts[1]
810            .parse::<InterpolationMethod>()
811            .map_err(|_| anyhow::anyhow!("Invalid 5th round method: {}", parts[1]))?;
812
813        // Validate methods for correct round
814        match method4 {
815            InterpolationMethod::Linear
816            | InterpolationMethod::Quadratic
817            | InterpolationMethod::Natural => {},
818            _ => {
819                return Err(anyhow::anyhow!(
820                    "Method '{}' not available for 4th round",
821                    parts[0]
822                ))
823            },
824        }
825
826        Ok((method4, method5))
827    }
828
829    #[inline]
830    pub fn parse_qp_range(s: &str) -> Result<(u32, u32), String> {
831        if let Some((min_str, max_str)) = s.split_once('-') {
832            let min = min_str.parse::<u32>().map_err(|_| "Invalid range format")?;
833            let max = max_str.parse::<u32>().map_err(|_| "Invalid range format")?;
834            if min >= max {
835                return Err("Min must be < max".to_string());
836            }
837            Ok((min, max))
838        } else {
839            Err("Quality range must be specified as min-max (e.g., 10-50)".to_string())
840        }
841    }
842
843    #[inline]
844    pub fn parse_probe_res(probe_resolution: &str) -> Result<(u32, u32), String> {
845        let parts: Vec<_> = probe_resolution.split('x').collect();
846        if parts.len() != 2 {
847            return Err(format!(
848                "Invalid probe resolution: {probe_resolution}. Expected widthxheight"
849            ));
850        }
851        let width = parts
852            .first()
853            .expect("Probe resolution has width and height")
854            .parse::<u32>()
855            .map_err(|_| format!("Invalid probe resolution width: {probe_resolution}"))?;
856        let height = parts
857            .get(1)
858            .expect("Probe resolution has width and height")
859            .parse::<u32>()
860            .map_err(|_| format!("Invalid probe resolution height: {probe_resolution}"))?;
861
862        Ok((width, height))
863    }
864
865    #[inline]
866    pub fn validate_probes(probes: u32) -> Result<(u32, Option<String>), String> {
867        match probes {
868            probes if probes >= 4 => Ok((probes, None)),
869            1..4 => Ok((
870                probes,
871                Some("Number of probes is recommended to be at least 4".to_string()),
872            )),
873            _ => Err("Number of probes must be greater than 0".to_string()),
874        }
875    }
876
877    #[inline]
878    pub fn validate_probing_rate(probing_rate: usize) -> Result<(usize, Option<String>), String> {
879        match probing_rate {
880            1..=4 => Ok((probing_rate, None)),
881            _ => Err("Probing rate must be an integer from 1 to 4".to_string()),
882        }
883    }
884}
885
886#[expect(clippy::result_large_err)]
887fn build_encoder_pipe(
888    cmd: &str,
889    args: &[Cow<'_, str>],
890    in_pipe: impl Into<Stdio>,
891) -> Result<Child, EncoderCrash> {
892    std::process::Command::new(cmd)
893        .args(args.iter().map(AsRef::as_ref))
894        .stdin(in_pipe)
895        .stdout(std::process::Stdio::piped())
896        .stderr(std::process::Stdio::piped())
897        .spawn()
898        .map_err(|e| EncoderCrash {
899            exit_status:        std::process::ExitStatus::default(),
900            source_pipe_stderr: String::new().into(),
901            ffmpeg_pipe_stderr: None,
902            stderr:             format!("Failed to spawn encoder: {e}").into(),
903            stdout:             String::new().into(),
904        })
905}
906
907fn predict_quantizer(
908    lower_quantizer_limit: f32,
909    upper_quantizer_limit: f32,
910    quantizer_score_history: &[(f32, f64)],
911    target_range: (f64, f64),
912    interp_method: Option<(InterpolationMethod, InterpolationMethod)>,
913    step: f32,
914) -> anyhow::Result<f32> {
915    let target = f64::midpoint(target_range.0, target_range.1);
916    let binary_search = f32::midpoint(lower_quantizer_limit, upper_quantizer_limit);
917
918    let predicted_quantizer = match quantizer_score_history.len() {
919        0..=1 => binary_search as f64,
920        n => {
921            // Sort history by quantizer
922            let mut sorted = quantizer_score_history.to_vec();
923            sorted.sort_by(|(_, s1), (_, s2)| {
924                s1.partial_cmp(s2).unwrap_or(std::cmp::Ordering::Equal)
925            });
926
927            let (scores, quantizers): (Vec<f64>, Vec<f64>) =
928                sorted.iter().map(|(q, s)| (*s, *q as f64)).unzip();
929
930            let result = match n {
931                2 => {
932                    // 3rd probe: linear interpolation
933                    linear_interpolate(
934                        &[scores[0], scores[1]],
935                        &[quantizers[0], quantizers[1]],
936                        target,
937                    )
938                },
939                3 => {
940                    // 4th probe: configurable method
941                    let method = interp_method.map_or(InterpolationMethod::Natural, |(m, _)| m);
942                    match method {
943                        InterpolationMethod::Linear => linear_interpolate(
944                            &[scores[0], scores[1]],
945                            &[quantizers[0], quantizers[1]],
946                            target,
947                        ),
948                        InterpolationMethod::Quadratic => quadratic_interpolate(
949                            &[scores[0], scores[1], scores[2]],
950                            &[quantizers[0], quantizers[1], quantizers[2]],
951                            target,
952                        ),
953                        InterpolationMethod::Natural => {
954                            natural_cubic_spline(&scores, &quantizers, target)
955                        },
956                        _ => None,
957                    }
958                },
959                4 => {
960                    // 5th probe: configurable method
961                    let method = interp_method.map_or(InterpolationMethod::Pchip, |(_, m)| m);
962                    let s: &[f64; 4] = &scores[..4].try_into()?;
963                    let q: &[f64; 4] = &quantizers[..4].try_into()?;
964
965                    match method {
966                        InterpolationMethod::Linear => {
967                            linear_interpolate(&[s[0], s[1]], &[q[0], q[1]], target)
968                        },
969                        InterpolationMethod::Quadratic => {
970                            quadratic_interpolate(&[s[0], s[1], s[2]], &[q[0], q[1], q[2]], target)
971                        },
972                        InterpolationMethod::Natural => {
973                            natural_cubic_spline(&scores, &quantizers, target)
974                        },
975                        InterpolationMethod::Pchip => pchip_interpolate(s, q, target),
976                        InterpolationMethod::Catmull => catmull_rom_interpolate(s, q, target),
977                        InterpolationMethod::Akima => akima_interpolate(s, q, target),
978                        InterpolationMethod::CubicPolynomial => {
979                            cubic_polynomial_interpolate(s, q, target)
980                        },
981                    }
982                },
983                _ => None,
984            };
985
986            result.unwrap_or_else(|| {
987                trace!("Interpolation failed, falling back to binary search");
988                binary_search as f64
989            })
990        },
991    };
992
993    // Round the result of the interpolation to the nearest integer
994    Ok(
995        (((predicted_quantizer / step as f64).round() * step as f64) as f32)
996            .clamp(lower_quantizer_limit, upper_quantizer_limit),
997    )
998}
999
1000fn within_range(score: f64, target_range: (f64, f64)) -> bool {
1001    score >= target_range.0 && score <= target_range.1
1002}
1003
1004pub fn vmaf_auto_threads(workers: usize) -> usize {
1005    const OVER_PROVISION_FACTOR: f64 = 1.25;
1006
1007    let threads = available_parallelism()
1008        .expect("Unrecoverable: Failed to get thread count")
1009        .get();
1010
1011    cmp::max(
1012        ((threads / workers) as f64 * OVER_PROVISION_FACTOR) as usize,
1013        1,
1014    )
1015}
1016
1017#[derive(Copy, Clone)]
1018pub enum SkipProbingReason {
1019    QuantizerTooHigh,
1020    QuantizerTooLow,
1021    WithinTolerance,
1022    ProbeLimitReached,
1023    None,
1024}
1025
1026#[expect(clippy::too_many_arguments)]
1027pub fn log_probes(
1028    quantizer_score_history: &[(f32, f64)],
1029    metric: TargetMetric,
1030    target: (f64, f64),
1031    frames: u32,
1032    probing_rate: u32,
1033    video_params: Option<&Vec<String>>,
1034    chunk_name: &str,
1035    target_quantizer: f32,
1036    target_score: f64,
1037    skip: SkipProbingReason,
1038) {
1039    // Sort history by quantizer
1040    let mut sorted_quantizer_scores = quantizer_score_history.to_vec();
1041    sorted_quantizer_scores
1042        .sort_by(|(q1, _), (q2, _)| q1.partial_cmp(q2).unwrap_or(std::cmp::Ordering::Equal));
1043    // Butteraugli is an inverse metric and needs to be inverted back before display
1044    if matches!(
1045        metric,
1046        TargetMetric::ButteraugliINF | TargetMetric::Butteraugli3
1047    ) {
1048        sorted_quantizer_scores = sorted_quantizer_scores
1049            .iter()
1050            .map(|(quantizer, score)| (*quantizer, -score))
1051            .collect();
1052    }
1053
1054    debug!(
1055        "chunk {name}: Target={min}-{max}, Metric={target_metric}, P-Rate={rate}, {frame_count} \
1056         frames{custom_params_string}
1057       TQ-Probes: {history:.2?}{suffix}
1058       Final Q={target_quantizer:.2}, Final Score={target_score:.2}",
1059        name = chunk_name,
1060        min = target.0,
1061        max = target.1,
1062        target_metric = metric,
1063        rate = probing_rate,
1064        frame_count = frames,
1065        custom_params_string = video_params
1066            .map(|params| format!(
1067                ", P-Video-Params: {params_string}",
1068                params_string = params.join(" ")
1069            ))
1070            .unwrap_or_default(),
1071        history = sorted_quantizer_scores,
1072        suffix = match skip {
1073            SkipProbingReason::None => "",
1074            SkipProbingReason::QuantizerTooHigh => "Early Skip High Quantizer",
1075            SkipProbingReason::QuantizerTooLow => " Early Skip Low Quantizer",
1076            SkipProbingReason::WithinTolerance => " Early Skip Within Tolerance",
1077            SkipProbingReason::ProbeLimitReached => " Early Skip Probe Limit Reached",
1078        },
1079        target_quantizer = target_quantizer,
1080        target_score = target_score
1081    );
1082}
1083
1084#[cfg(test)]
1085mod tests {
1086    use super::*;
1087
1088    // Full algorithm simulation tests
1089    fn get_score_map(case: usize) -> Vec<(f32, f64)> {
1090        match case {
1091            1 => vec![(35.0, 80.08)],
1092            2 => vec![(17.0, 80.03), (35.0, 65.73)],
1093            3 => vec![(17.0, 83.15), (22.0, 80.02), (35.0, 71.94)],
1094            4 => vec![(17.0, 85.81), (30.0, 80.92), (32.0, 80.01), (35.0, 78.05)],
1095            5 => vec![(35.0, 83.31), (53.0, 81.22), (55.0, 80.03), (61.0, 73.56), (64.0, 67.56)],
1096            6 => vec![
1097                (35.0, 86.99),
1098                (53.0, 84.41),
1099                (57.0, 82.47),
1100                (59.0, 81.14),
1101                (60.0, 80.09),
1102                (61.0, 78.58),
1103                (69.0, 68.57),
1104                (70.0, 64.90),
1105            ],
1106            _ => panic!("Unknown case"),
1107        }
1108    }
1109
1110    fn run_av1an_simulation(case: usize) -> Vec<(f32, f64)> {
1111        let scores = get_score_map(case);
1112        let mut history = vec![];
1113        let mut lo = 1.0f32;
1114        let mut hi = 70.0f32;
1115        let target_range = (79.5, 80.5);
1116
1117        for _ in 1..=10 {
1118            if lo > hi {
1119                break;
1120            }
1121            let next_quantizer = predict_quantizer(lo, hi, &history, target_range, None, 1.0)
1122                .expect("predict_quantizer should succeed");
1123
1124            // Round to nearest available quantizer in test data
1125            let next_quantizer = if let Some((closest_q, _)) =
1126                scores.iter().min_by(|(q1, _), (q2, _)| {
1127                    ((*q1 - next_quantizer).abs())
1128                        .partial_cmp(&((*q2 - next_quantizer).abs()))
1129                        .expect("partial_cmp should succeed")
1130                }) {
1131                *closest_q
1132            } else {
1133                next_quantizer
1134            };
1135
1136            // Check if this quantizer was already probed
1137            if let Some((_quantizer, _score)) =
1138                history.iter().find(|(quantizer, _)| *quantizer == next_quantizer)
1139            {
1140                break;
1141            }
1142
1143            if let Some((_, score)) = scores.iter().find(|(q, _)| *q == next_quantizer) {
1144                history.push((next_quantizer, *score));
1145
1146                if within_range(*score, target_range) {
1147                    break;
1148                }
1149
1150                if *score > target_range.1 {
1151                    lo = (next_quantizer + 1.0).min(hi);
1152                } else if *score < target_range.0 {
1153                    hi = (next_quantizer - 1.0).max(lo);
1154                }
1155            } else {
1156                break;
1157            }
1158        }
1159
1160        history
1161    }
1162
1163    #[test]
1164    fn target_quality_all_cases() {
1165        for case in 1..=6 {
1166            let result = run_av1an_simulation(case);
1167            assert!(!result.is_empty(), "Case {} returned empty result", case);
1168            assert!(
1169                within_range(result.last().expect("result is not empty").1, (79.5, 80.5)),
1170                "Case {} final score not in range",
1171                case
1172            );
1173        }
1174    }
1175}