relay-crypto 0.2.0-alpha.7

The crypto library for the Relay Ecosystem.
Documentation
use rand_core::{CryptoRng, RngCore};

use super::CryptoError;

pub const SECRET_LENGTH: usize = 32;
pub const ID_LENGTH: usize = 16;
pub const ENCODED_LENGTH: usize = 3 + SECRET_LENGTH + ID_LENGTH;

#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Share {
    pub index: u8,
    pub threshold: u8,
    pub share_count: u8,
    pub value: [u8; SECRET_LENGTH],
}

impl Share {
    pub fn encode(&self, id: [u8; ID_LENGTH]) -> [u8; ENCODED_LENGTH] {
        let mut out = [0u8; ENCODED_LENGTH];
        out[0] = self.index;
        out[1] = self.threshold;
        out[2] = self.share_count;
        out[3..3 + SECRET_LENGTH].copy_from_slice(&self.value);
        out[3 + SECRET_LENGTH..].copy_from_slice(&id);
        out
    }

    pub fn decode(encoded: &[u8; ENCODED_LENGTH]) -> (Self, [u8; ID_LENGTH]) {
        let index = encoded[0];
        let threshold = encoded[1];
        let share_count = encoded[2];
        let mut value = [0u8; SECRET_LENGTH];
        value.copy_from_slice(&encoded[3..3 + SECRET_LENGTH]);
        let mut id = [0u8; ID_LENGTH];
        id.copy_from_slice(&encoded[3 + SECRET_LENGTH..]);
        (
            Self {
                index,
                threshold,
                share_count,
                value,
            },
            id,
        )
    }
}

pub fn split_secret(
    mut rng: impl RngCore + CryptoRng,
    secret: [u8; SECRET_LENGTH],
    threshold: u8,
    share_count: u8,
) -> Result<Vec<Share>, CryptoError> {
    if threshold < 1 {
        return Err(CryptoError::ThresholdTooLow(threshold as u32));
    }
    if share_count == 0 {
        return Err(CryptoError::ShareCountError(share_count as u32));
    }
    if share_count < threshold {
        return Err(CryptoError::ThresholdMoreThanShareCound(
            share_count as u32,
            threshold as u32,
        ));
    }

    let threshold = threshold as usize;

    let mut coefficients = vec![[0u8; SECRET_LENGTH]; threshold];
    coefficients[0] = secret;

    for coeff in coefficients.iter_mut().skip(1) {
        rng.fill_bytes(coeff);
    }

    let mut output = Vec::with_capacity(share_count as usize);

    for x in 1..=share_count {
        let mut ys = [0u8; SECRET_LENGTH];

        for (byte_idx, y_out) in ys.iter_mut().enumerate() {
            let mut acc = coefficients[threshold - 1][byte_idx];
            for k in (0..threshold - 1).rev() {
                acc = gf256_add(gf256_mul(acc, x), coefficients[k][byte_idx]);
            }
            *y_out = acc;
        }

        output.push(Share {
            index: x,
            threshold: threshold as u8,
            share_count,
            value: ys,
        });
    }

    Ok(output)
}

pub fn combine_shares(shares: &[Share]) -> Result<[u8; SECRET_LENGTH], CryptoError> {
    if shares.is_empty() {
        return Err(CryptoError::NotEnoughShares(0, 0));
    }

    let threshold = shares[0].threshold as usize;

    for share in shares.iter() {
        if share.threshold != shares[0].threshold
            || share.share_count != shares[0].share_count
            || share.index == 0
        {
            return Err(CryptoError::InconsistentShares);
        }
    }

    if shares.len() < threshold {
        return Err(CryptoError::NotEnoughShares(
            shares.len() as u32,
            threshold as u32,
        ));
    }

    for i in 0..shares.len() {
        for j in i + 1..shares.len() {
            if shares[i].index == shares[j].index {
                return Err(CryptoError::DuplicateShareIndex(
                    shares[i].index,
                    shares[j].index,
                ));
            }
        }
    }

    let used = &shares[0..threshold];

    let mut secret = [0u8; SECRET_LENGTH];

    for (byte_idx, out_byte) in secret.iter_mut().enumerate() {
        let mut acc = 0u8;

        for (i, share) in used.iter().enumerate() {
            let xi = share.index;
            let yi = share.value[byte_idx];

            let mut num = 1u8;
            let mut den = 1u8;

            for (m, other_share) in used.iter().enumerate() {
                if m == i {
                    continue;
                }

                let xm = other_share.index;
                num = gf256_mul(num, xm);
                den = gf256_mul(den, gf256_add(xm, xi));
            }

            let li0 = gf256_mul(num, gf256_inv(den));
            acc = gf256_add(acc, gf256_mul(yi, li0));
        }

        *out_byte = acc;
    }

    Ok(secret)
}

