Skip to main content

djvu_zp/
encoder.rs

1//! ZP adaptive binary arithmetic encoder.
2//!
3//! Encoding counterpart to [`super::ZpDecoder`]. Produces byte streams
4//! that the decoder can consume. Matches DjVuLibre's ZPCodec encoder.
5
6use super::tables::{LPS_NEXT, MPS_NEXT, PROB, THRESHOLD};
7
8/// ZP adaptive binary arithmetic encoder.
9///
10/// All internal registers (`a`, `subend`) are `u32` to match DjVuLibre's
11/// `unsigned int` types. They hold u16-range values but intermediate
12/// arithmetic can exceed 0xFFFF, which is critical for correct carry
13/// propagation in `zemit`.
14pub struct ZpEncoder {
15    /// Current interval width — stored as u32 but logically u16 after shifts.
16    a: u32,
17    /// Sub-interval lower bound for bit emission — u32 for carry propagation.
18    subend: u32,
19    /// 24-bit shift buffer for carry propagation (initialized to 0xFFFFFF).
20    buffer: u32,
21    /// Pending zero-byte run count for carry propagation.
22    nrun: i32,
23    /// Delay counter: first 25 outbit calls are absorbed.
24    delay: i32,
25    /// Byte accumulator for output.
26    byte: u8,
27    /// Bits accumulated in `byte` (0..8).
28    scount: u32,
29    /// Output bytes.
30    output: Vec<u8>,
31}
32
33impl Default for ZpEncoder {
34    fn default() -> Self {
35        Self::new()
36    }
37}
38
39impl ZpEncoder {
40    pub fn new() -> Self {
41        Self {
42            a: 0,
43            subend: 0,
44            buffer: 0xffffff,
45            nrun: 0,
46            delay: 25,
47            byte: 0,
48            scount: 0,
49            output: Vec::new(),
50        }
51    }
52
53    /// Encode one bit using an adaptive probability context.
54    ///
55    /// Matches DjVuLibre's inline `encoder(int bit, BitContext &ctx)`:
56    /// - LPS always calls encode_lps
57    /// - MPS with z >= 0x8000 calls encode_mps
58    /// - MPS with z < 0x8000 takes fast path (a = z, no shift)
59    pub fn encode_bit(&mut self, ctx: &mut u8, bit: bool) {
60        let state = *ctx as usize;
61        let mps_bit = (state & 1) != 0;
62        let z = self.a + PROB[state] as u32;
63
64        if bit != mps_bit {
65            self.encode_lps(ctx, z);
66        } else if z >= 0x8000 {
67            self.encode_mps(ctx, z);
68        } else {
69            // Fast path: MPS and z < 0x8000 — just update a, no shift
70            self.a = z;
71        }
72    }
73
74    /// Encode one bit in IW44 passthrough mode (threshold `z = 0x8000 + 3a/8`).
75    ///
76    /// Counterpart to [`ZpDecoder::decode_passthrough_iw44`]; must produce a
77    /// stream that it correctly decodes.
78    pub fn encode_passthrough_iw44(&mut self, bit: bool) {
79        let z = 0x8000 + (3 * self.a / 8);
80        // Invariant: self.a < 0x8000 (all encode paths maintain this).
81        // Therefore z = 0x8000 + 3a/8 ∈ [0x8000, 0xB000) — always ≥ 0x8000.
82        if !bit {
83            self.a = z;
84            // z ≥ 0x8000 always — single unconditional shift
85            self.zemit(1 - (self.subend >> 15) as i32);
86            self.subend = (self.subend << 1) & 0xffff;
87            self.a = (self.a << 1) & 0xffff;
88        } else {
89            let z_comp = 0x10000 - z;
90            self.subend += z_comp;
91            self.a += z_comp;
92            while self.a >= 0x8000 {
93                self.zemit(1 - (self.subend >> 15) as i32);
94                self.subend = (self.subend << 1) & 0xffff;
95                self.a = (self.a << 1) & 0xffff;
96            }
97        }
98    }
99
100    pub fn encode_passthrough(&mut self, bit: bool) {
101        let z = 0x8000 + (self.a >> 1);
102        // Invariant: self.a < 0x8000, so z = 0x8000 + a/2 ∈ [0x8000, 0xC000) — always ≥ 0x8000.
103        if !bit {
104            // false (MPS-like): a = z, single unconditional shift
105            self.a = z;
106            self.zemit(1 - (self.subend >> 15) as i32);
107            self.subend = (self.subend << 1) & 0xffff;
108            self.a = (self.a << 1) & 0xffff;
109        } else {
110            // true (LPS-like): z_comp = 0x10000 - z
111            let z_comp = 0x10000 - z;
112            self.subend += z_comp;
113            self.a += z_comp;
114            while self.a >= 0x8000 {
115                self.zemit(1 - (self.subend >> 15) as i32);
116                self.subend = (self.subend << 1) & 0xffff;
117                self.a = (self.a << 1) & 0xffff;
118            }
119        }
120    }
121
122    /// Flush the encoder and return the compressed byte stream.
123    pub fn finish(mut self) -> Vec<u8> {
124        // eflush: round subend up to disambiguate
125        if self.subend > 0x8000 {
126            self.subend = 0x10000;
127        } else if self.subend > 0 {
128            self.subend = 0x8000;
129        }
130        // Emit until buffer is flushed and subend is 0
131        while self.buffer != 0xffffff || self.subend != 0 {
132            self.zemit(1 - (self.subend >> 15) as i32);
133            self.subend = (self.subend << 1) & 0xffff;
134        }
135        // Final bits
136        self.outbit(1);
137        while self.nrun > 0 {
138            self.nrun -= 1;
139            self.outbit(0);
140        }
141        // Pad remaining byte with 1s
142        while self.scount > 0 {
143            self.outbit(1);
144        }
145        self.delay = 0xff; // prevent further output
146        // Ensure minimum 2 bytes for decoder initialization
147        while self.output.len() < 2 {
148            self.output.push(0xff);
149        }
150        self.output
151    }
152
153    fn encode_mps(&mut self, ctx: &mut u8, z: u32) {
154        // Clamp z: d = 0x6000 + (z + a) / 4
155        let d = 0x6000 + ((z + self.a) >> 2);
156        let z = z.min(d);
157
158        if (self.a & 0xffff) as u16 >= THRESHOLD[*ctx as usize] {
159            *ctx = MPS_NEXT[*ctx as usize];
160        }
161        // Code MPS bit + single shift
162        self.a = z;
163        self.zemit(1 - (self.subend >> 15) as i32);
164        self.subend = (self.subend << 1) & 0xffff;
165        self.a = (self.a << 1) & 0xffff;
166    }
167
168    fn encode_lps(&mut self, ctx: &mut u8, z: u32) {
169        // Clamp z
170        let d = 0x6000 + ((z + self.a) >> 2);
171        let z = z.min(d);
172
173        *ctx = LPS_NEXT[*ctx as usize];
174        let z_comp = 0x10000 - z;
175        self.subend += z_comp;
176        self.a += z_comp;
177        while self.a >= 0x8000 {
178            self.zemit(1 - (self.subend >> 15) as i32);
179            self.subend = (self.subend << 1) & 0xffff;
180            self.a = (self.a << 1) & 0xffff;
181        }
182    }
183
184    /// Emit one bit through the 24-bit carry-propagation buffer.
185    fn zemit(&mut self, b: i32) {
186        self.buffer = (self.buffer << 1).wrapping_add(b as u32);
187        let top = self.buffer >> 24;
188        self.buffer &= 0xffffff;
189        match top {
190            1 => {
191                self.outbit(1);
192                while self.nrun > 0 {
193                    self.nrun -= 1;
194                    self.outbit(0);
195                }
196            }
197            0xff => {
198                self.outbit(0);
199                while self.nrun > 0 {
200                    self.nrun -= 1;
201                    self.outbit(1);
202                }
203            }
204            0 => {
205                self.nrun += 1;
206            }
207            _ => {} // shouldn't happen
208        }
209    }
210
211    /// Emit one bit to the output byte stream (with delay).
212    fn outbit(&mut self, bit: i32) {
213        if self.delay > 0 {
214            if self.delay < 0xff {
215                self.delay -= 1;
216            }
217            return;
218        }
219        self.byte = (self.byte << 1) | (bit as u8);
220        self.scount += 1;
221        if self.scount == 8 {
222            self.output.push(self.byte);
223            self.scount = 0;
224            self.byte = 0;
225        }
226    }
227}
228
229#[cfg(test)]
230mod tests {
231    use super::*;
232    use crate::ZpDecoder;
233
234    #[test]
235    fn zp_roundtrip_passthrough_false() {
236        let mut enc = ZpEncoder::new();
237        for _ in 0..100 {
238            enc.encode_passthrough(false);
239        }
240        let compressed = enc.finish();
241        assert!(!compressed.is_empty());
242
243        let mut dec = ZpDecoder::new(&compressed).expect("init");
244        for i in 0..100 {
245            let got = dec.decode_passthrough();
246            assert!(!got, "expected false at bit {i}");
247        }
248    }
249
250    #[test]
251    fn zp_roundtrip_passthrough_true() {
252        let mut enc = ZpEncoder::new();
253        for _ in 0..100 {
254            enc.encode_passthrough(true);
255        }
256        let compressed = enc.finish();
257        assert!(!compressed.is_empty());
258
259        let mut dec = ZpDecoder::new(&compressed).expect("init");
260        for i in 0..100 {
261            let got = dec.decode_passthrough();
262            assert!(got, "expected true at bit {i}");
263        }
264    }
265
266    #[test]
267    fn zp_roundtrip_context_all_mps() {
268        let n = 200;
269        let mut enc = ZpEncoder::new();
270        let mut ctx = 0u8;
271        for _ in 0..n {
272            enc.encode_bit(&mut ctx, false);
273        }
274        let compressed = enc.finish();
275        let mut dec = ZpDecoder::new(&compressed).expect("init");
276        let mut dec_ctx = 0u8;
277        for i in 0..n {
278            let got = dec.decode_bit(&mut dec_ctx);
279            assert!(!got, "all-MPS mismatch at bit {i}");
280        }
281    }
282
283    #[test]
284    fn zp_roundtrip_context_all_lps() {
285        let n = 200;
286        let mut enc = ZpEncoder::new();
287        let mut ctx = 0u8;
288        for _ in 0..n {
289            enc.encode_bit(&mut ctx, true);
290        }
291        let compressed = enc.finish();
292        let mut dec = ZpDecoder::new(&compressed).expect("init");
293        let mut dec_ctx = 0u8;
294        for i in 0..n {
295            let got = dec.decode_bit(&mut dec_ctx);
296            assert!(got, "all-LPS mismatch at bit {i}");
297        }
298    }
299
300    #[test]
301    fn zp_roundtrip_context_bits() {
302        let mut rng: u64 = 0xdead_beef;
303        let n = 2000;
304        let mut bits = Vec::with_capacity(n);
305        let mut enc = ZpEncoder::new();
306        let mut ctx = 0u8;
307        for _ in 0..n {
308            rng ^= rng << 13;
309            rng ^= rng >> 7;
310            rng ^= rng << 17;
311            let bit = (rng & 1) != 0;
312            bits.push(bit);
313            enc.encode_bit(&mut ctx, bit);
314        }
315        let compressed = enc.finish();
316        let mut dec = ZpDecoder::new(&compressed).expect("init");
317        let mut dec_ctx = 0u8;
318        for (i, &expected) in bits.iter().enumerate() {
319            let got = dec.decode_bit(&mut dec_ctx);
320            assert_eq!(got, expected, "mismatch at bit {i}");
321        }
322    }
323
324    #[test]
325    fn zp_roundtrip_mixed() {
326        let mut enc = ZpEncoder::new();
327        let mut ctx = [0u8; 2];
328        let mut seq: Vec<(bool, bool)> = Vec::new();
329
330        for i in 0..500 {
331            let is_pt = i % 5 == 0;
332            let bit = (i * 13 + 7) % 3 != 0;
333            seq.push((is_pt, bit));
334            if is_pt {
335                enc.encode_passthrough(bit);
336            } else {
337                enc.encode_bit(&mut ctx[i % 2], bit);
338            }
339        }
340        let compressed = enc.finish();
341
342        let mut dec = ZpDecoder::new(&compressed).expect("init");
343        let mut dec_ctx = [0u8; 2];
344        for (i, &(is_pt, expected)) in seq.iter().enumerate() {
345            let got = if is_pt {
346                dec.decode_passthrough()
347            } else {
348                dec.decode_bit(&mut dec_ctx[i % 2])
349            };
350            assert_eq!(got, expected, "mismatch at step {i} (pt={is_pt})");
351        }
352    }
353
354    #[test]
355    fn zp_roundtrip_multiple_contexts() {
356        let mut rng: u64 = 42;
357        let n = 1000;
358        let nctx = 4;
359        let mut bits = Vec::with_capacity(n);
360        let mut enc = ZpEncoder::new();
361        let mut ctx = vec![0u8; nctx];
362
363        for i in 0..n {
364            rng ^= rng << 13;
365            rng ^= rng >> 7;
366            rng ^= rng << 17;
367            let bit = (rng & 1) != 0;
368            bits.push((i % nctx, bit));
369            enc.encode_bit(&mut ctx[i % nctx], bit);
370        }
371        let compressed = enc.finish();
372
373        let mut dec = ZpDecoder::new(&compressed).expect("init");
374        let mut dec_ctx = vec![0u8; nctx];
375        for (i, &(ci, expected)) in bits.iter().enumerate() {
376            let got = dec.decode_bit(&mut dec_ctx[ci]);
377            assert_eq!(got, expected, "mismatch at bit {i} ctx {ci}");
378        }
379    }
380}