1#![doc = include_str!("../README.md")]
2
3use bytes::{Buf, BufMut};
4use thiserror::Error;
5
6pub const BITS: usize = 8;
8pub const ORDER: usize = u8::MAX as usize + 1;
10pub const MODULUS: u8 = u8::MAX;
12pub const POLYNOMIAL: usize = 0x11D;
14pub const CANTOR_BASIS: [u8; BITS] = [1, 214, 152, 146, 86, 200, 88, 230];
16
17mod lut;
19
20#[derive(Debug, Error)]
22pub enum LeopardError {
23 #[error("Maximum shard number ({}) exceeded: {0}", ORDER)]
25 MaxShardNumberExceeded(usize),
26
27 #[error("Maximum parity shard number ({0}) exceeded: {1}")]
29 MaxParityShardNumberExceeded(usize, usize),
30
31 #[error("Unsupported number of data ({0}) and parity ({1}) shards")]
35 UnsupportedShardsAmounts(usize, usize),
36
37 #[error("Shards contain no data")]
39 EmptyShards,
40
41 #[error("Shards of different lengths found")]
43 UnequalShardsLengths,
44
45 #[error("Shard size ({0}) should be a multiple of 64")]
47 InvalidShardSize(usize),
48
49 #[error("Too few shards ({0}) to reconstruct data, at least {1} needed")]
51 TooFewShards(usize, usize),
52}
53
54pub type Result<T, E = LeopardError> = std::result::Result<T, E>;
56
57pub fn encode(shards: &mut [impl AsMut<[u8]>], data_shards: usize) -> Result<()> {
66 if shards.len() > ORDER {
67 return Err(LeopardError::MaxShardNumberExceeded(shards.len()));
68 }
69 let parity_shards = shards.len() - data_shards;
70 if parity_shards > data_shards {
71 return Err(LeopardError::MaxParityShardNumberExceeded(
72 parity_shards,
73 data_shards,
74 ));
75 }
76 if is_encode_buf_overflow(data_shards, parity_shards) {
77 return Err(LeopardError::UnsupportedShardsAmounts(
78 data_shards,
79 parity_shards,
80 ));
81 }
82
83 let mut shards: Vec<&mut [u8]> = shards.iter_mut().map(|shard| shard.as_mut()).collect();
84 let shard_size = check_shards(&shards, false)?;
85
86 if shard_size % 64 != 0 {
87 return Err(LeopardError::InvalidShardSize(shard_size));
88 }
89
90 encode_inner(&mut shards, data_shards, shard_size);
91
92 Ok(())
93}
94
95fn encode_inner(shards: &mut [&mut [u8]], data_shards: usize, shard_size: usize) {
96 let parity_shards = shards.len() - data_shards;
97
98 let m = ceil_pow2(parity_shards);
99 let mtrunc = m.min(data_shards);
100
101 let mut work_mem = vec![0; 2 * m * shard_size];
104 let mut work: Vec<_> = work_mem.chunks_exact_mut(shard_size).collect();
105
106 let skew_lut = &lut::FFT_SKEW[m - 1..];
107
108 for (shard, work) in shards[data_shards..].iter().zip(work.iter_mut()) {
110 work.copy_from_slice(shard);
111 }
112
113 ifft_dit_encoder(
114 &shards[..data_shards],
115 mtrunc,
116 &mut work,
117 None, m,
119 skew_lut,
120 );
121
122 let last_count = data_shards % m;
123
124 if m < data_shards {
126 let (xor_out, work) = work.split_at_mut(m);
127 let mut n = m;
128
129 while n <= data_shards - m {
131 ifft_dit_encoder(&shards[n..], m, work, Some(xor_out), m, &skew_lut[n..]);
133 n += m;
134 }
135
136 if last_count != 0 {
138 ifft_dit_encoder(
139 &shards[n..],
140 last_count,
141 work,
142 Some(xor_out),
143 m,
144 &skew_lut[n..],
145 );
146 }
147 }
148
149 fft_dit(&mut work, parity_shards, m, &*lut::FFT_SKEW);
151
152 for (shard, work) in shards[data_shards..].iter_mut().zip(work.iter()) {
153 shard.copy_from_slice(work);
154 }
155}
156
157fn is_encode_buf_overflow(data_shards: usize, parity_shards: usize) -> bool {
164 debug_assert!(data_shards >= parity_shards);
165 debug_assert!(data_shards + parity_shards <= ORDER);
166
167 let m = ceil_pow2(parity_shards);
168 let last_count = data_shards % m;
169
170 if m >= data_shards || last_count == 0 {
172 return false;
173 }
174
175 let full_passes = data_shards / m;
176 (full_passes + 1) * m + 1 > MODULUS as usize
178}
179
180pub fn reconstruct(shards: &mut [impl AsMut<Vec<u8>>], data_shards: usize) -> Result<()> {
194 if shards.len() > ORDER {
195 return Err(LeopardError::MaxShardNumberExceeded(shards.len()));
196 }
197 let parity_shards = shards.len() - data_shards;
198 if parity_shards > data_shards {
199 return Err(LeopardError::MaxParityShardNumberExceeded(
200 parity_shards,
201 data_shards,
202 ));
203 }
204
205 let mut shards: Vec<_> = shards.iter_mut().map(|shard| shard.as_mut()).collect();
206 let shard_size = check_shards(&shards, true)?;
207
208 let present_shards = shards.iter().filter(|shard| !shard.is_empty()).count();
209 if present_shards == shards.len() {
210 return Ok(());
212 }
213
214 if present_shards < data_shards {
216 return Err(LeopardError::TooFewShards(present_shards, data_shards));
217 }
218
219 if shard_size % 64 != 0 {
220 return Err(LeopardError::InvalidShardSize(shard_size));
221 }
222
223 reconstruct_inner(&mut shards, data_shards, shard_size);
224
225 Ok(())
226}
227
228fn reconstruct_inner(shards: &mut [&mut Vec<u8>], data_shards: usize, shard_size: usize) {
229 let parity_shards = shards.len() - data_shards;
230
231 let m = ceil_pow2(parity_shards);
238 let n = ceil_pow2(m + data_shards);
239
240 let empty_shards_mask: Vec<_> = shards.iter().map(|shard| shard.is_empty()).collect();
242 for shard in shards.iter_mut().filter(|shard| shard.is_empty()) {
244 shard.resize(shard_size, 0);
245 }
246
247 let mut err_locs = [0u8; ORDER];
248
249 for (&is_empty, err_loc) in empty_shards_mask
250 .iter()
251 .skip(data_shards)
252 .zip(err_locs.iter_mut())
253 {
254 if is_empty {
255 *err_loc = 1;
256 }
257 }
258
259 for err in &mut err_locs[parity_shards..m] {
260 *err = 1;
261 }
262
263 for (&is_empty, err_loc) in empty_shards_mask
264 .iter()
265 .take(data_shards)
266 .zip(err_locs[m..].iter_mut())
267 {
268 if is_empty {
269 *err_loc = 1;
270 }
271 }
272
273 fwht(&mut err_locs, ORDER, m + data_shards);
277
278 for (err, &log_walsh) in err_locs.iter_mut().zip(lut::LOG_WALSH.iter()) {
279 let mul = (*err) as usize * log_walsh as usize;
280 *err = (mul % MODULUS as usize) as u8;
281 }
282
283 fwht(&mut err_locs, ORDER, ORDER);
284
285 let mut work_mem = vec![0u8; shard_size * n];
286 let mut work: Vec<_> = work_mem.chunks_exact_mut(shard_size).collect();
287
288 for i in 0..parity_shards {
289 if !empty_shards_mask[i + data_shards] {
290 mul_gf(work[i], shards[i + data_shards], err_locs[i]);
291 } else {
292 work[i].fill(0);
293 }
294 }
295 for work in work.iter_mut().take(m).skip(parity_shards) {
296 work.fill(0);
297 }
298
299 for i in 0..data_shards {
301 if !empty_shards_mask[i] {
302 mul_gf(work[m + i], shards[i], err_locs[m + i])
303 } else {
304 work[m + i].fill(0);
305 }
306 }
307 for work in work.iter_mut().take(n).skip(m + data_shards) {
308 work.fill(0);
309 }
310
311 ifft_dit_decoder(m + data_shards, &mut work, n, &lut::FFT_SKEW[..]);
313
314 for i in 1..n {
316 let width = ((i ^ (i - 1)) + 1) >> 1;
317 let (output, input) = work.split_at_mut(i);
318 slices_xor(
319 &mut output[i - width..],
320 input.iter_mut().map(|elem| &**elem),
321 );
322 }
323
324 fft_dit(&mut work, m + data_shards, n, &lut::FFT_SKEW[..]);
326
327 for (i, shard) in shards.iter_mut().enumerate() {
334 if !empty_shards_mask[i] {
335 continue;
336 }
337
338 if i >= data_shards {
339 mul_gf(
341 shard,
342 work[i - data_shards],
343 MODULUS - err_locs[i - data_shards],
344 );
345 } else {
346 mul_gf(shard, work[i + m], MODULUS - err_locs[i + m]);
348 }
349 }
350}
351
352fn shard_size(shards: &[impl AsRef<[u8]>]) -> usize {
353 shards
354 .iter()
355 .map(|shard| shard.as_ref().len())
356 .find(|&len| len != 0)
357 .unwrap_or(0)
358}
359
360fn check_shards(shards: &[impl AsRef<[u8]>], allow_zero: bool) -> Result<usize> {
366 let size = shard_size(shards);
367
368 if size == 0 {
369 if allow_zero {
370 return Ok(0);
371 } else {
372 return Err(LeopardError::EmptyShards);
373 }
374 }
375
376 let are_all_same_size = shards.iter().all(|shard| {
378 let shard = shard.as_ref();
379 if allow_zero && shard.is_empty() {
380 true
381 } else {
382 shard.len() == size
383 }
384 });
385
386 if !are_all_same_size {
387 return Err(LeopardError::UnequalShardsLengths);
388 }
389
390 Ok(size)
391}
392
393#[inline]
395const fn add_mod(a: u8, b: u8) -> u8 {
396 let sum = a as u32 + b as u32;
397
398 (sum + (sum >> BITS)) as u8
400}
401
402#[inline]
404const fn sub_mod(a: u8, b: u8) -> u8 {
405 let b = if a < b { b as u32 + 1 } else { b as u32 };
406 let a = a as u32 + ORDER as u32;
408 let dif = a - b;
409
410 dif as u8
411}
412
413#[inline]
420fn mul_log(a: u8, log_b: u8) -> u8 {
421 if a == 0 {
422 0
423 } else {
424 let log_a = lut::log(a);
425 lut::exp(add_mod(log_a, log_b))
426 }
427}
428
429fn mul_add(x: &mut [u8], y: &[u8], log_m: u8) {
430 x.iter_mut().zip(y.iter()).for_each(|(x, y)| {
431 *x ^= lut::mul(*y, log_m);
432 })
433}
434
435fn mul_gf(out: &mut [u8], input: &[u8], log_m: u8) {
436 let mul_lut = lut::MUL[log_m as usize];
437 for (out, &input) in out.iter_mut().zip(input.iter()) {
438 *out = mul_lut[input as usize];
439 }
440}
441
442fn fwht(data: &mut [u8; ORDER], m: usize, mtrunc: usize) {
446 let mut dist: usize = 1;
448 let mut dist4: usize = 4;
449
450 while dist4 <= m {
451 for offset in (0..mtrunc).step_by(dist4) {
452 let mut offset = offset;
453
454 for _ in 0..dist {
455 let t0 = data[offset];
461 let t1 = data[offset + dist];
462 let t2 = data[offset + dist * 2];
463 let t3 = data[offset + dist * 3];
464
465 let (t0, t1) = fwht2alt(t0, t1);
466 let (t2, t3) = fwht2alt(t2, t3);
467 let (t0, t2) = fwht2alt(t0, t2);
468 let (t1, t3) = fwht2alt(t1, t3);
469
470 data[offset] = t0;
471 data[offset + dist] = t1;
472 data[offset + dist * 2] = t2;
473 data[offset + dist * 3] = t3;
474
475 offset += 1
476 }
477 }
478 dist = dist4;
479 dist4 <<= 2;
480 }
481
482 if dist < m {
484 for i in 0..dist {
485 let (first, second) = data.split_at_mut(i + 1);
486 fwht2(&mut first[i], &mut second[dist]);
487 }
488 }
489}
490
491#[inline]
493fn fwht2(a: &mut u8, b: &mut u8) {
494 let sum = add_mod(*a, *b);
495 let dif = sub_mod(*a, *b);
496
497 *a = sum;
498 *b = dif;
499}
500
501#[inline]
503fn fwht2alt(a: u8, b: u8) -> (u8, u8) {
504 (add_mod(a, b), sub_mod(a, b))
505}
506
507#[inline]
508const fn ceil_pow2(x: usize) -> usize {
509 let bitwidth = usize::BITS;
510 1 << (bitwidth - (x - 1).leading_zeros())
511}
512
513fn ifft_dit_encoder(
515 data: &[impl AsRef<[u8]>],
516 mtrunc: usize,
517 work: &mut [&mut [u8]],
518 xor_output: Option<&mut [&mut [u8]]>,
519 m: usize,
520 skew_lut: &[u8],
521) {
522 for i in 0..mtrunc {
527 work[i].copy_from_slice(data[i].as_ref());
528 }
529 for row in work[mtrunc..m].iter_mut() {
530 row.fill(0);
531 }
532
533 let mut dist = 1;
535 let mut dist4 = 4;
536
537 while dist4 <= m {
538 for r in (0..mtrunc).step_by(dist4) {
539 let iend = r + dist;
540 let log_m01 = skew_lut[iend];
541 let log_m02 = skew_lut[iend + dist];
542 let log_m23 = skew_lut[iend + dist * 2];
543
544 for i in r..iend {
547 ifft_dit4(&mut work[i..], dist, log_m01, log_m23, log_m02);
548 }
549 }
550
551 dist = dist4;
552 dist4 <<= 2;
553 }
558
559 if dist < m {
561 debug_assert_eq!(dist * 2, m);
563
564 let log_m = skew_lut[dist];
565
566 if log_m == MODULUS {
567 let (input, output) = work.split_at_mut(dist);
568 slices_xor(&mut output[..dist], input.iter_mut().map(|elem| &**elem));
569 } else {
570 let (x, y) = work.split_at_mut(dist);
571 for i in 0..dist {
572 ifft_dit2(x[i], y[i], log_m);
573 }
574 }
575 }
576 if let Some(xor_output) = xor_output {
582 slices_xor(
583 &mut xor_output[..m],
584 work[..m].iter_mut().map(|elem| &**elem),
585 );
586 }
587}
588
589fn ifft_dit_decoder(mtrunc: usize, work: &mut [&mut [u8]], m: usize, skew_lut: &[u8]) {
591 let mut dist = 1;
593 let mut dist4 = 4;
594
595 while dist4 <= m {
596 for r in (0..mtrunc).step_by(dist4) {
598 let iend = r + dist;
599 let log_m01 = skew_lut[iend - 1];
600 let log_m02 = skew_lut[iend + dist - 1];
601 let log_m23 = skew_lut[iend + 2 * dist - 1];
602
603 for i in r..iend {
605 ifft_dit4(&mut work[i..], dist, log_m01, log_m23, log_m02);
606 }
607 }
608
609 dist = dist4;
610 dist4 <<= 2;
611 }
612
613 if dist < m {
615 debug_assert_eq!(2 * dist, m);
617
618 let log_m = skew_lut[dist - 1];
619
620 if log_m == MODULUS {
621 let (input, output) = work.split_at_mut(dist);
622 slices_xor(&mut output[..dist], input.iter_mut().map(|elem| &**elem));
623 } else {
624 let (x, y) = work.split_at_mut(dist);
625 for i in 0..dist {
626 ifft_dit2(x[i], y[i], log_m)
627 }
628 }
629 }
630}
631
632fn ifft_dit4(work: &mut [&mut [u8]], dist: usize, log_m01: u8, log_m23: u8, log_m02: u8) {
633 if work[0].is_empty() {
634 return;
635 }
636
637 let (dist0, dist1) = work.split_at_mut(dist);
641 let (dist1, dist2) = dist1.split_at_mut(dist);
642 let (dist2, dist3) = dist2.split_at_mut(dist);
643
644 if log_m01 == MODULUS {
646 slice_xor(&*dist0[0], dist1[0]);
647 } else {
648 ifft_dit2(dist0[0], dist1[0], log_m01);
649 }
650
651 if log_m23 == MODULUS {
652 slice_xor(&*dist2[0], dist3[0]);
653 } else {
654 ifft_dit2(dist2[0], dist3[0], log_m23);
655 }
656
657 if log_m02 == MODULUS {
659 slice_xor(&*dist0[0], dist2[0]);
660 slice_xor(&*dist1[0], dist3[0]);
661 } else {
662 ifft_dit2(dist0[0], dist2[0], log_m02);
663 ifft_dit2(dist1[0], dist3[0], log_m02);
664 }
665}
666
667fn ifft_dit2(x: &mut [u8], y: &mut [u8], log_m: u8) {
668 slice_xor(&*x, y);
669 mul_add(x, y, log_m);
670}
671
672fn fft_dit(work: &mut [&mut [u8]], mtrunc: usize, m: usize, skew_lut: &[u8]) {
674 let mut dist4 = m;
676 let mut dist = m >> 2;
677
678 while dist != 0 {
679 for r in (0..mtrunc).step_by(dist4) {
681 let iend = r + dist;
682 let log_m01 = skew_lut[iend - 1];
683 let log_m02 = skew_lut[iend + dist - 1];
684 let log_m23 = skew_lut[iend + 2 * dist - 1];
685
686 for i in r..iend {
688 fft_dit4(&mut work[i..], dist, log_m01, log_m23, log_m02);
689 }
690 }
691
692 dist4 = dist;
693 dist >>= 2;
694 }
695
696 if dist4 == 2 {
698 for r in (0..mtrunc).step_by(2) {
699 let log_m = skew_lut[r];
700 let (x, y) = work.split_at_mut(r + 1);
701
702 if log_m == MODULUS {
703 slice_xor(&*x[r], y[0]);
704 } else {
705 fft_dit2(x[r], y[0], log_m);
706 }
707 }
708 }
709}
710
711fn fft_dit4(work: &mut [&mut [u8]], dist: usize, log_m01: u8, log_m23: u8, log_m02: u8) {
713 if work[0].is_empty() {
714 return;
715 }
716
717 let (dist0, dist1) = work.split_at_mut(dist);
722 let (dist1, dist2) = dist1.split_at_mut(dist);
723 let (dist2, dist3) = dist2.split_at_mut(dist);
724
725 if log_m02 == MODULUS {
727 slice_xor(&*dist0[0], dist2[0]);
728 slice_xor(&*dist1[0], dist3[0]);
729 } else {
730 fft_dit2(dist0[0], dist2[0], log_m02);
731 fft_dit2(dist1[0], dist3[0], log_m02);
732 }
733
734 if log_m01 == MODULUS {
736 slice_xor(&*dist0[0], dist1[0]);
737 } else {
738 fft_dit2(dist0[0], dist1[0], log_m01);
739 }
740
741 if log_m23 == MODULUS {
742 slice_xor(&*dist2[0], dist3[0]);
743 } else {
744 fft_dit2(dist2[0], dist3[0], log_m23);
745 }
746}
747
748fn fft_dit2(x: &mut [u8], y: &mut [u8], log_m: u8) {
750 if x.is_empty() {
751 return;
752 }
753
754 mul_add(x, y, log_m);
755 slice_xor(&*x, y);
756}
757
758fn slices_xor(output: &mut [&mut [u8]], input: impl Iterator<Item = impl Buf>) {
759 output
760 .iter_mut()
761 .zip(input)
762 .for_each(|(out, inp)| slice_xor(inp, out));
763}
764
765fn slice_xor(mut input: impl Buf, mut output: &mut [u8]) {
766 while output.remaining_mut() >= 32 && input.remaining() >= 32 {
768 let mut output_buf = &*output;
769 let v0 = output_buf.get_u64_le() ^ input.get_u64_le();
770 let v1 = output_buf.get_u64_le() ^ input.get_u64_le();
771 let v2 = output_buf.get_u64_le() ^ input.get_u64_le();
772 let v3 = output_buf.get_u64_le() ^ input.get_u64_le();
773
774 output.put_u64_le(v0);
775 output.put_u64_le(v1);
776 output.put_u64_le(v2);
777 output.put_u64_le(v3);
778 }
779
780 let rest = output.remaining_mut().min(input.remaining());
781 for _ in 0..rest {
782 let xor = (&*output).get_u8() ^ input.get_u8();
783 output.put_u8(xor);
784 }
785}
786
787#[cfg(test)]
788mod tests {
789 use std::panic::catch_unwind;
790
791 use rand::{seq::index, Fill, Rng};
792 use test_strategy::{proptest, Arbitrary};
793
794 use super::*;
795
796 #[proptest]
797 fn go_reedsolomon_encode_compatibility(input: TestCase) {
798 let TestCase {
799 data_shards,
800 parity_shards,
801 shard_size,
802 } = input;
803 let total_shards = data_shards + parity_shards;
804 let test_shards = random_shards(total_shards, shard_size);
805
806 let mut shards = test_shards.clone();
807 encode(&mut shards, data_shards).unwrap();
808
809 let mut expected = test_shards;
810 go_leopard::encode(&mut expected, data_shards, shard_size).unwrap();
811
812 if expected != shards {
813 panic!("Go and Rust encoding differ for {input:#?}")
814 }
815 }
816
817 #[proptest]
818 fn encode_reconstruct(input: TestCase) {
819 let TestCase {
820 data_shards,
821 parity_shards,
822 shard_size,
823 } = input;
824 let total_shards = data_shards + parity_shards;
825 let mut shards = random_shards(total_shards, shard_size);
826
827 encode(&mut shards, data_shards).unwrap();
828
829 let expected = shards.clone();
830
831 let mut rng = rand::thread_rng();
832 let missing_shards = rng.gen_range(1..=parity_shards);
833 for idx in index::sample(&mut rng, total_shards, missing_shards) {
834 shards[idx] = vec![];
835 }
836
837 reconstruct(&mut shards, data_shards).unwrap();
838
839 if expected != shards {
840 panic!("shares differ after reconstruction");
841 }
842 }
843
844 #[test]
845 fn overflow_detection() {
846 for data_shards in 1..MODULUS as usize {
847 for parity_shards in 1..data_shards {
848 let total_shards = data_shards + parity_shards;
849
850 if total_shards > ORDER {
852 continue;
853 }
854
855 let overflow = is_encode_buf_overflow(data_shards, parity_shards);
856
857 let result = catch_unwind(|| {
858 let mut shards = random_shards(total_shards, 64);
859 let mut shards_ref: Vec<_> = shards
860 .iter_mut()
861 .map(|shard| shard.as_mut_slice())
862 .collect();
863 encode_inner(&mut shards_ref, data_shards, 64);
864 });
865
866 assert_eq!(result.is_err(), overflow, "{data_shards} {parity_shards}");
867 }
868 }
869 }
870
871 #[derive(Arbitrary, Debug)]
872 #[filter(!is_encode_buf_overflow(#data_shards, #parity_shards))]
873 struct TestCase {
874 #[strategy(1..ORDER - 1)]
875 data_shards: usize,
876
877 #[strategy(1..=(ORDER - #data_shards).min(#data_shards))]
878 parity_shards: usize,
879
880 #[strategy(1usize..1024)]
881 #[map(|x| x * 64)]
882 shard_size: usize,
883 }
884
885 fn random_shards(shards: usize, shard_size: usize) -> Vec<Vec<u8>> {
886 let mut rng = rand::thread_rng();
887 (0..shards)
888 .map(|_| {
889 let mut shard = vec![0; shard_size];
890 shard.try_fill(&mut rng).unwrap();
891 shard
892 })
893 .collect()
894 }
895}