use ps_buffer::Buffer;
use crate::{codeword::Codeword, long, DecodeError, EncodeError, ReedSolomon};
pub fn encode(message: &[u8], parity: u8) -> Result<Buffer, EncodeError> {
if message.len() + (usize::from(parity) << 1) > 0xff {
let segment_length = 0xFF - (parity << 1);
let codeword = long::encode(message, parity, segment_length, segment_length)?;
return Ok(codeword);
}
let rs = ReedSolomon::new(parity)?;
Ok(rs.encode(message)?)
}
pub fn decode(received: &[u8], parity: u8) -> Result<Codeword<'_>, DecodeError> {
if let Ok(length) = u8::try_from(received.len()) {
if parity > length >> 1 {
return Err(DecodeError::InsufficientParityBytes(parity, length));
}
let rs = ReedSolomon::new(parity)?;
Ok(rs.decode(received)?)
} else {
Ok(long::decode(received)?)
}
}
#[must_use]
pub fn validate(received: &[u8], parity: u8) -> bool {
if let Ok(length) = u8::try_from(received.len()) {
if parity > length >> 1 {
return false;
}
let Ok(rs) = ReedSolomon::new(parity) else {
return false;
};
match rs.validate(received) {
Ok(None) => true,
Ok(Some(_)) | Err(_) => false,
}
} else {
long::fast_validate(received).unwrap_or_default()
}
}
#[cfg(test)]
mod tests {
use crate::EccError;
use super::{decode, encode};
#[test]
fn ecc_works() -> Result<(), EccError> {
let test_str = "Strč prst skrz krk! ¯\\_(ツ)_/¯".as_bytes();
let mut encoded = encode(test_str, 13)?;
for i in 0..13 {
let index = (i * 37) % encoded.len();
encoded[index] ^= (i * index + 13).to_le_bytes()[0];
let decoded = decode(&encoded, 13)?;
assert_eq!(test_str, &decoded[..]);
}
Ok(())
}
}
#[cfg(test)]
mod validate_tests {
use crate::{
long, validate, LongEccConstructorError, LongEccDecodeError, LongEccEncodeError,
LongEccToBytesError, RSEncodeError, ReedSolomon,
};
use ps_buffer::ToBuffer;
#[derive(thiserror::Error, Debug)]
enum TestError {
#[error(transparent)]
LongEccConstructor(#[from] LongEccConstructorError),
#[error(transparent)]
LongEccEncode(#[from] LongEccEncodeError),
#[error(transparent)]
LongEccDecode(#[from] LongEccDecodeError),
#[error(transparent)]
LongEccToBytes(#[from] LongEccToBytesError),
#[error(transparent)]
Buffer(#[from] ps_buffer::BufferError),
#[error(transparent)]
RSConstructorError(#[from] crate::RSConstructorError),
#[error(transparent)]
RSEncodeError(#[from] RSEncodeError),
}
#[test]
fn test_validate_short_data_valid_no_errors() -> Result<(), TestError> {
let data = b"test";
let parity = 2;
let rs = ReedSolomon::new(parity)?;
let codeword = rs.encode(data)?;
assert!(validate(&codeword, parity));
Ok(())
}
#[test]
fn test_validate_short_data_invalid_with_errors() -> Result<(), TestError> {
let data = b"test";
let parity = 2;
let rs = ReedSolomon::new(parity)?;
let mut codeword = rs.encode(data)?;
codeword[0] ^= 1;
codeword[1] ^= 1;
codeword[2] ^= 1;
assert!(!validate(&codeword, parity));
Ok(())
}
#[test]
fn test_validate_short_data_parity_too_large() {
let data = b"test"; let parity = 3;
assert!(!validate(data, parity));
}
#[test]
fn test_validate_short_data_rs_constructor_error() {
let data = b"test";
let parity = 255;
assert!(!validate(data, parity));
}
#[test]
fn test_validate_short_data_correctable_errors() -> Result<(), TestError> {
let data = b"test";
let parity = 2;
let rs = ReedSolomon::new(parity)?;
let mut codeword = rs.encode(data)?;
codeword[0] ^= 1;
assert!(!validate(&codeword, parity));
Ok(())
}
#[test]
fn test_validate_long_data_valid_no_errors() -> Result<(), TestError> {
let message = b"This is a longer message that will use long ECC".repeat(7);
let parity = 2;
let segment_length = 20;
let segment_distance = 16;
let encoded = long::encode(&message, parity, segment_length, segment_distance)?;
assert!(validate(&encoded, parity));
Ok(())
}
#[test]
fn test_validate_long_data_invalid_with_errors() -> Result<(), TestError> {
let message = b"This is a longer message that will use long ECC".to_buffer()?;
let parity = 2;
let segment_length = 20;
let segment_distance = 16;
let mut encoded = long::encode(&message, parity, segment_length, segment_distance)?;
encoded[32] ^= 1;
encoded[37] ^= 1;
assert!(!validate(&encoded, parity));
Ok(())
}
#[test]
fn test_validate_long_data_fast_path_valid() -> Result<(), TestError> {
let message = b"Fast path validation test".repeat(12);
let parity = 1;
let segment_length = 15;
let segment_distance = 12;
let encoded = long::encode(&message, parity, segment_length, segment_distance)?;
assert!(validate(&encoded, parity));
Ok(())
}
#[test]
fn test_validate_empty_data() {
let data = b"";
let parity = 0;
assert!(validate(data, parity));
}
#[test]
fn test_validate_single_byte() -> Result<(), TestError> {
let data = b"A";
let parity = 1;
let rs = ReedSolomon::new(parity)?;
let codeword = rs.encode(data)?;
assert!(validate(&codeword, parity));
Ok(())
}
#[test]
fn test_validate_large_short_data() -> Result<(), TestError> {
let data = b"This is exactly 32 bytes of test data!!";
let parity = 4;
let rs = ReedSolomon::new(parity)?;
let codeword = rs.encode(data)?;
assert!(validate(&codeword, parity));
Ok(())
}
#[test]
fn test_validate_edge_case_parity_equals_length_div_2() -> Result<(), TestError> {
let data = b"test"; let parity = 2;
let rs = ReedSolomon::new(parity)?;
let codeword = rs.encode(data)?;
assert!(validate(&codeword, parity));
Ok(())
}
#[test]
fn test_validate_edge_case_parity_just_over_length_div_2() {
let data = b"test"; let parity = 3;
assert!(!validate(data, parity));
}
#[test]
fn test_validate_long_data_with_zero_parity() -> Result<(), TestError> {
let message = b"Zero parity test".to_buffer()?;
let parity = 0;
let segment_length = 15;
let segment_distance = 12;
let encoded = long::encode(&message, parity, segment_length, segment_distance)?;
assert!(validate(&encoded, parity));
Ok(())
}
#[test]
fn test_validate_long_data_header_corrupted() -> Result<(), TestError> {
let message = b"Header corruption test".to_buffer()?;
let parity = 2;
let segment_length = 15;
let segment_distance = 12;
let mut encoded = long::encode(&message, parity, segment_length, segment_distance)?;
encoded[0] ^= 1;
encoded[5] ^= 1;
assert!(!validate(&encoded, parity));
Ok(())
}
#[test]
fn test_validate_short_data_length_conversion_error() {
let data: Vec<u8> = vec![0x42; 300]; let parity = 2;
assert!(!validate(&data, parity));
}
#[test]
fn test_validate_short_data_unrecoverable_errors() -> Result<(), TestError> {
let data = b"test data";
let parity = 1;
let rs = ReedSolomon::new(parity)?;
let mut codeword = rs.encode(data)?;
codeword[0] ^= 1;
codeword[1] ^= 1;
assert!(!validate(&codeword, parity));
Ok(())
}
#[test]
fn test_validate_long_data_corrupted_parity() -> Result<(), TestError> {
let message = b"Corrupted parity test".to_buffer()?;
let parity = 2;
let segment_length = 15;
let segment_distance = 12;
let mut encoded = long::encode(&message, parity, segment_length, segment_distance)?;
let parity_start = 32 + message.len();
if parity_start < encoded.len() {
encoded[parity_start] ^= 1;
}
assert!(!validate(&encoded, parity));
Ok(())
}
}