Skip to main content

djvu_zp/
lib.rs

1//! ZP adaptive binary arithmetic coder — pure-Rust clean-room implementation.
2//!
3//! This crate implements the ZP (Z-Prime) adaptive binary arithmetic coder
4//! from the DjVu v3 specification (<https://www.sndjvu.org/spec.html>),
5//! used by the JB2, IW44, and BZZ codecs that make up a DjVu file.
6//!
7//! ## Usage
8//!
9//! Decoder (no_std-capable, no allocations):
10//!
11//! ```
12//! use djvu_zp::ZpDecoder;
13//! let compressed: &[u8] = &[0x00, 0x00];
14//! let mut dec = ZpDecoder::new(compressed)?;
15//! let mut ctx = 0u8;
16//! let _bit = dec.decode_bit(&mut ctx);
17//! # Ok::<(), djvu_zp::ZpError>(())
18//! ```
19//!
20//! Encoder (requires `std` feature, default-on):
21//!
22//! ```
23//! # #[cfg(feature = "std")]
24//! # {
25//! use djvu_zp::encoder::ZpEncoder;
26//! let mut enc = ZpEncoder::new();
27//! let mut ctx = 0u8;
28//! enc.encode_bit(&mut ctx, true);
29//! let _bytes: Vec<u8> = enc.finish();
30//! # }
31//! ```
32//!
33//! ## Features
34//!
35//! - `std` (default) — enables [`encoder::ZpEncoder`].  The decoder works
36//!   with or without `std` and never allocates.
37
38#![cfg_attr(not(feature = "std"), no_std)]
39
40#[cfg(feature = "std")]
41pub mod encoder;
42pub mod tables;
43
44use tables::{LPS_NEXT, MPS_NEXT, PROB, THRESHOLD};
45
46/// Errors that can occur while initializing or decoding a ZP stream.
47#[derive(Debug, thiserror::Error, PartialEq, Eq)]
48pub enum ZpError {
49    /// Input is too short — the ZP coder needs at least 2 bytes to load the
50    /// initial code register.
51    #[error("ZP input is too short (need at least 2 bytes)")]
52    TooShort,
53}
54
55/// ZP (Z-Prime) adaptive binary arithmetic decoder.
56///
57/// Implements the decoder described in the DjVu v3 specification. The decoder
58/// maintains a probability model for each context and adapts the model as bits
59/// are decoded.
60///
61/// Context bytes encode both the probability state index and the current MPS
62/// (most probable symbol) value. The low bit of the context byte indicates
63/// the current MPS; the remaining bits encode the probability state.
64pub struct ZpDecoder<'a> {
65    /// Current interval width register (16-bit value held in low 16 bits).
66    pub a: u32,
67    /// Current code (value within the interval) register (16-bit value held in low 16 bits).
68    pub c: u32,
69    /// Cached upper bound for the fast decode path (= min(c, 0x7fff)).
70    pub fence: u32,
71    /// Bit buffer for feeding bits into the code register.
72    pub bit_buf: u32,
73    /// Number of valid bits remaining in `bit_buf`.
74    pub bit_count: i32,
75    /// Compressed input bytes.
76    pub data: &'a [u8],
77    /// Current read position within `data`.
78    pub pos: usize,
79}
80
81impl<'a> ZpDecoder<'a> {
82    /// Construct a new ZP decoder from the given compressed byte slice.
83    ///
84    /// Reads the initial code register from the first two bytes of `data`.
85    ///
86    /// # Errors
87    ///
88    /// Returns [`ZpError::TooShort`] if `data` has fewer than 2 bytes.
89    pub fn new(data: &'a [u8]) -> Result<Self, ZpError> {
90        if data.len() < 2 {
91            return Err(ZpError::TooShort);
92        }
93
94        let mut dec = ZpDecoder {
95            a: 0,
96            c: 0,
97            fence: 0,
98            bit_buf: 0,
99            bit_count: 0,
100            data,
101            pos: 0,
102        };
103
104        // Load the initial code register from the first two bytes
105        let high = dec.read_byte() as u32;
106        let low = dec.read_byte() as u32;
107        dec.c = (high << 8) | low;
108
109        // Pre-fill the bit buffer
110        dec.refill_buffer();
111
112        // Initialise the fence
113        dec.fence = dec.c.min(0x7fff);
114
115        Ok(dec)
116    }
117
118    /// Read the next byte from the input stream, returning `0xFF` on exhaustion.
119    #[inline(always)]
120    fn read_byte(&mut self) -> u8 {
121        if self.pos < self.data.len() {
122            let b = self.data[self.pos];
123            self.pos += 1;
124            b
125        } else {
126            0xff
127        }
128    }
129
130    /// Fill `bit_buf` with fresh bytes until it holds at least 24 bits.
131    #[inline(always)]
132    fn refill_buffer(&mut self) {
133        while self.bit_count <= 24 {
134            let byte = self.read_byte();
135            self.bit_buf = (self.bit_buf << 8) | (byte as u32);
136            self.bit_count += 8;
137        }
138    }
139
140    /// Decode one bit using an adaptive probability context.
141    ///
142    /// `ctx` is a mutable context byte encoding the current probability state
143    /// and MPS value. It is updated in-place after each call.
144    ///
145    /// Returns `true` if the decoded bit is 1.
146    #[inline(always)]
147    pub fn decode_bit(&mut self, ctx: &mut u8) -> bool {
148        let state = *ctx as usize;
149        let mps_bit = state & 1; // low bit encodes the current MPS
150        let z = self.a + PROB[state] as u32;
151
152        // Fast path: interval stays within the fence — no renormalization needed
153        if z <= self.fence {
154            self.a = z;
155            return mps_bit != 0;
156        }
157
158        // Clamp to the decision boundary
159        let boundary = 0x6000u32 + ((self.a + z) >> 2);
160        let z_clamped = z.min(boundary);
161
162        if z_clamped > self.c {
163            // LPS event: decoded bit is opposite of MPS
164            let lps_bit = 1 - mps_bit;
165            let complement = 0x10000u32 - z_clamped;
166            self.a = (self.a + complement) & 0xffff;
167            self.c = (self.c + complement) & 0xffff;
168            *ctx = LPS_NEXT[state];
169            self.renormalize();
170            lps_bit != 0
171        } else {
172            // MPS event: decoded bit matches MPS
173            if self.a >= THRESHOLD[state] as u32 {
174                *ctx = MPS_NEXT[state];
175            }
176            self.bit_count -= 1;
177            self.a = (z_clamped << 1) & 0xffff;
178            self.c = (self.c << 1 | (self.bit_buf >> self.bit_count as u32) & 1) & 0xffff;
179            if self.bit_count < 16 {
180                self.refill_buffer();
181            }
182            self.fence = self.c.min(0x7fff);
183            mps_bit != 0
184        }
185    }
186
187    /// Returns `true` once all real input bytes have been consumed.
188    ///
189    /// After exhaustion the coder returns `0xFF` bytes indefinitely, producing
190    /// deterministic but meaningless bits. Callers may use this to skip
191    /// remaining work that would otherwise loop on constant input.
192    pub fn is_exhausted(&self) -> bool {
193        self.pos >= self.data.len()
194    }
195
196    /// Decode one bit in passthrough (context-free) mode.
197    ///
198    /// Used by BZZ to decode raw integer values (block size, BWT index).
199    /// The threshold is `z = 0x8000 + (a >> 1)`.
200    ///
201    /// Returns `true` if the decoded bit is 1.
202    #[inline(always)]
203    pub fn decode_passthrough(&mut self) -> bool {
204        let z = (0x8000u32 + (self.a >> 1)) as u16;
205        self.passthrough_with_threshold(z)
206    }
207
208    /// Decode one bit in IW44 passthrough mode.
209    ///
210    /// The threshold is `z = 0x8000 + (3 * a / 8)`.
211    ///
212    /// Returns `true` if the decoded bit is 1.
213    #[inline(always)]
214    pub fn decode_passthrough_iw44(&mut self) -> bool {
215        let z = (0x8000u32 + (3u32 * self.a) / 8) as u16;
216        self.passthrough_with_threshold(z)
217    }
218
219    /// Internal passthrough decode with an explicit threshold `z`.
220    #[inline(always)]
221    fn passthrough_with_threshold(&mut self, z: u16) -> bool {
222        if z as u32 > self.c {
223            // Bit is 1
224            let complement = 0x10000u32 - z as u32;
225            self.a = (self.a + complement) & 0xffff;
226            self.c = (self.c + complement) & 0xffff;
227            self.renormalize();
228            true
229        } else {
230            // Bit is 0
231            self.bit_count -= 1;
232            self.a = (z as u32 * 2) & 0xffff;
233            self.c = (self.c << 1 | (self.bit_buf >> self.bit_count as u32) & 1) & 0xffff;
234            if self.bit_count < 16 {
235                self.refill_buffer();
236            }
237            self.fence = self.c.min(0x7fff);
238            false
239        }
240    }
241
242    /// Renormalize after an LPS event.
243    ///
244    /// Shifts the interval register left until `a >= 0x8000`, pulling fresh bits
245    /// into `c` from the bit buffer.
246    #[inline(always)]
247    fn renormalize(&mut self) {
248        let shift = (self.a as u16).leading_ones();
249        self.bit_count -= shift as i32;
250        self.a = (self.a << shift) & 0xffff;
251        let mask = (1u32 << shift) - 1;
252        self.c = ((self.c << shift) | (self.bit_buf >> self.bit_count as u32) & mask) & 0xffff;
253        if self.bit_count < 16 {
254            self.refill_buffer();
255        }
256        self.fence = self.c.min(0x7fff);
257    }
258}
259
260#[cfg(test)]
261mod tests {
262    use super::*;
263
264    #[test]
265    fn zp_decoder_rejects_empty_input() {
266        assert!(matches!(ZpDecoder::new(&[]), Err(ZpError::TooShort)));
267    }
268
269    #[test]
270    fn zp_decoder_rejects_one_byte_input() {
271        assert!(matches!(ZpDecoder::new(&[0x00]), Err(ZpError::TooShort)));
272    }
273
274    #[test]
275    fn zp_decoder_accepts_two_byte_input() {
276        assert!(ZpDecoder::new(&[0x00, 0x00]).is_ok());
277        assert!(ZpDecoder::new(&[0xff, 0xff]).is_ok());
278    }
279
280    #[test]
281    fn zp_tables_spot_check() {
282        // These values are from the DjVu v3 spec
283        assert_eq!(PROB[0], 0x8000);
284        assert_eq!(PROB[250], 0x481a);
285        assert_eq!(MPS_NEXT[0], 84);
286        assert_eq!(LPS_NEXT[0], 145);
287        assert_eq!(THRESHOLD[83], 0);
288        assert_eq!(THRESHOLD[250], 0);
289    }
290}