Skip to main content

oximedia_optimize/rdo/
engine.rs

1//! Rate-distortion optimization engine.
2
3use crate::{OptimizationLevel, OptimizerConfig};
4use oximedia_core::OxiResult;
5
6/// RDO engine for mode decision.
7pub struct RdoEngine {
8    lambda_calc: super::LambdaCalculator,
9    optimization_level: OptimizationLevel,
10    parallel_enabled: bool,
11}
12
13impl RdoEngine {
14    /// Creates a new RDO engine.
15    pub fn new(config: &OptimizerConfig) -> OxiResult<Self> {
16        let lambda_calc = super::LambdaCalculator::new(config.lambda_multiplier, config.level);
17
18        Ok(Self {
19            lambda_calc,
20            optimization_level: config.level,
21            parallel_enabled: config.parallel_rdo,
22        })
23    }
24
25    /// Calculates the rate-distortion cost for a decision.
26    ///
27    /// # Parameters
28    /// - `distortion`: Distortion metric (SSE, SAD, SATD)
29    /// - `rate`: Bit rate for this decision
30    /// - `qp`: Quantization parameter
31    ///
32    /// # Returns
33    /// The RD cost: `distortion + lambda * rate`
34    #[must_use]
35    pub fn calculate_cost(&self, distortion: f64, rate: f64, qp: u8) -> f64 {
36        let lambda = self.lambda_calc.calculate(qp);
37        distortion + lambda * rate
38    }
39
40    /// Evaluates multiple mode decisions and returns the best one.
41    pub fn evaluate_modes<F>(&self, candidates: &[ModeCandidate], eval_fn: F) -> RdoResult
42    where
43        F: Fn(&ModeCandidate) -> (f64, f64) + Send + Sync,
44    {
45        if self.parallel_enabled && candidates.len() > 4 {
46            self.evaluate_parallel(candidates, eval_fn)
47        } else {
48            self.evaluate_sequential(candidates, eval_fn)
49        }
50    }
51
52    fn evaluate_sequential<F>(&self, candidates: &[ModeCandidate], eval_fn: F) -> RdoResult
53    where
54        F: Fn(&ModeCandidate) -> (f64, f64),
55    {
56        let mut best_cost = f64::MAX;
57        let mut best_idx = 0;
58
59        for (idx, candidate) in candidates.iter().enumerate() {
60            let (distortion, rate) = eval_fn(candidate);
61            let cost = self.calculate_cost(distortion, rate, candidate.qp);
62
63            if cost < best_cost {
64                best_cost = cost;
65                best_idx = idx;
66            }
67        }
68
69        RdoResult {
70            best_mode_idx: best_idx,
71            cost: best_cost,
72            distortion: 0.0, // Will be filled by caller
73            rate: 0.0,       // Will be filled by caller
74        }
75    }
76
77    fn evaluate_parallel<F>(&self, candidates: &[ModeCandidate], eval_fn: F) -> RdoResult
78    where
79        F: Fn(&ModeCandidate) -> (f64, f64) + Send + Sync,
80    {
81        use rayon::prelude::*;
82
83        let results: Vec<_> = candidates
84            .par_iter()
85            .enumerate()
86            .map(|(idx, candidate)| {
87                let (distortion, rate) = eval_fn(candidate);
88                let cost = self.calculate_cost(distortion, rate, candidate.qp);
89                (idx, cost, distortion, rate)
90            })
91            .collect();
92
93        let best = results
94            .iter()
95            .min_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
96
97        match best {
98            Some(b) => RdoResult {
99                best_mode_idx: b.0,
100                cost: b.1,
101                distortion: b.2,
102                rate: b.3,
103            },
104            None => RdoResult {
105                best_mode_idx: 0,
106                cost: f64::MAX,
107                distortion: 0.0,
108                rate: 0.0,
109            },
110        }
111    }
112
113    /// Gets the optimization level.
114    #[must_use]
115    pub fn optimization_level(&self) -> OptimizationLevel {
116        self.optimization_level
117    }
118
119    /// Checks if the engine should perform full RDO.
120    #[must_use]
121    pub fn should_perform_full_rdo(&self) -> bool {
122        matches!(
123            self.optimization_level,
124            OptimizationLevel::Slow | OptimizationLevel::Placebo
125        )
126    }
127
128    /// Checks if the engine should use SATD instead of SAD.
129    #[must_use]
130    pub fn should_use_satd(&self) -> bool {
131        !matches!(self.optimization_level, OptimizationLevel::Fast)
132    }
133}
134
135/// Mode candidate for RDO evaluation.
136#[derive(Debug, Clone)]
137pub struct ModeCandidate {
138    /// Mode index.
139    pub mode_idx: usize,
140    /// Quantization parameter.
141    pub qp: u8,
142    /// Additional mode-specific data.
143    pub data: Vec<u8>,
144}
145
146/// Result of RDO optimization.
147#[derive(Debug, Clone)]
148pub struct RdoResult {
149    /// Index of the best mode.
150    pub best_mode_idx: usize,
151    /// Rate-distortion cost.
152    pub cost: f64,
153    /// Distortion component.
154    pub distortion: f64,
155    /// Rate component (in bits).
156    pub rate: f64,
157}
158
159#[cfg(test)]
160mod tests {
161    use super::*;
162
163    #[test]
164    fn test_rdo_engine_creation() {
165        let config = OptimizerConfig::default();
166        let engine = RdoEngine::new(&config).expect("RDO engine creation should succeed");
167        assert_eq!(engine.optimization_level(), OptimizationLevel::Medium);
168    }
169
170    #[test]
171    fn test_cost_calculation() {
172        let config = OptimizerConfig::default();
173        let engine = RdoEngine::new(&config).expect("RDO engine creation should succeed");
174        let cost = engine.calculate_cost(100.0, 50.0, 26);
175        assert!(cost > 100.0); // Should include rate penalty
176    }
177
178    #[test]
179    fn test_mode_evaluation() {
180        let config = OptimizerConfig::default();
181        let engine = RdoEngine::new(&config).expect("RDO engine creation should succeed");
182
183        let candidates = vec![
184            ModeCandidate {
185                mode_idx: 0,
186                qp: 26,
187                data: vec![],
188            },
189            ModeCandidate {
190                mode_idx: 1,
191                qp: 26,
192                data: vec![],
193            },
194        ];
195
196        let result = engine.evaluate_modes(&candidates, |c| {
197            // Simulate: mode 1 has lower distortion but higher rate
198            if c.mode_idx == 0 {
199                (150.0, 40.0)
200            } else {
201                (100.0, 60.0)
202            }
203        });
204
205        assert!(result.best_mode_idx < 2);
206    }
207}