Skip to main content

nodedb_codec/
rans.rs

1//! Interleaved rANS (Asymmetric Numeral Systems) entropy coder.
2//!
3//! Compresses byte streams to the Shannon entropy limit — optimal
4//! compression ratio at Huffman-like speed. Used as the terminal
5//! compressor for cold/S3 tier partitions where ratio matters more
6//! than decompression speed.
7//!
8//! 4-stream interleaving breaks the sequential dependency chain:
9//! the CPU processes all streams simultaneously, achieving high
10//! throughput despite the inherently sequential nature of ANS.
11//!
12//! Wire format:
13//! ```text
14//! [4 bytes] uncompressed size (LE u32)
15//! [256 × 4 bytes] frequency table (LE u32 per byte value)
16//! [4 bytes] compressed size (LE u32)
17//! [N bytes] interleaved rANS bitstream (4 streams)
18//! ```
19
20use crate::error::CodecError;
21
22/// Number of interleaved streams.
23const NUM_STREAMS: usize = 4;
24
25/// rANS probability scale (power of 2 for fast division).
26const PROB_BITS: u32 = 14;
27const PROB_SCALE: u32 = 1 << PROB_BITS;
28
29/// rANS state lower bound.
30const RANS_L: u32 = 1 << 23;
31
32/// Frequency table header size: 256 × 4 bytes = 1024 bytes.
33const FREQ_TABLE_SIZE: usize = 256 * 4;
34
35/// Full header: 4 (uncomp size) + 1024 (freq table) + 4 (comp size).
36const HEADER_SIZE: usize = 4 + FREQ_TABLE_SIZE + 4;
37
38// ---------------------------------------------------------------------------
39// Public API
40// ---------------------------------------------------------------------------
41
42/// Compress bytes using interleaved rANS.
43pub fn encode(data: &[u8]) -> Vec<u8> {
44    if data.is_empty() {
45        let out = vec![0u8; HEADER_SIZE];
46        // uncompressed_size = 0, freq table = all zeros, compressed_size = 0
47        return out;
48    }
49
50    // Build frequency table.
51    let mut freqs = [0u32; 256];
52    for &b in data {
53        freqs[b as usize] += 1;
54    }
55
56    // Normalize frequencies to sum to PROB_SCALE.
57    let norm_freqs = normalize_frequencies(&freqs, data.len());
58
59    // Build cumulative frequency table.
60    let (cum_freqs, sym_freqs) = build_cum_table(&norm_freqs);
61
62    // Encode using 4 interleaved streams.
63    // Each stream processes every 4th byte: stream 0 gets bytes 0,4,8,...
64    let mut streams: [Vec<u8>; NUM_STREAMS] = std::array::from_fn(|_| Vec::new());
65    let mut states = [RANS_L; NUM_STREAMS];
66
67    // Encode in REVERSE order (rANS encodes backward, decodes forward).
68    for i in (0..data.len()).rev() {
69        let stream_idx = i % NUM_STREAMS;
70        let sym = data[i] as usize;
71        let freq = sym_freqs[sym];
72        let start = cum_freqs[sym];
73
74        if freq == 0 {
75            continue; // Symbol with zero frequency — shouldn't happen after normalization.
76        }
77
78        rans_encode_symbol(
79            &mut states[stream_idx],
80            &mut streams[stream_idx],
81            start,
82            freq,
83        );
84    }
85
86    // Flush final states.
87    for i in 0..NUM_STREAMS {
88        let s = states[i];
89        streams[i].push((s & 0xFF) as u8);
90        streams[i].push(((s >> 8) & 0xFF) as u8);
91        streams[i].push(((s >> 16) & 0xFF) as u8);
92        streams[i].push(((s >> 24) & 0xFF) as u8);
93    }
94
95    // Build output.
96    let total_compressed: usize = streams.iter().map(|s| s.len()).sum();
97    let mut out = Vec::with_capacity(HEADER_SIZE + total_compressed + NUM_STREAMS * 4);
98
99    // Header: uncompressed size.
100    out.extend_from_slice(&(data.len() as u32).to_le_bytes());
101
102    // Frequency table (raw counts for decoding).
103    for &f in &norm_freqs {
104        out.extend_from_slice(&f.to_le_bytes());
105    }
106
107    // Compressed size.
108    let comp_payload_size = total_compressed + NUM_STREAMS * 4; // streams + per-stream sizes
109    out.extend_from_slice(&(comp_payload_size as u32).to_le_bytes());
110
111    // Per-stream sizes (4 bytes each).
112    for s in &streams {
113        out.extend_from_slice(&(s.len() as u32).to_le_bytes());
114    }
115
116    // Stream data (reversed — rANS bitstream is read backward).
117    for s in &streams {
118        out.extend_from_slice(s);
119    }
120
121    out
122}
123
124/// Decompress interleaved rANS data.
125pub fn decode(data: &[u8]) -> Result<Vec<u8>, CodecError> {
126    if data.len() < HEADER_SIZE {
127        return Err(CodecError::Truncated {
128            expected: HEADER_SIZE,
129            actual: data.len(),
130        });
131    }
132
133    let uncompressed_size = u32::from_le_bytes([data[0], data[1], data[2], data[3]]) as usize;
134    if uncompressed_size == 0 {
135        return Ok(Vec::new());
136    }
137
138    // Read frequency table.
139    let mut norm_freqs = [0u32; 256];
140    for (i, freq) in norm_freqs.iter_mut().enumerate() {
141        let pos = 4 + i * 4;
142        *freq = u32::from_le_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
143    }
144
145    let (cum_freqs, sym_freqs) = build_cum_table(&norm_freqs);
146
147    // Build reverse lookup table for decoding.
148    let lookup = build_decode_table(&cum_freqs, &sym_freqs);
149
150    let _comp_size = u32::from_le_bytes([
151        data[HEADER_SIZE - 4],
152        data[HEADER_SIZE - 3],
153        data[HEADER_SIZE - 2],
154        data[HEADER_SIZE - 1],
155    ]) as usize;
156
157    // Read per-stream sizes.
158    let mut pos = HEADER_SIZE;
159    if pos + NUM_STREAMS * 4 > data.len() {
160        return Err(CodecError::Truncated {
161            expected: pos + NUM_STREAMS * 4,
162            actual: data.len(),
163        });
164    }
165
166    let mut stream_sizes = [0usize; NUM_STREAMS];
167    for size in stream_sizes.iter_mut() {
168        *size =
169            u32::from_le_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]) as usize;
170        pos += 4;
171    }
172
173    // Read streams.
174    let mut stream_data: [Vec<u8>; NUM_STREAMS] = std::array::from_fn(|_| Vec::new());
175    for i in 0..NUM_STREAMS {
176        let end = pos + stream_sizes[i];
177        if end > data.len() {
178            return Err(CodecError::Truncated {
179                expected: end,
180                actual: data.len(),
181            });
182        }
183        stream_data[i] = data[pos..end].to_vec();
184        pos += stream_sizes[i];
185    }
186
187    // Initialize states from the end of each stream.
188    let mut states = [0u32; NUM_STREAMS];
189    let mut stream_pos = [0usize; NUM_STREAMS];
190    for i in 0..NUM_STREAMS {
191        let s = &stream_data[i];
192        if s.len() < 4 {
193            return Err(CodecError::Corrupt {
194                detail: format!("rANS stream {i} too short for state"),
195            });
196        }
197        let end = s.len();
198        states[i] = u32::from_le_bytes([s[end - 4], s[end - 3], s[end - 2], s[end - 1]]);
199        stream_pos[i] = end - 4;
200    }
201
202    // Decode forward.
203    let mut output = vec![0u8; uncompressed_size];
204    for (i, out_byte) in output.iter_mut().enumerate() {
205        let stream_idx = i % NUM_STREAMS;
206        let (sym, new_state) =
207            rans_decode_symbol(states[stream_idx], &lookup, &cum_freqs, &sym_freqs);
208        *out_byte = sym;
209        states[stream_idx] = rans_decode_renorm(
210            new_state,
211            &stream_data[stream_idx],
212            &mut stream_pos[stream_idx],
213        );
214    }
215
216    Ok(output)
217}
218
219// ---------------------------------------------------------------------------
220// rANS core operations
221// ---------------------------------------------------------------------------
222
223fn rans_encode_symbol(state: &mut u32, bitstream: &mut Vec<u8>, start: u32, freq: u32) {
224    // Renormalize: output bytes until state is in the correct range.
225    let max_state = ((RANS_L >> PROB_BITS) << 8) * freq;
226    while *state >= max_state {
227        bitstream.push((*state & 0xFF) as u8);
228        *state >>= 8;
229    }
230
231    // Encode symbol.
232    *state = ((*state / freq) << PROB_BITS) + (*state % freq) + start;
233}
234
235fn rans_decode_symbol(
236    state: u32,
237    lookup: &[u8; PROB_SCALE as usize],
238    cum_freqs: &[u32; 257],
239    sym_freqs: &[u32; 256],
240) -> (u8, u32) {
241    let slot = state & (PROB_SCALE - 1);
242    let sym = lookup[slot as usize];
243    let start = cum_freqs[sym as usize];
244    let freq = sym_freqs[sym as usize];
245
246    let new_state = freq * (state >> PROB_BITS) + slot - start;
247    (sym, new_state)
248}
249
250fn rans_decode_renorm(mut state: u32, stream: &[u8], pos: &mut usize) -> u32 {
251    while state < RANS_L && *pos > 0 {
252        *pos -= 1;
253        state = (state << 8) | stream[*pos] as u32;
254    }
255    state
256}
257
258// ---------------------------------------------------------------------------
259// Frequency table operations
260// ---------------------------------------------------------------------------
261
262/// Normalize raw frequencies to sum to PROB_SCALE.
263fn normalize_frequencies(freqs: &[u32; 256], total: usize) -> [u32; 256] {
264    let mut norm = [0u32; 256];
265    let mut sum = 0u32;
266    let total_f64 = total as f64;
267
268    // First pass: proportional scaling.
269    for i in 0..256 {
270        if freqs[i] > 0 {
271            // Ensure every present symbol gets at least frequency 1.
272            norm[i] = ((freqs[i] as f64 / total_f64 * PROB_SCALE as f64).round() as u32).max(1);
273            sum += norm[i];
274        }
275    }
276
277    // Adjust to make sum exactly PROB_SCALE.
278    if sum > 0 {
279        while sum > PROB_SCALE {
280            // Find the symbol with the highest frequency and reduce it.
281            let max_idx = norm
282                .iter()
283                .enumerate()
284                .filter(|(_, f)| **f > 1)
285                .max_by_key(|(_, f)| **f)
286                .map(|(i, _)| i)
287                .unwrap_or(0);
288            norm[max_idx] -= 1;
289            sum -= 1;
290        }
291        while sum < PROB_SCALE {
292            let max_idx = norm
293                .iter()
294                .enumerate()
295                .max_by_key(|(_, f)| **f)
296                .map(|(i, _)| i)
297                .unwrap_or(0);
298            norm[max_idx] += 1;
299            sum += 1;
300        }
301    }
302
303    norm
304}
305
306/// Build cumulative frequency table.
307fn build_cum_table(freqs: &[u32; 256]) -> ([u32; 257], [u32; 256]) {
308    let mut cum = [0u32; 257];
309    let sym_freqs = *freqs;
310    for i in 0..256 {
311        cum[i + 1] = cum[i] + freqs[i];
312    }
313    (cum, sym_freqs)
314}
315
316/// Build decode lookup table: for each slot in [0, PROB_SCALE), which symbol?
317fn build_decode_table(
318    cum_freqs: &[u32; 257],
319    _sym_freqs: &[u32; 256],
320) -> [u8; PROB_SCALE as usize] {
321    let mut table = [0u8; PROB_SCALE as usize];
322    for sym in 0..256u16 {
323        let start = cum_freqs[sym as usize] as usize;
324        let end = cum_freqs[sym as usize + 1] as usize;
325        for entry in table[start..end].iter_mut() {
326            *entry = sym as u8;
327        }
328    }
329    table
330}
331
332#[cfg(test)]
333mod tests {
334    use super::*;
335
336    #[test]
337    fn empty_roundtrip() {
338        let encoded = encode(&[]);
339        let decoded = decode(&encoded).unwrap();
340        assert!(decoded.is_empty());
341    }
342
343    #[test]
344    fn single_byte() {
345        let encoded = encode(&[42]);
346        let decoded = decode(&encoded).unwrap();
347        assert_eq!(decoded, vec![42]);
348    }
349
350    #[test]
351    fn repeated_bytes() {
352        let data = vec![0u8; 10_000];
353        let encoded = encode(&data);
354        let decoded = decode(&encoded).unwrap();
355        assert_eq!(decoded, data);
356
357        // Highly repetitive → near-zero entropy → excellent compression.
358        let ratio = data.len() as f64 / encoded.len() as f64;
359        assert!(
360            ratio > 2.0,
361            "repeated bytes should compress >2x, got {ratio:.1}x"
362        );
363    }
364
365    #[test]
366    fn text_data() {
367        let text = b"the quick brown fox jumps over the lazy dog. ";
368        let data: Vec<u8> = text.iter().copied().cycle().take(10_000).collect();
369        let encoded = encode(&data);
370        let decoded = decode(&encoded).unwrap();
371        assert_eq!(decoded, data);
372
373        let ratio = data.len() as f64 / encoded.len() as f64;
374        assert!(ratio > 1.5, "text should compress >1.5x, got {ratio:.1}x");
375    }
376
377    #[test]
378    fn uniform_random_data() {
379        // Uniform random → ~8 bits/byte → no compression possible.
380        let mut data = vec![0u8; 5000];
381        let mut rng: u64 = 12345;
382        for byte in &mut data {
383            rng = rng.wrapping_mul(6364136223846793005).wrapping_add(1);
384            *byte = (rng >> 33) as u8;
385        }
386        let encoded = encode(&data);
387        let decoded = decode(&encoded).unwrap();
388        assert_eq!(decoded, data);
389    }
390
391    #[test]
392    fn all_byte_values() {
393        // All 256 byte values present.
394        let data: Vec<u8> = (0..=255u8).cycle().take(4096).collect();
395        let encoded = encode(&data);
396        let decoded = decode(&encoded).unwrap();
397        assert_eq!(decoded, data);
398    }
399
400    #[test]
401    fn skewed_distribution() {
402        // 90% zeros, 10% ones — should compress well.
403        let mut data = vec![0u8; 10_000];
404        for i in 0..1000 {
405            data[i * 10] = 1;
406        }
407        let encoded = encode(&data);
408        let decoded = decode(&encoded).unwrap();
409        assert_eq!(decoded, data);
410
411        let ratio = data.len() as f64 / encoded.len() as f64;
412        assert!(
413            ratio > 1.5,
414            "skewed data should compress >1.5x, got {ratio:.1}x"
415        );
416    }
417
418    #[test]
419    fn better_than_raw_on_structured() {
420        // Structured data after type-aware preprocessing (typical pipeline output).
421        let mut data = Vec::with_capacity(10_000);
422        for i in 0..10_000 {
423            data.push((i % 16) as u8); // Low entropy, 4 bits/byte → 2x compression.
424        }
425        let encoded = encode(&data);
426        let decoded = decode(&encoded).unwrap();
427        assert_eq!(decoded, data);
428
429        let ratio = data.len() as f64 / encoded.len() as f64;
430        assert!(
431            ratio > 1.5,
432            "low-entropy data should compress >1.5x, got {ratio:.1}x"
433        );
434    }
435
436    #[test]
437    fn truncated_input_errors() {
438        assert!(decode(&[]).is_err());
439        assert!(decode(&[1, 0, 0, 0]).is_err()); // too short for freq table
440    }
441}