Skip to main content

oximedia_optimize/transform/
select.rs

1//! Transform type selection.
2
3/// Transform types.
4#[derive(Debug, Clone, Copy, PartialEq, Eq)]
5pub enum TransformType {
6    /// Discrete Cosine Transform.
7    Dct,
8    /// Asymmetric DST.
9    Adst,
10    /// Identity transform.
11    Identity,
12    /// Hybrid DCT/ADST.
13    Hybrid,
14}
15
16/// Transform selection optimizer.
17pub struct TransformSelection {
18    enable_adst: bool,
19    enable_identity: bool,
20}
21
22impl Default for TransformSelection {
23    fn default() -> Self {
24        Self::new(true, false)
25    }
26}
27
28impl TransformSelection {
29    /// Creates a new transform selector.
30    #[must_use]
31    pub fn new(enable_adst: bool, enable_identity: bool) -> Self {
32        Self {
33            enable_adst,
34            enable_identity,
35        }
36    }
37
38    /// Selects the best transform for a block.
39    #[allow(dead_code)]
40    #[must_use]
41    pub fn select(&self, residual: &[i16], is_intra: bool) -> TransformType {
42        let candidates = self.candidate_transforms(is_intra);
43        let mut best_transform = TransformType::Dct;
44        let mut best_cost = f64::MAX;
45
46        for &transform in &candidates {
47            let cost = self.evaluate_transform(residual, transform);
48            if cost < best_cost {
49                best_cost = cost;
50                best_transform = transform;
51            }
52        }
53
54        best_transform
55    }
56
57    fn candidate_transforms(&self, is_intra: bool) -> Vec<TransformType> {
58        let mut transforms = vec![TransformType::Dct];
59
60        if self.enable_adst && is_intra {
61            transforms.push(TransformType::Adst);
62            transforms.push(TransformType::Hybrid);
63        }
64
65        if self.enable_identity {
66            transforms.push(TransformType::Identity);
67        }
68
69        transforms
70    }
71
72    fn evaluate_transform(&self, residual: &[i16], transform: TransformType) -> f64 {
73        match transform {
74            TransformType::Dct => self.evaluate_dct(residual),
75            TransformType::Adst => self.evaluate_adst(residual),
76            TransformType::Identity => self.evaluate_identity(residual),
77            TransformType::Hybrid => self.evaluate_hybrid(residual),
78        }
79    }
80
81    fn evaluate_dct(&self, residual: &[i16]) -> f64 {
82        // Simplified: count non-zero coefficients after transform
83        residual.iter().filter(|&&x| x.abs() > 10).count() as f64
84    }
85
86    fn evaluate_adst(&self, residual: &[i16]) -> f64 {
87        // ADST is better for directional content
88        // Simplified evaluation
89        self.evaluate_dct(residual) * 0.95
90    }
91
92    fn evaluate_identity(&self, residual: &[i16]) -> f64 {
93        // Identity is only good for very low residuals
94        let sum: i32 = residual.iter().map(|&x| i32::from(x.abs())).sum();
95        if sum < 100 {
96            f64::from(sum) * 0.5
97        } else {
98            f64::MAX // Avoid identity for high residuals
99        }
100    }
101
102    fn evaluate_hybrid(&self, residual: &[i16]) -> f64 {
103        // Hybrid can be beneficial for mixed content
104        (self.evaluate_dct(residual) + self.evaluate_adst(residual)) / 2.0
105    }
106}
107
108#[cfg(test)]
109mod tests {
110    use super::*;
111
112    #[test]
113    fn test_transform_selection_creation() {
114        let selector = TransformSelection::default();
115        assert!(selector.enable_adst);
116        assert!(!selector.enable_identity);
117    }
118
119    #[test]
120    fn test_candidate_transforms_intra() {
121        let selector = TransformSelection::default();
122        let transforms = selector.candidate_transforms(true);
123        assert!(transforms.contains(&TransformType::Dct));
124        assert!(transforms.contains(&TransformType::Adst));
125    }
126
127    #[test]
128    fn test_candidate_transforms_inter() {
129        let selector = TransformSelection::default();
130        let transforms = selector.candidate_transforms(false);
131        assert!(transforms.contains(&TransformType::Dct));
132        assert!(!transforms.contains(&TransformType::Adst)); // ADST disabled for inter
133    }
134
135    #[test]
136    fn test_transform_types() {
137        assert_ne!(TransformType::Dct, TransformType::Adst);
138        assert_eq!(TransformType::Dct, TransformType::Dct);
139    }
140}