Skip to main content

oximedia_codec/intra/
palette.rs

1//! Palette mode implementation (AV1).
2//!
3//! Palette mode uses a small set of colors (2-8) to represent the block,
4//! with each pixel being an index into the palette. This is particularly
5//! effective for blocks with limited color variation, such as graphics
6//! or screen content.
7//!
8//! # Structure
9//!
10//! - **PaletteInfo**: Contains palette colors and size
11//! - **ColorCache**: Previously used colors for prediction
12//! - **ColorIndexMap**: Per-pixel palette indices
13//!
14//! # Encoding
15//!
16//! 1. Palette colors are signaled in the bitstream
17//! 2. Color indices are entropy coded using context
18//! 3. The decoder reconstructs pixels by looking up palette colors
19
20#![forbid(unsafe_code)]
21#![allow(dead_code)]
22#![allow(clippy::cast_possible_truncation)]
23#![allow(clippy::doc_markdown)]
24
25use super::{BitDepth, BlockDimensions, IntraPredContext, IntraPredictor};
26
27/// Maximum palette size (AV1 supports 2-8 colors).
28pub const MAX_PALETTE_SIZE: usize = 8;
29
30/// Minimum palette size.
31pub const MIN_PALETTE_SIZE: usize = 2;
32
33/// Maximum color cache size.
34pub const MAX_COLOR_CACHE_SIZE: usize = 64;
35
36/// Palette information for a block.
37#[derive(Clone, Debug)]
38pub struct PaletteInfo {
39    /// Palette colors (Y/U/V or R/G/B values).
40    colors: [u16; MAX_PALETTE_SIZE],
41    /// Number of colors in the palette (2-8).
42    size: usize,
43    /// Bit depth of color values.
44    bit_depth: BitDepth,
45}
46
47impl PaletteInfo {
48    /// Create a new empty palette.
49    #[must_use]
50    pub const fn new(bit_depth: BitDepth) -> Self {
51        Self {
52            colors: [0; MAX_PALETTE_SIZE],
53            size: 0,
54            bit_depth,
55        }
56    }
57
58    /// Create a palette with specified colors.
59    #[must_use]
60    pub fn with_colors(colors: &[u16], bit_depth: BitDepth) -> Self {
61        let size = colors.len().min(MAX_PALETTE_SIZE);
62        let mut palette_colors = [0u16; MAX_PALETTE_SIZE];
63        palette_colors[..size].copy_from_slice(&colors[..size]);
64
65        Self {
66            colors: palette_colors,
67            size,
68            bit_depth,
69        }
70    }
71
72    /// Get the palette size.
73    #[must_use]
74    pub const fn size(&self) -> usize {
75        self.size
76    }
77
78    /// Set the palette size.
79    pub fn set_size(&mut self, size: usize) {
80        self.size = size.clamp(MIN_PALETTE_SIZE, MAX_PALETTE_SIZE);
81    }
82
83    /// Get a color at the specified index.
84    #[must_use]
85    pub fn get_color(&self, idx: usize) -> u16 {
86        if idx < self.size {
87            self.colors[idx]
88        } else {
89            0
90        }
91    }
92
93    /// Set a color at the specified index.
94    pub fn set_color(&mut self, idx: usize, color: u16) {
95        if idx < MAX_PALETTE_SIZE {
96            self.colors[idx] = color.min(self.bit_depth.max_value());
97            if idx >= self.size {
98                self.size = idx + 1;
99            }
100        }
101    }
102
103    /// Get all colors as a slice.
104    #[must_use]
105    pub fn colors(&self) -> &[u16] {
106        &self.colors[..self.size]
107    }
108
109    /// Check if the palette is valid.
110    #[must_use]
111    pub const fn is_valid(&self) -> bool {
112        self.size >= MIN_PALETTE_SIZE && self.size <= MAX_PALETTE_SIZE
113    }
114
115    /// Sort colors in ascending order.
116    pub fn sort_colors(&mut self) {
117        self.colors[..self.size].sort_unstable();
118    }
119
120    /// Find the nearest color index for a given value.
121    #[must_use]
122    pub fn find_nearest(&self, value: u16) -> usize {
123        let mut best_idx = 0;
124        let mut best_diff = u32::MAX;
125
126        for (idx, &color) in self.colors[..self.size].iter().enumerate() {
127            let diff = (i32::from(value) - i32::from(color)).unsigned_abs();
128            if diff < best_diff {
129                best_diff = diff;
130                best_idx = idx;
131            }
132        }
133
134        best_idx
135    }
136}
137
138impl Default for PaletteInfo {
139    fn default() -> Self {
140        Self::new(BitDepth::Bits8)
141    }
142}
143
144/// Color cache for palette mode.
145///
146/// Stores recently used colors to improve entropy coding efficiency.
147#[derive(Clone, Debug)]
148pub struct ColorCache {
149    /// Cached colors.
150    colors: Vec<u16>,
151    /// Maximum cache size.
152    max_size: usize,
153    /// Bit depth.
154    bit_depth: BitDepth,
155}
156
157impl ColorCache {
158    /// Create a new color cache.
159    #[must_use]
160    pub fn new(max_size: usize, bit_depth: BitDepth) -> Self {
161        Self {
162            colors: Vec::with_capacity(max_size),
163            max_size,
164            bit_depth,
165        }
166    }
167
168    /// Add a color to the cache.
169    pub fn add(&mut self, color: u16) {
170        // Don't add duplicates
171        if self.colors.contains(&color) {
172            return;
173        }
174
175        if self.colors.len() >= self.max_size {
176            // Remove oldest color
177            self.colors.remove(0);
178        }
179
180        self.colors.push(color);
181    }
182
183    /// Check if a color is in the cache.
184    #[must_use]
185    pub fn contains(&self, color: u16) -> bool {
186        self.colors.contains(&color)
187    }
188
189    /// Find the index of a color in the cache.
190    #[must_use]
191    pub fn find(&self, color: u16) -> Option<usize> {
192        self.colors.iter().position(|&c| c == color)
193    }
194
195    /// Get all cached colors.
196    #[must_use]
197    pub fn colors(&self) -> &[u16] {
198        &self.colors
199    }
200
201    /// Get the cache size.
202    #[must_use]
203    pub fn len(&self) -> usize {
204        self.colors.len()
205    }
206
207    /// Check if the cache is empty.
208    #[must_use]
209    pub fn is_empty(&self) -> bool {
210        self.colors.is_empty()
211    }
212
213    /// Clear the cache.
214    pub fn clear(&mut self) {
215        self.colors.clear();
216    }
217
218    /// Build cache from neighbor samples.
219    pub fn build_from_neighbors(&mut self, top: &[u16], left: &[u16]) {
220        self.clear();
221
222        // Add unique colors from top
223        for &color in top {
224            self.add(color);
225        }
226
227        // Add unique colors from left
228        for &color in left {
229            self.add(color);
230        }
231    }
232}
233
234impl Default for ColorCache {
235    fn default() -> Self {
236        Self::new(MAX_COLOR_CACHE_SIZE, BitDepth::Bits8)
237    }
238}
239
240/// Color index map for a block.
241#[derive(Clone, Debug)]
242pub struct ColorIndexMap {
243    /// Per-pixel palette indices.
244    indices: Vec<u8>,
245    /// Block width.
246    width: usize,
247    /// Block height.
248    height: usize,
249}
250
251impl ColorIndexMap {
252    /// Create a new color index map.
253    #[must_use]
254    pub fn new(width: usize, height: usize) -> Self {
255        Self {
256            indices: vec![0; width * height],
257            width,
258            height,
259        }
260    }
261
262    /// Get the index at a position.
263    #[must_use]
264    pub fn get(&self, x: usize, y: usize) -> u8 {
265        if x < self.width && y < self.height {
266            self.indices[y * self.width + x]
267        } else {
268            0
269        }
270    }
271
272    /// Set the index at a position.
273    pub fn set(&mut self, x: usize, y: usize, idx: u8) {
274        if x < self.width && y < self.height {
275            self.indices[y * self.width + x] = idx;
276        }
277    }
278
279    /// Get all indices as a slice.
280    #[must_use]
281    pub fn indices(&self) -> &[u8] {
282        &self.indices
283    }
284
285    /// Get indices as mutable slice.
286    pub fn indices_mut(&mut self) -> &mut [u8] {
287        &mut self.indices
288    }
289}
290
291/// Palette predictor.
292#[derive(Clone, Debug)]
293pub struct PalettePredictor {
294    /// Palette information.
295    palette: PaletteInfo,
296    /// Color index map.
297    index_map: ColorIndexMap,
298}
299
300impl PalettePredictor {
301    /// Create a new palette predictor.
302    #[must_use]
303    pub fn new(palette: PaletteInfo, width: usize, height: usize) -> Self {
304        Self {
305            palette,
306            index_map: ColorIndexMap::new(width, height),
307        }
308    }
309
310    /// Get the palette info.
311    #[must_use]
312    pub const fn palette(&self) -> &PaletteInfo {
313        &self.palette
314    }
315
316    /// Get mutable palette info.
317    pub fn palette_mut(&mut self) -> &mut PaletteInfo {
318        &mut self.palette
319    }
320
321    /// Get the index map.
322    #[must_use]
323    pub const fn index_map(&self) -> &ColorIndexMap {
324        &self.index_map
325    }
326
327    /// Get mutable index map.
328    pub fn index_map_mut(&mut self) -> &mut ColorIndexMap {
329        &mut self.index_map
330    }
331
332    /// Set a color index at a position.
333    pub fn set_index(&mut self, x: usize, y: usize, idx: u8) {
334        self.index_map.set(x, y, idx);
335    }
336
337    /// Reconstruct the block from palette indices.
338    pub fn reconstruct(&self, output: &mut [u16], stride: usize, dims: BlockDimensions) {
339        for y in 0..dims.height {
340            let row_start = y * stride;
341            for x in 0..dims.width {
342                let idx = self.index_map.get(x, y) as usize;
343                output[row_start + x] = self.palette.get_color(idx);
344            }
345        }
346    }
347}
348
349impl IntraPredictor for PalettePredictor {
350    fn predict(
351        &self,
352        _ctx: &IntraPredContext,
353        output: &mut [u16],
354        stride: usize,
355        dims: BlockDimensions,
356    ) {
357        self.reconstruct(output, stride, dims);
358    }
359}
360
361/// Decode color indices using run-length coding.
362pub fn decode_color_indices_rle(
363    data: &[u8],
364    index_map: &mut ColorIndexMap,
365    palette_size: usize,
366) -> usize {
367    let mut offset = 0;
368    let mut x = 0;
369    let mut y = 0;
370    let width = index_map.width;
371    let height = index_map.height;
372
373    while y < height && offset < data.len() {
374        let color_idx = data[offset] % (palette_size as u8);
375        offset += 1;
376
377        let mut run_length = 1;
378        if offset < data.len() {
379            run_length = data[offset] as usize + 1;
380            offset += 1;
381        }
382
383        for _ in 0..run_length {
384            if y >= height {
385                break;
386            }
387            index_map.set(x, y, color_idx);
388            x += 1;
389            if x >= width {
390                x = 0;
391                y += 1;
392            }
393        }
394    }
395
396    offset
397}
398
399#[cfg(test)]
400mod tests {
401    use super::*;
402
403    #[test]
404    fn test_palette_info_creation() {
405        let palette = PaletteInfo::new(BitDepth::Bits8);
406        assert_eq!(palette.size(), 0);
407        assert!(!palette.is_valid());
408
409        let palette = PaletteInfo::with_colors(&[100, 150, 200], BitDepth::Bits8);
410        assert_eq!(palette.size(), 3);
411        assert!(palette.is_valid());
412        assert_eq!(palette.get_color(0), 100);
413        assert_eq!(palette.get_color(1), 150);
414        assert_eq!(palette.get_color(2), 200);
415    }
416
417    #[test]
418    fn test_palette_set_color() {
419        let mut palette = PaletteInfo::new(BitDepth::Bits8);
420        palette.set_color(0, 100);
421        palette.set_color(1, 200);
422
423        assert_eq!(palette.size(), 2);
424        assert!(palette.is_valid());
425        assert_eq!(palette.get_color(0), 100);
426        assert_eq!(palette.get_color(1), 200);
427    }
428
429    #[test]
430    fn test_palette_find_nearest() {
431        let palette = PaletteInfo::with_colors(&[0, 100, 200, 255], BitDepth::Bits8);
432
433        assert_eq!(palette.find_nearest(0), 0);
434        assert_eq!(palette.find_nearest(50), 0); // Closer to 0 than 100 (dist 50 vs 50, first wins)
435        assert_eq!(palette.find_nearest(60), 1); // Closer to 100 (dist 40 vs 60)
436        assert_eq!(palette.find_nearest(150), 1); // Equal distance to 100 and 200, first wins
437        assert_eq!(palette.find_nearest(160), 2); // Closer to 200 (dist 40 vs 60)
438        assert_eq!(palette.find_nearest(255), 3);
439    }
440
441    #[test]
442    fn test_palette_sort() {
443        let mut palette = PaletteInfo::with_colors(&[200, 50, 150, 100], BitDepth::Bits8);
444        palette.sort_colors();
445
446        assert_eq!(palette.colors(), &[50, 100, 150, 200]);
447    }
448
449    #[test]
450    fn test_color_cache() {
451        let mut cache = ColorCache::new(4, BitDepth::Bits8);
452
453        cache.add(100);
454        cache.add(150);
455        cache.add(200);
456
457        assert_eq!(cache.len(), 3);
458        assert!(cache.contains(100));
459        assert!(cache.contains(150));
460        assert!(!cache.contains(50));
461
462        assert_eq!(cache.find(150), Some(1));
463        assert_eq!(cache.find(50), None);
464    }
465
466    #[test]
467    fn test_color_cache_overflow() {
468        let mut cache = ColorCache::new(3, BitDepth::Bits8);
469
470        cache.add(100);
471        cache.add(150);
472        cache.add(200);
473        cache.add(250); // Should evict 100
474
475        assert_eq!(cache.len(), 3);
476        assert!(!cache.contains(100));
477        assert!(cache.contains(150));
478        assert!(cache.contains(250));
479    }
480
481    #[test]
482    fn test_color_cache_no_duplicates() {
483        let mut cache = ColorCache::new(4, BitDepth::Bits8);
484
485        cache.add(100);
486        cache.add(100);
487        cache.add(100);
488
489        assert_eq!(cache.len(), 1);
490    }
491
492    #[test]
493    fn test_color_index_map() {
494        let mut map = ColorIndexMap::new(4, 4);
495
496        map.set(0, 0, 1);
497        map.set(1, 0, 2);
498        map.set(0, 1, 3);
499
500        assert_eq!(map.get(0, 0), 1);
501        assert_eq!(map.get(1, 0), 2);
502        assert_eq!(map.get(0, 1), 3);
503        assert_eq!(map.get(2, 2), 0); // Default
504    }
505
506    #[test]
507    fn test_palette_predictor() {
508        let palette = PaletteInfo::with_colors(&[0, 128, 255], BitDepth::Bits8);
509        let mut predictor = PalettePredictor::new(palette, 2, 2);
510
511        predictor.set_index(0, 0, 0);
512        predictor.set_index(1, 0, 1);
513        predictor.set_index(0, 1, 2);
514        predictor.set_index(1, 1, 1);
515
516        let dims = BlockDimensions::new(2, 2);
517        let mut output = vec![0u16; 4];
518
519        predictor.reconstruct(&mut output, 2, dims);
520
521        assert_eq!(output[0], 0);
522        assert_eq!(output[1], 128);
523        assert_eq!(output[2], 255);
524        assert_eq!(output[3], 128);
525    }
526
527    #[test]
528    fn test_decode_color_indices_rle() {
529        let mut map = ColorIndexMap::new(4, 2);
530
531        // Color 0 with run of 4, Color 1 with run of 4
532        let data = [0, 3, 1, 3];
533
534        let bytes_read = decode_color_indices_rle(&data, &mut map, 3);
535
536        assert_eq!(bytes_read, 4);
537        // First row: 0, 0, 0, 0
538        assert_eq!(map.get(0, 0), 0);
539        assert_eq!(map.get(3, 0), 0);
540        // Second row: 1, 1, 1, 1
541        assert_eq!(map.get(0, 1), 1);
542        assert_eq!(map.get(3, 1), 1);
543    }
544
545    #[test]
546    fn test_build_cache_from_neighbors() {
547        let mut cache = ColorCache::new(16, BitDepth::Bits8);
548
549        let top = [100, 100, 150, 200];
550        let left = [100, 125, 175, 200];
551
552        cache.build_from_neighbors(&top, &left);
553
554        // Should have unique colors: 100, 150, 200, 125, 175
555        assert!(cache.contains(100));
556        assert!(cache.contains(150));
557        assert!(cache.contains(200));
558        assert!(cache.contains(125));
559        assert!(cache.contains(175));
560    }
561}