oozextract 0.1.1

Open source Kraken / Mermaid / Selkie / Leviathan / LZNA / Bitknit decompressor
Documentation
use crate::error::{ErrorBuilder, ErrorContext, Res, ResultBuilder, WithContext};

#[derive(Copy, Clone)]
struct Base<const F: usize, const A: usize, const L: usize> {
    a: [u16; A],
    freq: [u16; F],
    adapt_interval: u16,
    lookup: [u16; L],
}

impl<const F: usize, const A: usize, const L: usize> ErrorContext for Base<F, A, L> {
    fn describe(&mut self) -> Option<String> {
        Some(
            match F {
                300 => "Literal",
                40 => "DistanceLsb",
                21 => "DistanceBits",
                _ => unreachable!(),
            }
            .into(),
        )
    }
}

type Literal = Base<300, 301, 512>;

type DistanceLsb = Base<40, 41, 64>;

type DistanceBits = Base<21, 22, 64>;

impl<const F: usize, const A: usize, const L: usize> Base<F, A, L> {
    const SHIFT: u16 = if A == 301 { 6 } else { 9 };
    const F_INC: u16 = 1026 - A as u16;

    fn fill_lut(&mut self) -> Res<()> {
        let mut p = 0;
        for (v, i) in self.a[1..].iter().zip(0u16..) {
            let p_end = (((v - 1) >> Self::SHIFT) + 1) as usize;
            self.lookup
                .get_mut(p..p_end)
                .message(|_| format!("{}..{} can't index [{}]", p, p_end, L))?
                .fill(i);
            p = p_end;
        }
        Ok(())
    }

    fn adapt(&mut self, sym: usize) -> Res<()> {
        self.adapt_interval = 1024;
        self.assert_lt(sym, F)?;
        self.freq[sym] += Self::F_INC;

        let mut sum = 0;
        for (f, a) in self.freq.iter_mut().zip(self.a[1..].iter_mut()) {
            sum += *f as u32;
            *a = (*a as u32).wrapping_add(sum.wrapping_sub(*a as u32) >> 1) as u16;
        }
        self.freq.fill(1);

        self.fill_lut().at(self)?;
        Ok(())
    }

    fn lookup(&mut self, bits: &mut u32) -> Res<usize> {
        let masked = (*bits & 0x7FFF) as u16;
        let i = (masked >> Self::SHIFT) as usize;
        self.assert_lt(i, L)?;
        let mut sym = self.lookup[i] as usize;
        self.assert_lt(sym + 1, A)?;
        if masked > self.a[sym + 1] {
            sym += 1;
            self.assert_lt(sym + 1, A)?;
        }
        sym += self.a[sym + 1..]
            .iter()
            .position(|&v| v > masked)
            .ok_or_else(ErrorBuilder::default)?;
        let s = self.a[sym] as u32;
        let s1 = self.a[sym + 1] as u32;
        *bits = masked as u32 + (*bits >> 15) * (s1 - s) - s;
        self.freq[sym] += 31;
        self.adapt_interval -= 1;
        if self.adapt_interval == 0 {
            self.adapt(sym).at(self)?;
        }
        Ok(sym)
    }
}

impl<const F: usize, const A: usize, const L: usize> Default for Base<F, A, L> {
    fn default() -> Self {
        let a = if Self::SHIFT == 6 {
            core::array::from_fn(|i| {
                if i < 264 {
                    ((0x8000 - 300 + 264) * i / 264) as u16
                } else {
                    ((0x8000 - 300) + i) as u16
                }
            })
        } else {
            core::array::from_fn(|i| (0x8000 * i / F) as u16)
        };

        let mut s = Self {
            a,
            freq: [1; F],
            adapt_interval: 1024,
            lookup: [0; L],
        };

        s.fill_lut().unwrap();
        s
    }
}

pub(crate) struct State {
    recent_dist: [u32; 8],
    last_match_dist: u32,
    recent_dist_mask: u32,

    literals: [Literal; 4],
    distance_lsb: [DistanceLsb; 4],
    distance_bits: DistanceBits,
}

