relay-crypto 0.2.0-beta.4

The crypto library for the Relay Ecosystem.
Documentation
use rand_core::{CryptoRng, RngCore};

pub const SECRET_LENGTH: usize = 32;
pub const ID_LENGTH: usize = 16;
pub const ENCODED_LENGTH: usize = 3 + SECRET_LENGTH + ID_LENGTH;

#[derive(Debug, PartialEq, Eq, thiserror::Error)]
pub enum ShamirError {
    #[error("Threshold must be more than 1, but got {0}")]
    InvalidThreshold(u32),
    #[error("Share count must be between 1 and 255, but got {0}")]
    InvalidShareCount(u32),
    #[error("Share count {0} must be at least the threshold {1}")]
    LessSharesThanThreshhold(u32, u32),
    #[error(
        "Not enough shares to reconstruct the secret. Provided {0} shares, but threshold is {1}"
    )]
    NotEnoughShares(u32, u32),
    #[error("Inconsistent shares provided")]
    InconsistentShares,
    #[error("Duplicate share index: {0} and {1}")]
    DuplicateShareIndex(u8, u8),
}

#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Share {
    pub index: u8,
    pub threshold: u8,
    pub share_count: u8,
    pub value: [u8; SECRET_LENGTH],
}

impl Share {
    pub fn encode(&self, id: [u8; ID_LENGTH]) -> [u8; ENCODED_LENGTH] {
        let mut out = [0u8; ENCODED_LENGTH];
        out[0] = self.index;
        out[1] = self.threshold;
        out[2] = self.share_count;
        out[3..3 + SECRET_LENGTH].copy_from_slice(&self.value);
        out[3 + SECRET_LENGTH..].copy_from_slice(&id);
        out
    }

    pub fn decode(encoded: &[u8; ENCODED_LENGTH]) -> (Self, [u8; ID_LENGTH]) {
        let index = encoded[0];
        let threshold = encoded[1];
        let share_count = encoded[2];
        let mut value = [0u8; SECRET_LENGTH];
        value.copy_from_slice(&encoded[3..3 + SECRET_LENGTH]);
        let mut id = [0u8; ID_LENGTH];
        id.copy_from_slice(&encoded[3 + SECRET_LENGTH..]);
        (
            Self {
                index,
                threshold,
                share_count,
                value,
            },
            id,
        )
    }
}

pub fn split_secret(
    mut rng: impl RngCore + CryptoRng,
    secret: [u8; SECRET_LENGTH],
    threshold: u8,
    share_count: u8,
) -> Result<Vec<Share>, ShamirError> {
    if threshold < 1 {
        return Err(ShamirError::InvalidThreshold(threshold as u32));
    }
    if share_count == 0 {
        return Err(ShamirError::InvalidShareCount(share_count as u32));
    }
    if share_count < threshold {
        return Err(ShamirError::LessSharesThanThreshhold(
            share_count as u32,
            threshold as u32,
        ));
    }

    let threshold = threshold as usize;

    let mut coefficients = vec![[0u8; SECRET_LENGTH]; threshold];
    coefficients[0] = secret;

    for coeff in coefficients.iter_mut().skip(1) {
        rng.fill_bytes(coeff);
    }

    let mut output = Vec::with_capacity(share_count as usize);

    for x in 1..=share_count {
        let mut ys = [0u8; SECRET_LENGTH];

        for (byte_idx, y_out) in ys.iter_mut().enumerate() {
            let mut acc = coefficients[threshold - 1][byte_idx];
            for k in (0..threshold - 1).rev() {
                acc = gf256::add(gf256::mul(acc, x), coefficients[k][byte_idx]);
            }
            *y_out = acc;
        }

        output.push(Share {
            index: x,
            threshold: threshold as u8,
            share_count,
            value: ys,
        });
    }

    Ok(output)
}

pub fn combine_shares(shares: &[Share]) -> Result<[u8; SECRET_LENGTH], ShamirError> {
    if shares.is_empty() {
        return Err(ShamirError::NotEnoughShares(0, 0));
    }

    let threshold = shares[0].threshold as usize;

    for share in shares.iter() {
        if share.threshold != shares[0].threshold
            || share.share_count != shares[0].share_count
            || share.index == 0
        {
            return Err(ShamirError::InconsistentShares);
        }
    }

    if shares.len() < threshold {
        return Err(ShamirError::NotEnoughShares(
            shares.len() as u32,
            threshold as u32,
        ));
    }

    for i in 0..shares.len() {
        for j in i + 1..shares.len() {
            if shares[i].index == shares[j].index {
                return Err(ShamirError::DuplicateShareIndex(
                    shares[i].index,
                    shares[j].index,
                ));
            }
        }
    }

    let used = &shares[0..threshold];

    let mut secret = [0u8; SECRET_LENGTH];

    for (byte_idx, out_byte) in secret.iter_mut().enumerate() {
        let mut acc = 0u8;

        for (i, share) in used.iter().enumerate() {
            let xi = share.index;
            let yi = share.value[byte_idx];

            let mut num = 1u8;
            let mut den = 1u8;

            for (m, other_share) in used.iter().enumerate() {
                if m == i {
                    continue;
                }

                let xm = other_share.index;
                num = gf256::mul(num, xm);
                den = gf256::mul(den, gf256::add(xm, xi));
            }

            let li0 = gf256::mul(num, gf256::inv(den));
            acc = gf256::add(acc, gf256::mul(yi, li0));
        }

        *out_byte = acc;
    }

    Ok(secret)
}

