1use crate::error::CodecError;
21
22const NUM_STREAMS: usize = 4;
24
25const PROB_BITS: u32 = 14;
27const PROB_SCALE: u32 = 1 << PROB_BITS;
28
29const RANS_L: u32 = 1 << 23;
31
32const FREQ_TABLE_SIZE: usize = 256 * 4;
34
35const HEADER_SIZE: usize = 4 + FREQ_TABLE_SIZE + 4;
37
38pub fn encode(data: &[u8]) -> Vec<u8> {
44 if data.is_empty() {
45 let out = vec![0u8; HEADER_SIZE];
46 return out;
48 }
49
50 let mut freqs = [0u32; 256];
52 for &b in data {
53 freqs[b as usize] += 1;
54 }
55
56 let norm_freqs = normalize_frequencies(&freqs, data.len());
58
59 let (cum_freqs, sym_freqs) = build_cum_table(&norm_freqs);
61
62 let mut streams: [Vec<u8>; NUM_STREAMS] = std::array::from_fn(|_| Vec::new());
65 let mut states = [RANS_L; NUM_STREAMS];
66
67 for i in (0..data.len()).rev() {
69 let stream_idx = i % NUM_STREAMS;
70 let sym = data[i] as usize;
71 let freq = sym_freqs[sym];
72 let start = cum_freqs[sym];
73
74 if freq == 0 {
75 continue; }
77
78 rans_encode_symbol(
79 &mut states[stream_idx],
80 &mut streams[stream_idx],
81 start,
82 freq,
83 );
84 }
85
86 for i in 0..NUM_STREAMS {
88 let s = states[i];
89 streams[i].push((s & 0xFF) as u8);
90 streams[i].push(((s >> 8) & 0xFF) as u8);
91 streams[i].push(((s >> 16) & 0xFF) as u8);
92 streams[i].push(((s >> 24) & 0xFF) as u8);
93 }
94
95 let total_compressed: usize = streams.iter().map(|s| s.len()).sum();
97 let mut out = Vec::with_capacity(HEADER_SIZE + total_compressed + NUM_STREAMS * 4);
98
99 out.extend_from_slice(&(data.len() as u32).to_le_bytes());
101
102 for &f in &norm_freqs {
104 out.extend_from_slice(&f.to_le_bytes());
105 }
106
107 let comp_payload_size = total_compressed + NUM_STREAMS * 4; out.extend_from_slice(&(comp_payload_size as u32).to_le_bytes());
110
111 for s in &streams {
113 out.extend_from_slice(&(s.len() as u32).to_le_bytes());
114 }
115
116 for s in &streams {
118 out.extend_from_slice(s);
119 }
120
121 out
122}
123
124pub fn decode(data: &[u8]) -> Result<Vec<u8>, CodecError> {
126 if data.len() < HEADER_SIZE {
127 return Err(CodecError::Truncated {
128 expected: HEADER_SIZE,
129 actual: data.len(),
130 });
131 }
132
133 let uncompressed_size = u32::from_le_bytes([data[0], data[1], data[2], data[3]]) as usize;
134 if uncompressed_size == 0 {
135 return Ok(Vec::new());
136 }
137
138 let mut norm_freqs = [0u32; 256];
140 for (i, freq) in norm_freqs.iter_mut().enumerate() {
141 let pos = 4 + i * 4;
142 *freq = u32::from_le_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
143 }
144
145 let (cum_freqs, sym_freqs) = build_cum_table(&norm_freqs);
146
147 let lookup = build_decode_table(&cum_freqs, &sym_freqs);
149
150 let _comp_size = u32::from_le_bytes([
151 data[HEADER_SIZE - 4],
152 data[HEADER_SIZE - 3],
153 data[HEADER_SIZE - 2],
154 data[HEADER_SIZE - 1],
155 ]) as usize;
156
157 let mut pos = HEADER_SIZE;
159 if pos + NUM_STREAMS * 4 > data.len() {
160 return Err(CodecError::Truncated {
161 expected: pos + NUM_STREAMS * 4,
162 actual: data.len(),
163 });
164 }
165
166 let mut stream_sizes = [0usize; NUM_STREAMS];
167 for size in stream_sizes.iter_mut() {
168 *size =
169 u32::from_le_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]) as usize;
170 pos += 4;
171 }
172
173 let mut stream_data: [Vec<u8>; NUM_STREAMS] = std::array::from_fn(|_| Vec::new());
175 for i in 0..NUM_STREAMS {
176 let end = pos + stream_sizes[i];
177 if end > data.len() {
178 return Err(CodecError::Truncated {
179 expected: end,
180 actual: data.len(),
181 });
182 }
183 stream_data[i] = data[pos..end].to_vec();
184 pos += stream_sizes[i];
185 }
186
187 let mut states = [0u32; NUM_STREAMS];
189 let mut stream_pos = [0usize; NUM_STREAMS];
190 for i in 0..NUM_STREAMS {
191 let s = &stream_data[i];
192 if s.len() < 4 {
193 return Err(CodecError::Corrupt {
194 detail: format!("rANS stream {i} too short for state"),
195 });
196 }
197 let end = s.len();
198 states[i] = u32::from_le_bytes([s[end - 4], s[end - 3], s[end - 2], s[end - 1]]);
199 stream_pos[i] = end - 4;
200 }
201
202 let mut output = vec![0u8; uncompressed_size];
204 for (i, out_byte) in output.iter_mut().enumerate() {
205 let stream_idx = i % NUM_STREAMS;
206 let (sym, new_state) =
207 rans_decode_symbol(states[stream_idx], &lookup, &cum_freqs, &sym_freqs);
208 *out_byte = sym;
209 states[stream_idx] = rans_decode_renorm(
210 new_state,
211 &stream_data[stream_idx],
212 &mut stream_pos[stream_idx],
213 );
214 }
215
216 Ok(output)
217}
218
219fn rans_encode_symbol(state: &mut u32, bitstream: &mut Vec<u8>, start: u32, freq: u32) {
224 let max_state = ((RANS_L >> PROB_BITS) << 8) * freq;
226 while *state >= max_state {
227 bitstream.push((*state & 0xFF) as u8);
228 *state >>= 8;
229 }
230
231 *state = ((*state / freq) << PROB_BITS) + (*state % freq) + start;
233}
234
235fn rans_decode_symbol(
236 state: u32,
237 lookup: &[u8; PROB_SCALE as usize],
238 cum_freqs: &[u32; 257],
239 sym_freqs: &[u32; 256],
240) -> (u8, u32) {
241 let slot = state & (PROB_SCALE - 1);
242 let sym = lookup[slot as usize];
243 let start = cum_freqs[sym as usize];
244 let freq = sym_freqs[sym as usize];
245
246 let new_state = freq * (state >> PROB_BITS) + slot - start;
247 (sym, new_state)
248}
249
250fn rans_decode_renorm(mut state: u32, stream: &[u8], pos: &mut usize) -> u32 {
251 while state < RANS_L && *pos > 0 {
252 *pos -= 1;
253 state = (state << 8) | stream[*pos] as u32;
254 }
255 state
256}
257
258fn normalize_frequencies(freqs: &[u32; 256], total: usize) -> [u32; 256] {
264 let mut norm = [0u32; 256];
265 let mut sum = 0u32;
266 let total_f64 = total as f64;
267
268 for i in 0..256 {
270 if freqs[i] > 0 {
271 norm[i] = ((freqs[i] as f64 / total_f64 * PROB_SCALE as f64).round() as u32).max(1);
273 sum += norm[i];
274 }
275 }
276
277 if sum > 0 {
279 while sum > PROB_SCALE {
280 let max_idx = norm
282 .iter()
283 .enumerate()
284 .filter(|(_, f)| **f > 1)
285 .max_by_key(|(_, f)| **f)
286 .map(|(i, _)| i)
287 .unwrap_or(0);
288 norm[max_idx] -= 1;
289 sum -= 1;
290 }
291 while sum < PROB_SCALE {
292 let max_idx = norm
293 .iter()
294 .enumerate()
295 .max_by_key(|(_, f)| **f)
296 .map(|(i, _)| i)
297 .unwrap_or(0);
298 norm[max_idx] += 1;
299 sum += 1;
300 }
301 }
302
303 norm
304}
305
306fn build_cum_table(freqs: &[u32; 256]) -> ([u32; 257], [u32; 256]) {
308 let mut cum = [0u32; 257];
309 let sym_freqs = *freqs;
310 for i in 0..256 {
311 cum[i + 1] = cum[i] + freqs[i];
312 }
313 (cum, sym_freqs)
314}
315
316fn build_decode_table(
318 cum_freqs: &[u32; 257],
319 _sym_freqs: &[u32; 256],
320) -> [u8; PROB_SCALE as usize] {
321 let mut table = [0u8; PROB_SCALE as usize];
322 for sym in 0..256u16 {
323 let start = cum_freqs[sym as usize] as usize;
324 let end = cum_freqs[sym as usize + 1] as usize;
325 for entry in table[start..end].iter_mut() {
326 *entry = sym as u8;
327 }
328 }
329 table
330}
331
332#[cfg(test)]
333mod tests {
334 use super::*;
335
336 #[test]
337 fn empty_roundtrip() {
338 let encoded = encode(&[]);
339 let decoded = decode(&encoded).unwrap();
340 assert!(decoded.is_empty());
341 }
342
343 #[test]
344 fn single_byte() {
345 let encoded = encode(&[42]);
346 let decoded = decode(&encoded).unwrap();
347 assert_eq!(decoded, vec![42]);
348 }
349
350 #[test]
351 fn repeated_bytes() {
352 let data = vec![0u8; 10_000];
353 let encoded = encode(&data);
354 let decoded = decode(&encoded).unwrap();
355 assert_eq!(decoded, data);
356
357 let ratio = data.len() as f64 / encoded.len() as f64;
359 assert!(
360 ratio > 2.0,
361 "repeated bytes should compress >2x, got {ratio:.1}x"
362 );
363 }
364
365 #[test]
366 fn text_data() {
367 let text = b"the quick brown fox jumps over the lazy dog. ";
368 let data: Vec<u8> = text.iter().copied().cycle().take(10_000).collect();
369 let encoded = encode(&data);
370 let decoded = decode(&encoded).unwrap();
371 assert_eq!(decoded, data);
372
373 let ratio = data.len() as f64 / encoded.len() as f64;
374 assert!(ratio > 1.5, "text should compress >1.5x, got {ratio:.1}x");
375 }
376
377 #[test]
378 fn uniform_random_data() {
379 let mut data = vec![0u8; 5000];
381 let mut rng: u64 = 12345;
382 for byte in &mut data {
383 rng = rng.wrapping_mul(6364136223846793005).wrapping_add(1);
384 *byte = (rng >> 33) as u8;
385 }
386 let encoded = encode(&data);
387 let decoded = decode(&encoded).unwrap();
388 assert_eq!(decoded, data);
389 }
390
391 #[test]
392 fn all_byte_values() {
393 let data: Vec<u8> = (0..=255u8).cycle().take(4096).collect();
395 let encoded = encode(&data);
396 let decoded = decode(&encoded).unwrap();
397 assert_eq!(decoded, data);
398 }
399
400 #[test]
401 fn skewed_distribution() {
402 let mut data = vec![0u8; 10_000];
404 for i in 0..1000 {
405 data[i * 10] = 1;
406 }
407 let encoded = encode(&data);
408 let decoded = decode(&encoded).unwrap();
409 assert_eq!(decoded, data);
410
411 let ratio = data.len() as f64 / encoded.len() as f64;
412 assert!(
413 ratio > 1.5,
414 "skewed data should compress >1.5x, got {ratio:.1}x"
415 );
416 }
417
418 #[test]
419 fn better_than_raw_on_structured() {
420 let mut data = Vec::with_capacity(10_000);
422 for i in 0..10_000 {
423 data.push((i % 16) as u8); }
425 let encoded = encode(&data);
426 let decoded = decode(&encoded).unwrap();
427 assert_eq!(decoded, data);
428
429 let ratio = data.len() as f64 / encoded.len() as f64;
430 assert!(
431 ratio > 1.5,
432 "low-entropy data should compress >1.5x, got {ratio:.1}x"
433 );
434 }
435
436 #[test]
437 fn truncated_input_errors() {
438 assert!(decode(&[]).is_err());
439 assert!(decode(&[1, 0, 0, 0]).is_err()); }
441}