impl State {
    pub(crate) fn new() -> Self {
        Self {
            last_match_dist: 1,
            recent_dist: [1; 8],
            recent_dist_mask: (1 << 3)
                | (2 << (2 * 3))
                | (3 << (3 * 3))
                | (4 << (4 * 3))
                | (5 << (5 * 3))
                | (6 << (6 * 3))
                | (7 << (7 * 3)),
            literals: Default::default(),
            distance_lsb: Default::default(),
            distance_bits: Default::default(),
        }
    }
}

pub(crate) struct Core<'a> {
    state: &'a mut State,
    input: &'a [u8],
    output: &'a mut [u8],
    src: usize,
    dst: usize,
    bits: u32,
    bits2: u32,
    litmodel: [usize; 4],
    distancelsb: [usize; 4],
}

impl ErrorContext for Core<'_> {}

impl<'a> Core<'a> {
    pub(crate) fn new(
        input: &'a [u8],
        output: &'a mut [u8],
        state: &'a mut State,
        dst: usize,
    ) -> Core<'a> {
        Self {
            state,
            input,
            output,
            src: 0,
            dst,
            bits: 0x10000,
            bits2: 0x10000,
            litmodel: core::array::from_fn(|i| i),
            distancelsb: core::array::from_fn(|i| i),
        }
    }

    fn read<const N: usize>(&self) -> Result<&[u8; N], ErrorBuilder> {
        Ok(self
            .input
            .get(self.src..)
            .and_then(|s| s.first_chunk())
            .message(|_| {
                format!(
                    "Can't read {} bytes from [{}] at {}",
                    N,
                    self.input.len(),
                    self.src
                )
            })?)
    }

    fn read_2(&mut self) -> Res<u32> {
        let v = u16::from_le_bytes(*self.read()?);
        self.src += 2;
        Ok(v as u32)
    }

    fn read_4(&mut self) -> Res<u32> {
        let v = u32::from_le_bytes(*self.read()?);
        self.src += 4;
        Ok(v)
    }

    fn write_1(&mut self, v: u8) -> Res<()> {
        self.assert_lt(self.dst, self.output.len())?;
        self.output[self.dst] = v;
        self.dst += 1;
        Ok(())
    }

    fn write_2(&mut self, v: u16) -> Res<()> {
        let i = self.dst;
        self.output
            .get_mut(i..i + 2)
            .message(|_| format!("{} out of bounds", i))?
            .copy_from_slice(&v.to_le_bytes());
        self.dst += 2;
        Ok(())
    }

    fn write_sym(&mut self, sym: u8) -> Res<()> {
        self.assert_lt(self.dst, self.output.len())?;
        self.output[self.dst] = sym.wrapping_add(self.last_match());
        self.dst += 1;
        Ok(())
    }

    fn copy_chunks<const CHUNK_SIZE: usize>(
        &mut self,
        copy_length: usize,
        match_dist: usize,
    ) -> Res<()> {
        self.assert_le(match_dist, self.dst)?;
        self.assert_le(self.dst + copy_length, self.output.len())?;
        for i in 0..copy_length / CHUNK_SIZE {
            let dst = self.dst + i * CHUNK_SIZE;
            let src = dst - match_dist;
            self.output.copy_within(src..src + CHUNK_SIZE, dst);
        }
        let rem = copy_length % CHUNK_SIZE;
        let dst = self.dst + copy_length - rem;
        let src = dst - match_dist;
        self.output.copy_within(src..src + rem, dst);
        Ok(())
    }

    fn last_match(&self) -> u8 {
        self.output[self.dst - self.state.last_match_dist as usize]
    }

    fn lookup_literal(&mut self) -> Res<usize> {
        self.state.literals[self.litmodel[self.dst & 3]].lookup(&mut self.bits)
    }

    fn lookup_lsb(&mut self) -> Res<usize> {
        self.state.distance_lsb[self.distancelsb[self.dst & 3]].lookup(&mut self.bits)
    }

    fn lookup_bits(&mut self) -> Res<usize> {
        self.state.distance_bits.lookup(&mut self.bits)
    }

    fn renormalize(&mut self) -> Res<()> {
        if self.bits < 0x10000 {
            self.bits = (self.bits << 16) | self.read_2().at(self)?;
        }
        std::mem::swap(&mut self.bits, &mut self.bits2);
        Ok(())
    }

    pub(crate) fn decode(&mut self) -> Res<usize> {
        let mut recent_mask = self.state.recent_dist_mask as usize;

        let v = self.read_4().at(self)?;
        if v < 0x10000 {
            return Ok(0);
        }

        let mut a = v >> 4;
        let n = v & 0xF;
        if a < 0x10000 {
            a = (a << 16) | self.read_2().at(self)?;
        }
        self.bits = a >> n;
        if self.bits < 0x10000 {
            self.bits = (self.bits << 16) | self.read_2().at(self)?;
        }
        a = (a << 16) | self.read_2().at(self)?;

        self.bits2 = (1 << (n + 16)) | (a & ((1 << (n + 16)) - 1));

        if self.dst == 0 {
            self.write_1(self.bits as u8).at(self)?;
            self.bits >>= 8;
            self.renormalize().at(self)?;
        }

        while self.dst + 4 < self.output.len() {
            let mut sym = self.lookup_literal().at(self)?;
            self.renormalize().at(self)?;

            if sym < 256 {
                self.write_sym(sym as u8).at(self)?;

                if self.dst + 4 >= self.output.len() {
                    break;
                }

                sym = self.lookup_literal().at(self)?;
                self.renormalize().at(self)?;

                if sym < 256 {
                    self.write_sym(sym as u8).at(self)?;
                    continue;
                }
            }

            if sym >= 288 {
                let nb = sym - 287;
                sym = (self.bits as usize & ((1 << nb) - 1)) + (1 << nb) + 286;
                self.bits >>= nb;
                self.renormalize().at(self)?;
            }

            let copy_length = sym - 254;

            sym = self.lookup_lsb().at(self)?;
            self.renormalize().at(self)?;

            let mut match_dist;
            if sym >= 8 {
                let nb = self.lookup_bits().at(self)?;
                self.renormalize().at(self)?;

                match_dist = self.bits & ((1 << (nb & 0xF)) - 1);
                self.bits >>= nb & 0xF;
                self.renormalize().at(self)?;
                if nb >= 0x10 {
                    match_dist = (match_dist << 16) | self.read_2().at(self)?;
                }
                match_dist = (32 << nb) + (match_dist << 5) + sym as u32 - 39;

                self.state.recent_dist[(recent_mask >> 21) & 7] =
                    self.state.recent_dist[(recent_mask >> 18) & 7];
                self.state.recent_dist[(recent_mask >> 18) & 7] = match_dist;
            } else {
                let idx = (recent_mask >> (3 * sym)) & 7;
                let mask = !7 << (3 * sym);
                match_dist = self.state.recent_dist[idx];
                recent_mask = (recent_mask & mask) | ((idx + 8 * recent_mask) & !mask);
            }

            if match_dist == 1 {
                let v = self.output[self.dst - 1];
                self.output[self.dst..][..copy_length].fill(v);
            } else if match_dist as usize > copy_length {
                let src = self.dst - match_dist as usize;
                self.output.copy_within(src..src + copy_length, self.dst);
            } else if match_dist >= 8 {
                self.copy_chunks::<8>(copy_length, match_dist as usize)
                    .at(self)?;
            } else if match_dist >= 4 {
                self.copy_chunks::<4>(copy_length, match_dist as usize)
                    .at(self)?;
            } else {
                for i in 0..copy_length {
                    self.output[self.dst + i] = self.output[self.dst + i - match_dist as usize];
                }
            }

            self.dst += copy_length;
            self.state.last_match_dist = match_dist;
        }
        self.write_2(self.bits as u16).at(self)?;
        self.write_2(self.bits2 as u16).at(self)?;

        self.state.recent_dist_mask = recent_mask as u32;
        Ok(self.src)
    }
}