use alloc::vec;
use alloc::vec::Vec;
use reed_solomon_simd::{ReedSolomonDecoder, ReedSolomonEncoder};
pub const RS_DATA_SHARDS: usize = 4;
pub const RS_PARITY_SHARDS: usize = 2;
#[must_use]
pub fn shard_bytes(payload_len: usize, data_shards: usize) -> usize {
if data_shards == 0 {
return 0;
}
let raw = payload_len.div_ceil(data_shards);
raw.div_ceil(2) * 2
}
#[must_use]
pub fn parity_len(payload_len: usize, data_shards: usize, parity_shards: usize) -> usize {
shard_bytes(payload_len, data_shards) * parity_shards
}
pub fn encode_parity(
payload: &[u8],
data_shards: usize,
parity_shards: usize,
) -> crate::Result<Vec<u8>> {
let sb = shard_bytes(payload.len(), data_shards);
if sb == 0 || parity_shards == 0 {
return Ok(Vec::new());
}
if parity_shards == 1 {
return Ok(encode_xor(payload, data_shards, sb));
}
encode_rs(payload, data_shards, parity_shards, sb)
}
fn fill_data_shard(buf: &mut [u8], payload: &[u8], i: usize, sb: usize) {
buf.fill(0);
let start = i * sb;
let end = ((i + 1) * sb).min(payload.len());
if start < payload.len() {
#[expect(
clippy::indexing_slicing,
reason = "start < payload.len() and end <= payload.len() guarded above"
)]
buf[..end - start].copy_from_slice(&payload[start..end]);
}
}
fn encode_xor(payload: &[u8], data_shards: usize, sb: usize) -> Vec<u8> {
let mut parity = vec![0u8; sb];
let mut shard = vec![0u8; sb];
for i in 0..data_shards {
fill_data_shard(&mut shard, payload, i, sb);
for (p, &b) in parity.iter_mut().zip(shard.iter()) {
*p ^= b;
}
}
parity
}
fn encode_rs(
payload: &[u8],
data_shards: usize,
parity_shards: usize,
sb: usize,
) -> crate::Result<Vec<u8>> {
let mut encoder = ReedSolomonEncoder::new(data_shards, parity_shards, sb)
.map_err(|_| crate::Error::Unrecoverable)?;
let mut shard_buf = vec![0u8; sb];
for i in 0..data_shards {
fill_data_shard(&mut shard_buf, payload, i, sb);
encoder
.add_original_shard(&shard_buf)
.map_err(|_| crate::Error::Unrecoverable)?;
}
let result = encoder.encode().map_err(|_| crate::Error::Unrecoverable)?;
let mut out = Vec::with_capacity(sb * parity_shards);
for shard in result.recovery_iter() {
out.extend_from_slice(shard);
}
Ok(out)
}
pub fn try_recover<F>(
data: &[u8],
parity: &[u8],
expected_payload_len: usize,
data_shards: usize,
parity_shards: usize,
mut xxh3_oracle: F,
) -> crate::Result<Option<Vec<u8>>>
where
F: FnMut(&[u8]) -> bool,
{
let sb = shard_bytes(expected_payload_len, data_shards);
if sb == 0
|| parity_shards == 0
|| data.len() < expected_payload_len
|| parity.len() < sb * parity_shards
{
return Ok(None);
}
if parity_shards == 1 {
return Ok(xor_recover(
data,
parity,
expected_payload_len,
data_shards,
sb,
&mut xxh3_oracle,
));
}
rs_recover(
data,
parity,
expected_payload_len,
data_shards,
parity_shards,
sb,
&mut xxh3_oracle,
)
}
fn carve_data_shards(data: &[u8], data_shards: usize, sb: usize) -> Vec<Vec<u8>> {
let mut shards = Vec::with_capacity(data_shards);
for i in 0..data_shards {
let mut buf = vec![0u8; sb];
let start = i * sb;
let end = ((i + 1) * sb).min(data.len());
if start < data.len() {
#[expect(
clippy::indexing_slicing,
reason = "start < data.len() and end <= data.len() guarded above"
)]
buf[..end - start].copy_from_slice(&data[start..end]);
}
shards.push(buf);
}
shards
}
fn xor_recover<F>(
data: &[u8],
parity: &[u8],
expected_payload_len: usize,
data_shards: usize,
sb: usize,
xxh3_oracle: &mut F,
) -> Option<Vec<u8>>
where
F: FnMut(&[u8]) -> bool,
{
let shards = carve_data_shards(data, data_shards, sb);
let mut payload = Vec::with_capacity(data_shards * sb);
for s in &shards {
payload.extend_from_slice(s);
}
payload.truncate(expected_payload_len);
if xxh3_oracle(&payload) {
return Some(payload);
}
#[expect(
clippy::indexing_slicing,
reason = "parity.len() >= sb guarded by caller"
)]
let parity_shard = &parity[..sb];
for miss in 0..data_shards {
let mut recovered = parity_shard.to_vec();
for (i, s) in shards.iter().enumerate() {
if i == miss {
continue;
}
for (r, &b) in recovered.iter_mut().zip(s.iter()) {
*r ^= b;
}
}
let mut payload = Vec::with_capacity(data_shards * sb);
for (i, s) in shards.iter().enumerate() {
if i == miss {
payload.extend_from_slice(&recovered);
} else {
payload.extend_from_slice(s);
}
}
payload.truncate(expected_payload_len);
if xxh3_oracle(&payload) {
return Some(payload);
}
}
None
}
fn rs_recover<F>(
data: &[u8],
parity: &[u8],
expected_payload_len: usize,
data_shards: usize,
parity_shards: usize,
sb: usize,
xxh3_oracle: &mut F,
) -> crate::Result<Option<Vec<u8>>>
where
F: FnMut(&[u8]) -> bool,
{
let n = data_shards + parity_shards;
let mut shards = carve_data_shards(data, data_shards, sb);
for i in 0..parity_shards {
let start = i * sb;
let end = start + sb;
if end > parity.len() {
return Ok(None);
}
#[expect(clippy::indexing_slicing, reason = "end <= parity.len() guarded above")]
shards.push(parity[start..end].to_vec());
}
let mut missing = (0..parity_shards).collect::<Vec<usize>>();
loop {
if let Some(payload) = try_decode_one(
&shards,
sb,
expected_payload_len,
data_shards,
parity_shards,
&missing,
)? && xxh3_oracle(&payload)
{
return Ok(Some(payload));
}
if !next_combination(&mut missing, n) {
break;
}
}
Ok(None)
}
fn next_combination(combo: &mut [usize], n: usize) -> bool {
let k = combo.len();
if k == 0 {
return false;
}
let mut i = k - 1;
loop {
let max_at_i = n - k + i;
#[expect(clippy::indexing_slicing, reason = "i < k == combo.len()")]
if combo[i] < max_at_i {
combo[i] += 1;
for j in (i + 1)..k {
#[expect(clippy::indexing_slicing, reason = "i < j < k == combo.len()")]
{
combo[j] = combo[j - 1] + 1;
}
}
return true;
}
if i == 0 {
return false;
}
i -= 1;
}
}
fn try_decode_one(
shards: &[Vec<u8>],
sb: usize,
expected_payload_len: usize,
data_shards: usize,
parity_shards: usize,
missing: &[usize],
) -> crate::Result<Option<Vec<u8>>> {
let mut decoder = ReedSolomonDecoder::new(data_shards, parity_shards, sb)
.map_err(|_| crate::Error::Unrecoverable)?;
for (i, shard) in shards.iter().enumerate().take(data_shards) {
if missing.contains(&i) {
continue;
}
decoder
.add_original_shard(i, shard)
.map_err(|_| crate::Error::Unrecoverable)?;
}
for (i, shard) in shards
.iter()
.enumerate()
.skip(data_shards)
.take(parity_shards)
{
if missing.contains(&i) {
continue;
}
decoder
.add_recovery_shard(i - data_shards, shard)
.map_err(|_| crate::Error::Unrecoverable)?;
}
let Ok(result) = decoder.decode() else {
return Ok(None);
};
let mut payload = Vec::with_capacity(data_shards * sb);
for (i, shard) in shards.iter().enumerate().take(data_shards) {
if missing.contains(&i) {
match result.restored_original(i) {
Some(s) => payload.extend_from_slice(s),
None => return Ok(None),
}
} else {
payload.extend_from_slice(shard);
}
}
payload.truncate(expected_payload_len);
Ok(Some(payload))
}
#[cfg(test)]
#[expect(clippy::expect_used, clippy::indexing_slicing, reason = "test code")]
mod tests {
use super::*;
use test_log::test;
fn xxh3_oracle(expected: u128) -> impl FnMut(&[u8]) -> bool {
move |candidate: &[u8]| crate::hash::hash128(candidate) == expected
}
#[test]
fn shard_bytes_rounds_up_to_even_quarter() {
assert_eq!(shard_bytes(4, 4), 2);
assert_eq!(shard_bytes(33, 4), 10);
assert_eq!(shard_bytes(4096, 4), 1024);
assert_eq!(shard_bytes(4097, 4), 1026);
assert_eq!(shard_bytes(4096, 10), 410);
assert_eq!(shard_bytes(0, 4), 0);
}
#[test]
fn rs_4_2_layout_is_byte_identical_to_legacy() {
for n in [1usize, 4, 33, 4096, 4097] {
assert_eq!(
parity_len(n, RS_DATA_SHARDS, RS_PARITY_SHARDS),
shard_bytes(n, 4) * 2,
);
}
}
#[test]
fn rs_encode_decode_roundtrip_no_corruption() {
let payload: Vec<u8> = (0..4096_u32).map(|i| (i & 0xff) as u8).collect();
let parity = encode_parity(&payload, 4, 2).expect("encode");
assert_eq!(parity.len(), parity_len(payload.len(), 4, 2));
let expected = crate::hash::hash128(&payload);
let recovered = try_recover(
&payload,
&parity,
payload.len(),
4,
2,
xxh3_oracle(expected),
)
.expect("try_recover")
.expect("recoverable");
assert_eq!(recovered, payload);
}
#[test]
fn rs_recovers_from_double_data_shard_corruption() {
let payload: Vec<u8> = (0..4096_u32).map(|i| (i & 0xff) as u8).collect();
let parity = encode_parity(&payload, 4, 2).expect("encode");
let expected = crate::hash::hash128(&payload);
let mut corrupted = payload.clone();
let sb = shard_bytes(payload.len(), 4);
for b in &mut corrupted[0..sb] {
*b ^= 0xAA;
}
for b in &mut corrupted[2 * sb..3 * sb] {
*b ^= 0xBB;
}
let recovered = try_recover(
&corrupted,
&parity,
payload.len(),
4,
2,
xxh3_oracle(expected),
)
.expect("try_recover")
.expect("double-shard corruption recoverable under RS(4,2)");
assert_eq!(recovered, payload);
}
#[test]
fn rs_unrecoverable_when_three_shards_corrupt() {
let payload: Vec<u8> = (0..4096_u32).map(|i| (i & 0xff) as u8).collect();
let parity = encode_parity(&payload, 4, 2).expect("encode");
let expected = crate::hash::hash128(&payload);
let mut corrupted = payload.clone();
let sb = shard_bytes(payload.len(), 4);
for b in &mut corrupted[0..3 * sb] {
*b ^= 0xCC;
}
let outcome = try_recover(
&corrupted,
&parity,
payload.len(),
4,
2,
xxh3_oracle(expected),
)
.expect("try_recover");
assert!(
outcome.is_none(),
"three-shard corruption must be unrecoverable"
);
}
#[test]
fn xor_single_parity_overhead_is_one_over_data_shards() {
let payload: Vec<u8> = (0..4096_u32).map(|i| (i & 0xff) as u8).collect();
let parity = encode_parity(&payload, 10, 1).expect("encode");
assert_eq!(parity.len(), parity_len(payload.len(), 10, 1));
assert_eq!(parity.len(), shard_bytes(4096, 10)); }
#[test]
fn xor_recovers_single_data_shard_loss() {
let payload: Vec<u8> = (0..4096_u32)
.map(|i| (i.wrapping_mul(7) & 0xff) as u8)
.collect();
let parity = encode_parity(&payload, 8, 1).expect("encode");
let expected = crate::hash::hash128(&payload);
let sb = shard_bytes(payload.len(), 8);
let mut corrupted = payload.clone();
for b in &mut corrupted[3 * sb..4 * sb] {
*b ^= 0xFF;
}
let recovered = try_recover(
&corrupted,
&parity,
payload.len(),
8,
1,
xxh3_oracle(expected),
)
.expect("try_recover")
.expect("single-shard loss recoverable under XOR");
assert_eq!(recovered, payload);
}
#[test]
fn xor_recovers_when_parity_itself_is_corrupt() {
let payload: Vec<u8> = (0..2048_u32).map(|i| (i & 0xff) as u8).collect();
let mut parity = encode_parity(&payload, 8, 1).expect("encode");
let expected = crate::hash::hash128(&payload);
parity[0] ^= 0xFF;
let recovered = try_recover(
&payload,
&parity,
payload.len(),
8,
1,
xxh3_oracle(expected),
)
.expect("try_recover")
.expect("data intact, parity corrupt is recoverable");
assert_eq!(recovered, payload);
}
#[test]
fn xor_unrecoverable_when_two_data_shards_lost() {
let payload: Vec<u8> = (0..4096_u32).map(|i| (i & 0xff) as u8).collect();
let parity = encode_parity(&payload, 8, 1).expect("encode");
let expected = crate::hash::hash128(&payload);
let sb = shard_bytes(payload.len(), 8);
let mut corrupted = payload.clone();
for b in &mut corrupted[0..sb] {
*b ^= 0xAA;
}
for b in &mut corrupted[2 * sb..3 * sb] {
*b ^= 0xBB;
}
let outcome = try_recover(
&corrupted,
&parity,
payload.len(),
8,
1,
xxh3_oracle(expected),
)
.expect("try_recover");
assert!(
outcome.is_none(),
"two-shard loss must be unrecoverable under XOR"
);
}
#[test]
fn rs_8_2_recovers_double_loss_low_overhead() {
let payload: Vec<u8> = (0..8192_u32).map(|i| (i & 0xff) as u8).collect();
let parity = encode_parity(&payload, 8, 2).expect("encode");
assert_eq!(parity.len(), shard_bytes(8192, 8) * 2);
let expected = crate::hash::hash128(&payload);
let sb = shard_bytes(payload.len(), 8);
let mut corrupted = payload.clone();
for b in &mut corrupted[5 * sb..6 * sb] {
*b ^= 0xAA;
}
for b in &mut corrupted[7 * sb..8 * sb] {
*b ^= 0xBB;
}
let recovered = try_recover(
&corrupted,
&parity,
payload.len(),
8,
2,
xxh3_oracle(expected),
)
.expect("try_recover")
.expect("double-shard loss recoverable under RS(8,2)");
assert_eq!(recovered, payload);
}
#[test]
fn handles_unaligned_block_size() {
let payload: Vec<u8> = (0..33_u8).collect();
let parity = encode_parity(&payload, 4, 2).expect("encode");
let expected = crate::hash::hash128(&payload);
let mut corrupted = payload.clone();
corrupted[0] ^= 0xFF;
let recovered = try_recover(
&corrupted,
&parity,
payload.len(),
4,
2,
xxh3_oracle(expected),
)
.expect("try_recover")
.expect("unaligned single-shard corruption recoverable");
assert_eq!(recovered, payload);
}
#[test]
fn next_combination_enumerates_all_pairs() {
let mut combo = vec![0usize, 1];
let mut seen = vec![combo.clone()];
while next_combination(&mut combo, 4) {
seen.push(combo.clone());
}
assert_eq!(
seen,
vec![
vec![0, 1],
vec![0, 2],
vec![0, 3],
vec![1, 2],
vec![1, 3],
vec![2, 3],
],
);
}
}