use crate::domain::errors::ReconstructionError;
use crate::domain::ports::{ErrorCorrector, ExtractTechnique, Reconstructor};
use crate::domain::reconstruction::{
arrange_shards, count_present, deserialize_shard, validate_shard_count,
};
use crate::domain::types::{CoverMedia, Payload};
pub struct ReconstructorImpl {
data_shards: u8,
parity_shards: u8,
original_len: usize,
corrector: Box<dyn ErrorCorrector>,
}
impl ReconstructorImpl {
#[must_use]
pub fn new(
data_shards: u8,
parity_shards: u8,
original_len: usize,
corrector: Box<dyn ErrorCorrector>,
) -> Self {
Self {
data_shards,
parity_shards,
original_len,
corrector,
}
}
}
impl Reconstructor for ReconstructorImpl {
fn reconstruct(
&self,
covers: Vec<CoverMedia>,
extractor: &dyn ExtractTechnique,
progress_cb: &dyn Fn(usize, usize),
) -> Result<Payload, ReconstructionError> {
let total = covers.len();
let total_shards = self.data_shards.strict_add(self.parity_shards);
let mut shards = Vec::with_capacity(total);
for (i, cover) in covers.into_iter().enumerate() {
match extractor.extract(&cover) {
Ok(payload) => {
if let Some(shard) = deserialize_shard(payload.as_bytes()) {
shards.push(shard);
} else {
tracing::warn!(
cover_index = i,
"could not deserialize shard, treating as missing"
);
}
}
Err(e) => {
tracing::warn!(cover_index = i, error = %e, "extraction failed, treating as missing shard");
}
}
progress_cb(i.strict_add(1), total);
}
let slots = arrange_shards(shards, total_shards);
let present = count_present(&slots);
validate_shard_count(present, usize::from(self.data_shards))?;
let recovered = self
.corrector
.decode(&slots, self.data_shards, self.parity_shards)
.map_err(|source| ReconstructionError::CorrectionFailed { source })?;
let payload_bytes = if self.original_len > 0 {
recovered
.get(..self.original_len)
.map_or_else(|| recovered.to_vec(), ToOwned::to_owned)
} else {
recovered.to_vec()
};
Ok(Payload::from_bytes(payload_bytes))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::domain::errors::StegoError;
use crate::domain::ports::EmbedTechnique;
use crate::domain::ports::ErrorCorrector;
use crate::domain::reconstruction::serialize_shard;
use crate::domain::types::{Capacity, CoverMedia, CoverMediaKind, StegoTechnique};
use bytes::Bytes;
use std::cell::Cell;
type TestResult = Result<(), Box<dyn std::error::Error>>;
struct MockEmbedder;
impl EmbedTechnique for MockEmbedder {
fn technique(&self) -> StegoTechnique {
StegoTechnique::LsbImage
}
fn capacity(&self, cover: &CoverMedia) -> Result<Capacity, StegoError> {
Ok(Capacity {
bytes: cover.data.len() as u64,
technique: StegoTechnique::LsbImage,
})
}
fn embed(&self, cover: CoverMedia, payload: &Payload) -> Result<CoverMedia, StegoError> {
let mut data = cover.data.to_vec();
#[expect(clippy::cast_possible_truncation, reason = "test data < 4 GiB")]
let len = payload.len() as u32;
data.extend_from_slice(&len.to_le_bytes());
data.extend_from_slice(payload.as_bytes());
Ok(CoverMedia {
kind: cover.kind,
data: Bytes::from(data),
metadata: cover.metadata,
})
}
}
struct MockExtractor {
cover_prefix_len: usize,
}
impl ExtractTechnique for MockExtractor {
fn technique(&self) -> StegoTechnique {
StegoTechnique::LsbImage
}
fn extract(&self, stego: &CoverMedia) -> Result<Payload, StegoError> {
let data = &stego.data;
if data.len() <= self.cover_prefix_len + 4 {
return Err(StegoError::NoPayloadFound);
}
let offset = self.cover_prefix_len;
let len_bytes: [u8; 4] = data
.get(offset..offset + 4)
.ok_or(StegoError::NoPayloadFound)?
.try_into()
.map_err(|_| StegoError::NoPayloadFound)?;
let len = u32::from_le_bytes(len_bytes) as usize;
let start = offset + 4;
let payload_data = data
.get(start..start + len)
.ok_or(StegoError::NoPayloadFound)?;
Ok(Payload::from_bytes(payload_data.to_vec()))
}
}
fn make_cover(size: usize) -> CoverMedia {
CoverMedia {
kind: CoverMediaKind::PngImage,
data: Bytes::from(vec![0u8; size]),
metadata: std::collections::HashMap::new(),
}
}
fn distribute_and_get_covers(
payload: &[u8],
data_shards: u8,
parity_shards: u8,
hmac_key: &[u8],
cover_size: usize,
) -> Result<Vec<CoverMedia>, Box<dyn std::error::Error>> {
let corrector = crate::adapters::correction::RsErrorCorrector::new(hmac_key.to_vec());
let shards = corrector.encode(payload, data_shards, parity_shards)?;
let embedder = MockEmbedder;
let covers = shards
.iter()
.map(|shard| {
let cover = make_cover(cover_size);
let serialized = serialize_shard(shard);
let shard_payload = Payload::from_bytes(serialized);
embedder.embed(cover, &shard_payload)
})
.collect::<Result<Vec<_>, _>>()?;
Ok(covers)
}
#[test]
fn full_recovery_all_shards_present() -> TestResult {
let original = b"hello reconstruction world!";
let hmac_key = b"test-hmac-key";
let covers = distribute_and_get_covers(original, 3, 2, hmac_key, 128)?;
assert_eq!(covers.len(), 5);
let corrector: Box<dyn ErrorCorrector> = Box::new(
crate::adapters::correction::RsErrorCorrector::new(hmac_key.to_vec()),
);
let reconstructor = ReconstructorImpl::new(3, 2, original.len(), corrector);
let extractor = MockExtractor {
cover_prefix_len: 128,
};
let progress_calls = Cell::new(0usize);
let result = reconstructor.reconstruct(covers, &extractor, &|_done, _total| {
progress_calls.set(progress_calls.get().strict_add(1));
})?;
assert_eq!(result.as_bytes(), original);
assert_eq!(progress_calls.get(), 5);
Ok(())
}
#[test]
fn partial_recovery_minimum_shards() -> TestResult {
let original = b"partial recovery test payload";
let hmac_key = b"test-hmac-key";
let mut covers = distribute_and_get_covers(original, 3, 2, hmac_key, 128)?;
assert_eq!(covers.len(), 5);
covers.remove(4);
covers.remove(3);
let corrector: Box<dyn ErrorCorrector> = Box::new(
crate::adapters::correction::RsErrorCorrector::new(hmac_key.to_vec()),
);
let reconstructor = ReconstructorImpl::new(3, 2, original.len(), corrector);
let extractor = MockExtractor {
cover_prefix_len: 128,
};
let result = reconstructor.reconstruct(covers, &extractor, &|_, _| {})?;
assert_eq!(result.as_bytes(), original);
Ok(())
}
#[test]
fn insufficient_shards_returns_error() -> TestResult {
let original = b"not enough shards";
let hmac_key = b"test-hmac-key";
let mut covers = distribute_and_get_covers(original, 3, 2, hmac_key, 128)?;
covers.remove(4);
covers.remove(3);
covers.remove(2);
let corrector: Box<dyn ErrorCorrector> = Box::new(
crate::adapters::correction::RsErrorCorrector::new(hmac_key.to_vec()),
);
let reconstructor = ReconstructorImpl::new(3, 2, original.len(), corrector);
let extractor = MockExtractor {
cover_prefix_len: 128,
};
let result = reconstructor.reconstruct(covers, &extractor, &|_, _| {});
assert!(result.is_err());
Ok(())
}
#[test]
fn progress_callback_called_correctly() -> TestResult {
let original = b"track progress";
let hmac_key = b"test-hmac-key";
let covers = distribute_and_get_covers(original, 2, 1, hmac_key, 64)?;
let total_covers = covers.len();
let corrector: Box<dyn ErrorCorrector> = Box::new(
crate::adapters::correction::RsErrorCorrector::new(hmac_key.to_vec()),
);
let reconstructor = ReconstructorImpl::new(2, 1, original.len(), corrector);
let extractor = MockExtractor {
cover_prefix_len: 64,
};
let progress_log = std::cell::RefCell::new(Vec::new());
let result = reconstructor.reconstruct(covers, &extractor, &|done, total| {
progress_log.borrow_mut().push((done, total));
})?;
assert_eq!(result.as_bytes(), original);
let log = progress_log.borrow();
assert_eq!(log.len(), total_covers);
for (i, &(done, total)) in log.iter().enumerate() {
assert_eq!(done, i + 1);
assert_eq!(total, total_covers);
}
Ok(())
}
}