#[cfg(test)]
mod tests;
use std::fmt;
const PP: &[u8; 9] = b"101110001";
const STRIDE: usize = 8192;
#[derive(Debug)]
pub enum Error {
ZeroK,
ZeroM,
BigN,
KGtN,
NotEnoughChunks,
Tbd,
}
impl std::fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"Zfec error: {}",
match self {
Self::ZeroK => "'k' must be greater than 0",
Self::ZeroM => "'m' must be greater than 0",
Self::BigN => "'n' must be less than 257",
Self::KGtN => "'k' must be less than 'n'",
Self::NotEnoughChunks => "Not enough chunks were provided",
Self::Tbd => "Unknown error",
}
)
}
}
impl std::error::Error for Error {}
type Gf = u8;
type Result<Fec> = std::result::Result<Fec, Error>;
#[derive(Debug, Clone)]
pub struct Chunk {
pub data: Vec<u8>,
pub index: usize,
}
impl Chunk {
pub fn new(data: Vec<u8>, index: usize) -> Self {
Self {
data: data,
index: index,
}
}
}
impl From<Chunk> for (Vec<Gf>, usize) {
fn from(val: Chunk) -> Self {
(val.data, val.index)
}
}
impl std::fmt::Display for Chunk {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}, {:?}", self.index, self.data)
}
}
pub struct Fec {
k: usize,
m: usize,
enc_matrix: Vec<Gf>,
statics: Statics,
}
impl Fec {
pub fn new(k: usize, m: usize) -> Result<Fec> {
if k < 1 {
return Err(Error::ZeroK);
}
if m < 1 {
return Err(Error::ZeroM);
}
if m > 256 {
return Err(Error::BigN);
}
if k > m {
return Err(Error::KGtN);
}
let mut tmp_m: Vec<Gf> = vec![0; m * k];
let statics = Statics::new();
let mut enc_matrix: Vec<Gf> = vec![0; m * k];
let mut ret_val = Fec {
k: k,
m: m,
enc_matrix: vec![], statics: statics,
};
tmp_m[0] = 1;
for col in 1..k {
tmp_m[col] = 0;
}
for row in 0..(m - 1) {
let p: &mut [u8] = &mut tmp_m[(row + 1) * k..];
for col in 0..k {
p[col] = ret_val.statics.gf_exp[Statics::modnn((row * col) as i32) as usize];
}
}
ret_val.statics._invert_vdm(&mut tmp_m, k);
ret_val.statics._matmul(
&tmp_m[k * k..],
&tmp_m[..],
&mut enc_matrix[k * k..],
m - k,
k,
k,
);
for i in 0..k {
enc_matrix[i * (k + 1)] = 1;
}
ret_val.enc_matrix = enc_matrix;
Ok(ret_val)
}
pub fn encode(&self, data: &[u8]) -> Result<(Vec<Chunk>, usize)> {
let chunk_size = self.chunk_size(data.len());
let data_slice = &data[..];
let mut chunks = vec![];
let mut padding = 0;
for i in 0..self.k {
let mut temp_vec = vec![];
if (i * chunk_size) >= data_slice.len() {
temp_vec.append(&mut vec![0; chunk_size].to_vec());
padding += chunk_size;
} else if ((i * chunk_size) < data_slice.len())
&& (((i + 1) * chunk_size) > data_slice.len())
{
temp_vec.append(&mut data_slice[i * chunk_size..].to_vec());
let added = ((i + 1) * chunk_size) as usize - data_slice.len();
for _ in 0..added {
temp_vec.push(0);
}
padding += added;
} else {
let new_chunk =
&data_slice[(i * chunk_size) as usize..((i + 1) * chunk_size) as usize];
temp_vec.append(&mut new_chunk.to_vec())
}
chunks.push(temp_vec);
}
let num_check_blocks_produced = self.m - self.k;
let mut check_blocks_produced = vec![vec![0; chunk_size]; num_check_blocks_produced];
let check_block_ids: Vec<usize> = (self.k..self.m).map(|x| x as usize).collect();
let mut k = 0;
while k < chunk_size {
let stride = if (chunk_size - k) < STRIDE {
chunk_size - k
} else {
STRIDE
};
for i in 0..num_check_blocks_produced {
let fecnum = check_block_ids[i];
if fecnum < self.k {
return Err(Error::Tbd);
}
let p = &self.enc_matrix[fecnum as usize * self.k..];
for j in 0..self.k {
self.statics.addmul(
&mut check_blocks_produced[i][k..],
&chunks[j][k..k + stride],
p[j],
stride,
);
}
}
k += STRIDE;
}
let mut ret_chunks = vec![];
ret_chunks.append(&mut chunks);
ret_chunks.append(&mut check_blocks_produced);
let mut ret_vec = vec![];
for (i, chunk) in ret_chunks.iter().enumerate() {
ret_vec.push(Chunk {
index: i,
data: chunk.to_vec(),
});
}
Ok((ret_vec, padding))
}
pub fn decode(&self, encoded_data: &Vec<Chunk>, padding: usize) -> Result<Vec<u8>> {
if encoded_data.len() < self.k {
return Err(Error::NotEnoughChunks);
}
let mut share_nums: Vec<usize> = vec![];
let mut chunks: Vec<Vec<u8>> = vec![vec![]; self.m];
for chunk in encoded_data {
let num = chunk.index;
share_nums.push(num);
chunks[num] = chunk.data.clone();
}
let sz = chunks[share_nums[0] as usize].len();
let mut ret_chunks = vec![vec![0; sz]; self.k];
let mut complete = true;
let mut missing = std::collections::VecDeque::new();
let mut replaced = vec![];
for i in 0..self.k {
if !share_nums.contains(&i) {
complete = false;
missing.push_back(i);
}
}
for i in self.k..self.m {
if chunks[i].len() != 0 {
match missing.pop_front() {
Some(index) => {
replaced.push(index);
share_nums.insert(index, i);
chunks[index] = chunks[i].to_vec();
}
None => {}
}
}
}
if complete {
let flat = Self::flatten(&mut chunks[..self.k].to_vec());
return Ok(flat[..flat.len() - padding].to_vec());
}
let mut m_dec = vec![0; self.k * self.k];
let mut outix = 0;
self.build_decode_matrix_into_space(&share_nums, self.k, &mut m_dec[..]);
for row in 0..self.k {
assert!((share_nums[row] >= self.k) || (share_nums[row] == row));
if share_nums[row] >= self.k {
for i in 0..sz {
ret_chunks[outix][i] = 0;
}
for col in 0..self.k {
self.statics.addmul(
&mut ret_chunks[outix][..],
&chunks[col][..],
m_dec[row * self.k + col],
sz,
);
}
outix += 1;
}
}
for i in 0..replaced.len() {
chunks[replaced[i]] = ret_chunks[i].to_vec();
}
let ret_vec = Self::flatten(&mut chunks[0..self.k].to_vec());
Ok(ret_vec[..ret_vec.len() - padding].to_vec())
}
fn chunk_size(&self, data_len: usize) -> usize {
(data_len as f64 / self.k as f64).ceil() as usize
}
fn flatten(square: &mut Vec<Vec<u8>>) -> Vec<u8> {
let mut ret_vec = vec![];
for chunk in square {
ret_vec.append(chunk);
}
ret_vec
}
fn build_decode_matrix_into_space(&self, index: &[usize], k: usize, matrix: &mut [Gf]) {
for i in 0..k {
let p = &mut matrix[i * k..];
if index[i] < k {
p[i] = 1;
} else {
for j in 0..k {
p[j] = self.enc_matrix[(index[i] * self.k) + j];
}
}
}
self.statics._invert_mat(matrix, k);
}
}
struct Statics {
gf_exp: [Gf; 510],
inverse: [Gf; 256],
gf_mul_table: [[Gf; 256]; 256],
}
impl Statics {
pub fn new() -> Self {
let mut gf_exp: [Gf; 510] = [0; 510];
let mut gf_log: [i32; 256] = [0; 256];
let mut inverse: [Gf; 256] = [0; 256];
let mut gf_mul_table: [[Gf; 256]; 256] = [[0; 256]; 256];
Self::generate_gf(&mut gf_exp, &mut gf_log, &mut inverse);
Self::_init_mul_table(&mut gf_mul_table, &mut gf_exp, &mut gf_log);
Self {
gf_exp: gf_exp,
inverse: inverse,
gf_mul_table: gf_mul_table,
}
}
fn generate_gf(gf_exp: &mut [Gf; 510], gf_log: &mut [i32; 256], inverse: &mut [Gf; 256]) {
let mut mask: Gf;
mask = 1;
gf_exp[8] = 0;
for i in 0..8 {
gf_exp[i] = mask;
gf_log[gf_exp[i] as usize] = i as i32;
if PP[i] == b'1' {
gf_exp[8] ^= mask;
}
mask <<= 1;
}
gf_log[gf_exp[8] as usize] = 8;
mask = 1 << 7;
for i in 9..255 {
if gf_exp[i - 1] >= mask {
gf_exp[i] = gf_exp[8] ^ ((gf_exp[i - 1] ^ mask) << 1);
} else {
gf_exp[i] = gf_exp[i - 1] << 1;
}
gf_log[gf_exp[i] as usize] = i as i32;
}
gf_log[0] = 255;
for i in 0..255 {
gf_exp[i + 255] = gf_exp[i];
}
inverse[0] = 0;
inverse[1] = 1;
for i in 2..=255 {
inverse[i] = gf_exp[255 - gf_log[i] as usize];
}
}
fn _init_mul_table(
gf_mul_table: &mut [[Gf; 256]; 256],
gf_exp: &[Gf; 510],
gf_log: &[i32; 256],
) {
for i in 0..256 {
for j in 0..256 {
gf_mul_table[i][j] = gf_exp[Self::modnn(gf_log[i] + gf_log[j]) as usize];
}
}
for j in 0..256 {
gf_mul_table[j][0] = 0;
gf_mul_table[0][j] = 0;
}
}
fn modnn(mut x: i32) -> Gf {
while x >= 255 {
x -= 255;
x = (x >> 8) + (x & 255);
}
x as Gf
}
pub fn addmul(&self, dst: &mut [Gf], src: &[Gf], c: Gf, sz: usize) {
if c != 0 {
self._addmul1(dst, src, c, sz);
}
}
fn _addmul1(&self, dst: &mut [Gf], src: &[Gf], c: Gf, sz: usize) {
if src.len() > 0 {
let mulc = self.gf_mul_table[c as usize];
for i in 0..sz {
dst[i] ^= mulc[src[i] as usize];
}
}
}
fn _matmul(&self, a: &[Gf], b: &[Gf], c: &mut [Gf], n: usize, k: usize, m: usize) {
for row in 0..n {
for col in 0..m {
let mut acc: Gf = 0;
for i in 0..k {
let pa: Gf = a[(row * k) + i];
let pb: Gf = b[col + (i * m)];
acc ^= self.gf_mul(pa, pb);
}
c[row * m + col] = acc;
}
}
}
fn _invert_vdm(&self, src: &mut Vec<Gf>, k: usize) {
let (mut b, mut c, mut p): (Vec<Gf>, Vec<Gf>, Vec<Gf>) =
(vec![0; k], vec![0; k], vec![0; k]);
let (mut t, mut xx): (Gf, Gf);
if k == 1 {
return;
}
let mut j = 1;
for i in 0..k {
c[i] = 0;
p[i] = src[j];
j += k;
}
c[k - 1] = p[0];
for i in 1..k {
let p_i = p[i];
for j in (k - 1 - (i - 1))..(k - 1) {
c[j] ^= self.gf_mul(p_i, c[j + 1]);
}
c[k - 1] ^= p_i;
}
for row in 0..k {
xx = p[row];
t = 1;
b[k - 1] = 1;
for i in (1..=(k - 1)).rev() {
b[i - 1] = c[i] ^ self.gf_mul(xx, b[i]);
t = self.gf_mul(xx, t) ^ b[i - 1];
}
for col in 0..k {
src[col * k + row] = self.gf_mul(self.inverse[t as usize], b[col]);
}
}
return;
}
fn gf_mul(&self, x: Gf, y: Gf) -> Gf {
self.gf_mul_table[x as usize][y as usize]
}
fn _invert_mat(&self, src: &mut [Gf], k: usize) {
let mut c: Gf;
let (mut irow, mut icol) = (0, 0);
let mut indxc = vec![0; k];
let mut indxr = vec![0; k];
let mut ipiv = vec![0; k];
let mut id_row = vec![0; k];
for i in 0..k {
ipiv[i] = 0;
}
for col in 0..k {
let mut piv_found: bool = false;
if ipiv[col] != 1 && src[col * k + col] != 0 {
irow = col;
icol = col;
}
for row in 0..k {
if ipiv[row] != 1 {
for ix in 0..k {
if ipiv[ix] == 0 {
if src[row * k + ix] != 0 {
irow = row;
icol = ix;
piv_found = true;
}
} else {
assert!(ipiv[ix] <= 1);
}
if piv_found {
break;
}
}
}
if piv_found {
break;
}
}
ipiv[icol] += 1;
if irow != icol {
for ix in 0..k {
let tmp = src[irow * k + ix];
src[irow * k + ix] = src[icol * k + ix];
src[icol * k + ix] = tmp;
}
}
indxr[col] = irow;
indxc[col] = icol;
let pivot_row = &mut src[icol * k..(icol + 1) * k];
c = pivot_row[icol];
assert!(c != 0);
if c != 1 {
c = self.inverse[c as usize];
pivot_row[icol] = 1;
for ix in 0..k {
pivot_row[ix] = self.gf_mul(c, pivot_row[ix]);
}
}
id_row[icol] = 1;
if pivot_row != id_row {
let mut pivot_clone = vec![0; pivot_row.len()];
for val in pivot_row.iter() {
pivot_clone.push(*val);
}
for ix in 0..k {
let p = &mut src[ix * k..(ix + 1) * k];
if ix != icol {
c = p[icol];
p[icol] = 0;
self.addmul(p, &pivot_clone[k..], c, k);
}
}
}
id_row[icol] = 0;
}
for col in (1..=k).rev() {
if indxr[col - 1] != indxc[col - 1] {
for row in 0..k {
let tmp = src[row * k + indxr[col - 1]];
src[row * k + indxr[col - 1]] = src[row * k + indxc[col - 1]];
src[row * k + indxc[col - 1]] = tmp;
}
}
}
}
}