extern crate alloc;
use alloc::sync::Arc;
use alloc::vec;
use alloc::vec::Vec;
use smallvec::SmallVec;
use crate::errors::Error;
use crate::errors::SBSError;
use crate::matrix::Matrix;
use lru::LruCache;
#[cfg(feature = "std")]
use parking_lot::Mutex;
#[cfg(not(feature = "std"))]
use spin::Mutex;
use super::Field;
use super::ReconstructShard;
const DATA_DECODE_MATRIX_CACHE_CAPACITY: usize = 254;
#[derive(PartialEq, Debug)]
pub struct ShardByShard<'a, F: 'a + Field> {
codec: &'a ReedSolomon<F>,
cur_input: usize,
}
impl<'a, F: 'a + Field> ShardByShard<'a, F> {
pub fn new(codec: &'a ReedSolomon<F>) -> ShardByShard<'a, F> {
ShardByShard {
codec,
cur_input: 0,
}
}
pub fn parity_ready(&self) -> bool {
self.cur_input == self.codec.data_shard_count
}
pub fn reset(&mut self) -> Result<(), SBSError> {
if self.cur_input > 0 && !self.parity_ready() {
return Err(SBSError::LeftoverShards);
}
self.cur_input = 0;
Ok(())
}
pub fn reset_force(&mut self) {
self.cur_input = 0;
}
pub fn cur_input_index(&self) -> usize {
self.cur_input
}
fn return_ok_and_incre_cur_input(&mut self) -> Result<(), SBSError> {
self.cur_input += 1;
Ok(())
}
fn sbs_encode_checks<U: AsRef<[F::Elem]> + AsMut<[F::Elem]>>(
&mut self,
slices: &mut [U],
) -> Result<(), SBSError> {
let internal_checks = |codec: &ReedSolomon<F>, data: &mut [U]| {
check_piece_count!(all => codec, data);
check_slices!(multi => data);
Ok(())
};
if self.parity_ready() {
return Err(SBSError::TooManyCalls);
}
match internal_checks(self.codec, slices) {
Ok(()) => Ok(()),
Err(e) => Err(SBSError::RSError(e)),
}
}
fn sbs_encode_sep_checks<T: AsRef<[F::Elem]>, U: AsRef<[F::Elem]> + AsMut<[F::Elem]>>(
&mut self,
data: &[T],
parity: &mut [U],
) -> Result<(), SBSError> {
let internal_checks = |codec: &ReedSolomon<F>, data: &[T], parity: &mut [U]| {
check_piece_count!(data => codec, data);
check_piece_count!(parity => codec, parity);
check_slices!(multi => data, multi => parity);
Ok(())
};
if self.parity_ready() {
return Err(SBSError::TooManyCalls);
}
match internal_checks(self.codec, data, parity) {
Ok(()) => Ok(()),
Err(e) => Err(SBSError::RSError(e)),
}
}
pub fn encode<T, U>(&mut self, mut shards: T) -> Result<(), SBSError>
where
T: AsRef<[U]> + AsMut<[U]>,
U: AsRef<[F::Elem]> + AsMut<[F::Elem]>,
{
let shards = shards.as_mut();
self.sbs_encode_checks(shards)?;
self.codec.encode_single(self.cur_input, shards).unwrap();
self.return_ok_and_incre_cur_input()
}
pub fn encode_sep<T: AsRef<[F::Elem]>, U: AsRef<[F::Elem]> + AsMut<[F::Elem]>>(
&mut self,
data: &[T],
parity: &mut [U],
) -> Result<(), SBSError> {
self.sbs_encode_sep_checks(data, parity)?;
self.codec
.encode_single_sep(self.cur_input, data[self.cur_input].as_ref(), parity)
.unwrap();
self.return_ok_and_incre_cur_input()
}
}
#[derive(Debug)]
pub struct ReedSolomon<F: Field> {
data_shard_count: usize,
parity_shard_count: usize,
total_shard_count: usize,
matrix: Matrix<F>,
data_decode_matrix_cache: Mutex<LruCache<Vec<usize>, Arc<Matrix<F>>>>,
}
impl<F: Field> Clone for ReedSolomon<F> {
fn clone(&self) -> ReedSolomon<F> {
ReedSolomon::new(self.data_shard_count, self.parity_shard_count)
.expect("basic checks already passed as precondition of existence of self")
}
}
impl<F: Field> PartialEq for ReedSolomon<F> {
fn eq(&self, rhs: &ReedSolomon<F>) -> bool {
self.data_shard_count == rhs.data_shard_count
&& self.parity_shard_count == rhs.parity_shard_count
}
}
impl<F: Field> ReedSolomon<F> {
fn get_parity_rows(&self) -> SmallVec<[&[F::Elem]; 32]> {
let mut parity_rows = SmallVec::with_capacity(self.parity_shard_count);
let matrix = &self.matrix;
for i in self.data_shard_count..self.total_shard_count {
parity_rows.push(matrix.get_row(i));
}
parity_rows
}
fn build_matrix(data_shards: usize, total_shards: usize) -> Matrix<F> {
let vandermonde = Matrix::vandermonde(total_shards, data_shards);
let top = vandermonde.sub_matrix(0, 0, data_shards, data_shards);
vandermonde.multiply(&top.invert().unwrap())
}
pub fn new(data_shards: usize, parity_shards: usize) -> Result<ReedSolomon<F>, Error> {
if data_shards == 0 {
return Err(Error::TooFewDataShards);
}
if parity_shards == 0 {
return Err(Error::TooFewParityShards);
}
if data_shards + parity_shards > F::ORDER {
return Err(Error::TooManyShards);
}
let total_shards = data_shards + parity_shards;
let matrix = Self::build_matrix(data_shards, total_shards);
Ok(ReedSolomon {
data_shard_count: data_shards,
parity_shard_count: parity_shards,
total_shard_count: total_shards,
matrix,
data_decode_matrix_cache: Mutex::new(LruCache::new(DATA_DECODE_MATRIX_CACHE_CAPACITY)),
})
}
pub fn data_shard_count(&self) -> usize {
self.data_shard_count
}
pub fn parity_shard_count(&self) -> usize {
self.parity_shard_count
}
pub fn total_shard_count(&self) -> usize {
self.total_shard_count
}
fn code_some_slices<T: AsRef<[F::Elem]>, U: AsMut<[F::Elem]>>(
&self,
matrix_rows: &[&[F::Elem]],
inputs: &[T],
outputs: &mut [U],
) {
for i_input in 0..self.data_shard_count {
self.code_single_slice(matrix_rows, i_input, inputs[i_input].as_ref(), outputs);
}
}
fn code_single_slice<U: AsMut<[F::Elem]>>(
&self,
matrix_rows: &[&[F::Elem]],
i_input: usize,
input: &[F::Elem],
outputs: &mut [U],
) {
outputs.iter_mut().enumerate().for_each(|(i_row, output)| {
let matrix_row_to_use = matrix_rows[i_row][i_input];
let output = output.as_mut();
if i_input == 0 {
F::mul_slice(matrix_row_to_use, input, output);
} else {
F::mul_slice_add(matrix_row_to_use, input, output);
}
})
}
fn check_some_slices_with_buffer<T, U>(
&self,
matrix_rows: &[&[F::Elem]],
inputs: &[T],
to_check: &[T],
buffer: &mut [U],
) -> bool
where
T: AsRef<[F::Elem]>,
U: AsRef<[F::Elem]> + AsMut<[F::Elem]>,
{
self.code_some_slices(matrix_rows, inputs, buffer);
let at_least_one_mismatch_present = buffer
.iter_mut()
.enumerate()
.map(|(i, expected_parity_shard)| {
expected_parity_shard.as_ref() == to_check[i].as_ref()
})
.any(|x| !x); !at_least_one_mismatch_present
}
pub fn encode_single<T, U>(&self, i_data: usize, mut shards: T) -> Result<(), Error>
where
T: AsRef<[U]> + AsMut<[U]>,
U: AsRef<[F::Elem]> + AsMut<[F::Elem]>,
{
let slices = shards.as_mut();
check_slice_index!(data => self, i_data);
check_piece_count!(all=> self, slices);
check_slices!(multi => slices);
let (mut_input, output) = slices.split_at_mut(self.data_shard_count);
let input = mut_input[i_data].as_ref();
self.encode_single_sep(i_data, input, output)
}
pub fn encode_single_sep<U: AsRef<[F::Elem]> + AsMut<[F::Elem]>>(
&self,
i_data: usize,
single_data: &[F::Elem],
parity: &mut [U],
) -> Result<(), Error> {
check_slice_index!(data => self, i_data);
check_piece_count!(parity => self, parity);
check_slices!(multi => parity, single => single_data);
let parity_rows = self.get_parity_rows();
self.code_single_slice(&parity_rows, i_data, single_data, parity);
Ok(())
}
pub fn encode<T, U>(&self, mut shards: T) -> Result<(), Error>
where
T: AsRef<[U]> + AsMut<[U]>,
U: AsRef<[F::Elem]> + AsMut<[F::Elem]>,
{
let slices: &mut [U] = shards.as_mut();
check_piece_count!(all => self, slices);
check_slices!(multi => slices);
let (input, output) = slices.split_at_mut(self.data_shard_count);
self.encode_sep(&*input, output)
}
pub fn encode_sep<T: AsRef<[F::Elem]>, U: AsRef<[F::Elem]> + AsMut<[F::Elem]>>(
&self,
data: &[T],
parity: &mut [U],
) -> Result<(), Error> {
check_piece_count!(data => self, data);
check_piece_count!(parity => self, parity);
check_slices!(multi => data, multi => parity);
let parity_rows = self.get_parity_rows();
self.code_some_slices(&parity_rows, data, parity);
Ok(())
}
pub fn verify<T: AsRef<[F::Elem]>>(&self, slices: &[T]) -> Result<bool, Error> {
check_piece_count!(all => self, slices);
check_slices!(multi => slices);
let slice_len = slices[0].as_ref().len();
let mut buffer: SmallVec<[Vec<F::Elem>; 32]> =
SmallVec::with_capacity(self.parity_shard_count);
for _ in 0..self.parity_shard_count {
buffer.push(vec![F::zero(); slice_len]);
}
self.verify_with_buffer(slices, &mut buffer)
}
pub fn verify_with_buffer<T, U>(&self, slices: &[T], buffer: &mut [U]) -> Result<bool, Error>
where
T: AsRef<[F::Elem]>,
U: AsRef<[F::Elem]> + AsMut<[F::Elem]>,
{
check_piece_count!(all => self, slices);
check_piece_count!(parity_buf => self, buffer);
check_slices!(multi => slices, multi => buffer);
let data = &slices[0..self.data_shard_count];
let to_check = &slices[self.data_shard_count..];
let parity_rows = self.get_parity_rows();
Ok(self.check_some_slices_with_buffer(&parity_rows, data, to_check, buffer))
}
pub fn reconstruct<T: ReconstructShard<F>>(&self, slices: &mut [T]) -> Result<(), Error> {
self.reconstruct_internal(slices, false)
}
pub fn reconstruct_data<T: ReconstructShard<F>>(&self, slices: &mut [T]) -> Result<(), Error> {
self.reconstruct_internal(slices, true)
}
fn get_data_decode_matrix(
&self,
valid_indices: &[usize],
invalid_indices: &[usize],
) -> Arc<Matrix<F>> {
{
let mut cache = self.data_decode_matrix_cache.lock();
if let Some(entry) = cache.get(invalid_indices) {
return entry.clone();
}
}
let mut sub_matrix = Matrix::new(self.data_shard_count, self.data_shard_count);
for (sub_matrix_row, &valid_index) in valid_indices.iter().enumerate() {
for c in 0..self.data_shard_count {
sub_matrix.set(sub_matrix_row, c, self.matrix.get(valid_index, c));
}
}
let data_decode_matrix = Arc::new(sub_matrix.invert().unwrap());
{
let data_decode_matrix = data_decode_matrix.clone();
let mut cache = self.data_decode_matrix_cache.lock();
cache.put(Vec::from(invalid_indices), data_decode_matrix);
}
data_decode_matrix
}
fn reconstruct_internal<T: ReconstructShard<F>>(
&self,
shards: &mut [T],
data_only: bool,
) -> Result<(), Error> {
check_piece_count!(all => self, shards);
let data_shard_count = self.data_shard_count;
let mut number_present = 0;
let mut shard_len = None;
for shard in shards.iter_mut() {
if let Some(len) = shard.len() {
if len == 0 {
return Err(Error::EmptyShard);
}
number_present += 1;
if let Some(old_len) = shard_len {
if len != old_len {
return Err(Error::IncorrectShardSize);
}
}
shard_len = Some(len);
}
}
if number_present == self.total_shard_count {
return Ok(());
}
if number_present < data_shard_count {
return Err(Error::TooFewShardsPresent);
}
let shard_len = shard_len.expect("at least one shard present; qed");
let mut sub_shards: SmallVec<[&[F::Elem]; 32]> = SmallVec::with_capacity(data_shard_count);
let mut missing_data_slices: SmallVec<[&mut [F::Elem]; 32]> =
SmallVec::with_capacity(self.parity_shard_count);
let mut missing_parity_slices: SmallVec<[&mut [F::Elem]; 32]> =
SmallVec::with_capacity(self.parity_shard_count);
let mut valid_indices: SmallVec<[usize; 32]> = SmallVec::with_capacity(data_shard_count);
let mut invalid_indices: SmallVec<[usize; 32]> = SmallVec::with_capacity(data_shard_count);
for (matrix_row, shard) in shards.iter_mut().enumerate() {
let shard_data = if matrix_row >= data_shard_count && data_only {
shard.get().ok_or(None)
} else {
shard.get_or_initialize(shard_len).map_err(Some)
};
match shard_data {
Ok(shard) => {
if sub_shards.len() < data_shard_count {
sub_shards.push(shard);
valid_indices.push(matrix_row);
} else {
}
}
Err(None) => {
invalid_indices.push(matrix_row);
}
Err(Some(x)) => {
let shard = x?;
if matrix_row < data_shard_count {
missing_data_slices.push(shard);
} else {
missing_parity_slices.push(shard);
}
invalid_indices.push(matrix_row);
}
}
}
let data_decode_matrix = self.get_data_decode_matrix(&valid_indices, &invalid_indices);
let mut matrix_rows: SmallVec<[&[F::Elem]; 32]> =
SmallVec::with_capacity(self.parity_shard_count);
for i_slice in invalid_indices
.iter()
.cloned()
.take_while(|i| i < &data_shard_count)
{
matrix_rows.push(data_decode_matrix.get_row(i_slice));
}
self.code_some_slices(&matrix_rows, &sub_shards, &mut missing_data_slices);
if data_only {
Ok(())
} else {
let mut matrix_rows: SmallVec<[&[F::Elem]; 32]> =
SmallVec::with_capacity(self.parity_shard_count);
let parity_rows = self.get_parity_rows();
for i_slice in invalid_indices
.iter()
.cloned()
.skip_while(|i| i < &data_shard_count)
{
matrix_rows.push(parity_rows[i_slice - data_shard_count]);
}
{
let mut i_old_data_slice = 0;
let mut i_new_data_slice = 0;
let mut all_data_slices: SmallVec<[&[F::Elem]; 32]> =
SmallVec::with_capacity(data_shard_count);
let mut next_maybe_good = 0;
let mut push_good_up_to = move |data_slices: &mut SmallVec<_>, up_to| {
for _ in next_maybe_good..up_to {
data_slices.push(sub_shards[i_old_data_slice]);
i_old_data_slice += 1;
}
next_maybe_good = up_to + 1;
};
for i_slice in invalid_indices
.iter()
.cloned()
.take_while(|i| i < &data_shard_count)
{
push_good_up_to(&mut all_data_slices, i_slice);
all_data_slices.push(missing_data_slices[i_new_data_slice]);
i_new_data_slice += 1;
}
push_good_up_to(&mut all_data_slices, data_shard_count);
self.code_some_slices(&matrix_rows, &all_data_slices, &mut missing_parity_slices);
}
Ok(())
}
}
}