1#![forbid(clippy::let_underscore_drop)]
20#![forbid(unsafe_code)]
21#![warn(clippy::unwrap_used)]
22#![warn(missing_docs)]
23
24use std::collections::HashMap;
25use std::fs::File;
26use std::io::{self, BufWriter, Read, Write};
27use std::path::Path;
28use thiserror::Error as ThisError;
29
30pub const TOP: u32 = 1 << 24;
31pub const BOT: u32 = 1 << 15;
32pub const MAX_FREQ: u8 = 124;
33pub const DEFAULT_ORDER: u8 = 5;
34
35pub type PpmResult<T> = Result<T, PpmError>;
37
38#[derive(ThisError, Debug)]
40pub enum PpmError {
41 #[error("IO error: {0}")]
43 IoError(#[from] io::Error),
44
45 #[error("Corrupt input data")]
47 CorruptData,
48
49 #[error("Invalid decoder state")]
51 InvalidState,
52
53 #[error("Model error: {0}")]
55 ModelError(&'static str),
56}
57
58pub struct RangeEncoder<W: Write> {
94 low: u32,
95 range: u32,
96 buffer: Vec<u8>,
97 writer: W,
98}
99
100impl<W: Write> RangeEncoder<W> {
101 pub fn new(writer: W) -> Self {
111 Self {
112 low: 0,
113 range: u32::MAX,
114 buffer: Vec::with_capacity(4096),
115 writer,
116 }
117 }
118
119 fn encode(&mut self, cum_freq: u32, freq: u32, tot_freq: u32) -> PpmResult<()> {
120 assert!(tot_freq > 0, "total frequency must be positive");
121 assert!(freq > 0, "symbol frequency must be positive");
122 assert!(cum_freq < tot_freq, "cumulative freq out of range");
123 assert!(cum_freq + freq <= tot_freq, "freq interval exceeds total");
124
125 self.range /= tot_freq;
126 self.low = self.low.wrapping_add(cum_freq * self.range);
127 self.range = self.range.wrapping_mul(freq);
128
129 while (self.low ^ (self.low.wrapping_add(self.range))) < TOP || self.range < BOT {
130 if self.range < BOT {
131 self.range = (-(self.low as i32) as u32) & (BOT - 1);
132 }
133 self.buffer.push((self.low >> 24) as u8);
134 self.low <<= 8;
135 self.range <<= 8;
136 if self.buffer.len() >= 4096 {
137 self.writer.write_all(&self.buffer)?;
138 self.buffer.clear();
139 }
140 }
141 Ok(())
142 }
143
144 fn finish(mut self) -> PpmResult<W> {
145 assert!(self.range > 0, "range became zero in finish");
146 for _ in 0..4 {
147 self.buffer.push((self.low >> 24) as u8);
148 self.low <<= 8;
149 }
150 assert!(!self.buffer.is_empty(), "nothing to flush in finish");
151 self.writer.write_all(&self.buffer)?;
152 Ok(self.writer)
153 }
154}
155
156pub struct RangeDecoder<R: Read> {
188 low: u32,
189 code: u32,
190 range: u32,
191 reader: R,
192 buffer: [u8; 1],
193}
194
195impl<R: Read> RangeDecoder<R> {
196 pub fn new(mut reader: R) -> PpmResult<Self> {
207 let mut code = 0;
208 for _ in 0..4 {
209 let mut byte = [0];
210 reader.read_exact(&mut byte)?;
211 code = (code << 8) | u32::from(byte[0]);
212 }
213 Ok(Self {
214 low: 0,
215 code,
216 range: u32::MAX,
217 reader,
218 buffer: [0],
219 })
220 }
221
222 fn get_freq(&mut self, tot_freq: u32) -> PpmResult<u32> {
223 assert!(tot_freq > 0, "total frequency must be positive");
224 self.range /= tot_freq;
225 let tmp = (self.code.wrapping_sub(self.low)) / self.range;
226 if tmp >= tot_freq {
227 return Err(PpmError::CorruptData);
228 }
229 Ok(tmp)
230 }
231
232 fn decode(&mut self, cum_freq: u32, freq: u32, tot_freq: u32) -> PpmResult<()> {
233 assert!(freq > 0, "frequency must be positive");
234 assert!(cum_freq < tot_freq, "cumulative freq out of range");
235 assert!(cum_freq + freq <= tot_freq, "freq interval exceeds total");
236
237 if cum_freq.wrapping_add(freq) > tot_freq {
238 return Err(PpmError::CorruptData);
239 }
240 self.low = self.low.wrapping_add(cum_freq * self.range);
241 self.range = self.range.wrapping_mul(freq);
242 while (self.low ^ (self.low.wrapping_add(self.range))) < TOP || self.range < BOT {
243 if self.range < BOT {
244 self.range = (-(self.low as i32) as u32) & (BOT - 1);
245 }
246 self.reader.read_exact(&mut self.buffer)?;
247 self.code = (self.code << 8) | u32::from(self.buffer[0]);
248 self.low <<= 8;
249 self.range <<= 8;
250 }
251 Ok(())
252 }
253}
254
255#[derive(Debug, Clone)]
256struct State {
257 symbol: u8,
258 freq: u8,
259}
260
261#[derive(Clone, Debug)]
262struct PpmContext {
263 stats: Vec<State>,
264 total_freq: u32,
265}
266
267impl PpmContext {
268 fn new() -> Self {
269 PpmContext {
270 stats: Vec::new(),
271 total_freq: 0,
272 }
273 }
274
275 fn inherit_from(&mut self, parent: &PpmContext) {
278 assert!(parent.stats.len() > 0, "parent context has no stats");
280
281 self.stats.clear();
282 for st in &parent.stats {
283 let f = (st.freq / 2).max(1);
284 assert!(f >= 1, "inherited frequency dropped below 1");
285 self.stats.push(State {
286 symbol: st.symbol,
287 freq: f,
288 });
289 }
290 self.total_freq = self.stats.iter().map(|s| s.freq as u32).sum();
291 assert!(
292 self.total_freq > 0,
293 "total_freq must be positive after inherit"
294 );
295 }
296
297 fn get_cumulative(&self) -> (Vec<u8>, Vec<u32>, u32, u32) {
302 let c: u32 = self.stats.iter().map(|s| s.freq as u32).sum();
303 let q = self.stats.len() as u32;
304 let tot = 2 * c;
305 assert!(tot > 0, "total (2*C) must be positive");
306
307 let mut syms = Vec::with_capacity(self.stats.len());
308 let mut freqs = Vec::with_capacity(self.stats.len());
309 for st in &self.stats {
310 let f = (st.freq as u32) * 2 - 1;
311 assert!(f > 0, "computed symbol frequency must be positive");
312 syms.push(st.symbol);
313 freqs.push(f);
314 }
315 (syms, freqs, q, tot)
316 }
317
318 fn update_exclusion(&mut self, symbol: u8) {
321 let before = self.total_freq;
322 if let Some(st) = self.stats.iter_mut().find(|s| s.symbol == symbol) {
323 let new_freq = st.freq.saturating_add(1).min(MAX_FREQ);
324 assert!(new_freq >= st.freq, "freq must not decrease on bump");
325
326 st.freq = new_freq;
327 self.total_freq = self.stats.iter().map(|s| s.freq as u32).sum();
328 assert!(
329 self.total_freq >= before,
330 "total_freq must not shrink after update"
331 );
332 }
333 }
334}
335
336pub struct PpmModel {
365 max_order: u8,
366 contexts: HashMap<Vec<u8>, PpmContext>,
367}
368
369impl PpmModel {
370 pub fn new(max_order: u8) -> PpmResult<Self> {
376 assert!(
377 max_order > 0 && max_order <= 16,
378 "max_order out of valid range"
379 );
380 let mut m = PpmModel {
381 max_order,
382 contexts: HashMap::new(),
383 };
384 let mut root = PpmContext::new();
386 for sym in 0u8..=255 {
387 root.stats.push(State {
388 symbol: sym,
389 freq: 1,
390 });
391 }
392 root.total_freq = 256;
393 assert_eq!(root.stats.len(), 256, "root must contain all 256 symbols");
394 m.contexts.insert(Vec::new(), root);
395 Ok(m)
396 }
397
398 pub fn encode<R: Read, W: Write>(&mut self, mut input: R, output: W) -> PpmResult<W> {
401 let mut encoder = RangeEncoder::new(output);
402 let mut history = Vec::new();
403 let mut buf = [0u8; 1];
404
405 while input.read(&mut buf)? > 0 {
406 let sym = buf[0];
407 self.encode_symbol(&mut encoder, &history, sym)?;
408 self.update_model(&mut history, sym)?;
409 }
410 encoder.finish()
411 }
412
413 fn encode_symbol<W: Write>(
414 &self,
415 encoder: &mut RangeEncoder<W>,
416 history: &[u8],
417 symbol: u8,
418 ) -> PpmResult<()> {
419 assert!(history.len() <= self.max_order as usize, "history too long");
420
421 for order in (1..=self.max_order.min(history.len() as u8)).rev() {
423 let key = history[history.len() - order as usize..].to_vec();
424 if let Some(ctx) = self.contexts.get(&key) {
425 let (syms, freqs, esc, tot) = ctx.get_cumulative();
426 let mut cum = 0;
427 for (i, &s) in syms.iter().enumerate() {
429 if s == symbol {
430 return encoder.encode(cum, freqs[i], tot);
431 }
432 cum += freqs[i];
433 }
434 encoder.encode(cum, esc, tot)?;
436 }
437 }
438 let root = &self.contexts[&Vec::new()];
440 let tot0 = (root.stats.len() as u32) + 1;
441 if let Some(pos) = root.stats.iter().position(|s| s.symbol == symbol) {
442 encoder.encode(pos as u32, 1, tot0)
443 } else {
444 encoder.encode(root.stats.len() as u32, 1, tot0)
446 }
447 }
448
449 fn update_model(&mut self, history: &mut Vec<u8>, symbol: u8) -> PpmResult<()> {
451 let before = self.contexts.len();
452 let mut bumped = false;
454 for i in 0..history.len() {
455 let key = history[i..].to_vec();
456 if let Some(ctx) = self.contexts.get_mut(&key) {
457 if !bumped {
458 ctx.update_exclusion(symbol);
459 bumped = ctx.stats.iter().any(|s| s.symbol == symbol);
460 }
461 }
462 }
463 if !bumped {
465 let root = self.contexts.get_mut(&Vec::new()).unwrap();
466 root.stats.push(State { symbol, freq: 1 });
467 root.total_freq += 1;
468 }
469
470 history.push(symbol);
472 if history.len() > self.max_order as usize {
473 history.remove(0);
474 }
475 assert!(self.contexts.len() >= before, "contexts should not shrink");
476
477 let current_len = history.len();
479 let max_ctx = self.max_order.min(current_len as u8) as usize;
480 for order in 1..=max_ctx {
481 let key = history[current_len - order..].to_vec();
482 if !self.contexts.contains_key(&key) {
483 let parent_key = if key.len() > 1 {
485 key[1..].to_vec()
486 } else {
487 Vec::new()
488 };
489 let mut ctx = PpmContext::new();
490 if let Some(parent) = self.contexts.get(&parent_key) {
491 ctx.inherit_from(parent);
492 }
493 self.contexts.insert(key, ctx);
494 }
495 }
496
497 assert!(
498 history.len() <= self.max_order as usize,
499 "history exceeded max_order"
500 );
501 Ok(())
502 }
503
504 pub fn decode_symbol<R: Read>(
509 &mut self,
510 decoder: &mut RangeDecoder<R>,
511 history: &mut Vec<u8>,
512 out: &mut [u8],
513 ) -> PpmResult<()> {
514 assert!(out.len() == 1, "output buffer must be exactly one byte");
515 for order in (1..=self.max_order.min(history.len() as u8)).rev() {
517 let key = history[history.len() - order as usize..].to_vec();
518 if let Some(ctx) = self.contexts.get(&key) {
519 let (syms, freqs, esc, tot) = ctx.get_cumulative();
520 let threshold = tot.saturating_sub(esc);
521 let r = decoder.get_freq(tot)?;
522 if r < threshold {
523 let mut cum = 0;
525 for (i, &f) in freqs.iter().enumerate() {
526 if r < cum + f {
527 let sym = syms[i];
528 decoder.decode(cum, f, tot)?;
529 out[0] = sym;
530 self.update_model(history, sym)?;
531 return Ok(());
532 }
533 cum += f;
534 }
535 unreachable!();
536 } else {
537 decoder.decode(threshold, esc, tot)?;
539 }
540 }
541 }
542
543 let root = &self.contexts[&Vec::new()];
545 let tot0 = (root.stats.len() as u32) + 1;
546 let r0 = decoder.get_freq(tot0)?;
547 if r0 < root.stats.len() as u32 {
548 let sym = root.stats[r0 as usize].symbol;
549 decoder.decode(r0, 1, tot0)?;
550 out[0] = sym;
551 self.update_model(history, sym)?;
552 Ok(())
553 } else {
554 decoder.decode(root.stats.len() as u32, 1, tot0)?;
556 Err(PpmError::CorruptData)
557 }
558 }
559}
560
561pub fn encode_file<P: AsRef<Path>, Q: AsRef<Path>>(
573 input_path: P,
574 output_path: Q,
575 max_order: Option<usize>,
576) -> PpmResult<()> {
577 let input_path = input_path.as_ref();
578 let output_path = output_path.as_ref();
579
580 let input_file = File::open(input_path)?;
582 let input_len = input_file.metadata()?.len();
583
584 let mut output = File::create(output_path)?;
586 output.write_all(&input_len.to_le_bytes())?;
587
588 let mut input = File::open(input_path)?;
590
591 let order = max_order.unwrap_or(DEFAULT_ORDER as usize);
593 let mut model = PpmModel::new(order.try_into().unwrap())?;
594 model.encode(&mut input, &mut output)?;
595
596 Ok(())
597}
598
599pub fn decode_file<P: AsRef<Path>, Q: AsRef<Path>>(input_path: P, output_path: Q) -> PpmResult<()> {
610 let input_path = input_path.as_ref();
611 let output_path = output_path.as_ref();
612
613 let mut input = File::open(input_path)?;
614 let mut len_buf = [0u8; 8];
615 input.read_exact(&mut len_buf)?;
616 let expected = u64::from_le_bytes(len_buf);
617
618 let mut decoder = RangeDecoder::new(input)?;
619 let mut model = PpmModel::new(DEFAULT_ORDER)?;
620 let mut history = Vec::new();
621 let mut writer = BufWriter::new(File::create(output_path)?);
622
623 let mut buf = [0u8; 1];
624 let mut actual = 0;
625 while actual < expected {
626 model.decode_symbol(&mut decoder, &mut history, &mut buf)?;
627 writer.write_all(&buf)?;
628 actual += 1;
629 }
630
631 Ok(())
632}