webp-rust 0.2.1

Pure Rust implementation of a WebP encoder and decoder
Documentation
use crate::encoder::bit_writer::BitWriter;
use crate::encoder::EncoderError;

const MAX_ALLOWED_CODE_LENGTH: usize = 15;

#[derive(Debug, Clone, Copy)]
struct HuffmanTreeNode {
    total_count: u32,
    value: isize,
    left: isize,
    right: isize,
}

#[derive(Debug, Clone, Copy)]
pub(crate) struct HuffmanTreeToken {
    pub(crate) code: u8,
    pub(crate) extra_bits: u8,
}

#[derive(Debug, Clone)]
pub(crate) struct HuffmanCode {
    code_lengths: Vec<u8>,
    codes: Vec<u16>,
    single_symbol: Option<usize>,
}

impl HuffmanCode {
    pub(crate) fn from_code_lengths(code_lengths: Vec<u8>) -> Result<Self, EncoderError> {
        let mut counts = [0u32; MAX_ALLOWED_CODE_LENGTH + 1];
        let symbols = code_lengths
            .iter()
            .enumerate()
            .filter_map(|(symbol, &len)| (len != 0).then_some(symbol))
            .collect::<Vec<_>>();

        if symbols.is_empty() {
            return Err(EncoderError::Bitstream("empty Huffman tree"));
        }

        for &len in &code_lengths {
            let bits = len as usize;
            if bits > MAX_ALLOWED_CODE_LENGTH {
                return Err(EncoderError::Bitstream("invalid Huffman code length"));
            }
            if bits > 0 {
                counts[bits] += 1;
            }
        }

        let single_symbol = (symbols.len() == 1).then_some(symbols[0]);
        if symbols.len() > 1 {
            let mut left = 1i32;
            for bits in 1..=MAX_ALLOWED_CODE_LENGTH {
                left = (left << 1) - counts[bits] as i32;
                if left < 0 {
                    return Err(EncoderError::Bitstream("oversubscribed Huffman tree"));
                }
            }
            if left != 0 {
                return Err(EncoderError::Bitstream("incomplete Huffman tree"));
            }
        }

        let mut next_code = [0u32; MAX_ALLOWED_CODE_LENGTH + 1];
        let mut code = 0u32;
        for bits in 1..=MAX_ALLOWED_CODE_LENGTH {
            code = (code + counts[bits - 1]) << 1;
            next_code[bits] = code;
        }

        let mut codes = vec![0u16; code_lengths.len()];
        for (symbol, &len) in code_lengths.iter().enumerate() {
            let bits = len as usize;
            if bits == 0 {
                continue;
            }
            let canonical = next_code[bits];
            next_code[bits] += 1;
            codes[symbol] = reverse_bits(canonical, bits);
        }

        Ok(Self {
            code_lengths,
            codes,
            single_symbol,
        })
    }

    pub(crate) fn from_histogram(
        histogram: &[u32],
        tree_depth_limit: usize,
    ) -> Result<Self, EncoderError> {
        let code_lengths = generate_code_lengths(histogram, tree_depth_limit)?;
        Self::from_code_lengths(code_lengths)
    }

    pub(crate) fn code_lengths(&self) -> &[u8] {
        &self.code_lengths
    }

    pub(crate) fn used_symbols(&self) -> Vec<usize> {
        self.code_lengths
            .iter()
            .enumerate()
            .filter_map(|(symbol, &len)| (len != 0).then_some(symbol))
            .collect()
    }

    pub(crate) fn write_symbol(
        &self,
        bw: &mut BitWriter,
        symbol: usize,
    ) -> Result<(), EncoderError> {
        if let Some(single_symbol) = self.single_symbol {
            if symbol != single_symbol {
                return Err(EncoderError::Bitstream(
                    "attempted to write unexpected single-symbol Huffman code",
                ));
            }
            return Ok(());
        }

        let depth = *self
            .code_lengths
            .get(symbol)
            .ok_or(EncoderError::InvalidParam("Huffman symbol is out of range"))?
            as usize;
        if depth == 0 {
            return Err(EncoderError::Bitstream(
                "attempted to write unused Huffman symbol",
            ));
        }
        bw.put_bits(self.codes[symbol] as u32, depth)
    }
}

pub(crate) fn compress_huffman_tree(code_lengths: &[u8]) -> Vec<HuffmanTreeToken> {
    let mut tokens = Vec::with_capacity(code_lengths.len());
    let mut prev_value = 8u8;
    let mut index = 0usize;

    while index < code_lengths.len() {
        let value = code_lengths[index];
        let mut next = index + 1;
        while next < code_lengths.len() && code_lengths[next] == value {
            next += 1;
        }
        let runs = next - index;
        if value == 0 {
            code_repeated_zeros(runs, &mut tokens);
        } else {
            code_repeated_values(runs, value, prev_value, &mut tokens);
            prev_value = value;
        }
        index = next;
    }

    tokens
}

