Skip to main content

oximedia_transcode/
rate_distortion.rs

1#![allow(dead_code)]
2//! Rate-distortion analysis for optimal encoding parameter selection.
3//!
4//! Models the trade-off between bitrate and quality to help select the best
5//! CRF, QP, or bitrate for a given quality target. Provides RD-curve fitting,
6//! operating point selection, and Bjontegaard-delta (BD-rate) comparison.
7
8use std::fmt;
9
10/// A single rate-distortion measurement point.
11#[derive(Debug, Clone, Copy, PartialEq)]
12pub struct RdPoint {
13    /// Bitrate in kilobits per second.
14    pub bitrate_kbps: f64,
15    /// Quality metric value (e.g. PSNR in dB, SSIM, VMAF score).
16    pub quality: f64,
17}
18
19impl RdPoint {
20    /// Create a new RD point.
21    #[must_use]
22    pub fn new(bitrate_kbps: f64, quality: f64) -> Self {
23        Self {
24            bitrate_kbps,
25            quality,
26        }
27    }
28}
29
30impl fmt::Display for RdPoint {
31    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
32        write!(f, "({:.1} kbps, {:.2})", self.bitrate_kbps, self.quality)
33    }
34}
35
36/// Quality metric type used in the analysis.
37#[derive(Debug, Clone, Copy, PartialEq, Eq)]
38pub enum QualityMetric {
39    /// Peak Signal-to-Noise Ratio (dB).
40    Psnr,
41    /// Structural Similarity Index.
42    Ssim,
43    /// Video Multimethod Assessment Fusion.
44    Vmaf,
45}
46
47impl fmt::Display for QualityMetric {
48    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
49        match self {
50            Self::Psnr => write!(f, "PSNR"),
51            Self::Ssim => write!(f, "SSIM"),
52            Self::Vmaf => write!(f, "VMAF"),
53        }
54    }
55}
56
57/// An RD curve consisting of multiple measurement points.
58#[derive(Debug, Clone)]
59pub struct RdCurve {
60    /// Label for this curve (e.g. codec name, preset).
61    pub label: String,
62    /// Quality metric used.
63    pub metric: QualityMetric,
64    /// Measurement points sorted by bitrate ascending.
65    points: Vec<RdPoint>,
66}
67
68impl RdCurve {
69    /// Create a new empty RD curve.
70    pub fn new(label: impl Into<String>, metric: QualityMetric) -> Self {
71        Self {
72            label: label.into(),
73            metric,
74            points: Vec::new(),
75        }
76    }
77
78    /// Add a measurement point and maintain sorted order.
79    pub fn add_point(&mut self, point: RdPoint) {
80        self.points.push(point);
81        self.points.sort_by(|a, b| {
82            a.bitrate_kbps
83                .partial_cmp(&b.bitrate_kbps)
84                .unwrap_or(std::cmp::Ordering::Equal)
85        });
86    }
87
88    /// Return the number of points.
89    #[must_use]
90    pub fn point_count(&self) -> usize {
91        self.points.len()
92    }
93
94    /// Return all points as a slice.
95    #[must_use]
96    pub fn points(&self) -> &[RdPoint] {
97        &self.points
98    }
99
100    /// Find the point with the highest quality.
101    #[must_use]
102    pub fn best_quality(&self) -> Option<&RdPoint> {
103        self.points.iter().max_by(|a, b| {
104            a.quality
105                .partial_cmp(&b.quality)
106                .unwrap_or(std::cmp::Ordering::Equal)
107        })
108    }
109
110    /// Find the point with the lowest bitrate.
111    #[must_use]
112    pub fn lowest_bitrate(&self) -> Option<&RdPoint> {
113        self.points.first()
114    }
115
116    /// Find the operating point closest to a target quality.
117    #[must_use]
118    pub fn find_nearest_quality(&self, target: f64) -> Option<&RdPoint> {
119        self.points.iter().min_by(|a, b| {
120            let da = (a.quality - target).abs();
121            let db = (b.quality - target).abs();
122            da.partial_cmp(&db).unwrap_or(std::cmp::Ordering::Equal)
123        })
124    }
125
126    /// Find the operating point closest to a target bitrate.
127    #[must_use]
128    pub fn find_nearest_bitrate(&self, target_kbps: f64) -> Option<&RdPoint> {
129        self.points.iter().min_by(|a, b| {
130            let da = (a.bitrate_kbps - target_kbps).abs();
131            let db = (b.bitrate_kbps - target_kbps).abs();
132            da.partial_cmp(&db).unwrap_or(std::cmp::Ordering::Equal)
133        })
134    }
135
136    /// Linearly interpolate quality at a given bitrate.
137    /// Returns `None` if outside the curve range or fewer than 2 points.
138    #[allow(clippy::cast_precision_loss)]
139    #[must_use]
140    pub fn interpolate_quality(&self, bitrate_kbps: f64) -> Option<f64> {
141        if self.points.len() < 2 {
142            return None;
143        }
144        let first = self.points.first()?;
145        let last = self.points.last()?;
146        if bitrate_kbps < first.bitrate_kbps || bitrate_kbps > last.bitrate_kbps {
147            return None;
148        }
149        // Find the two bounding points
150        for window in self.points.windows(2) {
151            let lo = &window[0];
152            let hi = &window[1];
153            if bitrate_kbps >= lo.bitrate_kbps && bitrate_kbps <= hi.bitrate_kbps {
154                let range = hi.bitrate_kbps - lo.bitrate_kbps;
155                if range.abs() < f64::EPSILON {
156                    return Some(lo.quality);
157                }
158                let t = (bitrate_kbps - lo.bitrate_kbps) / range;
159                return Some(lo.quality + t * (hi.quality - lo.quality));
160            }
161        }
162        None
163    }
164}
165
166/// Compute the average quality difference between two RD curves
167/// over their overlapping bitrate range (simplified BD-rate style comparison).
168/// Positive means `curve_b` has higher quality at the same bitrate.
169#[allow(clippy::cast_precision_loss)]
170#[must_use]
171pub fn average_quality_delta(curve_a: &RdCurve, curve_b: &RdCurve, samples: usize) -> Option<f64> {
172    if curve_a.point_count() < 2 || curve_b.point_count() < 2 || samples == 0 {
173        return None;
174    }
175
176    let a_min = curve_a.points().first()?.bitrate_kbps;
177    let a_max = curve_a.points().last()?.bitrate_kbps;
178    let b_min = curve_b.points().first()?.bitrate_kbps;
179    let b_max = curve_b.points().last()?.bitrate_kbps;
180
181    let lo = a_min.max(b_min);
182    let hi = a_max.min(b_max);
183    if lo >= hi {
184        return None;
185    }
186
187    let step = (hi - lo) / samples as f64;
188    let mut sum = 0.0;
189    let mut count = 0u64;
190
191    let mut br = lo;
192    while br <= hi {
193        if let (Some(qa), Some(qb)) = (
194            curve_a.interpolate_quality(br),
195            curve_b.interpolate_quality(br),
196        ) {
197            sum += qb - qa;
198            count += 1;
199        }
200        br += step;
201    }
202
203    if count == 0 {
204        return None;
205    }
206    Some(sum / count as f64)
207}
208
209/// Compute the efficiency of a point as quality per kbps.
210#[allow(clippy::cast_precision_loss)]
211#[must_use]
212pub fn efficiency(point: &RdPoint) -> f64 {
213    if point.bitrate_kbps.abs() < f64::EPSILON {
214        return 0.0;
215    }
216    point.quality / point.bitrate_kbps
217}
218
219#[cfg(test)]
220mod tests {
221    use super::*;
222
223    fn sample_curve(label: &str) -> RdCurve {
224        let mut c = RdCurve::new(label, QualityMetric::Psnr);
225        c.add_point(RdPoint::new(500.0, 30.0));
226        c.add_point(RdPoint::new(1000.0, 35.0));
227        c.add_point(RdPoint::new(2000.0, 38.0));
228        c.add_point(RdPoint::new(4000.0, 40.0));
229        c
230    }
231
232    #[test]
233    fn test_rd_point_display() {
234        let p = RdPoint::new(1000.0, 35.5);
235        assert_eq!(p.to_string(), "(1000.0 kbps, 35.50)");
236    }
237
238    #[test]
239    fn test_quality_metric_display() {
240        assert_eq!(QualityMetric::Psnr.to_string(), "PSNR");
241        assert_eq!(QualityMetric::Ssim.to_string(), "SSIM");
242        assert_eq!(QualityMetric::Vmaf.to_string(), "VMAF");
243    }
244
245    #[test]
246    fn test_curve_sorted() {
247        let mut c = RdCurve::new("test", QualityMetric::Vmaf);
248        c.add_point(RdPoint::new(2000.0, 90.0));
249        c.add_point(RdPoint::new(500.0, 70.0));
250        c.add_point(RdPoint::new(1000.0, 80.0));
251        assert_eq!(c.points()[0].bitrate_kbps as u64, 500);
252        assert_eq!(c.points()[1].bitrate_kbps as u64, 1000);
253        assert_eq!(c.points()[2].bitrate_kbps as u64, 2000);
254    }
255
256    #[test]
257    fn test_best_quality() {
258        let c = sample_curve("x");
259        let best = c.best_quality().expect("should succeed in test");
260        assert!((best.quality - 40.0).abs() < f64::EPSILON);
261    }
262
263    #[test]
264    fn test_lowest_bitrate() {
265        let c = sample_curve("x");
266        let low = c.lowest_bitrate().expect("should succeed in test");
267        assert!((low.bitrate_kbps - 500.0).abs() < f64::EPSILON);
268    }
269
270    #[test]
271    fn test_find_nearest_quality() {
272        let c = sample_curve("x");
273        let p = c
274            .find_nearest_quality(36.0)
275            .expect("should succeed in test");
276        assert!((p.quality - 35.0).abs() < f64::EPSILON);
277    }
278
279    #[test]
280    fn test_find_nearest_bitrate() {
281        let c = sample_curve("x");
282        let p = c
283            .find_nearest_bitrate(1200.0)
284            .expect("should succeed in test");
285        assert!((p.bitrate_kbps - 1000.0).abs() < f64::EPSILON);
286    }
287
288    #[test]
289    fn test_interpolate_quality_midpoint() {
290        let c = sample_curve("x");
291        let q = c
292            .interpolate_quality(750.0)
293            .expect("should succeed in test");
294        // Midpoint between (500, 30) and (1000, 35) => 32.5
295        assert!((q - 32.5).abs() < 0.01);
296    }
297
298    #[test]
299    fn test_interpolate_quality_out_of_range() {
300        let c = sample_curve("x");
301        assert!(c.interpolate_quality(100.0).is_none());
302        assert!(c.interpolate_quality(5000.0).is_none());
303    }
304
305    #[test]
306    fn test_interpolate_quality_insufficient_points() {
307        let mut c = RdCurve::new("x", QualityMetric::Psnr);
308        c.add_point(RdPoint::new(1000.0, 35.0));
309        assert!(c.interpolate_quality(1000.0).is_none());
310    }
311
312    #[test]
313    fn test_average_quality_delta_same_curve() {
314        let c = sample_curve("x");
315        let delta = average_quality_delta(&c, &c, 10).expect("should succeed in test");
316        assert!(delta.abs() < 0.01);
317    }
318
319    #[test]
320    fn test_average_quality_delta_better_curve() {
321        let a = sample_curve("a");
322        let mut b = RdCurve::new("b", QualityMetric::Psnr);
323        b.add_point(RdPoint::new(500.0, 32.0));
324        b.add_point(RdPoint::new(1000.0, 37.0));
325        b.add_point(RdPoint::new(2000.0, 40.0));
326        b.add_point(RdPoint::new(4000.0, 42.0));
327        let delta = average_quality_delta(&a, &b, 20).expect("should succeed in test");
328        assert!(delta > 0.0, "curve b should be better");
329    }
330
331    #[test]
332    fn test_average_quality_delta_no_overlap() {
333        let mut a = RdCurve::new("a", QualityMetric::Psnr);
334        a.add_point(RdPoint::new(100.0, 20.0));
335        a.add_point(RdPoint::new(200.0, 25.0));
336        let mut b = RdCurve::new("b", QualityMetric::Psnr);
337        b.add_point(RdPoint::new(500.0, 30.0));
338        b.add_point(RdPoint::new(1000.0, 35.0));
339        assert!(average_quality_delta(&a, &b, 10).is_none());
340    }
341
342    #[test]
343    fn test_efficiency() {
344        let p = RdPoint::new(1000.0, 35.0);
345        assert!((efficiency(&p) - 0.035).abs() < 0.001);
346    }
347
348    #[test]
349    fn test_efficiency_zero_bitrate() {
350        let p = RdPoint::new(0.0, 35.0);
351        assert!((efficiency(&p) - 0.0).abs() < f64::EPSILON);
352    }
353}