1use super::tables::{LPS_NEXT, MPS_NEXT, PROB, THRESHOLD};
7
8pub struct ZpEncoder {
15 a: u32,
17 subend: u32,
19 buffer: u32,
21 nrun: i32,
23 delay: i32,
25 byte: u8,
27 scount: u32,
29 output: Vec<u8>,
31}
32
33impl Default for ZpEncoder {
34 fn default() -> Self {
35 Self::new()
36 }
37}
38
39impl ZpEncoder {
40 pub fn new() -> Self {
41 Self {
42 a: 0,
43 subend: 0,
44 buffer: 0xffffff,
45 nrun: 0,
46 delay: 25,
47 byte: 0,
48 scount: 0,
49 output: Vec::new(),
50 }
51 }
52
53 pub fn encode_bit(&mut self, ctx: &mut u8, bit: bool) {
60 let state = *ctx as usize;
61 let mps_bit = (state & 1) != 0;
62 let z = self.a + PROB[state] as u32;
63
64 if bit != mps_bit {
65 self.encode_lps(ctx, z);
66 } else if z >= 0x8000 {
67 self.encode_mps(ctx, z);
68 } else {
69 self.a = z;
71 }
72 }
73
74 pub fn encode_passthrough_iw44(&mut self, bit: bool) {
79 let z = 0x8000 + (3 * self.a / 8);
80 if !bit {
83 self.a = z;
84 self.zemit(1 - (self.subend >> 15) as i32);
86 self.subend = (self.subend << 1) & 0xffff;
87 self.a = (self.a << 1) & 0xffff;
88 } else {
89 let z_comp = 0x10000 - z;
90 self.subend += z_comp;
91 self.a += z_comp;
92 while self.a >= 0x8000 {
93 self.zemit(1 - (self.subend >> 15) as i32);
94 self.subend = (self.subend << 1) & 0xffff;
95 self.a = (self.a << 1) & 0xffff;
96 }
97 }
98 }
99
100 pub fn encode_passthrough(&mut self, bit: bool) {
101 let z = 0x8000 + (self.a >> 1);
102 if !bit {
104 self.a = z;
106 self.zemit(1 - (self.subend >> 15) as i32);
107 self.subend = (self.subend << 1) & 0xffff;
108 self.a = (self.a << 1) & 0xffff;
109 } else {
110 let z_comp = 0x10000 - z;
112 self.subend += z_comp;
113 self.a += z_comp;
114 while self.a >= 0x8000 {
115 self.zemit(1 - (self.subend >> 15) as i32);
116 self.subend = (self.subend << 1) & 0xffff;
117 self.a = (self.a << 1) & 0xffff;
118 }
119 }
120 }
121
122 pub fn finish(mut self) -> Vec<u8> {
124 if self.subend > 0x8000 {
126 self.subend = 0x10000;
127 } else if self.subend > 0 {
128 self.subend = 0x8000;
129 }
130 while self.buffer != 0xffffff || self.subend != 0 {
132 self.zemit(1 - (self.subend >> 15) as i32);
133 self.subend = (self.subend << 1) & 0xffff;
134 }
135 self.outbit(1);
137 while self.nrun > 0 {
138 self.nrun -= 1;
139 self.outbit(0);
140 }
141 while self.scount > 0 {
143 self.outbit(1);
144 }
145 self.delay = 0xff; while self.output.len() < 2 {
148 self.output.push(0xff);
149 }
150 self.output
151 }
152
153 fn encode_mps(&mut self, ctx: &mut u8, z: u32) {
154 let d = 0x6000 + ((z + self.a) >> 2);
156 let z = z.min(d);
157
158 if (self.a & 0xffff) as u16 >= THRESHOLD[*ctx as usize] {
159 *ctx = MPS_NEXT[*ctx as usize];
160 }
161 self.a = z;
163 self.zemit(1 - (self.subend >> 15) as i32);
164 self.subend = (self.subend << 1) & 0xffff;
165 self.a = (self.a << 1) & 0xffff;
166 }
167
168 fn encode_lps(&mut self, ctx: &mut u8, z: u32) {
169 let d = 0x6000 + ((z + self.a) >> 2);
171 let z = z.min(d);
172
173 *ctx = LPS_NEXT[*ctx as usize];
174 let z_comp = 0x10000 - z;
175 self.subend += z_comp;
176 self.a += z_comp;
177 while self.a >= 0x8000 {
178 self.zemit(1 - (self.subend >> 15) as i32);
179 self.subend = (self.subend << 1) & 0xffff;
180 self.a = (self.a << 1) & 0xffff;
181 }
182 }
183
184 fn zemit(&mut self, b: i32) {
186 self.buffer = (self.buffer << 1).wrapping_add(b as u32);
187 let top = self.buffer >> 24;
188 self.buffer &= 0xffffff;
189 match top {
190 1 => {
191 self.outbit(1);
192 while self.nrun > 0 {
193 self.nrun -= 1;
194 self.outbit(0);
195 }
196 }
197 0xff => {
198 self.outbit(0);
199 while self.nrun > 0 {
200 self.nrun -= 1;
201 self.outbit(1);
202 }
203 }
204 0 => {
205 self.nrun += 1;
206 }
207 _ => {} }
209 }
210
211 fn outbit(&mut self, bit: i32) {
213 if self.delay > 0 {
214 if self.delay < 0xff {
215 self.delay -= 1;
216 }
217 return;
218 }
219 self.byte = (self.byte << 1) | (bit as u8);
220 self.scount += 1;
221 if self.scount == 8 {
222 self.output.push(self.byte);
223 self.scount = 0;
224 self.byte = 0;
225 }
226 }
227}
228
229#[cfg(test)]
230mod tests {
231 use super::*;
232 use crate::ZpDecoder;
233
234 #[test]
235 fn zp_roundtrip_passthrough_false() {
236 let mut enc = ZpEncoder::new();
237 for _ in 0..100 {
238 enc.encode_passthrough(false);
239 }
240 let compressed = enc.finish();
241 assert!(!compressed.is_empty());
242
243 let mut dec = ZpDecoder::new(&compressed).expect("init");
244 for i in 0..100 {
245 let got = dec.decode_passthrough();
246 assert!(!got, "expected false at bit {i}");
247 }
248 }
249
250 #[test]
251 fn zp_roundtrip_passthrough_true() {
252 let mut enc = ZpEncoder::new();
253 for _ in 0..100 {
254 enc.encode_passthrough(true);
255 }
256 let compressed = enc.finish();
257 assert!(!compressed.is_empty());
258
259 let mut dec = ZpDecoder::new(&compressed).expect("init");
260 for i in 0..100 {
261 let got = dec.decode_passthrough();
262 assert!(got, "expected true at bit {i}");
263 }
264 }
265
266 #[test]
267 fn zp_roundtrip_context_all_mps() {
268 let n = 200;
269 let mut enc = ZpEncoder::new();
270 let mut ctx = 0u8;
271 for _ in 0..n {
272 enc.encode_bit(&mut ctx, false);
273 }
274 let compressed = enc.finish();
275 let mut dec = ZpDecoder::new(&compressed).expect("init");
276 let mut dec_ctx = 0u8;
277 for i in 0..n {
278 let got = dec.decode_bit(&mut dec_ctx);
279 assert!(!got, "all-MPS mismatch at bit {i}");
280 }
281 }
282
283 #[test]
284 fn zp_roundtrip_context_all_lps() {
285 let n = 200;
286 let mut enc = ZpEncoder::new();
287 let mut ctx = 0u8;
288 for _ in 0..n {
289 enc.encode_bit(&mut ctx, true);
290 }
291 let compressed = enc.finish();
292 let mut dec = ZpDecoder::new(&compressed).expect("init");
293 let mut dec_ctx = 0u8;
294 for i in 0..n {
295 let got = dec.decode_bit(&mut dec_ctx);
296 assert!(got, "all-LPS mismatch at bit {i}");
297 }
298 }
299
300 #[test]
301 fn zp_roundtrip_context_bits() {
302 let mut rng: u64 = 0xdead_beef;
303 let n = 2000;
304 let mut bits = Vec::with_capacity(n);
305 let mut enc = ZpEncoder::new();
306 let mut ctx = 0u8;
307 for _ in 0..n {
308 rng ^= rng << 13;
309 rng ^= rng >> 7;
310 rng ^= rng << 17;
311 let bit = (rng & 1) != 0;
312 bits.push(bit);
313 enc.encode_bit(&mut ctx, bit);
314 }
315 let compressed = enc.finish();
316 let mut dec = ZpDecoder::new(&compressed).expect("init");
317 let mut dec_ctx = 0u8;
318 for (i, &expected) in bits.iter().enumerate() {
319 let got = dec.decode_bit(&mut dec_ctx);
320 assert_eq!(got, expected, "mismatch at bit {i}");
321 }
322 }
323
324 #[test]
325 fn zp_roundtrip_mixed() {
326 let mut enc = ZpEncoder::new();
327 let mut ctx = [0u8; 2];
328 let mut seq: Vec<(bool, bool)> = Vec::new();
329
330 for i in 0..500 {
331 let is_pt = i % 5 == 0;
332 let bit = (i * 13 + 7) % 3 != 0;
333 seq.push((is_pt, bit));
334 if is_pt {
335 enc.encode_passthrough(bit);
336 } else {
337 enc.encode_bit(&mut ctx[i % 2], bit);
338 }
339 }
340 let compressed = enc.finish();
341
342 let mut dec = ZpDecoder::new(&compressed).expect("init");
343 let mut dec_ctx = [0u8; 2];
344 for (i, &(is_pt, expected)) in seq.iter().enumerate() {
345 let got = if is_pt {
346 dec.decode_passthrough()
347 } else {
348 dec.decode_bit(&mut dec_ctx[i % 2])
349 };
350 assert_eq!(got, expected, "mismatch at step {i} (pt={is_pt})");
351 }
352 }
353
354 #[test]
355 fn zp_roundtrip_multiple_contexts() {
356 let mut rng: u64 = 42;
357 let n = 1000;
358 let nctx = 4;
359 let mut bits = Vec::with_capacity(n);
360 let mut enc = ZpEncoder::new();
361 let mut ctx = vec![0u8; nctx];
362
363 for i in 0..n {
364 rng ^= rng << 13;
365 rng ^= rng >> 7;
366 rng ^= rng << 17;
367 let bit = (rng & 1) != 0;
368 bits.push((i % nctx, bit));
369 enc.encode_bit(&mut ctx[i % nctx], bit);
370 }
371 let compressed = enc.finish();
372
373 let mut dec = ZpDecoder::new(&compressed).expect("init");
374 let mut dec_ctx = vec![0u8; nctx];
375 for (i, &(ci, expected)) in bits.iter().enumerate() {
376 let got = dec.decode_bit(&mut dec_ctx[ci]);
377 assert_eq!(got, expected, "mismatch at bit {i} ctx {ci}");
378 }
379 }
380}