Skip to main content

oximedia_codec/gif/
encoder.rs

1//! GIF encoder implementation.
2//!
3//! Supports:
4//! - Color quantization (median cut, octree algorithms)
5//! - Dithering (Floyd-Steinberg, ordered)
6//! - Animation encoding
7//! - Transparency
8//! - Disposal methods
9
10use super::lzw::LzwEncoder;
11use crate::error::{CodecError, CodecResult};
12use crate::frame::VideoFrame;
13use oximedia_core::PixelFormat;
14use std::collections::HashMap;
15use std::io::Write;
16
17/// Maximum colors in GIF palette.
18const MAX_COLORS: usize = 256;
19
20/// GIF signature and version.
21const GIF89A_HEADER: &[u8] = b"GIF89a";
22
23/// Extension introducer.
24const EXTENSION_INTRODUCER: u8 = 0x21;
25
26/// Image separator.
27const IMAGE_SEPARATOR: u8 = 0x2C;
28
29/// Trailer.
30const TRAILER: u8 = 0x3B;
31
32/// Graphics Control Extension label.
33const GRAPHICS_CONTROL_LABEL: u8 = 0xF9;
34
35/// Application Extension label.
36const APPLICATION_LABEL: u8 = 0xFF;
37
38/// Disposal method: No disposal specified.
39#[allow(dead_code)]
40const DISPOSAL_NONE: u8 = 0;
41
42/// Disposal method: Keep frame.
43#[allow(dead_code)]
44const DISPOSAL_KEEP: u8 = 1;
45
46/// Disposal method: Restore to background.
47#[allow(dead_code)]
48const DISPOSAL_BACKGROUND: u8 = 2;
49
50/// Disposal method: Restore to previous.
51#[allow(dead_code)]
52const DISPOSAL_PREVIOUS: u8 = 3;
53
54/// Dithering method.
55#[derive(Debug, Clone, Copy, PartialEq, Eq)]
56pub enum DitheringMethod {
57    /// No dithering.
58    None,
59    /// Floyd-Steinberg dithering.
60    FloydSteinberg,
61    /// Ordered (Bayer) dithering.
62    Ordered,
63}
64
65/// Color quantization method.
66#[derive(Debug, Clone, Copy, PartialEq, Eq)]
67pub enum QuantizationMethod {
68    /// Median cut algorithm.
69    MedianCut,
70    /// Octree algorithm.
71    Octree,
72}
73
74/// GIF encoder configuration.
75#[derive(Debug, Clone)]
76pub struct GifEncoderConfig {
77    /// Number of colors in palette (2-256).
78    pub colors: usize,
79    /// Color quantization method.
80    pub quantization: QuantizationMethod,
81    /// Dithering method.
82    pub dithering: DitheringMethod,
83    /// Transparent color index (None = no transparency).
84    pub transparent_index: Option<u8>,
85    /// Loop count (0 = infinite, 1 = no loop).
86    pub loop_count: u16,
87}
88
89impl Default for GifEncoderConfig {
90    fn default() -> Self {
91        Self {
92            colors: 256,
93            quantization: QuantizationMethod::MedianCut,
94            dithering: DitheringMethod::None,
95            transparent_index: None,
96            loop_count: 0,
97        }
98    }
99}
100
101/// GIF frame configuration.
102#[derive(Debug, Clone)]
103pub struct GifFrameConfig {
104    /// Delay time in hundredths of a second.
105    pub delay_time: u16,
106    /// Disposal method.
107    pub disposal_method: u8,
108    /// Left position on canvas.
109    pub left: u16,
110    /// Top position on canvas.
111    pub top: u16,
112}
113
114impl Default for GifFrameConfig {
115    fn default() -> Self {
116        Self {
117            delay_time: 10, // 100ms
118            disposal_method: DISPOSAL_BACKGROUND,
119            left: 0,
120            top: 0,
121        }
122    }
123}
124
125/// GIF encoder state.
126pub struct GifEncoderState {
127    /// Encoder configuration.
128    config: GifEncoderConfig,
129    /// Canvas width.
130    width: u32,
131    /// Canvas height.
132    height: u32,
133    /// Output buffer.
134    output: Vec<u8>,
135    /// Global color palette.
136    palette: Vec<u8>,
137}
138
139impl GifEncoderState {
140    /// Create a new GIF encoder.
141    pub fn new(width: u32, height: u32, config: GifEncoderConfig) -> CodecResult<Self> {
142        if width == 0 || height == 0 || width > 65535 || height > 65535 {
143            return Err(CodecError::InvalidParameter(format!(
144                "Invalid dimensions: {}x{}",
145                width, height
146            )));
147        }
148
149        if !(2..=256).contains(&config.colors) {
150            return Err(CodecError::InvalidParameter(format!(
151                "Invalid color count: {}",
152                config.colors
153            )));
154        }
155
156        Ok(Self {
157            config,
158            width,
159            height,
160            output: Vec::new(),
161            palette: Vec::new(),
162        })
163    }
164
165    /// Encode frames to GIF.
166    ///
167    /// # Errors
168    ///
169    /// Returns error if encoding fails.
170    pub fn encode(
171        &mut self,
172        frames: &[VideoFrame],
173        frame_configs: &[GifFrameConfig],
174    ) -> CodecResult<Vec<u8>> {
175        if frames.is_empty() {
176            return Err(CodecError::InvalidParameter("No frames to encode".into()));
177        }
178
179        if frames.len() != frame_configs.len() {
180            return Err(CodecError::InvalidParameter(
181                "Frame count mismatch with configs".into(),
182            ));
183        }
184
185        self.output.clear();
186
187        // Generate global palette from all frames
188        self.palette = self.generate_global_palette(frames)?;
189
190        // Write header
191        self.write_header()?;
192
193        // Write logical screen descriptor
194        self.write_screen_descriptor()?;
195
196        // Write global color table
197        let palette = self.palette.clone();
198        self.write_color_table(&palette)?;
199
200        // Write Netscape extension for looping
201        if frames.len() > 1 {
202            self.write_netscape_extension()?;
203        }
204
205        // Write frames
206        for (frame, frame_config) in frames.iter().zip(frame_configs) {
207            self.write_frame(frame, frame_config)?;
208        }
209
210        // Write trailer
211        self.output.write_all(&[TRAILER])?;
212
213        Ok(self.output.clone())
214    }
215
216    /// Write GIF header.
217    fn write_header(&mut self) -> CodecResult<()> {
218        self.output.write_all(GIF89A_HEADER)?;
219        Ok(())
220    }
221
222    /// Write Logical Screen Descriptor.
223    fn write_screen_descriptor(&mut self) -> CodecResult<()> {
224        let width = self.width as u16;
225        let height = self.height as u16;
226
227        self.output.write_all(&width.to_le_bytes())?;
228        self.output.write_all(&height.to_le_bytes())?;
229
230        // Packed field
231        let global_color_table_flag = 1u8 << 7;
232        let color_resolution = 7u8 << 4; // 8 bits per color
233        let sort_flag = 0u8;
234        let size = Self::color_table_size_field(self.palette.len() / 3);
235        let packed = global_color_table_flag | color_resolution | sort_flag | size;
236
237        self.output.write_all(&[packed])?;
238        self.output.write_all(&[0])?; // Background color index
239        self.output.write_all(&[0])?; // Pixel aspect ratio
240
241        Ok(())
242    }
243
244    /// Write color table.
245    fn write_color_table(&mut self, table: &[u8]) -> CodecResult<()> {
246        let size = Self::next_power_of_two(table.len() / 3) * 3;
247        self.output.write_all(table)?;
248
249        // Pad to power of 2
250        if table.len() < size {
251            self.output
252                .resize(self.output.len() + (size - table.len()), 0);
253        }
254
255        Ok(())
256    }
257
258    /// Write Netscape extension for animation looping.
259    fn write_netscape_extension(&mut self) -> CodecResult<()> {
260        self.output.write_all(&[EXTENSION_INTRODUCER])?;
261        self.output.write_all(&[APPLICATION_LABEL])?;
262        self.output.write_all(&[11])?; // Block size
263        self.output.write_all(b"NETSCAPE2.0")?;
264        self.output.write_all(&[3])?; // Sub-block size
265        self.output.write_all(&[1])?; // Loop sub-block ID
266        self.output
267            .write_all(&self.config.loop_count.to_le_bytes())?;
268        self.output.write_all(&[0])?; // Block terminator
269
270        Ok(())
271    }
272
273    /// Write a single frame.
274    fn write_frame(&mut self, frame: &VideoFrame, config: &GifFrameConfig) -> CodecResult<()> {
275        // Write Graphics Control Extension
276        self.write_graphics_control_extension(config)?;
277
278        // Convert frame to indexed colors
279        let rgba_data = self.frame_to_rgba(frame)?;
280        let indices = self.quantize_frame(&rgba_data)?;
281
282        // Write Image Descriptor
283        self.write_image_descriptor(config)?;
284
285        // Compress and write image data
286        self.write_image_data(&indices)?;
287
288        Ok(())
289    }
290
291    /// Write Graphics Control Extension.
292    fn write_graphics_control_extension(&mut self, config: &GifFrameConfig) -> CodecResult<()> {
293        self.output.write_all(&[EXTENSION_INTRODUCER])?;
294        self.output.write_all(&[GRAPHICS_CONTROL_LABEL])?;
295        self.output.write_all(&[4])?; // Block size
296
297        // Packed field
298        let disposal_method = (config.disposal_method & 0x07) << 2;
299        let user_input_flag = 0u8;
300        let transparency_flag = if self.config.transparent_index.is_some() {
301            1u8
302        } else {
303            0u8
304        };
305        let packed = disposal_method | user_input_flag | transparency_flag;
306
307        self.output.write_all(&[packed])?;
308        self.output.write_all(&config.delay_time.to_le_bytes())?;
309        self.output
310            .write_all(&[self.config.transparent_index.unwrap_or(0)])?;
311        self.output.write_all(&[0])?; // Block terminator
312
313        Ok(())
314    }
315
316    /// Write Image Descriptor.
317    fn write_image_descriptor(&mut self, config: &GifFrameConfig) -> CodecResult<()> {
318        self.output.write_all(&[IMAGE_SEPARATOR])?;
319        self.output.write_all(&config.left.to_le_bytes())?;
320        self.output.write_all(&config.top.to_le_bytes())?;
321        self.output.write_all(&(self.width as u16).to_le_bytes())?;
322        self.output.write_all(&(self.height as u16).to_le_bytes())?;
323
324        // Packed field (no local color table, no interlace)
325        self.output.write_all(&[0])?;
326
327        Ok(())
328    }
329
330    /// Write compressed image data.
331    fn write_image_data(&mut self, indices: &[u8]) -> CodecResult<()> {
332        // Calculate LZW minimum code size
333        let color_bits = Self::bits_needed(self.palette.len() / 3);
334        let min_code_size = color_bits.max(2);
335
336        self.output.write_all(&[min_code_size])?;
337
338        // Compress with LZW
339        let mut encoder = LzwEncoder::new(min_code_size)?;
340        let compressed = encoder.compress(indices)?;
341
342        // Write data in sub-blocks
343        let mut offset = 0;
344        while offset < compressed.len() {
345            let block_size = (compressed.len() - offset).min(255);
346            self.output.write_all(&[block_size as u8])?;
347            self.output
348                .write_all(&compressed[offset..offset + block_size])?;
349            offset += block_size;
350        }
351
352        // Block terminator
353        self.output.write_all(&[0])?;
354
355        Ok(())
356    }
357
358    /// Convert VideoFrame to RGBA data.
359    fn frame_to_rgba(&self, frame: &VideoFrame) -> CodecResult<Vec<u8>> {
360        if frame.width != self.width || frame.height != self.height {
361            return Err(CodecError::InvalidParameter(format!(
362                "Frame size {}x{} doesn't match canvas {}x{}",
363                frame.width, frame.height, self.width, self.height
364            )));
365        }
366
367        match frame.format {
368            PixelFormat::Rgba32 => {
369                if frame.planes.is_empty() {
370                    return Err(CodecError::InvalidData("Frame has no planes".into()));
371                }
372                Ok(frame.planes[0].data.to_vec())
373            }
374            PixelFormat::Rgb24 => {
375                if frame.planes.is_empty() {
376                    return Err(CodecError::InvalidData("Frame has no planes".into()));
377                }
378                let rgb = &frame.planes[0].data;
379                let mut rgba = Vec::with_capacity((self.width * self.height * 4) as usize);
380                for chunk in rgb.chunks_exact(3) {
381                    rgba.extend_from_slice(chunk);
382                    rgba.push(255);
383                }
384                Ok(rgba)
385            }
386            _ => Err(CodecError::UnsupportedFeature(format!(
387                "Pixel format {} not supported for GIF encoding",
388                frame.format
389            ))),
390        }
391    }
392
393    /// Generate global palette from all frames using quantization.
394    fn generate_global_palette(&self, frames: &[VideoFrame]) -> CodecResult<Vec<u8>> {
395        // Collect all unique colors from all frames
396        let mut all_colors = Vec::new();
397
398        for frame in frames {
399            let rgba = self.frame_to_rgba(frame)?;
400            for chunk in rgba.chunks_exact(4) {
401                let color = [chunk[0], chunk[1], chunk[2]];
402                all_colors.push(color);
403            }
404        }
405
406        // Apply quantization
407        let palette = match self.config.quantization {
408            QuantizationMethod::MedianCut => self.median_cut_quantize(&all_colors)?,
409            QuantizationMethod::Octree => self.octree_quantize(&all_colors)?,
410        };
411
412        Ok(palette)
413    }
414
415    /// Median cut color quantization.
416    fn median_cut_quantize(&self, colors: &[[u8; 3]]) -> CodecResult<Vec<u8>> {
417        let target_colors = self.config.colors.min(MAX_COLORS);
418
419        // Start with all colors in one bucket
420        let mut buckets = vec![colors.to_vec()];
421
422        // Split buckets until we have enough colors
423        while buckets.len() < target_colors {
424            // Find largest bucket
425            let largest_idx = buckets
426                .iter()
427                .enumerate()
428                .max_by_key(|(_, b)| b.len())
429                .map(|(i, _)| i)
430                .expect("buckets is non-empty inside the while loop");
431
432            let bucket = buckets.remove(largest_idx);
433            if bucket.is_empty() {
434                break;
435            }
436
437            // Find channel with largest range
438            let (mut min_r, mut max_r) = (255, 0);
439            let (mut min_g, mut max_g) = (255, 0);
440            let (mut min_b, mut max_b) = (255, 0);
441
442            for color in &bucket {
443                min_r = min_r.min(color[0]);
444                max_r = max_r.max(color[0]);
445                min_g = min_g.min(color[1]);
446                max_g = max_g.max(color[1]);
447                min_b = min_b.min(color[2]);
448                max_b = max_b.max(color[2]);
449            }
450
451            let range_r = max_r - min_r;
452            let range_g = max_g - min_g;
453            let range_b = max_b - min_b;
454
455            // Sort by channel with largest range
456            let mut bucket = bucket;
457            if range_r >= range_g && range_r >= range_b {
458                bucket.sort_by_key(|c| c[0]);
459            } else if range_g >= range_r && range_g >= range_b {
460                bucket.sort_by_key(|c| c[1]);
461            } else {
462                bucket.sort_by_key(|c| c[2]);
463            }
464
465            // Split at median
466            let mid = bucket.len() / 2;
467            let (left, right) = bucket.split_at(mid);
468            buckets.push(left.to_vec());
469            buckets.push(right.to_vec());
470        }
471
472        // Average colors in each bucket to get palette
473        let mut palette = Vec::with_capacity(target_colors * 3);
474        for bucket in buckets {
475            if bucket.is_empty() {
476                continue;
477            }
478
479            let mut sum_r = 0u32;
480            let mut sum_g = 0u32;
481            let mut sum_b = 0u32;
482
483            for color in &bucket {
484                sum_r += u32::from(color[0]);
485                sum_g += u32::from(color[1]);
486                sum_b += u32::from(color[2]);
487            }
488
489            let count = bucket.len() as u32;
490            palette.push((sum_r / count) as u8);
491            palette.push((sum_g / count) as u8);
492            palette.push((sum_b / count) as u8);
493        }
494
495        Ok(palette)
496    }
497
498    /// Octree color quantization.
499    fn octree_quantize(&self, colors: &[[u8; 3]]) -> CodecResult<Vec<u8>> {
500        let mut tree = OctreeQuantizer::new(self.config.colors);
501
502        for &color in colors {
503            tree.add_color(color);
504        }
505
506        let palette = tree.get_palette();
507        Ok(palette)
508    }
509
510    /// Quantize frame to palette indices.
511    fn quantize_frame(&self, rgba: &[u8]) -> CodecResult<Vec<u8>> {
512        let mut indices = Vec::with_capacity((self.width * self.height) as usize);
513
514        match self.config.dithering {
515            DitheringMethod::None => {
516                for chunk in rgba.chunks_exact(4) {
517                    let color = [chunk[0], chunk[1], chunk[2]];
518                    let index = self.find_closest_color(color);
519                    indices.push(index);
520                }
521            }
522            DitheringMethod::FloydSteinberg => {
523                indices = self.floyd_steinberg_dither(rgba)?;
524            }
525            DitheringMethod::Ordered => {
526                indices = self.ordered_dither(rgba)?;
527            }
528        }
529
530        Ok(indices)
531    }
532
533    /// Find closest color in palette.
534    fn find_closest_color(&self, color: [u8; 3]) -> u8 {
535        let mut best_index = 0;
536        let mut best_distance = u32::MAX;
537
538        for i in 0..(self.palette.len() / 3) {
539            let pal_r = self.palette[i * 3];
540            let pal_g = self.palette[i * 3 + 1];
541            let pal_b = self.palette[i * 3 + 2];
542
543            let dr = i32::from(color[0]) - i32::from(pal_r);
544            let dg = i32::from(color[1]) - i32::from(pal_g);
545            let db = i32::from(color[2]) - i32::from(pal_b);
546
547            let distance = (dr * dr + dg * dg + db * db) as u32;
548
549            if distance < best_distance {
550                best_distance = distance;
551                best_index = i;
552            }
553        }
554
555        best_index as u8
556    }
557
558    /// Floyd-Steinberg dithering.
559    #[allow(clippy::cast_possible_wrap)]
560    fn floyd_steinberg_dither(&self, rgba: &[u8]) -> CodecResult<Vec<u8>> {
561        let width = self.width as usize;
562        let height = self.height as usize;
563
564        // Create error buffer
565        let mut errors = vec![[0i16; 3]; width * height];
566        let mut indices = Vec::with_capacity(width * height);
567
568        for y in 0..height {
569            for x in 0..width {
570                let idx = (y * width + x) * 4;
571                let pixel_idx = y * width + x;
572
573                // Get original color with accumulated error
574                let r = (i16::from(rgba[idx]) + errors[pixel_idx][0]).clamp(0, 255) as u8;
575                let g = (i16::from(rgba[idx + 1]) + errors[pixel_idx][1]).clamp(0, 255) as u8;
576                let b = (i16::from(rgba[idx + 2]) + errors[pixel_idx][2]).clamp(0, 255) as u8;
577
578                // Find closest palette color
579                let color = [r, g, b];
580                let index = self.find_closest_color(color);
581                indices.push(index);
582
583                // Calculate error
584                let pal_r = self.palette[index as usize * 3];
585                let pal_g = self.palette[index as usize * 3 + 1];
586                let pal_b = self.palette[index as usize * 3 + 2];
587
588                let err_r = i16::from(r) - i16::from(pal_r);
589                let err_g = i16::from(g) - i16::from(pal_g);
590                let err_b = i16::from(b) - i16::from(pal_b);
591
592                // Distribute error to neighbors
593                if x + 1 < width {
594                    let next_idx = pixel_idx + 1;
595                    errors[next_idx][0] += err_r * 7 / 16;
596                    errors[next_idx][1] += err_g * 7 / 16;
597                    errors[next_idx][2] += err_b * 7 / 16;
598                }
599
600                if y + 1 < height {
601                    if x > 0 {
602                        let next_idx = pixel_idx + width - 1;
603                        errors[next_idx][0] += err_r * 3 / 16;
604                        errors[next_idx][1] += err_g * 3 / 16;
605                        errors[next_idx][2] += err_b * 3 / 16;
606                    }
607
608                    let next_idx = pixel_idx + width;
609                    errors[next_idx][0] += err_r * 5 / 16;
610                    errors[next_idx][1] += err_g * 5 / 16;
611                    errors[next_idx][2] += err_b * 5 / 16;
612
613                    if x + 1 < width {
614                        let next_idx = pixel_idx + width + 1;
615                        errors[next_idx][0] += err_r / 16;
616                        errors[next_idx][1] += err_g / 16;
617                        errors[next_idx][2] += err_b / 16;
618                    }
619                }
620            }
621        }
622
623        Ok(indices)
624    }
625
626    /// Ordered (Bayer) dithering.
627    fn ordered_dither(&self, rgba: &[u8]) -> CodecResult<Vec<u8>> {
628        // 4x4 Bayer matrix
629        #[rustfmt::skip]
630        const BAYER_MATRIX: [[i16; 4]; 4] = [
631            [  0,  8,  2, 10 ],
632            [ 12,  4, 14,  6 ],
633            [  3, 11,  1,  9 ],
634            [ 15,  7, 13,  5 ],
635        ];
636
637        let width = self.width as usize;
638        let height = self.height as usize;
639        let mut indices = Vec::with_capacity(width * height);
640
641        for y in 0..height {
642            for x in 0..width {
643                let idx = (y * width + x) * 4;
644
645                // Apply Bayer matrix threshold
646                let threshold = BAYER_MATRIX[y % 4][x % 4] * 16 - 128;
647
648                let r = (i16::from(rgba[idx]) + threshold).clamp(0, 255) as u8;
649                let g = (i16::from(rgba[idx + 1]) + threshold).clamp(0, 255) as u8;
650                let b = (i16::from(rgba[idx + 2]) + threshold).clamp(0, 255) as u8;
651
652                let color = [r, g, b];
653                let index = self.find_closest_color(color);
654                indices.push(index);
655            }
656        }
657
658        Ok(indices)
659    }
660
661    /// Calculate size field for color table.
662    fn color_table_size_field(colors: usize) -> u8 {
663        let size = Self::next_power_of_two(colors);
664        let bits = Self::bits_needed(size);
665        bits.saturating_sub(1)
666    }
667
668    /// Calculate next power of two.
669    fn next_power_of_two(n: usize) -> usize {
670        let mut power = 2;
671        while power < n {
672            power *= 2;
673        }
674        power
675    }
676
677    /// Calculate bits needed to represent n values.
678    fn bits_needed(n: usize) -> u8 {
679        if n <= 2 {
680            1
681        } else {
682            (n as f64).log2().ceil() as u8
683        }
684    }
685}
686
687/// Octree node for color quantization.
688struct OctreeNode {
689    children: [Option<Box<OctreeNode>>; 8],
690    color_sum: [u32; 3],
691    pixel_count: u32,
692    is_leaf: bool,
693}
694
695impl OctreeNode {
696    fn new() -> Self {
697        Self {
698            children: Default::default(),
699            color_sum: [0, 0, 0],
700            pixel_count: 0,
701            is_leaf: false,
702        }
703    }
704}
705
706/// Octree quantizer for color reduction.
707struct OctreeQuantizer {
708    root: OctreeNode,
709    #[allow(dead_code)]
710    max_colors: usize,
711    leaf_count: usize,
712}
713
714impl OctreeQuantizer {
715    fn new(max_colors: usize) -> Self {
716        Self {
717            root: OctreeNode::new(),
718            max_colors,
719            leaf_count: 0,
720        }
721    }
722
723    fn add_color(&mut self, color: [u8; 3]) {
724        Self::add_color_recursive(&mut self.root, color, 0, &mut self.leaf_count);
725    }
726
727    fn add_color_recursive(
728        node: &mut OctreeNode,
729        color: [u8; 3],
730        depth: u8,
731        leaf_count: &mut usize,
732    ) {
733        if depth >= 8 || node.is_leaf {
734            node.color_sum[0] += u32::from(color[0]);
735            node.color_sum[1] += u32::from(color[1]);
736            node.color_sum[2] += u32::from(color[2]);
737            node.pixel_count += 1;
738            if !node.is_leaf {
739                node.is_leaf = true;
740                *leaf_count += 1;
741            }
742            return;
743        }
744
745        let index = Self::get_child_index(color, depth);
746
747        if node.children[index].is_none() {
748            node.children[index] = Some(Box::new(OctreeNode::new()));
749        }
750
751        if let Some(child) = &mut node.children[index] {
752            Self::add_color_recursive(child, color, depth + 1, leaf_count);
753        }
754    }
755
756    fn get_child_index(color: [u8; 3], depth: u8) -> usize {
757        let shift = 7 - depth;
758        let r_bit = ((color[0] >> shift) & 1) as usize;
759        let g_bit = ((color[1] >> shift) & 1) as usize;
760        let b_bit = ((color[2] >> shift) & 1) as usize;
761        (r_bit << 2) | (g_bit << 1) | b_bit
762    }
763
764    fn get_palette(&self) -> Vec<u8> {
765        let mut palette = Vec::new();
766        self.collect_colors(&self.root, &mut palette);
767        palette
768    }
769
770    fn collect_colors(&self, node: &OctreeNode, palette: &mut Vec<u8>) {
771        if node.is_leaf && node.pixel_count > 0 {
772            let r = (node.color_sum[0] / node.pixel_count) as u8;
773            let g = (node.color_sum[1] / node.pixel_count) as u8;
774            let b = (node.color_sum[2] / node.pixel_count) as u8;
775            palette.extend_from_slice(&[r, g, b]);
776            return;
777        }
778
779        for child in &node.children {
780            if let Some(child) = child {
781                self.collect_colors(child, palette);
782            }
783        }
784    }
785}
786
787#[cfg(test)]
788mod tests {
789    use super::*;
790
791    #[test]
792    fn test_color_table_size_field() {
793        assert_eq!(GifEncoderState::color_table_size_field(2), 0);
794        assert_eq!(GifEncoderState::color_table_size_field(4), 1);
795        assert_eq!(GifEncoderState::color_table_size_field(256), 7);
796    }
797
798    #[test]
799    fn test_next_power_of_two() {
800        assert_eq!(GifEncoderState::next_power_of_two(1), 2);
801        assert_eq!(GifEncoderState::next_power_of_two(3), 4);
802        assert_eq!(GifEncoderState::next_power_of_two(100), 128);
803    }
804
805    #[test]
806    fn test_bits_needed() {
807        assert_eq!(GifEncoderState::bits_needed(2), 1);
808        assert_eq!(GifEncoderState::bits_needed(4), 2);
809        assert_eq!(GifEncoderState::bits_needed(256), 8);
810    }
811}