aws_kms_tls_auth/
psk_parser.rs1use crate::{
5 codec::{DecodeByteSource, DecodeValue},
6 prefixed_list::{PrefixedBlob, PrefixedList},
7};
8use s2n_tls::error::Fallible;
9use s2n_tls_sys::{s2n_client_hello_get_extension_by_id, s2n_client_hello_get_extension_length};
10use std::ffi::c_uint;
11
12trait S2NClientHelloExtension {
13 fn pre_shared_key(&self) -> Result<Option<Vec<u8>>, s2n_tls::error::Error>;
15}
16
17impl S2NClientHelloExtension for s2n_tls::client_hello::ClientHello {
18 fn pre_shared_key(&self) -> Result<Option<Vec<u8>>, s2n_tls::error::Error> {
20 static_assertions::assert_eq_size!(
26 s2n_tls::client_hello::ClientHello,
27 s2n_tls_sys::s2n_client_hello
28 );
29
30 const PRE_SHARED_KEY: c_uint = 41;
32
33 let raw_ch =
34 self as *const s2n_tls::client_hello::ClientHello as *mut s2n_tls_sys::s2n_client_hello;
35 let psk_length =
36 unsafe { s2n_client_hello_get_extension_length(raw_ch, PRE_SHARED_KEY).into_result() }?;
37 if psk_length == 0 {
38 return Ok(None);
39 }
40
41 let mut psk_extension = vec![0; psk_length];
42 let written_length = unsafe {
43 s2n_client_hello_get_extension_by_id(
44 raw_ch,
45 PRE_SHARED_KEY,
46 psk_extension.as_mut_ptr(),
47 psk_extension.len() as u32,
48 )
49 .into_result()
50 }?;
51
52 debug_assert_eq!(psk_length, written_length);
53
54 Ok(Some(psk_extension))
55 }
56}
57
58pub fn retrieve_psk_identities(
60 client_hello: &s2n_tls::client_hello::ClientHello,
61) -> anyhow::Result<PrefixedList<PskIdentity, u16>> {
62 let psk_extension_data = match client_hello.pre_shared_key()? {
63 Some(data) => data,
64 None => anyhow::bail!("no psk extension found"),
65 };
66 let psk = PresharedKeyClientHello::decode_from_exact(&psk_extension_data)?;
67 Ok(psk.identities)
68}
69
70#[derive(Clone, Debug, PartialEq, Eq)]
71pub struct PskIdentity {
72 pub identity: PrefixedBlob<u16>,
73 obfuscated_ticket_age: u32,
74}
75
76impl DecodeValue for PskIdentity {
77 fn decode_from(buffer: &[u8]) -> std::io::Result<(Self, &[u8])> {
78 let (identity, buffer) = buffer.decode_value()?;
79 let (obfuscated_ticket_age, buffer) = buffer.decode_value()?;
80
81 let value = Self {
82 identity,
83 obfuscated_ticket_age,
84 };
85
86 Ok((value, buffer))
87 }
88}
89
90#[derive(Clone, Debug, PartialEq, Eq)]
91struct PskBinderEntry {
92 entry: PrefixedBlob<u8>,
93}
94
95impl DecodeValue for PskBinderEntry {
96 fn decode_from(buffer: &[u8]) -> std::io::Result<(Self, &[u8])> {
97 let (entry, buffer) = buffer.decode_value()?;
98
99 let value = Self { entry };
100
101 Ok((value, buffer))
102 }
103}
104
105#[derive(Clone, Debug, PartialEq, Eq)]
106pub struct PresharedKeyClientHello {
107 identities: PrefixedList<PskIdentity, u16>,
108 binders: PrefixedList<PskBinderEntry, u16>,
109}
110
111impl DecodeValue for PresharedKeyClientHello {
112 fn decode_from(buffer: &[u8]) -> std::io::Result<(Self, &[u8])> {
113 let (identities, buffer) = buffer.decode_value()?;
114 let (binders, buffer) = buffer.decode_value()?;
115
116 let value = Self {
117 identities,
118 binders,
119 };
120
121 Ok((value, buffer))
122 }
123}
124
125#[cfg(test)]
126mod tests {
127 use super::*;
128 use rcgen::{generate_simple_self_signed, CertifiedKey};
129 use s2n_tls::{callbacks::VerifyHostNameCallback, enums::PskHmac, testing::TestPair};
130
131 const PSK_IDENTITY: &[u8] = b"hello there imma psk";
132 const PSK_SECRET: &[u8] = b"secret material for the psk";
133
134 fn connection_with_psk() -> Result<s2n_tls::connection::Connection, s2n_tls::error::Error> {
135 let mut config = s2n_tls::config::Config::builder();
136 config.set_security_policy(&s2n_tls::security::DEFAULT_TLS13)?;
137 let config = config.build()?;
138
139 let psk = {
140 let mut psk = s2n_tls::psk::Psk::builder()?;
141 psk.set_hmac(PskHmac::SHA384)?;
142 psk.set_identity(PSK_IDENTITY)?;
143 psk.set_secret(PSK_SECRET)?;
144 psk.build()?
145 };
146
147 let mut pair = TestPair::from_config(&config);
148 pair.client.append_psk(&psk)?;
149 pair.server.append_psk(&psk)?;
150
151 pair.handshake()?;
152 Ok(pair.server)
153 }
154
155 fn connection_without_psk() -> Result<s2n_tls::connection::Connection, s2n_tls::error::Error> {
156 struct Verifier;
157 impl VerifyHostNameCallback for Verifier {
158 fn verify_host_name(&self, _host_name: &str) -> bool {
159 true
160 }
161 }
162
163 let CertifiedKey { cert, signing_key } =
164 generate_simple_self_signed(vec!["localhost".to_string()]).unwrap();
165
166 let mut config = s2n_tls::config::Config::builder();
167 config.set_security_policy(&s2n_tls::security::DEFAULT_TLS13)?;
168 config.trust_pem(cert.pem().as_bytes())?;
169 config.load_pem(
170 cert.pem().as_bytes(),
171 signing_key.serialize_pem().as_bytes(),
172 )?;
173 config.set_verify_host_callback(Verifier)?;
174 let config = config.build()?;
175
176 let mut pair = TestPair::from_config(&config);
177
178 pair.handshake().unwrap();
179 Ok(pair.server)
180 }
181
182 #[test]
183 fn retrieve_identities() -> anyhow::Result<()> {
184 let conn = connection_with_psk()?;
185 let client_hello = conn.client_hello()?;
186 let identities = retrieve_psk_identities(client_hello).unwrap();
187
188 assert_eq!(identities.list().len(), 1);
189 assert_eq!(identities.list()[0].identity.blob(), PSK_IDENTITY);
190 Ok(())
191 }
192
193 #[test]
194 fn no_available_identities() -> anyhow::Result<()> {
195 let conn = connection_without_psk()?;
196 let client_hello = conn.client_hello()?;
197 let no_psk_error = retrieve_psk_identities(client_hello).unwrap_err();
198 assert!(no_psk_error.to_string().contains("no psk extension found"));
199 Ok(())
200 }
201}