1use crate::{Deserializable, HpkeError, Serializable};
4
5use core::fmt::Debug;
6
7use generic_array::{ArrayLength, GenericArray};
8use rand_core::{CryptoRng, RngCore};
9use zeroize::Zeroize;
10
11mod dhkem;
12pub use dhkem::*;
13
14pub trait Kem: Sized {
16 type PublicKey: Clone + Debug + PartialEq + Eq + Serializable + Deserializable;
19
20 type PrivateKey: Clone + PartialEq + Eq + Serializable + Deserializable;
23
24 fn sk_to_pk(sk: &Self::PrivateKey) -> Self::PublicKey;
26 type EncappedKey: Clone + Serializable + Deserializable;
29
30 #[doc(hidden)]
32 type NSecret: ArrayLength<u8>;
33
34 const KEM_ID: u16;
36
37 fn derive_keypair(ikm: &[u8]) -> (Self::PrivateKey, Self::PublicKey);
45
46 fn gen_keypair<R: CryptoRng + RngCore>(csprng: &mut R) -> (Self::PrivateKey, Self::PublicKey) {
48 let mut ikm: GenericArray<u8, <Self::PrivateKey as Serializable>::OutputSize> =
50 GenericArray::default();
51 csprng.fill_bytes(&mut ikm);
53 Self::derive_keypair(&ikm)
55 }
56
57 #[doc(hidden)]
65 fn decap(
66 sk_recip: &Self::PrivateKey,
67 pk_sender_id: Option<&Self::PublicKey>,
68 encapped_key: &Self::EncappedKey,
69 ) -> Result<SharedSecret<Self>, HpkeError>;
70
71 #[doc(hidden)]
81 fn encap<R: CryptoRng + RngCore>(
82 pk_recip: &Self::PublicKey,
83 sender_id_keypair: Option<(&Self::PrivateKey, &Self::PublicKey)>,
84 csprng: &mut R,
85 ) -> Result<(SharedSecret<Self>, Self::EncappedKey), HpkeError>;
86}
87
88use Kem as KemTrait;
90
91#[doc(hidden)]
93pub struct SharedSecret<Kem: KemTrait>(pub GenericArray<u8, Kem::NSecret>);
94
95impl<Kem: KemTrait> Default for SharedSecret<Kem> {
96 fn default() -> SharedSecret<Kem> {
97 SharedSecret(GenericArray::<u8, Kem::NSecret>::default())
98 }
99}
100
101impl<Kem: KemTrait> Zeroize for SharedSecret<Kem> {
103 fn zeroize(&mut self) {
104 self.0.zeroize()
105 }
106}
107impl<Kem: KemTrait> Drop for SharedSecret<Kem> {
108 fn drop(&mut self) {
109 self.zeroize();
110 }
111}
112
113#[cfg(test)]
114mod tests {
115 use crate::{kem::Kem as KemTrait, Deserializable, Serializable};
116
117 use rand::{rngs::StdRng, SeedableRng};
118
119 macro_rules! test_encap_correctness {
120 ($test_name:ident, $kem_ty:ty) => {
121 #[test]
123 fn $test_name() {
124 type Kem = $kem_ty;
125
126 let mut csprng = StdRng::from_os_rng();
127 let (sk_recip, pk_recip) = Kem::gen_keypair(&mut csprng);
128
129 let (auth_shared_secret, encapped_key) =
131 Kem::encap(&pk_recip, None, &mut csprng).unwrap();
132
133 let decapped_auth_shared_secret =
135 Kem::decap(&sk_recip, None, &encapped_key).unwrap();
136
137 assert_eq!(auth_shared_secret.0, decapped_auth_shared_secret.0);
139
140 let (sk_sender_id, pk_sender_id) = Kem::gen_keypair(&mut csprng);
146
147 let (auth_shared_secret, encapped_key) = Kem::encap(
149 &pk_recip,
150 Some((&sk_sender_id, &pk_sender_id.clone())),
151 &mut csprng,
152 )
153 .unwrap();
154
155 let decapped_auth_shared_secret =
157 Kem::decap(&sk_recip, Some(&pk_sender_id), &encapped_key).unwrap();
158
159 assert_eq!(auth_shared_secret.0, decapped_auth_shared_secret.0);
161 }
162 };
163 }
164
165 macro_rules! test_encapped_serialize {
167 ($test_name:ident, $kem_ty:ty) => {
168 #[test]
169 fn $test_name() {
170 type Kem = $kem_ty;
171
172 let encapped_key = {
174 let mut csprng = StdRng::from_os_rng();
175 let (_, pk_recip) = Kem::gen_keypair(&mut csprng);
176 Kem::encap(&pk_recip, None, &mut csprng).unwrap().1
177 };
178 let encapped_key_bytes = encapped_key.to_bytes();
180 let new_encapped_key =
182 <<Kem as KemTrait>::EncappedKey as Deserializable>::from_bytes(
183 &encapped_key_bytes,
184 )
185 .unwrap();
186
187 assert_eq!(
188 new_encapped_key.0, encapped_key.0,
189 "encapped key doesn't serialize correctly"
190 );
191 }
192 };
193 }
194
195 #[cfg(feature = "x25519")]
196 mod x25519_tests {
197 use super::*;
198
199 test_encap_correctness!(test_encap_correctness_x25519, crate::kem::X25519HkdfSha256);
200 test_encapped_serialize!(test_encapped_serialize_x25519, crate::kem::X25519HkdfSha256);
201 }
202
203 #[cfg(feature = "p256")]
204 mod p256_tests {
205 use super::*;
206
207 test_encap_correctness!(test_encap_correctness_p256, crate::kem::DhP256HkdfSha256);
208 test_encapped_serialize!(test_encapped_serialize_p256, crate::kem::DhP256HkdfSha256);
209 }
210
211 #[cfg(feature = "p384")]
212 mod p384_tests {
213 use super::*;
214
215 test_encap_correctness!(test_encap_correctness_p384, crate::kem::DhP384HkdfSha384);
216 test_encapped_serialize!(test_encapped_serialize_p384, crate::kem::DhP384HkdfSha384);
217 }
218
219 #[cfg(feature = "p521")]
220 mod p521_tests {
221 use super::*;
222
223 test_encap_correctness!(test_encap_correctness_p521, crate::kem::DhP521HkdfSha512);
224 test_encapped_serialize!(test_encapped_serialize_p521, crate::kem::DhP521HkdfSha512);
225 }
226}