1use crate::{
5 codec::DecodeValue,
6 identity::{ObfuscationKey, PskIdentity},
7 psk_from_material,
8 psk_parser::retrieve_psk_identities,
9 KeyArn, KEY_ROTATION_PERIOD, MAXIMUM_KEY_CACHE_SIZE,
10};
11use aws_sdk_kms::{primitives::Blob, Client};
12use moka::sync::Cache;
13use pin_project::pin_project;
14use s2n_tls::{
15 callbacks::{ClientHelloCallback, ConnectionFuture},
16 error::Error as S2NError,
17};
18use std::{future::Future, pin::Pin, sync::Arc, task::Poll};
19
20#[pin_project]
25struct DecryptFuture<F> {
26 #[pin]
27 future: F,
28}
29
30impl<F> DecryptFuture<F>
31where
32 F: 'static + Send + Sync + Future<Output = anyhow::Result<s2n_tls::psk::Psk>>,
33{
34 pub fn new(future: F) -> Self {
35 DecryptFuture { future }
36 }
37}
38
39impl<F> s2n_tls::callbacks::ConnectionFuture for DecryptFuture<F>
40where
41 F: 'static + Send + Sync + Future<Output = anyhow::Result<s2n_tls::psk::Psk>>,
42{
43 fn poll(
44 self: Pin<&mut Self>,
45 connection: &mut s2n_tls::connection::Connection,
46 ctx: &mut core::task::Context,
47 ) -> std::task::Poll<Result<(), S2NError>> {
48 let this = self.project();
49 let psk = match this.future.poll(ctx) {
50 Poll::Ready(Ok(psk)) => psk,
51 Poll::Ready(Err(e)) => {
52 return Poll::Ready(Err(s2n_tls::error::Error::application(
53 e.into_boxed_dyn_error(),
54 )));
55 }
56 Poll::Pending => return Poll::Pending,
57 };
58 connection.append_psk(&psk)?;
59 Poll::Ready(Ok(()))
60 }
61}
62
63pub struct PskReceiver {
68 kms_client: Client,
69 obfuscation_keys: Vec<ObfuscationKey>,
70 trusted_key_arns: Arc<Vec<KeyArn>>,
71 key_cache: Cache<Vec<u8>, Vec<u8>>,
75}
76
77impl PskReceiver {
78 pub fn new(
96 kms_client: Client,
97 trusted_key_arns: Vec<KeyArn>,
98 obfuscation_keys: Vec<ObfuscationKey>,
99 ) -> Self {
100 let key_cache = moka::sync::Cache::builder()
101 .max_capacity(MAXIMUM_KEY_CACHE_SIZE as u64)
102 .time_to_idle(KEY_ROTATION_PERIOD)
103 .build();
104 Self {
105 kms_client,
106 trusted_key_arns: Arc::new(trusted_key_arns),
107 obfuscation_keys,
108 key_cache,
109 }
110 }
111
112 async fn kms_decrypt_and_update(
123 psk_identity: Vec<u8>,
124 ciphertext_datakey: Vec<u8>,
125 client: Client,
126 trusted_key_arns: Arc<Vec<KeyArn>>,
127 key_cache: Cache<Vec<u8>, Vec<u8>>,
128 ) -> anyhow::Result<s2n_tls::psk::Psk> {
129 let ciphertext_datakey_clone = ciphertext_datakey.clone();
130 let decrypted = tokio::spawn(async move {
131 client
132 .decrypt()
133 .ciphertext_blob(Blob::new(ciphertext_datakey_clone))
134 .send()
135 .await
136 })
137 .await??;
138
139 let associated_key_arn = decrypted.key_id.as_ref().unwrap();
143 if !trusted_key_arns.contains(associated_key_arn) {
144 anyhow::bail!("untrusted KMS Key: {associated_key_arn} is not trusted");
145 }
146
147 let plaintext_datakey = decrypted.plaintext.unwrap().into_inner();
148 key_cache.insert(ciphertext_datakey, plaintext_datakey.clone());
149 let psk = psk_from_material(&psk_identity, &plaintext_datakey)?;
150
151 Ok(psk)
152 }
153}
154
155impl ClientHelloCallback for PskReceiver {
156 fn on_client_hello(
157 &self,
158 connection: &mut s2n_tls::connection::Connection,
159 ) -> Result<Option<Pin<Box<dyn ConnectionFuture>>>, s2n_tls::error::Error> {
160 let client_hello = connection.client_hello()?;
162 let identities = match retrieve_psk_identities(client_hello) {
163 Ok(identities) => identities,
164 Err(e) => {
165 return Err(s2n_tls::error::Error::application(e.into()));
166 }
167 };
168
169 let psk_identity = match identities.list().first() {
172 Some(id) => id.identity.blob(),
173 None => {
174 return Err(s2n_tls::error::Error::application(
175 "identities list was zero-length".into(),
176 ))
177 }
178 };
179
180 let identity = PskIdentity::decode_from_exact(psk_identity)
182 .map_err(|e| s2n_tls::error::Error::application(e.into()))?;
183
184 let ciphertext_datakey = identity
186 .deobfuscate_datakey(&self.obfuscation_keys)
187 .map_err(|e| s2n_tls::error::Error::application(e.into()))?;
188
189 let maybe_cached = self.key_cache.get(&ciphertext_datakey);
190 if let Some(plaintext_datakey) = maybe_cached {
191 let psk = psk_from_material(psk_identity, &plaintext_datakey)?;
193 connection.append_psk(&psk)?;
194 Ok(None)
195 } else {
196 let future = Self::kms_decrypt_and_update(
198 psk_identity.to_vec(),
199 ciphertext_datakey,
200 self.kms_client.clone(),
201 self.trusted_key_arns.clone(),
202 self.key_cache.clone(),
203 );
204 let wrapped = DecryptFuture::new(future);
205 Ok(Some(Box::pin(wrapped)))
206 }
207 }
208}
209
210#[cfg(test)]
211mod tests {
212 use crate::{
213 identity::PskVersion,
214 test_utils::{
215 configs_from_callbacks, decrypt_mocks, gdk_mocks, handshake, test_psk_provider,
216 CIPHERTEXT_DATAKEY_A, CONSTANT_OBFUSCATION_KEY, KMS_KEY_ARN, OBFUSCATION_KEY,
217 PLAINTEXT_DATAKEY_A,
218 },
219 PskProvider,
220 };
221
222 use super::*;
223 use aws_sdk_kms::{operation::decrypt::DecryptError, types::error::InvalidKeyUsageException};
224 use aws_smithy_mocks::{mock, mock_client};
226 use s2n_tls::config::ConnectionInitializer;
227
228 #[tokio::test]
234 async fn decrypt_path() {
235 let psk_provider = test_psk_provider().await;
236
237 let (decrypt_rule, decrypt_client) = decrypt_mocks();
238 let psk_receiver = PskReceiver::new(
239 decrypt_client,
240 vec![KMS_KEY_ARN.to_owned()],
241 vec![OBFUSCATION_KEY.clone()],
242 );
243
244 let cache_handle = psk_receiver.key_cache.clone();
245
246 let (client_config, server_config) = configs_from_callbacks(psk_provider, psk_receiver);
247 assert_eq!(decrypt_rule.num_calls(), 0);
248
249 handshake(&client_config, &server_config).await.unwrap();
250 assert_eq!(decrypt_rule.num_calls(), 1);
251 assert_eq!(
252 cache_handle.get(CIPHERTEXT_DATAKEY_A).unwrap().as_slice(),
253 PLAINTEXT_DATAKEY_A
254 );
255
256 handshake(&client_config, &server_config).await.unwrap();
258 assert_eq!(decrypt_rule.num_calls(), 1);
259 }
260
261 #[tokio::test]
263 async fn untrusted_key_arn() {
264 let psk_provider = test_psk_provider().await;
265
266 let (_decrypt_rule, decrypt_client) = decrypt_mocks();
267 let psk_receiver = PskReceiver::new(
268 decrypt_client,
269 vec!["arn::wont-be-seen".to_string()],
271 vec![OBFUSCATION_KEY.clone()],
272 );
273
274 let (client_config, server_config) = configs_from_callbacks(psk_provider, psk_receiver);
275
276 let err = handshake(&client_config, &server_config).await.unwrap_err();
277 assert!(err.to_string().contains("untrusted KMS Key: arn:aws:kms:us-west-2:111122223333:key/1234abcd-12ab-34cd-56ef-1234567890ab is not trusted"));
278 }
279
280 #[tokio::test]
281 async fn obfuscation_key_unavailable() {
282 let psk_provider = test_psk_provider().await;
283
284 let (decrypt_rule, decrypt_client) = decrypt_mocks();
286 let psk_receiver = PskReceiver::new(
287 decrypt_client,
288 vec![KMS_KEY_ARN.to_owned()],
289 vec![ObfuscationKey::random_test_key()],
290 );
291
292 let (client_config, server_config) = configs_from_callbacks(psk_provider, psk_receiver);
293
294 let err = handshake(&client_config, &server_config).await.unwrap_err();
295 assert!(err.to_string().starts_with("unable to deobfuscate: "));
297 assert!(err.to_string().ends_with("not available"));
298
299 assert_eq!(decrypt_rule.num_calls(), 0)
301 }
302
303 #[tokio::test]
305 async fn cache_max_capacity() {
306 let (decrypt_rule, decrypt_client) = decrypt_mocks();
307 let (_gdk_rule, gdk_client) = gdk_mocks();
308
309 let obfuscation_key = ObfuscationKey::random_test_key();
310 let psk_provider = PskProvider::initialize(
311 PskVersion::V1,
312 gdk_client,
313 KMS_KEY_ARN.to_string(),
314 obfuscation_key.clone(),
315 |_| {},
316 )
317 .await
318 .unwrap();
319
320 let psk_receiver = PskReceiver::new(
321 decrypt_client,
322 vec![KMS_KEY_ARN.to_owned()],
323 vec![obfuscation_key],
324 );
325
326 let cache_handle = psk_receiver.key_cache.clone();
327 for i in 0..MAXIMUM_KEY_CACHE_SIZE {
328 cache_handle.insert(i.to_be_bytes().to_vec(), i.to_be_bytes().to_vec());
329 }
330 cache_handle.run_pending_tasks();
331 assert_eq!(cache_handle.entry_count(), MAXIMUM_KEY_CACHE_SIZE as u64);
332
333 let (client_config, server_config) = configs_from_callbacks(psk_provider, psk_receiver);
334
335 assert_eq!(decrypt_rule.num_calls(), 0);
336 handshake(&client_config, &server_config).await.unwrap();
337 assert_eq!(decrypt_rule.num_calls(), 1);
338
339 cache_handle.run_pending_tasks();
340 assert_eq!(cache_handle.entry_count(), MAXIMUM_KEY_CACHE_SIZE as u64);
341 }
342
343 #[tokio::test]
345 async fn decrypt_error() {
346 let decrypt_rule = mock!(aws_sdk_kms::Client::decrypt).then_error(|| {
347 DecryptError::InvalidKeyUsageException(InvalidKeyUsageException::builder().build())
348 });
349 let decrypt_client = mock_client!(aws_sdk_kms, [&decrypt_rule]);
350
351 let psk_provider = test_psk_provider().await;
352
353 let psk_receiver = PskReceiver::new(
354 decrypt_client,
355 vec![KMS_KEY_ARN.to_owned()],
356 vec![OBFUSCATION_KEY.clone()],
357 );
358
359 let (client_config, server_config) = configs_from_callbacks(psk_provider, psk_receiver);
360
361 let decrypt_error = handshake(&client_config, &server_config).await.unwrap_err();
362 assert!(decrypt_error.to_string().contains("service error"));
363 }
364
365 #[tokio::test]
368 async fn receiver_rejects_old_identity() {
369 const OLD_IDENTITY: &[u8] = include_bytes!("../resources/psk_identity.bin");
370 struct OldIdentityInitializer;
371 impl ConnectionInitializer for OldIdentityInitializer {
372 fn initialize_connection(
373 &self,
374 connection: &mut s2n_tls::connection::Connection,
375 ) -> Result<Option<Pin<Box<(dyn ConnectionFuture)>>>, s2n_tls::error::Error>
376 {
377 let psk =
378 psk_from_material(OLD_IDENTITY, b"doesn't matter, should fail before using")?;
379 connection.append_psk(&psk)?;
380 Ok(None)
381 }
382 }
383
384 let (decrypt_rule, decrypt_client) = decrypt_mocks();
385 let psk_receiver = PskReceiver::new(
386 decrypt_client,
387 vec![KMS_KEY_ARN.to_owned()],
388 vec![CONSTANT_OBFUSCATION_KEY.clone()],
389 );
390
391 let (client_config, server_config) =
392 configs_from_callbacks(OldIdentityInitializer, psk_receiver);
393 let too_old_error = handshake(&client_config, &server_config).await.unwrap_err();
394 assert_eq!(decrypt_rule.num_calls(), 0);
395 assert!(too_old_error.to_string().contains("too old"));
396 }
397}