Skip to main content

oximedia_optimize/entropy/
context.rs

1//! Context modeling optimization.
2
3/// Entropy statistics.
4#[derive(Debug, Clone, Copy)]
5pub struct EntropyStats {
6    /// Total bits used.
7    pub total_bits: u64,
8    /// Number of symbols encoded.
9    pub num_symbols: u64,
10    /// Average bits per symbol.
11    pub avg_bits_per_symbol: f64,
12    /// Compression ratio.
13    pub compression_ratio: f64,
14}
15
16impl Default for EntropyStats {
17    fn default() -> Self {
18        Self {
19            total_bits: 0,
20            num_symbols: 0,
21            avg_bits_per_symbol: 0.0,
22            compression_ratio: 1.0,
23        }
24    }
25}
26
27impl EntropyStats {
28    /// Creates stats from counts.
29    #[must_use]
30    pub fn new(total_bits: u64, num_symbols: u64, uncompressed_bits: u64) -> Self {
31        let avg_bits_per_symbol = if num_symbols > 0 {
32            total_bits as f64 / num_symbols as f64
33        } else {
34            0.0
35        };
36
37        let compression_ratio = if total_bits > 0 {
38            uncompressed_bits as f64 / total_bits as f64
39        } else {
40            1.0
41        };
42
43        Self {
44            total_bits,
45            num_symbols,
46            avg_bits_per_symbol,
47            compression_ratio,
48        }
49    }
50}
51
52/// Context model for entropy coding.
53#[derive(Debug, Clone)]
54pub struct ContextModel {
55    /// Probability states for each context.
56    states: Vec<u8>,
57    /// Number of contexts.
58    num_contexts: usize,
59}
60
61impl ContextModel {
62    /// Creates a new context model.
63    #[must_use]
64    pub fn new(num_contexts: usize) -> Self {
65        // Initialize with neutral probability (state 63 = 0.5 probability)
66        Self {
67            states: vec![63; num_contexts],
68            num_contexts,
69        }
70    }
71
72    /// Updates context state based on symbol.
73    pub fn update(&mut self, context_idx: usize, symbol: bool) {
74        if context_idx >= self.num_contexts {
75            return;
76        }
77
78        let state = &mut self.states[context_idx];
79        if symbol {
80            // Move towards 1
81            *state = state.saturating_add(1).min(126);
82        } else {
83            // Move towards 0
84            *state = state.saturating_sub(1);
85        }
86    }
87
88    /// Gets probability for a context.
89    #[must_use]
90    pub fn get_probability(&self, context_idx: usize) -> f64 {
91        if context_idx >= self.num_contexts {
92            return 0.5;
93        }
94
95        // Convert state to probability
96        f64::from(self.states[context_idx]) / 126.0
97    }
98
99    /// Estimates bit cost for a symbol.
100    #[must_use]
101    pub fn estimate_bit_cost(&self, context_idx: usize, symbol: bool) -> f64 {
102        let prob = self.get_probability(context_idx);
103        let symbol_prob = if symbol { prob } else { 1.0 - prob };
104
105        if symbol_prob > 0.0 {
106            -symbol_prob.log2()
107        } else {
108            16.0 // Maximum cost for impossible symbol
109        }
110    }
111}
112
113/// Context optimizer for entropy coding.
114pub struct ContextOptimizer {
115    models: Vec<ContextModel>,
116    enable_adaptive: bool,
117}
118
119impl Default for ContextOptimizer {
120    fn default() -> Self {
121        Self::new(256, true)
122    }
123}
124
125impl ContextOptimizer {
126    /// Creates a new context optimizer.
127    #[must_use]
128    pub fn new(num_contexts: usize, enable_adaptive: bool) -> Self {
129        Self {
130            models: vec![ContextModel::new(num_contexts)],
131            enable_adaptive,
132        }
133    }
134
135    /// Selects optimal context for a symbol.
136    #[must_use]
137    pub fn select_context(&self, neighbors: &[bool], position: usize) -> usize {
138        // Simple context selection based on neighbors
139        let mut context = 0usize;
140
141        for (i, &neighbor) in neighbors.iter().enumerate().take(4) {
142            if neighbor {
143                context |= 1 << i;
144            }
145        }
146
147        // Add position-based context
148        context += (position % 8) * 16;
149
150        context.min(255)
151    }
152
153    /// Encodes a symbol and updates context.
154    #[allow(dead_code)]
155    pub fn encode_symbol(&mut self, symbol: bool, context_idx: usize, model_idx: usize) -> f64 {
156        if model_idx >= self.models.len() {
157            return 0.0;
158        }
159
160        let cost = self.models[model_idx].estimate_bit_cost(context_idx, symbol);
161
162        if self.enable_adaptive {
163            self.models[model_idx].update(context_idx, symbol);
164        }
165
166        cost
167    }
168
169    /// Calculates entropy for a sequence of symbols.
170    #[must_use]
171    pub fn calculate_entropy(&self, symbols: &[bool]) -> f64 {
172        if symbols.is_empty() {
173            return 0.0;
174        }
175
176        // Calculate empirical probabilities
177        let ones = symbols.iter().filter(|&&s| s).count();
178        let p1 = ones as f64 / symbols.len() as f64;
179        let p0 = 1.0 - p1;
180
181        let mut entropy = 0.0;
182        if p0 > 0.0 {
183            entropy -= p0 * p0.log2();
184        }
185        if p1 > 0.0 {
186            entropy -= p1 * p1.log2();
187        }
188
189        entropy
190    }
191
192    /// Gets statistics for encoding.
193    #[must_use]
194    pub fn get_stats(&self, total_bits: u64, num_symbols: u64) -> EntropyStats {
195        let uncompressed_bits = num_symbols; // 1 bit per symbol if uncompressed
196        EntropyStats::new(total_bits, num_symbols, uncompressed_bits)
197    }
198}
199
200#[cfg(test)]
201mod tests {
202    use super::*;
203
204    #[test]
205    fn test_context_model_creation() {
206        let model = ContextModel::new(128);
207        assert_eq!(model.num_contexts, 128);
208        assert_eq!(model.states.len(), 128);
209    }
210
211    #[test]
212    fn test_context_model_probability() {
213        let model = ContextModel::new(1);
214        let prob = model.get_probability(0);
215        assert!((prob - 0.5).abs() < 0.01); // Initial state should be ~0.5
216    }
217
218    #[test]
219    fn test_context_model_update() {
220        let mut model = ContextModel::new(1);
221        let initial_prob = model.get_probability(0);
222
223        model.update(0, true);
224        let updated_prob = model.get_probability(0);
225        assert!(updated_prob > initial_prob); // Probability should increase
226    }
227
228    #[test]
229    fn test_bit_cost_estimation() {
230        let model = ContextModel::new(1);
231        let cost_true = model.estimate_bit_cost(0, true);
232        let cost_false = model.estimate_bit_cost(0, false);
233        assert!(cost_true > 0.0);
234        assert!(cost_false > 0.0);
235        assert!((cost_true - cost_false).abs() < 0.01); // Should be similar for neutral state
236    }
237
238    #[test]
239    fn test_context_optimizer_creation() {
240        let optimizer = ContextOptimizer::default();
241        assert!(optimizer.enable_adaptive);
242        assert_eq!(optimizer.models.len(), 1);
243    }
244
245    #[test]
246    fn test_context_selection() {
247        let optimizer = ContextOptimizer::default();
248        let neighbors = vec![true, false, true, false];
249        let context = optimizer.select_context(&neighbors, 0);
250        assert!(context < 256);
251    }
252
253    #[test]
254    fn test_entropy_calculation() {
255        let optimizer = ContextOptimizer::default();
256
257        // All same symbols -> entropy = 0
258        let uniform = vec![true; 100];
259        let entropy_uniform = optimizer.calculate_entropy(&uniform);
260        assert_eq!(entropy_uniform, 0.0);
261
262        // 50/50 split -> entropy = 1
263        let mut balanced = vec![true; 50];
264        balanced.extend(vec![false; 50]);
265        let entropy_balanced = optimizer.calculate_entropy(&balanced);
266        assert!((entropy_balanced - 1.0).abs() < 0.01);
267    }
268
269    #[test]
270    fn test_entropy_stats() {
271        let stats = EntropyStats::new(100, 200, 200);
272        assert_eq!(stats.total_bits, 100);
273        assert_eq!(stats.num_symbols, 200);
274        assert_eq!(stats.avg_bits_per_symbol, 0.5);
275        assert_eq!(stats.compression_ratio, 2.0);
276    }
277
278    #[test]
279    fn test_get_stats() {
280        let optimizer = ContextOptimizer::default();
281        let stats = optimizer.get_stats(100, 200);
282        assert_eq!(stats.total_bits, 100);
283        assert_eq!(stats.num_symbols, 200);
284    }
285}