leopard_codec/
lib.rs

1#![doc = include_str!("../README.md")]
2
3use bytes::{Buf, BufMut};
4use thiserror::Error;
5
6/// Number of bits per element
7pub const BITS: usize = 8;
8/// Finite field order: Number of elements in the field
9pub const ORDER: usize = u8::MAX as usize + 1;
10/// Modulus for field operations
11pub const MODULUS: u8 = u8::MAX;
12/// LFSR Polynomial that generates the field elements
13pub const POLYNOMIAL: usize = 0x11D;
14/// Basis used for generating logarithm tables
15pub const CANTOR_BASIS: [u8; BITS] = [1, 214, 152, 146, 86, 200, 88, 230];
16
17/// lookup tables
18mod lut;
19
20/// Possible errors that can happen when interacting with Leopard.
21#[derive(Debug, Error)]
22pub enum LeopardError {
23    /// Maximum number of shards exceeded.
24    #[error("Maximum shard number ({}) exceeded: {0}", ORDER)]
25    MaxShardNumberExceeded(usize),
26
27    /// Maximum number of parity shards exceeded.
28    #[error("Maximum parity shard number ({0}) exceeded: {1}")]
29    MaxParityShardNumberExceeded(usize, usize),
30
31    /// This amount of (data, parity) shards is not supported by Leopard algorithm
32    /// and would result in buffer overflow on skew lookup table during encoding.
33    /// Please try using different amounts of shards.
34    #[error("Unsupported number of data ({0}) and parity ({1}) shards")]
35    UnsupportedShardsAmounts(usize, usize),
36
37    /// Some shards contain no data.
38    #[error("Shards contain no data")]
39    EmptyShards,
40
41    /// Some shards are of different lengths.
42    #[error("Shards of different lengths found")]
43    UnequalShardsLengths,
44
45    /// Shard size is invalid.
46    #[error("Shard size ({0}) should be a multiple of 64")]
47    InvalidShardSize(usize),
48
49    /// To few shards to reconstruct data.
50    #[error("Too few shards ({0}) to reconstruct data, at least {1} needed")]
51    TooFewShards(usize, usize),
52}
53
54/// A result type with [`LeopardError`].
55pub type Result<T, E = LeopardError> = std::result::Result<T, E>;
56
57/// Encode parity data into given shards.
58///
59/// The first `data_shards` shards will be the treated as data shards
60/// and the rest as parity shards.
61///
62/// # Errors
63///
64/// If too many shards provided or shards were of incorrect or different lengths.
65pub fn encode(shards: &mut [impl AsMut<[u8]>], data_shards: usize) -> Result<()> {
66    if shards.len() > ORDER {
67        return Err(LeopardError::MaxShardNumberExceeded(shards.len()));
68    }
69    let parity_shards = shards.len() - data_shards;
70    if parity_shards > data_shards {
71        return Err(LeopardError::MaxParityShardNumberExceeded(
72            parity_shards,
73            data_shards,
74        ));
75    }
76    if is_encode_buf_overflow(data_shards, parity_shards) {
77        return Err(LeopardError::UnsupportedShardsAmounts(
78            data_shards,
79            parity_shards,
80        ));
81    }
82
83    let mut shards: Vec<&mut [u8]> = shards.iter_mut().map(|shard| shard.as_mut()).collect();
84    let shard_size = check_shards(&shards, false)?;
85
86    if shard_size % 64 != 0 {
87        return Err(LeopardError::InvalidShardSize(shard_size));
88    }
89
90    encode_inner(&mut shards, data_shards, shard_size);
91
92    Ok(())
93}
94
95fn encode_inner(shards: &mut [&mut [u8]], data_shards: usize, shard_size: usize) {
96    let parity_shards = shards.len() - data_shards;
97
98    let m = ceil_pow2(parity_shards);
99    let mtrunc = m.min(data_shards);
100
101    // 'work' is a temporary buffer where the parity shards are computed.
102    // the first half of it is where the resulting parity shards will end up.
103    let mut work_mem = vec![0; 2 * m * shard_size];
104    let mut work: Vec<_> = work_mem.chunks_exact_mut(shard_size).collect();
105
106    let skew_lut = &lut::FFT_SKEW[m - 1..];
107
108    // copy the input to the work table
109    for (shard, work) in shards[data_shards..].iter().zip(work.iter_mut()) {
110        work.copy_from_slice(shard);
111    }
112
113    ifft_dit_encoder(
114        &shards[..data_shards],
115        mtrunc,
116        &mut work,
117        None, // No xor output
118        m,
119        skew_lut,
120    );
121
122    let last_count = data_shards % m;
123
124    // goto skip_body
125    if m < data_shards {
126        let (xor_out, work) = work.split_at_mut(m);
127        let mut n = m;
128
129        // for sets of m data pieces
130        while n <= data_shards - m {
131            // work <- work xor IFFT(data + i, m, m + i)
132            ifft_dit_encoder(&shards[n..], m, work, Some(xor_out), m, &skew_lut[n..]);
133            n += m;
134        }
135
136        // Handle final partial set of m pieces:
137        if last_count != 0 {
138            ifft_dit_encoder(
139                &shards[n..],
140                last_count,
141                work,
142                Some(xor_out),
143                m,
144                &skew_lut[n..],
145            );
146        }
147    }
148
149    // work <- FFT(work, m, 0)
150    fft_dit(&mut work, parity_shards, m, &*lut::FFT_SKEW);
151
152    for (shard, work) in shards[data_shards..].iter_mut().zip(work.iter()) {
153        shard.copy_from_slice(work);
154    }
155}
156
157/// Leopard algorithm is imperfect and can result in a buffer overflow on the 'lut::FFT_SKEW'.
158/// Encoding happens in passes where each pass encodes parity for data shards in chunks
159/// of the smallest power of 2 bigger (or equal) than parity shards amount.
160/// If the last chunk is not a full pass, we can hit the overflow for some pairs
161/// of (data_shards, parity_shards).
162/// This function detects such inputs.
163fn is_encode_buf_overflow(data_shards: usize, parity_shards: usize) -> bool {
164    debug_assert!(data_shards >= parity_shards);
165    debug_assert!(data_shards + parity_shards <= ORDER);
166
167    let m = ceil_pow2(parity_shards);
168    let last_count = data_shards % m;
169
170    // we can finish encoding with only full passes
171    if m >= data_shards || last_count == 0 {
172        return false;
173    }
174
175    let full_passes = data_shards / m;
176    // if this is 'true', we would overflow the fft skew table which has the size of `MODULUS`
177    (full_passes + 1) * m + 1 > MODULUS as usize
178}
179
180/// Reconstructs original shards from the provided slice.
181///
182/// The shards which are missing should be provided as empty `Vec`s.
183///
184/// Reconstruction can only happen if the amount of present data and parity shards
185/// is equal or greater than the `data_shards`.
186///
187/// The first `data_shards` shards will be the treated as data shards
188/// and the rest as parity shards.
189///
190/// # Errors
191///
192/// If too few shards are present to reconstruct original data or shards were of incorrect or different lengths.
193pub fn reconstruct(shards: &mut [impl AsMut<Vec<u8>>], data_shards: usize) -> Result<()> {
194    if shards.len() > ORDER {
195        return Err(LeopardError::MaxShardNumberExceeded(shards.len()));
196    }
197    let parity_shards = shards.len() - data_shards;
198    if parity_shards > data_shards {
199        return Err(LeopardError::MaxParityShardNumberExceeded(
200            parity_shards,
201            data_shards,
202        ));
203    }
204
205    let mut shards: Vec<_> = shards.iter_mut().map(|shard| shard.as_mut()).collect();
206    let shard_size = check_shards(&shards, true)?;
207
208    let present_shards = shards.iter().filter(|shard| !shard.is_empty()).count();
209    if present_shards == shards.len() {
210        // all shards present, nothing to do
211        return Ok(());
212    }
213
214    // Check if we have enough to reconstruct.
215    if present_shards < data_shards {
216        return Err(LeopardError::TooFewShards(present_shards, data_shards));
217    }
218
219    if shard_size % 64 != 0 {
220        return Err(LeopardError::InvalidShardSize(shard_size));
221    }
222
223    reconstruct_inner(&mut shards, data_shards, shard_size);
224
225    Ok(())
226}
227
228fn reconstruct_inner(shards: &mut [&mut Vec<u8>], data_shards: usize, shard_size: usize) {
229    let parity_shards = shards.len() - data_shards;
230
231    // TODO: errorbitfields for avoiding unnecessary fft steps
232    // orig:
233    // Use only if we are missing less than 1/4 parity,
234    // And we are restoring a significant amount of data.
235    // useBits := r.totalShards-numberPresent <= r.parityShards/4 && shardSize*r.totalShards >= 64<<10
236
237    let m = ceil_pow2(parity_shards);
238    let n = ceil_pow2(m + data_shards);
239
240    // save the info which shards were empty
241    let empty_shards_mask: Vec<_> = shards.iter().map(|shard| shard.is_empty()).collect();
242    // and recreate them
243    for shard in shards.iter_mut().filter(|shard| shard.is_empty()) {
244        shard.resize(shard_size, 0);
245    }
246
247    let mut err_locs = [0u8; ORDER];
248
249    for (&is_empty, err_loc) in empty_shards_mask
250        .iter()
251        .skip(data_shards)
252        .zip(err_locs.iter_mut())
253    {
254        if is_empty {
255            *err_loc = 1;
256        }
257    }
258
259    for err in &mut err_locs[parity_shards..m] {
260        *err = 1;
261    }
262
263    for (&is_empty, err_loc) in empty_shards_mask
264        .iter()
265        .take(data_shards)
266        .zip(err_locs[m..].iter_mut())
267    {
268        if is_empty {
269            *err_loc = 1;
270        }
271    }
272
273    // TODO: No inversion...
274
275    // Evaluate error locator polynomial8
276    fwht(&mut err_locs, ORDER, m + data_shards);
277
278    for (err, &log_walsh) in err_locs.iter_mut().zip(lut::LOG_WALSH.iter()) {
279        let mul = (*err) as usize * log_walsh as usize;
280        *err = (mul % MODULUS as usize) as u8;
281    }
282
283    fwht(&mut err_locs, ORDER, ORDER);
284
285    let mut work_mem = vec![0u8; shard_size * n];
286    let mut work: Vec<_> = work_mem.chunks_exact_mut(shard_size).collect();
287
288    for i in 0..parity_shards {
289        if !empty_shards_mask[i + data_shards] {
290            mul_gf(work[i], shards[i + data_shards], err_locs[i]);
291        } else {
292            work[i].fill(0);
293        }
294    }
295    for work in work.iter_mut().take(m).skip(parity_shards) {
296        work.fill(0);
297    }
298
299    // work <- original data
300    for i in 0..data_shards {
301        if !empty_shards_mask[i] {
302            mul_gf(work[m + i], shards[i], err_locs[m + i])
303        } else {
304            work[m + i].fill(0);
305        }
306    }
307    for work in work.iter_mut().take(n).skip(m + data_shards) {
308        work.fill(0);
309    }
310
311    // work <- IFFT(work, n, 0)
312    ifft_dit_decoder(m + data_shards, &mut work, n, &lut::FFT_SKEW[..]);
313
314    // work <- FormalDerivative(work, n)
315    for i in 1..n {
316        let width = ((i ^ (i - 1)) + 1) >> 1;
317        let (output, input) = work.split_at_mut(i);
318        slices_xor(
319            &mut output[i - width..],
320            input.iter_mut().map(|elem| &**elem),
321        );
322    }
323
324    // work <- FFT(work, n, 0) truncated to m + dataShards
325    fft_dit(&mut work, m + data_shards, n, &lut::FFT_SKEW[..]);
326
327    // Reveal erasures
328    //
329    //  Original = -ErrLocator * FFT( Derivative( IFFT( ErrLocator * ReceivedData ) ) )
330    //  mul_mem(x, y, log_m, ) equals x[] = y[] * log_m
331    //
332    // mem layout: [Recovery Data (Power of Two = M)] [Original Data (K)] [Zero Padding out to N]
333    for (i, shard) in shards.iter_mut().enumerate() {
334        if !empty_shards_mask[i] {
335            continue;
336        }
337
338        if i >= data_shards {
339            // parity shard
340            mul_gf(
341                shard,
342                work[i - data_shards],
343                MODULUS - err_locs[i - data_shards],
344            );
345        } else {
346            // data shard
347            mul_gf(shard, work[i + m], MODULUS - err_locs[i + m]);
348        }
349    }
350}
351
352fn shard_size(shards: &[impl AsRef<[u8]>]) -> usize {
353    shards
354        .iter()
355        .map(|shard| shard.as_ref().len())
356        .find(|&len| len != 0)
357        .unwrap_or(0)
358}
359
360// checkShards will check if shards are the same size
361// or 0, if allowed. An error is returned if this fails.
362// An error is also returned if all shards are size 0.
363//
364// returns the length of the single shard
365fn check_shards(shards: &[impl AsRef<[u8]>], allow_zero: bool) -> Result<usize> {
366    let size = shard_size(shards);
367
368    if size == 0 {
369        if allow_zero {
370            return Ok(0);
371        } else {
372            return Err(LeopardError::EmptyShards);
373        }
374    }
375
376    // NOTE: for happy case fold would be faster
377    let are_all_same_size = shards.iter().all(|shard| {
378        let shard = shard.as_ref();
379        if allow_zero && shard.is_empty() {
380            true
381        } else {
382            shard.len() == size
383        }
384    });
385
386    if !are_all_same_size {
387        return Err(LeopardError::UnequalShardsLengths);
388    }
389
390    Ok(size)
391}
392
393// z = x + y (mod Modulus)
394#[inline]
395const fn add_mod(a: u8, b: u8) -> u8 {
396    let sum = a as u32 + b as u32;
397
398    // Partial reduction step, allowing for kModulus to be returned
399    (sum + (sum >> BITS)) as u8
400}
401
402// z = x - y (mod Modulus)
403#[inline]
404const fn sub_mod(a: u8, b: u8) -> u8 {
405    let b = if a < b { b as u32 + 1 } else { b as u32 };
406    // make sure we don't underflow
407    let a = a as u32 + ORDER as u32;
408    let dif = a - b;
409
410    dif as u8
411}
412
413// Note that this operation is not a normal multiplication in a finite
414// field because the right operand is already a logarithm.  This is done
415// because it moves K table lookups from the Decode() method into the
416// initialization step that is less performance critical.  The LogWalsh[]
417// table below contains precalculated logarithms so it is easier to do
418// all the other multiplies in that form as well.
419#[inline]
420fn mul_log(a: u8, log_b: u8) -> u8 {
421    if a == 0 {
422        0
423    } else {
424        let log_a = lut::log(a);
425        lut::exp(add_mod(log_a, log_b))
426    }
427}
428
429fn mul_add(x: &mut [u8], y: &[u8], log_m: u8) {
430    x.iter_mut().zip(y.iter()).for_each(|(x, y)| {
431        *x ^= lut::mul(*y, log_m);
432    })
433}
434
435fn mul_gf(out: &mut [u8], input: &[u8], log_m: u8) {
436    let mul_lut = lut::MUL[log_m as usize];
437    for (out, &input) in out.iter_mut().zip(input.iter()) {
438        *out = mul_lut[input as usize];
439    }
440}
441
442// Decimation in time (DIT) Fast Walsh-Hadamard Transform
443// Unrolls pairs of layers to perform cross-layer operations in registers
444// mtrunc: Number of elements that are non-zero at the front of data
445fn fwht(data: &mut [u8; ORDER], m: usize, mtrunc: usize) {
446    // Decimation in time: Unroll 2 layers at a time
447    let mut dist: usize = 1;
448    let mut dist4: usize = 4;
449
450    while dist4 <= m {
451        for offset in (0..mtrunc).step_by(dist4) {
452            let mut offset = offset;
453
454            for _ in 0..dist {
455                // TODO: maybe in rust not faster
456                // TODO: bound checks
457                // fwht4(data[i:], dist) inlined...
458                // Reading values appear faster than updating pointers.
459                // Casting to uint is not faster.
460                let t0 = data[offset];
461                let t1 = data[offset + dist];
462                let t2 = data[offset + dist * 2];
463                let t3 = data[offset + dist * 3];
464
465                let (t0, t1) = fwht2alt(t0, t1);
466                let (t2, t3) = fwht2alt(t2, t3);
467                let (t0, t2) = fwht2alt(t0, t2);
468                let (t1, t3) = fwht2alt(t1, t3);
469
470                data[offset] = t0;
471                data[offset + dist] = t1;
472                data[offset + dist * 2] = t2;
473                data[offset + dist * 3] = t3;
474
475                offset += 1
476            }
477        }
478        dist = dist4;
479        dist4 <<= 2;
480    }
481
482    // If there is one layer left:
483    if dist < m {
484        for i in 0..dist {
485            let (first, second) = data.split_at_mut(i + 1);
486            fwht2(&mut first[i], &mut second[dist]);
487        }
488    }
489}
490
491// {a, b} = {a + b, a - b} (Mod Q)
492#[inline]
493fn fwht2(a: &mut u8, b: &mut u8) {
494    let sum = add_mod(*a, *b);
495    let dif = sub_mod(*a, *b);
496
497    *a = sum;
498    *b = dif;
499}
500
501// fwht2 not in place
502#[inline]
503fn fwht2alt(a: u8, b: u8) -> (u8, u8) {
504    (add_mod(a, b), sub_mod(a, b))
505}
506
507#[inline]
508const fn ceil_pow2(x: usize) -> usize {
509    let bitwidth = usize::BITS;
510    1 << (bitwidth - (x - 1).leading_zeros())
511}
512
513// Unrolled IFFT for encoder
514fn ifft_dit_encoder(
515    data: &[impl AsRef<[u8]>],
516    mtrunc: usize,
517    work: &mut [&mut [u8]],
518    xor_output: Option<&mut [&mut [u8]]>,
519    m: usize,
520    skew_lut: &[u8],
521) {
522    // NOTE: may not be valid in rust
523    // I tried rolling the memcpy/memset into the first layer of the FFT and
524    // found that it only yields a 4% performance improvement, which is not
525    // worth the extra complexity.
526    for i in 0..mtrunc {
527        work[i].copy_from_slice(data[i].as_ref());
528    }
529    for row in work[mtrunc..m].iter_mut() {
530        row.fill(0);
531    }
532
533    // Decimation in time: Unroll 2 layers at a time
534    let mut dist = 1;
535    let mut dist4 = 4;
536
537    while dist4 <= m {
538        for r in (0..mtrunc).step_by(dist4) {
539            let iend = r + dist;
540            let log_m01 = skew_lut[iend];
541            let log_m02 = skew_lut[iend + dist];
542            let log_m23 = skew_lut[iend + dist * 2];
543
544            // For each set of dist elements:
545            // NOTE: this is compatible with klauspost/reedsolomon but diverages from catid/leopard
546            for i in r..iend {
547                ifft_dit4(&mut work[i..], dist, log_m01, log_m23, log_m02);
548            }
549        }
550
551        dist = dist4;
552        dist4 <<= 2;
553        // orig:
554        // I tried alternating sweeps left->right and right->left to reduce cache misses.
555        // It provides about 1% performance boost when done for both FFT and IFFT, so it
556        // does not seem to be worth the extra complexity.
557    }
558
559    // If there is one layer left:
560    if dist < m {
561        // assume that dist = m / 2
562        debug_assert_eq!(dist * 2, m);
563
564        let log_m = skew_lut[dist];
565
566        if log_m == MODULUS {
567            let (input, output) = work.split_at_mut(dist);
568            slices_xor(&mut output[..dist], input.iter_mut().map(|elem| &**elem));
569        } else {
570            let (x, y) = work.split_at_mut(dist);
571            for i in 0..dist {
572                ifft_dit2(x[i], y[i], log_m);
573            }
574        }
575    }
576    // orig:
577    // I tried unrolling this but it does not provide more than 5% performance
578    // improvement for 16-bit finite fields, so it's not worth the complexity.
579
580    // NOTE: this is compatible with klauspost/reedsolomon but diverages from catid/leopard
581    if let Some(xor_output) = xor_output {
582        slices_xor(
583            &mut xor_output[..m],
584            work[..m].iter_mut().map(|elem| &**elem),
585        );
586    }
587}
588
589// Basic no-frills version for decoder
590fn ifft_dit_decoder(mtrunc: usize, work: &mut [&mut [u8]], m: usize, skew_lut: &[u8]) {
591    // Decimation in time: Unroll 2 layers at a time
592    let mut dist = 1;
593    let mut dist4 = 4;
594
595    while dist4 <= m {
596        // For each set of dist*4 elements:
597        for r in (0..mtrunc).step_by(dist4) {
598            let iend = r + dist;
599            let log_m01 = skew_lut[iend - 1];
600            let log_m02 = skew_lut[iend + dist - 1];
601            let log_m23 = skew_lut[iend + 2 * dist - 1];
602
603            // For each set of dist elements:
604            for i in r..iend {
605                ifft_dit4(&mut work[i..], dist, log_m01, log_m23, log_m02);
606            }
607        }
608
609        dist = dist4;
610        dist4 <<= 2;
611    }
612
613    // If there is one layer left:
614    if dist < m {
615        // Assuming that dist = m / 2
616        debug_assert_eq!(2 * dist, m);
617
618        let log_m = skew_lut[dist - 1];
619
620        if log_m == MODULUS {
621            let (input, output) = work.split_at_mut(dist);
622            slices_xor(&mut output[..dist], input.iter_mut().map(|elem| &**elem));
623        } else {
624            let (x, y) = work.split_at_mut(dist);
625            for i in 0..dist {
626                ifft_dit2(x[i], y[i], log_m)
627            }
628        }
629    }
630}
631
632fn ifft_dit4(work: &mut [&mut [u8]], dist: usize, log_m01: u8, log_m23: u8, log_m02: u8) {
633    if work[0].is_empty() {
634        return;
635    }
636
637    // TODO: support AVX2 if enabled
638    // No simd version
639
640    let (dist0, dist1) = work.split_at_mut(dist);
641    let (dist1, dist2) = dist1.split_at_mut(dist);
642    let (dist2, dist3) = dist2.split_at_mut(dist);
643
644    // First layer:
645    if log_m01 == MODULUS {
646        slice_xor(&*dist0[0], dist1[0]);
647    } else {
648        ifft_dit2(dist0[0], dist1[0], log_m01);
649    }
650
651    if log_m23 == MODULUS {
652        slice_xor(&*dist2[0], dist3[0]);
653    } else {
654        ifft_dit2(dist2[0], dist3[0], log_m23);
655    }
656
657    // Second layer:
658    if log_m02 == MODULUS {
659        slice_xor(&*dist0[0], dist2[0]);
660        slice_xor(&*dist1[0], dist3[0]);
661    } else {
662        ifft_dit2(dist0[0], dist2[0], log_m02);
663        ifft_dit2(dist1[0], dist3[0], log_m02);
664    }
665}
666
667fn ifft_dit2(x: &mut [u8], y: &mut [u8], log_m: u8) {
668    slice_xor(&*x, y);
669    mul_add(x, y, log_m);
670}
671
672// In-place FFT for encoder and decoder
673fn fft_dit(work: &mut [&mut [u8]], mtrunc: usize, m: usize, skew_lut: &[u8]) {
674    // Decimation in time: Unroll 2 layers at a time
675    let mut dist4 = m;
676    let mut dist = m >> 2;
677
678    while dist != 0 {
679        // For each set of dist*4 elements:
680        for r in (0..mtrunc).step_by(dist4) {
681            let iend = r + dist;
682            let log_m01 = skew_lut[iend - 1];
683            let log_m02 = skew_lut[iend + dist - 1];
684            let log_m23 = skew_lut[iend + 2 * dist - 1];
685
686            // For each set of dist elements:
687            for i in r..iend {
688                fft_dit4(&mut work[i..], dist, log_m01, log_m23, log_m02);
689            }
690        }
691
692        dist4 = dist;
693        dist >>= 2;
694    }
695
696    // If there is one layer left:
697    if dist4 == 2 {
698        for r in (0..mtrunc).step_by(2) {
699            let log_m = skew_lut[r];
700            let (x, y) = work.split_at_mut(r + 1);
701
702            if log_m == MODULUS {
703                slice_xor(&*x[r], y[0]);
704            } else {
705                fft_dit2(x[r], y[0], log_m);
706            }
707        }
708    }
709}
710
711// 4-way butterfly
712fn fft_dit4(work: &mut [&mut [u8]], dist: usize, log_m01: u8, log_m23: u8, log_m02: u8) {
713    if work[0].is_empty() {
714        return;
715    }
716
717    // TODO: support AVX2 if enabled
718    // No simd version
719
720    // First layer:
721    let (dist0, dist1) = work.split_at_mut(dist);
722    let (dist1, dist2) = dist1.split_at_mut(dist);
723    let (dist2, dist3) = dist2.split_at_mut(dist);
724
725    // First layer:
726    if log_m02 == MODULUS {
727        slice_xor(&*dist0[0], dist2[0]);
728        slice_xor(&*dist1[0], dist3[0]);
729    } else {
730        fft_dit2(dist0[0], dist2[0], log_m02);
731        fft_dit2(dist1[0], dist3[0], log_m02);
732    }
733
734    // Second layer:
735    if log_m01 == MODULUS {
736        slice_xor(&*dist0[0], dist1[0]);
737    } else {
738        fft_dit2(dist0[0], dist1[0], log_m01);
739    }
740
741    if log_m23 == MODULUS {
742        slice_xor(&*dist2[0], dist3[0]);
743    } else {
744        fft_dit2(dist2[0], dist3[0], log_m23);
745    }
746}
747
748// 2-way butterfly forward
749fn fft_dit2(x: &mut [u8], y: &mut [u8], log_m: u8) {
750    if x.is_empty() {
751        return;
752    }
753
754    mul_add(x, y, log_m);
755    slice_xor(&*x, y);
756}
757
758fn slices_xor(output: &mut [&mut [u8]], input: impl Iterator<Item = impl Buf>) {
759    output
760        .iter_mut()
761        .zip(input)
762        .for_each(|(out, inp)| slice_xor(inp, out));
763}
764
765fn slice_xor(mut input: impl Buf, mut output: &mut [u8]) {
766    // TODO: this unroll is inherited from go code, however it might not be needed
767    while output.remaining_mut() >= 32 && input.remaining() >= 32 {
768        let mut output_buf = &*output;
769        let v0 = output_buf.get_u64_le() ^ input.get_u64_le();
770        let v1 = output_buf.get_u64_le() ^ input.get_u64_le();
771        let v2 = output_buf.get_u64_le() ^ input.get_u64_le();
772        let v3 = output_buf.get_u64_le() ^ input.get_u64_le();
773
774        output.put_u64_le(v0);
775        output.put_u64_le(v1);
776        output.put_u64_le(v2);
777        output.put_u64_le(v3);
778    }
779
780    let rest = output.remaining_mut().min(input.remaining());
781    for _ in 0..rest {
782        let xor = (&*output).get_u8() ^ input.get_u8();
783        output.put_u8(xor);
784    }
785}
786
787#[cfg(test)]
788mod tests {
789    use std::panic::catch_unwind;
790
791    use rand::{seq::index, Fill, Rng};
792    use test_strategy::{proptest, Arbitrary};
793
794    use super::*;
795
796    #[proptest]
797    fn go_reedsolomon_encode_compatibility(input: TestCase) {
798        let TestCase {
799            data_shards,
800            parity_shards,
801            shard_size,
802        } = input;
803        let total_shards = data_shards + parity_shards;
804        let test_shards = random_shards(total_shards, shard_size);
805
806        let mut shards = test_shards.clone();
807        encode(&mut shards, data_shards).unwrap();
808
809        let mut expected = test_shards;
810        go_leopard::encode(&mut expected, data_shards, shard_size).unwrap();
811
812        if expected != shards {
813            panic!("Go and Rust encoding differ for {input:#?}")
814        }
815    }
816
817    #[proptest]
818    fn encode_reconstruct(input: TestCase) {
819        let TestCase {
820            data_shards,
821            parity_shards,
822            shard_size,
823        } = input;
824        let total_shards = data_shards + parity_shards;
825        let mut shards = random_shards(total_shards, shard_size);
826
827        encode(&mut shards, data_shards).unwrap();
828
829        let expected = shards.clone();
830
831        let mut rng = rand::thread_rng();
832        let missing_shards = rng.gen_range(1..=parity_shards);
833        for idx in index::sample(&mut rng, total_shards, missing_shards) {
834            shards[idx] = vec![];
835        }
836
837        reconstruct(&mut shards, data_shards).unwrap();
838
839        if expected != shards {
840            panic!("shares differ after reconstruction");
841        }
842    }
843
844    #[test]
845    fn overflow_detection() {
846        for data_shards in 1..MODULUS as usize {
847            for parity_shards in 1..data_shards {
848                let total_shards = data_shards + parity_shards;
849
850                // too many shards
851                if total_shards > ORDER {
852                    continue;
853                }
854
855                let overflow = is_encode_buf_overflow(data_shards, parity_shards);
856
857                let result = catch_unwind(|| {
858                    let mut shards = random_shards(total_shards, 64);
859                    let mut shards_ref: Vec<_> = shards
860                        .iter_mut()
861                        .map(|shard| shard.as_mut_slice())
862                        .collect();
863                    encode_inner(&mut shards_ref, data_shards, 64);
864                });
865
866                assert_eq!(result.is_err(), overflow, "{data_shards} {parity_shards}");
867            }
868        }
869    }
870
871    #[derive(Arbitrary, Debug)]
872    #[filter(!is_encode_buf_overflow(#data_shards, #parity_shards))]
873    struct TestCase {
874        #[strategy(1..ORDER - 1)]
875        data_shards: usize,
876
877        #[strategy(1..=(ORDER - #data_shards).min(#data_shards))]
878        parity_shards: usize,
879
880        #[strategy(1usize..1024)]
881        #[map(|x| x * 64)]
882        shard_size: usize,
883    }
884
885    fn random_shards(shards: usize, shard_size: usize) -> Vec<Vec<u8>> {
886        let mut rng = rand::thread_rng();
887        (0..shards)
888            .map(|_| {
889                let mut shard = vec![0; shard_size];
890                shard.try_fill(&mut rng).unwrap();
891                shard
892            })
893            .collect()
894    }
895}