mysql_connector/utils/
crypt.rs1use {
2 num::BigUint,
3 rand::{CryptoRng, Rng},
4 sha1::{Digest, Sha1},
5};
6
7#[derive(Debug)]
8pub enum Error {
9 InvalidPem,
10 MessageTooLong,
11}
12
13mod der {
14 use {
15 super::{Error, PublicKey},
16 base64::{engine::general_purpose::STANDARD, Engine as _},
17 num::BigUint,
18 };
19
20 fn eat_len(der: &mut &[u8]) -> Result<usize, Error> {
21 if der[0] & 0x80 == 0x80 {
22 const BITS: usize = (usize::BITS / 8) as usize;
23 let len = (der[0] & (!0x80)) as usize;
24 if len > BITS {
25 return Err(Error::InvalidPem);
26 }
27 let mut bytes = [0u8; BITS];
28 bytes[BITS - len..].copy_from_slice(&der[1..=len]);
29 *der = &der[len + 1..];
30 Ok(usize::from_be_bytes(bytes))
31 } else {
32 let len = der[0] as usize;
33 *der = &der[1..];
34 Ok(len)
35 }
36 }
37
38 fn eat_uint(der: &mut &[u8]) -> Result<BigUint, Error> {
39 if der[0] != 0x02 {
40 return Err(Error::InvalidPem);
41 }
42 *der = &der[1..];
43 let len = eat_len(der)?;
44 let uint = BigUint::from_bytes_be(&der[..len]);
45 *der = &der[len..];
46 Ok(uint)
47 }
48
49 fn eat_sequence<'a>(der: &mut &'a [u8]) -> Result<&'a [u8], Error> {
50 if der[0] != 0x30 {
51 return Err(Error::InvalidPem);
52 }
53 *der = &der[1..];
54 let len = eat_len(der)?;
55 let sequence = &der[..len];
56 *der = &der[len..];
57 Ok(sequence)
58 }
59
60 fn eat_bit_string<'a>(der: &mut &'a [u8]) -> Result<(u8, &'a [u8]), Error> {
61 if der[0] != 0x03 {
62 return Err(Error::InvalidPem);
63 }
64 *der = &der[1..];
65 let len = eat_len(der)?;
66 let unused_bits = der[0];
67 let bit_string = &der[1..len];
68 *der = &der[len..];
69 Ok((unused_bits, bit_string))
70 }
71
72 impl PublicKey {
73 pub fn try_from_pkcs1(mut der: &[u8]) -> Result<Self, Error> {
74 let mut pub_key = eat_sequence(&mut der)?;
75 let modulus = eat_uint(&mut pub_key)?;
76 let exponent = eat_uint(&mut pub_key)?;
77 Ok(Self { modulus, exponent })
78 }
79
80 pub fn try_from_pkcs8(mut der: &[u8]) -> Result<Self, Error> {
81 let mut seq_data = eat_sequence(&mut der)?;
82 eat_sequence(&mut seq_data)?;
83 let (unused_bits, pub_key) = eat_bit_string(&mut seq_data)?;
84 if unused_bits != 0 {
85 return Err(Error::InvalidPem);
86 }
87 Self::try_from_pkcs1(pub_key)
88 }
89
90 pub fn try_from_pem(pem: &[u8]) -> Result<Self, Error> {
91 const PKCS1: (&[u8], &[u8]) =
92 (b"-----BEGINRSAPUBLICKEY-----", b"-----ENDRSAPUBLICKEY-----");
93 const PKCS8: (&[u8], &[u8]) = (b"-----BEGINPUBLICKEY-----", b"-----ENDPUBLICKEY-----");
94
95 let pem: Vec<u8> = pem
96 .iter()
97 .filter(|x| !b" \n\t\r\x0b\x0c".contains(x))
98 .cloned()
99 .collect();
100
101 let (body, is_pkcs_1) = if pem.starts_with(PKCS1.0) && pem.ends_with(PKCS1.1) {
102 (&pem[PKCS1.0.len()..pem.len() - PKCS1.1.len()], true)
103 } else if pem.starts_with(PKCS8.0) && pem.ends_with(PKCS8.1) {
104 (&pem[PKCS8.0.len()..pem.len() - PKCS8.1.len()], false)
105 } else {
106 return Err(Error::InvalidPem);
107 };
108
109 let body = STANDARD.decode(body).map_err(|_| Error::InvalidPem)?;
110 match is_pkcs_1 {
111 true => Self::try_from_pkcs1(&body),
112 false => Self::try_from_pkcs8(&body),
113 }
114 }
115 }
116}
117
118#[derive(Debug)]
119pub struct PublicKey {
120 modulus: BigUint,
121 exponent: BigUint,
122}
123
124impl PublicKey {
125 pub const fn new(modulus: BigUint, exponent: BigUint) -> Self {
126 Self { modulus, exponent }
127 }
128
129 pub fn num_octets(&self) -> usize {
130 (self.modulus().bits() as usize + 6) >> 3
131 }
132
133 pub fn modulus(&self) -> &BigUint {
134 &self.modulus
135 }
136
137 pub fn exponent(&self) -> &BigUint {
138 &self.exponent
139 }
140
141 pub fn encrypt_padded<R: Rng + CryptoRng>(
142 &self,
143 data: &[u8],
144 mut padding: OaepPadding<R>,
145 ) -> Result<Vec<u8>, Error> {
146 let octets = self.num_octets();
147 let padded = BigUint::from_bytes_be(&padding.pad(data, octets)?);
148 let mut encrypted = padded.modpow(self.exponent(), self.modulus()).to_bytes_be();
149
150 let fill = octets - encrypted.len();
151 if fill > 0 {
152 let mut encrypted_new = vec![0u8; octets];
153 encrypted_new[fill..].copy_from_slice(&encrypted);
154 encrypted = encrypted_new;
155 }
156 Ok(encrypted)
157 }
158}
159
160pub struct OaepPadding<R: Rng + CryptoRng> {
161 rng: R,
162}
163
164impl<R: Rng + CryptoRng> OaepPadding<R> {
165 const HASH_LEN: usize = 20;
166
167 pub fn new(rng: R) -> Self {
168 Self { rng }
169 }
170
171 fn mgf1(seed: &[u8], len: usize) -> Result<Vec<u8>, Error> {
172 #[cfg(target_pointer_width = "64")]
173 if len > Self::HASH_LEN << 32 {
174 return Err(Error::MessageTooLong);
175 }
176
177 let mut output = vec![0u8; len];
178 let mut hash_source = vec![0u8; seed.len() + 4];
179 hash_source[0..seed.len()].copy_from_slice(seed);
180
181 for i in 0..(len / Self::HASH_LEN) {
182 hash_source[seed.len()..].copy_from_slice(&(i as u32).to_be_bytes());
183 let pos = i * Self::HASH_LEN;
184 output[pos..pos + Self::HASH_LEN].copy_from_slice(&Sha1::digest(&hash_source));
185 }
186
187 let remaining = len % Self::HASH_LEN;
188 if remaining > 0 {
189 hash_source[seed.len()..]
190 .copy_from_slice(&((len / Self::HASH_LEN) as u32).to_be_bytes());
191 output[len - remaining..].copy_from_slice(&Sha1::digest(&hash_source)[..remaining]);
192 }
193 Ok(output)
194 }
195
196 pub fn pad(&mut self, data: &[u8], n: usize) -> Result<Vec<u8>, Error> {
215 let seed_len = Self::HASH_LEN;
216 if n < 1 + seed_len + Self::HASH_LEN + 2 + 1 + data.len() {
217 return Err(Error::MessageTooLong);
218 }
219
220 let msg_len = n - seed_len - 1;
221 let filling_len = msg_len - data.len();
222
223 let mut padded = vec![0u8; n];
224 let (seed, msg) = padded[1..].split_at_mut(seed_len);
225 {
226 for byte in seed.iter_mut() {
227 *byte = self.rng.gen();
228 }
229 let (filling, msg_data) = msg.split_at_mut(filling_len);
230 filling[0..Self::HASH_LEN].copy_from_slice(&Sha1::digest([]));
231 filling[filling_len - 1] = 0x01;
232 msg_data.copy_from_slice(data);
233 }
234
235 let msg_mask = Self::mgf1(seed, msg.len())?;
236 for i in 0..msg.len() {
237 msg[i] ^= msg_mask[i];
238 }
239
240 let seed_mask = Self::mgf1(msg, seed_len)?;
241 for i in 0..seed_len {
242 seed[i] ^= seed_mask[i];
243 }
244
245 Ok(padded)
246 }
247}