Skip to main content

ppmd_core/
lib.rs

1//! PPMd-style entropy coder for byte streams.
2//!
3//! This crate implements a Prediction by Partial Matching (PPM) compressor/decompressor
4//! using a range encoder/decoder underneath.  PPM builds adaptive probability models
5//! based on the last N bytes of context, where N is the "order."  Higher orders
6//! give better compression at the cost of more memory and CPU; lower orders
7//! run faster but yield larger output.
8//!
9//! # Key Parameters
10//!
11//! - `DEFAULT_ORDER: u8 = 5`  
12//!   The default context length (order-5).  A good middle ground between speed and compression.
13//! - `MAX_FREQ: u8 = 124`  
14//!   Maximum per-symbol frequency in any context.  Caps frequencies to avoid overflow.
15//! - `TOP: u32 = 1 << 24` and `BOT: u32 = 1 << 15`  
16//!   Thresholds used by the underlying range coder to renormalize its internal registers.
17//!   You generally don't need to touch these unless you're tuning the coder itself.
18
19#![forbid(clippy::let_underscore_drop)]
20#![forbid(unsafe_code)]
21#![warn(clippy::unwrap_used)]
22#![warn(missing_docs)]
23
24use std::collections::HashMap;
25use std::fs::File;
26use std::io::{self, BufWriter, Read, Write};
27use std::path::Path;
28use thiserror::Error as ThisError;
29
30pub const TOP: u32 = 1 << 24;
31pub const BOT: u32 = 1 << 15;
32pub const MAX_FREQ: u8 = 124;
33pub const DEFAULT_ORDER: u8 = 5;
34
35/// A specialized `Result` using [`PpmError`] for errors.
36pub type PpmResult<T> = Result<T, PpmError>;
37
38/// The set of errors that can occur during PPM encoding or decoding.
39#[derive(ThisError, Debug)]
40pub enum PpmError {
41    /// Errors from any underlying I/O operations (file, stream, etc.).
42    #[error("IO error: {0}")]
43    IoError(#[from] io::Error),
44
45    /// The input data was unexpectedly corrupt (e.g., decoder sees an impossible symbol).
46    #[error("Corrupt input data")]
47    CorruptData,
48
49    /// The decoder was put into an invalid state (should not happen in normal use).
50    #[error("Invalid decoder state")]
51    InvalidState,
52
53    /// Errors in the PPM model itself (e.g., invalid parameters).
54    #[error("Model error: {0}")]
55    ModelError(&'static str),
56}
57
58/// Streaming range‐encoder for arithmetic coding.
59///
60/// The encoder maintains three key internal values:
61/// - `low`: the low end of the current coding interval  
62/// - `range`: the size of the current coding interval  
63/// - `buffer`: a byte buffer (flushed in 4 KB chunks) holding the high‐order output bytes
64///
65/// For each symbol you wish to encode, call [`encode`](#method.encode) with:
66/// 1. `cum_freq`: cumulative frequency of all symbols less than the one you’re encoding  
67/// 2. `freq`: the frequency of the symbol itself  
68/// 3. `tot_freq`: the total of all symbol frequencies in the current context  
69///
70/// After encoding every symbol, call [`finish`](#method.finish) to flush any remaining bytes
71/// and retrieve the underlying writer.
72///
73/// # Example
74///
75/// ```no_run
76/// use std::fs::File;
77/// use ppmd_core::{RangeEncoder, PpmResult};
78///
79/// fn encode_stream() -> PpmResult<()> {
80///     let file = File::create("out.ppm")?;
81///     let mut encoder = RangeEncoder::new(file);
82///
83///     // Suppose `model` yields (cum, freq, tot) triples for each byte:
84///     for (cum, freq, tot) in model.symbols() {
85///         encoder.encode(cum, freq, tot)?;
86///     }
87///
88///     // Finalize and get back the file writer
89///     let _file = encoder.finish()?;
90///     Ok(())
91/// }
92/// ```
93pub struct RangeEncoder<W: Write> {
94    low: u32,
95    range: u32,
96    buffer: Vec<u8>,
97    writer: W,
98}
99
100impl<W: Write> RangeEncoder<W> {
101    /// Create a new range encoder wrapping `writer`.
102    ///
103    /// Initializes:
104    /// - `low = 0`  
105    /// - `range = u32::MAX` (the full 32-bit interval)  
106    /// - an internal 4 KB output buffer  
107    ///
108    /// The encoder will emit one output byte at a time into `buffer`, flushing
109    /// to `writer` whenever the buffer is full.
110    pub fn new(writer: W) -> Self {
111        Self {
112            low: 0,
113            range: u32::MAX,
114            buffer: Vec::with_capacity(4096),
115            writer,
116        }
117    }
118
119    fn encode(&mut self, cum_freq: u32, freq: u32, tot_freq: u32) -> PpmResult<()> {
120        assert!(tot_freq > 0, "total frequency must be positive");
121        assert!(freq > 0, "symbol frequency must be positive");
122        assert!(cum_freq < tot_freq, "cumulative freq out of range");
123        assert!(cum_freq + freq <= tot_freq, "freq interval exceeds total");
124
125        self.range /= tot_freq;
126        self.low = self.low.wrapping_add(cum_freq * self.range);
127        self.range = self.range.wrapping_mul(freq);
128
129        while (self.low ^ (self.low.wrapping_add(self.range))) < TOP || self.range < BOT {
130            if self.range < BOT {
131                self.range = (-(self.low as i32) as u32) & (BOT - 1);
132            }
133            self.buffer.push((self.low >> 24) as u8);
134            self.low <<= 8;
135            self.range <<= 8;
136            if self.buffer.len() >= 4096 {
137                self.writer.write_all(&self.buffer)?;
138                self.buffer.clear();
139            }
140        }
141        Ok(())
142    }
143
144    fn finish(mut self) -> PpmResult<W> {
145        assert!(self.range > 0, "range became zero in finish");
146        for _ in 0..4 {
147            self.buffer.push((self.low >> 24) as u8);
148            self.low <<= 8;
149        }
150        assert!(!self.buffer.is_empty(), "nothing to flush in finish");
151        self.writer.write_all(&self.buffer)?;
152        Ok(self.writer)
153    }
154}
155
156/// Streaming range‐decoder for arithmetic coding.
157///
158/// The decoder maintains three internal registers:
159/// - `low`: the low end of the current coding interval  
160/// - `code`: the buffered input bits read from the stream  
161/// - `range`: the size of the current coding interval  
162///
163/// On each symbol decode, you first call [`get_freq`] to map the current code
164/// to a frequency within the total frequency range, then call [`decode`]
165/// to narrow the interval and consume bits as needed.
166///
167/// # Example
168///
169/// ```no_run
170/// use std::fs::File;
171/// use ppmd_core::{RangeDecoder, PpmModel, PpmResult};
172///
173/// fn decode_stream() -> PpmResult<()> {
174///     let file = File::open("data.ppm")?;
175///     let mut decoder = RangeDecoder::new(file)?;
176///     let mut model = PpmModel::new(5)?;
177///     let mut history = Vec::new();
178///     let mut out_byte = [0u8; 1];
179///
180///     // Repeatedly call `decode_symbol` until end of stream
181///     while model.decode_symbol(&mut decoder, &mut history, &mut out_byte).is_ok() {
182///         print!("{}", out_byte[0] as char);
183///     }
184///     Ok(())
185/// }
186/// ```
187pub struct RangeDecoder<R: Read> {
188    low: u32,
189    code: u32,
190    range: u32,
191    reader: R,
192    buffer: [u8; 1],
193}
194
195impl<R: Read> RangeDecoder<R> {
196    /// Initialize a new `RangeDecoder` by reading the first 4 bytes
197    /// from `reader` into the internal `code` register.
198    ///
199    /// The range decoder uses these first 32 bits as its starting buffer.
200    /// Subsequent calls to [`get_freq`] and [`decode`] will consume more
201    /// bytes from `reader` as needed to renormalize the interval.
202    ///
203    /// # Errors
204    ///
205    /// Returns `PpmError::IoError` if reading the initial 4-byte code prefix fails.
206    pub fn new(mut reader: R) -> PpmResult<Self> {
207        let mut code = 0;
208        for _ in 0..4 {
209            let mut byte = [0];
210            reader.read_exact(&mut byte)?;
211            code = (code << 8) | u32::from(byte[0]);
212        }
213        Ok(Self {
214            low: 0,
215            code,
216            range: u32::MAX,
217            reader,
218            buffer: [0],
219        })
220    }
221
222    fn get_freq(&mut self, tot_freq: u32) -> PpmResult<u32> {
223        assert!(tot_freq > 0, "total frequency must be positive");
224        self.range /= tot_freq;
225        let tmp = (self.code.wrapping_sub(self.low)) / self.range;
226        if tmp >= tot_freq {
227            return Err(PpmError::CorruptData);
228        }
229        Ok(tmp)
230    }
231
232    fn decode(&mut self, cum_freq: u32, freq: u32, tot_freq: u32) -> PpmResult<()> {
233        assert!(freq > 0, "frequency must be positive");
234        assert!(cum_freq < tot_freq, "cumulative freq out of range");
235        assert!(cum_freq + freq <= tot_freq, "freq interval exceeds total");
236
237        if cum_freq.wrapping_add(freq) > tot_freq {
238            return Err(PpmError::CorruptData);
239        }
240        self.low = self.low.wrapping_add(cum_freq * self.range);
241        self.range = self.range.wrapping_mul(freq);
242        while (self.low ^ (self.low.wrapping_add(self.range))) < TOP || self.range < BOT {
243            if self.range < BOT {
244                self.range = (-(self.low as i32) as u32) & (BOT - 1);
245            }
246            self.reader.read_exact(&mut self.buffer)?;
247            self.code = (self.code << 8) | u32::from(self.buffer[0]);
248            self.low <<= 8;
249            self.range <<= 8;
250        }
251        Ok(())
252    }
253}
254
255#[derive(Debug, Clone)]
256struct State {
257    symbol: u8,
258    freq: u8,
259}
260
261#[derive(Clone, Debug)]
262struct PpmContext {
263    stats: Vec<State>,
264    total_freq: u32,
265}
266
267impl PpmContext {
268    fn new() -> Self {
269        PpmContext {
270            stats: Vec::new(),
271            total_freq: 0,
272        }
273    }
274
275    /// PPMII "information inheritance":
276    /// copy each parent frequency as max(1, parent.freq/2)
277    fn inherit_from(&mut self, parent: &PpmContext) {
278        // parents stats should be non‐empty if called in contexts
279        assert!(parent.stats.len() > 0, "parent context has no stats");
280
281        self.stats.clear();
282        for st in &parent.stats {
283            let f = (st.freq / 2).max(1);
284            assert!(f >= 1, "inherited frequency dropped below 1");
285            self.stats.push(State {
286                symbol: st.symbol,
287                freq: f,
288            });
289        }
290        self.total_freq = self.stats.iter().map(|s| s.freq as u32).sum();
291        assert!(
292            self.total_freq > 0,
293            "total_freq must be positive after inherit"
294        );
295    }
296
297    /// PPMD escape probabilities:
298    ///   symbol fᵢ = 2·cᵢ − 1
299    ///   escape = q  (number of distinct symbols)
300    ///   tot    = 2·C
301    fn get_cumulative(&self) -> (Vec<u8>, Vec<u32>, u32, u32) {
302        let c: u32 = self.stats.iter().map(|s| s.freq as u32).sum();
303        let q = self.stats.len() as u32;
304        let tot = 2 * c;
305        assert!(tot > 0, "total (2*C) must be positive");
306
307        let mut syms = Vec::with_capacity(self.stats.len());
308        let mut freqs = Vec::with_capacity(self.stats.len());
309        for st in &self.stats {
310            let f = (st.freq as u32) * 2 - 1;
311            assert!(f > 0, "computed symbol frequency must be positive");
312            syms.push(st.symbol);
313            freqs.push(f);
314        }
315        (syms, freqs, q, tot)
316    }
317
318    /// Lazy exclusion: bump only the first (highest-order) context
319    /// that actually contained the symbol.
320    fn update_exclusion(&mut self, symbol: u8) {
321        let before = self.total_freq;
322        if let Some(st) = self.stats.iter_mut().find(|s| s.symbol == symbol) {
323            let new_freq = st.freq.saturating_add(1).min(MAX_FREQ);
324            assert!(new_freq >= st.freq, "freq must not decrease on bump");
325
326            st.freq = new_freq;
327            self.total_freq = self.stats.iter().map(|s| s.freq as u32).sum();
328            assert!(
329                self.total_freq >= before,
330                "total_freq must not shrink after update"
331            );
332        }
333    }
334}
335
336/// The central PPM model.  Maintains up to `max_order` contexts
337/// and dynamically updates symbol frequencies as you encode or decode.
338///
339/// Higher `max_order` (e.g. `Some(8)`) means the model looks at up to 8 previous
340/// bytes for each prediction:  
341/// - **Pros**: Better predictions and higher compression ratio  
342/// - **Cons**: More memory and CPU overhead  
343///
344/// Lower `max_order` (e.g. `None` → default order 5) is faster and lighter,
345/// but compresses less effectively.
346///
347/// # Examples
348///
349/// ```no_run
350/// use ppmd_core::{encode_file, decode_file, PpmResult};
351///
352/// fn main() -> PpmResult<()> {
353///     // Use default order = 5
354///     encode_file("input.bin", "out.ppm", None)?;
355///
356///     // Use a custom order = 8 for potentially better compression
357///     encode_file("input.bin", "out8.ppm", Some(8))?;
358///
359///     // Decode (always uses order 5)
360///     decode_file("out.ppm", "decoded.bin")?;
361///     Ok(())
362/// }
363/// ```
364pub struct PpmModel {
365    max_order: u8,
366    contexts: HashMap<Vec<u8>, PpmContext>,
367}
368
369impl PpmModel {
370    /// Create a new PPM model with contexts up to `max_order` (1..=16).
371    ///
372    /// # Panics
373    ///
374    /// Panics if `max_order == 0` or `max_order > 16`.
375    pub fn new(max_order: u8) -> PpmResult<Self> {
376        assert!(
377            max_order > 0 && max_order <= 16,
378            "max_order out of valid range"
379        );
380        let mut m = PpmModel {
381            max_order,
382            contexts: HashMap::new(),
383        };
384        // build the order−1 root context with uniform order‑0 stats
385        let mut root = PpmContext::new();
386        for sym in 0u8..=255 {
387            root.stats.push(State {
388                symbol: sym,
389                freq: 1,
390            });
391        }
392        root.total_freq = 256;
393        assert_eq!(root.stats.len(), 256, "root must contain all 256 symbols");
394        m.contexts.insert(Vec::new(), root);
395        Ok(m)
396    }
397
398    /// Encode the entire contents of `input` into `output`, updating the model
399    /// adaptively as you go.
400    pub fn encode<R: Read, W: Write>(&mut self, mut input: R, output: W) -> PpmResult<W> {
401        let mut encoder = RangeEncoder::new(output);
402        let mut history = Vec::new();
403        let mut buf = [0u8; 1];
404
405        while input.read(&mut buf)? > 0 {
406            let sym = buf[0];
407            self.encode_symbol(&mut encoder, &history, sym)?;
408            self.update_model(&mut history, sym)?;
409        }
410        encoder.finish()
411    }
412
413    fn encode_symbol<W: Write>(
414        &self,
415        encoder: &mut RangeEncoder<W>,
416        history: &[u8],
417        symbol: u8,
418    ) -> PpmResult<()> {
419        assert!(history.len() <= self.max_order as usize, "history too long");
420
421        // back‑off from highest order down to order−1:
422        for order in (1..=self.max_order.min(history.len() as u8)).rev() {
423            let key = history[history.len() - order as usize..].to_vec();
424            if let Some(ctx) = self.contexts.get(&key) {
425                let (syms, freqs, esc, tot) = ctx.get_cumulative();
426                let mut cum = 0;
427                // if symbol found in this context, emit it
428                for (i, &s) in syms.iter().enumerate() {
429                    if s == symbol {
430                        return encoder.encode(cum, freqs[i], tot);
431                    }
432                    cum += freqs[i];
433                }
434                // otherwise emit escape
435                encoder.encode(cum, esc, tot)?;
436            }
437        }
438        // final fallback at order−1 root (uniform)
439        let root = &self.contexts[&Vec::new()];
440        let tot0 = (root.stats.len() as u32) + 1;
441        if let Some(pos) = root.stats.iter().position(|s| s.symbol == symbol) {
442            encoder.encode(pos as u32, 1, tot0)
443        } else {
444            // one final escape (should never really happen)
445            encoder.encode(root.stats.len() as u32, 1, tot0)
446        }
447    }
448
449    /// After emitting symbol, update ALL contexts up to max_order
450    fn update_model(&mut self, history: &mut Vec<u8>, symbol: u8) -> PpmResult<()> {
451        let before = self.contexts.len();
452        // lazy‐exclusion update on the longest suffix that contained the symbol
453        let mut bumped = false;
454        for i in 0..history.len() {
455            let key = history[i..].to_vec();
456            if let Some(ctx) = self.contexts.get_mut(&key) {
457                if !bumped {
458                    ctx.update_exclusion(symbol);
459                    bumped = ctx.stats.iter().any(|s| s.symbol == symbol);
460                }
461            }
462        }
463        // if it never existed, add to the root order‑0
464        if !bumped {
465            let root = self.contexts.get_mut(&Vec::new()).unwrap();
466            root.stats.push(State { symbol, freq: 1 });
467            root.total_freq += 1;
468        }
469
470        // Slide the history window
471        history.push(symbol);
472        if history.len() > self.max_order as usize {
473            history.remove(0);
474        }
475        assert!(self.contexts.len() >= before, "contexts should not shrink");
476
477        // Create or inherit every missing context suffix up to max_order
478        let current_len = history.len();
479        let max_ctx = self.max_order.min(current_len as u8) as usize;
480        for order in 1..=max_ctx {
481            let key = history[current_len - order..].to_vec();
482            if !self.contexts.contains_key(&key) {
483                // build from the (order−1) parent:
484                let parent_key = if key.len() > 1 {
485                    key[1..].to_vec()
486                } else {
487                    Vec::new()
488                };
489                let mut ctx = PpmContext::new();
490                if let Some(parent) = self.contexts.get(&parent_key) {
491                    ctx.inherit_from(parent);
492                }
493                self.contexts.insert(key, ctx);
494            }
495        }
496
497        assert!(
498            history.len() <= self.max_order as usize,
499            "history exceeded max_order"
500        );
501        Ok(())
502    }
503
504    /// Decode one symbol at a time from `decoder`, writing to `out`,
505    /// and update the model adaptively.
506    ///
507    /// This is used by `decode_file` under the hood.
508    pub fn decode_symbol<R: Read>(
509        &mut self,
510        decoder: &mut RangeDecoder<R>,
511        history: &mut Vec<u8>,
512        out: &mut [u8],
513    ) -> PpmResult<()> {
514        assert!(out.len() == 1, "output buffer must be exactly one byte");
515        // back‑off decode:
516        for order in (1..=self.max_order.min(history.len() as u8)).rev() {
517            let key = history[history.len() - order as usize..].to_vec();
518            if let Some(ctx) = self.contexts.get(&key) {
519                let (syms, freqs, esc, tot) = ctx.get_cumulative();
520                let threshold = tot.saturating_sub(esc);
521                let r = decoder.get_freq(tot)?;
522                if r < threshold {
523                    // actual symbol
524                    let mut cum = 0;
525                    for (i, &f) in freqs.iter().enumerate() {
526                        if r < cum + f {
527                            let sym = syms[i];
528                            decoder.decode(cum, f, tot)?;
529                            out[0] = sym;
530                            self.update_model(history, sym)?;
531                            return Ok(());
532                        }
533                        cum += f;
534                    }
535                    unreachable!();
536                } else {
537                    // escape
538                    decoder.decode(threshold, esc, tot)?;
539                }
540            }
541        }
542
543        // root fallback:
544        let root = &self.contexts[&Vec::new()];
545        let tot0 = (root.stats.len() as u32) + 1;
546        let r0 = decoder.get_freq(tot0)?;
547        if r0 < root.stats.len() as u32 {
548            let sym = root.stats[r0 as usize].symbol;
549            decoder.decode(r0, 1, tot0)?;
550            out[0] = sym;
551            self.update_model(history, sym)?;
552            Ok(())
553        } else {
554            // end of stream
555            decoder.decode(root.stats.len() as u32, 1, tot0)?;
556            Err(PpmError::CorruptData)
557        }
558    }
559}
560
561/// Compress the file at `input_path` into `output_path` using PPM.
562///  
563/// - `max_order = None` ⇒ uses the crate’s `DEFAULT_ORDER` (5).  
564/// - `max_order = Some(n)` ⇒ uses order n (up to 16).  
565///
566/// By default, we first write an 8-byte little-endian prefix giving the
567/// original file length, then the PPM‐encoded payload.
568///
569/// # Errors
570///
571/// Returns an error if any I/O or encoding step fails.
572pub fn encode_file<P: AsRef<Path>, Q: AsRef<Path>>(
573    input_path: P,
574    output_path: Q,
575    max_order: Option<usize>,
576) -> PpmResult<()> {
577    let input_path = input_path.as_ref();
578    let output_path = output_path.as_ref();
579
580    // Step 1: determine original input length
581    let input_file = File::open(input_path)?;
582    let input_len = input_file.metadata()?.len();
583
584    // Step 2: open output and write length prefix
585    let mut output = File::create(output_path)?;
586    output.write_all(&input_len.to_le_bytes())?;
587
588    // Step 3: re-open input for reading (we already used it for metadata)
589    let mut input = File::open(input_path)?;
590
591    // Step 4: encode using PPM
592    let order = max_order.unwrap_or(DEFAULT_ORDER as usize);
593    let mut model = PpmModel::new(order.try_into().unwrap())?;
594    model.encode(&mut input, &mut output)?;
595
596    Ok(())
597}
598
599/// Decompress `input_path` (which must have been produced by `encode_file`)
600/// back into `output_path`, using the default `DEFAULT_ORDER = 5`.
601///
602/// Reads the 8-byte length prefix, then decodes exactly that many bytes
603/// via the range decoder + PPM model.
604///
605/// # Errors
606///
607/// Returns an error if any I/O or decoding step fails, or if the input
608/// is corrupt.
609pub fn decode_file<P: AsRef<Path>, Q: AsRef<Path>>(input_path: P, output_path: Q) -> PpmResult<()> {
610    let input_path = input_path.as_ref();
611    let output_path = output_path.as_ref();
612
613    let mut input = File::open(input_path)?;
614    let mut len_buf = [0u8; 8];
615    input.read_exact(&mut len_buf)?;
616    let expected = u64::from_le_bytes(len_buf);
617
618    let mut decoder = RangeDecoder::new(input)?;
619    let mut model = PpmModel::new(DEFAULT_ORDER)?;
620    let mut history = Vec::new();
621    let mut writer = BufWriter::new(File::create(output_path)?);
622
623    let mut buf = [0u8; 1];
624    let mut actual = 0;
625    while actual < expected {
626        model.decode_symbol(&mut decoder, &mut history, &mut buf)?;
627        writer.write_all(&buf)?;
628        actual += 1;
629    }
630
631    Ok(())
632}