oximedia_optimize/transform/
select.rs1#[derive(Debug, Clone, Copy, PartialEq, Eq)]
5pub enum TransformType {
6 Dct,
8 Adst,
10 Identity,
12 Hybrid,
14}
15
16pub struct TransformSelection {
18 enable_adst: bool,
19 enable_identity: bool,
20}
21
22impl Default for TransformSelection {
23 fn default() -> Self {
24 Self::new(true, false)
25 }
26}
27
28impl TransformSelection {
29 #[must_use]
31 pub fn new(enable_adst: bool, enable_identity: bool) -> Self {
32 Self {
33 enable_adst,
34 enable_identity,
35 }
36 }
37
38 #[allow(dead_code)]
40 #[must_use]
41 pub fn select(&self, residual: &[i16], is_intra: bool) -> TransformType {
42 let candidates = self.candidate_transforms(is_intra);
43 let mut best_transform = TransformType::Dct;
44 let mut best_cost = f64::MAX;
45
46 for &transform in &candidates {
47 let cost = self.evaluate_transform(residual, transform);
48 if cost < best_cost {
49 best_cost = cost;
50 best_transform = transform;
51 }
52 }
53
54 best_transform
55 }
56
57 fn candidate_transforms(&self, is_intra: bool) -> Vec<TransformType> {
58 let mut transforms = vec![TransformType::Dct];
59
60 if self.enable_adst && is_intra {
61 transforms.push(TransformType::Adst);
62 transforms.push(TransformType::Hybrid);
63 }
64
65 if self.enable_identity {
66 transforms.push(TransformType::Identity);
67 }
68
69 transforms
70 }
71
72 fn evaluate_transform(&self, residual: &[i16], transform: TransformType) -> f64 {
73 match transform {
74 TransformType::Dct => self.evaluate_dct(residual),
75 TransformType::Adst => self.evaluate_adst(residual),
76 TransformType::Identity => self.evaluate_identity(residual),
77 TransformType::Hybrid => self.evaluate_hybrid(residual),
78 }
79 }
80
81 fn evaluate_dct(&self, residual: &[i16]) -> f64 {
82 residual.iter().filter(|&&x| x.abs() > 10).count() as f64
84 }
85
86 fn evaluate_adst(&self, residual: &[i16]) -> f64 {
87 self.evaluate_dct(residual) * 0.95
90 }
91
92 fn evaluate_identity(&self, residual: &[i16]) -> f64 {
93 let sum: i32 = residual.iter().map(|&x| i32::from(x.abs())).sum();
95 if sum < 100 {
96 f64::from(sum) * 0.5
97 } else {
98 f64::MAX }
100 }
101
102 fn evaluate_hybrid(&self, residual: &[i16]) -> f64 {
103 (self.evaluate_dct(residual) + self.evaluate_adst(residual)) / 2.0
105 }
106}
107
108#[cfg(test)]
109mod tests {
110 use super::*;
111
112 #[test]
113 fn test_transform_selection_creation() {
114 let selector = TransformSelection::default();
115 assert!(selector.enable_adst);
116 assert!(!selector.enable_identity);
117 }
118
119 #[test]
120 fn test_candidate_transforms_intra() {
121 let selector = TransformSelection::default();
122 let transforms = selector.candidate_transforms(true);
123 assert!(transforms.contains(&TransformType::Dct));
124 assert!(transforms.contains(&TransformType::Adst));
125 }
126
127 #[test]
128 fn test_candidate_transforms_inter() {
129 let selector = TransformSelection::default();
130 let transforms = selector.candidate_transforms(false);
131 assert!(transforms.contains(&TransformType::Dct));
132 assert!(!transforms.contains(&TransformType::Adst)); }
134
135 #[test]
136 fn test_transform_types() {
137 assert_ne!(TransformType::Dct, TransformType::Adst);
138 assert_eq!(TransformType::Dct, TransformType::Dct);
139 }
140}