1use num_bigint_dig::algorithms::mod_inverse;
4use num_bigint_dig::prime::probably_prime;
5use num_bigint_dig::traits::ModInverse;
6use num_bigint_dig::{BigUint, IntoBigInt, IntoBigUint, RandPrime};
7use num_integer::Integer;
8use rand::prelude::*;
9use serde::{Deserialize, Serialize};
10use std::borrow::Cow;
11use std::convert::{From, TryFrom, TryInto};
12
13#[derive(Clone, Debug, PartialEq)]
14pub struct Prime {
15 prime: BigUint,
16}
17
18#[repr(transparent)]
20#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
21pub struct UncheckedPrime {
22 p: BigUint,
23}
24
25impl TryFrom<BigUint> for Prime {
26 type Error = String;
27
28 fn try_from(value: BigUint) -> Result<Self, Self::Error> {
29 if Prime::is_prime(&value) {
30 Ok(Prime { prime: value })
31 } else {
32 Err("Given number is not a prime".to_owned())
33 }
34 }
35}
36
37impl TryFrom<UncheckedPrime> for Prime {
38 type Error = String;
39
40 fn try_from(value: UncheckedPrime) -> Result<Self, Self::Error> {
41 value.p.try_into()
42 }
43}
44
45impl From<Prime> for UncheckedPrime {
46 fn from(value: Prime) -> Self {
47 UncheckedPrime { p: value.prime }
48 }
49}
50
51impl Prime {
52 fn is_prime(i: &BigUint) -> bool {
53 probably_prime(i, 256)
56 }
57
58 pub fn random<Rng: CryptoRng + RngCore>(num_bits: usize, rng: &mut Rng) -> Self {
59 Prime {
60 prime: rng.gen_prime(num_bits),
61 }
62 }
63
64 pub fn num_bits(&self) -> usize {
65 self.prime.bits()
66 }
67}
68
69#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
71pub struct UncheckedRsa {
72 e: BigUint,
73 d: BigUint,
74}
75
76impl From<Rsa> for UncheckedRsa {
77 fn from(value: Rsa) -> Self {
78 UncheckedRsa {
79 e: value.e,
80 d: value.d,
81 }
82 }
83}
84
85impl TryFrom<(UncheckedRsa, &RsaParameter)> for Rsa {
86 type Error = String;
87
88 fn try_from(value: (UncheckedRsa, &RsaParameter)) -> Result<Self, Self::Error> {
89 let (UncheckedRsa { e, d }, parameter) = value;
90 if e == 1u32.into() || e.gcd(¶meter.lambda_n) != 1u32.into() {
91 return Err("RSA encryption key not incorrect".to_owned());
92 }
93 if e.clone().mod_inverse(¶meter.lambda_n) != Some(d.clone().into_bigint().unwrap()) {
94 return Err("RSA decryption key incorrect".to_owned());
95 }
96 Ok(Rsa {
97 parameter: parameter.clone(),
98 e,
99 d,
100 })
101 }
102}
103
104#[derive(Clone, Debug, PartialEq)]
105pub struct Rsa {
106 parameter: RsaParameter,
107 e: BigUint,
108 d: BigUint,
109}
110
111#[derive(Clone, Debug, PartialEq)]
112pub struct RsaParameter {
113 n: BigUint,
114 lambda_n: BigUint,
115}
116
117impl RsaParameter {
118 pub fn from_primes(primes: &[Prime]) -> RsaParameter {
119 let lambda_n = primes
120 .iter()
121 .map(|p| &p.prime - &BigUint::from(1u32))
122 .fold(1u32.into(), |acc: BigUint, n: BigUint| acc.lcm(&n));
123 let n = primes
124 .iter()
125 .fold(1u32.into(), |acc: BigUint, p| acc * &p.prime);
126 RsaParameter { n, lambda_n }
127 }
128
129 pub fn n(&self) -> BigUint {
130 self.n.clone()
131 }
132
133 pub fn lambda_n(&self) -> BigUint {
134 self.lambda_n.clone()
135 }
136}
137
138impl Rsa {
139 pub fn get_e(&self) -> BigUint {
141 self.e.clone()
142 }
143
144 pub fn get_d(&self) -> BigUint {
146 self.d.clone()
147 }
148
149 pub fn encrypt(&self, message: BigUint) -> BigUint {
151 message.modpow(&self.e, &self.parameter.n)
152 }
153
154 pub fn decrypt(&self, message: BigUint) -> BigUint {
156 message.modpow(&self.d, &self.parameter.n)
157 }
158
159 pub fn gen_with_parameter<Rng: CryptoRng + RngCore>(
160 parameter: RsaParameter,
161 rng: &mut Rng,
162 ) -> Rsa {
163 let e = loop {
164 let num_bytes = (parameter.lambda_n.bits() + 7) / 8;
165 let mut number = vec![0u8; num_bytes as usize];
166 rng.fill_bytes(&mut number);
167 let number = BigUint::from_bytes_le(&number) % ¶meter.lambda_n;
169 if number.gcd(¶meter.lambda_n) == 1u32.into() {
170 break number;
171 }
172 };
173 let d = mod_inverse(Cow::Borrowed(&e), Cow::Borrowed(¶meter.lambda_n)).unwrap();
175 Rsa {
176 parameter,
177 e,
178 d: d.into_biguint().unwrap(),
179 }
180 }
181
182 pub fn from_e_d(e: BigUint, d: BigUint, parameter: RsaParameter) -> Result<Rsa, &'static str> {
183 if e >= parameter.lambda_n {
184 return Err("e has to be smaller than lambda_n");
185 }
186 if d >= parameter.lambda_n {
187 return Err("d has to be smaller than lambda_n");
188 }
189 if e.gcd(¶meter.lambda_n) != 1u32.into() {
190 return Err("invalid parameter e");
191 }
192 if d.clone().into_bigint()
193 != mod_inverse(Cow::Borrowed(&e), Cow::Borrowed(¶meter.lambda_n))
194 {
195 return Err("invalid parameter d");
196 }
197 Ok(Self { parameter, e, d })
198 }
199}
200
201#[cfg(test)]
202mod test {
203 use super::*;
204 use std::convert::TryInto;
205
206 #[test]
207 fn encrypt_decrypt() {
208 let rsa_parameter = RsaParameter {
209 n: BigUint::from(3233u32),
210 lambda_n: BigUint::from(780u32),
211 };
212 let key = Rsa {
213 parameter: rsa_parameter,
214 e: BigUint::from(17u32),
215 d: BigUint::from(413u32),
216 };
217 let m = BigUint::from(65u8);
218 let c = key.encrypt(m.clone());
219 assert_eq!(c, BigUint::from(2790u32));
220 let d = key.decrypt(c);
221 assert_eq!(d, m);
222 }
223
224 #[test]
225 fn from_e_d() {
226 let rsa_parameter = RsaParameter {
227 n: BigUint::from(3233u32),
228 lambda_n: BigUint::from(780u32),
229 };
230 let key =
231 Rsa::from_e_d(BigUint::from(17u32), BigUint::from(413u32), rsa_parameter).unwrap();
232 let m = BigUint::from(65u8);
233 let c = key.encrypt(m.clone());
234 assert_eq!(c, BigUint::from(2790u32));
235 let d = key.decrypt(c);
236 assert_eq!(d, m);
237 }
238
239 #[test]
240 fn generate_keys_1() {
241 let mut rng = rand::thread_rng();
242 let p = Prime::random(128, &mut rng);
243 let rsa_parameter = RsaParameter::from_primes(&[p]);
244 let key = Rsa::gen_with_parameter(rsa_parameter, &mut rng);
245 let m = BigUint::from_bytes_be(&[65u8, 66, 67, 68]);
246 let c = key.encrypt(m.clone());
247 let d = key.decrypt(c);
248 assert_eq!(d, m);
249 }
250
251 #[test]
252 fn generate_keys_2() {
253 let mut rng = rand::thread_rng();
254 let p = Prime::random(128, &mut rng);
255 let q = Prime::random(128, &mut rng);
256 let rsa_parameter = RsaParameter::from_primes(&[p, q]);
257 let key = Rsa::gen_with_parameter(rsa_parameter, &mut rng);
258 let m = BigUint::from_bytes_be(&[65u8, 66, 67, 68]);
259 let c = key.encrypt(m.clone());
260 let d = key.decrypt(c);
261 assert_eq!(d, m);
262 }
263
264 #[test]
265 fn serde() {
266 let mut rng = rand::thread_rng();
267 let p: UncheckedPrime = Prime::random(128, &mut rng).into();
268 let p_str = serde_json::to_string(&p).unwrap();
269 assert_eq!(p, serde_json::from_str(&p_str).unwrap())
270 }
271
272 #[test]
273 fn import() {
274 let mut rng = rand::thread_rng();
275 let ps = [Prime::random(128, &mut rng), Prime::random(128, &mut rng)];
276 let rsa_parameter = RsaParameter::from_primes(&ps);
277 let k = Rsa::gen_with_parameter(rsa_parameter.clone(), &mut rng);
278 let send_rsa: UncheckedRsa = k.clone().into();
279 assert_eq!(Ok(k), (send_rsa, &rsa_parameter).try_into())
280 }
281}