#![cfg_attr(not(feature = "std"), no_std)]
mod field;
mod math;
mod share;
extern crate alloc;
use alloc::vec::Vec;
use hashbrown::HashSet;
use field::GF256;
pub use field::PRIMITIVE_POLYS;
pub use share::Share;
pub struct SecretSharing<const POLY: u16>(pub u8);
impl<const POLY: u16> SecretSharing<POLY> {
pub fn dealer_rng<R: rand::Rng>(
&self,
secret: &[u8],
rng: &mut R,
) -> impl Iterator<Item = Share<POLY>> {
let mut polys = Vec::with_capacity(secret.len());
for chunk in secret {
polys.push(math::random_polynomial(GF256(*chunk), self.0, rng))
}
math::get_evaluator(polys)
}
#[cfg(feature = "std")]
pub fn dealer(&self, secret: &[u8]) -> impl Iterator<Item = Share<POLY>> {
let mut rng = rand::thread_rng();
self.dealer_rng(secret, &mut rng)
}
pub fn recover<'a, T>(&self, shares: T) -> Result<Vec<u8>, &str>
where
T: IntoIterator<Item = &'a Share<POLY>>,
T::IntoIter: Iterator<Item = &'a Share<POLY>>,
{
let mut share_length: Option<usize> = None;
let mut keys: HashSet<u8> = HashSet::new();
let mut values: Vec<Share<POLY>> = Vec::new();
for share in shares.into_iter() {
if share_length.is_none() {
share_length = Some(share.y.len());
}
if Some(share.y.len()) != share_length {
return Err("All shares must have the same length");
} else {
keys.insert(share.x.0);
values.push(share.clone());
}
}
if keys.is_empty() || (keys.len() < self.0 as usize) {
Err("Not enough shares to recover original secret")
} else {
Ok(math::interpolate(values.as_slice()))
}
}
pub fn recover_shares<'a, T>(&self, shares: T, n: usize) -> Result<Vec<Share<POLY>>, &str>
where
T: IntoIterator<Item = Option<&'a Share<POLY>>>,
T::IntoIter: Iterator<Item = Option<&'a Share<POLY>>>,
{
let mut share_length: Option<usize> = None;
let mut keys: HashSet<u8> = HashSet::new();
let mut values: Vec<Share<POLY>> = Vec::new();
let mut count = 0;
for share in shares.into_iter() {
if share.is_none() {
count += 1;
continue;
}
let share = share.unwrap();
if share_length.is_none() {
share_length = Some(share.y.len());
}
if Some(share.y.len()) != share_length {
return Err("All shares must have the same length");
} else {
keys.insert(share.x.0);
values.push(share.clone());
count += 1;
}
}
if count != n {
return Err("provide a shares array of size n; use None for unknown shares");
}
if keys.is_empty() || (keys.len() < self.0 as usize) {
Err("Not enough shares to recover original shares")
} else if self.0 == 1 {
Ok(values.iter().cloned().cycle().take(n).collect())
} else {
Ok((1..=n).map(|i| math::reshare(&values, i)).collect())
}
}
}
#[cfg(test)]
mod tests {
use super::{SecretSharing, Share};
use alloc::{vec, vec::Vec};
const POLY: u16 = 0x11d_u16;
impl<const POLY: u16> SecretSharing<POLY> {
#[cfg(not(feature = "std"))]
fn make_shares(&self, secret: &[u8]) -> impl Iterator<Item = Share<POLY>> {
use rand_chacha::{rand_core::SeedableRng, ChaCha8Rng};
let mut rng = ChaCha8Rng::from_seed([0x90; 32]);
self.dealer_rng(secret, &mut rng)
}
#[cfg(feature = "std")]
fn make_shares(&self, secret: &[u8]) -> impl Iterator<Item = Share<POLY>> {
self.dealer(secret)
}
}
#[test]
fn test_insufficient_shares_err() {
let sss = SecretSharing::<POLY>(255);
let shares: Vec<Share<POLY>> = sss.make_shares(&[1]).take(254).collect();
let secret = sss.recover(&shares);
assert!(secret.is_err());
}
#[test]
fn test_duplicate_shares_err() {
let sss = SecretSharing::<POLY>(255);
let mut shares: Vec<Share<POLY>> = sss.make_shares(&[1]).take(255).collect();
shares[1] = Share {
x: shares[0].x.clone(),
y: shares[0].y.clone(),
};
let secret = sss.recover(&shares);
assert!(secret.is_err());
}
#[test]
fn test_integration_works() {
let sss = SecretSharing::<POLY>(255);
let shares: Vec<Share<POLY>> = sss.make_shares(&[1, 2, 3, 4]).take(255).collect();
let secret = sss.recover(&shares).unwrap();
assert_eq!(secret, vec![1, 2, 3, 4]);
}
#[test]
fn test_reshare_works() {
let sss = SecretSharing::<POLY>(3);
let shares: Vec<Share<POLY>> = sss.make_shares(&[1, 2, 3, 4]).take(4).collect();
let recovered_shares = sss
.recover_shares(
[Some(&shares[0]), None, Some(&shares[2]), Some(&shares[3])],
4,
)
.unwrap();
assert_eq!(recovered_shares.len(), 4);
for (recovered_share, share) in recovered_shares.iter().zip(shares.iter()) {
assert_eq!(recovered_share.x, share.x);
assert_eq!(recovered_share.y, share.y);
}
let recovered_shares = sss
.recover_shares(
[None, Some(&shares[1]), Some(&shares[2]), Some(&shares[3])],
4,
)
.unwrap();
assert_eq!(recovered_shares.len(), 4);
for (recovered_share, share) in recovered_shares.iter().zip(shares.iter()) {
assert_eq!(recovered_share.x, share.x);
assert_eq!(recovered_share.y, share.y);
}
let recovered_shares = sss
.recover_shares(
[Some(&shares[0]), Some(&shares[1]), Some(&shares[2]), None],
4,
)
.unwrap();
assert_eq!(recovered_shares.len(), 4);
for (recovered_share, share) in recovered_shares.iter().zip(shares.iter()) {
assert_eq!(recovered_share.x, share.x);
assert_eq!(recovered_share.y, share.y);
}
let recovered_shares =
sss.recover_shares([Some(&shares[0]), None, None, Some(&shares[3])], 4);
assert!(recovered_shares.is_err());
}
}