use bytes::Bytes;
use hmac::{Hmac, Mac};
use reed_solomon_erasure::galois_8::ReedSolomon;
use sha2::Sha256;
use crate::domain::errors::CorrectionError;
use crate::domain::types::Shard;
type HmacSha256 = Hmac<Sha256>;
fn compute_hmac_tag(
hmac_key: &[u8],
index: u8,
total: u8,
data: &[u8],
) -> Result<[u8; 32], CorrectionError> {
let mut mac =
HmacSha256::new_from_slice(hmac_key).map_err(|_| CorrectionError::InvalidParameters {
reason: "invalid HMAC key length".into(),
})?;
mac.update(&[index]);
mac.update(&[total]);
mac.update(data);
Ok(mac.finalize().into_bytes().into())
}
fn verify_hmac_tag(hmac_key: &[u8], shard: &Shard) -> Result<bool, CorrectionError> {
use subtle::ConstantTimeEq;
let expected = compute_hmac_tag(hmac_key, shard.index, shard.total, &shard.data)?;
Ok(expected.ct_eq(&shard.hmac_tag).into())
}
pub fn encode_shards(
data: &[u8],
data_shards: u8,
parity_shards: u8,
hmac_key: &[u8],
) -> Result<Vec<Shard>, CorrectionError> {
if data_shards == 0 {
return Err(CorrectionError::InvalidParameters {
reason: "data_shards must be > 0".into(),
});
}
if parity_shards == 0 {
return Err(CorrectionError::InvalidParameters {
reason: "parity_shards must be > 0".into(),
});
}
let total_shards = data_shards.strict_add(parity_shards);
let shard_size = (data
.len()
.strict_add(usize::from(data_shards).strict_sub(1)))
/ usize::from(data_shards);
let total_size = shard_size.strict_mul(usize::from(data_shards));
let mut padded = vec![0u8; total_size];
padded
.get_mut(..data.len())
.ok_or_else(|| CorrectionError::InvalidParameters {
reason: "data length exceeds padded buffer".into(),
})?
.copy_from_slice(data);
let rs =
ReedSolomon::new(usize::from(data_shards), usize::from(parity_shards)).map_err(|e| {
CorrectionError::ReedSolomonError {
reason: e.to_string(),
}
})?;
let mut chunks: Vec<Vec<u8>> = padded.chunks(shard_size).map(<[u8]>::to_vec).collect();
chunks.resize(usize::from(total_shards), vec![0u8; shard_size]);
rs.encode(&mut chunks)
.map_err(|e| CorrectionError::ReedSolomonError {
reason: e.to_string(),
})?;
let shards = chunks
.into_iter()
.enumerate()
.map(|(i, data)| {
#[expect(clippy::cast_possible_truncation, reason = "total_shards is u8")]
let index = i as u8;
let hmac_tag = compute_hmac_tag(hmac_key, index, total_shards, &data)?;
Ok(Shard {
index,
total: total_shards,
data,
hmac_tag,
})
})
.collect::<Result<Vec<Shard>, CorrectionError>>()?;
Ok(shards)
}
pub fn decode_shards(
shards: &[Option<Shard>],
data_shards: u8,
parity_shards: u8,
hmac_key: &[u8],
original_len: usize,
) -> Result<Bytes, CorrectionError> {
let total_shards = data_shards.strict_add(parity_shards);
if shards.len() != usize::from(total_shards) {
return Err(CorrectionError::InvalidParameters {
reason: format!("expected {} shards, got {}", total_shards, shards.len()),
});
}
for shard in shards.iter().flatten() {
if !verify_hmac_tag(hmac_key, shard)? {
return Err(CorrectionError::HmacMismatch { index: shard.index });
}
}
let valid_count = shards.iter().filter(|s| s.is_some()).count();
if valid_count < usize::from(data_shards) {
return Err(CorrectionError::InsufficientShards {
needed: usize::from(data_shards),
available: valid_count,
});
}
let rs =
ReedSolomon::new(usize::from(data_shards), usize::from(parity_shards)).map_err(|e| {
CorrectionError::ReedSolomonError {
reason: e.to_string(),
}
})?;
let mut chunks: Vec<Option<Vec<u8>>> = shards
.iter()
.map(|opt| opt.as_ref().map(|s| s.data.clone()))
.collect();
rs.reconstruct(&mut chunks)
.map_err(|e| CorrectionError::ReedSolomonError {
reason: e.to_string(),
})?;
let mut recovered = Vec::new();
for chunk in chunks.iter().take(usize::from(data_shards)).flatten() {
recovered.extend_from_slice(chunk);
}
recovered.truncate(original_len);
Ok(Bytes::from(recovered))
}
#[cfg(test)]
mod tests {
use super::*;
type TestResult = Result<(), Box<dyn std::error::Error>>;
const HMAC_KEY: &[u8] = b"test_hmac_key_32_bytes_long_!!!";
#[test]
fn test_encode_decode_roundtrip() -> TestResult {
let data = b"The quick brown fox jumps over the lazy dog";
let data_shards = 10;
let parity_shards = 5;
let shards = encode_shards(data, data_shards, parity_shards, HMAC_KEY)?;
assert_eq!(shards.len(), 15);
let opt_shards: Vec<Option<Shard>> = shards.into_iter().map(Some).collect();
let recovered = decode_shards(
&opt_shards,
data_shards,
parity_shards,
HMAC_KEY,
data.len(),
)?;
assert_eq!(recovered.as_ref(), data);
Ok(())
}
#[test]
fn test_decode_with_missing_shards() -> TestResult {
let data = b"The quick brown fox jumps over the lazy dog";
let data_shards = 10;
let parity_shards = 5;
let shards = encode_shards(data, data_shards, parity_shards, HMAC_KEY)?;
let mut opt_shards: Vec<Option<Shard>> = shards.into_iter().map(Some).collect();
*opt_shards.get_mut(0).ok_or("out of bounds")? = None;
*opt_shards.get_mut(3).ok_or("out of bounds")? = None;
*opt_shards.get_mut(7).ok_or("out of bounds")? = None;
*opt_shards.get_mut(10).ok_or("out of bounds")? = None;
*opt_shards.get_mut(13).ok_or("out of bounds")? = None;
let recovered = decode_shards(
&opt_shards,
data_shards,
parity_shards,
HMAC_KEY,
data.len(),
)?;
assert_eq!(recovered.as_ref(), data);
Ok(())
}
#[test]
fn test_decode_insufficient_shards() -> TestResult {
let data = b"test data";
let data_shards = 10;
let parity_shards = 5;
let shards = encode_shards(data, data_shards, parity_shards, HMAC_KEY)?;
let mut opt_shards: Vec<Option<Shard>> = shards.into_iter().map(Some).collect();
for i in 0..6 {
*opt_shards.get_mut(i).ok_or("out of bounds")? = None;
}
let result = decode_shards(
&opt_shards,
data_shards,
parity_shards,
HMAC_KEY,
data.len(),
);
assert!(matches!(
result,
Err(CorrectionError::InsufficientShards { .. })
));
Ok(())
}
#[test]
fn test_decode_hmac_mismatch() -> TestResult {
let data = b"test data";
let data_shards = 10;
let parity_shards = 5;
let mut shards = encode_shards(data, data_shards, parity_shards, HMAC_KEY)?;
let shard = shards.get_mut(0).ok_or("missing shard 0")?;
*shard.data.first_mut().ok_or("empty shard data")? ^= 0xFF;
let opt_shards: Vec<Option<Shard>> = shards.into_iter().map(Some).collect();
let result = decode_shards(
&opt_shards,
data_shards,
parity_shards,
HMAC_KEY,
data.len(),
);
assert!(matches!(result, Err(CorrectionError::HmacMismatch { .. })));
Ok(())
}
#[test]
fn test_encode_zero_data_shards() {
let data = b"test";
let result = encode_shards(data, 0, 5, HMAC_KEY);
assert!(matches!(
result,
Err(CorrectionError::InvalidParameters { .. })
));
}
#[test]
fn test_encode_zero_parity_shards() {
let data = b"test";
let result = encode_shards(data, 10, 0, HMAC_KEY);
assert!(matches!(
result,
Err(CorrectionError::InvalidParameters { .. })
));
}
}