1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
use crate::gf256::GF256;
use crate::interpolate::interpolate;
use crate::polynomial::Polynomial;
use crate::Field;
use rand::{thread_rng, Rng};
use subtle::ConstantTimeEq;

/// Represents a set of points which make up a share of a secret
pub struct Share<Fq: Field> {
    points: Vec<(Fq, Fq)>,
}

/// A set of functions for Shamir Secret Sharing
pub struct Shamir;

impl Shamir {
    /// Splits a secret (a byte slice) into n shares, of which only k are needed to recover the secret
    /// This requires 0 < k <= n < 256 and supports arbitrary sized secrets
    pub fn split(secret: &[u8], k: usize, n: usize) -> Result<Vec<Share<GF256>>, &'static str> {
        if n == 0 || n >= 255 {
            return Err("n must be between 1 and 255");
        } else if k == 0 || k > n {
            return Err("k must be between 1 and n");
        }

        let mut shares: Vec<Share<GF256>> = Vec::with_capacity(n);
        for _ in 0..n {
            shares.push(Share {
                points: Vec::with_capacity(secret.len()),
            });
        }

        let mut point_set = std::collections::HashSet::new();
        for &secret_byte in secret {
            let polynomial: Polynomial<GF256> =
                Polynomial::from_intercept(GF256(secret_byte), k - 1);
            for share in shares.iter_mut() {
                let mut random_byte = thread_rng().gen::<u8>();
                while point_set.contains(&random_byte) || random_byte.ct_eq(&0).into() {
                    random_byte = thread_rng().gen::<u8>();
                }
                // Need to verify no random byte is repeated
                point_set.insert(random_byte);

                let x = GF256(random_byte);
                let y = polynomial.evaluate_at(x);
                share.points.push((x, y));
            }
            point_set.clear();
        }

        Ok(shares)
    }

    // TODO: Find a more efficient means of combining
    /// Combines a slice of shares to recover the secret which is returned as a vector of bytes
    pub fn combine(shares: &[Share<GF256>]) -> Vec<u8> {
        let secret_size = shares.first().unwrap().points.len();
        let mut result = Vec::with_capacity(secret_size);

        for i in 0..secret_size {
            let points = shares
                .iter()
                .map(|share| share.points[i])
                .collect::<Vec<(GF256, GF256)>>();
            let secret_byte = interpolate(&points);
            result.push(secret_byte.0);
        }

        result
    }
}

#[cfg(test)]
mod test {
    use crate::shamir::Shamir;
    use rand::{thread_rng, Rng};

    #[test]
    fn test_split_and_combine_10_of_20() {
        test_split_and_combine(32, 10, 20);
    }

    #[test]
    fn test_split_and_combine_1_of_2() {
        test_split_and_combine(32, 1, 2);
    }

    #[test]
    fn test_split_and_combine_2_of_2() {
        test_split_and_combine(32, 2, 2);
    }

    #[test]
    fn test_split_and_combine_insufficient() {
        test_split_and_combine_insufficient_shares(32, 5, 10);
    }

    #[test]
    fn test_split_invalid() {
        let mut random_secret = Vec::with_capacity(32);
        (0..32).for_each(|_| random_secret.push(thread_rng().gen::<u8>()));

        let k_is_0 = Shamir::split(&random_secret, 0, 1);
        assert!(k_is_0.is_err());

        let n_is_0 = Shamir::split(&random_secret, 1, 0);
        assert!(n_is_0.is_err());

        let k_gt_n = Shamir::split(&random_secret, 10, 1);
        assert!(k_gt_n.is_err());

        let n_gt_255 = Shamir::split(&random_secret, 10, 256);
        assert!(n_gt_255.is_err());

        let k_and_n_gt_255 = Shamir::split(&random_secret, 256, 256);
        assert!(k_and_n_gt_255.is_err());
    }

    fn test_split_and_combine(size: usize, k: usize, n: usize) {
        let mut random_secret = Vec::with_capacity(size);
        (0..size).for_each(|_| random_secret.push(thread_rng().gen::<u8>()));

        let shares = Shamir::split(&random_secret, k, n).unwrap();
        let combined = Shamir::combine(&shares);

        assert_eq!(combined, random_secret);
    }

    fn test_split_and_combine_insufficient_shares(size: usize, k: usize, n: usize) {
        let mut random_secret = Vec::with_capacity(size);
        (0..size).for_each(|_| random_secret.push(thread_rng().gen::<u8>()));

        let shares = Shamir::split(&random_secret, k, n).unwrap();
        let combined = Shamir::combine(&shares[0..k - 2]);

        assert_ne!(combined, random_secret);
    }
}