use crate::error::{Result, ThresholdError};
use rand::RngCore;
use serde::{Deserialize, Serialize};
use zeroize::{Zeroize, ZeroizeOnDrop};
#[derive(Clone, Serialize, Deserialize, ZeroizeOnDrop)]
pub struct Share {
index: u8,
value: Vec<u8>,
}
impl Share {
pub fn new(index: u8, value: Vec<u8>) -> Self {
Self { index, value }
}
pub fn index(&self) -> u8 {
self.index
}
pub fn value(&self) -> &[u8] {
&self.value
}
pub fn value_mut(&mut self) -> &mut [u8] {
&mut self.value
}
pub fn to_bytes(&self) -> Vec<u8> {
let mut bytes = Vec::with_capacity(1 + self.value.len());
bytes.push(self.index);
bytes.extend_from_slice(&self.value);
bytes
}
pub fn from_bytes(bytes: &[u8]) -> Result<Self> {
if bytes.is_empty() {
return Err(ThresholdError::InvalidShareFormat);
}
Ok(Self {
index: bytes[0],
value: bytes[1..].to_vec(),
})
}
}
impl std::fmt::Debug for Share {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Share(index={}, {} bytes)", self.index, self.value.len())
}
}
pub struct ShamirScheme;
impl ShamirScheme {
#[must_use = "secret sharing result must be checked for errors"]
pub fn split(secret: &[u8], threshold: usize, total: usize) -> Result<Vec<Share>> {
if threshold == 0 || threshold > total {
return Err(ThresholdError::InvalidThreshold { threshold, total });
}
if total > 255 {
return Err(ThresholdError::InvalidThreshold { threshold, total });
}
if secret.is_empty() {
return Err(ThresholdError::InvalidShareFormat);
}
let mut shares: Vec<Share> = (1..=total as u8)
.map(|i| Share::new(i, vec![0u8; secret.len()]))
.collect();
let mut rng = rand::rngs::OsRng;
let mut coeffs = vec![0u8; threshold];
for (byte_idx, &secret_byte) in secret.iter().enumerate() {
coeffs[0] = secret_byte;
rng.fill_bytes(&mut coeffs[1..]);
for share in &mut shares {
let x = share.index;
let y = evaluate_polynomial(&coeffs, x);
share.value[byte_idx] = y;
}
}
coeffs.zeroize();
Ok(shares)
}
#[must_use = "secret reconstruction result must be checked for errors"]
pub fn combine(shares: &[Share]) -> Result<Vec<u8>> {
if shares.is_empty() {
return Err(ThresholdError::InsufficientShares {
required: 1,
provided: 0,
});
}
let mut seen = [false; 256];
for share in shares {
if seen[share.index as usize] {
return Err(ThresholdError::DuplicateShareIndex {
index: share.index as usize,
});
}
seen[share.index as usize] = true;
}
let len = shares[0].value.len();
if !shares.iter().all(|s| s.value.len() == len) {
return Err(ThresholdError::InvalidShareFormat);
}
let mut secret = vec![0u8; len];
for byte_idx in 0..len {
let points: Vec<(u8, u8)> = shares
.iter()
.map(|s| (s.index, s.value[byte_idx]))
.collect();
secret[byte_idx] = lagrange_interpolate(&points, 0);
}
Ok(secret)
}
pub fn verify_shares(shares: &[Share], threshold: usize) -> bool {
if shares.len() < threshold {
return false;
}
if let Ok(secret1) = Self::combine(&shares[..threshold]) {
if shares.len() > threshold
&& let Ok(secret2) = Self::combine(&shares[1..=threshold])
{
return secret1 == secret2;
}
return true;
}
false
}
}
fn evaluate_polynomial(coeffs: &[u8], x: u8) -> u8 {
if x == 0 {
return coeffs[0];
}
let mut result = 0u8;
let mut x_power = 1u8;
for &coeff in coeffs {
result = gf256_add(result, gf256_mul(coeff, x_power));
x_power = gf256_mul(x_power, x);
}
result
}
fn lagrange_interpolate(points: &[(u8, u8)], x: u8) -> u8 {
let mut result = 0u8;
for (i, &(xi, yi)) in points.iter().enumerate() {
let mut numerator = 1u8;
let mut denominator = 1u8;
for (j, &(xj, _)) in points.iter().enumerate() {
if i != j {
numerator = gf256_mul(numerator, gf256_sub(x, xj));
denominator = gf256_mul(denominator, gf256_sub(xi, xj));
}
}
let term = gf256_mul(yi, gf256_mul(numerator, gf256_inv(denominator)));
result = gf256_add(result, term);
}
result
}
#[inline]
fn gf256_add(a: u8, b: u8) -> u8 {
a ^ b
}
#[inline]
fn gf256_sub(a: u8, b: u8) -> u8 {
a ^ b
}
fn gf256_mul(a: u8, b: u8) -> u8 {
let mut result = 0u8;
let mut a = a;
let mut b = b;
while b != 0 {
if b & 1 != 0 {
result ^= a;
}
let carry = a & 0x80 != 0;
a <<= 1;
if carry {
a ^= 0x1b; }
b >>= 1;
}
result
}
fn gf256_inv(a: u8) -> u8 {
if a == 0 {
return 0; }
let mut result = a;
for _ in 0..6 {
result = gf256_mul(result, result);
result = gf256_mul(result, a);
}
gf256_mul(result, result)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_split_and_combine() {
let secret = b"Hello, Shamir!";
let shares = ShamirScheme::split(secret, 3, 5).unwrap();
assert_eq!(shares.len(), 5);
let recovered = ShamirScheme::combine(&shares[..3]).unwrap();
assert_eq!(secret.as_slice(), recovered.as_slice());
let recovered = ShamirScheme::combine(&shares).unwrap();
assert_eq!(secret.as_slice(), recovered.as_slice());
}
#[test]
fn test_different_share_subsets() {
let secret = b"test secret";
let shares = ShamirScheme::split(secret, 3, 5).unwrap();
let r1 = ShamirScheme::combine(&[shares[0].clone(), shares[1].clone(), shares[2].clone()])
.unwrap();
let r2 = ShamirScheme::combine(&[shares[0].clone(), shares[2].clone(), shares[4].clone()])
.unwrap();
let r3 = ShamirScheme::combine(&[shares[1].clone(), shares[3].clone(), shares[4].clone()])
.unwrap();
assert_eq!(secret.as_slice(), r1.as_slice());
assert_eq!(secret.as_slice(), r2.as_slice());
assert_eq!(secret.as_slice(), r3.as_slice());
}
#[test]
fn test_threshold_2_of_3() {
let secret = b"2-of-3";
let shares = ShamirScheme::split(secret, 2, 3).unwrap();
let recovered = ShamirScheme::combine(&shares[..2]).unwrap();
assert_eq!(secret.as_slice(), recovered.as_slice());
}
#[test]
fn test_threshold_1_of_n() {
let secret = b"no security";
let shares = ShamirScheme::split(secret, 1, 5).unwrap();
let recovered = ShamirScheme::combine(&shares[..1]).unwrap();
assert_eq!(secret.as_slice(), recovered.as_slice());
}
#[test]
fn test_invalid_threshold() {
let secret = b"test";
assert!(ShamirScheme::split(secret, 0, 5).is_err());
assert!(ShamirScheme::split(secret, 6, 5).is_err());
}
#[test]
fn test_duplicate_indices() {
let share1 = Share::new(1, vec![1, 2, 3]);
let share2 = Share::new(1, vec![4, 5, 6]);
assert!(ShamirScheme::combine(&[share1, share2]).is_err());
}
#[test]
fn test_share_serialization() {
let share = Share::new(42, vec![1, 2, 3, 4, 5]);
let bytes = share.to_bytes();
let restored = Share::from_bytes(&bytes).unwrap();
assert_eq!(share.index, restored.index);
assert_eq!(share.value, restored.value);
}
#[test]
fn test_gf256_operations() {
assert_eq!(gf256_mul(1, 42), 42);
assert_eq!(gf256_mul(42, 1), 42);
for a in 1..=255u8 {
let inv = gf256_inv(a);
assert_eq!(gf256_mul(a, inv), 1, "Failed for a={}", a);
}
}
#[test]
fn test_large_secret() {
let secret: Vec<u8> = (0..1000).map(|i| i as u8).collect();
let shares = ShamirScheme::split(&secret, 5, 10).unwrap();
let recovered = ShamirScheme::combine(&shares[..5]).unwrap();
assert_eq!(secret, recovered);
}
#[test]
fn test_verify_shares() {
let secret = b"verify me";
let shares = ShamirScheme::split(secret, 3, 5).unwrap();
assert!(ShamirScheme::verify_shares(&shares, 3));
}
}