distributed_cards/
crypto.rs

1//! Crypto primitives necessary for distributed shuffling
2
3use num_bigint_dig::algorithms::mod_inverse;
4use num_bigint_dig::prime::probably_prime;
5use num_bigint_dig::traits::ModInverse;
6use num_bigint_dig::{BigUint, IntoBigInt, IntoBigUint, RandPrime};
7use num_integer::Integer;
8use rand::prelude::*;
9use serde::{Deserialize, Serialize};
10use std::borrow::Cow;
11use std::convert::{From, TryFrom, TryInto};
12
13#[derive(Clone, Debug, PartialEq)]
14pub struct Prime {
15    prime: BigUint,
16}
17
18/// Type for serialization and sending over the network
19#[repr(transparent)]
20#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
21pub struct UncheckedPrime {
22    p: BigUint,
23}
24
25impl TryFrom<BigUint> for Prime {
26    type Error = String;
27
28    fn try_from(value: BigUint) -> Result<Self, Self::Error> {
29        if Prime::is_prime(&value) {
30            Ok(Prime { prime: value })
31        } else {
32            Err("Given number is not a prime".to_owned())
33        }
34    }
35}
36
37impl TryFrom<UncheckedPrime> for Prime {
38    type Error = String;
39
40    fn try_from(value: UncheckedPrime) -> Result<Self, Self::Error> {
41        value.p.try_into()
42    }
43}
44
45impl From<Prime> for UncheckedPrime {
46    fn from(value: Prime) -> Self {
47        UncheckedPrime { p: value.prime }
48    }
49}
50
51impl Prime {
52    fn is_prime(i: &BigUint) -> bool {
53        // TODO: better/safer prime checking function? check if it divides first n (~= 40 primes)
54        // or https://crates.io/crates/glass_pumpkin
55        probably_prime(i, 256)
56    }
57
58    pub fn random<Rng: CryptoRng + RngCore>(num_bits: usize, rng: &mut Rng) -> Self {
59        Prime {
60            prime: rng.gen_prime(num_bits),
61        }
62    }
63
64    pub fn num_bits(&self) -> usize {
65        self.prime.bits()
66    }
67}
68
69/// Type for serialization and sending over the network
70#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
71pub struct UncheckedRsa {
72    e: BigUint,
73    d: BigUint,
74}
75
76impl From<Rsa> for UncheckedRsa {
77    fn from(value: Rsa) -> Self {
78        UncheckedRsa {
79            e: value.e,
80            d: value.d,
81        }
82    }
83}
84
85impl TryFrom<(UncheckedRsa, &RsaParameter)> for Rsa {
86    type Error = String;
87
88    fn try_from(value: (UncheckedRsa, &RsaParameter)) -> Result<Self, Self::Error> {
89        let (UncheckedRsa { e, d }, parameter) = value;
90        if e == 1u32.into() || e.gcd(&parameter.lambda_n) != 1u32.into() {
91            return Err("RSA encryption key not incorrect".to_owned());
92        }
93        if e.clone().mod_inverse(&parameter.lambda_n) != Some(d.clone().into_bigint().unwrap()) {
94            return Err("RSA decryption key incorrect".to_owned());
95        }
96        Ok(Rsa {
97            parameter: parameter.clone(),
98            e,
99            d,
100        })
101    }
102}
103
104#[derive(Clone, Debug, PartialEq)]
105pub struct Rsa {
106    parameter: RsaParameter,
107    e: BigUint,
108    d: BigUint,
109}
110
111#[derive(Clone, Debug, PartialEq)]
112pub struct RsaParameter {
113    n: BigUint,
114    lambda_n: BigUint,
115}
116
117impl RsaParameter {
118    pub fn from_primes(primes: &[Prime]) -> RsaParameter {
119        let lambda_n = primes
120            .iter()
121            .map(|p| &p.prime - &BigUint::from(1u32))
122            .fold(1u32.into(), |acc: BigUint, n: BigUint| acc.lcm(&n));
123        let n = primes
124            .iter()
125            .fold(1u32.into(), |acc: BigUint, p| acc * &p.prime);
126        RsaParameter { n, lambda_n }
127    }
128
129    pub fn n(&self) -> BigUint {
130        self.n.clone()
131    }
132
133    pub fn lambda_n(&self) -> BigUint {
134        self.lambda_n.clone()
135    }
136}
137
138impl Rsa {
139    /// get the encrypt key
140    pub fn get_e(&self) -> BigUint {
141        self.e.clone()
142    }
143
144    /// get the decrypt key
145    pub fn get_d(&self) -> BigUint {
146        self.d.clone()
147    }
148
149    /// encrypt given integer
150    pub fn encrypt(&self, message: BigUint) -> BigUint {
151        message.modpow(&self.e, &self.parameter.n)
152    }
153
154    /// decrypt given integer
155    pub fn decrypt(&self, message: BigUint) -> BigUint {
156        message.modpow(&self.d, &self.parameter.n)
157    }
158
159    pub fn gen_with_parameter<Rng: CryptoRng + RngCore>(
160        parameter: RsaParameter,
161        rng: &mut Rng,
162    ) -> Rsa {
163        let e = loop {
164            let num_bytes = (parameter.lambda_n.bits() + 7) / 8;
165            let mut number = vec![0u8; num_bytes as usize];
166            rng.fill_bytes(&mut number);
167            // ensure that e < lambda_n
168            let number = BigUint::from_bytes_le(&number) % &parameter.lambda_n;
169            if number.gcd(&parameter.lambda_n) == 1u32.into() {
170                break number;
171            }
172        };
173        // inverse exists, since gcd(e, lambda_n) == 1 (therefore e and lambda_n are coprime)
174        let d = mod_inverse(Cow::Borrowed(&e), Cow::Borrowed(&parameter.lambda_n)).unwrap();
175        Rsa {
176            parameter,
177            e,
178            d: d.into_biguint().unwrap(),
179        }
180    }
181
182    pub fn from_e_d(e: BigUint, d: BigUint, parameter: RsaParameter) -> Result<Rsa, &'static str> {
183        if e >= parameter.lambda_n {
184            return Err("e has to be smaller than lambda_n");
185        }
186        if d >= parameter.lambda_n {
187            return Err("d has to be smaller than lambda_n");
188        }
189        if e.gcd(&parameter.lambda_n) != 1u32.into() {
190            return Err("invalid parameter e");
191        }
192        if d.clone().into_bigint()
193            != mod_inverse(Cow::Borrowed(&e), Cow::Borrowed(&parameter.lambda_n))
194        {
195            return Err("invalid parameter d");
196        }
197        Ok(Self { parameter, e, d })
198    }
199}
200
201#[cfg(test)]
202mod test {
203    use super::*;
204    use std::convert::TryInto;
205
206    #[test]
207    fn encrypt_decrypt() {
208        let rsa_parameter = RsaParameter {
209            n: BigUint::from(3233u32),
210            lambda_n: BigUint::from(780u32),
211        };
212        let key = Rsa {
213            parameter: rsa_parameter,
214            e: BigUint::from(17u32),
215            d: BigUint::from(413u32),
216        };
217        let m = BigUint::from(65u8);
218        let c = key.encrypt(m.clone());
219        assert_eq!(c, BigUint::from(2790u32));
220        let d = key.decrypt(c);
221        assert_eq!(d, m);
222    }
223
224    #[test]
225    fn from_e_d() {
226        let rsa_parameter = RsaParameter {
227            n: BigUint::from(3233u32),
228            lambda_n: BigUint::from(780u32),
229        };
230        let key =
231            Rsa::from_e_d(BigUint::from(17u32), BigUint::from(413u32), rsa_parameter).unwrap();
232        let m = BigUint::from(65u8);
233        let c = key.encrypt(m.clone());
234        assert_eq!(c, BigUint::from(2790u32));
235        let d = key.decrypt(c);
236        assert_eq!(d, m);
237    }
238
239    #[test]
240    fn generate_keys_1() {
241        let mut rng = rand::thread_rng();
242        let p = Prime::random(128, &mut rng);
243        let rsa_parameter = RsaParameter::from_primes(&[p]);
244        let key = Rsa::gen_with_parameter(rsa_parameter, &mut rng);
245        let m = BigUint::from_bytes_be(&[65u8, 66, 67, 68]);
246        let c = key.encrypt(m.clone());
247        let d = key.decrypt(c);
248        assert_eq!(d, m);
249    }
250
251    #[test]
252    fn generate_keys_2() {
253        let mut rng = rand::thread_rng();
254        let p = Prime::random(128, &mut rng);
255        let q = Prime::random(128, &mut rng);
256        let rsa_parameter = RsaParameter::from_primes(&[p, q]);
257        let key = Rsa::gen_with_parameter(rsa_parameter, &mut rng);
258        let m = BigUint::from_bytes_be(&[65u8, 66, 67, 68]);
259        let c = key.encrypt(m.clone());
260        let d = key.decrypt(c);
261        assert_eq!(d, m);
262    }
263
264    #[test]
265    fn serde() {
266        let mut rng = rand::thread_rng();
267        let p: UncheckedPrime = Prime::random(128, &mut rng).into();
268        let p_str = serde_json::to_string(&p).unwrap();
269        assert_eq!(p, serde_json::from_str(&p_str).unwrap())
270    }
271
272    #[test]
273    fn import() {
274        let mut rng = rand::thread_rng();
275        let ps = [Prime::random(128, &mut rng), Prime::random(128, &mut rng)];
276        let rsa_parameter = RsaParameter::from_primes(&ps);
277        let k = Rsa::gen_with_parameter(rsa_parameter.clone(), &mut rng);
278        let send_rsa: UncheckedRsa = k.clone().into();
279        assert_eq!(Ok(k), (send_rsa, &rsa_parameter).try_into())
280    }
281}