use reed_solomon::{Decoder, Encoder};
const HEADER: usize = 5;
pub fn protect(data: &[u8], parity: u8) -> Vec<u8> {
let parity = parity.max(2); let chunk = 255 - parity as usize;
let encoder = Encoder::new(parity as usize);
let mut out = Vec::with_capacity(HEADER + data.len() + parity as usize);
out.push(parity);
out.extend_from_slice(&(data.len() as u32).to_le_bytes());
for block in data.chunks(chunk) {
let encoded = encoder.encode(block);
out.extend_from_slice(&encoded); }
out
}
pub fn recover(protected: &[u8]) -> Option<Vec<u8>> {
if protected.len() < HEADER {
return None;
}
let parity = protected[0];
if parity < 2 || 255 - (parity as usize) == 0 {
return None;
}
let data_len = u32::from_le_bytes(protected[1..HEADER].try_into().ok()?) as usize;
if data_len > protected.len() {
return None;
}
let chunk = 255 - parity as usize;
let decoder = Decoder::new(parity as usize);
let mut body = &protected[HEADER..];
let mut out = Vec::with_capacity(data_len);
let mut remaining = data_len;
while remaining > 0 {
let data_in_block = remaining.min(chunk);
let block_len = data_in_block + parity as usize;
if body.len() < block_len {
return None;
}
let corrected = decoder.correct(&body[..block_len], None).ok()?;
out.extend_from_slice(&corrected.data()[..data_in_block]);
body = &body[block_len..];
remaining -= data_in_block;
}
Some(out)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn round_trips_without_errors() {
let data = b"datos a proteger con correccion de errores";
let prot = protect(data, 8);
assert_eq!(recover(&prot).unwrap(), data);
}
#[test]
fn round_trips_large_data_across_blocks() {
let data: Vec<u8> = (0..1000u32).map(|i| (i % 251) as u8).collect();
let prot = protect(&data, 8);
assert_eq!(recover(&prot).unwrap(), data);
}
#[test]
fn corrects_errors_within_capacity() {
let data = b"mensaje que sufrira corrupcion en el canal";
let mut prot = protect(data, 8); for k in 0..4 {
prot[HEADER + k] ^= 0xFF;
}
assert_eq!(recover(&prot).unwrap(), data);
}
#[test]
fn fails_when_too_many_errors() {
let data = b"corto";
let mut prot = protect(data, 4); for k in 0..5 {
prot[HEADER + k] ^= 0xFF;
}
assert!(recover(&prot).is_none());
}
#[test]
fn round_trips_empty() {
let prot = protect(b"", 8);
assert_eq!(recover(&prot).unwrap(), b"");
}
#[test]
fn rejects_malicious_data_len_without_oom() {
let mut prot = protect(b"hola", 8);
prot[1..5].copy_from_slice(&u32::MAX.to_le_bytes());
assert!(recover(&prot).is_none());
}
#[test]
fn rejects_degenerate_parity_byte() {
let mut prot = protect(b"hola", 8);
prot[0] = 255;
assert!(recover(&prot).is_none());
let mut prot2 = protect(b"hola", 8);
prot2[0] = 254;
let _ = recover(&prot2);
}
}