use borsh::{BorshDeserialize, BorshSerialize};
use itertools::Itertools;
use reed_solomon_erasure::Field;
use reed_solomon_erasure::galois_8::ReedSolomon;
use std::collections::HashMap;
use std::io::Error;
use std::sync::Arc;
use tracing::span::EnteredSpan;
pub type ReedSolomonPart = Option<Box<[u8]>>;
pub const REED_SOLOMON_MAX_PARTS: usize = reed_solomon_erasure::galois_8::Field::ORDER;
const MAX_ENCODED_LENGTH: usize = 512usize.strict_mul(bytesize::MIB as usize);
pub fn reed_solomon_encode<T: BorshSerialize>(
rs: &ReedSolomon,
data: &T,
) -> (Vec<ReedSolomonPart>, usize) {
let mut bytes = borsh::to_vec(data).unwrap();
let encoded_length = bytes.len();
let data_parts = rs.data_shard_count();
let part_length = reed_solomon_part_length(encoded_length, data_parts);
bytes.resize(data_parts * part_length, 0);
let mut parts = bytes
.chunks_exact(part_length)
.map(|chunk| Some(chunk.to_vec().into_boxed_slice()))
.chain(itertools::repeat_n(None, rs.parity_shard_count()))
.collect_vec();
rs.reconstruct(&mut parts).unwrap();
(parts, encoded_length)
}
pub fn reed_solomon_decode<T: BorshDeserialize>(
rs: &ReedSolomon,
parts: &mut [ReedSolomonPart],
encoded_length: usize,
) -> Result<T, Error> {
if encoded_length > MAX_ENCODED_LENGTH {
return Err(Error::other("encoded length is too large"));
}
if let Err(err) = rs.reconstruct(parts) {
return Err(Error::other(err));
}
let mut buf = Vec::with_capacity(encoded_length);
for part_opt in parts.iter() {
let part = part_opt.as_ref().expect("Missing shard");
let remaining = encoded_length - buf.len();
let take_len = remaining.min(part.len());
buf.extend_from_slice(&part[..take_len]);
}
let end = encoded_length.min(buf.len());
T::try_from_slice(&buf[..end])
}
pub fn reed_solomon_part_length(encoded_length: usize, data_parts: usize) -> usize {
(encoded_length + data_parts - 1) / data_parts
}
pub fn reed_solomon_num_data_parts(total_parts: usize, ratio_data_parts: f64) -> usize {
std::cmp::max((total_parts as f64 * ratio_data_parts) as usize, 1)
}
pub struct ReedSolomonEncoder {
rs: Option<ReedSolomon>,
}
pub trait ReedSolomonEncoderSerialize: BorshSerialize {
fn serialize_single_part(&self) -> std::io::Result<Vec<u8>> {
borsh::to_vec(self)
}
}
pub trait ReedSolomonEncoderDeserialize: BorshDeserialize {
fn deserialize_single_part(data: &[u8]) -> std::io::Result<Self> {
Self::try_from_slice(data)
}
}
impl ReedSolomonEncoder {
pub fn new(total_parts: usize, ratio_data_parts: f64) -> ReedSolomonEncoder {
let rs = if total_parts > 1 {
let data_parts = reed_solomon_num_data_parts(total_parts, ratio_data_parts);
Some(ReedSolomon::new(data_parts, total_parts - data_parts).unwrap())
} else {
None
};
Self { rs }
}
pub fn total_parts(&self) -> usize {
match self.rs {
Some(ref rs) => rs.total_shard_count(),
None => 1,
}
}
pub fn data_parts(&self) -> usize {
match self.rs {
Some(ref rs) => rs.data_shard_count(),
None => 1,
}
}
pub fn encode<T: ReedSolomonEncoderSerialize>(
&self,
data: &T,
) -> (Vec<ReedSolomonPart>, usize) {
match self.rs {
Some(ref rs) => reed_solomon_encode(rs, data),
None => {
let bytes = T::serialize_single_part(&data).unwrap();
let size = bytes.len();
(vec![Some(bytes.into_boxed_slice())], size)
}
}
}
pub fn decode<T: ReedSolomonEncoderDeserialize>(
&self,
parts: &mut [ReedSolomonPart],
encoded_length: usize,
) -> Result<T, std::io::Error> {
match self.rs {
Some(ref rs) => reed_solomon_decode(rs, parts, encoded_length),
None => {
if parts.len() != 1 {
return Err(std::io::Error::other(format!(
"Expected single part, received {}",
parts.len()
)));
}
let Some(part) = &parts[0] else {
return Err(std::io::Error::other("Received part is not expected to be None"));
};
T::deserialize_single_part(part.as_ref())
}
}
}
}
pub struct ReedSolomonEncoderCache {
ratio_data_parts: f64,
instances: HashMap<usize, Arc<ReedSolomonEncoder>>,
}
impl ReedSolomonEncoderCache {
pub fn new(ratio_data_parts: f64) -> Self {
Self { ratio_data_parts, instances: HashMap::new() }
}
pub fn entry(&mut self, total_parts: usize) -> Arc<ReedSolomonEncoder> {
self.instances
.entry(total_parts)
.or_insert_with(|| {
Arc::new(ReedSolomonEncoder::new(total_parts, self.ratio_data_parts))
})
.clone()
}
}
pub struct ReedSolomonPartsTracker<T> {
parts: Vec<ReedSolomonPart>,
encoded_length: usize,
data_parts_present: usize,
encoder: Arc<ReedSolomonEncoder>,
total_parts_size: usize,
phantom: std::marker::PhantomData<T>,
}
pub enum InsertPartResult<T> {
Accepted,
PartAlreadyAvailable,
InvalidPartOrd,
Decoded(std::io::Result<T>),
}
impl<T: ReedSolomonEncoderDeserialize> ReedSolomonPartsTracker<T> {
pub fn new(encoder: Arc<ReedSolomonEncoder>, encoded_length: usize) -> Self {
Self {
data_parts_present: 0,
parts: vec![None; encoder.total_parts()],
total_parts_size: 0,
encoded_length,
encoder,
phantom: std::marker::PhantomData,
}
}
pub fn data_parts_present(&self) -> usize {
self.data_parts_present
}
pub fn total_parts_size(&self) -> usize {
self.total_parts_size
}
pub fn data_parts_required(&self) -> usize {
self.encoder.data_parts()
}
pub fn has_enough_parts(&self) -> bool {
self.data_parts_present >= self.data_parts_required()
}
pub fn encoded_length(&self) -> usize {
self.encoded_length
}
pub fn has_part(&self, part_ord: usize) -> bool {
self.parts.get(part_ord).is_some_and(|part| part.is_some())
}
pub fn insert_part(
&mut self,
part_ord: usize,
part: Box<[u8]>,
create_decode_span: Option<Box<dyn Fn() -> EnteredSpan>>,
) -> InsertPartResult<T> {
if part_ord >= self.parts.len() {
return InsertPartResult::InvalidPartOrd;
}
if self.has_part(part_ord) {
return InsertPartResult::PartAlreadyAvailable;
}
self.data_parts_present += 1;
self.total_parts_size += part.len();
self.parts[part_ord] = Some(part);
if self.has_enough_parts() {
let _decode_span = create_decode_span.map(|f| f());
InsertPartResult::Decoded(self.encoder.decode(&mut self.parts, self.encoded_length))
} else {
InsertPartResult::Accepted
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use assert_matches::assert_matches;
#[test]
fn reed_solomon_decode_returns_error_with_large_encoded_length() {
let data_shards = 1;
let parity_shards = 1;
let rs = ReedSolomon::new(data_shards, parity_shards).unwrap();
let data: Vec<u8> = b"my favorite data".to_vec();
let (mut parts, _encoded_length) = reed_solomon_encode(&rs, &data);
let encoded_length = usize::MAX;
assert_matches!(
reed_solomon_decode::<Vec<u8>>(&rs, &mut parts, encoded_length),
Err(Error { .. })
);
}
}