use alloc::vec;
use alloc::vec::Vec;
use reed_solomon_simd::{ReedSolomonDecoder, ReedSolomonEncoder};
pub const RS_DATA_SHARDS: usize = 4;
pub const RS_PARITY_SHARDS: usize = 2;
#[must_use]
pub fn shard_bytes(payload_len: usize, data_shards: usize) -> usize {
if data_shards == 0 {
return 0;
}
let raw = payload_len.div_ceil(data_shards);
raw.div_ceil(2) * 2
}
#[must_use]
pub fn parity_len(payload_len: usize, data_shards: usize, parity_shards: usize) -> usize {
shard_bytes(payload_len, data_shards) * parity_shards
}
pub fn encode_parity(
payload: &[u8],
data_shards: usize,
parity_shards: usize,
) -> crate::Result<Vec<u8>> {
let sb = shard_bytes(payload.len(), data_shards);
if sb == 0 || parity_shards == 0 {
return Ok(Vec::new());
}
if parity_shards == 1 {
return Ok(encode_xor(payload, data_shards, sb));
}
encode_rs(payload, data_shards, parity_shards, sb)
}
fn fill_data_shard(buf: &mut [u8], payload: &[u8], i: usize, sb: usize) {
buf.fill(0);
let start = i * sb;
let end = ((i + 1) * sb).min(payload.len());
if start < payload.len() {
#[expect(
clippy::indexing_slicing,
reason = "start < payload.len() and end <= payload.len() guarded above"
)]
buf[..end - start].copy_from_slice(&payload[start..end]);
}
}
fn encode_xor(payload: &[u8], data_shards: usize, sb: usize) -> Vec<u8> {
let mut parity = vec![0u8; sb];
let mut shard = vec![0u8; sb];
for i in 0..data_shards {
fill_data_shard(&mut shard, payload, i, sb);
for (p, &b) in parity.iter_mut().zip(shard.iter()) {
*p ^= b;
}
}
parity
}
fn encode_rs(
payload: &[u8],
data_shards: usize,
parity_shards: usize,
sb: usize,
) -> crate::Result<Vec<u8>> {
let mut encoder = ReedSolomonEncoder::new(data_shards, parity_shards, sb)
.map_err(|_| crate::Error::Unrecoverable)?;
let mut shard_buf = vec![0u8; sb];
for i in 0..data_shards {
fill_data_shard(&mut shard_buf, payload, i, sb);
encoder
.add_original_shard(&shard_buf)
.map_err(|_| crate::Error::Unrecoverable)?;
}
let result = encoder.encode().map_err(|_| crate::Error::Unrecoverable)?;
let mut out = Vec::with_capacity(sb * parity_shards);
for shard in result.recovery_iter() {
out.extend_from_slice(shard);
}
Ok(out)
}
pub fn try_recover<F>(
data: &[u8],
parity: &[u8],
expected_payload_len: usize,
data_shards: usize,
parity_shards: usize,
mut xxh3_oracle: F,
) -> crate::Result<Option<Vec<u8>>>
where
F: FnMut(&[u8]) -> bool,
{
let sb = shard_bytes(expected_payload_len, data_shards);
if sb == 0
|| parity_shards == 0
|| data.len() < expected_payload_len
|| parity.len() < sb * parity_shards
{
return Ok(None);
}
if parity_shards == 1 {
return Ok(xor_recover(
data,
parity,
expected_payload_len,
data_shards,
sb,
&mut xxh3_oracle,
));
}
rs_recover(
data,
parity,
expected_payload_len,
data_shards,
parity_shards,
sb,
&mut xxh3_oracle,
)
}
fn carve_data_shards(data: &[u8], data_shards: usize, sb: usize) -> Vec<Vec<u8>> {
let mut shards = Vec::with_capacity(data_shards);
for i in 0..data_shards {
let mut buf = vec![0u8; sb];
let start = i * sb;
let end = ((i + 1) * sb).min(data.len());
if start < data.len() {
#[expect(
clippy::indexing_slicing,
reason = "start < data.len() and end <= data.len() guarded above"
)]
buf[..end - start].copy_from_slice(&data[start..end]);
}
shards.push(buf);
}
shards
}
fn xor_recover<F>(
data: &[u8],
parity: &[u8],
expected_payload_len: usize,
data_shards: usize,
sb: usize,
xxh3_oracle: &mut F,
) -> Option<Vec<u8>>
where
F: FnMut(&[u8]) -> bool,
{
let shards = carve_data_shards(data, data_shards, sb);
let mut payload = Vec::with_capacity(data_shards * sb);
for s in &shards {
payload.extend_from_slice(s);
}
payload.truncate(expected_payload_len);
if xxh3_oracle(&payload) {
return Some(payload);
}
#[expect(
clippy::indexing_slicing,
reason = "parity.len() >= sb guarded by caller"
)]
let parity_shard = &parity[..sb];
for miss in 0..data_shards {
let mut recovered = parity_shard.to_vec();
for (i, s) in shards.iter().enumerate() {
if i == miss {
continue;
}
for (r, &b) in recovered.iter_mut().zip(s.iter()) {
*r ^= b;
}
}
let mut payload = Vec::with_capacity(data_shards * sb);
for (i, s) in shards.iter().enumerate() {
if i == miss {
payload.extend_from_slice(&recovered);
} else {
payload.extend_from_slice(s);
}
}
payload.truncate(expected_payload_len);
if xxh3_oracle(&payload) {
return Some(payload);
}
}
None
}
fn rs_recover<F>(
data: &[u8],
parity: &[u8],
expected_payload_len: usize,
data_shards: usize,
parity_shards: usize,
sb: usize,
xxh3_oracle: &mut F,
) -> crate::Result<Option<Vec<u8>>>
where
F: FnMut(&[u8]) -> bool,
{
let n = data_shards + parity_shards;
let mut shards = carve_data_shards(data, data_shards, sb);
for i in 0..parity_shards {
let start = i * sb;
let end = start + sb;
if end > parity.len() {
return Ok(None);
}
#[expect(clippy::indexing_slicing, reason = "end <= parity.len() guarded above")]
shards.push(parity[start..end].to_vec());
}
let mut missing = (0..parity_shards).collect::<Vec<usize>>();
loop {
if let Some(payload) = try_decode_one(
&shards,
sb,
expected_payload_len,
data_shards,
parity_shards,
&missing,
)? && xxh3_oracle(&payload)
{
return Ok(Some(payload));
}
if !next_combination(&mut missing, n) {
break;
}
}
Ok(None)
}
fn next_combination(combo: &mut [usize], n: usize) -> bool {
let k = combo.len();
if k == 0 {
return false;
}
let mut i = k - 1;
loop {
let max_at_i = n - k + i;
#[expect(clippy::indexing_slicing, reason = "i < k == combo.len()")]
if combo[i] < max_at_i {
combo[i] += 1;
for j in (i + 1)..k {
#[expect(clippy::indexing_slicing, reason = "i < j < k == combo.len()")]
{
combo[j] = combo[j - 1] + 1;
}
}
return true;
}
if i == 0 {
return false;
}
i -= 1;
}
}
fn try_decode_one(
shards: &[Vec<u8>],
sb: usize,
expected_payload_len: usize,
data_shards: usize,
parity_shards: usize,
missing: &[usize],
) -> crate::Result<Option<Vec<u8>>> {
let mut decoder = ReedSolomonDecoder::new(data_shards, parity_shards, sb)
.map_err(|_| crate::Error::Unrecoverable)?;
for (i, shard) in shards.iter().enumerate().take(data_shards) {
if missing.contains(&i) {
continue;
}
decoder
.add_original_shard(i, shard)
.map_err(|_| crate::Error::Unrecoverable)?;
}
for (i, shard) in shards
.iter()
.enumerate()
.skip(data_shards)
.take(parity_shards)
{
if missing.contains(&i) {
continue;
}
decoder
.add_recovery_shard(i - data_shards, shard)
.map_err(|_| crate::Error::Unrecoverable)?;
}
let Ok(result) = decoder.decode() else {
return Ok(None);
};
let mut payload = Vec::with_capacity(data_shards * sb);
for (i, shard) in shards.iter().enumerate().take(data_shards) {
if missing.contains(&i) {
match result.restored_original(i) {
Some(s) => payload.extend_from_slice(s),
None => return Ok(None),
}
} else {
payload.extend_from_slice(shard);
}
}
payload.truncate(expected_payload_len);
Ok(Some(payload))
}
#[cfg(test)]
#[expect(clippy::expect_used, clippy::indexing_slicing, reason = "test code")]
mod tests;