use std::num::NonZeroU16;
use commonware_codec::{Decode, Encode};
use commonware_coding::{CodecConfig, Scheme as _};
use commonware_cryptography::Sha256;
use commonware_parallel::Sequential;
use crate::hash::{self, HASH_LEN, Hash};
pub use commonware_coding::Config;
type RsScheme = commonware_coding::ReedSolomon<Sha256>;
type Commitment = <RsScheme as commonware_coding::Scheme>::Commitment;
type RsChunk = <RsScheme as commonware_coding::Scheme>::Shard;
const STRATEGY: Sequential = Sequential;
const MAX_SHARD_BYTES: usize = 4 * 1024 * 1024 * 1024;
pub const SHARD_SIZE_THRESHOLD: u64 = 1024 * 1024;
pub const MANIFEST_MAGIC: [u8; 4] = *b"MKSH";
pub const MANIFEST_VERSION: u8 = 0x01;
const MANIFEST_PROLOGUE_LEN: usize = 5;
pub const MANIFEST_MAX_BYTES: usize = 1024 * 1024;
#[must_use]
pub fn default_config() -> Config {
Config {
minimum_shards: NonZeroU16::new(16).expect("16 != 0"),
extra_shards: NonZeroU16::new(4).expect("4 != 0"),
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Shard {
pub index: u16,
pub bytes: Vec<u8>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ShardSet {
pub pack_hash: Hash,
pub config: Config,
pub shard_hashes: Vec<Hash>,
pub commitment: Hash,
}
#[derive(Debug, thiserror::Error)]
pub enum ShardError {
#[error("reed-solomon encode failed: {0}")]
EncodeFailed(String),
#[error("reed-solomon decode failed: {0}")]
DecodeFailed(String),
#[error("shard codec decode failed at index {index}: {source}")]
ShardCodecFailed {
index: u16,
#[source]
source: commonware_codec::Error,
},
#[error("shard {index} BLAKE3 mismatch (manifest tampered or shard corrupted)")]
ShardHashMismatch { index: u16 },
#[error("shard index {index} is out of range for config (total = {total})")]
IndexOutOfRange { index: u16, total: u32 },
#[error("duplicate shard index {index}")]
DuplicateIndex { index: u16 },
#[error(
"manifest has {actual} shard_hashes, expected {expected} \
(config.total_shards())"
)]
ManifestShardCountMismatch { actual: usize, expected: usize },
#[error("reconstructed pack hash does not match manifest.pack_hash")]
PackHashMismatch,
#[error("insufficient shards: {provided} < {minimum}")]
InsufficientShards { provided: usize, minimum: u16 },
#[error("invalid manifest prologue: {0}")]
InvalidManifestPrologue(&'static str),
#[error("unexpected eof while decoding manifest")]
ManifestUnexpectedEof,
#[error("trailing bytes after manifest body")]
ManifestTrailingBytes,
#[error("manifest declares zero shard count (min={minimum}, extra={extra})")]
ManifestZeroShardCount { minimum: u16, extra: u16 },
#[error("manifest is too large: {actual} > {max}")]
ManifestTooLarge { actual: usize, max: usize },
}
pub fn encode_pack_to_shards(
pack: &[u8],
config: Config,
) -> Result<(Vec<Shard>, ShardSet), ShardError> {
let (commitment, chunks) = RsScheme::encode(&config, pack, &STRATEGY)
.map_err(|e| ShardError::EncodeFailed(format!("{e:?}")))?;
let total = config.total_shards() as usize;
debug_assert_eq!(chunks.len(), total);
let mut shards = Vec::with_capacity(total);
let mut shard_hashes = Vec::with_capacity(total);
for (i, chunk) in chunks.into_iter().enumerate() {
let index = u16::try_from(i).expect("commonware emits <= u16::MAX shards");
let bytes = chunk.encode().to_vec();
let h = hash::hash(&bytes);
shards.push(Shard { index, bytes });
shard_hashes.push(h);
}
let manifest = ShardSet {
pack_hash: hash::hash(pack),
config,
shard_hashes,
commitment: digest_to_bytes(&commitment),
};
Ok((shards, manifest))
}
pub fn decode_pack_from_shards(
shards: &[Shard],
manifest: &ShardSet,
) -> Result<Vec<u8>, ShardError> {
let total = manifest.config.total_shards();
if manifest.shard_hashes.len() != total as usize {
return Err(ShardError::ManifestShardCountMismatch {
actual: manifest.shard_hashes.len(),
expected: total as usize,
});
}
let minimum = manifest.config.minimum_shards.get();
let commitment = bytes_to_digest(&manifest.commitment);
let codec_cfg = CodecConfig {
maximum_shard_size: MAX_SHARD_BYTES,
};
let mut seen = vec![false; total as usize];
let mut checked = Vec::with_capacity(shards.len());
for shard in shards {
if u32::from(shard.index) >= total {
return Err(ShardError::IndexOutOfRange {
index: shard.index,
total,
});
}
let slot = &mut seen[shard.index as usize];
if *slot {
return Err(ShardError::DuplicateIndex { index: shard.index });
}
*slot = true;
let expected = &manifest.shard_hashes[shard.index as usize];
if &hash::hash(&shard.bytes) != expected {
return Err(ShardError::ShardHashMismatch { index: shard.index });
}
let chunk = RsChunk::decode_cfg(shard.bytes.as_slice(), &codec_cfg).map_err(|e| {
ShardError::ShardCodecFailed {
index: shard.index,
source: e,
}
})?;
let checked_shard = RsScheme::check(&manifest.config, &commitment, shard.index, &chunk)
.map_err(|e| ShardError::DecodeFailed(format!("check({}): {e:?}", shard.index)))?;
checked.push(checked_shard);
}
if checked.len() < usize::from(minimum) {
return Err(ShardError::InsufficientShards {
provided: checked.len(),
minimum,
});
}
let pack = RsScheme::decode(&manifest.config, &commitment, checked.iter(), &STRATEGY)
.map_err(|e| ShardError::DecodeFailed(format!("{e:?}")))?;
if hash::hash(&pack) != manifest.pack_hash {
return Err(ShardError::PackHashMismatch);
}
Ok(pack)
}
fn digest_to_bytes(d: &Commitment) -> [u8; HASH_LEN] {
let slice: &[u8] = d.as_ref();
let mut out = [0u8; HASH_LEN];
out.copy_from_slice(slice);
out
}
fn bytes_to_digest(b: &[u8; HASH_LEN]) -> Commitment {
use commonware_codec::FixedSize;
debug_assert_eq!(<Commitment as FixedSize>::SIZE, HASH_LEN);
Commitment::from(*b)
}
pub fn encode_manifest(manifest: &ShardSet) -> Result<Vec<u8>, ShardError> {
let total = manifest.config.total_shards() as usize;
if manifest.shard_hashes.len() != total {
return Err(ShardError::ManifestShardCountMismatch {
actual: manifest.shard_hashes.len(),
expected: total,
});
}
let body_len = MANIFEST_PROLOGUE_LEN + HASH_LEN + 2 + 2 + HASH_LEN + 4 + total * HASH_LEN;
let mut out = Vec::with_capacity(body_len);
out.extend_from_slice(&MANIFEST_MAGIC);
out.push(MANIFEST_VERSION);
out.extend_from_slice(&manifest.pack_hash);
out.extend_from_slice(&manifest.config.minimum_shards.get().to_le_bytes());
out.extend_from_slice(&manifest.config.extra_shards.get().to_le_bytes());
out.extend_from_slice(&manifest.commitment);
out.extend_from_slice(
&u32::try_from(total)
.expect("total_shards fits in u32")
.to_le_bytes(),
);
for h in &manifest.shard_hashes {
out.extend_from_slice(h);
}
debug_assert_eq!(out.len(), body_len);
Ok(out)
}
pub fn decode_manifest(bytes: &[u8]) -> Result<ShardSet, ShardError> {
if bytes.len() > MANIFEST_MAX_BYTES {
return Err(ShardError::ManifestTooLarge {
actual: bytes.len(),
max: MANIFEST_MAX_BYTES,
});
}
if bytes.len() < MANIFEST_PROLOGUE_LEN {
return Err(ShardError::InvalidManifestPrologue(
"input shorter than prologue",
));
}
if bytes[..4] != MANIFEST_MAGIC {
return Err(ShardError::InvalidManifestPrologue("bad magic"));
}
if bytes[4] != MANIFEST_VERSION {
return Err(ShardError::InvalidManifestPrologue("unsupported version"));
}
let mut pos = MANIFEST_PROLOGUE_LEN;
if bytes.len() - pos < HASH_LEN {
return Err(ShardError::ManifestUnexpectedEof);
}
let mut pack_hash = [0u8; HASH_LEN];
pack_hash.copy_from_slice(&bytes[pos..pos + HASH_LEN]);
pos += HASH_LEN;
if bytes.len() - pos < 4 {
return Err(ShardError::ManifestUnexpectedEof);
}
let minimum = u16::from_le_bytes([bytes[pos], bytes[pos + 1]]);
let extra = u16::from_le_bytes([bytes[pos + 2], bytes[pos + 3]]);
pos += 4;
let minimum_nz =
NonZeroU16::new(minimum).ok_or(ShardError::ManifestZeroShardCount { minimum, extra })?;
let extra_nz =
NonZeroU16::new(extra).ok_or(ShardError::ManifestZeroShardCount { minimum, extra })?;
let config = Config {
minimum_shards: minimum_nz,
extra_shards: extra_nz,
};
let total = config.total_shards();
if bytes.len() - pos < HASH_LEN {
return Err(ShardError::ManifestUnexpectedEof);
}
let mut commitment = [0u8; HASH_LEN];
commitment.copy_from_slice(&bytes[pos..pos + HASH_LEN]);
pos += HASH_LEN;
if bytes.len() - pos < 4 {
return Err(ShardError::ManifestUnexpectedEof);
}
let declared_len =
u32::from_le_bytes([bytes[pos], bytes[pos + 1], bytes[pos + 2], bytes[pos + 3]]);
pos += 4;
if declared_len != total {
return Err(ShardError::ManifestShardCountMismatch {
actual: declared_len as usize,
expected: total as usize,
});
}
if (declared_len as usize).saturating_mul(HASH_LEN) > bytes.len() - pos {
return Err(ShardError::ManifestUnexpectedEof);
}
let mut shard_hashes = Vec::with_capacity(declared_len as usize);
for _ in 0..declared_len {
let mut h = [0u8; HASH_LEN];
h.copy_from_slice(&bytes[pos..pos + HASH_LEN]);
pos += HASH_LEN;
shard_hashes.push(h);
}
if pos != bytes.len() {
return Err(ShardError::ManifestTrailingBytes);
}
Ok(ShardSet {
pack_hash,
config,
shard_hashes,
commitment,
})
}
#[cfg(test)]
mod tests {
use super::*;
fn synthetic_pack(bytes: usize) -> Vec<u8> {
let mut x: u64 = 0x9E37_79B9_7F4A_7C15;
let mut out = Vec::with_capacity(bytes);
while out.len() < bytes {
x ^= x << 13;
x ^= x >> 7;
x ^= x << 17;
out.extend_from_slice(&x.to_le_bytes());
}
out.truncate(bytes);
out
}
#[test]
fn round_trip_default_config_1_mib_first_n_shards() {
let pack = synthetic_pack(1024 * 1024);
let config = default_config();
let (shards, manifest) = encode_pack_to_shards(&pack, config).unwrap();
assert_eq!(shards.len(), 20);
assert_eq!(manifest.shard_hashes.len(), 20);
assert_eq!(manifest.pack_hash, hash::hash(&pack));
let subset: Vec<Shard> = shards.into_iter().take(16).collect();
let recovered = decode_pack_from_shards(&subset, &manifest).unwrap();
assert_eq!(recovered, pack);
}
#[test]
fn lossy_round_trip_drops_shards_0_5_10_17() {
let pack = synthetic_pack(1024 * 1024);
let config = default_config();
let (shards, manifest) = encode_pack_to_shards(&pack, config).unwrap();
let dropped = [0u16, 5, 10, 17];
let subset: Vec<Shard> = shards
.into_iter()
.filter(|s| !dropped.contains(&s.index))
.collect();
assert_eq!(subset.len(), 16);
let recovered = decode_pack_from_shards(&subset, &manifest).unwrap();
assert_eq!(recovered, pack);
}
#[test]
fn tampered_shard_is_rejected_before_decode() {
let pack = synthetic_pack(256 * 1024);
let config = default_config();
let (mut shards, manifest) = encode_pack_to_shards(&pack, config).unwrap();
let last = shards[0].bytes.len() - 1;
shards[0].bytes[last] ^= 0x01;
let subset: Vec<Shard> = shards.into_iter().take(16).collect();
let err = decode_pack_from_shards(&subset, &manifest).unwrap_err();
assert!(
matches!(err, ShardError::ShardHashMismatch { index: 0 }),
"expected ShardHashMismatch{{index: 0}}, got {err:?}"
);
}
#[test]
fn manifest_wire_format_round_trip_default_config() {
let pack = synthetic_pack(64 * 1024);
let (_, manifest) = encode_pack_to_shards(&pack, default_config()).unwrap();
let bytes = encode_manifest(&manifest).unwrap();
assert_eq!(bytes.len(), 717);
assert_eq!(&bytes[..4], &MANIFEST_MAGIC);
assert_eq!(bytes[4], MANIFEST_VERSION);
let decoded = decode_manifest(&bytes).unwrap();
assert_eq!(decoded, manifest);
}
#[test]
fn manifest_decode_rejects_bad_magic() {
let pack = synthetic_pack(32 * 1024);
let (_, manifest) = encode_pack_to_shards(&pack, default_config()).unwrap();
let mut bytes = encode_manifest(&manifest).unwrap();
bytes[0] = b'X';
let err = decode_manifest(&bytes).unwrap_err();
assert!(
matches!(err, ShardError::InvalidManifestPrologue("bad magic")),
"expected InvalidManifestPrologue(bad magic), got {err:?}"
);
}
#[test]
fn manifest_decode_rejects_unsupported_version() {
let pack = synthetic_pack(32 * 1024);
let (_, manifest) = encode_pack_to_shards(&pack, default_config()).unwrap();
let mut bytes = encode_manifest(&manifest).unwrap();
bytes[4] = 0xFF;
let err = decode_manifest(&bytes).unwrap_err();
assert!(
matches!(
err,
ShardError::InvalidManifestPrologue("unsupported version")
),
"expected InvalidManifestPrologue(unsupported version), got {err:?}"
);
}
#[test]
fn manifest_decode_rejects_trailing_bytes() {
let pack = synthetic_pack(32 * 1024);
let (_, manifest) = encode_pack_to_shards(&pack, default_config()).unwrap();
let mut bytes = encode_manifest(&manifest).unwrap();
bytes.push(0xAB);
let err = decode_manifest(&bytes).unwrap_err();
assert!(
matches!(err, ShardError::ManifestTrailingBytes),
"expected ManifestTrailingBytes, got {err:?}"
);
}
#[test]
fn manifest_decode_rejects_truncated_body() {
let pack = synthetic_pack(32 * 1024);
let (_, manifest) = encode_pack_to_shards(&pack, default_config()).unwrap();
let mut bytes = encode_manifest(&manifest).unwrap();
bytes.truncate(bytes.len() - 1);
let err = decode_manifest(&bytes).unwrap_err();
assert!(
matches!(err, ShardError::ManifestUnexpectedEof),
"expected ManifestUnexpectedEof, got {err:?}"
);
}
#[test]
fn manifest_decode_rejects_oversize_input() {
let bytes = vec![0u8; MANIFEST_MAX_BYTES + 1];
let err = decode_manifest(&bytes).unwrap_err();
assert!(
matches!(err, ShardError::ManifestTooLarge { .. }),
"expected ManifestTooLarge, got {err:?}"
);
}
#[test]
fn manifest_decode_rejects_zero_config() {
let mut bytes = Vec::new();
bytes.extend_from_slice(&MANIFEST_MAGIC);
bytes.push(MANIFEST_VERSION);
bytes.extend_from_slice(&[0u8; HASH_LEN]); bytes.extend_from_slice(&0u16.to_le_bytes()); bytes.extend_from_slice(&4u16.to_le_bytes()); bytes.extend_from_slice(&[0u8; HASH_LEN]); bytes.extend_from_slice(&0u32.to_le_bytes()); let err = decode_manifest(&bytes).unwrap_err();
assert!(
matches!(err, ShardError::ManifestZeroShardCount { .. }),
"expected ManifestZeroShardCount, got {err:?}"
);
}
#[test]
fn insufficient_shards_returns_error() {
let pack = synthetic_pack(64 * 1024);
let config = default_config();
let (shards, manifest) = encode_pack_to_shards(&pack, config).unwrap();
let subset: Vec<Shard> = shards.into_iter().take(15).collect();
let err = decode_pack_from_shards(&subset, &manifest).unwrap_err();
assert!(
matches!(
err,
ShardError::InsufficientShards {
provided: 15,
minimum: 16,
}
),
"expected InsufficientShards{{15, 16}}, got {err:?}"
);
}
}