fn gf256_add(a: u8, b: u8) -> u8 {
    a ^ b
}

fn gf256_mul(mut a: u8, mut b: u8) -> u8 {
    let mut p = 0u8;

    for _ in 0..8 {
        if (b & 1) != 0 {
            p ^= a;
        }

        let hi = a & 0x80;
        a <<= 1;
        if hi != 0 {
            a ^= 0x1B;
        }
        b >>= 1;
    }

    p
}

fn gf256_pow(mut a: u8, mut e: u8) -> u8 {
    let mut r = 1u8;

    while e != 0 {
        if (e & 1) != 0 {
            r = gf256_mul(r, a);
        }

        a = gf256_mul(a, a);
        e >>= 1;
    }

    r
}

fn gf256_inv(a: u8) -> u8 {
    assert!(a != 0);
    gf256_pow(a, 254)
}

#[cfg(test)]
mod tests {
    use super::*;
    use rand_core::OsRng;

    fn random_secret() -> [u8; SECRET_LENGTH] {
        let mut rng = OsRng;
        let mut s = [0u8; SECRET_LENGTH];
        rng.fill_bytes(&mut s);
        s
    }

    #[test]
    fn split_and_recombine_3_of_5() {
        let mut rng = OsRng;
        let secret = random_secret();

        let shares = split_secret(&mut rng, secret, 3, 5).unwrap();

        let recovered = combine_shares(&shares[0..3]).unwrap();
        assert_eq!(secret, recovered);
    }

    #[test]
    fn recombinations_3_of_5() {
        let mut rng = OsRng;
        let secret = random_secret();

        let shares = split_secret(&mut rng, secret, 3, 5).unwrap();

        for i in 0..5 {
            for j in (i + 1)..5 {
                for k in (j + 1)..5 {
                    let combo = vec![shares[i].clone(), shares[j].clone(), shares[k].clone()];
                    let recovered = combine_shares(&combo).unwrap();
                    assert_eq!(secret, recovered);
                }
            }
        }
    }

    #[test]
    fn rejects_not_enough_shares() {
        let mut rng = OsRng;
        let secret = random_secret();

        let shares = split_secret(&mut rng, secret, 3, 5).unwrap();

        let result = combine_shares(&shares[0..2]);
        assert!(result.is_err());
    }

    #[test]
    fn fails_with_duplicate_indices() {
        let mut rng = OsRng;
        let secret = random_secret();

        let mut shares = split_secret(&mut rng, secret, 3, 5).unwrap();

        shares[1].index = shares[0].index;

        let result = combine_shares(&shares[0..3]);
        assert!(result.is_err());
    }

    #[test]
    fn rejects_inconsistent_threshold() {
        let mut rng = OsRng;
        let secret = random_secret();

        let mut shares = split_secret(&mut rng, secret, 3, 5).unwrap();

        shares[0].threshold = 2;

        let result = combine_shares(&shares[0..3]);
        assert!(result.is_err());
    }

    #[test]
    fn rejects_inconsistent_share_count() {
        let mut rng = OsRng;
        let secret = random_secret();

        let mut shares = split_secret(&mut rng, secret, 3, 5).unwrap();

        shares[0].share_count = 9;

        let result = combine_shares(&shares[0..3]);
        assert!(result.is_err());
    }

    #[test]
    fn multiple_random_runs() {
        let mut rng = OsRng;

        for _ in 0..50 {
            let secret = random_secret();
            let shares = split_secret(&mut rng, secret, 4, 7).unwrap();

            let recovered = combine_shares(&shares[1..5]).unwrap();
            assert_eq!(secret, recovered);
        }
    }

    #[test]
    fn encode_decode() {
        let mut rng = OsRng;

        let mut value = [0u8; SECRET_LENGTH];
        rng.fill_bytes(&mut value);
        let mut id = [0u8; ID_LENGTH];
        rng.fill_bytes(&mut id);

        let share = Share {
            index: 3,
            threshold: 4,
            share_count: 7,
            value,
        };

        let encoded = share.encode(id);
        let (decoded_share, decoded_id) = Share::decode(&encoded);
        assert_eq!(share, decoded_share);
        assert_eq!(id, decoded_id);
    }
}