aws_kms_tls_auth/
psk_parser.rs

1// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use crate::{
5    codec::{DecodeByteSource, DecodeValue},
6    prefixed_list::{PrefixedBlob, PrefixedList},
7};
8use s2n_tls::error::Fallible;
9use s2n_tls_sys::{s2n_client_hello_get_extension_by_id, s2n_client_hello_get_extension_length};
10use std::ffi::c_uint;
11
12trait S2NClientHelloExtension {
13    /// retrieve the pre-shared-key extension from an s2n-tls client hello
14    fn pre_shared_key(&self) -> Result<Option<Vec<u8>>, s2n_tls::error::Error>;
15}
16
17impl S2NClientHelloExtension for s2n_tls::client_hello::ClientHello {
18    /// Retrieve the extension_data from the pre-shared-key extension
19    fn pre_shared_key(&self) -> Result<Option<Vec<u8>>, s2n_tls::error::Error> {
20        // we are depending on an internal implementation detail, where ClientHello
21        // is aliased to the raw s2n_tls_sys type. Below is a compile-time assertion
22        // that this hasn't changed. Even if the change isn't run through
23        // aws-kms-tls-auth CI, customers will fail to build instead of
24        // encountering a runtime error.
25        static_assertions::assert_eq_size!(
26            s2n_tls::client_hello::ClientHello,
27            s2n_tls_sys::s2n_client_hello
28        );
29
30        // https://www.iana.org/assignments/tls-extensiontype-values/tls-extensiontype-values.xhtml#tls-extensiontype-values-1
31        const PRE_SHARED_KEY: c_uint = 41;
32
33        let raw_ch =
34            self as *const s2n_tls::client_hello::ClientHello as *mut s2n_tls_sys::s2n_client_hello;
35        let psk_length =
36            unsafe { s2n_client_hello_get_extension_length(raw_ch, PRE_SHARED_KEY).into_result() }?;
37        if psk_length == 0 {
38            return Ok(None);
39        }
40
41        let mut psk_extension = vec![0; psk_length];
42        let written_length = unsafe {
43            s2n_client_hello_get_extension_by_id(
44                raw_ch,
45                PRE_SHARED_KEY,
46                psk_extension.as_mut_ptr(),
47                psk_extension.len() as u32,
48            )
49            .into_result()
50        }?;
51
52        debug_assert_eq!(psk_length, written_length);
53
54        Ok(Some(psk_extension))
55    }
56}
57
58/// retrieve the PskIdentity items from the Psk extension in the ClientHello.
59pub fn retrieve_psk_identities(
60    client_hello: &s2n_tls::client_hello::ClientHello,
61) -> anyhow::Result<PrefixedList<PskIdentity, u16>> {
62    let psk_extension_data = match client_hello.pre_shared_key()? {
63        Some(data) => data,
64        None => anyhow::bail!("no psk extension found"),
65    };
66    let psk = PresharedKeyClientHello::decode_from_exact(&psk_extension_data)?;
67    Ok(psk.identities)
68}
69
70#[derive(Clone, Debug, PartialEq, Eq)]
71pub struct PskIdentity {
72    pub identity: PrefixedBlob<u16>,
73    obfuscated_ticket_age: u32,
74}
75
76impl DecodeValue for PskIdentity {
77    fn decode_from(buffer: &[u8]) -> std::io::Result<(Self, &[u8])> {
78        let (identity, buffer) = buffer.decode_value()?;
79        let (obfuscated_ticket_age, buffer) = buffer.decode_value()?;
80
81        let value = Self {
82            identity,
83            obfuscated_ticket_age,
84        };
85
86        Ok((value, buffer))
87    }
88}
89
90#[derive(Clone, Debug, PartialEq, Eq)]
91struct PskBinderEntry {
92    entry: PrefixedBlob<u8>,
93}
94
95impl DecodeValue for PskBinderEntry {
96    fn decode_from(buffer: &[u8]) -> std::io::Result<(Self, &[u8])> {
97        let (entry, buffer) = buffer.decode_value()?;
98
99        let value = Self { entry };
100
101        Ok((value, buffer))
102    }
103}
104
105#[derive(Clone, Debug, PartialEq, Eq)]
106pub struct PresharedKeyClientHello {
107    identities: PrefixedList<PskIdentity, u16>,
108    binders: PrefixedList<PskBinderEntry, u16>,
109}
110
111impl DecodeValue for PresharedKeyClientHello {
112    fn decode_from(buffer: &[u8]) -> std::io::Result<(Self, &[u8])> {
113        let (identities, buffer) = buffer.decode_value()?;
114        let (binders, buffer) = buffer.decode_value()?;
115
116        let value = Self {
117            identities,
118            binders,
119        };
120
121        Ok((value, buffer))
122    }
123}
124
125#[cfg(test)]
126mod tests {
127    use super::*;
128    use rcgen::{generate_simple_self_signed, CertifiedKey};
129    use s2n_tls::{callbacks::VerifyHostNameCallback, enums::PskHmac, testing::TestPair};
130
131    const PSK_IDENTITY: &[u8] = b"hello there imma psk";
132    const PSK_SECRET: &[u8] = b"secret material for the psk";
133
134    fn connection_with_psk() -> Result<s2n_tls::connection::Connection, s2n_tls::error::Error> {
135        let mut config = s2n_tls::config::Config::builder();
136        config.set_security_policy(&s2n_tls::security::DEFAULT_TLS13)?;
137        let config = config.build()?;
138
139        let psk = {
140            let mut psk = s2n_tls::psk::Psk::builder()?;
141            psk.set_hmac(PskHmac::SHA384)?;
142            psk.set_identity(PSK_IDENTITY)?;
143            psk.set_secret(PSK_SECRET)?;
144            psk.build()?
145        };
146
147        let mut pair = TestPair::from_config(&config);
148        pair.client.append_psk(&psk)?;
149        pair.server.append_psk(&psk)?;
150
151        pair.handshake()?;
152        Ok(pair.server)
153    }
154
155    fn connection_without_psk() -> Result<s2n_tls::connection::Connection, s2n_tls::error::Error> {
156        struct Verifier;
157        impl VerifyHostNameCallback for Verifier {
158            fn verify_host_name(&self, _host_name: &str) -> bool {
159                true
160            }
161        }
162
163        let CertifiedKey { cert, signing_key } =
164            generate_simple_self_signed(vec!["localhost".to_string()]).unwrap();
165
166        let mut config = s2n_tls::config::Config::builder();
167        config.set_security_policy(&s2n_tls::security::DEFAULT_TLS13)?;
168        config.trust_pem(cert.pem().as_bytes())?;
169        config.load_pem(
170            cert.pem().as_bytes(),
171            signing_key.serialize_pem().as_bytes(),
172        )?;
173        config.set_verify_host_callback(Verifier)?;
174        let config = config.build()?;
175
176        let mut pair = TestPair::from_config(&config);
177
178        pair.handshake().unwrap();
179        Ok(pair.server)
180    }
181
182    #[test]
183    fn retrieve_identities() -> anyhow::Result<()> {
184        let conn = connection_with_psk()?;
185        let client_hello = conn.client_hello()?;
186        let identities = retrieve_psk_identities(client_hello).unwrap();
187
188        assert_eq!(identities.list().len(), 1);
189        assert_eq!(identities.list()[0].identity.blob(), PSK_IDENTITY);
190        Ok(())
191    }
192
193    #[test]
194    fn no_available_identities() -> anyhow::Result<()> {
195        let conn = connection_without_psk()?;
196        let client_hello = conn.client_hello()?;
197        let no_psk_error = retrieve_psk_identities(client_hello).unwrap_err();
198        assert!(no_psk_error.to_string().contains("no psk extension found"));
199        Ok(())
200    }
201}