1use crate::error::CodecError;
23
24const NUM_STREAMS: usize = 4;
26
27const PROB_BITS: u32 = 14;
29const PROB_SCALE: u32 = 1 << PROB_BITS;
30
31const RANS_L: u32 = 1 << 23;
33
34const FREQ_TABLE_SIZE: usize = 256 * 4;
36
37const HEADER_SIZE: usize = 4 + FREQ_TABLE_SIZE + 4;
39
40pub fn encode(data: &[u8]) -> Vec<u8> {
46 if data.is_empty() {
47 let out = vec![0u8; HEADER_SIZE];
48 return out;
50 }
51
52 let mut freqs = [0u32; 256];
54 for &b in data {
55 freqs[b as usize] += 1;
56 }
57
58 let norm_freqs = normalize_frequencies(&freqs, data.len());
60
61 let (cum_freqs, sym_freqs) = build_cum_table(&norm_freqs);
63
64 let mut streams: [Vec<u8>; NUM_STREAMS] = std::array::from_fn(|_| Vec::new());
67 let mut states = [RANS_L; NUM_STREAMS];
68
69 for i in (0..data.len()).rev() {
71 let stream_idx = i % NUM_STREAMS;
72 let sym = data[i] as usize;
73 let freq = sym_freqs[sym];
74 let start = cum_freqs[sym];
75
76 if freq == 0 {
77 continue; }
79
80 rans_encode_symbol(
81 &mut states[stream_idx],
82 &mut streams[stream_idx],
83 start,
84 freq,
85 );
86 }
87
88 for i in 0..NUM_STREAMS {
90 let s = states[i];
91 streams[i].push((s & 0xFF) as u8);
92 streams[i].push(((s >> 8) & 0xFF) as u8);
93 streams[i].push(((s >> 16) & 0xFF) as u8);
94 streams[i].push(((s >> 24) & 0xFF) as u8);
95 }
96
97 let total_compressed: usize = streams.iter().map(|s| s.len()).sum();
99 let mut out = Vec::with_capacity(HEADER_SIZE + total_compressed + NUM_STREAMS * 4);
100
101 out.extend_from_slice(&(data.len() as u32).to_le_bytes());
103
104 for &f in &norm_freqs {
106 out.extend_from_slice(&f.to_le_bytes());
107 }
108
109 let comp_payload_size = total_compressed + NUM_STREAMS * 4; out.extend_from_slice(&(comp_payload_size as u32).to_le_bytes());
112
113 for s in &streams {
115 out.extend_from_slice(&(s.len() as u32).to_le_bytes());
116 }
117
118 for s in &streams {
120 out.extend_from_slice(s);
121 }
122
123 out
124}
125
126pub fn decode(data: &[u8]) -> Result<Vec<u8>, CodecError> {
128 if data.len() < HEADER_SIZE {
129 return Err(CodecError::Truncated {
130 expected: HEADER_SIZE,
131 actual: data.len(),
132 });
133 }
134
135 let uncompressed_size = u32::from_le_bytes([data[0], data[1], data[2], data[3]]) as usize;
136 if uncompressed_size == 0 {
137 return Ok(Vec::new());
138 }
139
140 let mut norm_freqs = [0u32; 256];
142 for (i, freq) in norm_freqs.iter_mut().enumerate() {
143 let pos = 4 + i * 4;
144 *freq = u32::from_le_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
145 }
146
147 let (cum_freqs, sym_freqs) = build_cum_table(&norm_freqs);
148
149 let lookup = build_decode_table(&cum_freqs, &sym_freqs);
151
152 let _comp_size = u32::from_le_bytes([
153 data[HEADER_SIZE - 4],
154 data[HEADER_SIZE - 3],
155 data[HEADER_SIZE - 2],
156 data[HEADER_SIZE - 1],
157 ]) as usize;
158
159 let mut pos = HEADER_SIZE;
161 if pos + NUM_STREAMS * 4 > data.len() {
162 return Err(CodecError::Truncated {
163 expected: pos + NUM_STREAMS * 4,
164 actual: data.len(),
165 });
166 }
167
168 let mut stream_sizes = [0usize; NUM_STREAMS];
169 for size in stream_sizes.iter_mut() {
170 *size =
171 u32::from_le_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]) as usize;
172 pos += 4;
173 }
174
175 let mut stream_data: [Vec<u8>; NUM_STREAMS] = std::array::from_fn(|_| Vec::new());
177 for i in 0..NUM_STREAMS {
178 let end = pos + stream_sizes[i];
179 if end > data.len() {
180 return Err(CodecError::Truncated {
181 expected: end,
182 actual: data.len(),
183 });
184 }
185 stream_data[i] = data[pos..end].to_vec();
186 pos += stream_sizes[i];
187 }
188
189 let mut states = [0u32; NUM_STREAMS];
191 let mut stream_pos = [0usize; NUM_STREAMS];
192 for i in 0..NUM_STREAMS {
193 let s = &stream_data[i];
194 if s.len() < 4 {
195 return Err(CodecError::Corrupt {
196 detail: format!("rANS stream {i} too short for state"),
197 });
198 }
199 let end = s.len();
200 states[i] = u32::from_le_bytes([s[end - 4], s[end - 3], s[end - 2], s[end - 1]]);
201 stream_pos[i] = end - 4;
202 }
203
204 let mut output = vec![0u8; uncompressed_size];
206 for (i, out_byte) in output.iter_mut().enumerate() {
207 let stream_idx = i % NUM_STREAMS;
208 let (sym, new_state) =
209 rans_decode_symbol(states[stream_idx], &lookup, &cum_freqs, &sym_freqs);
210 *out_byte = sym;
211 states[stream_idx] = rans_decode_renorm(
212 new_state,
213 &stream_data[stream_idx],
214 &mut stream_pos[stream_idx],
215 );
216 }
217
218 Ok(output)
219}
220
221fn rans_encode_symbol(state: &mut u32, bitstream: &mut Vec<u8>, start: u32, freq: u32) {
226 let max_state = ((RANS_L >> PROB_BITS) << 8) * freq;
228 while *state >= max_state {
229 bitstream.push((*state & 0xFF) as u8);
230 *state >>= 8;
231 }
232
233 *state = ((*state / freq) << PROB_BITS) + (*state % freq) + start;
235}
236
237fn rans_decode_symbol(
238 state: u32,
239 lookup: &[u8; PROB_SCALE as usize],
240 cum_freqs: &[u32; 257],
241 sym_freqs: &[u32; 256],
242) -> (u8, u32) {
243 let slot = state & (PROB_SCALE - 1);
244 let sym = lookup[slot as usize];
245 let start = cum_freqs[sym as usize];
246 let freq = sym_freqs[sym as usize];
247
248 let new_state = freq * (state >> PROB_BITS) + slot - start;
249 (sym, new_state)
250}
251
252fn rans_decode_renorm(mut state: u32, stream: &[u8], pos: &mut usize) -> u32 {
253 while state < RANS_L && *pos > 0 {
254 *pos -= 1;
255 state = (state << 8) | stream[*pos] as u32;
256 }
257 state
258}
259
260fn normalize_frequencies(freqs: &[u32; 256], total: usize) -> [u32; 256] {
266 let mut norm = [0u32; 256];
267 let mut sum = 0u32;
268 let total_f64 = total as f64;
269
270 for i in 0..256 {
272 if freqs[i] > 0 {
273 norm[i] = ((freqs[i] as f64 / total_f64 * PROB_SCALE as f64).round() as u32).max(1);
275 sum += norm[i];
276 }
277 }
278
279 if sum > 0 {
281 while sum > PROB_SCALE {
282 let max_idx = norm
284 .iter()
285 .enumerate()
286 .filter(|(_, f)| **f > 1)
287 .max_by_key(|(_, f)| **f)
288 .map(|(i, _)| i)
289 .unwrap_or(0);
290 norm[max_idx] -= 1;
291 sum -= 1;
292 }
293 while sum < PROB_SCALE {
294 let max_idx = norm
295 .iter()
296 .enumerate()
297 .max_by_key(|(_, f)| **f)
298 .map(|(i, _)| i)
299 .unwrap_or(0);
300 norm[max_idx] += 1;
301 sum += 1;
302 }
303 }
304
305 norm
306}
307
308fn build_cum_table(freqs: &[u32; 256]) -> ([u32; 257], [u32; 256]) {
310 let mut cum = [0u32; 257];
311 let sym_freqs = *freqs;
312 for i in 0..256 {
313 cum[i + 1] = cum[i] + freqs[i];
314 }
315 (cum, sym_freqs)
316}
317
318fn build_decode_table(
320 cum_freqs: &[u32; 257],
321 _sym_freqs: &[u32; 256],
322) -> [u8; PROB_SCALE as usize] {
323 let mut table = [0u8; PROB_SCALE as usize];
324 for sym in 0..256u16 {
325 let start = cum_freqs[sym as usize] as usize;
326 let end = cum_freqs[sym as usize + 1] as usize;
327 for entry in table[start..end].iter_mut() {
328 *entry = sym as u8;
329 }
330 }
331 table
332}
333
334#[cfg(test)]
335mod tests {
336 use super::*;
337
338 #[test]
339 fn empty_roundtrip() {
340 let encoded = encode(&[]);
341 let decoded = decode(&encoded).unwrap();
342 assert!(decoded.is_empty());
343 }
344
345 #[test]
346 fn single_byte() {
347 let encoded = encode(&[42]);
348 let decoded = decode(&encoded).unwrap();
349 assert_eq!(decoded, vec![42]);
350 }
351
352 #[test]
353 fn repeated_bytes() {
354 let data = vec![0u8; 10_000];
355 let encoded = encode(&data);
356 let decoded = decode(&encoded).unwrap();
357 assert_eq!(decoded, data);
358
359 let ratio = data.len() as f64 / encoded.len() as f64;
361 assert!(
362 ratio > 2.0,
363 "repeated bytes should compress >2x, got {ratio:.1}x"
364 );
365 }
366
367 #[test]
368 fn text_data() {
369 let text = b"the quick brown fox jumps over the lazy dog. ";
370 let data: Vec<u8> = text.iter().copied().cycle().take(10_000).collect();
371 let encoded = encode(&data);
372 let decoded = decode(&encoded).unwrap();
373 assert_eq!(decoded, data);
374
375 let ratio = data.len() as f64 / encoded.len() as f64;
376 assert!(ratio > 1.5, "text should compress >1.5x, got {ratio:.1}x");
377 }
378
379 #[test]
380 fn uniform_random_data() {
381 let mut data = vec![0u8; 5000];
383 let mut rng: u64 = 12345;
384 for byte in &mut data {
385 rng = rng.wrapping_mul(6364136223846793005).wrapping_add(1);
386 *byte = (rng >> 33) as u8;
387 }
388 let encoded = encode(&data);
389 let decoded = decode(&encoded).unwrap();
390 assert_eq!(decoded, data);
391 }
392
393 #[test]
394 fn all_byte_values() {
395 let data: Vec<u8> = (0..=255u8).cycle().take(4096).collect();
397 let encoded = encode(&data);
398 let decoded = decode(&encoded).unwrap();
399 assert_eq!(decoded, data);
400 }
401
402 #[test]
403 fn skewed_distribution() {
404 let mut data = vec![0u8; 10_000];
406 for i in 0..1000 {
407 data[i * 10] = 1;
408 }
409 let encoded = encode(&data);
410 let decoded = decode(&encoded).unwrap();
411 assert_eq!(decoded, data);
412
413 let ratio = data.len() as f64 / encoded.len() as f64;
414 assert!(
415 ratio > 1.5,
416 "skewed data should compress >1.5x, got {ratio:.1}x"
417 );
418 }
419
420 #[test]
421 fn better_than_raw_on_structured() {
422 let mut data = Vec::with_capacity(10_000);
424 for i in 0..10_000 {
425 data.push((i % 16) as u8); }
427 let encoded = encode(&data);
428 let decoded = decode(&encoded).unwrap();
429 assert_eq!(decoded, data);
430
431 let ratio = data.len() as f64 / encoded.len() as f64;
432 assert!(
433 ratio > 1.5,
434 "low-entropy data should compress >1.5x, got {ratio:.1}x"
435 );
436 }
437
438 #[test]
439 fn truncated_input_errors() {
440 assert!(decode(&[]).is_err());
441 assert!(decode(&[1, 0, 0, 0]).is_err()); }
443}