oximedia_optimize/rdo/
engine.rs1use crate::{OptimizationLevel, OptimizerConfig};
4use oximedia_core::OxiResult;
5
6pub struct RdoEngine {
8 lambda_calc: super::LambdaCalculator,
9 optimization_level: OptimizationLevel,
10 parallel_enabled: bool,
11}
12
13impl RdoEngine {
14 pub fn new(config: &OptimizerConfig) -> OxiResult<Self> {
16 let lambda_calc = super::LambdaCalculator::new(config.lambda_multiplier, config.level);
17
18 Ok(Self {
19 lambda_calc,
20 optimization_level: config.level,
21 parallel_enabled: config.parallel_rdo,
22 })
23 }
24
25 #[must_use]
35 pub fn calculate_cost(&self, distortion: f64, rate: f64, qp: u8) -> f64 {
36 let lambda = self.lambda_calc.calculate(qp);
37 distortion + lambda * rate
38 }
39
40 pub fn evaluate_modes<F>(&self, candidates: &[ModeCandidate], eval_fn: F) -> RdoResult
42 where
43 F: Fn(&ModeCandidate) -> (f64, f64) + Send + Sync,
44 {
45 if self.parallel_enabled && candidates.len() > 4 {
46 self.evaluate_parallel(candidates, eval_fn)
47 } else {
48 self.evaluate_sequential(candidates, eval_fn)
49 }
50 }
51
52 fn evaluate_sequential<F>(&self, candidates: &[ModeCandidate], eval_fn: F) -> RdoResult
53 where
54 F: Fn(&ModeCandidate) -> (f64, f64),
55 {
56 let mut best_cost = f64::MAX;
57 let mut best_idx = 0;
58
59 for (idx, candidate) in candidates.iter().enumerate() {
60 let (distortion, rate) = eval_fn(candidate);
61 let cost = self.calculate_cost(distortion, rate, candidate.qp);
62
63 if cost < best_cost {
64 best_cost = cost;
65 best_idx = idx;
66 }
67 }
68
69 RdoResult {
70 best_mode_idx: best_idx,
71 cost: best_cost,
72 distortion: 0.0, rate: 0.0, }
75 }
76
77 fn evaluate_parallel<F>(&self, candidates: &[ModeCandidate], eval_fn: F) -> RdoResult
78 where
79 F: Fn(&ModeCandidate) -> (f64, f64) + Send + Sync,
80 {
81 use rayon::prelude::*;
82
83 let results: Vec<_> = candidates
84 .par_iter()
85 .enumerate()
86 .map(|(idx, candidate)| {
87 let (distortion, rate) = eval_fn(candidate);
88 let cost = self.calculate_cost(distortion, rate, candidate.qp);
89 (idx, cost, distortion, rate)
90 })
91 .collect();
92
93 let best = results
94 .iter()
95 .min_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
96
97 match best {
98 Some(b) => RdoResult {
99 best_mode_idx: b.0,
100 cost: b.1,
101 distortion: b.2,
102 rate: b.3,
103 },
104 None => RdoResult {
105 best_mode_idx: 0,
106 cost: f64::MAX,
107 distortion: 0.0,
108 rate: 0.0,
109 },
110 }
111 }
112
113 #[must_use]
115 pub fn optimization_level(&self) -> OptimizationLevel {
116 self.optimization_level
117 }
118
119 #[must_use]
121 pub fn should_perform_full_rdo(&self) -> bool {
122 matches!(
123 self.optimization_level,
124 OptimizationLevel::Slow | OptimizationLevel::Placebo
125 )
126 }
127
128 #[must_use]
130 pub fn should_use_satd(&self) -> bool {
131 !matches!(self.optimization_level, OptimizationLevel::Fast)
132 }
133}
134
135#[derive(Debug, Clone)]
137pub struct ModeCandidate {
138 pub mode_idx: usize,
140 pub qp: u8,
142 pub data: Vec<u8>,
144}
145
146#[derive(Debug, Clone)]
148pub struct RdoResult {
149 pub best_mode_idx: usize,
151 pub cost: f64,
153 pub distortion: f64,
155 pub rate: f64,
157}
158
159#[cfg(test)]
160mod tests {
161 use super::*;
162
163 #[test]
164 fn test_rdo_engine_creation() {
165 let config = OptimizerConfig::default();
166 let engine = RdoEngine::new(&config).expect("RDO engine creation should succeed");
167 assert_eq!(engine.optimization_level(), OptimizationLevel::Medium);
168 }
169
170 #[test]
171 fn test_cost_calculation() {
172 let config = OptimizerConfig::default();
173 let engine = RdoEngine::new(&config).expect("RDO engine creation should succeed");
174 let cost = engine.calculate_cost(100.0, 50.0, 26);
175 assert!(cost > 100.0); }
177
178 #[test]
179 fn test_mode_evaluation() {
180 let config = OptimizerConfig::default();
181 let engine = RdoEngine::new(&config).expect("RDO engine creation should succeed");
182
183 let candidates = vec![
184 ModeCandidate {
185 mode_idx: 0,
186 qp: 26,
187 data: vec![],
188 },
189 ModeCandidate {
190 mode_idx: 1,
191 qp: 26,
192 data: vec![],
193 },
194 ];
195
196 let result = engine.evaluate_modes(&candidates, |c| {
197 if c.mode_idx == 0 {
199 (150.0, 40.0)
200 } else {
201 (100.0, 60.0)
202 }
203 });
204
205 assert!(result.best_mode_idx < 2);
206 }
207}