Skip to main content

picky_krb/crypto/
diffie_hellman.rs

1use crypto_bigint::modular::{BoxedMontyForm, BoxedMontyParams};
2use crypto_bigint::{BoxedUint, Odd, RandomBits, Resize};
3use rand::TryRngCore;
4use sha1::{Digest, Sha1};
5use thiserror::Error;
6
7use crate::crypto::Cipher;
8
9#[derive(Error, Debug)]
10pub enum DiffieHellmanError {
11    #[error("Invalid bit len: {0}")]
12    BitLen(String),
13    #[error("Invalid data len: expected at least {0} but got {1}.")]
14    DataLen(usize, usize),
15    #[error("modulus is not odd")]
16    ModulusIsNotOdd,
17}
18
19pub type DiffieHellmanResult<T> = Result<T, DiffieHellmanError>;
20
21/// [Using Diffie-Hellman Key Exchange](https://www.rfc-editor.org/rfc/rfc4556.html#section-3.2.3.1)
22/// K-truncate truncates its input to the first K bits
23fn k_truncate(k: usize, mut data: Vec<u8>) -> DiffieHellmanResult<Vec<u8>> {
24    if k % 8 != 0 {
25        return Err(DiffieHellmanError::BitLen(format!(
26            "Seed bit len must be a multiple of 8. Got: {k}"
27        )));
28    }
29
30    let bytes_len = k / 8;
31
32    if bytes_len > data.len() {
33        return Err(DiffieHellmanError::DataLen(bytes_len, data.len()));
34    }
35
36    data.resize(bytes_len, 0);
37
38    Ok(data)
39}
40
41/// [Using Diffie-Hellman Key Exchange](https://www.rfc-editor.org/rfc/rfc4556.html#section-3.2.3.1)
42/// octetstring2key(x) == random-to-key(K-truncate(
43///                          SHA1(0x00 | x) |
44///                          SHA1(0x01 | x) |
45///                          SHA1(0x02 | x) |
46///                          ...
47///                          ))
48fn octet_string_to_key(x: &[u8], cipher: &dyn Cipher) -> DiffieHellmanResult<Vec<u8>> {
49    let seed_len = cipher.seed_bit_len() / 8;
50
51    let mut key = Vec::new();
52
53    let mut i = 0;
54    while key.len() < seed_len {
55        let mut data = vec![i];
56        data.extend_from_slice(x);
57
58        let mut sha1 = Sha1::new();
59        sha1.update(data);
60
61        key.extend_from_slice(sha1.finalize().as_slice());
62        i += 1;
63    }
64
65    Ok(cipher.random_to_key(k_truncate(seed_len * 8, key)?))
66}
67
68pub struct DhNonce<'a> {
69    pub client_nonce: &'a [u8],
70    pub server_nonce: &'a [u8],
71}
72
73/// [Using Diffie-Hellman Key Exchange](https://www.rfc-editor.org/rfc/rfc4556.html#section-3.2.3.1)
74/// let n_c be the clientDHNonce and n_k be the serverDHNonce; otherwise, let both n_c and n_k be empty octet strings.
75/// k = octetstring2key(DHSharedSecret | n_c | n_k)
76pub fn generate_key_from_shared_secret(
77    mut dh_shared_secret: Vec<u8>,
78    nonce: Option<DhNonce>,
79    cipher: &dyn Cipher,
80) -> DiffieHellmanResult<Vec<u8>> {
81    if let Some(DhNonce {
82        client_nonce,
83        server_nonce,
84    }) = nonce
85    {
86        dh_shared_secret.extend_from_slice(client_nonce);
87        dh_shared_secret.extend_from_slice(server_nonce);
88    }
89
90    octet_string_to_key(&dh_shared_secret, cipher)
91}
92
93/// [Using Diffie-Hellman Key Exchange](https://www.rfc-editor.org/rfc/rfc4556.html#section-3.2.3.1)
94/// let DHSharedSecret be the shared secret value. DHSharedSecret is the value ZZ
95///
96/// [Generation of ZZ](https://www.rfc-editor.org/rfc/rfc2631#section-2.1.1)
97/// ZZ = g ^ (xb * xa) mod p
98/// ZZ = (yb ^ xa)  mod p  = (ya ^ xb)  mod p
99/// where ^ denotes exponentiation
100fn generate_dh_shared_secret(public_key: &[u8], private_key: &[u8], p: &[u8]) -> DiffieHellmanResult<Vec<u8>> {
101    let public_key = BoxedUint::from_be_slice_vartime(public_key);
102    let private_key = BoxedUint::from_be_slice_vartime(private_key);
103    let p = Odd::new(BoxedUint::from_be_slice_vartime(p))
104        .into_option()
105        .ok_or(DiffieHellmanError::ModulusIsNotOdd)?;
106    let p = BoxedMontyParams::new_vartime(p);
107
108    // ZZ = (public_key ^ private_key) mod p
109    let out = pow_mod_params(&public_key, &private_key, &p);
110    Ok(out.to_be_bytes().to_vec())
111}
112
113//= [Using Diffie-Hellman Key Exchange](https://www.rfc-editor.org/rfc/rfc4556.html#section-3.2.3.1) =//
114pub fn generate_key(
115    public_key: &[u8],
116    private_key: &[u8],
117    modulus: &[u8],
118    nonce: Option<DhNonce>,
119    cipher: &dyn Cipher,
120) -> DiffieHellmanResult<Vec<u8>> {
121    let dh_shared_secret = generate_dh_shared_secret(public_key, private_key, modulus)?;
122    generate_key_from_shared_secret(dh_shared_secret, nonce, cipher)
123}
124
125/// [Key and Parameter Requirements](https://www.rfc-editor.org/rfc/rfc2631#section-2.2)
126/// X9.42 requires that the private key x be in the interval [2, (q - 2)]
127pub fn generate_private_key<R: TryRngCore>(q: &[u8], rng: &mut R) -> Result<Vec<u8>, R::Error> {
128    let q = BoxedUint::from_be_slice_vartime(q);
129    let low_bound = BoxedUint::from_be_slice_vartime(&[2]);
130    let high_bound = q - 1_u32;
131
132    let min_bits = low_bound.bits();
133    let max_bits = high_bound.bits();
134    loop {
135        let bit_length = rng.try_next_u32()? % (max_bits - min_bits) + min_bits;
136        let x = BoxedUint::random_bits(rng, bit_length);
137
138        if (&low_bound..&high_bound).contains(&&x) {
139            return Ok(x.to_be_bytes().into_vec());
140        }
141    }
142}
143
144/// [Key and Parameter Requirements](https://www.rfc-editor.org/rfc/rfc2631#section-2.2)
145/// y is then computed by calculating g^x mod p.
146pub fn compute_public_key(private_key: &[u8], modulus: &[u8], base: &[u8]) -> DiffieHellmanResult<Vec<u8>> {
147    generate_dh_shared_secret(base, private_key, modulus)
148}
149
150// Copied from `rsa` crate: https://github.com/RustCrypto/RSA/blob/eb1cca7b7ea42445dc874c1c1ce38873e4adade7/src/algorithms/rsa.rs#L232-L241
151fn pow_mod_params(base: &BoxedUint, exp: &BoxedUint, n_params: &BoxedMontyParams) -> BoxedUint {
152    let base = reduce_vartime(base, n_params);
153    base.pow(exp).retrieve()
154}
155
156fn reduce_vartime(n: &BoxedUint, p: &BoxedMontyParams) -> BoxedMontyForm {
157    let modulus = p.modulus().as_nz_ref().clone();
158    let n_reduced = n.rem_vartime(&modulus).resize_unchecked(p.bits_precision());
159    BoxedMontyForm::new(n_reduced, p)
160}