fn code_repeated_values(
    mut repetitions: usize,
    value: u8,
    prev_value: u8,
    tokens: &mut Vec<HuffmanTreeToken>,
) {
    if value != prev_value {
        tokens.push(HuffmanTreeToken {
            code: value,
            extra_bits: 0,
        });
        repetitions -= 1;
    }

    while repetitions >= 1 {
        if repetitions < 3 {
            for _ in 0..repetitions {
                tokens.push(HuffmanTreeToken {
                    code: value,
                    extra_bits: 0,
                });
            }
            break;
        } else if repetitions < 7 {
            tokens.push(HuffmanTreeToken {
                code: 16,
                extra_bits: (repetitions - 3) as u8,
            });
            break;
        } else {
            tokens.push(HuffmanTreeToken {
                code: 16,
                extra_bits: 3,
            });
            repetitions -= 6;
        }
    }
}

fn code_repeated_zeros(mut repetitions: usize, tokens: &mut Vec<HuffmanTreeToken>) {
    while repetitions >= 1 {
        if repetitions < 3 {
            for _ in 0..repetitions {
                tokens.push(HuffmanTreeToken {
                    code: 0,
                    extra_bits: 0,
                });
            }
            break;
        } else if repetitions < 11 {
            tokens.push(HuffmanTreeToken {
                code: 17,
                extra_bits: (repetitions - 3) as u8,
            });
            break;
        } else if repetitions < 139 {
            tokens.push(HuffmanTreeToken {
                code: 18,
                extra_bits: (repetitions - 11) as u8,
            });
            break;
        } else {
            tokens.push(HuffmanTreeToken {
                code: 18,
                extra_bits: 0x7f,
            });
            repetitions -= 138;
        }
    }
}

fn generate_code_lengths(
    histogram: &[u32],
    tree_depth_limit: usize,
) -> Result<Vec<u8>, EncoderError> {
    let mut code_lengths = vec![0u8; histogram.len()];
    let tree_size_orig = histogram.iter().filter(|&&count| count != 0).count();
    if tree_size_orig == 0 {
        return Err(EncoderError::Bitstream("empty Huffman histogram"));
    }
    if tree_size_orig > (1usize << (tree_depth_limit - 1)) {
        return Err(EncoderError::Bitstream("Huffman tree exceeds depth limit"));
    }

    let mut count_min = 1u32;
    loop {
        code_lengths.fill(0);
        let mut tree = histogram
            .iter()
            .enumerate()
            .filter_map(|(value, &count)| {
                (count != 0).then_some(HuffmanTreeNode {
                    total_count: count.max(count_min),
                    value: value as isize,
                    left: -1,
                    right: -1,
                })
            })
            .collect::<Vec<_>>();
        tree.sort_by(|a, b| {
            b.total_count
                .cmp(&a.total_count)
                .then_with(|| a.value.cmp(&b.value))
        });

        if tree.len() == 1 {
            code_lengths[tree[0].value as usize] = 1;
        } else {
            let mut tree_pool = Vec::with_capacity(tree.len() * 2);
            let mut tree_size = tree.len();
            while tree_size > 1 {
                tree_pool.push(tree[tree_size - 1]);
                tree_pool.push(tree[tree_size - 2]);
                let count = tree_pool[tree_pool.len() - 1].total_count
                    + tree_pool[tree_pool.len() - 2].total_count;
                tree_size -= 2;

                let mut insert_at = 0usize;
                while insert_at < tree_size && tree[insert_at].total_count > count {
                    insert_at += 1;
                }
                let new_node = HuffmanTreeNode {
                    total_count: count,
                    value: -1,
                    left: (tree_pool.len() - 1) as isize,
                    right: (tree_pool.len() - 2) as isize,
                };
                tree.insert(insert_at, new_node);
                tree_size += 1;
            }
            set_bit_depths(&tree[0], &tree_pool, &mut code_lengths, 0);
        }

        let max_depth = code_lengths.iter().copied().max().unwrap_or(0) as usize;
        if max_depth <= tree_depth_limit {
            return Ok(code_lengths);
        }

        count_min = count_min
            .checked_mul(2)
            .ok_or(EncoderError::Bitstream("Huffman count limit overflow"))?;
    }
}

fn set_bit_depths(
    node: &HuffmanTreeNode,
    pool: &[HuffmanTreeNode],
    bit_depths: &mut [u8],
    level: u8,
) {
    if node.left >= 0 {
        set_bit_depths(&pool[node.left as usize], pool, bit_depths, level + 1);
        set_bit_depths(&pool[node.right as usize], pool, bit_depths, level + 1);
    } else {
        bit_depths[node.value as usize] = level;
    }
}

fn reverse_bits(mut code: u32, bits: usize) -> u16 {
    let mut out = 0u32;
    for _ in 0..bits {
        out = (out << 1) | (code & 1);
        code >>= 1;
    }
    out as u16
}