Skip to main content

oximedia_codec/gif/
encoder.rs

1//! GIF encoder implementation.
2//!
3//! Supports:
4//! - Color quantization (median cut, octree algorithms)
5//! - Dithering (Floyd-Steinberg, ordered)
6//! - Animation encoding
7//! - Transparency
8//! - Disposal methods
9
10use super::lzw::LzwEncoder;
11use crate::error::{CodecError, CodecResult};
12use crate::frame::VideoFrame;
13use oximedia_core::PixelFormat;
14use std::collections::HashMap;
15use std::io::Write;
16
17/// Maximum colors in GIF palette.
18const MAX_COLORS: usize = 256;
19
20/// GIF signature and version.
21const GIF89A_HEADER: &[u8] = b"GIF89a";
22
23/// Extension introducer.
24const EXTENSION_INTRODUCER: u8 = 0x21;
25
26/// Image separator.
27const IMAGE_SEPARATOR: u8 = 0x2C;
28
29/// Trailer.
30const TRAILER: u8 = 0x3B;
31
32/// Graphics Control Extension label.
33const GRAPHICS_CONTROL_LABEL: u8 = 0xF9;
34
35/// Application Extension label.
36const APPLICATION_LABEL: u8 = 0xFF;
37
38/// Disposal method: No disposal specified.
39#[allow(dead_code)]
40const DISPOSAL_NONE: u8 = 0;
41
42/// Disposal method: Keep frame.
43#[allow(dead_code)]
44const DISPOSAL_KEEP: u8 = 1;
45
46/// Disposal method: Restore to background.
47#[allow(dead_code)]
48const DISPOSAL_BACKGROUND: u8 = 2;
49
50/// Disposal method: Restore to previous.
51#[allow(dead_code)]
52const DISPOSAL_PREVIOUS: u8 = 3;
53
54/// Dithering method.
55#[derive(Debug, Clone, Copy, PartialEq, Eq)]
56pub enum DitheringMethod {
57    /// No dithering.
58    None,
59    /// Floyd-Steinberg dithering.
60    FloydSteinberg,
61    /// Ordered (Bayer) dithering.
62    Ordered,
63}
64
65/// Color quantization method.
66#[derive(Debug, Clone, Copy, PartialEq, Eq)]
67pub enum QuantizationMethod {
68    /// Median cut algorithm.
69    MedianCut,
70    /// Octree algorithm.
71    Octree,
72}
73
74/// GIF encoder configuration.
75#[derive(Debug, Clone)]
76pub struct GifEncoderConfig {
77    /// Number of colors in palette (2-256).
78    pub colors: usize,
79    /// Color quantization method.
80    pub quantization: QuantizationMethod,
81    /// Dithering method.
82    pub dithering: DitheringMethod,
83    /// Transparent color index (None = no transparency).
84    pub transparent_index: Option<u8>,
85    /// Loop count (0 = infinite, 1 = no loop).
86    pub loop_count: u16,
87}
88
89impl Default for GifEncoderConfig {
90    fn default() -> Self {
91        Self {
92            colors: 256,
93            quantization: QuantizationMethod::MedianCut,
94            dithering: DitheringMethod::None,
95            transparent_index: None,
96            loop_count: 0,
97        }
98    }
99}
100
101/// GIF frame configuration.
102#[derive(Debug, Clone)]
103pub struct GifFrameConfig {
104    /// Delay time in hundredths of a second.
105    pub delay_time: u16,
106    /// Disposal method.
107    pub disposal_method: u8,
108    /// Left position on canvas.
109    pub left: u16,
110    /// Top position on canvas.
111    pub top: u16,
112}
113
114impl Default for GifFrameConfig {
115    fn default() -> Self {
116        Self {
117            delay_time: 10, // 100ms
118            disposal_method: DISPOSAL_BACKGROUND,
119            left: 0,
120            top: 0,
121        }
122    }
123}
124
125/// GIF encoder state.
126pub struct GifEncoderState {
127    /// Encoder configuration.
128    config: GifEncoderConfig,
129    /// Canvas width.
130    width: u32,
131    /// Canvas height.
132    height: u32,
133    /// Output buffer.
134    output: Vec<u8>,
135    /// Global color palette.
136    palette: Vec<u8>,
137}
138
139impl GifEncoderState {
140    /// Create a new GIF encoder.
141    pub fn new(width: u32, height: u32, config: GifEncoderConfig) -> CodecResult<Self> {
142        if width == 0 || height == 0 || width > 65535 || height > 65535 {
143            return Err(CodecError::InvalidParameter(format!(
144                "Invalid dimensions: {}x{}",
145                width, height
146            )));
147        }
148
149        if !(2..=256).contains(&config.colors) {
150            return Err(CodecError::InvalidParameter(format!(
151                "Invalid color count: {}",
152                config.colors
153            )));
154        }
155
156        Ok(Self {
157            config,
158            width,
159            height,
160            output: Vec::new(),
161            palette: Vec::new(),
162        })
163    }
164
165    /// Encode frames to GIF.
166    ///
167    /// # Errors
168    ///
169    /// Returns error if encoding fails.
170    pub fn encode(
171        &mut self,
172        frames: &[VideoFrame],
173        frame_configs: &[GifFrameConfig],
174    ) -> CodecResult<Vec<u8>> {
175        if frames.is_empty() {
176            return Err(CodecError::InvalidParameter("No frames to encode".into()));
177        }
178
179        if frames.len() != frame_configs.len() {
180            return Err(CodecError::InvalidParameter(
181                "Frame count mismatch with configs".into(),
182            ));
183        }
184
185        self.output.clear();
186
187        // Generate global palette from all frames
188        self.palette = self.generate_global_palette(frames)?;
189
190        // Write header
191        self.write_header()?;
192
193        // Write logical screen descriptor
194        self.write_screen_descriptor()?;
195
196        // Write global color table
197        let palette = self.palette.clone();
198        self.write_color_table(&palette)?;
199
200        // Write Netscape extension for looping
201        if frames.len() > 1 {
202            self.write_netscape_extension()?;
203        }
204
205        // Write frames
206        for (frame, frame_config) in frames.iter().zip(frame_configs) {
207            self.write_frame(frame, frame_config)?;
208        }
209
210        // Write trailer
211        self.output.write_all(&[TRAILER])?;
212
213        Ok(self.output.clone())
214    }
215
216    /// Write GIF header.
217    fn write_header(&mut self) -> CodecResult<()> {
218        self.output.write_all(GIF89A_HEADER)?;
219        Ok(())
220    }
221
222    /// Write Logical Screen Descriptor.
223    fn write_screen_descriptor(&mut self) -> CodecResult<()> {
224        let width = self.width as u16;
225        let height = self.height as u16;
226
227        self.output.write_all(&width.to_le_bytes())?;
228        self.output.write_all(&height.to_le_bytes())?;
229
230        // Packed field
231        let global_color_table_flag = 1u8 << 7;
232        let color_resolution = 7u8 << 4; // 8 bits per color
233        let sort_flag = 0u8;
234        let size = Self::color_table_size_field(self.palette.len() / 3);
235        let packed = global_color_table_flag | color_resolution | sort_flag | size;
236
237        self.output.write_all(&[packed])?;
238        self.output.write_all(&[0])?; // Background color index
239        self.output.write_all(&[0])?; // Pixel aspect ratio
240
241        Ok(())
242    }
243
244    /// Write color table.
245    fn write_color_table(&mut self, table: &[u8]) -> CodecResult<()> {
246        let size = Self::next_power_of_two(table.len() / 3) * 3;
247        self.output.write_all(table)?;
248
249        // Pad to power of 2
250        if table.len() < size {
251            self.output
252                .resize(self.output.len() + (size - table.len()), 0);
253        }
254
255        Ok(())
256    }
257
258    /// Write Netscape extension for animation looping.
259    fn write_netscape_extension(&mut self) -> CodecResult<()> {
260        self.output.write_all(&[EXTENSION_INTRODUCER])?;
261        self.output.write_all(&[APPLICATION_LABEL])?;
262        self.output.write_all(&[11])?; // Block size
263        self.output.write_all(b"NETSCAPE2.0")?;
264        self.output.write_all(&[3])?; // Sub-block size
265        self.output.write_all(&[1])?; // Loop sub-block ID
266        self.output
267            .write_all(&self.config.loop_count.to_le_bytes())?;
268        self.output.write_all(&[0])?; // Block terminator
269
270        Ok(())
271    }
272
273    /// Write a single frame.
274    fn write_frame(&mut self, frame: &VideoFrame, config: &GifFrameConfig) -> CodecResult<()> {
275        // Write Graphics Control Extension
276        self.write_graphics_control_extension(config)?;
277
278        // Convert frame to indexed colors
279        let rgba_data = self.frame_to_rgba(frame)?;
280        let indices = self.quantize_frame(&rgba_data)?;
281
282        // Write Image Descriptor
283        self.write_image_descriptor(config)?;
284
285        // Compress and write image data
286        self.write_image_data(&indices)?;
287
288        Ok(())
289    }
290
291    /// Write Graphics Control Extension.
292    fn write_graphics_control_extension(&mut self, config: &GifFrameConfig) -> CodecResult<()> {
293        self.output.write_all(&[EXTENSION_INTRODUCER])?;
294        self.output.write_all(&[GRAPHICS_CONTROL_LABEL])?;
295        self.output.write_all(&[4])?; // Block size
296
297        // Packed field
298        let disposal_method = (config.disposal_method & 0x07) << 2;
299        let user_input_flag = 0u8;
300        let transparency_flag = if self.config.transparent_index.is_some() {
301            1u8
302        } else {
303            0u8
304        };
305        let packed = disposal_method | user_input_flag | transparency_flag;
306
307        self.output.write_all(&[packed])?;
308        self.output.write_all(&config.delay_time.to_le_bytes())?;
309        self.output
310            .write_all(&[self.config.transparent_index.unwrap_or(0)])?;
311        self.output.write_all(&[0])?; // Block terminator
312
313        Ok(())
314    }
315
316    /// Write Image Descriptor.
317    fn write_image_descriptor(&mut self, config: &GifFrameConfig) -> CodecResult<()> {
318        self.output.write_all(&[IMAGE_SEPARATOR])?;
319        self.output.write_all(&config.left.to_le_bytes())?;
320        self.output.write_all(&config.top.to_le_bytes())?;
321        self.output.write_all(&(self.width as u16).to_le_bytes())?;
322        self.output.write_all(&(self.height as u16).to_le_bytes())?;
323
324        // Packed field (no local color table, no interlace)
325        self.output.write_all(&[0])?;
326
327        Ok(())
328    }
329
330    /// Write compressed image data.
331    fn write_image_data(&mut self, indices: &[u8]) -> CodecResult<()> {
332        // Calculate LZW minimum code size
333        let color_bits = Self::bits_needed(self.palette.len() / 3);
334        let min_code_size = color_bits.max(2);
335
336        self.output.write_all(&[min_code_size])?;
337
338        // Compress with LZW
339        let mut encoder = LzwEncoder::new(min_code_size)?;
340        let compressed = encoder.compress(indices)?;
341
342        // Write data in sub-blocks
343        let mut offset = 0;
344        while offset < compressed.len() {
345            let block_size = (compressed.len() - offset).min(255);
346            self.output.write_all(&[block_size as u8])?;
347            self.output
348                .write_all(&compressed[offset..offset + block_size])?;
349            offset += block_size;
350        }
351
352        // Block terminator
353        self.output.write_all(&[0])?;
354
355        Ok(())
356    }
357
358    /// Convert VideoFrame to RGBA data.
359    fn frame_to_rgba(&self, frame: &VideoFrame) -> CodecResult<Vec<u8>> {
360        if frame.width != self.width || frame.height != self.height {
361            return Err(CodecError::InvalidParameter(format!(
362                "Frame size {}x{} doesn't match canvas {}x{}",
363                frame.width, frame.height, self.width, self.height
364            )));
365        }
366
367        match frame.format {
368            PixelFormat::Rgba32 => {
369                if frame.planes.is_empty() {
370                    return Err(CodecError::InvalidData("Frame has no planes".into()));
371                }
372                Ok(frame.planes[0].data.to_vec())
373            }
374            PixelFormat::Rgb24 => {
375                if frame.planes.is_empty() {
376                    return Err(CodecError::InvalidData("Frame has no planes".into()));
377                }
378                let rgb = &frame.planes[0].data;
379                let mut rgba = Vec::with_capacity((self.width * self.height * 4) as usize);
380                for chunk in rgb.chunks_exact(3) {
381                    rgba.extend_from_slice(chunk);
382                    rgba.push(255);
383                }
384                Ok(rgba)
385            }
386            _ => Err(CodecError::UnsupportedFeature(format!(
387                "Pixel format {} not supported for GIF encoding",
388                frame.format
389            ))),
390        }
391    }
392
393    /// Generate global palette from all frames using quantization.
394    fn generate_global_palette(&self, frames: &[VideoFrame]) -> CodecResult<Vec<u8>> {
395        // Collect all unique colors from all frames
396        let mut all_colors = Vec::new();
397
398        for frame in frames {
399            let rgba = self.frame_to_rgba(frame)?;
400            for chunk in rgba.chunks_exact(4) {
401                let color = [chunk[0], chunk[1], chunk[2]];
402                all_colors.push(color);
403            }
404        }
405
406        // Apply quantization
407        let palette = match self.config.quantization {
408            QuantizationMethod::MedianCut => self.median_cut_quantize(&all_colors)?,
409            QuantizationMethod::Octree => self.octree_quantize(&all_colors)?,
410        };
411
412        Ok(palette)
413    }
414
415    /// Median cut color quantization.
416    fn median_cut_quantize(&self, colors: &[[u8; 3]]) -> CodecResult<Vec<u8>> {
417        let target_colors = self.config.colors.min(MAX_COLORS);
418
419        // Start with all colors in one bucket
420        let mut buckets = vec![colors.to_vec()];
421
422        // Split buckets until we have enough colors
423        while buckets.len() < target_colors {
424            // Find largest bucket (buckets is non-empty: guarded by while condition)
425            let Some(largest_idx) = buckets
426                .iter()
427                .enumerate()
428                .max_by_key(|(_, b)| b.len())
429                .map(|(i, _)| i)
430            else {
431                break;
432            };
433
434            let bucket = buckets.remove(largest_idx);
435            if bucket.is_empty() {
436                break;
437            }
438
439            // Find channel with largest range
440            let (mut min_r, mut max_r) = (255, 0);
441            let (mut min_g, mut max_g) = (255, 0);
442            let (mut min_b, mut max_b) = (255, 0);
443
444            for color in &bucket {
445                min_r = min_r.min(color[0]);
446                max_r = max_r.max(color[0]);
447                min_g = min_g.min(color[1]);
448                max_g = max_g.max(color[1]);
449                min_b = min_b.min(color[2]);
450                max_b = max_b.max(color[2]);
451            }
452
453            let range_r = max_r - min_r;
454            let range_g = max_g - min_g;
455            let range_b = max_b - min_b;
456
457            // Sort by channel with largest range
458            let mut bucket = bucket;
459            if range_r >= range_g && range_r >= range_b {
460                bucket.sort_by_key(|c| c[0]);
461            } else if range_g >= range_r && range_g >= range_b {
462                bucket.sort_by_key(|c| c[1]);
463            } else {
464                bucket.sort_by_key(|c| c[2]);
465            }
466
467            // Split at median
468            let mid = bucket.len() / 2;
469            let (left, right) = bucket.split_at(mid);
470            buckets.push(left.to_vec());
471            buckets.push(right.to_vec());
472        }
473
474        // Average colors in each bucket to get palette
475        let mut palette = Vec::with_capacity(target_colors * 3);
476        for bucket in buckets {
477            if bucket.is_empty() {
478                continue;
479            }
480
481            let mut sum_r = 0u32;
482            let mut sum_g = 0u32;
483            let mut sum_b = 0u32;
484
485            for color in &bucket {
486                sum_r += u32::from(color[0]);
487                sum_g += u32::from(color[1]);
488                sum_b += u32::from(color[2]);
489            }
490
491            let count = bucket.len() as u32;
492            palette.push((sum_r / count) as u8);
493            palette.push((sum_g / count) as u8);
494            palette.push((sum_b / count) as u8);
495        }
496
497        Ok(palette)
498    }
499
500    /// Octree color quantization.
501    fn octree_quantize(&self, colors: &[[u8; 3]]) -> CodecResult<Vec<u8>> {
502        let mut tree = OctreeQuantizer::new(self.config.colors);
503
504        for &color in colors {
505            tree.add_color(color);
506        }
507
508        let palette = tree.get_palette();
509        Ok(palette)
510    }
511
512    /// Quantize frame to palette indices.
513    fn quantize_frame(&self, rgba: &[u8]) -> CodecResult<Vec<u8>> {
514        let mut indices = Vec::with_capacity((self.width * self.height) as usize);
515
516        match self.config.dithering {
517            DitheringMethod::None => {
518                for chunk in rgba.chunks_exact(4) {
519                    let color = [chunk[0], chunk[1], chunk[2]];
520                    let index = self.find_closest_color(color);
521                    indices.push(index);
522                }
523            }
524            DitheringMethod::FloydSteinberg => {
525                indices = self.floyd_steinberg_dither(rgba)?;
526            }
527            DitheringMethod::Ordered => {
528                indices = self.ordered_dither(rgba)?;
529            }
530        }
531
532        Ok(indices)
533    }
534
535    /// Find closest color in palette.
536    fn find_closest_color(&self, color: [u8; 3]) -> u8 {
537        let mut best_index = 0;
538        let mut best_distance = u32::MAX;
539
540        for i in 0..(self.palette.len() / 3) {
541            let pal_r = self.palette[i * 3];
542            let pal_g = self.palette[i * 3 + 1];
543            let pal_b = self.palette[i * 3 + 2];
544
545            let dr = i32::from(color[0]) - i32::from(pal_r);
546            let dg = i32::from(color[1]) - i32::from(pal_g);
547            let db = i32::from(color[2]) - i32::from(pal_b);
548
549            let distance = (dr * dr + dg * dg + db * db) as u32;
550
551            if distance < best_distance {
552                best_distance = distance;
553                best_index = i;
554            }
555        }
556
557        best_index as u8
558    }
559
560    /// Floyd-Steinberg dithering.
561    #[allow(clippy::cast_possible_wrap)]
562    fn floyd_steinberg_dither(&self, rgba: &[u8]) -> CodecResult<Vec<u8>> {
563        let width = self.width as usize;
564        let height = self.height as usize;
565
566        // Create error buffer
567        let mut errors = vec![[0i16; 3]; width * height];
568        let mut indices = Vec::with_capacity(width * height);
569
570        for y in 0..height {
571            for x in 0..width {
572                let idx = (y * width + x) * 4;
573                let pixel_idx = y * width + x;
574
575                // Get original color with accumulated error
576                let r = (i16::from(rgba[idx]) + errors[pixel_idx][0]).clamp(0, 255) as u8;
577                let g = (i16::from(rgba[idx + 1]) + errors[pixel_idx][1]).clamp(0, 255) as u8;
578                let b = (i16::from(rgba[idx + 2]) + errors[pixel_idx][2]).clamp(0, 255) as u8;
579
580                // Find closest palette color
581                let color = [r, g, b];
582                let index = self.find_closest_color(color);
583                indices.push(index);
584
585                // Calculate error
586                let pal_r = self.palette[index as usize * 3];
587                let pal_g = self.palette[index as usize * 3 + 1];
588                let pal_b = self.palette[index as usize * 3 + 2];
589
590                let err_r = i16::from(r) - i16::from(pal_r);
591                let err_g = i16::from(g) - i16::from(pal_g);
592                let err_b = i16::from(b) - i16::from(pal_b);
593
594                // Distribute error to neighbors
595                if x + 1 < width {
596                    let next_idx = pixel_idx + 1;
597                    errors[next_idx][0] += err_r * 7 / 16;
598                    errors[next_idx][1] += err_g * 7 / 16;
599                    errors[next_idx][2] += err_b * 7 / 16;
600                }
601
602                if y + 1 < height {
603                    if x > 0 {
604                        let next_idx = pixel_idx + width - 1;
605                        errors[next_idx][0] += err_r * 3 / 16;
606                        errors[next_idx][1] += err_g * 3 / 16;
607                        errors[next_idx][2] += err_b * 3 / 16;
608                    }
609
610                    let next_idx = pixel_idx + width;
611                    errors[next_idx][0] += err_r * 5 / 16;
612                    errors[next_idx][1] += err_g * 5 / 16;
613                    errors[next_idx][2] += err_b * 5 / 16;
614
615                    if x + 1 < width {
616                        let next_idx = pixel_idx + width + 1;
617                        errors[next_idx][0] += err_r / 16;
618                        errors[next_idx][1] += err_g / 16;
619                        errors[next_idx][2] += err_b / 16;
620                    }
621                }
622            }
623        }
624
625        Ok(indices)
626    }
627
628    /// Ordered (Bayer) dithering.
629    fn ordered_dither(&self, rgba: &[u8]) -> CodecResult<Vec<u8>> {
630        // 4x4 Bayer matrix
631        #[rustfmt::skip]
632        const BAYER_MATRIX: [[i16; 4]; 4] = [
633            [  0,  8,  2, 10 ],
634            [ 12,  4, 14,  6 ],
635            [  3, 11,  1,  9 ],
636            [ 15,  7, 13,  5 ],
637        ];
638
639        let width = self.width as usize;
640        let height = self.height as usize;
641        let mut indices = Vec::with_capacity(width * height);
642
643        for y in 0..height {
644            for x in 0..width {
645                let idx = (y * width + x) * 4;
646
647                // Apply Bayer matrix threshold
648                let threshold = BAYER_MATRIX[y % 4][x % 4] * 16 - 128;
649
650                let r = (i16::from(rgba[idx]) + threshold).clamp(0, 255) as u8;
651                let g = (i16::from(rgba[idx + 1]) + threshold).clamp(0, 255) as u8;
652                let b = (i16::from(rgba[idx + 2]) + threshold).clamp(0, 255) as u8;
653
654                let color = [r, g, b];
655                let index = self.find_closest_color(color);
656                indices.push(index);
657            }
658        }
659
660        Ok(indices)
661    }
662
663    /// Calculate size field for color table.
664    fn color_table_size_field(colors: usize) -> u8 {
665        let size = Self::next_power_of_two(colors);
666        let bits = Self::bits_needed(size);
667        bits.saturating_sub(1)
668    }
669
670    /// Calculate next power of two.
671    fn next_power_of_two(n: usize) -> usize {
672        let mut power = 2;
673        while power < n {
674            power *= 2;
675        }
676        power
677    }
678
679    /// Calculate bits needed to represent n values.
680    fn bits_needed(n: usize) -> u8 {
681        if n <= 2 {
682            1
683        } else {
684            (n as f64).log2().ceil() as u8
685        }
686    }
687}
688
689/// Octree node for color quantization.
690struct OctreeNode {
691    children: [Option<Box<OctreeNode>>; 8],
692    color_sum: [u32; 3],
693    pixel_count: u32,
694    is_leaf: bool,
695}
696
697impl OctreeNode {
698    fn new() -> Self {
699        Self {
700            children: Default::default(),
701            color_sum: [0, 0, 0],
702            pixel_count: 0,
703            is_leaf: false,
704        }
705    }
706}
707
708/// Octree quantizer for color reduction.
709struct OctreeQuantizer {
710    root: OctreeNode,
711    #[allow(dead_code)]
712    max_colors: usize,
713    leaf_count: usize,
714}
715
716impl OctreeQuantizer {
717    fn new(max_colors: usize) -> Self {
718        Self {
719            root: OctreeNode::new(),
720            max_colors,
721            leaf_count: 0,
722        }
723    }
724
725    fn add_color(&mut self, color: [u8; 3]) {
726        Self::add_color_recursive(&mut self.root, color, 0, &mut self.leaf_count);
727    }
728
729    fn add_color_recursive(
730        node: &mut OctreeNode,
731        color: [u8; 3],
732        depth: u8,
733        leaf_count: &mut usize,
734    ) {
735        if depth >= 8 || node.is_leaf {
736            node.color_sum[0] += u32::from(color[0]);
737            node.color_sum[1] += u32::from(color[1]);
738            node.color_sum[2] += u32::from(color[2]);
739            node.pixel_count += 1;
740            if !node.is_leaf {
741                node.is_leaf = true;
742                *leaf_count += 1;
743            }
744            return;
745        }
746
747        let index = Self::get_child_index(color, depth);
748
749        if node.children[index].is_none() {
750            node.children[index] = Some(Box::new(OctreeNode::new()));
751        }
752
753        if let Some(child) = &mut node.children[index] {
754            Self::add_color_recursive(child, color, depth + 1, leaf_count);
755        }
756    }
757
758    fn get_child_index(color: [u8; 3], depth: u8) -> usize {
759        let shift = 7 - depth;
760        let r_bit = ((color[0] >> shift) & 1) as usize;
761        let g_bit = ((color[1] >> shift) & 1) as usize;
762        let b_bit = ((color[2] >> shift) & 1) as usize;
763        (r_bit << 2) | (g_bit << 1) | b_bit
764    }
765
766    fn get_palette(&self) -> Vec<u8> {
767        let mut palette = Vec::new();
768        self.collect_colors(&self.root, &mut palette);
769        palette
770    }
771
772    fn collect_colors(&self, node: &OctreeNode, palette: &mut Vec<u8>) {
773        if node.is_leaf && node.pixel_count > 0 {
774            let r = (node.color_sum[0] / node.pixel_count) as u8;
775            let g = (node.color_sum[1] / node.pixel_count) as u8;
776            let b = (node.color_sum[2] / node.pixel_count) as u8;
777            palette.extend_from_slice(&[r, g, b]);
778            return;
779        }
780
781        for child in &node.children {
782            if let Some(child) = child {
783                self.collect_colors(child, palette);
784            }
785        }
786    }
787}
788
789#[cfg(test)]
790mod tests {
791    use super::*;
792
793    #[test]
794    fn test_color_table_size_field() {
795        assert_eq!(GifEncoderState::color_table_size_field(2), 0);
796        assert_eq!(GifEncoderState::color_table_size_field(4), 1);
797        assert_eq!(GifEncoderState::color_table_size_field(256), 7);
798    }
799
800    #[test]
801    fn test_next_power_of_two() {
802        assert_eq!(GifEncoderState::next_power_of_two(1), 2);
803        assert_eq!(GifEncoderState::next_power_of_two(3), 4);
804        assert_eq!(GifEncoderState::next_power_of_two(100), 128);
805    }
806
807    #[test]
808    fn test_bits_needed() {
809        assert_eq!(GifEncoderState::bits_needed(2), 1);
810        assert_eq!(GifEncoderState::bits_needed(4), 2);
811        assert_eq!(GifEncoderState::bits_needed(256), 8);
812    }
813}