hpke_rs_rust_crypto/
lib.rs

1#![doc = include_str!("../Readme.md")]
2#![cfg_attr(not(test), no_std)]
3
4extern crate alloc;
5
6use alloc::{string::String, vec::Vec};
7use core::fmt::Display;
8
9use hpke_rs_crypto::{
10    error::Error,
11    types::{AeadAlgorithm, KdfAlgorithm, KemAlgorithm},
12    CryptoRng, HpkeCrypto, HpkeTestRng, RngCore,
13};
14use p256::{
15    elliptic_curve::ecdh::diffie_hellman as p256diffie_hellman, PublicKey as p256PublicKey,
16    SecretKey as p256SecretKey,
17};
18
19use k256::{
20    elliptic_curve::{ecdh::diffie_hellman as k256diffie_hellman, sec1::ToEncodedPoint},
21    PublicKey as k256PublicKey, SecretKey as k256SecretKey,
22};
23
24use rand_core::SeedableRng;
25use x25519_dalek::{PublicKey as X25519PublicKey, StaticSecret as X25519StaticSecret};
26
27mod aead;
28mod hkdf;
29use crate::aead::*;
30use crate::hkdf::*;
31
32/// The Rust Crypto HPKE Provider
33#[derive(Debug)]
34pub struct HpkeRustCrypto {}
35
36/// The PRNG for the Rust Crypto Provider.
37pub struct HpkeRustCryptoPrng {
38    rng: rand_chacha::ChaCha20Rng,
39    #[cfg(feature = "deterministic-prng")]
40    fake_rng: Vec<u8>,
41}
42
43impl HpkeCrypto for HpkeRustCrypto {
44    fn name() -> String {
45        "RustCrypto".into()
46    }
47
48    fn kdf_extract(alg: KdfAlgorithm, salt: &[u8], ikm: &[u8]) -> Result<Vec<u8>, Error> {
49        Ok(match alg {
50            KdfAlgorithm::HkdfSha256 => sha256_extract(salt, ikm),
51            KdfAlgorithm::HkdfSha384 => sha384_extract(salt, ikm),
52            KdfAlgorithm::HkdfSha512 => sha512_extract(salt, ikm),
53        })
54    }
55
56    fn kdf_expand(
57        alg: KdfAlgorithm,
58        prk: &[u8],
59        info: &[u8],
60        output_size: usize,
61    ) -> Result<Vec<u8>, Error> {
62        match alg {
63            KdfAlgorithm::HkdfSha256 => sha256_expand(prk, info, output_size),
64            KdfAlgorithm::HkdfSha384 => sha384_expand(prk, info, output_size),
65            KdfAlgorithm::HkdfSha512 => sha512_expand(prk, info, output_size),
66        }
67    }
68
69    fn dh(alg: KemAlgorithm, pk: &[u8], sk: &[u8]) -> Result<Vec<u8>, Error> {
70        match alg {
71            KemAlgorithm::DhKem25519 => {
72                if sk.len() != 32 {
73                    return Err(Error::KemInvalidSecretKey);
74                }
75                if pk.len() != 32 {
76                    return Err(Error::KemInvalidPublicKey);
77                }
78                assert!(pk.len() == 32);
79                assert!(sk.len() == 32);
80                let sk_array: [u8; 32] = sk.try_into().map_err(|_| Error::KemInvalidSecretKey)?;
81                let pk_array: [u8; 32] = pk.try_into().map_err(|_| Error::KemInvalidPublicKey)?;
82                let sk = X25519StaticSecret::from(sk_array);
83                Ok(sk
84                    .diffie_hellman(&X25519PublicKey::from(pk_array))
85                    .as_bytes()
86                    .to_vec())
87            }
88            KemAlgorithm::DhKemP256 => {
89                let sk = p256SecretKey::from_slice(sk).map_err(|_| Error::KemInvalidSecretKey)?;
90                let pk =
91                    p256PublicKey::from_sec1_bytes(pk).map_err(|_| Error::KemInvalidPublicKey)?;
92                Ok(p256diffie_hellman(sk.to_nonzero_scalar(), pk.as_affine())
93                    .raw_secret_bytes()
94                    .as_slice()
95                    .into())
96            }
97            KemAlgorithm::DhKemK256 => {
98                let sk = k256SecretKey::from_slice(sk).map_err(|_| Error::KemInvalidSecretKey)?;
99                let pk =
100                    k256PublicKey::from_sec1_bytes(pk).map_err(|_| Error::KemInvalidPublicKey)?;
101                Ok(k256diffie_hellman(sk.to_nonzero_scalar(), pk.as_affine())
102                    .raw_secret_bytes()
103                    .as_slice()
104                    .into())
105            }
106            _ => Err(Error::UnknownKemAlgorithm),
107        }
108    }
109
110    fn kem_key_gen_derand(_alg: KemAlgorithm, _seed: &[u8]) -> Result<(Vec<u8>, Vec<u8>), Error> {
111        // No ciphersuite uses this.
112        return Err(Error::UnsupportedKemOperation);
113    }
114
115    fn kem_encaps(
116        _alg: KemAlgorithm,
117        _pk_r: &[u8],
118        _prng: &mut Self::HpkePrng,
119    ) -> Result<(Vec<u8>, Vec<u8>), Error> {
120        // No ciphersuite uses this.
121        return Err(Error::UnsupportedKemOperation);
122    }
123
124    fn kem_decaps(_alg: KemAlgorithm, _ct: &[u8], _sk_r: &[u8]) -> Result<Vec<u8>, Error> {
125        // No ciphersuite uses this.
126        return Err(Error::UnsupportedKemOperation);
127    }
128
129    fn secret_to_public(alg: KemAlgorithm, sk: &[u8]) -> Result<Vec<u8>, Error> {
130        match alg {
131            KemAlgorithm::DhKem25519 => {
132                if sk.len() != 32 {
133                    return Err(Error::KemInvalidSecretKey);
134                }
135                assert!(sk.len() == 32);
136                let sk_array: [u8; 32] = sk.try_into().map_err(|_| Error::KemInvalidSecretKey)?;
137                let sk = X25519StaticSecret::from(sk_array);
138                Ok(X25519PublicKey::from(&sk).as_bytes().to_vec())
139            }
140            KemAlgorithm::DhKemP256 => {
141                let sk = p256SecretKey::from_slice(sk).map_err(|_| Error::KemInvalidSecretKey)?;
142                Ok(sk.public_key().to_encoded_point(false).as_bytes().into())
143            }
144            KemAlgorithm::DhKemK256 => {
145                let sk = k256SecretKey::from_slice(sk).map_err(|_| Error::KemInvalidSecretKey)?;
146                Ok(sk.public_key().to_encoded_point(false).as_bytes().into())
147            }
148            _ => Err(Error::UnsupportedKemOperation),
149        }
150    }
151
152    fn kem_key_gen(
153        alg: KemAlgorithm,
154        prng: &mut Self::HpkePrng,
155    ) -> Result<(Vec<u8>, Vec<u8>), Error> {
156        let rng = &mut prng.rng;
157        match alg {
158            KemAlgorithm::DhKem25519 => {
159                let sk = X25519StaticSecret::random_from_rng(&mut *rng);
160                let pk = X25519PublicKey::from(&sk).as_bytes().to_vec();
161                let sk = sk.to_bytes().to_vec();
162                Ok((pk, sk))
163            }
164            KemAlgorithm::DhKemP256 => {
165                let sk = p256SecretKey::random(&mut *rng);
166                let pk = sk.public_key().to_encoded_point(false).as_bytes().into();
167                let sk = sk.to_bytes().as_slice().into();
168                Ok((pk, sk))
169            }
170            KemAlgorithm::DhKemK256 => {
171                let sk = k256SecretKey::random(&mut *rng);
172                let pk = sk.public_key().to_encoded_point(false).as_bytes().into();
173                let sk = sk.to_bytes().as_slice().into();
174                Ok((pk, sk))
175            }
176            _ => Err(Error::UnknownKemAlgorithm),
177        }
178    }
179
180    fn dh_validate_sk(alg: KemAlgorithm, sk: &[u8]) -> Result<Vec<u8>, Error> {
181        match alg {
182            KemAlgorithm::DhKemP256 => p256SecretKey::from_slice(sk)
183                .map_err(|_| Error::KemInvalidSecretKey)
184                .map(|_| sk.into()),
185            KemAlgorithm::DhKemK256 => k256SecretKey::from_slice(sk)
186                .map_err(|_| Error::KemInvalidSecretKey)
187                .map(|_| sk.into()),
188            _ => Err(Error::UnknownKemAlgorithm),
189        }
190    }
191
192    fn aead_seal(
193        alg: AeadAlgorithm,
194        key: &[u8],
195        nonce: &[u8],
196        aad: &[u8],
197        msg: &[u8],
198    ) -> Result<Vec<u8>, Error> {
199        match alg {
200            AeadAlgorithm::Aes128Gcm => aes128_seal(key, nonce, aad, msg),
201            AeadAlgorithm::Aes256Gcm => aes256_seal(key, nonce, aad, msg),
202            AeadAlgorithm::ChaCha20Poly1305 => chacha_seal(key, nonce, aad, msg),
203            AeadAlgorithm::HpkeExport => Err(Error::UnknownAeadAlgorithm),
204        }
205    }
206
207    fn aead_open(
208        alg: AeadAlgorithm,
209        key: &[u8],
210        nonce: &[u8],
211        aad: &[u8],
212        msg: &[u8],
213    ) -> Result<Vec<u8>, Error> {
214        match alg {
215            AeadAlgorithm::Aes128Gcm => aes128_open(alg, key, nonce, aad, msg),
216            AeadAlgorithm::Aes256Gcm => aes256_open(alg, key, nonce, aad, msg),
217            AeadAlgorithm::ChaCha20Poly1305 => chacha_open(alg, key, nonce, aad, msg),
218            AeadAlgorithm::HpkeExport => Err(Error::UnknownAeadAlgorithm),
219        }
220    }
221
222    type HpkePrng = HpkeRustCryptoPrng;
223
224    fn prng() -> Self::HpkePrng {
225        #[cfg(feature = "deterministic-prng")]
226        {
227            let mut fake_rng = alloc::vec![0u8; 256];
228            rand_chacha::ChaCha20Rng::from_entropy().fill_bytes(&mut fake_rng);
229            HpkeRustCryptoPrng {
230                fake_rng,
231                rng: rand_chacha::ChaCha20Rng::from_entropy(),
232            }
233        }
234        #[cfg(not(feature = "deterministic-prng"))]
235        HpkeRustCryptoPrng {
236            rng: rand_chacha::ChaCha20Rng::from_entropy(),
237        }
238    }
239
240    /// Returns an error if the KDF algorithm is not supported by this crypto provider.
241    fn supports_kdf(_: KdfAlgorithm) -> Result<(), Error> {
242        Ok(())
243    }
244
245    /// Returns an error if the KEM algorithm is not supported by this crypto provider.
246    fn supports_kem(alg: KemAlgorithm) -> Result<(), Error> {
247        match alg {
248            KemAlgorithm::DhKem25519 | KemAlgorithm::DhKemP256 | KemAlgorithm::DhKemK256 => Ok(()),
249            _ => Err(Error::UnknownKemAlgorithm),
250        }
251    }
252
253    /// Returns an error if the AEAD algorithm is not supported by this crypto provider.
254    fn supports_aead(alg: AeadAlgorithm) -> Result<(), Error> {
255        match alg {
256            AeadAlgorithm::Aes128Gcm
257            | AeadAlgorithm::Aes256Gcm
258            | AeadAlgorithm::ChaCha20Poly1305
259            | AeadAlgorithm::HpkeExport => Ok(()),
260        }
261    }
262}
263
264// We need to implement the old and new traits here because the crytpo uses the
265// old one.
266
267impl rand_old::RngCore for HpkeRustCryptoPrng {
268    fn next_u32(&mut self) -> u32 {
269        self.rng.next_u32()
270    }
271
272    fn next_u64(&mut self) -> u64 {
273        self.rng.next_u64()
274    }
275
276    fn fill_bytes(&mut self, dest: &mut [u8]) {
277        self.rng.fill_bytes(dest);
278    }
279
280    fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), rand_core::Error> {
281        self.rng.try_fill_bytes(dest)
282    }
283}
284
285impl rand_old::CryptoRng for HpkeRustCryptoPrng {}
286
287use rand_old::RngCore as _;
288
289impl RngCore for HpkeRustCryptoPrng {
290    fn next_u32(&mut self) -> u32 {
291        self.rng.next_u32()
292    }
293
294    fn next_u64(&mut self) -> u64 {
295        self.rng.next_u64()
296    }
297
298    fn fill_bytes(&mut self, dest: &mut [u8]) {
299        self.rng.fill_bytes(dest);
300    }
301}
302
303impl CryptoRng for HpkeRustCryptoPrng {}
304
305impl HpkeTestRng for HpkeRustCryptoPrng {
306    #[cfg(feature = "deterministic-prng")]
307    fn try_fill_test_bytes(&mut self, dest: &mut [u8]) -> Result<(), rand_old::Error> {
308        // Here we fake our randomness for testing.
309        if dest.len() > self.fake_rng.len() {
310            return Err(rand_core::Error::new(Error::InsufficientRandomness));
311        }
312        dest.clone_from_slice(&self.fake_rng.split_off(self.fake_rng.len() - dest.len()));
313        Ok(())
314    }
315
316    #[cfg(feature = "deterministic-prng")]
317    fn seed(&mut self, seed: &[u8]) {
318        self.fake_rng = seed.to_vec();
319    }
320    #[cfg(not(feature = "deterministic-prng"))]
321    fn try_fill_test_bytes(&mut self, dest: &mut [u8]) -> Result<(), rand_old::Error> {
322        self.rng.try_fill_bytes(dest)
323    }
324
325    #[cfg(not(feature = "deterministic-prng"))]
326    fn seed(&mut self, _: &[u8]) {}
327
328    type Error = rand_old::Error;
329}
330
331impl Display for HpkeRustCrypto {
332    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
333        write!(f, "{}", Self::name())
334    }
335}