use bytes::Bytes;
use crate::domain::correction::{decode_shards, encode_shards};
use crate::domain::errors::CorrectionError;
use crate::domain::ports::ErrorCorrector;
use crate::domain::types::Shard;
#[derive(Debug)]
pub struct RsErrorCorrector {
hmac_key: Vec<u8>,
}
impl RsErrorCorrector {
#[must_use]
pub fn new(hmac_key: Vec<u8>) -> Self {
assert!(!hmac_key.is_empty(), "HMAC key must not be empty");
Self { hmac_key }
}
}
impl ErrorCorrector for RsErrorCorrector {
fn encode(
&self,
data: &[u8],
data_shards: u8,
parity_shards: u8,
) -> Result<Vec<Shard>, CorrectionError> {
encode_shards(data, data_shards, parity_shards, &self.hmac_key)
}
fn decode(
&self,
shards: &[Option<Shard>],
data_shards: u8,
parity_shards: u8,
) -> Result<Bytes, CorrectionError> {
let first_shard = shards.iter().find_map(|opt| opt.as_ref()).ok_or_else(|| {
CorrectionError::InsufficientShards {
needed: usize::from(data_shards),
available: 0,
}
})?;
let shard_size = first_shard.data.len();
let total_data_size = shard_size.strict_mul(usize::from(data_shards));
decode_shards(
shards,
data_shards,
parity_shards,
&self.hmac_key,
total_data_size,
)
}
}
#[cfg(test)]
mod tests {
use super::*;
type TestResult = Result<(), Box<dyn std::error::Error>>;
#[test]
fn test_rs_error_corrector_roundtrip() -> TestResult {
let hmac_key = b"test_key_32_bytes_long_padding!!".to_vec();
let corrector = RsErrorCorrector::new(hmac_key);
let data = b"The quick brown fox jumps over the lazy dog";
let shards = corrector.encode(data, 10, 5)?;
let opt_shards: Vec<Option<Shard>> = shards.into_iter().map(Some).collect();
let recovered = corrector.decode(&opt_shards, 10, 5)?;
assert!(recovered.starts_with(data));
Ok(())
}
#[test]
fn test_rs_error_corrector_with_missing_shards() -> TestResult {
let hmac_key = b"test_key_32_bytes_long_padding!!".to_vec();
let corrector = RsErrorCorrector::new(hmac_key);
let data = b"The quick brown fox jumps over the lazy dog";
let shards = corrector.encode(data, 10, 5)?;
let mut opt_shards: Vec<Option<Shard>> = shards.into_iter().map(Some).collect();
*opt_shards.get_mut(0).ok_or("out of bounds")? = None;
*opt_shards.get_mut(3).ok_or("out of bounds")? = None;
*opt_shards.get_mut(7).ok_or("out of bounds")? = None;
*opt_shards.get_mut(10).ok_or("out of bounds")? = None;
*opt_shards.get_mut(13).ok_or("out of bounds")? = None;
let recovered = corrector.decode(&opt_shards, 10, 5)?;
assert!(recovered.starts_with(data));
Ok(())
}
#[test]
fn test_rs_error_corrector_insufficient_shards() -> TestResult {
let hmac_key = b"test_key_32_bytes_long_padding!!".to_vec();
let corrector = RsErrorCorrector::new(hmac_key);
let data = b"test data";
let shards = corrector.encode(data, 10, 5)?;
let mut opt_shards: Vec<Option<Shard>> = shards.into_iter().map(Some).collect();
for i in 0..6 {
*opt_shards.get_mut(i).ok_or("out of bounds")? = None;
}
let result = corrector.decode(&opt_shards, 10, 5);
assert!(matches!(
result,
Err(CorrectionError::InsufficientShards { .. })
));
Ok(())
}
}