use crate::{
errors::SSSError,
operations::{
big_nums_to_bytes, bytes_to_big_nums, evaluate, random, DEFAULT_PRIME,
U8S_TO_BIG_INT_INITIAL,
},
};
use openssl::bn::{BigNum, BigNumContext};
const COEFFICIENTS_PER_SHARE: usize = 2;
const COEFFICIENTS_SIZE: usize = COEFFICIENTS_PER_SHARE * U8S_TO_BIG_INT_INITIAL;
#[inline(always)]
pub(crate) fn create_shares(
min: usize,
shares: usize,
secret: &[u8],
) -> Result<Vec<Vec<u8>>, SSSError> {
if min > shares {
return Err(SSSError::WithReason(
"Minimum value cannot be bigger then total shares.".to_owned(),
));
}
let mut ctx = BigNumContext::new()?;
let prime = BigNum::from_dec_str(DEFAULT_PRIME)?;
let secret = bytes_to_big_nums(secret)?;
let mut polynomial: Vec<Vec<BigNum>> = Vec::with_capacity(secret.len());
for part in secret.iter() {
let mut coefficients = Vec::with_capacity(min - 1);
let temp = BigNum::from_slice(&part.to_vec())?;
coefficients.push(temp);
for _ in 1..min {
coefficients.push(random(&prime)?);
}
polynomial.push(coefficients);
}
let mut results: Vec<Vec<u8>> = Vec::with_capacity(shares);
for _ in 0..shares {
let mut bytes: Vec<u8> = Vec::with_capacity(secret.len() * COEFFICIENTS_SIZE);
let mut counter = 0;
while counter < secret.len() {
let coefficient_x = random(&prime)?;
let coefficient_y = evaluate(&mut ctx, &polynomial[counter], &coefficient_x, &prime)?;
let coefficients: &[BigNum; 2] = &[coefficient_x, coefficient_y];
bytes.extend(big_nums_to_bytes(coefficients));
counter += 1;
}
results.push(bytes);
}
Ok(results)
}
#[inline(always)]
pub(crate) fn combine_shares(shares: Vec<Vec<u8>>) -> Result<Vec<u8>, SSSError> {
let mut ctx = BigNumContext::new()?;
let negative_one = BigNum::from_dec_str("-1")?;
let prime = BigNum::from_dec_str(DEFAULT_PRIME)?;
let mut shares_polynomials: Vec<Vec<Vec<BigNum>>> = Vec::with_capacity(shares.len());
let mut first_share_count = None;
for share in shares.iter() {
if share.len() % COEFFICIENTS_SIZE != 0 {
return Err(SSSError::WithReason(format!(
"Share size is not divisible by {COEFFICIENTS_SIZE}"
)));
}
let share_count = share.len() / COEFFICIENTS_SIZE;
if let Some(first_share_count) = first_share_count {
if share_count != first_share_count {
return Err(SSSError::WithReason(format!(
"All shares shall have the same size of {first_share_count}"
)));
}
} else {
first_share_count = Some(share_count);
}
let mut polynomials: Vec<Vec<BigNum>> = Vec::with_capacity(share_count);
for i in 0..share_count {
let polynomial =
bytes_to_big_nums(&share[i * COEFFICIENTS_SIZE..(i + 1) * COEFFICIENTS_SIZE])?;
polynomials.push(polynomial);
}
shares_polynomials.push(polynomials);
}
let share_count = first_share_count.unwrap_or_default();
let mut pre_secret_coeffisiances: Vec<BigNum> = Vec::with_capacity(share_count);
for j in 0..share_count {
let mut candidate = BigNum::from_dec_str("0")?;
for (i, polys_i) in shares_polynomials.iter().enumerate() {
let origin = &polys_i[j][0];
let origin_y = &polys_i[j][1];
let mut numerator = BigNum::from_dec_str("1")?;
let mut denominator = BigNum::from_dec_str("1")?;
'k_iter: for (k, polys_k) in shares_polynomials.iter().enumerate() {
if k == i {
continue 'k_iter;
}
let current = &polys_k[j][0];
let mut negative = BigNum::from_dec_str("0")?;
negative.checked_mul(&negative_one, current, &mut ctx)?;
let mut added = BigNum::from_dec_str("0")?;
added.checked_sub(origin, current)?;
let mut temp = BigNum::new()?;
temp.checked_mul(&numerator, &negative, &mut ctx)?;
numerator.nnmod(&temp, &prime, &mut ctx)?;
let mut temp = BigNum::new()?;
temp.checked_mul(&denominator, &added, &mut ctx)?;
denominator.nnmod(&temp, &prime, &mut ctx)?;
}
let mut working = BigNum::from_dec_str("0")?;
working.checked_mul(origin_y, &numerator, &mut ctx)?;
let mut temp = BigNum::new()?;
temp.mod_inverse(&denominator, &prime, &mut ctx)?;
denominator = temp;
let mut temp = BigNum::new()?;
temp.checked_mul(&working, &denominator, &mut ctx)?;
working = temp;
let mut temp = BigNum::new()?;
temp.checked_add(&candidate, &working)?;
candidate.nnmod(&temp, &prime, &mut ctx)?;
}
pre_secret_coeffisiances.push(candidate);
}
Ok(big_nums_to_bytes(&pre_secret_coeffisiances))
}
#[cfg(test)]
mod tests {
use crate::{
combine_std, create_std,
errors::SSSError,
shamirss::{combine_shares, create_shares},
};
use openssl::rand::rand_bytes;
use rand::seq::SliceRandom;
use rand::thread_rng;
fn get_random_bytes(size: usize) -> Result<Vec<u8>, SSSError> {
let mut bytes = vec![0; size];
rand_bytes(&mut bytes)?;
Ok(bytes)
}
#[test]
fn it_should_create_shares_and_combine_shares_for_min_half_and_all_shares(
) -> Result<(), SSSError> {
struct TestCase<'a> {
min: usize,
shares: usize,
secret: &'a [u8],
fails: bool,
}
let test_cases: &[TestCase; 12] = &[
TestCase {
min: 4,
shares: 5,
secret: &get_random_bytes(128)?,
fails: false,
},
TestCase {
min: 10,
shares: 20,
secret: &get_random_bytes(256)?,
fails: false,
},
TestCase {
min: 20,
shares: 40,
secret: &get_random_bytes(512)?,
fails: false,
},
TestCase {
min: 40,
shares: 41,
secret: &get_random_bytes(128)?,
fails: false,
},
TestCase {
min: 40,
shares: 41,
secret: &get_random_bytes(256)?,
fails: false,
},
TestCase {
min: 40,
shares: 41,
secret: &get_random_bytes(512)?,
fails: false,
},
TestCase {
min: 40,
shares: 80,
secret: &get_random_bytes(128)?,
fails: false,
},
TestCase {
min: 40,
shares: 80,
secret: &get_random_bytes(256)?,
fails: false,
},
TestCase {
min: 40,
shares: 80,
secret: &get_random_bytes(512)?,
fails: false,
},
TestCase {
min: 100,
shares: 200,
secret: &get_random_bytes(128)?,
fails: false,
},
TestCase {
min: 100,
shares: 200,
secret: &get_random_bytes(256)?,
fails: false,
},
TestCase {
min: 100,
shares: 200,
secret: &get_random_bytes(512)?,
fails: false,
},
];
for c in test_cases.iter() {
let result = create_shares(c.min, c.shares, c.secret);
let mut shares = match result {
Ok(s) => {
assert!(!c.fails);
s
}
Err(_) => {
assert!(c.fails);
vec![]
}
};
if shares.is_empty() {
continue;
}
let mut rng = thread_rng();
shares.shuffle(&mut rng);
let secret = combine_shares(shares[0..c.min - 1].to_vec())?;
assert_ne!(secret, c.secret);
let secret = combine_shares(shares[0..c.min].to_vec())?;
assert_eq!(secret, c.secret);
let secret = combine_shares(shares[0..c.min + (c.shares - c.min) / 2].to_vec())?;
assert_eq!(secret, c.secret);
let secret = combine_shares(shares)?;
assert_eq!(secret, c.secret);
}
Ok(())
}
#[test]
fn it_should_create_shares_and_combine_shares_randomized_data() -> Result<(), SSSError> {
let mut size = 32;
let min = 50;
let shares_count = 100;
for _ in 0..8 {
let secret = get_random_bytes(size)?;
let secret_shares = create_std(min, shares_count, &secret)?;
let secret_decoded = combine_std(secret_shares)?;
assert_eq!(secret, secret_decoded);
size *= 2;
}
Ok(())
}
}