mysql_connector/utils/
crypt.rs

1use {
2    num::BigUint,
3    rand::{CryptoRng, Rng},
4    sha1::{Digest, Sha1},
5};
6
7#[derive(Debug)]
8pub enum Error {
9    InvalidPem,
10    MessageTooLong,
11}
12
13mod der {
14    use {
15        super::{Error, PublicKey},
16        base64::{engine::general_purpose::STANDARD, Engine as _},
17        num::BigUint,
18    };
19
20    fn eat_len(der: &mut &[u8]) -> Result<usize, Error> {
21        if der[0] & 0x80 == 0x80 {
22            const BITS: usize = (usize::BITS / 8) as usize;
23            let len = (der[0] & (!0x80)) as usize;
24            if len > BITS {
25                return Err(Error::InvalidPem);
26            }
27            let mut bytes = [0u8; BITS];
28            bytes[BITS - len..].copy_from_slice(&der[1..=len]);
29            *der = &der[len + 1..];
30            Ok(usize::from_be_bytes(bytes))
31        } else {
32            let len = der[0] as usize;
33            *der = &der[1..];
34            Ok(len)
35        }
36    }
37
38    fn eat_uint(der: &mut &[u8]) -> Result<BigUint, Error> {
39        if der[0] != 0x02 {
40            return Err(Error::InvalidPem);
41        }
42        *der = &der[1..];
43        let len = eat_len(der)?;
44        let uint = BigUint::from_bytes_be(&der[..len]);
45        *der = &der[len..];
46        Ok(uint)
47    }
48
49    fn eat_sequence<'a>(der: &mut &'a [u8]) -> Result<&'a [u8], Error> {
50        if der[0] != 0x30 {
51            return Err(Error::InvalidPem);
52        }
53        *der = &der[1..];
54        let len = eat_len(der)?;
55        let sequence = &der[..len];
56        *der = &der[len..];
57        Ok(sequence)
58    }
59
60    fn eat_bit_string<'a>(der: &mut &'a [u8]) -> Result<(u8, &'a [u8]), Error> {
61        if der[0] != 0x03 {
62            return Err(Error::InvalidPem);
63        }
64        *der = &der[1..];
65        let len = eat_len(der)?;
66        let unused_bits = der[0];
67        let bit_string = &der[1..len];
68        *der = &der[len..];
69        Ok((unused_bits, bit_string))
70    }
71
72    impl PublicKey {
73        pub fn try_from_pkcs1(mut der: &[u8]) -> Result<Self, Error> {
74            let mut pub_key = eat_sequence(&mut der)?;
75            let modulus = eat_uint(&mut pub_key)?;
76            let exponent = eat_uint(&mut pub_key)?;
77            Ok(Self { modulus, exponent })
78        }
79
80        pub fn try_from_pkcs8(mut der: &[u8]) -> Result<Self, Error> {
81            let mut seq_data = eat_sequence(&mut der)?;
82            eat_sequence(&mut seq_data)?;
83            let (unused_bits, pub_key) = eat_bit_string(&mut seq_data)?;
84            if unused_bits != 0 {
85                return Err(Error::InvalidPem);
86            }
87            Self::try_from_pkcs1(pub_key)
88        }
89
90        pub fn try_from_pem(pem: &[u8]) -> Result<Self, Error> {
91            const PKCS1: (&[u8], &[u8]) =
92                (b"-----BEGINRSAPUBLICKEY-----", b"-----ENDRSAPUBLICKEY-----");
93            const PKCS8: (&[u8], &[u8]) = (b"-----BEGINPUBLICKEY-----", b"-----ENDPUBLICKEY-----");
94
95            let pem: Vec<u8> = pem
96                .iter()
97                .filter(|x| !b" \n\t\r\x0b\x0c".contains(x))
98                .cloned()
99                .collect();
100
101            let (body, is_pkcs_1) = if pem.starts_with(PKCS1.0) && pem.ends_with(PKCS1.1) {
102                (&pem[PKCS1.0.len()..pem.len() - PKCS1.1.len()], true)
103            } else if pem.starts_with(PKCS8.0) && pem.ends_with(PKCS8.1) {
104                (&pem[PKCS8.0.len()..pem.len() - PKCS8.1.len()], false)
105            } else {
106                return Err(Error::InvalidPem);
107            };
108
109            let body = STANDARD.decode(body).map_err(|_| Error::InvalidPem)?;
110            match is_pkcs_1 {
111                true => Self::try_from_pkcs1(&body),
112                false => Self::try_from_pkcs8(&body),
113            }
114        }
115    }
116}
117
118#[derive(Debug)]
119pub struct PublicKey {
120    modulus: BigUint,
121    exponent: BigUint,
122}
123
124impl PublicKey {
125    pub const fn new(modulus: BigUint, exponent: BigUint) -> Self {
126        Self { modulus, exponent }
127    }
128
129    pub fn num_octets(&self) -> usize {
130        (self.modulus().bits() as usize + 6) >> 3
131    }
132
133    pub fn modulus(&self) -> &BigUint {
134        &self.modulus
135    }
136
137    pub fn exponent(&self) -> &BigUint {
138        &self.exponent
139    }
140
141    pub fn encrypt_padded<R: Rng + CryptoRng>(
142        &self,
143        data: &[u8],
144        mut padding: OaepPadding<R>,
145    ) -> Result<Vec<u8>, Error> {
146        let octets = self.num_octets();
147        let padded = BigUint::from_bytes_be(&padding.pad(data, octets)?);
148        let mut encrypted = padded.modpow(self.exponent(), self.modulus()).to_bytes_be();
149
150        let fill = octets - encrypted.len();
151        if fill > 0 {
152            let mut encrypted_new = vec![0u8; octets];
153            encrypted_new[fill..].copy_from_slice(&encrypted);
154            encrypted = encrypted_new;
155        }
156        Ok(encrypted)
157    }
158}
159
160pub struct OaepPadding<R: Rng + CryptoRng> {
161    rng: R,
162}
163
164impl<R: Rng + CryptoRng> OaepPadding<R> {
165    const HASH_LEN: usize = 20;
166
167    pub fn new(rng: R) -> Self {
168        Self { rng }
169    }
170
171    fn mgf1(seed: &[u8], len: usize) -> Result<Vec<u8>, Error> {
172        #[cfg(target_pointer_width = "64")]
173        if len > Self::HASH_LEN << 32 {
174            return Err(Error::MessageTooLong);
175        }
176
177        let mut output = vec![0u8; len];
178        let mut hash_source = vec![0u8; seed.len() + 4];
179        hash_source[0..seed.len()].copy_from_slice(seed);
180
181        for i in 0..(len / Self::HASH_LEN) {
182            hash_source[seed.len()..].copy_from_slice(&(i as u32).to_be_bytes());
183            let pos = i * Self::HASH_LEN;
184            output[pos..pos + Self::HASH_LEN].copy_from_slice(&Sha1::digest(&hash_source));
185        }
186
187        let remaining = len % Self::HASH_LEN;
188        if remaining > 0 {
189            hash_source[seed.len()..]
190                .copy_from_slice(&((len / Self::HASH_LEN) as u32).to_be_bytes());
191            output[len - remaining..].copy_from_slice(&Sha1::digest(&hash_source)[..remaining]);
192        }
193        Ok(output)
194    }
195
196    /// Pads data according to RFC 8017.
197    ///
198    /// Returns an error if data is too long.
199    ///
200    ///  ```text
201    ///                                                  msg_len
202    ///                           ┣━━━━━━━ filling_len ━━┻━━━━━━━━━━━━━┫
203    ///                           ┣━━━━━━━━┻━━━━━━━━━━━━━━━━━━━┫
204    ///                     seed  hash([]) || 0x00..0x00 || 0x01 || data
205    ///                     ━┳━━  ━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━
206    ///                      ┃              ┃
207    ///                      ┣━━━> MGF ━━> xor
208    ///                      ┃             ━┳━
209    ///                     xor <━━ MGF <━━━┫
210    ///                     ━┳━             ┃
211    ///                      ┃              ┃
212    /// padded = 0x00 || masked seed || masked msg
213    /// ```
214    pub fn pad(&mut self, data: &[u8], n: usize) -> Result<Vec<u8>, Error> {
215        let seed_len = Self::HASH_LEN;
216        if n < 1 + seed_len + Self::HASH_LEN + 2 + 1 + data.len() {
217            return Err(Error::MessageTooLong);
218        }
219
220        let msg_len = n - seed_len - 1;
221        let filling_len = msg_len - data.len();
222
223        let mut padded = vec![0u8; n];
224        let (seed, msg) = padded[1..].split_at_mut(seed_len);
225        {
226            for byte in seed.iter_mut() {
227                *byte = self.rng.gen();
228            }
229            let (filling, msg_data) = msg.split_at_mut(filling_len);
230            filling[0..Self::HASH_LEN].copy_from_slice(&Sha1::digest([]));
231            filling[filling_len - 1] = 0x01;
232            msg_data.copy_from_slice(data);
233        }
234
235        let msg_mask = Self::mgf1(seed, msg.len())?;
236        for i in 0..msg.len() {
237            msg[i] ^= msg_mask[i];
238        }
239
240        let seed_mask = Self::mgf1(msg, seed_len)?;
241        for i in 0..seed_len {
242            seed[i] ^= seed_mask[i];
243        }
244
245        Ok(padded)
246    }
247}