Skip to main content

nodedb_codec/
rans.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! Interleaved rANS (Asymmetric Numeral Systems) entropy coder.
4//!
5//! Compresses byte streams to the Shannon entropy limit — optimal
6//! compression ratio at Huffman-like speed. Used as the terminal
7//! compressor for cold/S3 tier partitions where ratio matters more
8//! than decompression speed.
9//!
10//! 4-stream interleaving breaks the sequential dependency chain:
11//! the CPU processes all streams simultaneously, achieving high
12//! throughput despite the inherently sequential nature of ANS.
13//!
14//! Wire format:
15//! ```text
16//! [4 bytes] uncompressed size (LE u32)
17//! [256 × 4 bytes] frequency table (LE u32 per byte value)
18//! [4 bytes] compressed size (LE u32)
19//! [N bytes] interleaved rANS bitstream (4 streams)
20//! ```
21
22use crate::error::CodecError;
23
24/// Number of interleaved streams.
25const NUM_STREAMS: usize = 4;
26
27/// rANS probability scale (power of 2 for fast division).
28const PROB_BITS: u32 = 14;
29const PROB_SCALE: u32 = 1 << PROB_BITS;
30
31/// rANS state lower bound.
32const RANS_L: u32 = 1 << 23;
33
34/// Frequency table header size: 256 × 4 bytes = 1024 bytes.
35const FREQ_TABLE_SIZE: usize = 256 * 4;
36
37/// Full header: 4 (uncomp size) + 1024 (freq table) + 4 (comp size).
38const HEADER_SIZE: usize = 4 + FREQ_TABLE_SIZE + 4;
39
40// ---------------------------------------------------------------------------
41// Public API
42// ---------------------------------------------------------------------------
43
44/// Compress bytes using interleaved rANS.
45pub fn encode(data: &[u8]) -> Vec<u8> {
46    if data.is_empty() {
47        let out = vec![0u8; HEADER_SIZE];
48        // uncompressed_size = 0, freq table = all zeros, compressed_size = 0
49        return out;
50    }
51
52    // Build frequency table.
53    let mut freqs = [0u32; 256];
54    for &b in data {
55        freqs[b as usize] += 1;
56    }
57
58    // Normalize frequencies to sum to PROB_SCALE.
59    let norm_freqs = normalize_frequencies(&freqs, data.len());
60
61    // Build cumulative frequency table.
62    let (cum_freqs, sym_freqs) = build_cum_table(&norm_freqs);
63
64    // Encode using 4 interleaved streams.
65    // Each stream processes every 4th byte: stream 0 gets bytes 0,4,8,...
66    let mut streams: [Vec<u8>; NUM_STREAMS] = std::array::from_fn(|_| Vec::new());
67    let mut states = [RANS_L; NUM_STREAMS];
68
69    // Encode in REVERSE order (rANS encodes backward, decodes forward).
70    for i in (0..data.len()).rev() {
71        let stream_idx = i % NUM_STREAMS;
72        let sym = data[i] as usize;
73        let freq = sym_freqs[sym];
74        let start = cum_freqs[sym];
75
76        if freq == 0 {
77            continue; // Symbol with zero frequency — shouldn't happen after normalization.
78        }
79
80        rans_encode_symbol(
81            &mut states[stream_idx],
82            &mut streams[stream_idx],
83            start,
84            freq,
85        );
86    }
87
88    // Flush final states.
89    for i in 0..NUM_STREAMS {
90        let s = states[i];
91        streams[i].push((s & 0xFF) as u8);
92        streams[i].push(((s >> 8) & 0xFF) as u8);
93        streams[i].push(((s >> 16) & 0xFF) as u8);
94        streams[i].push(((s >> 24) & 0xFF) as u8);
95    }
96
97    // Build output.
98    let total_compressed: usize = streams.iter().map(|s| s.len()).sum();
99    let mut out = Vec::with_capacity(HEADER_SIZE + total_compressed + NUM_STREAMS * 4);
100
101    // Header: uncompressed size.
102    out.extend_from_slice(&(data.len() as u32).to_le_bytes());
103
104    // Frequency table (raw counts for decoding).
105    for &f in &norm_freqs {
106        out.extend_from_slice(&f.to_le_bytes());
107    }
108
109    // Compressed size.
110    let comp_payload_size = total_compressed + NUM_STREAMS * 4; // streams + per-stream sizes
111    out.extend_from_slice(&(comp_payload_size as u32).to_le_bytes());
112
113    // Per-stream sizes (4 bytes each).
114    for s in &streams {
115        out.extend_from_slice(&(s.len() as u32).to_le_bytes());
116    }
117
118    // Stream data (reversed — rANS bitstream is read backward).
119    for s in &streams {
120        out.extend_from_slice(s);
121    }
122
123    out
124}
125
126/// Decompress interleaved rANS data.
127pub fn decode(data: &[u8]) -> Result<Vec<u8>, CodecError> {
128    if data.len() < HEADER_SIZE {
129        return Err(CodecError::Truncated {
130            expected: HEADER_SIZE,
131            actual: data.len(),
132        });
133    }
134
135    let uncompressed_size = u32::from_le_bytes([data[0], data[1], data[2], data[3]]) as usize;
136    if uncompressed_size == 0 {
137        return Ok(Vec::new());
138    }
139
140    // Read frequency table.
141    let mut norm_freqs = [0u32; 256];
142    for (i, freq) in norm_freqs.iter_mut().enumerate() {
143        let pos = 4 + i * 4;
144        *freq = u32::from_le_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
145    }
146
147    let (cum_freqs, sym_freqs) = build_cum_table(&norm_freqs);
148
149    // Build reverse lookup table for decoding.
150    let lookup = build_decode_table(&cum_freqs, &sym_freqs);
151
152    let _comp_size = u32::from_le_bytes([
153        data[HEADER_SIZE - 4],
154        data[HEADER_SIZE - 3],
155        data[HEADER_SIZE - 2],
156        data[HEADER_SIZE - 1],
157    ]) as usize;
158
159    // Read per-stream sizes.
160    let mut pos = HEADER_SIZE;
161    if pos + NUM_STREAMS * 4 > data.len() {
162        return Err(CodecError::Truncated {
163            expected: pos + NUM_STREAMS * 4,
164            actual: data.len(),
165        });
166    }
167
168    let mut stream_sizes = [0usize; NUM_STREAMS];
169    for size in stream_sizes.iter_mut() {
170        *size =
171            u32::from_le_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]) as usize;
172        pos += 4;
173    }
174
175    // Read streams.
176    let mut stream_data: [Vec<u8>; NUM_STREAMS] = std::array::from_fn(|_| Vec::new());
177    for i in 0..NUM_STREAMS {
178        let end = pos + stream_sizes[i];
179        if end > data.len() {
180            return Err(CodecError::Truncated {
181                expected: end,
182                actual: data.len(),
183            });
184        }
185        stream_data[i] = data[pos..end].to_vec();
186        pos += stream_sizes[i];
187    }
188
189    // Initialize states from the end of each stream.
190    let mut states = [0u32; NUM_STREAMS];
191    let mut stream_pos = [0usize; NUM_STREAMS];
192    for i in 0..NUM_STREAMS {
193        let s = &stream_data[i];
194        if s.len() < 4 {
195            return Err(CodecError::Corrupt {
196                detail: format!("rANS stream {i} too short for state"),
197            });
198        }
199        let end = s.len();
200        states[i] = u32::from_le_bytes([s[end - 4], s[end - 3], s[end - 2], s[end - 1]]);
201        stream_pos[i] = end - 4;
202    }
203
204    // Decode forward.
205    let mut output = vec![0u8; uncompressed_size];
206    for (i, out_byte) in output.iter_mut().enumerate() {
207        let stream_idx = i % NUM_STREAMS;
208        let (sym, new_state) =
209            rans_decode_symbol(states[stream_idx], &lookup, &cum_freqs, &sym_freqs);
210        *out_byte = sym;
211        states[stream_idx] = rans_decode_renorm(
212            new_state,
213            &stream_data[stream_idx],
214            &mut stream_pos[stream_idx],
215        );
216    }
217
218    Ok(output)
219}
220
221// ---------------------------------------------------------------------------
222// rANS core operations
223// ---------------------------------------------------------------------------
224
225fn rans_encode_symbol(state: &mut u32, bitstream: &mut Vec<u8>, start: u32, freq: u32) {
226    // Renormalize: output bytes until state is in the correct range.
227    let max_state = ((RANS_L >> PROB_BITS) << 8) * freq;
228    while *state >= max_state {
229        bitstream.push((*state & 0xFF) as u8);
230        *state >>= 8;
231    }
232
233    // Encode symbol.
234    *state = ((*state / freq) << PROB_BITS) + (*state % freq) + start;
235}
236
237fn rans_decode_symbol(
238    state: u32,
239    lookup: &[u8; PROB_SCALE as usize],
240    cum_freqs: &[u32; 257],
241    sym_freqs: &[u32; 256],
242) -> (u8, u32) {
243    let slot = state & (PROB_SCALE - 1);
244    let sym = lookup[slot as usize];
245    let start = cum_freqs[sym as usize];
246    let freq = sym_freqs[sym as usize];
247
248    let new_state = freq * (state >> PROB_BITS) + slot - start;
249    (sym, new_state)
250}
251
252fn rans_decode_renorm(mut state: u32, stream: &[u8], pos: &mut usize) -> u32 {
253    while state < RANS_L && *pos > 0 {
254        *pos -= 1;
255        state = (state << 8) | stream[*pos] as u32;
256    }
257    state
258}
259
260// ---------------------------------------------------------------------------
261// Frequency table operations
262// ---------------------------------------------------------------------------
263
264/// Normalize raw frequencies to sum to PROB_SCALE.
265fn normalize_frequencies(freqs: &[u32; 256], total: usize) -> [u32; 256] {
266    let mut norm = [0u32; 256];
267    let mut sum = 0u32;
268    let total_f64 = total as f64;
269
270    // First pass: proportional scaling.
271    for i in 0..256 {
272        if freqs[i] > 0 {
273            // Ensure every present symbol gets at least frequency 1.
274            norm[i] = ((freqs[i] as f64 / total_f64 * PROB_SCALE as f64).round() as u32).max(1);
275            sum += norm[i];
276        }
277    }
278
279    // Adjust to make sum exactly PROB_SCALE.
280    if sum > 0 {
281        while sum > PROB_SCALE {
282            // Find the symbol with the highest frequency and reduce it.
283            let max_idx = norm
284                .iter()
285                .enumerate()
286                .filter(|(_, f)| **f > 1)
287                .max_by_key(|(_, f)| **f)
288                .map(|(i, _)| i)
289                .unwrap_or(0);
290            norm[max_idx] -= 1;
291            sum -= 1;
292        }
293        while sum < PROB_SCALE {
294            let max_idx = norm
295                .iter()
296                .enumerate()
297                .max_by_key(|(_, f)| **f)
298                .map(|(i, _)| i)
299                .unwrap_or(0);
300            norm[max_idx] += 1;
301            sum += 1;
302        }
303    }
304
305    norm
306}
307
308/// Build cumulative frequency table.
309fn build_cum_table(freqs: &[u32; 256]) -> ([u32; 257], [u32; 256]) {
310    let mut cum = [0u32; 257];
311    let sym_freqs = *freqs;
312    for i in 0..256 {
313        cum[i + 1] = cum[i] + freqs[i];
314    }
315    (cum, sym_freqs)
316}
317
318/// Build decode lookup table: for each slot in [0, PROB_SCALE), which symbol?
319fn build_decode_table(
320    cum_freqs: &[u32; 257],
321    _sym_freqs: &[u32; 256],
322) -> [u8; PROB_SCALE as usize] {
323    let mut table = [0u8; PROB_SCALE as usize];
324    for sym in 0..256u16 {
325        let start = cum_freqs[sym as usize] as usize;
326        let end = cum_freqs[sym as usize + 1] as usize;
327        for entry in table[start..end].iter_mut() {
328            *entry = sym as u8;
329        }
330    }
331    table
332}
333
334#[cfg(test)]
335mod tests {
336    use super::*;
337
338    #[test]
339    fn empty_roundtrip() {
340        let encoded = encode(&[]);
341        let decoded = decode(&encoded).unwrap();
342        assert!(decoded.is_empty());
343    }
344
345    #[test]
346    fn single_byte() {
347        let encoded = encode(&[42]);
348        let decoded = decode(&encoded).unwrap();
349        assert_eq!(decoded, vec![42]);
350    }
351
352    #[test]
353    fn repeated_bytes() {
354        let data = vec![0u8; 10_000];
355        let encoded = encode(&data);
356        let decoded = decode(&encoded).unwrap();
357        assert_eq!(decoded, data);
358
359        // Highly repetitive → near-zero entropy → excellent compression.
360        let ratio = data.len() as f64 / encoded.len() as f64;
361        assert!(
362            ratio > 2.0,
363            "repeated bytes should compress >2x, got {ratio:.1}x"
364        );
365    }
366
367    #[test]
368    fn text_data() {
369        let text = b"the quick brown fox jumps over the lazy dog. ";
370        let data: Vec<u8> = text.iter().copied().cycle().take(10_000).collect();
371        let encoded = encode(&data);
372        let decoded = decode(&encoded).unwrap();
373        assert_eq!(decoded, data);
374
375        let ratio = data.len() as f64 / encoded.len() as f64;
376        assert!(ratio > 1.5, "text should compress >1.5x, got {ratio:.1}x");
377    }
378
379    #[test]
380    fn uniform_random_data() {
381        // Uniform random → ~8 bits/byte → no compression possible.
382        let mut data = vec![0u8; 5000];
383        let mut rng: u64 = 12345;
384        for byte in &mut data {
385            rng = rng.wrapping_mul(6364136223846793005).wrapping_add(1);
386            *byte = (rng >> 33) as u8;
387        }
388        let encoded = encode(&data);
389        let decoded = decode(&encoded).unwrap();
390        assert_eq!(decoded, data);
391    }
392
393    #[test]
394    fn all_byte_values() {
395        // All 256 byte values present.
396        let data: Vec<u8> = (0..=255u8).cycle().take(4096).collect();
397        let encoded = encode(&data);
398        let decoded = decode(&encoded).unwrap();
399        assert_eq!(decoded, data);
400    }
401
402    #[test]
403    fn skewed_distribution() {
404        // 90% zeros, 10% ones — should compress well.
405        let mut data = vec![0u8; 10_000];
406        for i in 0..1000 {
407            data[i * 10] = 1;
408        }
409        let encoded = encode(&data);
410        let decoded = decode(&encoded).unwrap();
411        assert_eq!(decoded, data);
412
413        let ratio = data.len() as f64 / encoded.len() as f64;
414        assert!(
415            ratio > 1.5,
416            "skewed data should compress >1.5x, got {ratio:.1}x"
417        );
418    }
419
420    #[test]
421    fn better_than_raw_on_structured() {
422        // Structured data after type-aware preprocessing (typical pipeline output).
423        let mut data = Vec::with_capacity(10_000);
424        for i in 0..10_000 {
425            data.push((i % 16) as u8); // Low entropy, 4 bits/byte → 2x compression.
426        }
427        let encoded = encode(&data);
428        let decoded = decode(&encoded).unwrap();
429        assert_eq!(decoded, data);
430
431        let ratio = data.len() as f64 / encoded.len() as f64;
432        assert!(
433            ratio > 1.5,
434            "low-entropy data should compress >1.5x, got {ratio:.1}x"
435        );
436    }
437
438    #[test]
439    fn truncated_input_errors() {
440        assert!(decode(&[]).is_err());
441        assert!(decode(&[1, 0, 0, 0]).is_err()); // too short for freq table
442    }
443}