aws_kms_tls_auth/
receiver.rs

1// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use crate::{
5    codec::DecodeValue,
6    identity::{ObfuscationKey, PskIdentity},
7    psk_from_material,
8    psk_parser::retrieve_psk_identities,
9    KeyArn, KEY_ROTATION_PERIOD, MAXIMUM_KEY_CACHE_SIZE,
10};
11use aws_sdk_kms::{primitives::Blob, Client};
12use moka::sync::Cache;
13use pin_project::pin_project;
14use s2n_tls::{
15    callbacks::{ClientHelloCallback, ConnectionFuture},
16    error::Error as S2NError,
17};
18use std::{future::Future, pin::Pin, sync::Arc, task::Poll};
19
20/// DecryptFuture wraps a future from the SDK into a format that s2n-tls understands
21/// and can poll.
22///
23/// Specifically, it implements ConnectionFuture for the interior future type.
24#[pin_project]
25struct DecryptFuture<F> {
26    #[pin]
27    future: F,
28}
29
30impl<F> DecryptFuture<F>
31where
32    F: 'static + Send + Sync + Future<Output = anyhow::Result<s2n_tls::psk::Psk>>,
33{
34    pub fn new(future: F) -> Self {
35        DecryptFuture { future }
36    }
37}
38
39impl<F> s2n_tls::callbacks::ConnectionFuture for DecryptFuture<F>
40where
41    F: 'static + Send + Sync + Future<Output = anyhow::Result<s2n_tls::psk::Psk>>,
42{
43    fn poll(
44        self: Pin<&mut Self>,
45        connection: &mut s2n_tls::connection::Connection,
46        ctx: &mut core::task::Context,
47    ) -> std::task::Poll<Result<(), S2NError>> {
48        let this = self.project();
49        let psk = match this.future.poll(ctx) {
50            Poll::Ready(Ok(psk)) => psk,
51            Poll::Ready(Err(e)) => {
52                return Poll::Ready(Err(s2n_tls::error::Error::application(
53                    e.into_boxed_dyn_error(),
54                )));
55            }
56            Poll::Pending => return Poll::Pending,
57        };
58        connection.append_psk(&psk)?;
59        Poll::Ready(Ok(()))
60    }
61}
62
63/// The `PskReceiver` is used along with the [`PskProvider`] to perform TLS
64/// 1.3 out-of-band PSK authentication, using PSK's generated from KMS.
65///
66/// This struct can be enabled on a config with [`s2n_tls::config::Builder::set_client_hello_callback`].
67pub struct PskReceiver {
68    kms_client: Client,
69    obfuscation_keys: Vec<ObfuscationKey>,
70    trusted_key_arns: Arc<Vec<KeyArn>>,
71    /// The key_cache maps from the ciphertext datakey to the plaintext datakey.
72    /// It has a bounded size, and will also evict items after 2 * KEY_ROTATION_PERIOD
73    /// has elapsed.
74    key_cache: Cache<Vec<u8>, Vec<u8>>,
75}
76
77impl PskReceiver {
78    /// Create a new PskReceiver.
79    ///
80    /// This will receive the ciphertext datakey identities from a TLS client hello,
81    /// then decrypt them using KMS. This establishes a mutually authenticated TLS
82    /// handshake between parties with IAM permissions to generate and decrypt data keys
83    ///
84    /// * `kms_client`: The KMS Client that will be used for the decrypt calls
85    ///
86    /// * `trusted_key_arns`: The list of KMS KeyArns that the PskReceiver will
87    ///   accept PSKs from. This is necessary because an attacker could grant the
88    ///   server decrypt permissions on AttackerKeyArn, but the PskReceiver should
89    ///   _not_ trust any Psk's from AttackerKeyArn.
90    ///
91    /// * `obfuscation_keys`: The keys that will be used to deobfuscate the received
92    ///   identities. The client `PskProvider` must be using one of the obfuscation
93    ///   keys in this list. If the PskReceiver receives a Psk identity obfuscated
94    ///   using a key _not_ on this list, then the handshake will fail.
95    pub fn new(
96        kms_client: Client,
97        trusted_key_arns: Vec<KeyArn>,
98        obfuscation_keys: Vec<ObfuscationKey>,
99    ) -> Self {
100        let key_cache = moka::sync::Cache::builder()
101            .max_capacity(MAXIMUM_KEY_CACHE_SIZE as u64)
102            .time_to_idle(KEY_ROTATION_PERIOD)
103            .build();
104        Self {
105            kms_client,
106            trusted_key_arns: Arc::new(trusted_key_arns),
107            obfuscation_keys,
108            key_cache,
109        }
110    }
111
112    /// This is the main async future that s2n-tls polls.
113    ///
114    /// It will
115    /// 1. decrypt the ciphertext datakey
116    /// 2. check that the decrypted material is associated with a trusted key id
117    /// 3. cache the decrypted material in the key cache
118    /// 4. return an s2n-tls psk
119    ///
120    /// All of the arguments are owned to satisfy the `'static` bound that s2n-tls
121    /// requires on connection futures.
122    async fn kms_decrypt_and_update(
123        psk_identity: Vec<u8>,
124        ciphertext_datakey: Vec<u8>,
125        client: Client,
126        trusted_key_arns: Arc<Vec<KeyArn>>,
127        key_cache: Cache<Vec<u8>, Vec<u8>>,
128    ) -> anyhow::Result<s2n_tls::psk::Psk> {
129        let ciphertext_datakey_clone = ciphertext_datakey.clone();
130        let decrypted = tokio::spawn(async move {
131            client
132                .decrypt()
133                .ciphertext_blob(Blob::new(ciphertext_datakey_clone))
134                .send()
135                .await
136        })
137        .await??;
138
139        // although the field is called `key_id`, it is actually the key arn. This
140        // is confirmed in the documentation:
141        // https://docs.aws.amazon.com/kms/latest/APIReference/API_Decrypt.html#API_Decrypt_ResponseSyntax
142        let associated_key_arn = decrypted.key_id.as_ref().unwrap();
143        if !trusted_key_arns.contains(associated_key_arn) {
144            anyhow::bail!("untrusted KMS Key: {associated_key_arn} is not trusted");
145        }
146
147        let plaintext_datakey = decrypted.plaintext.unwrap().into_inner();
148        key_cache.insert(ciphertext_datakey, plaintext_datakey.clone());
149        let psk = psk_from_material(&psk_identity, &plaintext_datakey)?;
150
151        Ok(psk)
152    }
153}
154
155impl ClientHelloCallback for PskReceiver {
156    fn on_client_hello(
157        &self,
158        connection: &mut s2n_tls::connection::Connection,
159    ) -> Result<Option<Pin<Box<dyn ConnectionFuture>>>, s2n_tls::error::Error> {
160        // parse the identity list from the client hello
161        let client_hello = connection.client_hello()?;
162        let identities = match retrieve_psk_identities(client_hello) {
163            Ok(identities) => identities,
164            Err(e) => {
165                return Err(s2n_tls::error::Error::application(e.into()));
166            }
167        };
168
169        // extract the identity bytes from the first PSK entry. We assume that we
170        // are talking to a PskProvider, so we don't look at any additional entries.
171        let psk_identity = match identities.list().first() {
172            Some(id) => id.identity.blob(),
173            None => {
174                return Err(s2n_tls::error::Error::application(
175                    "identities list was zero-length".into(),
176                ))
177            }
178        };
179
180        // parse the identity bytes to a PskIdentity
181        let identity = PskIdentity::decode_from_exact(psk_identity)
182            .map_err(|e| s2n_tls::error::Error::application(e.into()))?;
183
184        // deobfuscate the identity to get the ciphertext datakey
185        let ciphertext_datakey = identity
186            .deobfuscate_datakey(&self.obfuscation_keys)
187            .map_err(|e| s2n_tls::error::Error::application(e.into()))?;
188
189        let maybe_cached = self.key_cache.get(&ciphertext_datakey);
190        if let Some(plaintext_datakey) = maybe_cached {
191            // if we already had it cached, then append the PSK and return
192            let psk = psk_from_material(psk_identity, &plaintext_datakey)?;
193            connection.append_psk(&psk)?;
194            Ok(None)
195        } else {
196            // otherwise return a future to decrypt with KMS
197            let future = Self::kms_decrypt_and_update(
198                psk_identity.to_vec(),
199                ciphertext_datakey,
200                self.kms_client.clone(),
201                self.trusted_key_arns.clone(),
202                self.key_cache.clone(),
203            );
204            let wrapped = DecryptFuture::new(future);
205            Ok(Some(Box::pin(wrapped)))
206        }
207    }
208}
209
210#[cfg(test)]
211mod tests {
212    use crate::{
213        identity::PskVersion,
214        test_utils::{
215            configs_from_callbacks, decrypt_mocks, gdk_mocks, handshake, test_psk_provider,
216            CIPHERTEXT_DATAKEY_A, CONSTANT_OBFUSCATION_KEY, KMS_KEY_ARN, OBFUSCATION_KEY,
217            PLAINTEXT_DATAKEY_A,
218        },
219        PskProvider,
220    };
221
222    use super::*;
223    use aws_sdk_kms::{operation::decrypt::DecryptError, types::error::InvalidKeyUsageException};
224    // https://docs.aws.amazon.com/sdk-for-rust/latest/dg/testing-smithy-mocks.html
225    use aws_smithy_mocks::{mock, mock_client};
226    use s2n_tls::config::ConnectionInitializer;
227
228    /// When a new identity isn't in the cache, we
229    /// 1. call KMS to decrypt it
230    /// 2. store the result in the PSK
231    /// When an identity is in the cache
232    /// 1. no calls are made to KMS to decrypt it
233    #[tokio::test]
234    async fn decrypt_path() {
235        let psk_provider = test_psk_provider().await;
236
237        let (decrypt_rule, decrypt_client) = decrypt_mocks();
238        let psk_receiver = PskReceiver::new(
239            decrypt_client,
240            vec![KMS_KEY_ARN.to_owned()],
241            vec![OBFUSCATION_KEY.clone()],
242        );
243
244        let cache_handle = psk_receiver.key_cache.clone();
245
246        let (client_config, server_config) = configs_from_callbacks(psk_provider, psk_receiver);
247        assert_eq!(decrypt_rule.num_calls(), 0);
248
249        handshake(&client_config, &server_config).await.unwrap();
250        assert_eq!(decrypt_rule.num_calls(), 1);
251        assert_eq!(
252            cache_handle.get(CIPHERTEXT_DATAKEY_A).unwrap().as_slice(),
253            PLAINTEXT_DATAKEY_A
254        );
255
256        // no additional decrypt calls, the cached key was used
257        handshake(&client_config, &server_config).await.unwrap();
258        assert_eq!(decrypt_rule.num_calls(), 1);
259    }
260
261    // if the key ARN isn't recognized, then the handshake fails
262    #[tokio::test]
263    async fn untrusted_key_arn() {
264        let psk_provider = test_psk_provider().await;
265
266        let (_decrypt_rule, decrypt_client) = decrypt_mocks();
267        let psk_receiver = PskReceiver::new(
268            decrypt_client,
269            // use an ARN different from the one KMS will return
270            vec!["arn::wont-be-seen".to_string()],
271            vec![OBFUSCATION_KEY.clone()],
272        );
273
274        let (client_config, server_config) = configs_from_callbacks(psk_provider, psk_receiver);
275
276        let err = handshake(&client_config, &server_config).await.unwrap_err();
277        assert!(err.to_string().contains("untrusted KMS Key: arn:aws:kms:us-west-2:111122223333:key/1234abcd-12ab-34cd-56ef-1234567890ab is not trusted"));
278    }
279
280    #[tokio::test]
281    async fn obfuscation_key_unavailable() {
282        let psk_provider = test_psk_provider().await;
283
284        // we configured the Psk Receiver with a different obfuscation key
285        let (decrypt_rule, decrypt_client) = decrypt_mocks();
286        let psk_receiver = PskReceiver::new(
287            decrypt_client,
288            vec![KMS_KEY_ARN.to_owned()],
289            vec![ObfuscationKey::random_test_key()],
290        );
291
292        let (client_config, server_config) = configs_from_callbacks(psk_provider, psk_receiver);
293
294        let err = handshake(&client_config, &server_config).await.unwrap_err();
295        // unable to deobfuscate: f6c9d1107f9b86a7bfbf836458d0483e not available
296        assert!(err.to_string().starts_with("unable to deobfuscate: "));
297        assert!(err.to_string().ends_with("not available"));
298
299        // we should not have attempted to decrypt the key
300        assert_eq!(decrypt_rule.num_calls(), 0)
301    }
302
303    // when the map is at capacity, old items are evicted when new ones are added
304    #[tokio::test]
305    async fn cache_max_capacity() {
306        let (decrypt_rule, decrypt_client) = decrypt_mocks();
307        let (_gdk_rule, gdk_client) = gdk_mocks();
308
309        let obfuscation_key = ObfuscationKey::random_test_key();
310        let psk_provider = PskProvider::initialize(
311            PskVersion::V1,
312            gdk_client,
313            KMS_KEY_ARN.to_string(),
314            obfuscation_key.clone(),
315            |_| {},
316        )
317        .await
318        .unwrap();
319
320        let psk_receiver = PskReceiver::new(
321            decrypt_client,
322            vec![KMS_KEY_ARN.to_owned()],
323            vec![obfuscation_key],
324        );
325
326        let cache_handle = psk_receiver.key_cache.clone();
327        for i in 0..MAXIMUM_KEY_CACHE_SIZE {
328            cache_handle.insert(i.to_be_bytes().to_vec(), i.to_be_bytes().to_vec());
329        }
330        cache_handle.run_pending_tasks();
331        assert_eq!(cache_handle.entry_count(), MAXIMUM_KEY_CACHE_SIZE as u64);
332
333        let (client_config, server_config) = configs_from_callbacks(psk_provider, psk_receiver);
334
335        assert_eq!(decrypt_rule.num_calls(), 0);
336        handshake(&client_config, &server_config).await.unwrap();
337        assert_eq!(decrypt_rule.num_calls(), 1);
338
339        cache_handle.run_pending_tasks();
340        assert_eq!(cache_handle.entry_count(), MAXIMUM_KEY_CACHE_SIZE as u64);
341    }
342
343    // when the decrypt operation fails, the handshake should also fail
344    #[tokio::test]
345    async fn decrypt_error() {
346        let decrypt_rule = mock!(aws_sdk_kms::Client::decrypt).then_error(|| {
347            DecryptError::InvalidKeyUsageException(InvalidKeyUsageException::builder().build())
348        });
349        let decrypt_client = mock_client!(aws_sdk_kms, [&decrypt_rule]);
350
351        let psk_provider = test_psk_provider().await;
352
353        let psk_receiver = PskReceiver::new(
354            decrypt_client,
355            vec![KMS_KEY_ARN.to_owned()],
356            vec![OBFUSCATION_KEY.clone()],
357        );
358
359        let (client_config, server_config) = configs_from_callbacks(psk_provider, psk_receiver);
360
361        let decrypt_error = handshake(&client_config, &server_config).await.unwrap_err();
362        assert!(decrypt_error.to_string().contains("service error"));
363    }
364
365    /// When an old PskIdentity is received, the handshake should fail, and no
366    /// decrypt calls should be made.
367    #[tokio::test]
368    async fn receiver_rejects_old_identity() {
369        const OLD_IDENTITY: &[u8] = include_bytes!("../resources/psk_identity.bin");
370        struct OldIdentityInitializer;
371        impl ConnectionInitializer for OldIdentityInitializer {
372            fn initialize_connection(
373                &self,
374                connection: &mut s2n_tls::connection::Connection,
375            ) -> Result<Option<Pin<Box<(dyn ConnectionFuture)>>>, s2n_tls::error::Error>
376            {
377                let psk =
378                    psk_from_material(OLD_IDENTITY, b"doesn't matter, should fail before using")?;
379                connection.append_psk(&psk)?;
380                Ok(None)
381            }
382        }
383
384        let (decrypt_rule, decrypt_client) = decrypt_mocks();
385        let psk_receiver = PskReceiver::new(
386            decrypt_client,
387            vec![KMS_KEY_ARN.to_owned()],
388            vec![CONSTANT_OBFUSCATION_KEY.clone()],
389        );
390
391        let (client_config, server_config) =
392            configs_from_callbacks(OldIdentityInitializer, psk_receiver);
393        let too_old_error = handshake(&client_config, &server_config).await.unwrap_err();
394        assert_eq!(decrypt_rule.num_calls(), 0);
395        assert!(too_old_error.to_string().contains("too old"));
396    }
397}