use crate::gf256::GF256;
use crate::interpolate::interpolate;
use crate::polynomial::Polynomial;
use crate::Field;
use rand::{thread_rng, Rng};
use subtle::ConstantTimeEq;
pub struct Share<Fq: Field> {
points: Vec<(Fq, Fq)>,
}
pub struct Shamir;
impl Shamir {
pub fn split(secret: &[u8], k: usize, n: usize) -> Result<Vec<Share<GF256>>, &'static str> {
if n == 0 || n >= 255 {
return Err("n must be between 1 and 255");
} else if k == 0 || k > n {
return Err("k must be between 1 and n");
}
let mut shares: Vec<Share<GF256>> = Vec::with_capacity(n);
for _ in 0..n {
shares.push(Share {
points: Vec::with_capacity(secret.len()),
});
}
let mut point_set = std::collections::HashSet::new();
for &secret_byte in secret {
let polynomial: Polynomial<GF256> =
Polynomial::from_intercept(GF256(secret_byte), k - 1);
for share in shares.iter_mut() {
let mut random_byte = thread_rng().gen::<u8>();
while point_set.contains(&random_byte) || random_byte.ct_eq(&0).into() {
random_byte = thread_rng().gen::<u8>();
}
point_set.insert(random_byte);
let x = GF256(random_byte);
let y = polynomial.evaluate_at(x);
share.points.push((x, y));
}
point_set.clear();
}
Ok(shares)
}
pub fn combine(shares: &[Share<GF256>]) -> Vec<u8> {
let secret_size = shares.first().unwrap().points.len();
let mut result = Vec::with_capacity(secret_size);
for i in 0..secret_size {
let points = shares
.iter()
.map(|share| share.points[i])
.collect::<Vec<(GF256, GF256)>>();
let secret_byte = interpolate(&points);
result.push(secret_byte.0);
}
result
}
}
#[cfg(test)]
mod test {
use crate::shamir::Shamir;
use rand::{thread_rng, Rng};
#[test]
fn test_split_and_combine_10_of_20() {
test_split_and_combine(32, 10, 20);
}
#[test]
fn test_split_and_combine_1_of_2() {
test_split_and_combine(32, 1, 2);
}
#[test]
fn test_split_and_combine_2_of_2() {
test_split_and_combine(32, 2, 2);
}
#[test]
fn test_split_and_combine_insufficient() {
test_split_and_combine_insufficient_shares(32, 5, 10);
}
#[test]
fn test_split_invalid() {
let mut random_secret = Vec::with_capacity(32);
(0..32).for_each(|_| random_secret.push(thread_rng().gen::<u8>()));
let k_is_0 = Shamir::split(&random_secret, 0, 1);
assert!(k_is_0.is_err());
let n_is_0 = Shamir::split(&random_secret, 1, 0);
assert!(n_is_0.is_err());
let k_gt_n = Shamir::split(&random_secret, 10, 1);
assert!(k_gt_n.is_err());
let n_gt_255 = Shamir::split(&random_secret, 10, 256);
assert!(n_gt_255.is_err());
let k_and_n_gt_255 = Shamir::split(&random_secret, 256, 256);
assert!(k_and_n_gt_255.is_err());
}
fn test_split_and_combine(size: usize, k: usize, n: usize) {
let mut random_secret = Vec::with_capacity(size);
(0..size).for_each(|_| random_secret.push(thread_rng().gen::<u8>()));
let shares = Shamir::split(&random_secret, k, n).unwrap();
let combined = Shamir::combine(&shares);
assert_eq!(combined, random_secret);
}
fn test_split_and_combine_insufficient_shares(size: usize, k: usize, n: usize) {
let mut random_secret = Vec::with_capacity(size);
(0..size).for_each(|_| random_secret.push(thread_rng().gen::<u8>()));
let shares = Shamir::split(&random_secret, k, n).unwrap();
let combined = Shamir::combine(&shares[0..k - 2]);
assert_ne!(combined, random_secret);
}
}