Skip to main content

phasm_core/stego/armor/
ecc.rs

1// Copyright (c) 2026 Christoph Gaffga
2// SPDX-License-Identifier: GPL-3.0-only
3// https://github.com/cgaffga/phasmcore
4
5//! Reed-Solomon error correction over GF(2^8).
6//!
7//! Implements RS(255, k) with the primitive polynomial 0x11D (x^8+x^4+x^3+x^2+1).
8//! Supports systematic encoding and Berlekamp-Massey decoding with Chien search
9//! and Forney algorithm. Shortened codes are used for payloads smaller than k.
10
11/// Primitive polynomial for GF(2^8): x^8 + x^4 + x^3 + x^2 + 1 = 0x11D.
12const PRIM_POLY: u16 = 0x11D;
13
14/// Maximum RS block size.
15const N_MAX: usize = 255;
16
17/// Default data symbols per block.
18const K_DEFAULT: usize = 191;
19
20/// Number of parity symbols (n - k).
21const PARITY_LEN: usize = N_MAX - K_DEFAULT; // 64
22
23/// Error correction capability: t = parity_len / 2 = 32.
24pub const T_MAX: usize = PARITY_LEN / 2;
25
26/// Fixed parity tiers for adaptive RS (limits decoder search space to 4 attempts).
27pub const PARITY_TIERS: [usize; 4] = [64, 128, 192, 240];
28
29// --- GF(2^8) Arithmetic ---
30
31/// Precomputed log and exp tables for GF(2^8).
32struct GfTables {
33    exp: [u8; 512],
34    log: [u8; 256],
35}
36
37/// Build log/exp tables at compile time is not possible in const context with loops
38/// in stable Rust, so we build them once at runtime.
39fn build_gf_tables() -> GfTables {
40    let mut exp = [0u8; 512];
41    let mut log = [0u8; 256];
42
43    let mut x: u16 = 1;
44    for i in 0..255u16 {
45        exp[i as usize] = x as u8;
46        exp[(i + 255) as usize] = x as u8; // wrap-around for easy modular access
47        log[x as usize] = i as u8;
48        x <<= 1;
49        if x & 0x100 != 0 {
50            x ^= PRIM_POLY;
51        }
52    }
53    // log[0] is undefined (log of 0 doesn't exist), leave as 0
54    // exp[510] and exp[511] are unused padding
55    exp[510] = exp[0];
56    exp[511] = exp[1];
57
58    GfTables { exp, log }
59}
60
61fn gf_tables() -> &'static GfTables {
62    use std::sync::OnceLock;
63    static TABLES: OnceLock<GfTables> = OnceLock::new();
64    TABLES.get_or_init(build_gf_tables)
65}
66
67/// GF(2^8) multiplication.
68fn gf_mul(a: u8, b: u8) -> u8 {
69    if a == 0 || b == 0 {
70        return 0;
71    }
72    let t = gf_tables();
73    let log_sum = t.log[a as usize] as usize + t.log[b as usize] as usize;
74    t.exp[log_sum]
75}
76
77/// GF(2^8) addition (same as XOR).
78fn gf_add(a: u8, b: u8) -> u8 {
79    a ^ b
80}
81
82/// GF(2^8) multiplicative inverse. Panics if a == 0.
83fn gf_inv(a: u8) -> u8 {
84    assert_ne!(a, 0, "cannot invert zero in GF(2^8)");
85    let t = gf_tables();
86    t.exp[255 - t.log[a as usize] as usize]
87}
88
89/// GF(2^8) power: a^n.
90#[cfg(test)]
91fn gf_pow(a: u8, n: u32) -> u8 {
92    if a == 0 {
93        return if n == 0 { 1 } else { 0 };
94    }
95    let t = gf_tables();
96    let log_a = t.log[a as usize] as u32;
97    let exp_idx = (log_a * n) % 255;
98    t.exp[exp_idx as usize]
99}
100
101/// Evaluate polynomial at x. poly[0] is the highest-degree coefficient.
102fn poly_eval(poly: &[u8], x: u8) -> u8 {
103    let mut result = 0u8;
104    for &coeff in poly {
105        result = gf_add(gf_mul(result, x), coeff);
106    }
107    result
108}
109
110/// Multiply two polynomials. poly[0] is highest-degree coefficient.
111fn poly_mul(a: &[u8], b: &[u8]) -> Vec<u8> {
112    let mut result = vec![0u8; a.len() + b.len() - 1];
113    for (i, &ac) in a.iter().enumerate() {
114        for (j, &bc) in b.iter().enumerate() {
115            result[i + j] = gf_add(result[i + j], gf_mul(ac, bc));
116        }
117    }
118    result
119}
120
121// --- Generator Polynomial ---
122
123/// Build the RS generator polynomial g(x) = prod_{i=0}^{2t-1} (x - alpha^i).
124/// Returns coefficients from highest to lowest degree.
125fn build_gen_poly(parity_len: usize) -> Vec<u8> {
126    let t = gf_tables();
127    let mut gpoly = vec![1u8]; // Start with 1
128
129    for i in 0..parity_len {
130        let root = t.exp[i]; // alpha^i
131        gpoly = poly_mul(&gpoly, &[1, root]);
132    }
133    gpoly
134}
135
136fn gen_poly() -> &'static Vec<u8> {
137    use std::sync::OnceLock;
138    static GEN: OnceLock<Vec<u8>> = OnceLock::new();
139    GEN.get_or_init(|| build_gen_poly(PARITY_LEN))
140}
141
142/// Cached generator polynomial for a given parity length.
143/// Uses OnceLock per tier to avoid recomputing.
144fn gen_poly_for(parity_len: usize) -> &'static Vec<u8> {
145    use std::sync::OnceLock;
146    static GEN_4: OnceLock<Vec<u8>> = OnceLock::new();
147    static GEN_8: OnceLock<Vec<u8>> = OnceLock::new();
148    static GEN_16: OnceLock<Vec<u8>> = OnceLock::new();
149    static GEN_32: OnceLock<Vec<u8>> = OnceLock::new();
150    static GEN_64: OnceLock<Vec<u8>> = OnceLock::new();
151    static GEN_128: OnceLock<Vec<u8>> = OnceLock::new();
152    static GEN_192: OnceLock<Vec<u8>> = OnceLock::new();
153    static GEN_240: OnceLock<Vec<u8>> = OnceLock::new();
154
155    match parity_len {
156        4 => GEN_4.get_or_init(|| build_gen_poly(4)),
157        8 => GEN_8.get_or_init(|| build_gen_poly(8)),
158        16 => GEN_16.get_or_init(|| build_gen_poly(16)),
159        32 => GEN_32.get_or_init(|| build_gen_poly(32)),
160        64 => GEN_64.get_or_init(|| build_gen_poly(64)),
161        128 => GEN_128.get_or_init(|| build_gen_poly(128)),
162        192 => GEN_192.get_or_init(|| build_gen_poly(192)),
163        240 => GEN_240.get_or_init(|| build_gen_poly(240)),
164        _ => {
165            // For the default parity, reuse the existing cache
166            if parity_len == PARITY_LEN {
167                gen_poly()
168            } else {
169                panic!("unsupported parity length: {parity_len}")
170            }
171        }
172    }
173}
174
175// --- Encoding ---
176
177/// RS-encode a single data block (systematic encoding).
178///
179/// # Arguments
180/// - `data`: Data bytes of length <= `K_DEFAULT` (191).
181///
182/// # Returns
183/// A vector of `data.len() + PARITY_LEN` bytes: the original data followed
184/// by 64 parity symbols.
185///
186/// # Panics
187/// Panics if `data.len() > K_DEFAULT`.
188///
189/// For shortened codes (`data.len() < K_DEFAULT`), the data is conceptually
190/// zero-padded at the front to K_DEFAULT, encoded, then the padding is removed.
191/// The parity symbols are computed over this virtual full-length block.
192pub fn rs_encode(data: &[u8]) -> Vec<u8> {
193    assert!(
194        data.len() <= K_DEFAULT,
195        "data length {} exceeds max {}",
196        data.len(),
197        K_DEFAULT
198    );
199
200    let gpoly = gen_poly();
201    let parity_len = PARITY_LEN;
202
203    // Systematic encoding: compute remainder of data * x^parity_len / g(x).
204    // Work with the actual data length (shortened code).
205    let mut shift_reg = vec![0u8; parity_len];
206
207    for &byte in data {
208        let feedback = gf_add(byte, shift_reg[0]);
209        // Shift left
210        for j in 0..parity_len - 1 {
211            shift_reg[j] = gf_add(shift_reg[j + 1], gf_mul(feedback, gpoly[j + 1]));
212        }
213        shift_reg[parity_len - 1] = gf_mul(feedback, gpoly[parity_len]);
214    }
215
216    // Output: data || parity
217    let mut encoded = Vec::with_capacity(data.len() + parity_len);
218    encoded.extend_from_slice(data);
219    encoded.extend_from_slice(&shift_reg);
220    encoded
221}
222
223/// RS-encode an arbitrarily long payload, splitting into [`K_DEFAULT`]-byte blocks.
224///
225/// Returns the concatenation of all RS-encoded blocks. Each block has
226/// `min(remaining_data, K_DEFAULT) + PARITY_LEN` bytes. The last block
227/// may be a shortened code if `payload.len() % K_DEFAULT != 0`.
228pub fn rs_encode_blocks(payload: &[u8]) -> Vec<u8> {
229    let mut encoded = Vec::new();
230    for chunk in payload.chunks(K_DEFAULT) {
231        encoded.extend_from_slice(&rs_encode(chunk));
232    }
233    encoded
234}
235
236// --- Decoding ---
237
238/// Compute syndromes S_0 .. S_{2t-1} for a received block (FCR=0).
239/// poly_eval treats received as highest-degree-first: r(x) = received[0]*x^{n-1} + ...
240fn compute_syndromes(received: &[u8]) -> Vec<u8> {
241    let tab = gf_tables();
242    let two_t = PARITY_LEN;
243    let mut syndromes = vec![0u8; two_t];
244    for i in 0..two_t {
245        syndromes[i] = poly_eval(received, tab.exp[i]); // S_i = r(α^i)
246    }
247    syndromes
248}
249
250fn syndromes_are_zero(syndromes: &[u8]) -> bool {
251    syndromes.iter().all(|&s| s == 0)
252}
253
254/// Berlekamp-Massey algorithm.
255///
256/// Returns sigma(x) coefficients in ascending power: sigma[0]=1, sigma[1]=σ_1, etc.
257fn berlekamp_massey(syndromes: &[u8]) -> Vec<u8> {
258    let n = syndromes.len(); // 2t
259
260    // C(x) = error locator, ascending power
261    let mut c = vec![0u8; n + 1];
262    c[0] = 1;
263    let mut c_len = 1usize;
264
265    // B(x) = previous C, ascending power
266    let mut b = vec![0u8; n + 1];
267    b[0] = 1;
268    let mut b_len = 1usize;
269
270    let mut ell = 0usize; // current error count estimate
271    let mut bval = 1u8; // previous discrepancy
272    let mut m = 1usize; // step counter
273
274    for r in 0..n {
275        // Discrepancy
276        let mut delta = syndromes[r];
277        for i in 1..c_len {
278            delta = gf_add(delta, gf_mul(c[i], syndromes[r - i]));
279        }
280
281        if delta == 0 {
282            m += 1;
283            continue;
284        }
285
286        let factor = gf_mul(delta, gf_inv(bval));
287
288        if 2 * ell <= r {
289            // Save C before updating (it becomes the new B)
290            let old_c = c[..c_len].to_vec();
291            let old_c_len = c_len;
292
293            let new_len = (b_len + m).max(c_len);
294            c_len = new_len;
295            for j in 0..b_len {
296                c[j + m] = gf_add(c[j + m], gf_mul(factor, b[j]));
297            }
298
299            b[..old_c_len].copy_from_slice(&old_c[..old_c_len]);
300            for j in old_c_len..b.len() {
301                b[j] = 0;
302            }
303            b_len = old_c_len;
304            ell = r + 1 - ell;
305            bval = delta;
306            m = 1;
307        } else {
308            let new_len = (b_len + m).max(c_len);
309            c_len = new_len;
310            for j in 0..b_len {
311                c[j + m] = gf_add(c[j + m], gf_mul(factor, b[j]));
312            }
313            m += 1;
314        }
315    }
316
317    c[..c_len].to_vec()
318}
319
320/// Evaluate polynomial in ascending power format at x.
321fn eval_asc(poly: &[u8], x: u8) -> u8 {
322    let mut result = 0u8;
323    let mut x_pow = 1u8;
324    for &coeff in poly {
325        result = gf_add(result, gf_mul(coeff, x_pow));
326        x_pow = gf_mul(x_pow, x);
327    }
328    result
329}
330
331/// Chien search: find roots of sigma(x) to determine error positions.
332///
333/// Convention: poly_eval treats the codeword as c(x) = c[0]*x^{n-1} + c[1]*x^{n-2} + ...
334/// An error at array index k affects the coefficient of x^{n-1-k}.
335/// The error locator polynomial sigma(x) has roots at X_l^{-1} where X_l = α^{n-1-k_l}.
336///
337/// Returns (gf_pos, array_pos) pairs.
338fn chien_search(sigma_asc: &[u8], n: usize) -> Option<Vec<(usize, usize)>> {
339    if n == 0 {
340        return None;
341    }
342    let tab = gf_tables();
343    let num_errors = sigma_asc.len() - 1;
344    let mut found = Vec::with_capacity(num_errors);
345
346    // Test sigma at α^{-p} for each GF position p = 0..n-1.
347    // If sigma(α^{-p}) = 0, error at GF position p → array index n-1-p.
348    for p in 0..n {
349        let x = if p == 0 {
350            1u8
351        } else {
352            tab.exp[(255 - (p % 255)) % 255] // α^{-p}
353        };
354        if eval_asc(sigma_asc, x) == 0 {
355            found.push((p, n - 1 - p));
356        }
357    }
358
359    if found.len() != num_errors {
360        return None;
361    }
362
363    Some(found)
364}
365
366/// Forney algorithm: compute error magnitudes.
367///
368/// With FCR=0: e_l = X_l * Omega(X_l^{-1}) / Sigma'(X_l^{-1})
369/// where X_l = α^{gf_pos}, and Omega = S(x) * Sigma(x) mod x^{2t},
370/// S(x) = S_0 + S_1*x + S_2*x^2 + ...
371fn forney(
372    sigma_asc: &[u8],
373    syndromes: &[u8],
374    found: &[(usize, usize)],
375) -> Vec<u8> {
376    let tab = gf_tables();
377    let two_t = syndromes.len();
378
379    // Omega(x) = S(x) * Sigma(x) mod x^{2t} (ascending power)
380    let mut omega = vec![0u8; two_t];
381    for i in 0..sigma_asc.len().min(two_t) {
382        for j in 0..two_t {
383            if i + j < two_t {
384                omega[i + j] = gf_add(omega[i + j], gf_mul(sigma_asc[i], syndromes[j]));
385            }
386        }
387    }
388
389    // Formal derivative of sigma (ascending power):
390    // d/dx (a_0 + a_1*x + a_2*x^2 + ...) = a_1 + 2*a_2*x + 3*a_3*x^2 + ...
391    // In GF(2^m): even multipliers vanish, odd survive (3=1, 5=1, etc.)
392    // So sigma'[j] = sigma_asc[j+1] if (j+1) is odd, else 0
393    let deriv_len = sigma_asc.len().saturating_sub(1);
394    let mut sigma_prime = vec![0u8; deriv_len];
395    for i in (1..sigma_asc.len()).step_by(2) {
396        sigma_prime[i - 1] = sigma_asc[i];
397    }
398
399    let mut magnitudes = Vec::with_capacity(found.len());
400    for &(gf_pos, _) in found {
401        let x_val = if gf_pos == 0 {
402            1u8
403        } else {
404            tab.exp[gf_pos % 255] // α^{gf_pos} = X_l
405        };
406        let x_inv = if gf_pos == 0 {
407            1u8
408        } else {
409            tab.exp[(255 - (gf_pos % 255)) % 255] // α^{-gf_pos} = X_l^{-1}
410        };
411
412        let omega_val = eval_asc(&omega, x_inv);
413        let sp_val = eval_asc(&sigma_prime, x_inv);
414
415        if sp_val == 0 {
416            magnitudes.push(0);
417            continue;
418        }
419
420        // FCR=0: e = X_l * Omega(X_l^{-1}) / Sigma'(X_l^{-1})
421        magnitudes.push(gf_mul(x_val, gf_mul(omega_val, gf_inv(sp_val))));
422    }
423
424    magnitudes
425}
426
427/// Error returned when RS decoding fails (too many errors).
428#[derive(Debug, PartialEq)]
429pub struct RsDecodeError;
430
431impl core::fmt::Display for RsDecodeError {
432    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
433        write!(f, "Reed-Solomon: too many errors to correct")
434    }
435}
436
437/// RS-decode a single block with error correction.
438///
439/// # Arguments
440/// - `received`: Received block of length `data_len + PARITY_LEN`.
441/// - `data_len`: Original data length (before parity was appended).
442///
443/// # Returns
444/// A tuple of (corrected data of length `data_len`, number of symbol errors corrected).
445///
446/// # Panics
447/// Panics if `received.len() != data_len + PARITY_LEN`.
448///
449/// # Errors
450/// Returns [`RsDecodeError`] if there are more than `t=32` symbol errors,
451/// or if errors fall in the zero-padded region of a shortened code.
452///
453/// For shortened codes (`data_len < K_DEFAULT`), the received block is
454/// conceptually zero-padded at the front to form a full 255-symbol block
455/// during syndrome computation.
456pub fn rs_decode(received: &[u8], data_len: usize) -> Result<(Vec<u8>, usize), RsDecodeError> {
457    let block_len = data_len + PARITY_LEN;
458    assert_eq!(
459        received.len(),
460        block_len,
461        "received length {} != expected {}",
462        received.len(),
463        block_len
464    );
465
466    // For shortened codes, prepend zeros to make a full 255-symbol block
467    let padding = N_MAX - block_len;
468    let mut full_block = vec![0u8; N_MAX];
469    full_block[padding..].copy_from_slice(received);
470
471    // Compute syndromes on the full block
472    let syndromes = compute_syndromes(&full_block);
473
474    if syndromes_are_zero(&syndromes) {
475        return Ok((received[..data_len].to_vec(), 0));
476    }
477
478    // Find error locator polynomial (ascending power)
479    let sigma_asc = berlekamp_massey(&syndromes);
480    let num_errors = sigma_asc.len() - 1;
481
482    if num_errors > T_MAX {
483        return Err(RsDecodeError);
484    }
485
486    // Chien search: find (gf_pos, array_pos) pairs in the full 255-symbol block
487    let found = chien_search(&sigma_asc, N_MAX).ok_or(RsDecodeError)?;
488
489    // Forney: compute error magnitudes
490    let magnitudes = forney(&sigma_asc, &syndromes, &found);
491
492    // Apply corrections
493    let mut corrected = full_block;
494    for (i, &(_, array_pos)) in found.iter().enumerate() {
495        if array_pos < padding {
496            // Error in the zero-padded region of a shortened code — can't correct
497            return Err(RsDecodeError);
498        }
499        corrected[array_pos] = gf_add(corrected[array_pos], magnitudes[i]);
500    }
501
502    // Verify syndromes are now zero
503    let check_syndromes = compute_syndromes(&corrected);
504    if !syndromes_are_zero(&check_syndromes) {
505        return Err(RsDecodeError);
506    }
507
508    // Extract data (skip padding)
509    Ok((corrected[padding..padding + data_len].to_vec(), num_errors))
510}
511
512/// Statistics from RS decoding across all blocks.
513#[derive(Debug, Clone, Default)]
514pub struct RsDecodeStats {
515    /// Total symbol errors corrected across all blocks.
516    pub total_errors: usize,
517    /// Maximum correctable errors per block (T_MAX × num_blocks).
518    pub error_capacity: usize,
519    /// Maximum errors found in any single block.
520    pub max_block_errors: usize,
521    /// Number of RS blocks decoded.
522    pub num_blocks: usize,
523}
524
525/// RS-decode a payload that was encoded with [`rs_encode_blocks`].
526///
527/// Splits the encoded data into blocks based on `total_data_len`, decodes
528/// each block independently (correcting up to `t=32` symbol errors per block),
529/// and concatenates the results. Returns decode stats with error counts.
530///
531/// # Arguments
532/// - `encoded`: The RS-encoded data (output of `rs_encode_blocks`).
533/// - `total_data_len`: The original payload length before encoding.
534///
535/// # Errors
536/// Returns [`RsDecodeError`] if any block has too many errors to correct
537/// or the encoded data is too short.
538pub fn rs_decode_blocks(encoded: &[u8], total_data_len: usize) -> Result<(Vec<u8>, RsDecodeStats), RsDecodeError> {
539    let mut decoded = Vec::with_capacity(total_data_len);
540    let mut remaining_data = total_data_len;
541    let mut offset = 0;
542    let mut stats = RsDecodeStats::default();
543
544    while remaining_data > 0 {
545        let chunk_data_len = remaining_data.min(K_DEFAULT);
546        let block_len = chunk_data_len + PARITY_LEN;
547
548        if offset + block_len > encoded.len() {
549            return Err(RsDecodeError);
550        }
551
552        let block = &encoded[offset..offset + block_len];
553        let (data, errors) = rs_decode(block, chunk_data_len)?;
554        decoded.extend_from_slice(&data);
555
556        stats.total_errors += errors;
557        stats.num_blocks += 1;
558        if errors > stats.max_block_errors {
559            stats.max_block_errors = errors;
560        }
561
562        offset += block_len;
563        remaining_data -= chunk_data_len;
564    }
565
566    stats.error_capacity = stats.num_blocks * T_MAX;
567    Ok((decoded, stats))
568}
569
570/// Return the RS-encoded length for a given data length.
571pub fn rs_encoded_len(data_len: usize) -> usize {
572    let full_blocks = data_len / K_DEFAULT;
573    let remainder = data_len % K_DEFAULT;
574    let mut total = full_blocks * (K_DEFAULT + PARITY_LEN);
575    if remainder > 0 {
576        total += remainder + PARITY_LEN;
577    }
578    total
579}
580
581/// Return the parity length per block.
582pub const fn parity_len() -> usize {
583    PARITY_LEN
584}
585
586// --- Adaptive RS with configurable parity ---
587
588/// RS-encode a single data block with configurable parity length.
589///
590/// # Arguments
591/// - `data`: Data bytes of length <= `255 - parity_len`.
592/// - `parity_len`: Number of parity symbols (must be even, <= 240).
593///
594/// # Returns
595/// A vector of `data.len() + parity_len` bytes.
596pub fn rs_encode_with_parity(data: &[u8], parity_len: usize) -> Vec<u8> {
597    if parity_len == 0 { return data.to_vec(); }
598    let k_max = N_MAX - parity_len;
599    assert!(
600        data.len() <= k_max,
601        "data length {} exceeds max {} for parity_len={}",
602        data.len(),
603        k_max,
604        parity_len
605    );
606    assert!(parity_len <= 240, "parity_len {} exceeds 240", parity_len);
607
608    let gpoly = gen_poly_for(parity_len);
609    let mut shift_reg = vec![0u8; parity_len];
610
611    for &byte in data {
612        let feedback = gf_add(byte, shift_reg[0]);
613        for j in 0..parity_len - 1 {
614            shift_reg[j] = gf_add(shift_reg[j + 1], gf_mul(feedback, gpoly[j + 1]));
615        }
616        shift_reg[parity_len - 1] = gf_mul(feedback, gpoly[parity_len]);
617    }
618
619    let mut encoded = Vec::with_capacity(data.len() + parity_len);
620    encoded.extend_from_slice(data);
621    encoded.extend_from_slice(&shift_reg);
622    encoded
623}
624
625/// RS-decode a single block with configurable parity length.
626///
627/// # Arguments
628/// - `received`: Received block of length `data_len + parity_len`.
629/// - `data_len`: Original data length.
630/// - `parity_len`: Parity symbols used during encoding.
631///
632/// # Returns
633/// (corrected data, number of errors corrected).
634pub fn rs_decode_with_parity(
635    received: &[u8],
636    data_len: usize,
637    parity_len: usize,
638) -> Result<(Vec<u8>, usize), RsDecodeError> {
639    let block_len = data_len + parity_len;
640    assert_eq!(
641        received.len(),
642        block_len,
643        "received length {} != expected {}",
644        received.len(),
645        block_len
646    );
647
648    let padding = N_MAX - block_len;
649    let mut full_block = vec![0u8; N_MAX];
650    full_block[padding..].copy_from_slice(received);
651
652    // Compute syndromes with the given parity length
653    let tab = gf_tables();
654    let mut syndromes = vec![0u8; parity_len];
655    for i in 0..parity_len {
656        syndromes[i] = poly_eval(&full_block, tab.exp[i]);
657    }
658
659    if syndromes.iter().all(|&s| s == 0) {
660        return Ok((received[..data_len].to_vec(), 0));
661    }
662
663    let t_max = parity_len / 2;
664    let sigma_asc = berlekamp_massey(&syndromes);
665    let num_errors = sigma_asc.len() - 1;
666
667    if num_errors > t_max {
668        return Err(RsDecodeError);
669    }
670
671    let found = chien_search(&sigma_asc, N_MAX).ok_or(RsDecodeError)?;
672    let magnitudes = forney(&sigma_asc, &syndromes, &found);
673
674    let mut corrected = full_block;
675    for (i, &(_, array_pos)) in found.iter().enumerate() {
676        if array_pos < padding {
677            return Err(RsDecodeError);
678        }
679        corrected[array_pos] = gf_add(corrected[array_pos], magnitudes[i]);
680    }
681
682    // Verify syndromes are now zero
683    let mut check_ok = true;
684    for i in 0..parity_len {
685        if poly_eval(&corrected, tab.exp[i]) != 0 {
686            check_ok = false;
687            break;
688        }
689    }
690    if !check_ok {
691        return Err(RsDecodeError);
692    }
693
694    Ok((corrected[padding..padding + data_len].to_vec(), num_errors))
695}
696
697/// RS-encode an arbitrarily long payload with configurable parity, splitting into blocks.
698pub fn rs_encode_blocks_with_parity(payload: &[u8], parity_len: usize) -> Vec<u8> {
699    let k_max = N_MAX - parity_len;
700    let mut encoded = Vec::new();
701    for chunk in payload.chunks(k_max) {
702        encoded.extend_from_slice(&rs_encode_with_parity(chunk, parity_len));
703    }
704    encoded
705}
706
707/// RS-decode a payload encoded with [`rs_encode_blocks_with_parity`].
708pub fn rs_decode_blocks_with_parity(
709    encoded: &[u8],
710    total_data_len: usize,
711    parity_len: usize,
712) -> Result<(Vec<u8>, RsDecodeStats), RsDecodeError> {
713    let k_max = N_MAX - parity_len;
714    let t_max = parity_len / 2;
715    let mut decoded = Vec::with_capacity(total_data_len);
716    let mut remaining_data = total_data_len;
717    let mut offset = 0;
718    let mut stats = RsDecodeStats::default();
719
720    while remaining_data > 0 {
721        let chunk_data_len = remaining_data.min(k_max);
722        let block_len = chunk_data_len + parity_len;
723
724        if offset + block_len > encoded.len() {
725            return Err(RsDecodeError);
726        }
727
728        let block = &encoded[offset..offset + block_len];
729        let (data, errors) = rs_decode_with_parity(block, chunk_data_len, parity_len)?;
730        decoded.extend_from_slice(&data);
731
732        stats.total_errors += errors;
733        stats.num_blocks += 1;
734        if errors > stats.max_block_errors {
735            stats.max_block_errors = errors;
736        }
737
738        offset += block_len;
739        remaining_data -= chunk_data_len;
740    }
741
742    stats.error_capacity = stats.num_blocks * t_max;
743    Ok((decoded, stats))
744}
745
746/// Return the RS-encoded length for a given data length and parity length.
747pub fn rs_encoded_len_with_parity(data_len: usize, parity_len: usize) -> usize {
748    let k_max = N_MAX - parity_len;
749    let full_blocks = data_len / k_max;
750    let remainder = data_len % k_max;
751    let mut total = full_blocks * (k_max + parity_len);
752    if remainder > 0 {
753        total += remainder + parity_len;
754    }
755    total
756}
757
758/// Choose the best parity tier for a given frame size and embedding capacity (in bits).
759///
760/// Picks the largest parity from [`PARITY_TIERS`] where the RS-encoded data
761/// (in bits) still fits within `num_units`.
762pub fn choose_parity_tier(frame_len: usize, num_units: usize) -> usize {
763    let mut best = PARITY_TIERS[0]; // fallback to smallest
764    for &tier in &PARITY_TIERS {
765        let rs_bits = rs_encoded_len_with_parity(frame_len, tier) * 8;
766        if rs_bits <= num_units {
767            best = tier;
768        } else {
769            break;
770        }
771    }
772    best
773}
774
775#[cfg(test)]
776mod tests {
777    use super::*;
778
779    #[test]
780    fn gf_mul_identity() {
781        for a in 0..=255u16 {
782            assert_eq!(gf_mul(a as u8, 1), a as u8);
783            assert_eq!(gf_mul(1, a as u8), a as u8);
784        }
785    }
786
787    #[test]
788    fn gf_mul_zero() {
789        for a in 0..=255u16 {
790            assert_eq!(gf_mul(a as u8, 0), 0);
791            assert_eq!(gf_mul(0, a as u8), 0);
792        }
793    }
794
795    #[test]
796    fn gf_inverse_roundtrip() {
797        for a in 1..=255u16 {
798            let inv = gf_inv(a as u8);
799            assert_eq!(gf_mul(a as u8, inv), 1, "a={a}, inv={inv}");
800        }
801    }
802
803    #[test]
804    fn gf_pow_consistency() {
805        let t = gf_tables();
806        for a in 1..=255u16 {
807            // a^1 == a
808            assert_eq!(gf_pow(a as u8, 1), a as u8);
809            // a^0 == 1
810            assert_eq!(gf_pow(a as u8, 0), 1);
811            // a^255 == 1 (Fermat's little theorem for GF(2^8))
812            assert_eq!(gf_pow(a as u8, 255), 1, "a={a}");
813        }
814        let _ = t;
815    }
816
817    #[test]
818    fn encode_decode_no_errors() {
819        let data = b"Hello, Reed-Solomon!";
820        let encoded = rs_encode(data);
821        let (decoded, errors) = rs_decode(&encoded, data.len()).unwrap();
822        assert_eq!(decoded, data);
823        assert_eq!(errors, 0);
824    }
825
826    #[test]
827    fn encode_decode_with_errors() {
828        let data = b"Test message for RS error correction.";
829        let mut encoded = rs_encode(data);
830
831        // Introduce 10 symbol errors (well within t=32 correction capability).
832        // Note: data.len()=37, so avoid position 37 in the data region to
833        // prevent overlap with the first parity error at data.len().
834        encoded[0] ^= 0xFF;
835        encoded[5] ^= 0xAA;
836        encoded[10] ^= 0x55;
837        encoded[15] ^= 0x11;
838        encoded[20] ^= 0x22;
839        encoded[25] ^= 0x33;
840        encoded[30] ^= 0x01;
841        encoded[data.len()] ^= 0x77; // error in parity
842        encoded[data.len() + 10] ^= 0x88;
843        encoded[data.len() + 30] ^= 0x99;
844
845        let (decoded, errors) = rs_decode(&encoded, data.len()).unwrap();
846        assert_eq!(decoded, data);
847        assert_eq!(errors, 10);
848    }
849
850    #[test]
851    fn encode_decode_max_correctable() {
852        let data = vec![42u8; 100];
853        let mut encoded = rs_encode(&data);
854
855        // Introduce exactly t=32 errors
856        for i in 0..32 {
857            encoded[i * 3] ^= 0xFF;
858        }
859
860        let (decoded, errors) = rs_decode(&encoded, data.len()).unwrap();
861        assert_eq!(decoded, data);
862        assert_eq!(errors, 32);
863    }
864
865    #[test]
866    fn too_many_errors_fails() {
867        let data = vec![0u8; 50];
868        let mut encoded = rs_encode(&data);
869
870        // Introduce 33 errors (exceeds t=32)
871        for i in 0..33 {
872            encoded[i] ^= 0xFF;
873        }
874
875        assert!(rs_decode(&encoded, data.len()).is_err());
876    }
877
878    #[test]
879    fn shortened_code_works() {
880        // Very short data (much less than K_DEFAULT=191)
881        let data = b"Hi";
882        let encoded = rs_encode(data);
883        assert_eq!(encoded.len(), data.len() + PARITY_LEN);
884
885        let (decoded, errors) = rs_decode(&encoded, data.len()).unwrap();
886        assert_eq!(decoded, data);
887        assert_eq!(errors, 0);
888    }
889
890    #[test]
891    fn shortened_code_with_errors() {
892        let data = b"Short";
893        let mut encoded = rs_encode(data);
894        encoded[0] ^= 0xFF;
895        encoded[2] ^= 0xAA;
896
897        let (decoded, errors) = rs_decode(&encoded, data.len()).unwrap();
898        assert_eq!(decoded, data);
899        assert_eq!(errors, 2);
900    }
901
902    #[test]
903    fn blocks_roundtrip() {
904        // Data larger than K_DEFAULT, requiring multiple blocks
905        let data: Vec<u8> = (0..400).map(|i| (i % 256) as u8).collect();
906        let encoded = rs_encode_blocks(&data);
907
908        // Should be 2 blocks: 191+64=255 and 209+64=273 → 528 total
909        assert_eq!(encoded.len(), rs_encoded_len(data.len()));
910
911        let (decoded, stats) = rs_decode_blocks(&encoded, data.len()).unwrap();
912        assert_eq!(decoded, data);
913        assert_eq!(stats.total_errors, 0);
914    }
915
916    #[test]
917    fn blocks_with_errors() {
918        let data: Vec<u8> = (0..400).map(|i| (i % 256) as u8).collect();
919        let mut encoded = rs_encode_blocks(&data);
920
921        // Corrupt a few bytes in block 1 (starts at 0, len 255)
922        encoded[10] ^= 0xFF;
923        encoded[100] ^= 0xAA;
924        // Block 2 starts at 255 (len 255)
925        encoded[260] ^= 0x55;
926        encoded[300] ^= 0x11;
927        // Block 3 starts at 510 (len 82)
928        encoded[520] ^= 0x33;
929
930        let (decoded, stats) = rs_decode_blocks(&encoded, data.len()).unwrap();
931        assert_eq!(decoded, data);
932        assert_eq!(stats.total_errors, 5);
933        assert!(stats.max_block_errors <= 2);
934    }
935
936    #[test]
937    fn empty_data() {
938        let data: &[u8] = &[];
939        let encoded = rs_encode(data);
940        assert_eq!(encoded.len(), PARITY_LEN);
941        let (decoded, errors) = rs_decode(&encoded, 0).unwrap();
942        assert_eq!(decoded, data);
943        assert_eq!(errors, 0);
944    }
945
946    #[test]
947    fn rs_encoded_len_correct() {
948        assert_eq!(rs_encoded_len(100), 100 + 64);
949        assert_eq!(rs_encoded_len(191), 191 + 64);
950        assert_eq!(rs_encoded_len(192), (191 + 64) + (1 + 64));
951        // 400 / 191 = 2 full blocks (382), remainder 18 → 3 blocks
952        assert_eq!(rs_encoded_len(400), 2 * (191 + 64) + (18 + 64));
953    }
954
955    #[test]
956    fn rs_encoded_len_edge_cases() {
957        assert_eq!(rs_encoded_len(0), 0);
958        assert_eq!(rs_encoded_len(1), 1 + 64);
959        // Full block boundary
960        assert_eq!(rs_encoded_len(191), 191 + 64);
961        // Just over one block
962        assert_eq!(rs_encoded_len(192), (191 + 64) + (1 + 64));
963    }
964
965    #[test]
966    fn single_error_full_block() {
967        let data = vec![42u8; K_DEFAULT];
968        let mut encoded = rs_encode(&data);
969        encoded[50] ^= 0x01;
970        let (decoded, errors) = rs_decode(&encoded, K_DEFAULT).unwrap();
971        assert_eq!(decoded, data);
972        assert_eq!(errors, 1);
973    }
974
975    #[test]
976    fn single_error_shortened() {
977        let data = b"Short";
978        let mut encoded = rs_encode(data);
979        encoded[0] ^= 0xFF;
980        let (decoded, errors) = rs_decode(&encoded, data.len()).unwrap();
981        assert_eq!(decoded, data);
982        assert_eq!(errors, 1);
983    }
984
985    #[test]
986    fn two_errors_full_block() {
987        let data = vec![42u8; K_DEFAULT];
988        let mut encoded = rs_encode(&data);
989        encoded[0] ^= 0xFF;
990        encoded[50] ^= 0xAA;
991        let (decoded, errors) = rs_decode(&encoded, K_DEFAULT).unwrap();
992        assert_eq!(decoded, data);
993        assert_eq!(errors, 2);
994    }
995
996    #[test]
997    fn two_errors_shortened() {
998        let data = b"Short";
999        let mut encoded = rs_encode(data);
1000        encoded[0] ^= 0xFF;
1001        encoded[2] ^= 0xAA;
1002        let (decoded, errors) = rs_decode(&encoded, data.len()).unwrap();
1003        assert_eq!(decoded, data);
1004        assert_eq!(errors, 2);
1005    }
1006
1007    #[test]
1008    fn generator_polynomial_correct() {
1009        let gpoly = gen_poly();
1010        // Gen poly should have degree = PARITY_LEN, so length = PARITY_LEN + 1
1011        assert_eq!(gpoly.len(), PARITY_LEN + 1);
1012        // Leading coefficient should be 1
1013        assert_eq!(gpoly[0], 1);
1014        // All roots alpha^0 .. alpha^{2t-1} should evaluate to 0
1015        let t = gf_tables();
1016        for i in 0..PARITY_LEN {
1017            assert_eq!(poly_eval(gpoly, t.exp[i]), 0, "root alpha^{i} failed");
1018        }
1019    }
1020
1021    // --- Adaptive RS tests ---
1022
1023    #[test]
1024    fn adaptive_rs_roundtrip_each_tier() {
1025        for &parity in &PARITY_TIERS {
1026            let k_max = N_MAX - parity;
1027            let data: Vec<u8> = (0..k_max.min(100)).map(|i| (i % 256) as u8).collect();
1028            let encoded = rs_encode_with_parity(&data, parity);
1029            assert_eq!(encoded.len(), data.len() + parity);
1030            let (decoded, errors) = rs_decode_with_parity(&encoded, data.len(), parity).unwrap();
1031            assert_eq!(decoded, data, "parity={parity}");
1032            assert_eq!(errors, 0, "parity={parity}");
1033        }
1034    }
1035
1036    #[test]
1037    fn adaptive_rs_corrects_errors_at_each_tier() {
1038        for &parity in &PARITY_TIERS {
1039            let k_max = N_MAX - parity;
1040            let t = parity / 2;
1041            let data: Vec<u8> = (0..k_max.min(50)).map(|i| (i % 256) as u8).collect();
1042            let mut encoded = rs_encode_with_parity(&data, parity);
1043
1044            // Introduce t/2 errors (well within correction capability)
1045            let num_errors = (t / 2).min(encoded.len());
1046            let elen = encoded.len();
1047            for i in 0..num_errors {
1048                encoded[i * 2 % elen] ^= 0xFF;
1049            }
1050
1051            let (decoded, errors) = rs_decode_with_parity(&encoded, data.len(), parity).unwrap();
1052            assert_eq!(decoded, data, "parity={parity}");
1053            assert!(errors > 0, "parity={parity}");
1054        }
1055    }
1056
1057    #[test]
1058    fn adaptive_rs_blocks_roundtrip() {
1059        let data: Vec<u8> = (0..200).map(|i| (i % 256) as u8).collect();
1060        for &parity in &PARITY_TIERS {
1061            let encoded = rs_encode_blocks_with_parity(&data, parity);
1062            assert_eq!(encoded.len(), rs_encoded_len_with_parity(data.len(), parity));
1063            let (decoded, stats) = rs_decode_blocks_with_parity(&encoded, data.len(), parity).unwrap();
1064            assert_eq!(decoded, data, "parity={parity}");
1065            assert_eq!(stats.total_errors, 0, "parity={parity}");
1066        }
1067    }
1068
1069    #[test]
1070    fn rs_encoded_len_with_parity_correct() {
1071        // With parity=128, k_max=127
1072        assert_eq!(rs_encoded_len_with_parity(100, 128), 100 + 128);
1073        assert_eq!(rs_encoded_len_with_parity(127, 128), 127 + 128);
1074        // 128 bytes at parity=128: k_max=127, so 2 blocks: 127+128 + 1+128
1075        assert_eq!(rs_encoded_len_with_parity(128, 128), (127 + 128) + (1 + 128));
1076    }
1077
1078    #[test]
1079    fn choose_parity_tier_picks_largest_fitting() {
1080        // 100-byte frame, 10000 embedding units
1081        // tier 64: rs_len = 100+64=164 bytes = 1312 bits → fits
1082        // tier 128: rs_len = 100+128=228 bytes = 1824 bits → fits
1083        // tier 192: rs_len = 100+192=292 bytes (but k_max=63, 100>63, so 2 blocks)
1084        //   = 63+192 + 37+192 = 484 bytes = 3872 bits → fits
1085        // tier 240: rs_len = 100+240 but k_max=15, many blocks
1086        //   = ceil(100/15)*255 = 7*255 = 1785 bytes = 14280 bits → exceeds 10000
1087        let tier = choose_parity_tier(100, 10000);
1088        assert_eq!(tier, 192);
1089    }
1090}