pub mod gf256 {
    pub fn add(a: u8, b: u8) -> u8 {
        a ^ b
    }

    pub fn mul(mut a: u8, mut b: u8) -> u8 {
        let mut p = 0u8;

        for _ in 0..8 {
            if (b & 1) != 0 {
                p ^= a;
            }

            let hi = a & 0x80;
            a <<= 1;
            if hi != 0 {
                a ^= 0x1B;
            }
            b >>= 1;
        }

        p
    }

    pub fn pow(mut a: u8, mut e: u8) -> u8 {
        let mut r = 1u8;

        while e != 0 {
            if (e & 1) != 0 {
                r = mul(r, a);
            }

            a = mul(a, a);
            e >>= 1;
        }

        r
    }

    pub fn inv(a: u8) -> u8 {
        assert!(a != 0);
        pow(a, 254)
    }
}

#[cfg(test)]
mod tests {
    use crate::rng::os_rng_hkdf;

    use super::*;

    fn random_secret() -> [u8; SECRET_LENGTH] {
        let mut s = [0u8; SECRET_LENGTH];
        getrandom::fill(&mut s).unwrap();
        s
    }

    #[test]
    fn split_and_recombine_3_of_5() {
        let mut rng = os_rng_hkdf(None, b"").unwrap();
        let secret = random_secret();

        let shares = split_secret(&mut rng, secret, 3, 5).unwrap();

        let recovered = combine_shares(&shares[0..3]).unwrap();
        assert_eq!(secret, recovered);
    }

    #[test]
    fn recombinations_3_of_5() {
        let mut rng = os_rng_hkdf(None, b"").unwrap();
        let secret = random_secret();

        let shares = split_secret(&mut rng, secret, 3, 5).unwrap();

        for i in 0..5 {
            for j in (i + 1)..5 {
                for k in (j + 1)..5 {
                    let combo = vec![shares[i].clone(), shares[j].clone(), shares[k].clone()];
                    let recovered = combine_shares(&combo).unwrap();
                    assert_eq!(secret, recovered);
                }
            }
        }
    }

    #[test]
    fn rejects_not_enough_shares() {
        let mut rng = os_rng_hkdf(None, b"").unwrap();
        let secret = random_secret();

        let shares = split_secret(&mut rng, secret, 3, 5).unwrap();

        let result = combine_shares(&shares[0..2]);
        assert!(result.is_err());
    }

    #[test]
    fn fails_with_duplicate_indices() {
        let mut rng = os_rng_hkdf(None, b"").unwrap();
        let secret = random_secret();

        let mut shares = split_secret(&mut rng, secret, 3, 5).unwrap();

        shares[1].index = shares[0].index;

        let result = combine_shares(&shares[0..3]);
        assert!(result.is_err());
    }

    #[test]
    fn rejects_inconsistent_threshold() {
        let mut rng = os_rng_hkdf(None, b"").unwrap();
        let secret = random_secret();

        let mut shares = split_secret(&mut rng, secret, 3, 5).unwrap();

        shares[0].threshold = 2;

        let result = combine_shares(&shares[0..3]);
        assert!(result.is_err());
    }

    #[test]
    fn rejects_inconsistent_share_count() {
        let mut rng = os_rng_hkdf(None, b"").unwrap();
        let secret = random_secret();

        let mut shares = split_secret(&mut rng, secret, 3, 5).unwrap();

        shares[0].share_count = 9;

        let result = combine_shares(&shares[0..3]);
        assert!(result.is_err());
    }

    #[test]
    fn multiple_random_runs() {
        let mut rng = os_rng_hkdf(None, b"").unwrap();

        for _ in 0..50 {
            let secret = random_secret();
            let shares = split_secret(&mut rng, secret, 4, 7).unwrap();

            let recovered = combine_shares(&shares[1..5]).unwrap();
            assert_eq!(secret, recovered);
        }
    }

    #[test]
    fn encode_decode() {
        let mut rng = os_rng_hkdf(None, b"").unwrap();

        let mut value = [0u8; SECRET_LENGTH];
        rng.fill_bytes(&mut value);
        let mut id = [0u8; ID_LENGTH];
        rng.fill_bytes(&mut id);

        let share = Share {
            index: 3,
            threshold: 4,
            share_count: 7,
            value,
        };

        let encoded = share.encode(id);
        let (decoded_share, decoded_id) = Share::decode(&encoded);
        assert_eq!(share, decoded_share);
        assert_eq!(id, decoded_id);
    }
}