use crate::config::Config;
use crate::error::{Result, ShamirError};
use crate::finite_field::FiniteField;
use rand::rngs::OsRng;
use rand_chacha::ChaCha20Rng;
use rand_chacha::rand_core::RngCore;
use rand_core::SeedableRng;
use rayon::iter::ParallelIterator;
use rayon::prelude::*;
use sha2::{Digest, Sha256};
use std::io::{Read, Write};
#[cfg(feature = "zeroize")]
use zeroize::{Zeroize, ZeroizeOnDrop};
const HASH_SIZE: usize = 32;
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(feature = "zeroize", derive(Zeroize, ZeroizeOnDrop))]
pub struct Share {
pub index: u8,
pub data: Vec<u8>,
pub threshold: u8,
pub total_shares: u8,
pub integrity_check: bool,
pub compression: bool,
}
#[derive(Debug, Clone, Copy)]
pub struct ShareView<'a> {
pub index: u8,
pub data: &'a [u8],
}
#[cfg_attr(feature = "zeroize", derive(Zeroize, ZeroizeOnDrop))]
pub struct Dealer {
data: Vec<u8>,
coefficients: Vec<u8>,
current_x: u8,
threshold: u8,
total_shares: u8,
integrity_check: bool,
compression: bool,
}
#[derive(Debug)]
pub struct ShamirShare {
total_shares: u8,
threshold: u8,
config: Config,
rng: ChaCha20Rng,
}
#[derive(Debug)]
pub struct ShamirShareBuilder {
total_shares: u8,
threshold: u8,
config: Config,
}
impl ShamirShareBuilder {
pub fn new(total_shares: u8, threshold: u8) -> Self {
Self {
total_shares,
threshold,
config: Config::default(),
}
}
pub fn with_config(mut self, config: Config) -> Self {
self.config = config;
self
}
pub fn build(self) -> Result<ShamirShare> {
if self.total_shares == 0 {
return Err(ShamirError::InvalidShareCount(self.total_shares));
}
if self.threshold == 0 {
return Err(ShamirError::InvalidThreshold(self.threshold));
}
if self.threshold > self.total_shares {
return Err(ShamirError::ThresholdTooLarge {
threshold: self.threshold,
total_shares: self.total_shares,
});
}
self.config.validate()?;
Ok(ShamirShare {
total_shares: self.total_shares,
threshold: self.threshold,
config: self.config,
rng: ChaCha20Rng::try_from_rng(&mut OsRng).unwrap(),
})
}
}
impl ShamirShare {
pub fn threshold(&self) -> u8 {
self.threshold
}
pub fn total_shares(&self) -> u8 {
self.total_shares
}
pub fn builder(total_shares: u8, threshold: u8) -> ShamirShareBuilder {
ShamirShareBuilder::new(total_shares, threshold)
}
pub fn dealer(&mut self, secret: &[u8]) -> Dealer {
#[cfg_attr(not(feature = "zeroize"), allow(unused_mut))]
let mut data_to_split = if self.config.integrity_check {
let hash = Sha256::digest(secret);
let mut data = Vec::with_capacity(HASH_SIZE + secret.len());
data.extend_from_slice(&hash);
#[cfg(feature = "compress")]
if self.config.compression {
let compressed_secret = zstd::encode_all(secret, 0)
.map_err(|e| ShamirError::CompressionError(e.to_string()))
.unwrap();
data.extend_from_slice(&compressed_secret);
} else {
data.extend_from_slice(secret);
}
#[cfg(not(feature = "compress"))]
data.extend_from_slice(secret);
data
} else {
#[cfg(feature = "compress")]
if self.config.compression {
zstd::encode_all(secret, 0)
.map_err(|e| ShamirError::CompressionError(e.to_string()))
.unwrap()
} else {
secret.to_vec()
}
#[cfg(not(feature = "compress"))]
secret.to_vec()
};
let secret_len = data_to_split.len();
let t = self.threshold as usize;
let mut coefficients = vec![0u8; secret_len * (t - 1)];
self.rng.fill_bytes(&mut coefficients);
let dealer = Dealer {
data: data_to_split.clone(),
coefficients: coefficients.clone(),
current_x: 1,
threshold: self.threshold,
total_shares: self.total_shares,
integrity_check: self.config.integrity_check,
compression: self.config.compression,
};
#[cfg(feature = "zeroize")]
{
data_to_split.zeroize();
coefficients.zeroize();
}
dealer
}
pub fn split(&mut self, secret: &[u8]) -> Result<Vec<Share>> {
Ok(self
.dealer(secret)
.take(self.total_shares as usize)
.collect())
}
pub fn reconstruct(shares: &[Share]) -> Result<Vec<u8>> {
if shares.is_empty() {
return Err(ShamirError::InsufficientShares { needed: 1, got: 0 });
}
let threshold = shares[0].threshold;
if shares.len() < threshold as usize {
return Err(ShamirError::InsufficientShares {
needed: threshold,
got: shares.len() as u8,
});
}
let integrity_check = shares[0].integrity_check;
let compression = shares[0].compression;
if !shares.iter().all(|s| {
s.data.len() == shares[0].data.len()
&& s.integrity_check == integrity_check
&& s.compression == compression
}) {
return Err(ShamirError::InconsistentShareLength);
}
#[cfg_attr(not(feature = "zeroize"), allow(unused_mut))]
let mut reconstructed_data = Self::reconstruct_chunk(shares)?;
let result = if integrity_check {
if reconstructed_data.len() < HASH_SIZE {
return Err(ShamirError::IntegrityCheckFailed);
}
let (reconstructed_hash, compressed_secret) = reconstructed_data.split_at(HASH_SIZE);
let secret = {
#[cfg(feature = "compress")]
if compression {
zstd::decode_all(compressed_secret)
.map_err(|e| ShamirError::DecompressionError(e.to_string()))?
} else {
compressed_secret.to_vec()
}
#[cfg(not(feature = "compress"))]
compressed_secret.to_vec()
};
let calculated_hash = Sha256::digest(&secret);
let mut hash_match = 0u8;
for (a, b) in calculated_hash
.as_slice()
.iter()
.zip(reconstructed_hash.iter())
{
hash_match |= a ^ b;
}
if hash_match != 0 {
return Err(ShamirError::IntegrityCheckFailed);
}
Ok(secret)
} else {
#[cfg(feature = "compress")]
if compression {
zstd::decode_all(reconstructed_data.as_slice())
.map_err(|e| ShamirError::DecompressionError(e.to_string()))
} else {
Ok(reconstructed_data.clone())
}
#[cfg(not(feature = "compress"))]
Ok(reconstructed_data.clone())
};
#[cfg(feature = "zeroize")]
reconstructed_data.zeroize();
result
}
pub fn split_stream<R: Read, W: Write>(
&mut self,
source: &mut R,
destinations: &mut [W],
) -> Result<()> {
if destinations.len() != self.total_shares as usize {
return Err(ShamirError::InvalidConfig(format!(
"Expected {} destinations, got {}",
self.total_shares,
destinations.len()
)));
}
let integrity_flag = if self.config.integrity_check { 1 } else { 0 };
let compression_flag = if self.config.compression { 2 } else { 0 };
let flags = integrity_flag | compression_flag;
for (i, dest) in destinations.iter_mut().enumerate() {
dest.write_all(&[flags, (i + 1) as u8])
.map_err(ShamirError::IoError)?;
}
let chunk_size = self.config.chunk_size;
let mut chunk_read_buffer = vec![0u8; chunk_size];
let mut chunk_with_hash_buffer = Vec::with_capacity(if self.config.integrity_check {
HASH_SIZE + chunk_size
} else {
chunk_size
});
let max_chunk_size_with_hash = if self.config.integrity_check {
HASH_SIZE + chunk_size
} else {
chunk_size
};
let mut share_data_buffers: Vec<Vec<u8>> = (0..self.total_shares)
.map(|_| Vec::with_capacity(max_chunk_size_with_hash))
.collect();
loop {
let bytes_read = source
.read(&mut chunk_read_buffer)
.map_err(ShamirError::IoError)?;
if bytes_read == 0 {
break; }
let chunk = &chunk_read_buffer[..bytes_read];
chunk_with_hash_buffer.clear();
if self.config.integrity_check {
let hash = Sha256::digest(chunk);
chunk_with_hash_buffer.extend_from_slice(&hash);
}
#[cfg(feature = "compress")]
if self.config.compression {
let compressed_chunk = zstd::encode_all(chunk, 0)
.map_err(|e| ShamirError::CompressionError(e.to_string()))?;
chunk_with_hash_buffer.extend_from_slice(&compressed_chunk);
} else {
chunk_with_hash_buffer.extend_from_slice(chunk);
}
#[cfg(not(feature = "compress"))]
chunk_with_hash_buffer.extend_from_slice(chunk);
let chunk_share_data = self.split_chunk(&chunk_with_hash_buffer)?;
for (share_idx, chunk_data) in chunk_share_data.iter().enumerate() {
let share_buffer = &mut share_data_buffers[share_idx];
share_buffer.clear();
share_buffer.extend_from_slice(chunk_data);
}
for (i, share_data) in share_data_buffers.iter().enumerate() {
let length = share_data.len() as u32;
destinations[i]
.write_all(&length.to_le_bytes())
.map_err(ShamirError::IoError)?;
destinations[i]
.write_all(share_data)
.map_err(ShamirError::IoError)?;
}
}
#[cfg(feature = "zeroize")]
{
chunk_read_buffer.zeroize();
chunk_with_hash_buffer.zeroize();
for buffer in &mut share_data_buffers {
buffer.zeroize();
}
}
for dest in destinations.iter_mut() {
dest.flush().map_err(ShamirError::IoError)?;
}
Ok(())
}
pub fn reconstruct_stream<R: Read, W: Write>(
sources: &mut [R],
destination: &mut W,
) -> Result<()> {
if sources.is_empty() {
return Err(ShamirError::InsufficientShares { needed: 1, got: 0 });
}
let mut headers: Vec<[u8; 2]> = Vec::with_capacity(sources.len());
for source in sources.iter_mut() {
let mut header = [0u8; 2];
source
.read_exact(&mut header)
.map_err(ShamirError::IoError)?;
headers.push(header);
}
let first_flags = headers[0][0];
let integrity_check = (first_flags & 1) != 0;
let compression = (first_flags & 2) != 0;
for header in headers.iter().skip(1) {
if header[0] != first_flags {
return Err(ShamirError::InvalidConfig(
"Inconsistent flags across sources".to_string(),
));
}
}
let share_indices: Vec<u8> = headers.iter().map(|h| h[1]).collect();
let mut chunk_lengths_buffer = Vec::with_capacity(sources.len());
let mut share_chunk_data_buffers: Vec<Vec<u8>> =
(0..sources.len()).map(|_| Vec::new()).collect();
let mut reconstructed_chunk_buffer = Vec::new();
loop {
chunk_lengths_buffer.clear();
let mut eof_reached = false;
for source in sources.iter_mut() {
let mut length_bytes = [0u8; 4];
match source.read_exact(&mut length_bytes) {
Ok(()) => {
let length = u32::from_le_bytes(length_bytes) as usize;
chunk_lengths_buffer.push(length);
}
Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
eof_reached = true;
break;
}
Err(e) => return Err(ShamirError::IoError(e)),
}
}
if eof_reached {
break; }
for (i, source) in sources.iter_mut().enumerate() {
let share_chunk_buffer = &mut share_chunk_data_buffers[i];
let chunk_length = chunk_lengths_buffer[i];
if share_chunk_buffer.len() != chunk_length {
share_chunk_buffer.resize(chunk_length, 0);
}
source
.read_exact(share_chunk_buffer)
.map_err(ShamirError::IoError)?;
}
let share_views: Vec<ShareView> = share_chunk_data_buffers
.iter()
.enumerate()
.map(|(i, share_chunk_data)| ShareView {
index: share_indices[i], data: share_chunk_data, })
.collect();
let reconstructed_chunk = Self::reconstruct_chunk_from_views(
&share_views,
&mut reconstructed_chunk_buffer,
)?;
if integrity_check {
if reconstructed_chunk.len() < HASH_SIZE {
return Err(ShamirError::IntegrityCheckFailed);
}
let (reconstructed_hash, compressed_data) = reconstructed_chunk.split_at(HASH_SIZE);
let data = {
#[cfg(feature = "compress")]
if compression {
zstd::decode_all(compressed_data)
.map_err(|e| ShamirError::DecompressionError(e.to_string()))?
} else {
compressed_data.to_vec()
}
#[cfg(not(feature = "compress"))]
compressed_data.to_vec()
};
let calculated_hash = Sha256::digest(&data);
let mut hash_match = 0u8;
for (a, b) in calculated_hash
.as_slice()
.iter()
.zip(reconstructed_hash.iter())
{
hash_match |= a ^ b;
}
if hash_match != 0 {
return Err(ShamirError::IntegrityCheckFailed);
}
destination.write_all(&data).map_err(ShamirError::IoError)?;
} else {
#[cfg(feature = "compress")]
if compression {
let data = zstd::decode_all(reconstructed_chunk)
.map_err(|e| ShamirError::DecompressionError(e.to_string()))?;
destination.write_all(&data).map_err(ShamirError::IoError)?;
} else {
destination
.write_all(reconstructed_chunk)
.map_err(ShamirError::IoError)?;
}
#[cfg(not(feature = "compress"))]
destination
.write_all(reconstructed_chunk)
.map_err(ShamirError::IoError)?;
};
}
#[cfg(feature = "zeroize")]
{
for buffer in &mut share_chunk_data_buffers {
buffer.zeroize();
}
reconstructed_chunk_buffer.zeroize();
}
destination.flush().map_err(ShamirError::IoError)?;
Ok(())
}
#[inline]
fn split_chunk(&mut self, data: &[u8]) -> Result<Vec<Vec<u8>>> {
let secret_len = data.len();
let t = self.threshold as usize;
let mut random_data = vec![0u8; secret_len * (t - 1)];
self.rng.fill_bytes(&mut random_data);
let x_values: Vec<FiniteField> = (1..=self.total_shares).map(FiniteField::new).collect();
let share_data: Vec<Vec<u8>> = x_values
.into_par_iter()
.map(|x| {
(0..secret_len)
.map(|idx| {
let mut acc = FiniteField::new(0);
for j in (0..t).rev() {
let coeff = if j == 0 {
FiniteField::new(data[idx])
} else {
FiniteField::new(random_data[idx * (t - 1) + (j - 1)])
};
acc = acc * x + coeff;
}
acc.0
})
.collect()
})
.collect();
#[cfg(feature = "zeroize")]
random_data.zeroize();
Ok(share_data)
}
#[inline]
fn compute_lagrange_coefficients(shares: &[Share]) -> Result<Vec<FiniteField>> {
let xs: Vec<FiniteField> = shares
.iter()
.map(|share| FiniteField::new(share.index))
.collect();
for i in 0..xs.len() {
for j in (i + 1)..xs.len() {
if xs[i] == xs[j] {
return Err(ShamirError::InvalidShareFormat);
}
}
}
let p = xs.iter().fold(FiniteField::new(1), |acc, &x| acc * x);
let lagrange_coefficients: Result<Vec<FiniteField>> = xs
.iter()
.enumerate()
.map(|(i, &x_i)| {
let numerator = p * x_i.inverse().unwrap();
let mut denominator = FiniteField::new(1);
for (j, &x_j) in xs.iter().enumerate() {
if i != j {
denominator = denominator * (x_i + x_j);
}
}
denominator
.inverse()
.ok_or(ShamirError::InvalidShareFormat)
.map(|inv| numerator * inv)
})
.collect();
lagrange_coefficients
}
#[inline]
fn compute_lagrange_coefficients_from_views(share_views: &[ShareView]) -> Result<Vec<FiniteField>> {
let xs: Vec<FiniteField> = share_views
.iter()
.map(|view| FiniteField::new(view.index))
.collect();
for i in 0..xs.len() {
for j in (i + 1)..xs.len() {
if xs[i] == xs[j] {
return Err(ShamirError::InvalidShareFormat);
}
}
}
let p = xs.iter().fold(FiniteField::new(1), |acc, &x| acc * x);
let lagrange_coefficients: Result<Vec<FiniteField>> = xs
.iter()
.enumerate()
.map(|(i, &x_i)| {
let numerator = p * x_i.inverse().unwrap();
let mut denominator = FiniteField::new(1);
for (j, &x_j) in xs.iter().enumerate() {
if i != j {
denominator = denominator * (x_i + x_j);
}
}
denominator
.inverse()
.ok_or(ShamirError::InvalidShareFormat)
.map(|inv| numerator * inv)
})
.collect();
lagrange_coefficients
}
#[inline]
fn reconstruct_chunk(shares: &[Share]) -> Result<Vec<u8>> {
if shares.is_empty() {
return Err(ShamirError::InsufficientShares { needed: 1, got: 0 });
}
let secret_len = shares[0].data.len();
if !shares.iter().all(|s| s.data.len() == secret_len) {
return Err(ShamirError::InconsistentShareLength);
}
let lagrange_coefficients = Self::compute_lagrange_coefficients(shares)?;
let reconstructed_data = (0..secret_len)
.into_par_iter()
.map(|byte_idx| {
shares
.iter()
.zip(&lagrange_coefficients)
.fold(FiniteField::new(0), |acc, (share, &coeff)| {
acc + coeff * FiniteField::new(share.data[byte_idx])
})
.0
})
.collect::<Vec<u8>>();
Ok(reconstructed_data)
}
#[inline]
fn reconstruct_chunk_from_views<'a>(
share_views: &[ShareView],
output_buffer: &'a mut Vec<u8>,
) -> Result<&'a [u8]> {
if share_views.is_empty() {
return Err(ShamirError::InsufficientShares { needed: 1, got: 0 });
}
let secret_len = share_views[0].data.len();
if !share_views.iter().all(|v| v.data.len() == secret_len) {
return Err(ShamirError::InconsistentShareLength);
}
let lagrange_coefficients = Self::compute_lagrange_coefficients_from_views(share_views)?;
output_buffer.clear();
output_buffer.reserve(secret_len);
for byte_idx in 0..secret_len {
let reconstructed_byte = share_views
.iter()
.zip(&lagrange_coefficients)
.fold(FiniteField::new(0), |acc, (view, &coeff)| {
acc + coeff * FiniteField::new(view.data[byte_idx])
})
.0;
output_buffer.push(reconstructed_byte);
}
Ok(output_buffer)
}
fn generate_zero_polynomial_shares(
&mut self,
share_indices: &[u8],
data_length: usize,
) -> Result<Vec<Vec<u8>>> {
let t = self.threshold as usize;
let mut random_data = vec![0u8; data_length * (t - 1)];
self.rng.fill_bytes(&mut random_data);
let delta_shares: Vec<Vec<u8>> = share_indices
.par_iter()
.map(|&index| {
let x = FiniteField::new(index);
(0..data_length)
.map(|byte_idx| {
let mut acc = FiniteField::new(0);
for j in (1..t).rev() {
let coeff = FiniteField::new(random_data[byte_idx * (t - 1) + (j - 1)]);
acc = acc * x + coeff;
}
acc = acc * x;
acc.0
})
.collect()
})
.collect();
#[cfg(feature = "zeroize")]
random_data.zeroize();
Ok(delta_shares)
}
pub fn refresh_shares(&mut self, shares: &[Share]) -> Result<Vec<Share>> {
if shares.is_empty() {
return Err(ShamirError::InsufficientShares { needed: 1, got: 0 });
}
if shares.len() < self.threshold as usize {
return Err(ShamirError::InsufficientShares {
needed: self.threshold,
got: shares.len() as u8,
});
}
let data_length = shares[0].data.len();
let integrity_check = shares[0].integrity_check;
if !shares
.iter()
.all(|s| s.data.len() == data_length && s.integrity_check == integrity_check)
{
return Err(ShamirError::InconsistentShareLength);
}
let indices: Vec<u8> = shares.iter().map(|s| s.index).collect();
let deltas = self.generate_zero_polynomial_shares(&indices, data_length)?;
let refreshed_shares: Vec<Share> = shares
.iter()
.zip(deltas.iter())
.map(|(old_share, delta_data)| {
let new_data: Vec<u8> = old_share
.data
.iter()
.zip(delta_data.iter())
.map(|(&old_byte, &delta_byte)| old_byte ^ delta_byte)
.collect();
Share {
index: old_share.index,
data: new_data,
threshold: old_share.threshold,
total_shares: old_share.total_shares,
integrity_check: old_share.integrity_check,
compression: old_share.compression,
}
})
.collect();
Ok(refreshed_shares)
}
}
impl Iterator for Dealer {
type Item = Share;
fn next(&mut self) -> Option<Self::Item> {
if self.current_x == 0 {
return None;
}
let x = FiniteField::new(self.current_x);
let secret_len = self.data.len();
let t = self.threshold as usize;
let share_data: Vec<u8> = (0..secret_len)
.map(|byte_idx| {
let mut acc = FiniteField::new(0);
for j in (0..t).rev() {
let coeff = if j == 0 {
FiniteField::new(self.data[byte_idx])
} else {
FiniteField::new(self.coefficients[byte_idx * (t - 1) + (j - 1)])
};
acc = acc * x + coeff;
}
acc.0
})
.collect();
let share = Share {
index: self.current_x,
data: share_data,
threshold: self.threshold,
total_shares: self.total_shares,
integrity_check: self.integrity_check,
compression: self.compression,
};
self.current_x = self.current_x.wrapping_add(1);
Some(share)
}
fn size_hint(&self) -> (usize, Option<usize>) {
let remaining = if self.current_x == 0 {
0
} else {
256 - self.current_x as usize
};
(remaining, Some(remaining))
}
}
impl ExactSizeIterator for Dealer {
fn len(&self) -> usize {
self.size_hint().0
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_split_and_reconstruct() {
let secret = b"Hello, World!";
let mut shamir = ShamirShare::builder(5, 3).build().unwrap();
let shares = shamir.split(secret).unwrap();
assert_eq!(shares.len(), 5);
let reconstructed = ShamirShare::reconstruct(&shares[0..3]).unwrap();
assert_eq!(&reconstructed, secret);
let reconstructed = ShamirShare::reconstruct(&shares[1..5]).unwrap();
assert_eq!(&reconstructed, secret);
}
#[test]
fn test_invalid_parameters() {
assert!(ShamirShare::builder(0, 1).build().is_err());
assert!(ShamirShare::builder(1, 0).build().is_err());
assert!(ShamirShare::builder(3, 4).build().is_err());
}
#[test]
fn test_insufficient_shares() {
let secret = b"Test";
let mut shamir = ShamirShare::builder(5, 3).build().unwrap();
let shares = shamir.split(secret).unwrap();
assert!(ShamirShare::reconstruct(&shares[0..2]).is_err());
}
#[test]
fn test_different_share_combinations() {
let secret = b"Different combinations test";
let mut shamir = ShamirShare::builder(5, 3).build().unwrap();
let shares = shamir.split(secret).unwrap();
let combinations = vec![vec![0, 1, 2], vec![1, 2, 3], vec![2, 3, 4], vec![0, 2, 4]];
for combo in combinations {
let selected_shares: Vec<Share> = combo.iter().map(|&i| shares[i].clone()).collect();
let reconstructed = ShamirShare::reconstruct(&selected_shares).unwrap();
assert_eq!(&reconstructed, secret);
}
}
#[test]
fn test_empty_secret() {
let secret = b"";
let mut shamir = ShamirShare::builder(5, 3).build().unwrap();
let shares = shamir.split(secret).unwrap();
let reconstructed = ShamirShare::reconstruct(&shares[0..3]).unwrap();
assert_eq!(reconstructed, secret);
}
#[test]
fn test_single_byte_secret() {
let secret = b"x";
let mut shamir = ShamirShare::builder(5, 3).build().unwrap();
let shares = shamir.split(secret).unwrap();
let reconstructed = ShamirShare::reconstruct(&shares[0..3]).unwrap();
assert_eq!(reconstructed, secret);
}
#[test]
fn test_max_shares() {
let secret = b"Maximum shares test";
let mut shamir = ShamirShare::builder(255, 128).build().unwrap();
let shares = shamir.split(secret).unwrap();
assert_eq!(shares.len(), 255);
let reconstructed = ShamirShare::reconstruct(&shares[0..128]).unwrap();
assert_eq!(reconstructed, secret);
}
#[test]
fn test_duplicate_share_indices() {
let mut shamir = ShamirShare::builder(5, 3).build().unwrap();
let shares = shamir.split(b"test").unwrap();
let mut corrupted_shares = shares[0..3].to_vec();
corrupted_shares[1].index = corrupted_shares[0].index;
assert!(matches!(
ShamirShare::reconstruct(&corrupted_shares),
Err(ShamirError::InvalidShareFormat)
));
}
#[test]
fn test_corrupted_share_data() {
let mut shamir = ShamirShare::builder(5, 3).build().unwrap();
let mut shares = shamir.split(b"test").unwrap();
if shares[0].data[0] == 0 {
shares[0].data[0] = 1;
} else {
shares[0].data[0] = 0;
}
assert!(matches!(
ShamirShare::reconstruct(&shares[0..3]),
Err(ShamirError::IntegrityCheckFailed)
));
}
#[test]
fn test_builder_pattern() {
let shamir = ShamirShare::builder(5, 3).build().unwrap();
assert_eq!(shamir.total_shares, 5);
assert_eq!(shamir.threshold, 3);
assert!(shamir.config.integrity_check);
let config = Config::new().with_integrity_check(false);
let shamir = ShamirShare::builder(7, 4)
.with_config(config)
.build()
.unwrap();
assert_eq!(shamir.total_shares, 7);
assert_eq!(shamir.threshold, 4);
assert!(!shamir.config.integrity_check);
}
#[test]
fn test_builder_validation() {
assert!(ShamirShare::builder(0, 1).build().is_err());
assert!(ShamirShare::builder(1, 0).build().is_err());
assert!(ShamirShare::builder(3, 5).build().is_err());
let invalid_config = Config::new().with_chunk_size(0).unwrap_err();
assert!(matches!(invalid_config, ShamirError::InvalidConfig(_)));
}
#[test]
fn test_integrity_check_disabled() {
let config = Config::new().with_integrity_check(false);
let mut shamir = ShamirShare::builder(5, 3)
.with_config(config)
.build()
.unwrap();
let secret = b"test secret without integrity check";
let shares = shamir.split(secret).unwrap();
assert!(!shares[0].integrity_check);
let reconstructed = ShamirShare::reconstruct(&shares[0..3]).unwrap();
assert_eq!(&reconstructed, secret);
let mut shamir_with_integrity = ShamirShare::builder(5, 3).build().unwrap();
let shares_with_integrity = shamir_with_integrity.split(secret).unwrap();
assert!(shares[0].data.len() < shares_with_integrity[0].data.len());
assert_eq!(
shares_with_integrity[0].data.len() - shares[0].data.len(),
HASH_SIZE
);
}
#[test]
fn test_integrity_check_enabled() {
let config = Config::new().with_integrity_check(true);
let mut shamir = ShamirShare::builder(5, 3)
.with_config(config)
.build()
.unwrap();
let secret = b"test secret with integrity check";
let shares = shamir.split(secret).unwrap();
assert!(shares[0].integrity_check);
let reconstructed = ShamirShare::reconstruct(&shares[0..3]).unwrap();
assert_eq!(&reconstructed, secret);
let mut corrupted_shares = shares[0..3].to_vec();
if corrupted_shares[0].data[0] == 0 {
corrupted_shares[0].data[0] = 1;
} else {
corrupted_shares[0].data[0] = 0;
}
assert!(matches!(
ShamirShare::reconstruct(&corrupted_shares),
Err(ShamirError::IntegrityCheckFailed)
));
}
#[test]
fn test_mixed_integrity_check_shares() {
let config_with_integrity = Config::new().with_integrity_check(true);
let mut shamir_with_integrity = ShamirShare::builder(5, 3)
.with_config(config_with_integrity)
.build()
.unwrap();
let config_without_integrity = Config::new().with_integrity_check(false);
let mut shamir_without_integrity = ShamirShare::builder(5, 3)
.with_config(config_without_integrity)
.build()
.unwrap();
let secret = b"test secret";
let shares_with_integrity = shamir_with_integrity.split(secret).unwrap();
let shares_without_integrity = shamir_without_integrity.split(secret).unwrap();
let mixed_shares = vec![
shares_with_integrity[0].clone(),
shares_without_integrity[1].clone(),
shares_with_integrity[2].clone(),
];
assert!(matches!(
ShamirShare::reconstruct(&mixed_shares),
Err(ShamirError::InconsistentShareLength)
));
}
#[test]
fn test_config_builder_methods() {
use crate::config::SplitMode;
let config = Config::new()
.with_chunk_size(2048)
.unwrap()
.with_mode(SplitMode::Parallel)
.with_compression(true)
.with_integrity_check(false);
let shamir = ShamirShare::builder(5, 3)
.with_config(config.clone())
.build()
.unwrap();
assert_eq!(shamir.config.chunk_size, 2048);
assert_eq!(shamir.config.mode, SplitMode::Parallel);
assert!(shamir.config.compression);
assert!(!shamir.config.integrity_check);
}
#[test]
fn test_split_stream_basic() {
use std::io::Cursor;
let mut shamir = ShamirShare::builder(3, 2).build().unwrap();
let data = b"This is a test message for streaming functionality";
let mut source = Cursor::new(data);
let mut destinations = vec![Vec::new(); 3];
let mut dest_cursors: Vec<Cursor<Vec<u8>>> = destinations
.iter_mut()
.map(|d| Cursor::new(std::mem::take(d)))
.collect();
shamir.split_stream(&mut source, &mut dest_cursors).unwrap();
let share_data: Vec<Vec<u8>> = dest_cursors
.into_iter()
.map(|cursor| cursor.into_inner())
.collect();
for share in &share_data {
assert!(!share.is_empty());
}
let mut sources: Vec<Cursor<Vec<u8>>> = share_data[0..2]
.iter()
.map(|data| Cursor::new(data.clone()))
.collect();
let mut destination = Vec::new();
let mut dest_cursor = Cursor::new(&mut destination);
ShamirShare::reconstruct_stream(&mut sources, &mut dest_cursor).unwrap();
assert_eq!(&destination, data);
}
#[test]
fn test_split_stream_with_custom_chunk_size() {
use std::io::Cursor;
let config = Config::new().with_chunk_size(10).unwrap(); let mut shamir = ShamirShare::builder(3, 2)
.with_config(config)
.build()
.unwrap();
let data = b"This is a longer test message that will be split into multiple chunks";
let mut source = Cursor::new(data);
let mut destinations = vec![Vec::new(); 3];
let mut dest_cursors: Vec<Cursor<Vec<u8>>> = destinations
.iter_mut()
.map(|d| Cursor::new(std::mem::take(d)))
.collect();
shamir.split_stream(&mut source, &mut dest_cursors).unwrap();
let share_data: Vec<Vec<u8>> = dest_cursors
.into_iter()
.map(|cursor| cursor.into_inner())
.collect();
let mut sources: Vec<Cursor<Vec<u8>>> = share_data[0..2]
.iter()
.map(|data| Cursor::new(data.clone()))
.collect();
let mut destination = Vec::new();
let mut dest_cursor = Cursor::new(&mut destination);
ShamirShare::reconstruct_stream(&mut sources, &mut dest_cursor).unwrap();
assert_eq!(&destination, data);
}
#[test]
fn test_split_stream_without_integrity_check() {
use std::io::Cursor;
let config = Config::new()
.with_integrity_check(false)
.with_chunk_size(20)
.unwrap();
let mut shamir = ShamirShare::builder(3, 2)
.with_config(config)
.build()
.unwrap();
let data = b"Test message without integrity checking";
let mut source = Cursor::new(data);
let mut destinations = vec![Vec::new(); 3];
let mut dest_cursors: Vec<Cursor<Vec<u8>>> = destinations
.iter_mut()
.map(|d| Cursor::new(std::mem::take(d)))
.collect();
shamir.split_stream(&mut source, &mut dest_cursors).unwrap();
let share_data: Vec<Vec<u8>> = dest_cursors
.into_iter()
.map(|cursor| cursor.into_inner())
.collect();
let mut sources: Vec<Cursor<Vec<u8>>> = share_data[0..2]
.iter()
.map(|data| Cursor::new(data.clone()))
.collect();
let mut destination = Vec::new();
let mut dest_cursor = Cursor::new(&mut destination);
ShamirShare::reconstruct_stream(&mut sources, &mut dest_cursor).unwrap();
assert_eq!(&destination, data);
}
#[test]
fn test_split_stream_empty_data() {
use std::io::Cursor;
let mut shamir = ShamirShare::builder(3, 2).build().unwrap();
let data = b"";
let mut source = Cursor::new(data);
let mut destinations = vec![Vec::new(); 3];
let mut dest_cursors: Vec<Cursor<Vec<u8>>> = destinations
.iter_mut()
.map(|d| Cursor::new(std::mem::take(d)))
.collect();
shamir.split_stream(&mut source, &mut dest_cursors).unwrap();
let share_data: Vec<Vec<u8>> = dest_cursors
.into_iter()
.map(|cursor| cursor.into_inner())
.collect();
for share in &share_data {
assert_eq!(share.len(), 2); }
let mut sources: Vec<Cursor<Vec<u8>>> = share_data[0..2]
.iter()
.map(|data| Cursor::new(data.clone()))
.collect();
let mut destination = Vec::new();
let mut dest_cursor = Cursor::new(&mut destination);
ShamirShare::reconstruct_stream(&mut sources, &mut dest_cursor).unwrap();
assert_eq!(&destination, data);
}
#[test]
fn test_split_stream_wrong_destination_count() {
use std::io::Cursor;
let mut shamir = ShamirShare::builder(3, 2).build().unwrap();
let data = b"test";
let mut source = Cursor::new(data);
let mut destinations = vec![Vec::new(); 2];
let mut dest_cursors: Vec<Cursor<Vec<u8>>> = destinations
.iter_mut()
.map(|d| Cursor::new(std::mem::take(d)))
.collect();
let result = shamir.split_stream(&mut source, &mut dest_cursors);
assert!(matches!(result, Err(ShamirError::InvalidConfig(_))));
}
#[test]
fn test_reconstruct_stream_insufficient_sources() {
use std::io::Cursor;
let mut sources: Vec<Cursor<Vec<u8>>> = vec![];
let mut destination = Vec::new();
let mut dest_cursor = Cursor::new(&mut destination);
let result = ShamirShare::reconstruct_stream(&mut sources, &mut dest_cursor);
assert!(matches!(
result,
Err(ShamirError::InsufficientShares { .. })
));
}
#[test]
fn test_stream_data_format() {
use std::io::Cursor;
let config = Config::new().with_chunk_size(5).unwrap(); let mut shamir = ShamirShare::builder(3, 2)
.with_config(config)
.build()
.unwrap();
let data = b"Hello World!"; let mut source = Cursor::new(data);
let mut destinations = vec![Vec::new(); 3];
let mut dest_cursors: Vec<Cursor<Vec<u8>>> = destinations
.iter_mut()
.map(|d| Cursor::new(std::mem::take(d)))
.collect();
shamir.split_stream(&mut source, &mut dest_cursors).unwrap();
let share_data: Vec<Vec<u8>> = dest_cursors
.into_iter()
.map(|cursor| cursor.into_inner())
.collect();
for share in &share_data {
let mut cursor = Cursor::new(share);
let mut total_chunks = 0;
let mut header = [0u8; 2];
cursor.read_exact(&mut header).unwrap();
loop {
let mut length_bytes = [0u8; 4];
match cursor.read_exact(&mut length_bytes) {
Ok(()) => {
let length = u32::from_le_bytes(length_bytes) as usize;
let mut chunk_data = vec![0u8; length];
cursor.read_exact(&mut chunk_data).unwrap();
total_chunks += 1;
}
Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => break,
Err(e) => panic!("Unexpected error: {}", e),
}
}
assert_eq!(total_chunks, 3);
}
let mut sources: Vec<Cursor<Vec<u8>>> = share_data[0..2]
.iter()
.map(|data| Cursor::new(data.clone()))
.collect();
let mut destination = Vec::new();
let mut dest_cursor = Cursor::new(&mut destination);
ShamirShare::reconstruct_stream(&mut sources, &mut dest_cursor).unwrap();
assert_eq!(&destination, data);
}
#[test]
fn test_stream_integrity_check_detection() {
use std::io::Cursor;
let config_with_integrity = Config::new()
.with_integrity_check(true)
.with_chunk_size(10)
.unwrap();
let mut shamir_with_integrity = ShamirShare::builder(3, 2)
.with_config(config_with_integrity)
.build()
.unwrap();
let config_without_integrity = Config::new()
.with_integrity_check(false)
.with_chunk_size(10)
.unwrap();
let mut shamir_without_integrity = ShamirShare::builder(3, 2)
.with_config(config_without_integrity)
.build()
.unwrap();
let data = b"Test data for integrity checking";
let mut source1 = Cursor::new(data);
let mut destinations1 = vec![Vec::new(); 3];
let mut dest_cursors1: Vec<Cursor<Vec<u8>>> = destinations1
.iter_mut()
.map(|d| Cursor::new(std::mem::take(d)))
.collect();
shamir_with_integrity
.split_stream(&mut source1, &mut dest_cursors1)
.unwrap();
let share_data_with_integrity: Vec<Vec<u8>> = dest_cursors1
.into_iter()
.map(|cursor| cursor.into_inner())
.collect();
let mut source2 = Cursor::new(data);
let mut destinations2 = vec![Vec::new(); 3];
let mut dest_cursors2: Vec<Cursor<Vec<u8>>> = destinations2
.iter_mut()
.map(|d| Cursor::new(std::mem::take(d)))
.collect();
shamir_without_integrity
.split_stream(&mut source2, &mut dest_cursors2)
.unwrap();
let share_data_without_integrity: Vec<Vec<u8>> = dest_cursors2
.into_iter()
.map(|cursor| cursor.into_inner())
.collect();
assert!(share_data_with_integrity[0].len() > share_data_without_integrity[0].len());
let mut sources1: Vec<Cursor<Vec<u8>>> = share_data_with_integrity[0..2]
.iter()
.map(|data| Cursor::new(data.clone()))
.collect();
let mut destination1 = Vec::new();
let mut dest_cursor1 = Cursor::new(&mut destination1);
ShamirShare::reconstruct_stream(&mut sources1, &mut dest_cursor1).unwrap();
let mut sources2: Vec<Cursor<Vec<u8>>> = share_data_without_integrity[0..2]
.iter()
.map(|data| Cursor::new(data.clone()))
.collect();
let mut destination2 = Vec::new();
let mut dest_cursor2 = Cursor::new(&mut destination2);
ShamirShare::reconstruct_stream(&mut sources2, &mut dest_cursor2).unwrap();
assert_eq!(&destination1, data);
assert_eq!(&destination2, data);
}
#[test]
fn test_stream_large_data() {
use std::io::Cursor;
let config = Config::new().with_chunk_size(1024).unwrap();
let mut shamir = ShamirShare::builder(5, 3)
.with_config(config)
.build()
.unwrap();
let data: Vec<u8> = (0..10000).map(|i| (i % 256) as u8).collect();
let mut source = Cursor::new(&data);
let mut destinations = vec![Vec::new(); 5];
let mut dest_cursors: Vec<Cursor<Vec<u8>>> = destinations
.iter_mut()
.map(|d| Cursor::new(std::mem::take(d)))
.collect();
shamir.split_stream(&mut source, &mut dest_cursors).unwrap();
let share_data: Vec<Vec<u8>> = dest_cursors
.into_iter()
.map(|cursor| cursor.into_inner())
.collect();
let mut sources: Vec<Cursor<Vec<u8>>> = vec![
Cursor::new(share_data[0].clone()),
Cursor::new(share_data[2].clone()),
Cursor::new(share_data[4].clone()),
];
let mut destination = Vec::new();
let mut dest_cursor = Cursor::new(&mut destination);
ShamirShare::reconstruct_stream(&mut sources, &mut dest_cursor).unwrap();
assert_eq!(&destination, &data);
}
#[test]
fn test_dealer_basic_functionality() {
let secret = b"Hello, Dealer!";
let mut shamir = ShamirShare::builder(5, 3).build().unwrap();
let dealer_shares: Vec<Share> = shamir.dealer(secret).take(5).collect();
assert_eq!(dealer_shares.len(), 5);
for (i, share) in dealer_shares.iter().enumerate() {
assert_eq!(share.index, (i + 1) as u8);
assert_eq!(share.threshold, 3);
assert_eq!(share.total_shares, 5);
assert!(share.integrity_check); }
let reconstructed = ShamirShare::reconstruct(&dealer_shares[0..3]).unwrap();
assert_eq!(&reconstructed, secret);
let reconstructed = ShamirShare::reconstruct(&dealer_shares[1..5]).unwrap();
assert_eq!(&reconstructed, secret);
}
#[test]
fn test_dealer_vs_split_equivalence() {
let secret = b"Test equivalence between dealer and split";
let mut shamir = ShamirShare::builder(7, 4).build().unwrap();
let split_shares = shamir.split(secret).unwrap();
let dealer_shares: Vec<Share> = shamir.dealer(secret).take(7).collect();
assert_eq!(split_shares.len(), dealer_shares.len());
let reconstructed_split = ShamirShare::reconstruct(&split_shares[0..4]).unwrap();
let reconstructed_dealer = ShamirShare::reconstruct(&dealer_shares[0..4]).unwrap();
assert_eq!(&reconstructed_split, secret);
assert_eq!(&reconstructed_dealer, secret);
assert_eq!(reconstructed_split, reconstructed_dealer);
}
#[test]
fn test_dealer_lazy_evaluation() {
let secret = b"Lazy evaluation test";
let mut shamir = ShamirShare::builder(10, 5).build().unwrap();
let mut dealer = shamir.dealer(secret);
let first_three: Vec<Share> = dealer.by_ref().take(3).collect();
assert_eq!(first_three.len(), 3);
assert_eq!(first_three[0].index, 1);
assert_eq!(first_three[1].index, 2);
assert_eq!(first_three[2].index, 3);
let next_two: Vec<Share> = dealer.by_ref().take(2).collect();
assert_eq!(next_two.len(), 2);
assert_eq!(next_two[0].index, 4);
assert_eq!(next_two[1].index, 5);
let mut combined_shares = first_three;
combined_shares.extend(next_two);
let reconstructed = ShamirShare::reconstruct(&combined_shares).unwrap();
assert_eq!(&reconstructed, secret);
}
#[test]
fn test_dealer_max_shares_limit() {
let secret = b"Max shares test";
let mut shamir = ShamirShare::builder(255, 128).build().unwrap();
let dealer = shamir.dealer(secret);
let all_shares: Vec<Share> = dealer.collect();
assert_eq!(all_shares.len(), 255);
for (i, share) in all_shares.iter().enumerate() {
assert_eq!(share.index, (i + 1) as u8);
}
let reconstructed = ShamirShare::reconstruct(&all_shares[0..128]).unwrap();
assert_eq!(&reconstructed, secret);
}
#[test]
fn test_dealer_stops_at_255() {
let secret = b"Stop at 255 test";
let mut shamir = ShamirShare::builder(255, 128).build().unwrap();
let mut dealer = shamir.dealer(secret);
let shares: Vec<Share> = dealer.by_ref().collect();
assert_eq!(shares.len(), 255);
assert_eq!(dealer.next(), None);
assert_eq!(dealer.next(), None); }
#[test]
fn test_dealer_size_hint() {
let secret = b"Size hint test";
let mut shamir = ShamirShare::builder(10, 5).build().unwrap();
let mut dealer = shamir.dealer(secret);
assert_eq!(dealer.size_hint(), (255, Some(255)));
assert_eq!(dealer.len(), 255);
let _share = dealer.next().unwrap();
assert_eq!(dealer.size_hint(), (254, Some(254)));
assert_eq!(dealer.len(), 254);
let _shares: Vec<_> = dealer.by_ref().take(10).collect();
assert_eq!(dealer.size_hint(), (244, Some(244)));
assert_eq!(dealer.len(), 244);
}
#[test]
fn test_dealer_with_integrity_check_disabled() {
let config = Config::new().with_integrity_check(false);
let mut shamir = ShamirShare::builder(5, 3)
.with_config(config)
.build()
.unwrap();
let secret = b"No integrity check";
let dealer_shares: Vec<Share> = shamir.dealer(secret).take(5).collect();
for share in &dealer_shares {
assert!(!share.integrity_check);
}
let reconstructed = ShamirShare::reconstruct(&dealer_shares[0..3]).unwrap();
assert_eq!(&reconstructed, secret);
let split_shares = shamir.split(secret).unwrap();
let reconstructed_split = ShamirShare::reconstruct(&split_shares[0..3]).unwrap();
assert_eq!(reconstructed, reconstructed_split);
}
#[test]
fn test_dealer_empty_secret() {
let secret = b"";
let mut shamir = ShamirShare::builder(3, 2).build().unwrap();
let dealer_shares: Vec<Share> = shamir.dealer(secret).take(3).collect();
assert_eq!(dealer_shares.len(), 3);
let reconstructed = ShamirShare::reconstruct(&dealer_shares[0..2]).unwrap();
assert_eq!(&reconstructed, secret);
}
#[test]
fn test_dealer_single_byte_secret() {
let secret = b"x";
let mut shamir = ShamirShare::builder(5, 3).build().unwrap();
let dealer_shares: Vec<Share> = shamir.dealer(secret).take(5).collect();
assert_eq!(dealer_shares.len(), 5);
let reconstructed = ShamirShare::reconstruct(&dealer_shares[0..3]).unwrap();
assert_eq!(&reconstructed, secret);
}
#[test]
fn test_dealer_different_share_combinations() {
let secret = b"Different dealer combinations test";
let mut shamir = ShamirShare::builder(7, 4).build().unwrap();
let dealer_shares: Vec<Share> = shamir.dealer(secret).take(7).collect();
let combinations = vec![
vec![0, 1, 2, 3],
vec![1, 2, 3, 4],
vec![2, 3, 4, 5],
vec![0, 2, 4, 6],
vec![1, 3, 5, 6],
];
for combo in combinations {
let selected_shares: Vec<Share> =
combo.iter().map(|&i| dealer_shares[i].clone()).collect();
let reconstructed = ShamirShare::reconstruct(&selected_shares).unwrap();
assert_eq!(&reconstructed, secret);
}
}
#[test]
fn test_dealer_iterator_chain() {
let secret = b"Iterator chain test";
let mut shamir = ShamirShare::builder(10, 5).build().unwrap();
let even_indexed_shares: Vec<Share> = shamir
.dealer(secret)
.filter(|share| share.index % 2 == 0)
.take(5)
.collect();
assert_eq!(even_indexed_shares.len(), 5);
for share in &even_indexed_shares {
assert_eq!(share.index % 2, 0);
}
let reconstructed = ShamirShare::reconstruct(&even_indexed_shares).unwrap();
assert_eq!(&reconstructed, secret);
}
#[test]
#[cfg(feature = "zeroize")]
fn test_zeroize_feature_compilation() {
let secret = b"test secret for zeroize";
let mut shamir = ShamirShare::builder(5, 3).build().unwrap();
let shares = shamir.split(secret).unwrap();
assert_eq!(shares.len(), 5);
let dealer_shares: Vec<Share> = shamir.dealer(secret).take(3).collect();
assert_eq!(dealer_shares.len(), 3);
let reconstructed = ShamirShare::reconstruct(&shares[0..3]).unwrap();
assert_eq!(&reconstructed, secret);
let mut field = crate::FiniteField::new(42);
field.zeroize();
assert_eq!(field.0, 0);
}
#[test]
#[cfg(feature = "zeroize")]
fn test_share_zeroize_on_drop() {
use zeroize::Zeroize;
let secret = b"test secret for drop";
let mut shamir = ShamirShare::builder(3, 2).build().unwrap();
let share_data = {
let shares = shamir.split(secret).unwrap();
shares[0].data.clone()
};
assert!(!share_data.is_empty());
let mut shares = shamir.split(secret).unwrap();
let original_data = shares[0].data.clone();
shares[0].zeroize();
assert!(shares[0].data.iter().all(|&b| b == 0));
assert_ne!(original_data, shares[0].data);
}
}