use reed_solomon_erasure::galois_8;
use reed_solomon_erasure::ReedSolomon;
use super::blob_ref::Encoding;
use super::blob_tree::{ChunkRefV3, StripeBlock};
use super::error::BlobError;
pub const RS_STRIPE_TARGET_BYTES: u64 = 40 * 1024 * 1024;
pub const RS_STRIPE_MIN_BYTES: u64 = 8 * 1024 * 1024;
pub const DEFAULT_RS_K: u8 = 10;
pub const DEFAULT_RS_M: u8 = 4;
pub const RS_MAX_KM_SUM: u16 = 255;
pub const RS_WARN_KM_SUM: u16 = 64;
pub const DATAFORTS_BLOB_ERASURE_SUPPORTED: &str = "dataforts:blob-erasure-supported";
#[derive(Default)]
pub enum ErasureSupportProbe {
#[default]
AlwaysSupported,
ForceReplicated,
Dynamic(Box<dyn Fn() -> bool + Send + Sync>),
}
impl ErasureSupportProbe {
pub fn check(&self) -> bool {
match self {
ErasureSupportProbe::AlwaysSupported => true,
ErasureSupportProbe::ForceReplicated => false,
ErasureSupportProbe::Dynamic(f) => f(),
}
}
}
impl std::fmt::Debug for ErasureSupportProbe {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ErasureSupportProbe::AlwaysSupported => {
f.write_str("ErasureSupportProbe::AlwaysSupported")
}
ErasureSupportProbe::ForceReplicated => {
f.write_str("ErasureSupportProbe::ForceReplicated")
}
ErasureSupportProbe::Dynamic(_) => f.write_str("ErasureSupportProbe::Dynamic(..)"),
}
}
}
pub fn erasure_downgrade(encoding: Encoding, probe: &ErasureSupportProbe) -> Encoding {
match encoding {
Encoding::ReedSolomon { .. } if !probe.check() => Encoding::Replicated,
other => other,
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct RsParams {
pub k: u8,
pub m: u8,
}
impl RsParams {
pub const fn new(k: u8, m: u8) -> Self {
Self { k, m }
}
pub const fn default_production() -> Self {
Self {
k: DEFAULT_RS_K,
m: DEFAULT_RS_M,
}
}
pub fn validate(&self) -> Result<(), BlobError> {
if self.k == 0 {
return Err(BlobError::Backend(
"RS params: k must be >= 1; zero-data stripe is invalid".to_owned(),
));
}
if self.m == 0 {
return Err(BlobError::Backend(
"RS params: m must be >= 1; zero-parity stripe cannot reconstruct losses"
.to_owned(),
));
}
if self.k as u16 + self.m as u16 > RS_MAX_KM_SUM {
return Err(BlobError::Backend(format!(
"RS params: k + m = {} exceeds the wire-format maximum {}",
self.k as u16 + self.m as u16,
RS_MAX_KM_SUM
)));
}
Ok(())
}
pub fn from_encoding(encoding: Encoding) -> Option<Self> {
match encoding {
Encoding::ReedSolomon { k, m } => Some(Self { k, m }),
Encoding::Replicated => None,
}
}
}
impl Default for RsParams {
fn default() -> Self {
Self::default_production()
}
}
pub struct RsEncoder {
rs: ReedSolomon<galois_8::Field>,
params: RsParams,
}
impl RsEncoder {
pub fn new(params: RsParams) -> Result<Self, BlobError> {
params.validate()?;
let rs = ReedSolomon::<galois_8::Field>::new(params.k as usize, params.m as usize)
.map_err(|e| {
BlobError::Backend(format!(
"RS encoder construction failed for (k={}, m={}): {:?}",
params.k, params.m, e
))
})?;
Ok(Self { rs, params })
}
pub fn params(&self) -> RsParams {
self.params
}
pub fn encode(&self, data: &[Vec<u8>]) -> Result<Vec<Vec<u8>>, BlobError> {
if data.len() != self.params.k as usize {
return Err(BlobError::Backend(format!(
"RS encode: expected {} data shards, got {}",
self.params.k,
data.len()
)));
}
let shard_len = match data.first() {
Some(first) => first.len(),
None => 0,
};
if shard_len == 0 {
return Err(BlobError::Backend(
"RS encode: data shards must be non-empty".to_owned(),
));
}
if data.iter().any(|d| d.len() != shard_len) {
return Err(BlobError::Backend(
"RS encode: all data shards must be the same length (caller is responsible \
for zero-padding short chunks)"
.to_owned(),
));
}
let mut parity: Vec<Vec<u8>> = (0..self.params.m).map(|_| vec![0u8; shard_len]).collect();
let data_refs: Vec<&[u8]> = data.iter().map(|d| d.as_slice()).collect();
let mut parity_refs: Vec<&mut [u8]> = parity.iter_mut().map(|p| p.as_mut_slice()).collect();
self.rs
.encode_sep(&data_refs, &mut parity_refs)
.map_err(|e| BlobError::Backend(format!("RS encode_sep failed: {:?}", e)))?;
Ok(parity)
}
pub fn reconstruct_data(&self, shards: &mut [Option<Vec<u8>>]) -> Result<(), BlobError> {
let expected = self.params.k as usize + self.params.m as usize;
if shards.len() != expected {
return Err(BlobError::Backend(format!(
"RS reconstruct_data: expected {} shard slots (k={} + m={}), got {}",
expected,
self.params.k,
self.params.m,
shards.len()
)));
}
self.rs
.reconstruct_data(shards)
.map_err(|e| BlobError::Backend(format!("RS reconstruct_data failed: {:?}", e)))
}
}
pub struct ClosedStripe {
pub block: StripeBlock,
pub parity_bytes: Vec<Vec<u8>>,
}
pub struct RsStriper {
rs_params: RsParams,
encoder: RsEncoder,
in_flight: Vec<(Vec<u8>, ChunkRefV3)>,
in_flight_data_bytes: u64,
closed_count: u64,
}
impl RsStriper {
pub fn new(rs_params: RsParams) -> Result<Self, BlobError> {
let encoder = RsEncoder::new(rs_params)?;
Ok(Self {
rs_params,
encoder,
in_flight: Vec::new(),
in_flight_data_bytes: 0,
closed_count: 0,
})
}
pub fn push_chunk(
&mut self,
bytes: Vec<u8>,
chunk_ref: ChunkRefV3,
) -> Result<Option<ClosedStripe>, BlobError> {
if !chunk_ref.is_data() {
return Err(BlobError::Backend(
"RsStriper::push_chunk received a non-data chunk; striper only \
accepts Data role chunks (parity is computed internally)"
.to_owned(),
));
}
let chunk_bytes = chunk_ref.size as u64;
self.in_flight_data_bytes = self.in_flight_data_bytes.saturating_add(chunk_bytes);
self.in_flight.push((bytes, chunk_ref));
if self.in_flight.len() >= self.rs_params.k as usize {
let closed = self.close_stripe_with_rs()?;
return Ok(Some(closed));
}
Ok(None)
}
pub fn finalize(mut self) -> Result<Option<ClosedStripe>, BlobError> {
if self.in_flight.is_empty() {
return Ok(None);
}
Ok(Some(self.close_stripe_as_replicated()))
}
pub fn closed_stripe_count(&self) -> u64 {
self.closed_count
}
fn close_stripe_with_rs(&mut self) -> Result<ClosedStripe, BlobError> {
let k = self.rs_params.k as usize;
let m = self.rs_params.m as usize;
let in_flight = std::mem::take(&mut self.in_flight);
self.in_flight_data_bytes = 0;
if in_flight.len() != k {
return Err(BlobError::Backend(format!(
"RS striper: stripe close expected exactly {} data shards, got {}",
k,
in_flight.len()
)));
}
let max_len = in_flight
.iter()
.map(|(b, _)| b.len())
.max()
.unwrap_or(1)
.max(1);
let mut padded: Vec<Vec<u8>> = Vec::with_capacity(k);
let mut data_refs: Vec<ChunkRefV3> = Vec::with_capacity(k);
for (mut bytes, chunk_ref) in in_flight {
if bytes.len() < max_len {
bytes.resize(max_len, 0);
}
padded.push(bytes);
data_refs.push(chunk_ref);
}
let parity_bytes = self.encoder.encode(&padded)?;
let mut parity_refs: Vec<ChunkRefV3> = Vec::with_capacity(m);
for (i, pbytes) in parity_bytes.iter().enumerate() {
let phash: [u8; 32] = blake3::hash(pbytes).into();
parity_refs.push(ChunkRefV3::parity(phash, pbytes.len() as u32, i as u8));
}
let mut chunks: Vec<ChunkRefV3> = data_refs;
chunks.extend(parity_refs);
let block = StripeBlock {
encoding: Encoding::ReedSolomon {
k: self.rs_params.k,
m: self.rs_params.m,
},
chunks,
};
block.validate()?;
self.closed_count = self.closed_count.saturating_add(1);
Ok(ClosedStripe {
block,
parity_bytes,
})
}
fn close_stripe_as_replicated(&mut self) -> ClosedStripe {
let in_flight = std::mem::take(&mut self.in_flight);
self.in_flight_data_bytes = 0;
let chunks: Vec<ChunkRefV3> = in_flight.into_iter().map(|(_, r)| r).collect();
let block = StripeBlock {
encoding: Encoding::Replicated,
chunks,
};
self.closed_count = self.closed_count.saturating_add(1);
ClosedStripe {
block,
parity_bytes: Vec::new(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn encode_then_drop_m_shards_then_reconstruct_round_trips() {
let params = RsParams { k: 4, m: 2 };
let encoder = RsEncoder::new(params).unwrap();
let data: Vec<Vec<u8>> = (0..4u8)
.map(|i| (0..1024).map(|j| i.wrapping_add(j as u8)).collect())
.collect();
let parity = encoder.encode(&data).unwrap();
assert_eq!(parity.len(), 2);
assert_eq!(parity[0].len(), 1024);
let mut shards: Vec<Option<Vec<u8>>> = data
.iter()
.cloned()
.chain(parity.iter().cloned())
.map(Some)
.collect();
shards[1] = None; shards[5] = None;
encoder.reconstruct_data(&mut shards).unwrap();
assert_eq!(
shards[1].as_ref().unwrap(),
&data[1],
"reconstructed data shard 1 must equal the original"
);
assert_eq!(shards[0].as_ref().unwrap(), &data[0]);
assert_eq!(shards[2].as_ref().unwrap(), &data[2]);
assert_eq!(shards[3].as_ref().unwrap(), &data[3]);
}
#[test]
fn dropping_more_than_m_shards_fails_reconstruction() {
let params = RsParams { k: 4, m: 2 };
let encoder = RsEncoder::new(params).unwrap();
let data: Vec<Vec<u8>> = (0..4u8).map(|i| vec![i; 512]).collect();
let parity = encoder.encode(&data).unwrap();
let mut shards: Vec<Option<Vec<u8>>> = data
.iter()
.cloned()
.chain(parity.iter().cloned())
.map(Some)
.collect();
shards[0] = None;
shards[1] = None;
shards[2] = None;
let err = encoder.reconstruct_data(&mut shards).unwrap_err();
let msg = err.to_string();
assert!(
msg.contains("reconstruct_data") || msg.contains("TooFew"),
"expected an RS-library failure, got: {}",
msg
);
}
#[test]
fn parity_loss_with_full_data_set_succeeds_without_touching_data() {
let params = RsParams { k: 4, m: 2 };
let encoder = RsEncoder::new(params).unwrap();
let data: Vec<Vec<u8>> = (0..4u8).map(|i| vec![i.wrapping_mul(7); 256]).collect();
let parity = encoder.encode(&data).unwrap();
let mut shards: Vec<Option<Vec<u8>>> = data
.iter()
.cloned()
.chain(parity.iter().cloned())
.map(Some)
.collect();
shards[4] = None;
shards[5] = None;
encoder.reconstruct_data(&mut shards).unwrap();
for i in 0..4 {
assert_eq!(shards[i].as_ref().unwrap(), &data[i]);
}
}
#[test]
fn validate_rejects_malformed_params() {
assert!(RsParams { k: 0, m: 4 }.validate().is_err());
assert!(RsParams { k: 10, m: 0 }.validate().is_err());
assert!(RsParams { k: 200, m: 200 }.validate().is_err());
assert!(RsParams::default_production().validate().is_ok());
}
#[test]
fn from_encoding_extracts_params() {
assert_eq!(
RsParams::from_encoding(Encoding::ReedSolomon { k: 6, m: 3 }),
Some(RsParams { k: 6, m: 3 })
);
assert_eq!(RsParams::from_encoding(Encoding::Replicated), None);
}
#[test]
fn encode_rejects_uneven_data_shard_lengths() {
let encoder = RsEncoder::new(RsParams { k: 3, m: 2 }).unwrap();
let data = vec![vec![0u8; 100], vec![1u8; 50], vec![2u8; 100]];
assert!(encoder.encode(&data).is_err());
}
#[test]
fn encode_rejects_wrong_data_shard_count() {
let encoder = RsEncoder::new(RsParams { k: 4, m: 2 }).unwrap();
let data = vec![vec![0u8; 100], vec![1u8; 100]];
assert!(encoder.encode(&data).is_err());
}
#[test]
fn erasure_support_probe_static_variants() {
assert!(ErasureSupportProbe::AlwaysSupported.check());
assert!(!ErasureSupportProbe::ForceReplicated.check());
assert!(ErasureSupportProbe::default().check());
}
#[test]
fn erasure_support_probe_dynamic_consults_closure() {
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
let flag = Arc::new(AtomicBool::new(false));
let f = flag.clone();
let probe = ErasureSupportProbe::Dynamic(Box::new(move || f.load(Ordering::Relaxed)));
assert!(!probe.check());
flag.store(true, Ordering::Relaxed);
assert!(probe.check());
}
#[test]
fn erasure_downgrade_substitutes_only_for_rs_on_reject() {
let rs = Encoding::ReedSolomon { k: 10, m: 4 };
let rep = Encoding::Replicated;
assert_eq!(
erasure_downgrade(rs, &ErasureSupportProbe::AlwaysSupported),
rs
);
assert_eq!(
erasure_downgrade(rep, &ErasureSupportProbe::AlwaysSupported),
rep
);
assert_eq!(
erasure_downgrade(rs, &ErasureSupportProbe::ForceReplicated),
Encoding::Replicated
);
assert_eq!(
erasure_downgrade(rep, &ErasureSupportProbe::ForceReplicated),
rep
);
}
fn det_bytes(seed: u8, len: usize) -> Vec<u8> {
let mut state: u64 = seed as u64;
(0..len)
.map(|_| {
state = state
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
(state >> 33) as u8
})
.collect()
}
#[test]
fn striper_closes_at_k_chunks_into_rs_stripe() {
let params = RsParams { k: 4, m: 2 };
let mut striper = RsStriper::new(params).unwrap();
for i in 0..3u8 {
let bytes = det_bytes(i, 100);
let hash: [u8; 32] = blake3::hash(&bytes).into();
let cref = ChunkRefV3::data(hash, 100);
assert!(striper.push_chunk(bytes, cref).unwrap().is_none());
}
let bytes = det_bytes(3, 100);
let hash: [u8; 32] = blake3::hash(&bytes).into();
let cref = ChunkRefV3::data(hash, 100);
let closed = striper.push_chunk(bytes, cref).unwrap().unwrap();
assert_eq!(closed.block.chunks.len(), 6); assert_eq!(closed.parity_bytes.len(), 2);
assert_eq!(
closed.block.chunks.iter().filter(|c| c.is_data()).count(),
4
);
assert_eq!(
closed.block.chunks.iter().filter(|c| c.is_parity()).count(),
2
);
}
#[test]
fn striper_preserves_pre_padding_sizes_in_chunk_refs() {
let params = RsParams { k: 3, m: 2 };
let mut striper = RsStriper::new(params).unwrap();
let sizes = [200, 100, 150];
let mut sent = Vec::new();
for (i, &size) in sizes.iter().enumerate() {
let bytes = det_bytes(i as u8, size);
let hash: [u8; 32] = blake3::hash(&bytes).into();
let cref = ChunkRefV3::data(hash, size as u32);
sent.push(cref);
let result = striper.push_chunk(bytes, cref).unwrap();
if i + 1 == sizes.len() {
let closed = result.unwrap();
for (j, &expected_size) in sizes.iter().enumerate() {
assert_eq!(closed.block.chunks[j].size as usize, expected_size);
}
for parity in closed.block.chunks.iter().filter(|c| c.is_parity()) {
assert_eq!(parity.size as usize, 200);
}
} else {
assert!(result.is_none());
}
}
}
#[test]
fn striper_finalize_with_partial_emits_replicated_stripe() {
let params = RsParams { k: 5, m: 2 };
let mut striper = RsStriper::new(params).unwrap();
for i in 0..3u8 {
let bytes = det_bytes(i, 50);
let hash: [u8; 32] = blake3::hash(&bytes).into();
let cref = ChunkRefV3::data(hash, 50);
assert!(striper.push_chunk(bytes, cref).unwrap().is_none());
}
let closed = striper.finalize().unwrap().unwrap();
assert_eq!(closed.block.encoding, Encoding::Replicated);
assert_eq!(closed.block.chunks.len(), 3);
assert!(closed.parity_bytes.is_empty());
assert!(closed.block.chunks.iter().all(|c| c.is_data()));
}
#[test]
fn striper_finalize_with_no_chunks_returns_none() {
let params = RsParams { k: 4, m: 2 };
let striper = RsStriper::new(params).unwrap();
assert!(striper.finalize().unwrap().is_none());
}
#[test]
fn striper_closes_multiple_stripes() {
let params = RsParams { k: 3, m: 2 };
let mut striper = RsStriper::new(params).unwrap();
let mut closed_count = 0u64;
for i in 0..6u8 {
let bytes = det_bytes(i, 64);
let hash: [u8; 32] = blake3::hash(&bytes).into();
let cref = ChunkRefV3::data(hash, 64);
if striper.push_chunk(bytes, cref).unwrap().is_some() {
closed_count += 1;
}
}
assert_eq!(closed_count, 2);
assert_eq!(striper.closed_stripe_count(), 2);
}
#[test]
fn striper_rejects_parity_role_inputs() {
let mut striper = RsStriper::new(RsParams { k: 3, m: 2 }).unwrap();
let bytes = vec![0u8; 10];
let parity_ref = ChunkRefV3::parity([0u8; 32], 10, 0);
assert!(striper.push_chunk(bytes, parity_ref).is_err());
}
#[test]
fn striper_output_round_trips_through_rs_encoder() {
let params = RsParams { k: 3, m: 2 };
let mut striper = RsStriper::new(params).unwrap();
let originals: Vec<Vec<u8>> = (0..3u8).map(|i| det_bytes(i, 128)).collect();
let mut closed: Option<ClosedStripe> = None;
for (i, bytes) in originals.iter().enumerate() {
let hash: [u8; 32] = blake3::hash(bytes).into();
let cref = ChunkRefV3::data(hash, bytes.len() as u32);
let result = striper.push_chunk(bytes.clone(), cref).unwrap();
if i + 1 == originals.len() {
closed = Some(result.unwrap());
}
}
let closed = closed.unwrap();
let shard_len = closed.parity_bytes[0].len();
let mut shards: Vec<Option<Vec<u8>>> = Vec::with_capacity(5);
for orig in &originals {
let mut padded = orig.clone();
padded.resize(shard_len, 0);
shards.push(Some(padded));
}
for p in &closed.parity_bytes {
shards.push(Some(p.clone()));
}
shards[1] = None; shards[4] = None;
let encoder = RsEncoder::new(params).unwrap();
encoder.reconstruct_data(&mut shards).unwrap();
let mut expected = originals[1].clone();
expected.resize(shard_len, 0);
assert_eq!(shards[1].as_ref().unwrap(), &expected);
}
#[test]
fn striper_constants_match_plan_defaults() {
assert_eq!(DEFAULT_RS_K, 10);
assert_eq!(DEFAULT_RS_M, 4);
assert_eq!(RS_STRIPE_TARGET_BYTES, 40 * 1024 * 1024);
assert_eq!(RS_STRIPE_MIN_BYTES, 8 * 1024 * 1024);
assert_eq!(RsParams::default_production(), RsParams { k: 10, m: 4 });
}
}