s2n_tls/connection.rs
1// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4#![allow(clippy::missing_safety_doc)] // TODO add safety docs
5
6#[cfg(feature = "unstable-renegotiate")]
7use crate::renegotiate::RenegotiateState;
8use crate::{
9 callbacks::*,
10 cert_chain::{CertificateChain, CertificateChainHandle},
11 config::Config,
12 enums::*,
13 error::{Error, Fallible, Pollable},
14 psk::Psk,
15 security,
16};
17
18use core::{
19 convert::TryInto,
20 fmt,
21 mem::{self, ManuallyDrop, MaybeUninit},
22 pin::Pin,
23 ptr::NonNull,
24 task::{Poll, Waker},
25 time::Duration,
26};
27use libc::c_void;
28use s2n_tls_sys::*;
29use std::{any::Any, ffi::CStr};
30
31mod builder;
32pub use builder::*;
33
34/// return a &str scoped to the lifetime of the surrounding function
35///
36/// SAFETY: must be called on a null terminated string
37///
38/// SAFETY: the underlying data must live at least as long as the surrounding scope
39// We use a macro instead of a function so that the lifetime of the output is
40// automatically inferred to match the surrounding scope.
41macro_rules! const_str {
42 ($c_chars:expr) => {
43 CStr::from_ptr($c_chars)
44 .to_str()
45 .map_err(|_| Error::INVALID_INPUT)
46 };
47}
48
49#[non_exhaustive]
50#[derive(Debug, PartialEq)]
51/// s2n-tls only tracks up to u8::MAX (255) key updates. If any of the fields show
52/// 255 updates, then more than 255 updates may have occurred.
53pub struct KeyUpdateCount {
54 pub send_key_updates: u8,
55 pub recv_key_updates: u8,
56}
57
58/// Corresponds to [s2n_connection].
59pub struct Connection {
60 connection: NonNull<s2n_connection>,
61}
62
63impl fmt::Debug for Connection {
64 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
65 let mut debug = f.debug_struct("Connection");
66 if let Ok(handshake) = self.handshake_type() {
67 debug.field("handshake_type", &handshake);
68 }
69 if let Ok(cipher) = self.cipher_suite() {
70 debug.field("cipher_suite", &cipher);
71 }
72 if let Ok(version) = self.actual_protocol_version() {
73 debug.field("actual_protocol_version", &version);
74 }
75 if let Ok(curve) = self.selected_curve() {
76 debug.field("selected_curve", &curve);
77 }
78 debug.finish_non_exhaustive()
79 }
80}
81
82/// # Safety
83///
84/// s2n_connection objects can be sent across threads
85unsafe impl Send for Connection {}
86
87/// # Sync
88///
89/// Although NonNull isn't Sync and allows access to mutable pointers even from
90/// immutable references, the Connection interface enforces that all mutating
91/// methods correctly require &mut self.
92///
93/// Developers and reviewers MUST ensure that new methods correctly use
94/// either &self or &mut self depending on their behavior. No mechanism enforces this.
95///
96/// Note: Although non-mutating methods like getters should be thread-safe by definition,
97/// technically the only thread safety guarantee provided by the underlying C library
98/// is that s2n_send and s2n_recv can be called concurrently.
99///
100unsafe impl Sync for Connection {}
101
102impl Connection {
103 /// # Warning
104 ///
105 /// The newly created connection uses the default security policy.
106 /// Consider changing this depending on your security and compatibility requirements
107 /// by calling [`Connection::set_security_policy`].
108 /// Alternatively, you can use [`crate::config::Builder`], [`crate::config::Builder::set_security_policy`],
109 /// and [`Connection::set_config`] to set the policy on the Config instead of on the Connection.
110 /// See the s2n-tls usage guide:
111 /// <https://aws.github.io/s2n-tls/usage-guide/ch06-security-policies.html>
112 ///
113 /// Corresponds to [s2n_connection_new].
114 pub fn new(mode: Mode) -> Self {
115 crate::init::init();
116
117 let connection = unsafe { s2n_connection_new(mode.into()).into_result() }.unwrap();
118
119 unsafe {
120 debug_assert! {
121 s2n_connection_get_config(connection.as_ptr(), &mut core::ptr::null_mut())
122 .into_result()
123 .is_err()
124 }
125 }
126
127 let mut connection = Self { connection };
128 connection.init_context(mode);
129 connection
130 }
131
132 fn init_context(&mut self, mode: Mode) {
133 let context = Box::new(Context::new(mode));
134 let context = Box::into_raw(context) as *mut c_void;
135 // allocate a new context object
136 unsafe {
137 // There should never be an existing context
138 debug_assert!(s2n_connection_get_ctx(self.connection.as_ptr())
139 .into_result()
140 .is_err());
141
142 s2n_connection_set_ctx(self.connection.as_ptr(), context)
143 .into_result()
144 .unwrap();
145 }
146 }
147
148 pub fn new_client() -> Self {
149 Self::new(Mode::Client)
150 }
151
152 pub fn new_server() -> Self {
153 Self::new(Mode::Server)
154 }
155
156 pub(crate) fn as_ptr(&mut self) -> *mut s2n_connection {
157 self.connection.as_ptr()
158 }
159
160 /// # Safety
161 ///
162 /// Caller must ensure s2n_connection is a valid reference to a [`s2n_connection`] object
163 pub(crate) unsafe fn from_raw(connection: NonNull<s2n_connection>) -> Self {
164 Self { connection }
165 }
166
167 pub(crate) fn mode(&self) -> Mode {
168 self.context().mode
169 }
170
171 /// can be used to configure s2n to either use built-in blinding (set blinding
172 /// to Blinding::BuiltIn) or self-service blinding (set blinding to
173 /// Blinding::SelfService).
174 ///
175 /// Corresponds to [s2n_connection_set_blinding].
176 pub fn set_blinding(&mut self, blinding: Blinding) -> Result<&mut Self, Error> {
177 unsafe {
178 s2n_connection_set_blinding(self.connection.as_ptr(), blinding.into()).into_result()
179 }?;
180 Ok(self)
181 }
182
183 /// Reports the remaining nanoseconds before the connection may be gracefully shutdown.
184 ///
185 /// This method is expected to succeed, but could fail if the
186 /// [underlying C call](`s2n_connection_get_delay`) encounters errors.
187 /// Failure indicates that calls to [`Self::poll_shutdown`] will also fail and
188 /// that a graceful two-way shutdown of the connection will not be possible.
189 ///
190 /// Corresponds to [s2n_connection_get_delay].
191 pub fn remaining_blinding_delay(&self) -> Result<Duration, Error> {
192 let nanos = unsafe { s2n_connection_get_delay(self.connection.as_ptr()).into_result() }?;
193 Ok(Duration::from_nanos(nanos))
194 }
195
196 /// Sets whether or not a Client Certificate should be required to complete the TLS Connection.
197 ///
198 /// If this is set to ClientAuthType::Optional the server will request a client certificate
199 /// but allow the client to not provide one. Rejecting a client certificate when using
200 /// ClientAuthType::Optional will terminate the handshake.
201 ///
202 /// Corresponds to [s2n_connection_set_client_auth_type].
203 pub fn set_client_auth_type(
204 &mut self,
205 client_auth_type: ClientAuthType,
206 ) -> Result<&mut Self, Error> {
207 unsafe {
208 s2n_connection_set_client_auth_type(self.connection.as_ptr(), client_auth_type.into())
209 .into_result()
210 }?;
211 Ok(self)
212 }
213
214 /// Attempts to drop the config on the connection.
215 ///
216 /// # Safety
217 ///
218 /// The caller must ensure the config associated with the connection was created
219 /// with a [`config::Builder`].
220 unsafe fn drop_config(&mut self) -> Result<(), Error> {
221 let mut prev_config = core::ptr::null_mut();
222
223 // A valid non-null pointer is returned only if the application previously called
224 // [`Self::set_config()`].
225 if s2n_connection_get_config(self.connection.as_ptr(), &mut prev_config)
226 .into_result()
227 .is_ok()
228 {
229 let prev_config = NonNull::new(prev_config).expect(
230 "config should exist since the call to s2n_connection_get_config was successful",
231 );
232 drop(Config::from_raw(prev_config));
233 }
234
235 Ok(())
236 }
237
238 /// Associates a configuration object with a connection.
239 ///
240 /// Corresponds to [s2n_connection_set_config].
241 pub fn set_config(&mut self, mut config: Config) -> Result<&mut Self, Error> {
242 unsafe {
243 // attempt to drop the currently set config
244 self.drop_config()?;
245
246 s2n_connection_set_config(self.connection.as_ptr(), config.as_mut_ptr())
247 .into_result()?;
248
249 debug_assert! {
250 s2n_connection_get_config(self.connection.as_ptr(), &mut core::ptr::null_mut()).into_result().is_ok(),
251 "s2n_connection_set_config was successful"
252 };
253
254 // Setting the config on the connection creates one additional reference to the config
255 // so do not drop so prevent Rust from calling `drop()` at the end of this function.
256 mem::forget(config);
257 }
258
259 Ok(self)
260 }
261
262 pub(crate) fn config(&self) -> Option<Config> {
263 let mut raw = core::ptr::null_mut();
264 let config = unsafe {
265 s2n_connection_get_config(self.connection.as_ptr(), &mut raw)
266 .into_result()
267 .ok()?;
268 let raw = NonNull::new(raw)?;
269 Config::from_raw(raw)
270 };
271 // Because the config pointer is still set on the connection, this is a copy,
272 // not the original config. This is fine -- Configs are immutable.
273 let _ = ManuallyDrop::new(config.clone());
274 Some(config)
275 }
276
277 /// Corresponds to [s2n_connection_set_cipher_preferences].
278 pub fn set_security_policy(&mut self, policy: &security::Policy) -> Result<&mut Self, Error> {
279 unsafe {
280 s2n_connection_set_cipher_preferences(
281 self.connection.as_ptr(),
282 policy.as_cstr().as_ptr(),
283 )
284 .into_result()
285 }?;
286 Ok(self)
287 }
288
289 /// provides a smooth transition from s2n_connection_prefer_low_latency to s2n_connection_prefer_throughput.
290 ///
291 /// s2n_send uses small TLS records that fit into a single TCP segment for the resize_threshold
292 /// bytes (cap to 8M) of data and reset record size back to a single segment after timeout_threshold
293 /// seconds of inactivity.
294 ///
295 /// Corresponds to [s2n_connection_set_dynamic_record_threshold].
296 pub fn set_dynamic_record_threshold(
297 &mut self,
298 resize_threshold: u32,
299 timeout_threshold: u16,
300 ) -> Result<&mut Self, Error> {
301 unsafe {
302 s2n_connection_set_dynamic_record_threshold(
303 self.connection.as_ptr(),
304 resize_threshold,
305 timeout_threshold,
306 )
307 .into_result()
308 }?;
309 Ok(self)
310 }
311
312 /// Signals the connection to do a key_update at the next possible opportunity.
313 /// Note that the resulting key update message will not be sent until `send` is
314 /// called on the connection.
315 ///
316 /// `peer_request` indicates if a key update should also be requested
317 /// of the peer. When set to `KeyUpdateNotRequested`, then only the sending
318 /// key of the connection will be updated. If set to `KeyUpdateRequested`, then
319 /// the sending key of conn will be updated AND the peer will be requested to
320 /// update their sending key. Note that s2n-tls currently only supports
321 /// `peer_request` being set to `KeyUpdateNotRequested` and will return an error
322 /// if any other value is used.
323 ///
324 /// Corresponds to [s2n_connection_request_key_update].
325 pub fn request_key_update(&mut self, peer_request: PeerKeyUpdate) -> Result<&mut Self, Error> {
326 unsafe {
327 s2n_connection_request_key_update(self.connection.as_ptr(), peer_request.into())
328 .into_result()
329 }?;
330 Ok(self)
331 }
332
333 /// Reports the number of times sending and receiving keys have been updated.
334 ///
335 /// This only applies to TLS1.3. Earlier versions do not support key updates.
336 ///
337 /// Corresponds to [s2n_connection_get_key_update_counts].
338 #[cfg(feature = "unstable-ktls")]
339 pub fn key_update_counts(&self) -> Result<KeyUpdateCount, Error> {
340 let mut send_key_updates = 0;
341 let mut recv_key_updates = 0;
342 unsafe {
343 s2n_connection_get_key_update_counts(
344 self.connection.as_ptr(),
345 &mut send_key_updates,
346 &mut recv_key_updates,
347 )
348 .into_result()?;
349 }
350 Ok(KeyUpdateCount {
351 send_key_updates,
352 recv_key_updates,
353 })
354 }
355
356 /// sets the application protocol preferences on an s2n_connection object.
357 ///
358 /// protocols is a list in order of preference, with most preferred protocol first, and of
359 /// length protocol_count. When acting as a client the protocol list is included in the
360 /// Client Hello message as the ALPN extension. As a server, the list is used to negotiate
361 /// a mutual application protocol with the client. After the negotiation for the connection has
362 /// completed, the agreed upon protocol can be retrieved with s2n_get_application_protocol
363 ///
364 /// Corresponds to [s2n_connection_set_protocol_preferences].
365 pub fn set_application_protocol_preference<P: IntoIterator<Item = I>, I: AsRef<[u8]>>(
366 &mut self,
367 protocols: P,
368 ) -> Result<&mut Self, Error> {
369 // reset the list
370 unsafe {
371 s2n_connection_set_protocol_preferences(self.connection.as_ptr(), core::ptr::null(), 0)
372 .into_result()
373 }?;
374
375 for protocol in protocols {
376 self.append_application_protocol_preference(protocol.as_ref())?;
377 }
378
379 Ok(self)
380 }
381
382 /// Corresponds to [s2n_connection_append_protocol_preference].
383 pub fn append_application_protocol_preference(
384 &mut self,
385 protocol: &[u8],
386 ) -> Result<&mut Self, Error> {
387 unsafe {
388 s2n_connection_append_protocol_preference(
389 self.connection.as_ptr(),
390 protocol.as_ptr(),
391 protocol
392 .len()
393 .try_into()
394 .map_err(|_| Error::INVALID_INPUT)?,
395 )
396 .into_result()
397 }?;
398 Ok(self)
399 }
400
401 /// may be used to receive data with callbacks defined by the user.
402 ///
403 /// Corresponds to [s2n_connection_set_recv_cb].
404 pub fn set_receive_callback(&mut self, callback: s2n_recv_fn) -> Result<&mut Self, Error> {
405 unsafe { s2n_connection_set_recv_cb(self.connection.as_ptr(), callback).into_result() }?;
406 Ok(self)
407 }
408
409 /// # Safety
410 ///
411 /// The `context` pointer must live at least as long as the connection
412 ///
413 /// Corresponds to [s2n_connection_set_recv_ctx].
414 pub unsafe fn set_receive_context(&mut self, context: *mut c_void) -> Result<&mut Self, Error> {
415 s2n_connection_set_recv_ctx(self.connection.as_ptr(), context).into_result()?;
416 Ok(self)
417 }
418
419 /// may be used to receive data with callbacks defined by the user.
420 ///
421 /// Corresponds to [s2n_connection_set_send_cb].
422 pub fn set_send_callback(&mut self, callback: s2n_send_fn) -> Result<&mut Self, Error> {
423 unsafe { s2n_connection_set_send_cb(self.connection.as_ptr(), callback).into_result() }?;
424 Ok(self)
425 }
426
427 /// # Safety
428 ///
429 /// The `context` pointer must live at least as long as the connection
430 ///
431 /// Corresponds to [s2n_connection_set_send_ctx].
432 pub unsafe fn set_send_context(&mut self, context: *mut c_void) -> Result<&mut Self, Error> {
433 s2n_connection_set_send_ctx(self.connection.as_ptr(), context).into_result()?;
434 Ok(self)
435 }
436
437 /// Sets the callback to use for verifying that a hostname from an X.509 certificate is
438 /// trusted.
439 ///
440 /// The callback may be called more than once during certificate validation as each SAN on
441 /// the certificate will be checked.
442 ///
443 /// Corresponds to [s2n_connection_set_verify_host_callback].
444 pub fn set_verify_host_callback<T: 'static + VerifyHostNameCallback>(
445 &mut self,
446 handler: T,
447 ) -> Result<&mut Self, Error> {
448 unsafe extern "C" fn verify_host_cb_fn(
449 host_name: *const ::libc::c_char,
450 host_name_len: usize,
451 context: *mut ::libc::c_void,
452 ) -> u8 {
453 let context = &mut *(context as *mut Context);
454 let handler = context.verify_host_callback.as_mut().unwrap();
455 verify_host(host_name, host_name_len, handler)
456 }
457
458 self.context_mut().verify_host_callback = Some(Box::new(handler));
459 unsafe {
460 s2n_connection_set_verify_host_callback(
461 self.connection.as_ptr(),
462 Some(verify_host_cb_fn),
463 self.context_mut() as *mut Context as *mut c_void,
464 )
465 .into_result()
466 }?;
467 Ok(self)
468 }
469
470 /// Connections preferring low latency will be encrypted using small record sizes that
471 /// can be decrypted sooner by the recipient.
472 ///
473 /// Corresponds to [s2n_connection_prefer_low_latency].
474 pub fn prefer_low_latency(&mut self) -> Result<&mut Self, Error> {
475 unsafe { s2n_connection_prefer_low_latency(self.connection.as_ptr()).into_result() }?;
476 Ok(self)
477 }
478
479 /// Connections preferring throughput will use large record sizes that minimize overhead.
480 ///
481 /// Corresponds to [s2n_connection_prefer_throughput].
482 pub fn prefer_throughput(&mut self) -> Result<&mut Self, Error> {
483 unsafe { s2n_connection_prefer_throughput(self.connection.as_ptr()).into_result() }?;
484 Ok(self)
485 }
486
487 /// Configure the connection to reduce potentially expensive calls to recv.
488 ///
489 /// Corresponds to [s2n_connection_set_recv_buffering].
490 pub fn set_receive_buffering(&mut self, enabled: bool) -> Result<&mut Self, Error> {
491 unsafe {
492 s2n_connection_set_recv_buffering(self.connection.as_ptr(), enabled).into_result()
493 }?;
494 Ok(self)
495 }
496
497 /// wipes and free the in and out buffers associated with a connection.
498 ///
499 /// This function may be called when a connection is in keep-alive or idle state to
500 /// reduce memory overhead of long lived connections.
501 ///
502 /// Corresponds to [s2n_connection_release_buffers].
503 pub fn release_buffers(&mut self) -> Result<&mut Self, Error> {
504 unsafe { s2n_connection_release_buffers(self.connection.as_ptr()).into_result() }?;
505 Ok(self)
506 }
507
508 /// Corresponds to [s2n_connection_use_corked_io].
509 pub fn use_corked_io(&mut self) -> Result<&mut Self, Error> {
510 unsafe { s2n_connection_use_corked_io(self.connection.as_ptr()).into_result() }?;
511 Ok(self)
512 }
513
514 pub(crate) fn wipe_method<F, T>(&mut self, wipe: F) -> Result<(), Error>
515 where
516 F: FnOnce(&mut Self) -> Result<T, Error>,
517 {
518 let mode = self.mode();
519
520 // Safety:
521 // We re-init the context after the wipe
522 unsafe { self.drop_context()? };
523
524 let result = wipe(self);
525 // We must initialize the context again whether or not wipe succeeds.
526 // A connection without a context is invalid and has undefined behavior.
527 self.init_context(mode);
528 result?;
529
530 Ok(())
531 }
532
533 /// wipes an existing connection and allows it to be reused.
534 ///
535 /// This method erases all data associated with a connection including pending reads.
536 /// This function should be called after all I/O is completed and s2n_shutdown has been
537 /// called. Reusing the same connection handle(s) is more performant than repeatedly
538 /// calling s2n_connection_new and s2n_connection_free
539 ///
540 /// Corresponds to [s2n_connection_wipe].
541 pub fn wipe(&mut self) -> Result<&mut Self, Error> {
542 self.wipe_method(|conn| unsafe { s2n_connection_wipe(conn.as_ptr()).into_result() })?;
543 Ok(self)
544 }
545
546 fn trigger_initializer(&mut self) {
547 if !core::mem::replace(&mut self.context_mut().connection_initialized, true) {
548 if let Some(config) = self.config() {
549 if let Some(callback) = config.context().connection_initializer.as_ref() {
550 let future = callback.initialize_connection(self);
551 AsyncCallback::trigger(future, self);
552 }
553 }
554 }
555 }
556
557 // Poll the connection future if it exists.
558 //
559 // If the future returns Pending, then re-set it back on the Connection.
560 fn poll_async_task(&mut self) -> Option<Poll<Result<(), Error>>> {
561 self.take_async_callback().map(|mut callback| {
562 let waker = self.waker().ok_or(Error::MISSING_WAKER)?.clone();
563 let mut ctx = core::task::Context::from_waker(&waker);
564 match Pin::new(&mut callback).poll(self, &mut ctx) {
565 Poll::Ready(result) => Poll::Ready(result),
566 Poll::Pending => {
567 // replace the future if it hasn't completed yet
568 self.set_async_callback(callback);
569 Poll::Pending
570 }
571 }
572 })
573 }
574
575 pub(crate) fn poll_negotiate_method<F, T>(
576 &mut self,
577 mut negotiate: F,
578 ) -> Poll<Result<(), Error>>
579 where
580 F: FnMut(&mut Connection) -> Poll<Result<T, Error>>,
581 {
582 self.trigger_initializer();
583
584 loop {
585 // Check whether renegotiate is blocked by any async callbacks
586 match self.poll_async_task().unwrap_or(Poll::Ready(Ok(()))) {
587 Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
588 Poll::Pending => return Poll::Pending,
589 Poll::Ready(Ok(_)) => {}
590 };
591
592 match negotiate(self) {
593 Poll::Ready(res) => return Poll::Ready(res.map(|_| ())),
594 Poll::Pending => {
595 // If `negotiate` returned `Pending` it could be blocked on a connection future
596 // (i.e. not socket IO) so before we return, we need to make sure we poll
597 // the associated future at least once. Otherwise, we will violate the waker contract.
598 //
599 // See https://github.com/aws/s2n-quic/pull/2248
600 if self.context_mut().async_callback.is_some() {
601 // continuing in the loop will poll the task
602 continue;
603 }
604
605 // we don't have anything else to poll so return `Pending`
606 return Poll::Pending;
607 }
608 }
609 }
610 }
611
612 /// Performs the TLS handshake to completion
613 ///
614 /// Multiple callbacks can be configured for a connection and config, but
615 /// [`Self::poll_negotiate()`] can only execute and block on one callback at a time.
616 /// The handshake is sequential, not concurrent, and stops execution when
617 /// it encounters an async callback.
618 ///
619 /// The handshake does not continue execution (and therefore can't call
620 /// any other callbacks) until the blocking async task reports completion.
621 ///
622 /// Corresponds to [s2n_negotiate].
623 pub fn poll_negotiate(&mut self) -> Poll<Result<&mut Self, Error>> {
624 let mut blocked = s2n_blocked_status::NOT_BLOCKED;
625 self.poll_negotiate_method(|conn| unsafe {
626 s2n_negotiate(conn.as_ptr(), &mut blocked).into_poll()
627 })
628 .map_ok(|_| self)
629 }
630
631 /// Encrypts and sends data on a connection where
632 /// [negotiate](`Self::poll_negotiate`) has succeeded.
633 ///
634 /// Returns the number of bytes written, and may indicate a partial write.
635 ///
636 /// Corresponds to [s2n_send].
637 #[cfg(not(feature = "unstable-renegotiate"))]
638 pub fn poll_send(&mut self, buf: &[u8]) -> Poll<Result<usize, Error>> {
639 let mut blocked = s2n_blocked_status::NOT_BLOCKED;
640 let buf_len: isize = buf.len().try_into().map_err(|_| Error::INVALID_INPUT)?;
641 let buf_ptr = buf.as_ptr() as *const ::libc::c_void;
642 unsafe { s2n_send(self.connection.as_ptr(), buf_ptr, buf_len, &mut blocked).into_poll() }
643 }
644
645 #[cfg(not(feature = "unstable-renegotiate"))]
646 pub(crate) fn poll_recv_raw(
647 &mut self,
648 buf_ptr: *mut ::libc::c_void,
649 buf_len: isize,
650 ) -> Poll<Result<usize, Error>> {
651 let mut blocked = s2n_blocked_status::NOT_BLOCKED;
652 unsafe { s2n_recv(self.connection.as_ptr(), buf_ptr, buf_len, &mut blocked).into_poll() }
653 }
654
655 /// Reads and decrypts data from a connection where
656 /// [negotiate](`Self::poll_negotiate`) has succeeded.
657 ///
658 /// Returns the number of bytes read, and may indicate a partial read.
659 /// 0 bytes returned indicates EOF due to connection closure.
660 ///
661 /// Corresponds to [s2n_recv].
662 pub fn poll_recv(&mut self, buf: &mut [u8]) -> Poll<Result<usize, Error>> {
663 let buf_len: isize = buf.len().try_into().map_err(|_| Error::INVALID_INPUT)?;
664 let buf_ptr = buf.as_ptr() as *mut ::libc::c_void;
665 self.poll_recv_raw(buf_ptr, buf_len)
666 }
667
668 /// Reads and decrypts data from a connection where
669 /// [negotiate](`Self::poll_negotiate`) has succeeded
670 /// to a uninitialized buffer.
671 ///
672 /// Returns the number of bytes read, and may indicate a partial read.
673 /// 0 bytes returned indicates EOF due to connection closure.
674 ///
675 /// Safety: this function is always safe to call, and additionally:
676 /// 1. It will never uninitialize any bytes in `buf`.
677 /// 2. If it returns `Ok(n)`, then the first `n` bytes of `buf`
678 /// will have been initialized by this function.
679 ///
680 /// Corresponds to [s2n_recv].
681 pub fn poll_recv_uninitialized(
682 &mut self,
683 buf: &mut [MaybeUninit<u8>],
684 ) -> Poll<Result<usize, Error>> {
685 let buf_len: isize = buf.len().try_into().map_err(|_| Error::INVALID_INPUT)?;
686 let buf_ptr = buf.as_ptr() as *mut ::libc::c_void;
687
688 // Safety:
689 // 1. s2n_recv never writes uninitialized garbage to `buf`.
690 // 2. if s2n_recv returns `+n`, it guarantees that the first
691 // `n` bytes of `buf` have been initialized, which allows this
692 // function to return `Ok(n)`
693 self.poll_recv_raw(buf_ptr, buf_len)
694 }
695
696 /// Attempts to flush any data previously buffered by a call to [send](`Self::poll_send`).
697 ///
698 /// poll_flush can only flush data that s2n-tls has already encrypted and
699 /// buffered for sending. poll_send may need to be called again to fully send
700 /// all data. See the [Usage Guide](https://github.com/aws/s2n-tls/blob/main/docs/usage-guide/topics/ch07-io.md)
701 /// for more details.
702 ///
703 /// Corresponds to [s2n_flush].
704 pub fn poll_flush(&mut self) -> Poll<Result<&mut Self, Error>> {
705 let mut blocked = s2n_blocked_status::NOT_BLOCKED;
706 unsafe {
707 s2n_flush(self.connection.as_ptr(), &mut blocked)
708 .into_poll()
709 .map_ok(|_| self)
710 }
711 }
712
713 /// Gets the number of bytes that are currently available in the buffer to be read.
714 ///
715 /// Corresponds to [s2n_peek].
716 pub fn peek_len(&self) -> usize {
717 unsafe { s2n_peek(self.connection.as_ptr()) as usize }
718 }
719
720 /// Attempts a graceful shutdown of the TLS connection.
721 ///
722 /// The shutdown is not complete until the necessary shutdown messages
723 /// have been successfully sent and received. If the peer does not respond
724 /// correctly, the graceful shutdown may fail.
725 ///
726 /// Corresponds to [s2n_shutdown].
727 pub fn poll_shutdown(&mut self) -> Poll<Result<&mut Self, Error>> {
728 if !self.remaining_blinding_delay()?.is_zero() {
729 return Poll::Pending;
730 }
731 let mut blocked = s2n_blocked_status::NOT_BLOCKED;
732 unsafe {
733 s2n_shutdown(self.connection.as_ptr(), &mut blocked)
734 .into_poll()
735 .map_ok(|_| self)
736 }
737 }
738
739 /// Attempts a graceful shutdown of the write side of a TLS connection.
740 ///
741 /// Unlike Self::poll_shutdown, no response from the peer is necessary.
742 /// If using TLS1.3, the connection can continue to be used for reading afterwards.
743 ///
744 /// Corresponds to [s2n_shutdown_send].
745 pub fn poll_shutdown_send(&mut self) -> Poll<Result<&mut Self, Error>> {
746 if !self.remaining_blinding_delay()?.is_zero() {
747 return Poll::Pending;
748 }
749 let mut blocked = s2n_blocked_status::NOT_BLOCKED;
750 unsafe {
751 s2n_shutdown_send(self.connection.as_ptr(), &mut blocked)
752 .into_poll()
753 .map_ok(|_| self)
754 }
755 }
756
757 /// Returns the TLS alert code, if any
758 ///
759 /// Corresponds to [s2n_connection_get_alert].
760 pub fn alert(&self) -> Option<u8> {
761 let alert =
762 unsafe { s2n_connection_get_alert(self.connection.as_ptr()).into_result() }.ok()?;
763 Some(alert as u8)
764 }
765
766 /// Sets the server name value for the connection
767 ///
768 /// Corresponds to [s2n_set_server_name].
769 pub fn set_server_name(&mut self, server_name: &str) -> Result<&mut Self, Error> {
770 let server_name = std::ffi::CString::new(server_name).map_err(|_| Error::INVALID_INPUT)?;
771 unsafe {
772 s2n_set_server_name(self.connection.as_ptr(), server_name.as_ptr()).into_result()
773 }?;
774 Ok(self)
775 }
776
777 /// Get the server name associated with the connection client hello.
778 ///
779 /// Corresponds to [s2n_get_server_name].
780 pub fn server_name(&self) -> Option<&str> {
781 unsafe {
782 let server_name = s2n_get_server_name(self.connection.as_ptr());
783 match server_name.into_result() {
784 Ok(server_name) => CStr::from_ptr(server_name).to_str().ok(),
785 Err(_) => None,
786 }
787 }
788 }
789
790 /// Adds a session ticket from a previous TLS connection to create a resumed session
791 ///
792 /// Corresponds to [s2n_connection_set_session].
793 pub fn set_session_ticket(&mut self, session: &[u8]) -> Result<&mut Self, Error> {
794 unsafe {
795 s2n_connection_set_session(self.connection.as_ptr(), session.as_ptr(), session.len())
796 .into_result()
797 }?;
798 Ok(self)
799 }
800
801 /// Retrieves the size of the session ticket.
802 ///
803 /// Corresponds to [s2n_connection_get_session_length].
804 pub fn session_ticket_length(&self) -> Result<usize, Error> {
805 let len =
806 unsafe { s2n_connection_get_session_length(self.connection.as_ptr()).into_result()? };
807 Ok(len.try_into().unwrap())
808 }
809
810 /// Serializes the session state from the connection into `output` and returns
811 /// the length of the session ticket.
812 ///
813 /// If the buffer does not have the size for the session_ticket,
814 /// `Error::INVALID_INPUT` is returned.
815 ///
816 /// Note: This function is not recommended for > TLS1.2 because in TLS1.3
817 /// servers can send multiple session tickets and this will return only
818 /// the most recently received ticket.
819 ///
820 /// Corresponds to [s2n_connection_get_session].
821 pub fn session_ticket(&self, output: &mut [u8]) -> Result<usize, Error> {
822 if output.len() < self.session_ticket_length()? {
823 return Err(Error::INVALID_INPUT);
824 }
825 let written = unsafe {
826 s2n_connection_get_session(self.connection.as_ptr(), output.as_mut_ptr(), output.len())
827 .into_result()?
828 };
829 Ok(written.try_into().unwrap())
830 }
831
832 /// Sets a Waker on the connection context or clears it if `None` is passed.
833 pub fn set_waker(&mut self, waker: Option<&Waker>) -> Result<&mut Self, Error> {
834 let ctx = self.context_mut();
835
836 if let Some(waker) = waker {
837 if let Some(prev_waker) = ctx.waker.as_mut() {
838 // only replace the Waker if they don't reference the same task
839 if !prev_waker.will_wake(waker) {
840 prev_waker.clone_from(waker);
841 }
842 } else {
843 ctx.waker = Some(waker.clone());
844 }
845 } else {
846 ctx.waker = None;
847 }
848 Ok(self)
849 }
850
851 /// Returns the Waker set on the connection context.
852 pub fn waker(&self) -> Option<&Waker> {
853 let ctx = self.context();
854 ctx.waker.as_ref()
855 }
856
857 /// Takes the [`Option::take`] the connection_future stored on the
858 /// connection context.
859 ///
860 /// If the Future returns `Poll::Pending` and has not completed, then it
861 /// should be re-set using [`Self::set_connection_future()`]
862 fn take_async_callback(&mut self) -> Option<AsyncCallback> {
863 let ctx = self.context_mut();
864 ctx.async_callback.take()
865 }
866
867 /// Sets a `connection_future` on the connection context.
868 pub(crate) fn set_async_callback(&mut self, callback: AsyncCallback) {
869 let ctx = self.context_mut();
870 debug_assert!(ctx.async_callback.is_none());
871 ctx.async_callback = Some(callback);
872 }
873
874 /// Retrieve a mutable reference to the [`Context`] stored on the connection.
875 fn context_mut(&mut self) -> &mut Context {
876 unsafe {
877 let ctx = s2n_connection_get_ctx(self.connection.as_ptr())
878 .into_result()
879 .unwrap();
880 &mut *(ctx.as_ptr() as *mut Context)
881 }
882 }
883
884 /// Retrieve a reference to the [`Context`] stored on the connection.
885 fn context(&self) -> &Context {
886 unsafe {
887 let ctx = s2n_connection_get_ctx(self.connection.as_ptr())
888 .into_result()
889 .unwrap();
890 &*(ctx.as_ptr() as *mut Context)
891 }
892 }
893
894 /// Drop the context
895 ///
896 /// SAFETY:
897 /// A connection without a context is invalid. After calling this method
898 /// from anywhere other than Drop, you must reinitialize the context.
899 unsafe fn drop_context(&mut self) -> Result<(), Error> {
900 let ctx = s2n_connection_get_ctx(self.connection.as_ptr()).into_result();
901 if let Ok(ctx) = ctx {
902 drop(Box::from_raw(ctx.as_ptr() as *mut Context));
903 }
904 // Setting a NULL context is important: if we don't also remove the context
905 // from the connection, then the invalid memory is still accessible and
906 // may even be double-freed.
907 s2n_connection_set_ctx(self.connection.as_ptr(), core::ptr::null_mut()).into_result()?;
908 Ok(())
909 }
910
911 /// Mark that the server_name extension was used to configure the connection.
912 ///
913 /// Corresponds to [s2n_connection_server_name_extension_used].
914 pub fn server_name_extension_used(&mut self) {
915 // TODO: requiring the application to call this method is a pretty sharp edge.
916 // Figure out if its possible to automatically call this from the Rust bindings.
917 unsafe {
918 s2n_connection_server_name_extension_used(self.connection.as_ptr())
919 .into_result()
920 .unwrap();
921 }
922 }
923
924 /// Check if client auth was used for a connection.
925 ///
926 /// This is only relevant if [`ClientAuthType::Optional] was used.
927 ///
928 /// Corresponds to [s2n_connection_client_cert_used].
929 pub fn client_cert_used(&self) -> bool {
930 unsafe { s2n_connection_client_cert_used(self.connection.as_ptr()) == 1 }
931 }
932
933 /// Retrieves the raw bytes of the client cert chain received from the peer, if present.
934 ///
935 /// Corresponds to [s2n_connection_get_client_cert_chain].
936 pub fn client_cert_chain_bytes(&self) -> Result<Option<&[u8]>, Error> {
937 if !self.client_cert_used() {
938 return Ok(None);
939 }
940
941 let mut chain = std::ptr::null_mut();
942 let mut len = 0;
943 unsafe {
944 s2n_connection_get_client_cert_chain(self.connection.as_ptr(), &mut chain, &mut len)
945 .into_result()?;
946 }
947
948 if chain.is_null() || len == 0 {
949 return Ok(None);
950 }
951
952 unsafe { Ok(Some(std::slice::from_raw_parts(chain, len as usize))) }
953 }
954
955 // The memory backing the ClientHello is owned by the Connection, so we
956 // tie the ClientHello to the lifetime of the Connection. This is validated
957 // with a doc test that ensures the ClientHello is invalid once the
958 // connection has gone out of scope.
959 //
960 /// Returns a reference to the ClientHello associated with the connection.
961 /// ```compile_fail
962 /// use s2n_tls::client_hello::ClientHello;
963 /// use s2n_tls::connection::Connection;
964 /// use s2n_tls::enums::Mode;
965 ///
966 /// let mut conn = Connection::new(Mode::Server);
967 /// let mut client_hello: &ClientHello = conn.client_hello().unwrap();
968 /// drop(conn);
969 /// client_hello.raw_message();
970 /// ```
971 ///
972 /// The compilation could be failing for a variety of reasons, so make sure
973 /// that the test case is actually good.
974 /// ```no_run
975 /// use s2n_tls::client_hello::ClientHello;
976 /// use s2n_tls::connection::Connection;
977 /// use s2n_tls::enums::Mode;
978 ///
979 /// let mut conn = Connection::new(Mode::Server);
980 /// let mut client_hello: &ClientHello = conn.client_hello().unwrap();
981 /// client_hello.raw_message();
982 /// drop(conn);
983 /// ```
984 ///
985 /// Corresponds to [s2n_connection_get_client_hello].
986 pub fn client_hello(&self) -> Result<&crate::client_hello::ClientHello, Error> {
987 let mut handle =
988 unsafe { s2n_connection_get_client_hello(self.connection.as_ptr()).into_result()? };
989 Ok(crate::client_hello::ClientHello::from_ptr(unsafe {
990 handle.as_mut()
991 }))
992 }
993
994 /// Corresponds to [s2n_client_hello_cb_done].
995 pub(crate) fn mark_client_hello_cb_done(&mut self) -> Result<(), Error> {
996 unsafe {
997 s2n_client_hello_cb_done(self.connection.as_ptr()).into_result()?;
998 }
999 Ok(())
1000 }
1001
1002 /// Access the protocol version selected for the connection.
1003 ///
1004 /// Corresponds to [s2n_connection_get_actual_protocol_version].
1005 pub fn actual_protocol_version(&self) -> Result<Version, Error> {
1006 let version = unsafe {
1007 s2n_connection_get_actual_protocol_version(self.connection.as_ptr()).into_result()?
1008 };
1009 version.try_into()
1010 }
1011
1012 /// Detects if the client hello is using the SSLv2 format.
1013 ///
1014 /// s2n-tls will not negotiate SSLv2, but will accept SSLv2 ClientHellos
1015 /// advertising a higher protocol version like SSLv3 or TLS1.0.
1016 /// [Connection::actual_protocol_version()] can be used to retrieve the
1017 /// protocol version that is actually used on the connection.
1018 ///
1019 /// Corresponds to [s2n_connection_get_client_hello_version], but only checks
1020 /// for SSLv2.
1021 pub fn client_hello_is_sslv2(&self) -> Result<bool, Error> {
1022 let version = unsafe {
1023 s2n_connection_get_client_hello_version(self.connection.as_ptr()).into_result()?
1024 };
1025 let version: Version = version.try_into()?;
1026 Ok(version == Version::SSLV2)
1027 }
1028
1029 /// Corresponds to [s2n_connection_get_handshake_type_name].
1030 pub fn handshake_type(&self) -> Result<&str, Error> {
1031 let handshake = unsafe {
1032 s2n_connection_get_handshake_type_name(self.connection.as_ptr()).into_result()?
1033 };
1034 unsafe {
1035 // SAFETY: Constructed strings have a null byte appended to them.
1036 // SAFETY: The data has a 'static lifetime, because it resides in a
1037 // static char array, and is never modified after its initial
1038 // creation.
1039 const_str!(handshake)
1040 }
1041 }
1042
1043 /// Corresponds to [s2n_connection_get_cipher].
1044 pub fn cipher_suite(&self) -> Result<&str, Error> {
1045 let cipher = unsafe { s2n_connection_get_cipher(self.connection.as_ptr()).into_result()? };
1046 unsafe {
1047 // SAFETY: The data is null terminated because it is declared as a C
1048 // string literal.
1049 // SAFETY: cipher has a static lifetime because it lives on s2n_cipher_suite,
1050 // a static struct.
1051 const_str!(cipher)
1052 }
1053 }
1054
1055 /// Corresponds to [s2n_connection_get_kem_name].
1056 #[deprecated = "PQ TLS 1.2 KEM Names are no longer supported. Use kem_group_name() to retrieve PQ TLS 1.3 Group name."]
1057 pub fn kem_name(&self) -> Option<&str> {
1058 let name_bytes = {
1059 let name = unsafe { s2n_connection_get_kem_name(self.connection.as_ptr()) };
1060 if name.is_null() {
1061 return None;
1062 }
1063 name
1064 };
1065
1066 let name_str = unsafe {
1067 // SAFETY: The data is null terminated because it is declared as a C
1068 // string literal.
1069 // SAFETY: kem_name has a static lifetime because it lives on a const
1070 // struct s2n_kem with file scope.
1071 const_str!(name_bytes)
1072 };
1073
1074 match name_str {
1075 Ok("NONE") => None,
1076 Ok(name) => Some(name),
1077 Err(_) => {
1078 // Unreachable: This would indicate a non-utf-8 string literal in
1079 // the s2n-tls C codebase.
1080 None
1081 }
1082 }
1083 }
1084
1085 /// Corresponds to [s2n_connection_get_kem_group_name].
1086 pub fn kem_group_name(&self) -> Option<&str> {
1087 let name_bytes = {
1088 let name = unsafe { s2n_connection_get_kem_group_name(self.connection.as_ptr()) };
1089 if name.is_null() {
1090 return None;
1091 }
1092 name
1093 };
1094
1095 let name_str = unsafe {
1096 // SAFETY: The data is null terminated because it is declared as a C
1097 // string literal.
1098 // SAFETY: kem_name has a static lifetime because it lives on a const
1099 // struct s2n_kem with file scope.
1100 const_str!(name_bytes)
1101 };
1102
1103 match name_str {
1104 Ok("NONE") => None,
1105 Ok(name) => Some(name),
1106 Err(_) => {
1107 // Unreachable: This would indicate a non-utf-8 string literal in
1108 // the s2n-tls C codebase.
1109 None
1110 }
1111 }
1112 }
1113
1114 /// Corresponds to [s2n_connection_get_curve].
1115 pub fn selected_curve(&self) -> Result<&str, Error> {
1116 let curve = unsafe { s2n_connection_get_curve(self.connection.as_ptr()).into_result()? };
1117 unsafe {
1118 // SAFETY: The data is null terminated because it is declared as a C
1119 // string literal.
1120 // SAFETY: curve has a static lifetime because it lives on s2n_ecc_named_curve,
1121 // which is a static const struct.
1122 const_str!(curve)
1123 }
1124 }
1125
1126 /// Corresponds to [s2n_connection_get_selected_signature_algorithm].
1127 pub fn selected_signature_algorithm(&self) -> Result<SignatureAlgorithm, Error> {
1128 let mut sig_alg = s2n_tls_signature_algorithm::ANONYMOUS;
1129 unsafe {
1130 s2n_connection_get_selected_signature_algorithm(self.connection.as_ptr(), &mut sig_alg)
1131 .into_result()?;
1132 }
1133 sig_alg.try_into()
1134 }
1135
1136 /// Corresponds to [s2n_connection_get_selected_digest_algorithm].
1137 pub fn selected_hash_algorithm(&self) -> Result<HashAlgorithm, Error> {
1138 let mut hash_alg = s2n_tls_hash_algorithm::NONE;
1139 unsafe {
1140 s2n_connection_get_selected_digest_algorithm(self.connection.as_ptr(), &mut hash_alg)
1141 .into_result()?;
1142 }
1143 hash_alg.try_into()
1144 }
1145
1146 /// Corresponds to [s2n_connection_get_selected_client_cert_signature_algorithm].
1147 pub fn selected_client_signature_algorithm(&self) -> Result<Option<SignatureAlgorithm>, Error> {
1148 let mut sig_alg = s2n_tls_signature_algorithm::ANONYMOUS;
1149 unsafe {
1150 s2n_connection_get_selected_client_cert_signature_algorithm(
1151 self.connection.as_ptr(),
1152 &mut sig_alg,
1153 )
1154 .into_result()?;
1155 }
1156 Ok(match sig_alg {
1157 s2n_tls_signature_algorithm::ANONYMOUS => None,
1158 sig_alg => Some(sig_alg.try_into()?),
1159 })
1160 }
1161
1162 /// Corresponds to [s2n_connection_get_selected_client_cert_digest_algorithm].
1163 pub fn selected_client_hash_algorithm(&self) -> Result<Option<HashAlgorithm>, Error> {
1164 let mut hash_alg = s2n_tls_hash_algorithm::NONE;
1165 unsafe {
1166 s2n_connection_get_selected_client_cert_digest_algorithm(
1167 self.connection.as_ptr(),
1168 &mut hash_alg,
1169 )
1170 .into_result()?;
1171 }
1172 Ok(match hash_alg {
1173 s2n_tls_hash_algorithm::NONE => None,
1174 hash_alg => Some(hash_alg.try_into()?),
1175 })
1176 }
1177
1178 /// Corresponds to [s2n_get_application_protocol].
1179 pub fn application_protocol(&self) -> Option<&[u8]> {
1180 let protocol = unsafe { s2n_get_application_protocol(self.connection.as_ptr()) };
1181 if protocol.is_null() {
1182 return None;
1183 }
1184 Some(unsafe { CStr::from_ptr(protocol).to_bytes() })
1185 }
1186
1187 /// Provides access to the TLS-Exporter functionality.
1188 ///
1189 /// See https://datatracker.ietf.org/doc/html/rfc5705 and https://www.rfc-editor.org/rfc/rfc8446.
1190 ///
1191 /// This is currently only available with TLS 1.3 connections which have finished a handshake.
1192 ///
1193 /// Corresponds to [s2n_connection_tls_exporter].
1194 pub fn tls_exporter(
1195 &self,
1196 label: &[u8],
1197 context: &[u8],
1198 output: &mut [u8],
1199 ) -> Result<(), Error> {
1200 unsafe {
1201 s2n_connection_tls_exporter(
1202 self.connection.as_ptr(),
1203 label.as_ptr(),
1204 label.len().try_into().map_err(|_| Error::INVALID_INPUT)?,
1205 context.as_ptr(),
1206 context.len().try_into().map_err(|_| Error::INVALID_INPUT)?,
1207 output.as_mut_ptr(),
1208 output.len().try_into().map_err(|_| Error::INVALID_INPUT)?,
1209 )
1210 .into_result()
1211 .map(|_| ())
1212 }
1213 }
1214
1215 /// Returns the validated peer certificate chain.
1216 // 'static lifetime is because this copies the certificate chain from the connection into a new
1217 // chain, so the lifetime is independent of the connection.
1218 ///
1219 /// Corresponds to [s2n_connection_get_peer_cert_chain].
1220 pub fn peer_cert_chain(&self) -> Result<CertificateChain<'static>, Error> {
1221 unsafe {
1222 let chain_handle = CertificateChainHandle::allocate()?;
1223 s2n_connection_get_peer_cert_chain(
1224 self.connection.as_ptr(),
1225 chain_handle.cert.as_ptr(),
1226 )
1227 .into_result()
1228 .map(|_| ())?;
1229 Ok(CertificateChain::from_allocated(chain_handle))
1230 }
1231 }
1232
1233 /// Get the certificate used during the TLS handshake
1234 ///
1235 /// - If `self` is a server connection, the certificate selected will depend on the
1236 /// ServerName sent by the client and supported ciphers.
1237 /// - If `self` is a client connection, the certificate sent in response to a CertificateRequest
1238 /// message is returned. Currently s2n-tls supports loading only one certificate in client mode. Note that
1239 /// not all TLS endpoints will request a certificate.
1240 ///
1241 /// Corresponds to [s2n_connection_get_selected_cert].
1242 pub fn selected_cert(&self) -> Option<CertificateChain<'_>> {
1243 unsafe {
1244 // The API only returns null, no error is actually set.
1245 // Clippy doesn't realize from_ptr_reference is unsafe.
1246 #[allow(clippy::manual_map)]
1247 if let Some(ptr) =
1248 NonNull::new(s2n_connection_get_selected_cert(self.connection.as_ptr()))
1249 {
1250 Some(CertificateChain::from_ptr_reference(ptr))
1251 } else {
1252 None
1253 }
1254 }
1255 }
1256
1257 /// Corresponds to [s2n_connection_get_master_secret].
1258 pub fn master_secret(&self) -> Result<Vec<u8>, Error> {
1259 // TLS1.2 master secrets are always 48 bytes
1260 let mut secret = vec![0; 48];
1261 unsafe {
1262 s2n_connection_get_master_secret(
1263 self.connection.as_ptr(),
1264 secret.as_mut_ptr(),
1265 secret.len(),
1266 )
1267 .into_result()?;
1268 }
1269 Ok(secret)
1270 }
1271
1272 /// Retrieves the size of the serialized connection
1273 ///
1274 /// Corresponds to [s2n_connection_serialization_length].
1275 pub fn serialization_length(&self) -> Result<usize, Error> {
1276 unsafe {
1277 let mut length = 0;
1278 s2n_connection_serialization_length(self.connection.as_ptr(), &mut length)
1279 .into_result()?;
1280 Ok(length.try_into().unwrap())
1281 }
1282 }
1283
1284 /// Serializes the TLS connection into the provided buffer
1285 ///
1286 /// Corresponds to [s2n_connection_serialize].
1287 pub fn serialize(&self, output: &mut [u8]) -> Result<(), Error> {
1288 unsafe {
1289 s2n_connection_serialize(
1290 self.connection.as_ptr(),
1291 output.as_mut_ptr(),
1292 output.len().try_into().map_err(|_| Error::INVALID_INPUT)?,
1293 )
1294 .into_result()?;
1295 Ok(())
1296 }
1297 }
1298
1299 /// Deserializes the input buffer into a new TLS connection that can send/recv
1300 /// data from the original peer.
1301 ///
1302 /// Corresponds to [s2n_connection_deserialize].
1303 pub fn deserialize(&mut self, input: &[u8]) -> Result<(), Error> {
1304 let size = input.len();
1305 /* This is not ideal, we know that s2n_connection_deserialize will not mutate the
1306 * input value, however, the mut is needed to use the stuffer functions. */
1307 let input = input.as_ptr() as *mut u8;
1308 unsafe {
1309 s2n_connection_deserialize(
1310 self.as_ptr(),
1311 input,
1312 size.try_into().map_err(|_| Error::INVALID_INPUT)?,
1313 )
1314 .into_result()?;
1315 Ok(())
1316 }
1317 }
1318
1319 /// Determines whether the connection was resumed from an earlier handshake.
1320 ///
1321 /// Corresponds to [s2n_connection_is_session_resumed].
1322 pub fn resumed(&self) -> bool {
1323 unsafe { s2n_connection_is_session_resumed(self.connection.as_ptr()) == 1 }
1324 }
1325
1326 /// Append an external psk to a connection.
1327 ///
1328 /// This may be called repeatedly to support multiple PSKs.
1329 ///
1330 /// Corresponds to [s2n_connection_append_psk].
1331 pub fn append_psk(&mut self, psk: &Psk) -> Result<(), Error> {
1332 unsafe {
1333 // SAFETY: retrieving a *mut s2n_psk from &Psk: s2n-tls does not treat
1334 // the pointer as mutable, and only holds the reference to copy the
1335 // PSK onto the connection.
1336 s2n_connection_append_psk(self.as_ptr(), psk.ptr.as_ptr()).into_result()?
1337 };
1338 Ok(())
1339 }
1340
1341 /// Corresponds to [s2n_connection_get_negotiated_psk_identity_length].
1342 pub fn negotiated_psk_identity_length(&self) -> Result<usize, Error> {
1343 let mut length = 0;
1344 unsafe {
1345 s2n_connection_get_negotiated_psk_identity_length(self.connection.as_ptr(), &mut length)
1346 .into_result()?
1347 };
1348 Ok(length as usize)
1349 }
1350
1351 /// Retrieve the negotiated psk identity. Use [Connection::negotiated_psk_identity_length]
1352 /// to retrieve the length of the psk identity.
1353 ///
1354 /// Corresponds to [s2n_connection_get_negotiated_psk_identity].
1355 pub fn negotiated_psk_identity(&self, destination: &mut [u8]) -> Result<(), Error> {
1356 unsafe {
1357 s2n_connection_get_negotiated_psk_identity(
1358 self.connection.as_ptr(),
1359 destination.as_mut_ptr(),
1360 destination.len().min(u16::MAX as usize) as u16,
1361 )
1362 .into_result()?;
1363 }
1364 Ok(())
1365 }
1366
1367 /// Associates an arbitrary application context with the Connection to be later retrieved via
1368 /// the [`Self::application_context()`] and [`Self::application_context_mut()`] APIs.
1369 ///
1370 /// This API will override an existing application context set on the Connection.
1371 ///
1372 /// Corresponds to [s2n_connection_set_ctx].
1373 pub fn set_application_context<T: Send + Sync + 'static>(&mut self, app_context: T) {
1374 self.context_mut().app_context = Some(Box::new(app_context));
1375 }
1376
1377 /// Retrieves a reference to the application context associated with the Connection.
1378 ///
1379 /// If an application context hasn't already been set on the Connection, or if the set
1380 /// application context isn't of type T, None will be returned.
1381 ///
1382 /// To set a context on the connection, use [`Self::set_application_context()`]. To retrieve a
1383 /// mutable reference to the context, use [`Self::application_context_mut()`].
1384 ///
1385 /// Corresponds to [s2n_connection_get_ctx].
1386 pub fn application_context<T: Send + Sync + 'static>(&self) -> Option<&T> {
1387 match self.context().app_context.as_ref() {
1388 None => None,
1389 // The Any trait keeps track of the application context's type. downcast_ref() returns
1390 // Some only if the correct type is provided:
1391 // https://doc.rust-lang.org/std/any/trait.Any.html#method.downcast_ref
1392 Some(app_context) => app_context.downcast_ref::<T>(),
1393 }
1394 }
1395
1396 /// Retrieves a mutable reference to the application context associated with the Connection.
1397 ///
1398 /// If an application context hasn't already been set on the Connection, or if the set
1399 /// application context isn't of type T, None will be returned.
1400 ///
1401 /// To set a context on the connection, use [`Self::set_application_context()`]. To retrieve an
1402 /// immutable reference to the context, use [`Self::application_context()`].
1403 ///
1404 /// Corresponds to [s2n_connection_get_ctx].
1405 pub fn application_context_mut<T: Send + Sync + 'static>(&mut self) -> Option<&mut T> {
1406 match self.context_mut().app_context.as_mut() {
1407 None => None,
1408 Some(app_context) => app_context.downcast_mut::<T>(),
1409 }
1410 }
1411
1412 #[cfg(feature = "unstable-renegotiate")]
1413 pub(crate) fn renegotiate_state_mut(&mut self) -> &mut RenegotiateState {
1414 &mut self.context_mut().renegotiate_state
1415 }
1416
1417 #[cfg(feature = "unstable-renegotiate")]
1418 pub(crate) fn renegotiate_state(&self) -> &RenegotiateState {
1419 &self.context().renegotiate_state
1420 }
1421}
1422
1423struct Context {
1424 mode: Mode,
1425 waker: Option<Waker>,
1426 async_callback: Option<AsyncCallback>,
1427 verify_host_callback: Option<Box<dyn VerifyHostNameCallback>>,
1428 connection_initialized: bool,
1429 app_context: Option<Box<dyn Any + Send + Sync>>,
1430 #[cfg(feature = "unstable-renegotiate")]
1431 pub(crate) renegotiate_state: RenegotiateState,
1432}
1433
1434impl Context {
1435 fn new(mode: Mode) -> Self {
1436 Context {
1437 mode,
1438 waker: None,
1439 async_callback: None,
1440 verify_host_callback: None,
1441 connection_initialized: false,
1442 app_context: None,
1443 #[cfg(feature = "unstable-renegotiate")]
1444 renegotiate_state: RenegotiateState::default(),
1445 }
1446 }
1447}
1448
1449#[cfg(feature = "quic")]
1450impl Connection {
1451 /// Corresponds to [s2n_connection_enable_quic].
1452 pub fn enable_quic(&mut self) -> Result<&mut Self, Error> {
1453 unsafe { s2n_connection_enable_quic(self.connection.as_ptr()).into_result() }?;
1454 Ok(self)
1455 }
1456
1457 /// Corresponds to [s2n_connection_set_quic_transport_parameters].
1458 pub fn set_quic_transport_parameters(&mut self, buffer: &[u8]) -> Result<&mut Self, Error> {
1459 unsafe {
1460 s2n_connection_set_quic_transport_parameters(
1461 self.connection.as_ptr(),
1462 buffer.as_ptr(),
1463 buffer.len().try_into().map_err(|_| Error::INVALID_INPUT)?,
1464 )
1465 .into_result()
1466 }?;
1467 Ok(self)
1468 }
1469
1470 /// Corresponds to [s2n_connection_get_quic_transport_parameters].
1471 pub fn quic_transport_parameters(&mut self) -> Result<&[u8], Error> {
1472 let mut ptr = core::ptr::null();
1473 let mut len = 0;
1474 unsafe {
1475 s2n_connection_get_quic_transport_parameters(
1476 self.connection.as_ptr(),
1477 &mut ptr,
1478 &mut len,
1479 )
1480 .into_result()
1481 }?;
1482 let buffer = unsafe { core::slice::from_raw_parts(ptr, len as _) };
1483 Ok(buffer)
1484 }
1485
1486 /// # Safety
1487 ///
1488 /// The `context` pointer must live at least as long as the connection
1489 ///
1490 /// Corresponds to [s2n_connection_set_secret_callback].
1491 pub unsafe fn set_secret_callback(
1492 &mut self,
1493 callback: s2n_secret_cb,
1494 context: *mut c_void,
1495 ) -> Result<&mut Self, Error> {
1496 s2n_connection_set_secret_callback(self.connection.as_ptr(), callback, context)
1497 .into_result()?;
1498 Ok(self)
1499 }
1500
1501 /// Corresponds to [s2n_recv_quic_post_handshake_message].
1502 pub fn quic_process_post_handshake_message(&mut self) -> Result<&mut Self, Error> {
1503 let mut blocked = s2n_blocked_status::NOT_BLOCKED;
1504 unsafe {
1505 s2n_recv_quic_post_handshake_message(self.connection.as_ptr(), &mut blocked)
1506 .into_result()
1507 }?;
1508 Ok(self)
1509 }
1510
1511 /// Allows the quic library to check if session tickets are expected
1512 ///
1513 /// Corresponds to [s2n_connection_are_session_tickets_enabled].
1514 pub fn are_session_tickets_enabled(&self) -> bool {
1515 unsafe { s2n_connection_are_session_tickets_enabled(self.connection.as_ptr()) }
1516 }
1517}
1518
1519impl AsRef<Connection> for Connection {
1520 fn as_ref(&self) -> &Connection {
1521 self
1522 }
1523}
1524
1525impl AsMut<Connection> for Connection {
1526 fn as_mut(&mut self) -> &mut Connection {
1527 self
1528 }
1529}
1530
1531impl Drop for Connection {
1532 /// Corresponds to [s2n_connection_free].
1533 fn drop(&mut self) {
1534 // ignore failures since there's not much we can do about it
1535 unsafe {
1536 // clean up context
1537 let _ = self.drop_context();
1538
1539 // cleanup config
1540 let _ = self.drop_config();
1541
1542 // cleanup connection
1543 let _ = s2n_connection_free(self.connection.as_ptr()).into_result();
1544 }
1545 }
1546}
1547
1548#[cfg(test)]
1549mod tests {
1550 use super::*;
1551
1552 // ensure the connection context is send
1553 #[test]
1554 fn context_send_test() {
1555 fn assert_send<T: 'static + Send>() {}
1556 assert_send::<Context>();
1557 }
1558
1559 // ensure the connection context is sync
1560 #[test]
1561 fn context_sync_test() {
1562 fn assert_sync<T: 'static + Sync>() {}
1563 assert_sync::<Context>();
1564 }
1565
1566 /// Test that an application context can be set and retrieved.
1567 #[test]
1568 fn test_app_context_set_and_retrieve() {
1569 let mut connection = Connection::new_server();
1570
1571 // Before a context is set, None is returned.
1572 assert!(connection.application_context::<u32>().is_none());
1573
1574 let test_value: u32 = 1142;
1575 connection.set_application_context(test_value);
1576
1577 // After a context is set, the application data is returned.
1578 assert_eq!(*connection.application_context::<u32>().unwrap(), 1142);
1579 }
1580
1581 /// Test that an application context can be modified.
1582 #[test]
1583 fn test_app_context_modify() {
1584 let test_value: u64 = 0;
1585
1586 let mut connection = Connection::new_server();
1587 connection.set_application_context(test_value);
1588
1589 let context_value = connection.application_context_mut::<u64>().unwrap();
1590 *context_value += 1;
1591
1592 assert_eq!(*connection.application_context::<u64>().unwrap(), 1);
1593 }
1594
1595 /// Test that an application context can be overridden.
1596 #[test]
1597 fn test_app_context_override() {
1598 let mut connection = Connection::new_server();
1599
1600 let test_value: u16 = 1142;
1601 connection.set_application_context(test_value);
1602
1603 assert_eq!(*connection.application_context::<u16>().unwrap(), 1142);
1604
1605 // Override the context with a new value.
1606 let test_value: u16 = 10;
1607 connection.set_application_context(test_value);
1608
1609 assert_eq!(*connection.application_context::<u16>().unwrap(), 10);
1610
1611 // Override the context with a new type.
1612 let test_value: i16 = -20;
1613 connection.set_application_context(test_value);
1614
1615 assert_eq!(*connection.application_context::<i16>().unwrap(), -20);
1616 }
1617
1618 /// Test that a context of another type can't be retrieved.
1619 #[test]
1620 fn test_app_context_invalid_type() {
1621 let mut connection = Connection::new_server();
1622
1623 let test_value: u32 = 0;
1624 connection.set_application_context(test_value);
1625
1626 // A context type that wasn't set shouldn't be returned.
1627 assert!(connection.application_context::<i16>().is_none());
1628
1629 // Retrieving the correct type succeeds.
1630 assert!(connection.application_context::<u32>().is_some());
1631 }
1632}