1use std::ptr::NonNull;
5
6use crate::{
7 enums::PskHmac,
8 error::{Error, ErrorType, Fallible},
9};
10use s2n_tls_sys::*;
11
12#[derive(Debug)]
13pub struct Builder {
14 psk: Psk,
15 has_identity: bool,
16 has_secret: bool,
17 has_hmac: bool,
18}
19
20impl Builder {
21 pub fn new() -> Result<Self, crate::error::Error> {
22 crate::init::init();
23 let psk = Psk::allocate()?;
24 Ok(Self {
25 psk,
26 has_identity: false,
27 has_secret: false,
28 has_hmac: false,
29 })
30 }
31
32 pub fn set_identity(&mut self, identity: &[u8]) -> Result<&mut Self, crate::error::Error> {
36 let identity_length = identity.len().try_into().map_err(|_| {
37 Error::bindings(
38 ErrorType::UsageError,
39 "invalid psk identity",
40 "The identity must be no longer than u16::MAX",
41 )
42 })?;
43 unsafe {
44 s2n_psk_set_identity(self.psk.ptr.as_ptr(), identity.as_ptr(), identity_length)
45 .into_result()
46 }?;
47 self.has_identity = true;
48 Ok(self)
49 }
50
51 pub fn set_secret(&mut self, secret: &[u8]) -> Result<&mut Self, crate::error::Error> {
57 let secret_length = secret.len().try_into().map_err(|_| {
58 Error::bindings(
59 ErrorType::UsageError,
60 "invalid psk secret",
61 "The secret must be no longer than u16::MAX",
62 )
63 })?;
64
65 if secret_length < (128 / 8) {
70 return Err(Error::bindings(
71 ErrorType::UsageError,
72 "invalid psk secret",
73 "PSK secret must be at least 128 bits",
74 ));
75 }
76 unsafe {
77 s2n_psk_set_secret(self.psk.ptr.as_ptr(), secret.as_ptr(), secret_length).into_result()
78 }?;
79 self.has_secret = true;
80 Ok(self)
81 }
82
83 pub fn set_hmac(&mut self, hmac: PskHmac) -> Result<&mut Self, crate::error::Error> {
87 unsafe { s2n_psk_set_hmac(self.psk.ptr.as_ptr(), hmac.into()).into_result() }?;
88 self.has_hmac = true;
89 Ok(self)
90 }
91
92 pub fn build(self) -> Result<Psk, crate::error::Error> {
93 if !self.has_identity {
94 Err(Error::bindings(
95 crate::error::ErrorType::UsageError,
96 "invalid psk",
97 "You must set an identity using `with_identity`",
98 ))
99 } else if !self.has_secret {
100 Err(Error::bindings(
101 crate::error::ErrorType::UsageError,
102 "invalid psk",
103 "You must set a secret using `with_secret`",
104 ))
105 } else if !self.has_hmac {
106 Err(Error::bindings(
107 crate::error::ErrorType::UsageError,
108 "invalid psk",
109 "You must set an hmac `with_hmac`",
110 ))
111 } else {
112 Ok(self.psk)
113 }
114 }
115}
116
117#[derive(Debug)]
122pub struct Psk {
123 pub(crate) ptr: NonNull<s2n_psk>,
128}
129
130unsafe impl Send for Psk {}
134
135unsafe impl Sync for Psk {}
140
141impl Psk {
142 fn allocate() -> Result<Self, crate::error::Error> {
146 let psk = unsafe { s2n_external_psk_new().into_result() }?;
147 Ok(Self { ptr: psk })
148 }
149
150 pub fn builder() -> Result<Builder, crate::error::Error> {
151 Builder::new()
152 }
153}
154
155impl Drop for Psk {
156 fn drop(&mut self) {
158 let _ = unsafe { s2n_psk_free(&mut self.ptr.as_ptr()).into_result() };
161 }
162}
163
164#[cfg(test)]
165mod tests {
166 use crate::{config::Config, error::ErrorSource, security::DEFAULT_TLS13, testing::TestPair};
167
168 use super::*;
169
170 #[test]
173 fn build_errors() -> Result<(), crate::error::Error> {
174 const PERMUTATIONS: u8 = 0b111;
175
176 for permutation in 0..PERMUTATIONS {
177 let mut psk = Builder::new()?;
178 if permutation & 0b001 != 0 {
179 psk.set_identity(b"Alice")?;
180 }
181 if permutation & 0b010 != 0 {
182 psk.set_secret(b"Rabbits don't actually jump. They instead push the world down")?;
183 }
184 if permutation & 0b100 != 0 {
185 psk.set_hmac(PskHmac::SHA384)?;
186 }
187 assert!(psk.build().is_err());
188 }
189 Ok(())
190 }
191
192 #[test]
196 fn psk_secret_must_be_at_least_128_bits() -> Result<(), crate::error::Error> {
197 let secret = vec![5; 15];
199
200 let mut psk = Builder::new()?;
201 let err = psk.set_secret(&secret).unwrap_err();
202 assert_eq!(err.source(), ErrorSource::Bindings);
203 assert_eq!(err.kind(), ErrorType::UsageError);
204 assert_eq!(err.name(), "invalid psk secret");
205 assert_eq!(err.message(), "PSK secret must be at least 128 bits");
206 Ok(())
207 }
208
209 const TEST_PSK_IDENTITY: &[u8] = b"alice";
210
211 fn test_psk() -> Psk {
212 let mut builder = Psk::builder().unwrap();
213 builder.set_identity(TEST_PSK_IDENTITY).unwrap();
214 builder
215 .set_secret(b"contrary to popular belief, the moon is yogurt, not cheese")
216 .unwrap();
217 builder.set_hmac(PskHmac::SHA384).unwrap();
218 builder.build().unwrap()
219 }
220
221 #[test]
224 fn psk_handshake() -> Result<(), crate::error::Error> {
225 let psk = test_psk();
226 let mut config = Config::builder();
227 config.set_security_policy(&DEFAULT_TLS13)?;
228 let config = config.build()?;
229 let mut test_pair = TestPair::from_config(&config);
230 test_pair.client.append_psk(&psk)?;
231 test_pair.server.append_psk(&psk)?;
232 assert!(test_pair.handshake().is_ok());
233
234 for peer in [test_pair.client, test_pair.server] {
235 let mut identity_buffer = [0; TEST_PSK_IDENTITY.len()];
236 assert_eq!(
237 peer.negotiated_psk_identity_length()?,
238 TEST_PSK_IDENTITY.len()
239 );
240 peer.negotiated_psk_identity(&mut identity_buffer)?;
241 assert_eq!(identity_buffer, TEST_PSK_IDENTITY);
242 }
243 Ok(())
244 }
245}