Skip to main content

sym_adv_encoding/
utf8.rs

1//! Utilities for encoding bytes/UTF-8 text into ring elements.
2//!
3//! This module provides two related but distinct encodings for turning a byte
4//! string into a sequence of `RingElement`s with a given `modulus`:
5//!
6//! - Bit-packing (`encode_bytes` / `decode_bytes`):
7//!   Packs the input bitstream into chunks of b = floor(log2(modulus)) bits.
8//!   Every produced element satisfies value < 2^b <= modulus, so some residues
9//!   are intentionally unused when the modulus is not a power of two.
10//!   This format does not include the original byte length.
11//!
12//! - Length-delimited base-m transduction
13//!   (`encode_bytes_base_m_len` / `decode_bytes_base_m_len`):
14//!   Encodes the byte length and a fixed rANS state as base-modulus digits,
15//!   then converts the byte stream into base-modulus payload digits using a
16//!   uniform rANS transducer. Every produced digit satisfies digit < modulus,
17//!   so it uses all residues. Decoding is length-delimited and ignores trailing
18//!   elements.
19//!
20//! ## Target behavior (base-m-len variant)
21//!
22//! - Uses all residues: every emitted `RingElement` satisfies value < modulus,
23//!   with no "[0, 2^b)" restriction.
24//! - Near-linear time: amortized O(1) work per input byte, plus a constant-size
25//!   header (2 * k digits for a fixed k determined by `modulus`).
26//! - Same public API: the public `encode_bytes_base_m_len` /
27//!   `decode_bytes_base_m_len` signatures and
28//!   `encode_text_base_m_len` / `decode_text_base_m_len` behavior are unchanged.
29//! - Padding tolerance: appending extra elements at the end (for example, block
30//!   padding) does not affect decoding; the decoder stops after emitting the
31//!   byte count from the length prefix.
32//!
33//! ## Design: uniform rANS as a radix transducer
34//!
35//! Treat each input byte as a uniform symbol (frequency = 1 over 256 symbols).
36//! The symbol update is:
37//!
38//! - encode: x = x * 256 + byte
39//! - decode: byte = x % 256; x = x / 256
40//!
41//! Renormalization uses base = modulus, so every emitted digit is a valid ring
42//! residue (digit < modulus). The stream stores a fixed-size initial state
43//! immediately after the length prefix so decoding can proceed FIFO
44//! (left-to-right) and ignore any trailing padding.
45//!
46//! ## Future work
47//!
48//! Investigate non-uniform rANS as a radix transducer to incorporate empirical
49//! byte distributions while retaining the base-modulus digit stream interface.
50//!
51//! The bitstream in `encode_bytes` is packed little-endian: earlier bytes
52//! occupy lower bits of the internal buffer and are emitted first.
53//!
54//! Important limitations (bit-packing variant):
55//!
56//! - The encoding does not include the original byte length. For some parameter
57//!   choices (notably when b > 8), decoding may yield extra trailing 0x00 bytes
58//!   that come from zero-padding the final partial chunk. If you need exact
59//!   round-trips, store the original length separately and truncate after
60//!   decoding.
61//! - Decoding assumes elements are in the canonical range produced by this
62//!   module (each element value fits in b bits). If you modify elements via
63//!   ring arithmetic, decoding will generally not recover the original bytes.
64//!
65//! # References (ANS / rANS)
66//!
67//! - J. Duda, "Asymmetric numeral systems: entropy coding combining speed of
68//!   Huffman coding with compression rate of arithmetic coding",
69//!   <https://arxiv.org/abs/1311.2540>
70//! - F. Giesen, "rANS notes":
71//!   <https://fgiesen.wordpress.com/2014/02/02/rans-notes/>
72//! - F. Giesen, "rANS with static probability distributions":
73//!   <https://fgiesen.wordpress.com/2014/02/18/rans-with-static-probability-distributions/>
74//! - Y. Collet, "Finite State Entropy, a new breed of entropy coders" (FSE/tANS):
75//!   <https://fastcompression.blogspot.com/2013/12/finite-state-entropy-new-breed-of.html>
76
77use sym_adv_ring::RingElement;
78
79#[derive(Clone, Copy, Debug, PartialEq, Eq)]
80pub enum Utf8EncodingType {
81    LengthDelimitedBaseMTransduction,
82    BitPacking,
83}
84
85enum Utf8EncodingTypeInner {
86    LengthDelimitedBaseMTransduction {
87        length_prefix_digits: usize,
88        decoder_lower_bound: u64,
89        encoder_threshold: u64,
90    },
91    BitPacking,
92}
93
94pub struct Utf8Encoding {
95    modulus: u64,
96    encoding_type: Utf8EncodingTypeInner,
97}
98
99impl Utf8Encoding {
100    /// Builds an encoding configuration for the provided `modulus`.
101    ///
102    /// # Errors
103    ///
104    /// - Returns [`EncodingError::ModulusTooSmall`] if `modulus < 2`.
105    /// - Returns [`EncodingError::DecodingError`] when base-m rANS parameters are
106    ///   invalid for the provided `modulus`.
107    pub fn try_from(modulus: u64, encoding_type: Utf8EncodingType) -> Result<Self, EncodingError> {
108        if modulus < 2 {
109            return Err(EncodingError::ModulusTooSmall);
110        }
111
112        Ok(Self {
113            modulus,
114            encoding_type: Self::build_inner_encoding_type(modulus, encoding_type)?,
115        })
116    }
117
118    fn build_inner_encoding_type(
119        modulus: u64,
120        encoding_type: Utf8EncodingType,
121    ) -> Result<Utf8EncodingTypeInner, EncodingError> {
122        match encoding_type {
123            Utf8EncodingType::LengthDelimitedBaseMTransduction => {
124                let (decoder_lower_bound, encoder_threshold) = Self::rans_params(modulus)?;
125
126                Ok(Utf8EncodingTypeInner::LengthDelimitedBaseMTransduction {
127                    length_prefix_digits: Self::length_prefix_digits(modulus),
128                    decoder_lower_bound,
129                    encoder_threshold,
130                })
131            }
132            Utf8EncodingType::BitPacking => Ok(Utf8EncodingTypeInner::BitPacking),
133        }
134    }
135}
136
137impl Utf8Encoding {
138    /// Returns the fixed base-`modulus` digit count used for u64 prefixes.
139    ///
140    /// This computes the smallest `k` such that `modulus^k >= 2^64`. The
141    /// length-delimited `*_base_m_len` format encodes both the byte length and the
142    /// rANS state using exactly `k` base-`modulus` digits, which makes the stream
143    /// layout self-delimiting (`length || state || payload`) without separators.
144    ///
145    /// Digits are interpreted little-endian (least significant digit first).
146    ///
147    /// # Math
148    ///
149    /// Let m = modulus (with m >= 2). We choose the minimal digit count k >= 0 such
150    /// that every 64-bit value fits into k base-m digits:
151    ///
152    ///   k = min { k in N0 : m^k >= 2^64 }.
153    ///
154    /// Equivalently:
155    ///
156    ///   k = ceil( `log_m(2^64)` )
157    ///     = ceil( log(2^64) / log(m) ).
158    ///
159    /// This guarantees that any x in [0, 2^64 - 1] can be represented using exactly
160    /// k base-m digits (padding with leading zero digits as needed).
161    ///
162    /// # Complexity
163    ///
164    /// Runs in O(k) time and O(1) extra space, where k is the returned digit count
165    /// (the number of times we multiply by `modulus` until reaching 2^64).
166    ///
167    /// In terms of `modulus = m` (m >= 2), k = `ceil(log_m(2^64))`, so the loop runs
168    /// about ceil(64 / log2(m)) iterations (worst case at m = 2 gives k = 64).
169    fn length_prefix_digits(modulus: u64) -> usize {
170        let target: u128 = 1u128 << 64;
171        let base: u128 = u128::from(modulus);
172        let mut pow: u128 = 1;
173        let mut k: usize = 0;
174        while pow < target {
175            pow = pow.saturating_mul(base);
176            k += 1;
177        }
178
179        k
180    }
181
182    /// Chooses renormalization parameters for the uniform (256-symbol) rANS transducer.
183    ///
184    /// The `*_base_m_len` encoding treats bytes as symbols in a uniform 256-ary
185    /// alphabet and produces payload digits in radix `modulus` (so every digit is
186    /// `< modulus`). The internal state `x` is maintained in a `u64`.
187    ///
188    /// This function returns `(L, x_max)`:
189    /// - `L` is a decoder lower bound (chosen as a multiple of 256),
190    /// - `x_max = (L / 256) * modulus` is an encoder threshold.
191    ///
192    /// The encoder emits base-`modulus` digits while `x >= x_max` to ensure the
193    /// update `x = x * 256 + byte` cannot overflow `u64`. The decoder performs the
194    /// inverse operation by consuming digits while `x < L`.
195    ///
196    /// # Errors
197    ///
198    /// - Returns [`EncodingError::DecodingError`] if `modulus` is so large that no
199    ///   valid `L >= 256` can be chosen while keeping the state in `u64`.
200    fn rans_params(modulus: u64) -> Result<(u64, u64), EncodingError> {
201        let max_l = u64::MAX / modulus;
202        let l = (max_l / 256) * 256;
203        if l < 256 {
204            return Err(EncodingError::DecodingError(
205                "Modulus too large for base-m-len encoding".to_string(),
206            ));
207        }
208
209        let x_max_u128 = u128::from(l / 256) * u128::from(modulus);
210        if x_max_u128 > u128::from(u64::MAX) {
211            return Err(EncodingError::DecodingError(
212                "Invalid rANS parameters".to_string(),
213            ));
214        }
215
216        let x_max = u64::try_from(x_max_u128)
217            .map_err(|_| EncodingError::DecodingError("Invalid rANS parameters".to_string()))?;
218        Ok((l, x_max))
219    }
220
221    /// Encodes a `u64` as exactly `k` base-`modulus` digits (little-endian).
222    ///
223    /// Here `k = length_prefix_digits(modulus)`, so every `u64` fits. The returned
224    /// vector always has length `k`; most-significant digits are `0` when `value` is
225    /// small.
226    fn encode_u64_base_m_fixed(&self, value: u64) -> Vec<u64> {
227        let Utf8EncodingTypeInner::LengthDelimitedBaseMTransduction {
228            length_prefix_digits,
229            ..
230        } = self.encoding_type
231        else {
232            unreachable!()
233        };
234
235        let mut digits = Vec::with_capacity(length_prefix_digits);
236        let mut v = value;
237        for _ in 0..length_prefix_digits {
238            digits.push(v % self.modulus);
239            v /= self.modulus;
240        }
241
242        digits
243    }
244
245    /// Decodes a fixed-width base-`modulus` digit sequence (little-endian) into a `u64`.
246    ///
247    /// Each digit must satisfy `digit < modulus`. The digit slice is typically of
248    /// length `k = length_prefix_digits(modulus)`, as produced by
249    /// [`encode_u64_base_m_fixed`].
250    ///
251    /// # Errors
252    ///
253    /// - Returns [`EncodingError::DecodingError`] if any digit is out of range or if
254    ///   the decoded value does not fit in `u64`.
255    fn decode_u64_base_m_fixed(&self, digits: &[u64]) -> Result<u64, EncodingError> {
256        let mut value: u128 = 0;
257        let mut pow: u128 = 1;
258        let base: u128 = u128::from(self.modulus);
259        for &d in digits {
260            if d >= self.modulus {
261                return Err(EncodingError::DecodingError(
262                    "Invalid length prefix digit".to_string(),
263                ));
264            }
265            value += u128::from(d) * pow;
266            pow = pow.saturating_mul(base);
267        }
268
269        if value > u128::from(u64::MAX) {
270            return Err(EncodingError::DecodingError(
271                "Decoded length exceeds u64".to_string(),
272            ));
273        }
274
275        u64::try_from(value)
276            .map_err(|_| EncodingError::DecodingError("Decoded length exceeds u64".to_string()))
277    }
278
279    /// Uniform rANS encoder for bytes, emitting digits in radix `modulus`.
280    ///
281    /// Given `bytes`, this returns a pair `(state, stream_digits)` such that
282    /// [`rans_decode_uniform_bytes_base_m`] can reconstruct the original bytes when
283    /// provided with:
284    /// - the original `len = bytes.len()`,
285    /// - the returned `state` (encoded separately as fixed-width base-`modulus`
286    ///   digits),
287    /// - the returned `stream_digits` as the payload digit stream.
288    ///
289    /// Implementation details:
290    /// - Symbols are bytes with a *uniform* distribution, so the rANS update reduces
291    ///   to `x = x * 256 + byte`.
292    /// - Renormalization emits digits in radix `modulus`, ensuring every output
293    ///   digit is `< modulus` (i.e., all residues are available).
294    /// - Bytes are processed in reverse order, as is standard for streaming rANS.
295    ///
296    fn rans_encode_uniform_bytes_base_m(&self, bytes: &[u8]) -> (u64, Vec<u64>) {
297        let Utf8EncodingTypeInner::LengthDelimitedBaseMTransduction {
298            decoder_lower_bound,
299            encoder_threshold,
300            ..
301        } = self.encoding_type
302        else {
303            unreachable!()
304        };
305
306        let mut x = decoder_lower_bound;
307        let mut stream_digits = Vec::new();
308
309        for &b in bytes.iter().rev() {
310            while x >= encoder_threshold {
311                stream_digits.push(x % self.modulus);
312                x /= self.modulus;
313            }
314
315            x = x * 256 + u64::from(b);
316        }
317
318        stream_digits.reverse();
319        (x, stream_digits)
320    }
321
322    /// Uniform rANS decoder for bytes, consuming digits in radix `modulus`.
323    ///
324    /// This is the inverse of [`rans_encode_uniform_bytes_base_m`]. It reconstructs
325    /// exactly `len` bytes from an initial `state` and a digit stream `stream`.
326    ///
327    /// In the uniform 256-symbol case, symbol extraction corresponds to:
328    /// - decode: `byte = x % 256; x = x / 256`
329    ///
330    /// Only as many payload digits as necessary are consumed; any remaining trailing
331    /// elements in `stream` are ignored. This makes the overall `*_base_m_len`
332    /// decoding robust to zero-padding or garbage appended at the end.
333    ///
334    /// # Errors
335    ///
336    /// Returns [`EncodingError::DecodingError`] if:
337    /// - `state` is invalid for the chosen parameters,
338    /// - there are not enough payload digits to decode `len` bytes,
339    /// - any payload digit is out of range (`>= modulus`), or
340    /// - the reconstructed internal state would overflow `u64`.
341    fn rans_decode_uniform_bytes_base_m(
342        &self,
343        len: usize,
344        state: u64,
345        stream: &[RingElement],
346    ) -> Result<Vec<u8>, EncodingError> {
347        let Utf8EncodingTypeInner::LengthDelimitedBaseMTransduction {
348            decoder_lower_bound,
349            ..
350        } = self.encoding_type
351        else {
352            unreachable!()
353        };
354
355        if len == 0 {
356            return Ok(Vec::new());
357        }
358
359        if state < decoder_lower_bound {
360            return Err(EncodingError::DecodingError(
361                "Invalid rANS state".to_string(),
362            ));
363        }
364
365        let mut x = state;
366        let mut out = Vec::with_capacity(len);
367        let mut i = 0usize;
368
369        for _ in 0..len {
370            out.push(
371                u8::try_from(x & u64::from(u8::MAX))
372                    .expect("low 8 bits of the rANS state must fit in u8"),
373            );
374            x >>= 8;
375
376            while x < decoder_lower_bound {
377                let d = stream.get(i).ok_or_else(|| {
378                    EncodingError::DecodingError("Not enough payload digits".to_string())
379                })?;
380                let dv = d.value();
381                if dv >= self.modulus {
382                    return Err(EncodingError::DecodingError(
383                        "Invalid payload digit".to_string(),
384                    ));
385                }
386                let new_x = u128::from(x) * u128::from(self.modulus) + u128::from(dv);
387                if new_x > u128::from(u64::MAX) {
388                    return Err(EncodingError::DecodingError(
389                        "Invalid rANS state".to_string(),
390                    ));
391                }
392                x = u64::try_from(new_x)
393                    .map_err(|_| EncodingError::DecodingError("Invalid rANS state".to_string()))?;
394                i += 1;
395            }
396        }
397
398        Ok(out)
399    }
400
401    /// Encodes bytes into base-`modulus` digits with an embedded byte length.
402    ///
403    /// This encoding is **length-delimited** and **self-contained**: the output
404    /// embeds the original byte length (as a fixed-width base-`modulus` prefix) and
405    /// uses a uniform rANS transducer to convert the bytes into payload digits, all
406    /// strictly `< modulus`.
407    ///
408    /// # Target behavior
409    ///
410    /// - **Uses all residues**: digits are only required to satisfy `value < modulus`
411    ///   (there is no `"[0, 2^b)"` restriction as in [`encode_bytes`]).
412    /// - **Near-linear time**: amortized O(1) work per input byte, plus a constant-size
413    ///   header for the given modulus.
414    /// - **Padding tolerance**: trailing elements do not affect decoding because the
415    ///   decoded length is explicit.
416    /// - **Same public API**: this preserves the existing function signature and
417    ///   high-level behavior (but the on-wire representation is not backward compatible).
418    ///
419    /// # Wire format
420    ///
421    /// For `k = length_prefix_digits(modulus)`, the returned element sequence is:
422    ///
423    /// - `k` digits: `len` as a `u64` in base `modulus` (little-endian),
424    /// - `k` digits: `state` as a `u64` in base `modulus` (little-endian),
425    /// - `n` digits: payload stream digits produced by uniform rANS (each `< modulus`).
426    ///
427    /// The payload digit count `n` depends on `bytes.len()` and `modulus`.
428    ///
429    /// The rANS `state` is placed immediately after the length prefix so decoding can
430    /// proceed FIFO (left-to-right): read `len`, read `state`, then consume as many
431    /// payload digits as needed to emit exactly `len` bytes.
432    ///
433    /// # Decoding behavior
434    ///
435    /// [`decode_bytes_base_m_len`] uses the embedded `len` to decode exactly that
436    /// many bytes and ignores any trailing elements beyond what is required. This
437    /// makes the representation robust to padding or garbage appended at the end.
438    ///
439    /// # Complexity
440    ///
441    /// Near-linear in `bytes.len()` (amortized O(1) renormalization per byte),
442    /// avoiding quadratic base-conversion of large payloads.
443    ///
444    /// # Compatibility
445    ///
446    /// This is **not** compatible with older `*_base_m_len` encodings that used a
447    /// different payload conversion scheme.
448    ///
449    /// # Errors
450    ///
451    /// - Returns [`EncodingError::ModulusTooSmall`] if `modulus < 2`.
452    /// - Returns [`EncodingError::DecodingError`] if the byte length does not fit in
453    ///   `u64` or if valid rANS parameters cannot be chosen for the given modulus.
454    pub fn encode_bytes_base_m_len(&self, bytes: &[u8]) -> Result<Vec<RingElement>, EncodingError> {
455        let len_u64 = u64::try_from(bytes.len()).map_err(|_| {
456            EncodingError::DecodingError("Input length does not fit in u64".to_string())
457        })?;
458
459        let mut result: Vec<RingElement> = Vec::new();
460
461        let len_digits_le = self.encode_u64_base_m_fixed(len_u64);
462        result.extend(
463            len_digits_le
464                .into_iter()
465                .map(|d| RingElement::new(d, self.modulus)),
466        );
467
468        let (state, stream_digits) = self.rans_encode_uniform_bytes_base_m(bytes);
469
470        let state_digits_le = self.encode_u64_base_m_fixed(state);
471        result.extend(
472            state_digits_le
473                .into_iter()
474                .map(|d| RingElement::new(d, self.modulus)),
475        );
476
477        result.extend(
478            stream_digits
479                .into_iter()
480                .map(|d| RingElement::new(d, self.modulus)),
481        );
482
483        Ok(result)
484    }
485
486    /// Decodes bytes encoded by [`encode_bytes_base_m_len`].
487    ///
488    /// This parses the fixed-width base-`modulus` `len` prefix and `state`, then
489    /// uses the uniform rANS decoder to reconstruct exactly `len` bytes.
490    ///
491    /// Any trailing elements after the consumed payload are ignored, so appending
492    /// zero-padding or garbage does not change the decoded result.
493    ///
494    /// Decoding is FIFO: it reads the `len` prefix, then the fixed-width `state`, and
495    /// then consumes only as many payload digits as needed to emit `len` bytes.
496    ///
497    /// # Errors
498    ///
499    /// - Returns [`EncodingError::DecodingError`] if the stream is malformed (not
500    ///   enough prefix/payload digits, out-of-range digits, invalid state, or length
501    ///   that does not fit in `usize`).
502    pub fn decode_bytes_base_m_len(
503        &self,
504        elements: &[RingElement],
505    ) -> Result<Vec<u8>, EncodingError> {
506        let Utf8EncodingTypeInner::LengthDelimitedBaseMTransduction {
507            length_prefix_digits,
508            ..
509        } = self.encoding_type
510        else {
511            unreachable!()
512        };
513
514        if elements.len() < length_prefix_digits {
515            return Err(EncodingError::DecodingError(
516                "Not enough elements for length prefix".to_string(),
517            ));
518        }
519
520        let len_digits: Vec<u64> = elements[..length_prefix_digits]
521            .iter()
522            .map(RingElement::value)
523            .collect();
524        let len_u64 = self.decode_u64_base_m_fixed(&len_digits)?;
525        let len_usize = usize::try_from(len_u64).map_err(|_| {
526            EncodingError::DecodingError("Decoded length does not fit in usize".to_string())
527        })?;
528
529        if len_usize == 0 {
530            return Ok(Vec::new());
531        }
532
533        if elements.len() < 2 * length_prefix_digits {
534            return Err(EncodingError::DecodingError(
535                "Not enough elements for rANS state".to_string(),
536            ));
537        }
538
539        let state_digits: Vec<u64> = elements[length_prefix_digits..2 * length_prefix_digits]
540            .iter()
541            .map(RingElement::value)
542            .collect();
543        let state = self.decode_u64_base_m_fixed(&state_digits)?;
544
545        let payload_stream = &elements[2 * length_prefix_digits..];
546        self.rans_decode_uniform_bytes_base_m(len_usize, state, payload_stream)
547    }
548
549    /// Encodes a UTF-8 string using [`encode_bytes_base_m_len`].
550    ///
551    /// # Errors
552    ///
553    /// - Propagates any error returned by [`Self::encode_bytes_base_m_len`].
554    pub fn encode_text_base_m_len(&self, text: &str) -> Result<Vec<RingElement>, EncodingError> {
555        self.encode_bytes_base_m_len(text.as_bytes())
556    }
557
558    /// Decodes text encoded by [`encode_text_base_m_len`].
559    ///
560    /// This is [`decode_bytes_base_m_len`] followed by UTF-8 validation.
561    ///
562    /// # Errors
563    ///
564    /// - Returns [`EncodingError::DecodingError`] if the element stream is malformed
565    ///   or the decoded bytes are not valid UTF-8.
566    pub fn decode_text_base_m_len(
567        &self,
568        elements: &[RingElement],
569    ) -> Result<String, EncodingError> {
570        let bytes = self.decode_bytes_base_m_len(elements)?;
571        String::from_utf8(bytes).map_err(|e| EncodingError::DecodingError(e.to_string()))
572    }
573}
574
575impl Utf8Encoding {
576    /// Computes how many payload bits can be stored per element for a given modulus.
577    ///
578    /// Returns `b = floor(log2(modulus))`. With this choice,
579    /// every `b`-bit value is guaranteed to be `< modulus`, even when `modulus` is
580    /// not a power of two.
581    fn bits_per_element(&self) -> usize {
582        usize::try_from(self.modulus.ilog2()).expect("u32 bit width must fit in usize")
583    }
584
585    /// Encodes a byte slice into a sequence of [`RingElement`]s.
586    ///
587    /// The input bytes are treated as a little-endian bitstream and packed into
588    /// chunks of `b` bits, where `b = floor(log2(modulus))`. Each chunk becomes the
589    /// `value` of a `RingElement` with the provided `modulus`.
590    ///
591    /// This encoding uses only values in the range `[0, 2^b)`, which is always a
592    /// subset of the ring when `modulus` is not a power of two.
593    ///
594    /// # Notes
595    ///
596    /// - The output does not carry the original byte length. If the final chunk is
597    ///   only partially filled, it is emitted with implicit zero-padding in the
598    ///   high bits. For some moduli (notably when `b > 8`), this can cause decoding
599    ///   to produce extra trailing `0x00` bytes unless you truncate to the original
600    ///   length.
601    #[must_use]
602    pub fn encode_bytes(&self, bytes: &[u8]) -> Vec<RingElement> {
603        let bits_per_elem = self.bits_per_element();
604
605        let mut result = Vec::new();
606        let mut bit_buffer: u64 = 0;
607        let mut bits_in_buffer: usize = 0;
608
609        for &byte in bytes {
610            bit_buffer |= u64::from(byte) << bits_in_buffer;
611            bits_in_buffer += 8;
612
613            while bits_in_buffer >= bits_per_elem {
614                let mask = (1u64 << bits_per_elem) - 1;
615                let value = bit_buffer & mask;
616
617                result.push(RingElement::new(value, self.modulus));
618
619                bit_buffer >>= bits_per_elem;
620                bits_in_buffer -= bits_per_elem;
621            }
622        }
623
624        if bits_in_buffer > 0 {
625            result.push(RingElement::new(bit_buffer, self.modulus));
626        }
627
628        result
629    }
630
631    /// Decodes a sequence of [`RingElement`]s back into bytes.
632    ///
633    /// This reconstructs the same little-endian bitstream produced by
634    /// [`encode_bytes`], assuming `b = floor(log2(modulus))` bits of payload per
635    /// element and emitting 8-bit bytes from the stream.
636    ///
637    /// # Notes
638    ///
639    /// - This function does not validate that the provided `modulus` matches the
640    ///   modulus stored in each element; they must agree for meaningful results.
641    /// - This function assumes each element value fits in `b` bits (i.e., it lies
642    ///   in `[0, 2^b)`) as produced by [`encode_bytes`]. If elements were modified
643    ///   (or originate elsewhere), higher bits can leak into the reconstructed byte
644    ///   stream.
645    /// - Because the encoding does not include the original length, decoding may
646    ///   yield extra trailing `0x00` bytes for some parameter choices; truncate to
647    ///   the known original length when exact recovery is required.
648    ///
649    /// # Panics
650    ///
651    /// Panics only if an internal invariant is violated and extracting the low
652    /// 8 bits of the bit buffer fails to fit in a `u8`.
653    #[must_use]
654    pub fn decode_bytes(&self, elements: &[RingElement]) -> Vec<u8> {
655        if elements.is_empty() {
656            return Vec::new();
657        }
658
659        let bits_per_elem = self.bits_per_element();
660
661        let mut result = Vec::new();
662        let mut bit_buffer: u64 = 0;
663        let mut bits_in_buffer: usize = 0;
664
665        for elem in elements {
666            bit_buffer |= elem.value() << bits_in_buffer;
667            bits_in_buffer += bits_per_elem;
668
669            while bits_in_buffer >= 8 {
670                result.push(
671                    u8::try_from(bit_buffer & u64::from(u8::MAX))
672                        .expect("low 8 bits of the bit buffer must fit in u8"),
673                );
674                bit_buffer >>= 8;
675                bits_in_buffer -= 8;
676            }
677        }
678
679        result
680    }
681
682    /// Encodes a UTF-8 string as bytes using [`encode_bytes`].
683    #[must_use]
684    pub fn encode_text(&self, text: &str) -> Vec<RingElement> {
685        self.encode_bytes(text.as_bytes())
686    }
687
688    /// Decodes elements into bytes with [`decode_bytes`] and interprets them as UTF-8.
689    ///
690    /// # Errors
691    ///
692    /// - Returns [`EncodingError::DecodingError`] if the decoded bytes are not
693    ///   valid UTF-8.
694    ///
695    /// # Notes
696    ///
697    /// This encoding does not embed the original byte length. Depending on the
698    /// modulus and upstream padding, the decoded byte stream may contain trailing
699    /// `0x00` bytes, which are valid UTF-8 and will appear as `\0` characters in the
700    /// returned string.
701    pub fn decode_text(&self, elements: &[RingElement]) -> Result<String, EncodingError> {
702        let bytes = self.decode_bytes(elements);
703        String::from_utf8(bytes).map_err(|e| EncodingError::DecodingError(e.to_string()))
704    }
705}
706
707/// Errors that can occur while encoding or decoding.
708#[derive(Debug, Clone, thiserror::Error)]
709pub enum EncodingError {
710    /// The provided modulus is too small to encode any payload bits.
711    #[error("Modulus too small for encoding")]
712    ModulusTooSmall,
713    /// Encoding/decoding failed due to malformed input or invalid UTF-8.
714    #[error("Decoding error: {0}")]
715    DecodingError(String),
716}
717
718#[cfg(test)]
719mod tests {
720    use super::*;
721    use proptest::prelude::*;
722
723    #[test]
724    fn test_text_roundtrip_mod_256() {
725        let encoder = Utf8Encoding::try_from(256, Utf8EncodingType::BitPacking).unwrap();
726
727        let text = "Hello, World!";
728        let enc = encoder.encode_text(text);
729        let dec = encoder.decode_text(&enc).unwrap();
730        assert_eq!(dec, text);
731    }
732
733    proptest! {
734        #[test]
735        fn proptest_text_roundtrip_default(
736            text in proptest::string::string_regex(r"[\s\S]{0,200000}").unwrap(),
737            modulus in 2u64..=511,
738        ) {
739            let encoder = Utf8Encoding::try_from(modulus, Utf8EncodingType::BitPacking).unwrap();
740
741            let enc = encoder.encode_text(&text);
742            let dec = encoder.decode_text(&enc).unwrap();
743            prop_assert_eq!(dec, text);
744        }
745    }
746
747    #[test]
748    fn test_bit_stream_roundtrip_small_modulus() {
749        let encoder = Utf8Encoding::try_from(16, Utf8EncodingType::BitPacking).unwrap();
750
751        let data = vec![0xFF, 0x00, 0xAB, 0xCD];
752        let enc = encoder.encode_bytes(&data); // bits_per_elem = 4
753        let dec = encoder.decode_bytes(&enc);
754        assert_eq!(dec[..data.len()], data[..]);
755    }
756
757    #[test]
758    fn test_empty() {
759        let encoder = Utf8Encoding::try_from(16, Utf8EncodingType::BitPacking).unwrap();
760
761        let enc = encoder.encode_bytes(&[]);
762        let dec = encoder.decode_bytes(&enc);
763        assert!(dec.is_empty());
764    }
765
766    #[test]
767    fn test_modulus_too_small() {
768        let encoder = Utf8Encoding::try_from(1, Utf8EncodingType::BitPacking);
769        assert!(matches!(encoder, Err(EncodingError::ModulusTooSmall)));
770    }
771
772    #[test]
773    fn test_base_m_len_roundtrip_bytes_various_moduli() {
774        let data = vec![0x12, 0x00, 0x00, 0xFF, 0x00];
775        let moduli = [2u64, 3, 13, 128, 257, 65535];
776        for &m in &moduli {
777            let encoder =
778                Utf8Encoding::try_from(m, Utf8EncodingType::LengthDelimitedBaseMTransduction)
779                    .unwrap();
780            let enc = encoder.encode_bytes_base_m_len(&data).unwrap();
781            assert!(enc.iter().all(|e| e.value() < m));
782            let dec = encoder.decode_bytes_base_m_len(&enc).unwrap();
783            assert_eq!(dec, data);
784        }
785    }
786
787    #[test]
788    fn test_base_m_len_zero_padding_invariance() {
789        let data = vec![0x00, 0x00, 0x01, 0x02, 0x00, 0x00];
790
791        let encoder =
792            Utf8Encoding::try_from(257u64, Utf8EncodingType::LengthDelimitedBaseMTransduction)
793                .unwrap();
794
795        let enc = encoder.encode_bytes_base_m_len(&data).unwrap();
796        let mut padded = enc;
797        for _ in 0..17 {
798            padded.push(RingElement::zero(257u64));
799        }
800        let dec = encoder.decode_bytes_base_m_len(&padded).unwrap();
801        assert_eq!(dec, data);
802    }
803
804    #[test]
805    fn test_base_m_len_trailing_garbage_ignored() {
806        let data = vec![0x10, 0x20, 0x00, 0xFF, 0x01];
807        let modulus = 257u64;
808
809        let encoder =
810            Utf8Encoding::try_from(modulus, Utf8EncodingType::LengthDelimitedBaseMTransduction)
811                .unwrap();
812
813        let mut enc = encoder.encode_bytes_base_m_len(&data).unwrap();
814        enc.push(RingElement::new(1, modulus));
815        enc.push(RingElement::new(200, modulus));
816        enc.push(RingElement::new(256, modulus));
817        let dec = encoder.decode_bytes_base_m_len(&enc).unwrap();
818        assert_eq!(dec, data);
819    }
820
821    #[test]
822    fn test_base_m_len_decode_empty_errors() {
823        let encoder =
824            Utf8Encoding::try_from(10, Utf8EncodingType::LengthDelimitedBaseMTransduction).unwrap();
825        let res = encoder.decode_bytes_base_m_len(&[]);
826        assert!(matches!(res, Err(EncodingError::DecodingError(_))));
827    }
828
829    #[test]
830    fn test_base_m_len_text_roundtrip() {
831        let text = "Hello\u{0000}\u{0000}";
832        let encoder =
833            Utf8Encoding::try_from(257, Utf8EncodingType::LengthDelimitedBaseMTransduction)
834                .unwrap();
835        let enc = encoder.encode_text_base_m_len(text).unwrap();
836        let dec = encoder.decode_text_base_m_len(&enc).unwrap();
837        assert_eq!(dec, text);
838    }
839
840    proptest! {
841        #[test]
842        fn proptest_text_roundtrip_with_base_m_len(
843            text in proptest::string::string_regex(r"[\s\S]{0,200000}").unwrap(),
844            modulus in 2u64..=400,
845        ) {
846            let encoder = Utf8Encoding::try_from(modulus, Utf8EncodingType::LengthDelimitedBaseMTransduction).unwrap();
847            let enc = encoder.encode_text_base_m_len(&text).unwrap();
848            let dec = encoder.decode_text_base_m_len(&enc).unwrap();
849            prop_assert_eq!(dec, text);
850        }
851    }
852}