use super::consts::BOUNDARY_MARKER;
use crate::RLNCError;
use rand::Rng;
#[cfg(not(feature = "parallel"))]
use crate::common::simd::{gf256_inplace_add_vectors, gf256_mul_vec_by_scalar};
#[cfg(feature = "parallel")]
use crate::common::gf256::Gf256;
#[cfg(feature = "parallel")]
use rayon::prelude::*;
#[derive(Clone, Debug)]
pub struct Encoder {
data: Vec<u8>,
piece_count: usize,
piece_byte_len: usize,
}
impl Encoder {
pub fn get_piece_count(&self) -> usize {
self.piece_count
}
pub fn get_piece_byte_len(&self) -> usize {
self.piece_byte_len
}
pub fn get_full_coded_piece_byte_len(&self) -> usize {
self.get_piece_count() + self.get_piece_byte_len()
}
pub(crate) fn without_padding(data: Vec<u8>, piece_count: usize) -> Result<Encoder, RLNCError> {
if data.is_empty() {
return Err(RLNCError::DataLengthZero);
}
if piece_count == 0 {
return Err(RLNCError::PieceCountZero);
}
let in_data_len = data.len();
let piece_byte_len = in_data_len / piece_count;
let computed_total_data_len = piece_byte_len * piece_count;
if computed_total_data_len != in_data_len {
return Err(RLNCError::DataLengthMismatch);
}
Ok(Encoder {
data,
piece_count,
piece_byte_len,
})
}
pub fn new(mut data: Vec<u8>, piece_count: usize) -> Result<Encoder, RLNCError> {
if data.is_empty() {
return Err(RLNCError::DataLengthZero);
}
if piece_count == 0 {
return Err(RLNCError::PieceCountZero);
}
let in_data_len = data.len();
let boundary_marker_len = 1;
let piece_byte_len = (in_data_len + boundary_marker_len).div_ceil(piece_count);
let padded_data_len = piece_count * piece_byte_len;
data.resize(padded_data_len, 0);
data[in_data_len] = BOUNDARY_MARKER;
Ok(Encoder {
data,
piece_count,
piece_byte_len,
})
}
#[cfg(not(feature = "parallel"))]
pub fn code_with_coding_vector(&self, coding_vector: &[u8]) -> Result<Vec<u8>, RLNCError> {
if coding_vector.len() != self.piece_count {
return Err(RLNCError::CodingVectorLengthMismatch);
}
let mut full_coded_piece = vec![0u8; self.get_full_coded_piece_byte_len()];
full_coded_piece[..self.piece_count].copy_from_slice(coding_vector);
let coded_piece = &mut full_coded_piece[self.piece_count..];
self.data
.chunks_exact(self.piece_byte_len)
.zip(coding_vector)
.map(|(piece, &random_symbol)| gf256_mul_vec_by_scalar(piece, random_symbol))
.for_each(|cur| gf256_inplace_add_vectors(coded_piece, &cur));
Ok(full_coded_piece)
}
#[cfg(feature = "parallel")]
pub fn code_with_coding_vector(&self, coding_vector: &[u8]) -> Result<Vec<u8>, RLNCError> {
if coding_vector.len() != self.piece_count {
return Err(RLNCError::CodingVectorLengthMismatch);
}
let coded_piece = self
.data
.par_chunks_exact(self.piece_byte_len)
.zip(coding_vector)
.map(|(piece, &random_symbol)| piece.iter().map(move |&symbol| (Gf256::new(symbol) * Gf256::new(random_symbol)).get()))
.fold(
|| vec![0u8; self.piece_byte_len],
|mut acc, cur| {
acc.iter_mut().zip(cur).for_each(|(a, b)| {
*a = (Gf256::new(*a) + Gf256::new(b)).get();
});
acc
},
)
.reduce(
|| vec![0u8; self.piece_byte_len],
|mut acc, cur| {
acc.iter_mut().zip(cur).for_each(|(a, b)| {
*a ^= b;
});
acc
},
);
let mut full_coded_piece = vec![0u8; self.get_full_coded_piece_byte_len()];
full_coded_piece[..self.piece_count].copy_from_slice(coding_vector);
full_coded_piece[self.piece_count..].copy_from_slice(&coded_piece);
Ok(full_coded_piece)
}
pub fn code<R: Rng + ?Sized>(&self, rng: &mut R) -> Vec<u8> {
let random_coding_vector = (0..self.piece_count).map(|_| rng.random()).collect::<Vec<u8>>();
unsafe { self.code_with_coding_vector(&random_coding_vector).unwrap_unchecked() }
}
}
#[cfg(test)]
mod tests {
use super::{Encoder, RLNCError};
use rand::Rng;
#[test]
fn test_encoder_without_padding_invalid_data() {
let mut rng = rand::rng();
let data_byte_len_zero = 0usize;
let piece_count_non_zero = 10usize;
let data_zero: Vec<u8> = (0..data_byte_len_zero).map(|_| rng.random()).collect();
let result_data_zero = Encoder::without_padding(data_zero, piece_count_non_zero);
assert!(result_data_zero.is_err());
assert_eq!(result_data_zero.expect_err("Expected DataLengthZero error"), RLNCError::DataLengthZero);
let data_byte_len_non_zero = 100usize;
let piece_count_zero = 0usize;
let data_non_zero: Vec<u8> = (0..data_byte_len_non_zero).map(|_| rng.random()).collect();
let result_piece_count_zero = Encoder::without_padding(data_non_zero, piece_count_zero);
assert!(result_piece_count_zero.is_err());
assert_eq!(result_piece_count_zero.expect_err("Expected PieceCountZero error"), RLNCError::PieceCountZero);
let data_byte_len = 1001usize; let piece_count = 32usize;
let data = (0..data_byte_len).map(|_| rng.random()).collect::<Vec<u8>>();
let result = Encoder::without_padding(data, piece_count);
assert!(result.is_err());
assert_eq!(result.expect_err("Expected DataLengthMismatch error"), RLNCError::DataLengthMismatch);
let data_byte_len_valid = 100usize;
let piece_count_valid = 10usize;
let data_valid = (0..data_byte_len_valid).map(|_| rng.random()).collect::<Vec<u8>>();
let result_valid = Encoder::without_padding(data_valid, piece_count_valid);
assert!(result_valid.is_ok());
}
#[test]
fn test_encoder_new_invalid_inputs() {
let mut rng = rand::rng();
let data_byte_len_zero = 0;
let piece_count_non_zero = 5;
let data_zero: Vec<u8> = (0..data_byte_len_zero).map(|_| rng.random()).collect();
let result_data_zero = Encoder::new(data_zero, piece_count_non_zero);
assert!(result_data_zero.is_err());
assert_eq!(result_data_zero.expect_err("Expected DataLengthZero error"), RLNCError::DataLengthZero);
let data_byte_len_non_zero = 100;
let piece_count_zero = 0;
let data_non_zero: Vec<u8> = (0..data_byte_len_non_zero).map(|_| rng.random()).collect();
let result_piece_count_zero = Encoder::new(data_non_zero, piece_count_zero);
assert!(result_piece_count_zero.is_err());
assert_eq!(result_piece_count_zero.expect_err("Expected PieceCountZero error"), RLNCError::PieceCountZero);
let data_byte_len_both_zero = 0;
let piece_count_both_zero = 0;
let data_both_zero: Vec<u8> = (0..data_byte_len_both_zero).map(|_| rng.random()).collect();
let result_both_zero = Encoder::new(data_both_zero, piece_count_both_zero);
assert!(result_both_zero.is_err());
assert_eq!(
result_both_zero.expect_err("Expected DataLengthZero error for both zero inputs"),
RLNCError::DataLengthZero
);
let data_byte_len_valid = 1024;
let piece_count_valid = 32;
let data_valid = (0..data_byte_len_valid).map(|_| rng.random()).collect::<Vec<u8>>();
let result_valid = Encoder::new(data_valid, piece_count_valid);
assert!(result_valid.is_ok());
}
#[test]
fn test_encoder_code_with_coding_vector_invalid_inputs() {
let mut rng = rand::rng();
let data_byte_len = 1024usize;
let piece_count = 32usize;
let data = (0..data_byte_len).map(|_| rng.random()).collect::<Vec<u8>>();
let encoder = Encoder::new(data, piece_count).expect("Failed to create Encoder for invalid inputs test");
let short_coding_vector_len = piece_count - 1;
let short_coding_vector: Vec<u8> = (0..short_coding_vector_len).map(|_| rng.random()).collect();
let result_short = encoder.code_with_coding_vector(&short_coding_vector);
assert!(result_short.is_err());
assert_eq!(
result_short.expect_err("Expected CodingVectorLengthMismatch error for short vector"),
RLNCError::CodingVectorLengthMismatch
);
let long_coding_vector_len = piece_count + 1;
let long_coding_vector: Vec<u8> = (0..long_coding_vector_len).map(|_| rng.random()).collect();
let result_long = encoder.code_with_coding_vector(&long_coding_vector);
assert!(result_long.is_err());
assert_eq!(
result_long.expect_err("Expected CodingVectorLengthMismatch error for long vector"),
RLNCError::CodingVectorLengthMismatch
);
let empty_coding_vector: Vec<u8> = Vec::new();
let result_empty = encoder.code_with_coding_vector(&empty_coding_vector);
assert!(result_empty.is_err());
assert_eq!(
result_empty.expect_err("Expected CodingVectorLengthMismatch error for empty vector"),
RLNCError::CodingVectorLengthMismatch
);
let valid_coding_vector: Vec<u8> = (0..piece_count).map(|_| rng.random()).collect();
let result_valid = encoder.code_with_coding_vector(&valid_coding_vector);
assert!(result_valid.is_ok());
assert_eq!(
result_valid.expect("Expected a valid coded piece").len(),
encoder.get_full_coded_piece_byte_len()
);
}
#[test]
fn test_encoder_getters() {
let mut rng = rand::rng();
let data_byte_len_single = 100usize;
let piece_count_single = 1usize;
let data_single = (0..data_byte_len_single).map(|_| rng.random()).collect::<Vec<u8>>();
let encoder_single = Encoder::new(data_single.clone(), piece_count_single).expect("Failed to create Encoder (single piece)");
assert_eq!(encoder_single.get_piece_count(), piece_count_single);
assert_eq!(encoder_single.get_piece_byte_len(), (data_byte_len_single + 1).div_ceil(piece_count_single));
assert_eq!(
encoder_single.get_full_coded_piece_byte_len(),
piece_count_single + (data_byte_len_single + 1).div_ceil(piece_count_single)
);
let piece_count_min = 1usize;
let data_min = vec![42u8];
let encoder_min = Encoder::new(data_min, piece_count_min).expect("Failed to create Encoder (min data)");
assert_eq!(encoder_min.get_piece_count(), piece_count_min);
assert_eq!(encoder_min.get_piece_byte_len(), 2); assert_eq!(encoder_min.get_full_coded_piece_byte_len(), 3);
let data_byte_len_eq = 10usize;
let piece_count_eq = 10usize;
let data_eq = (0..data_byte_len_eq).map(|_| rng.random()).collect::<Vec<u8>>();
let encoder_eq = Encoder::new(data_eq, piece_count_eq).expect("Failed to create Encoder (equal length)");
assert_eq!(encoder_eq.get_piece_count(), piece_count_eq);
assert_eq!(encoder_eq.get_piece_byte_len(), (data_byte_len_eq + 1).div_ceil(piece_count_eq)); assert_eq!(encoder_eq.get_full_coded_piece_byte_len(), piece_count_eq + 2);
let data_byte_len_large = 100usize;
let piece_count_large = 50usize;
let data_large = (0..data_byte_len_large).map(|_| rng.random()).collect::<Vec<u8>>();
let encoder_large = Encoder::new(data_large, piece_count_large).expect("Failed to create Encoder (large piece count)");
assert_eq!(encoder_large.get_piece_count(), piece_count_large);
assert_eq!(encoder_large.get_piece_byte_len(), (data_byte_len_large + 1).div_ceil(piece_count_large));
assert_eq!(
encoder_large.get_full_coded_piece_byte_len(),
piece_count_large + (data_byte_len_large + 1).div_ceil(piece_count_large)
);
}
}