s2n-quic-tls 0.46.0

Internal crate used by s2n-quic
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0

use bytes::{Bytes, BytesMut};
use core::{ffi::c_void, marker::PhantomData};
use s2n_quic_core::{
    application::ServerName,
    crypto::{tls, tls::CipherSuite, CryptoSuite},
    endpoint, transport,
};
use s2n_quic_crypto::{
    handshake::HandshakeKey, hkdf, one_rtt::OneRttKey, ring_aead as aead, Prk, SecretPair, Suite,
};
use s2n_tls::{connection::Connection, error::Fallible, ffi::*};

/// The preallocated size of the outgoing buffer
///
/// By allocating a larger buffer, we only do a few allocations, even when
/// s2n-tls sends small chunks
const SEND_BUFFER_CAPACITY: usize = 2048;

/// Handles all callback contexts for each session
pub struct Callback<'a, T, C> {
    pub context: &'a mut T,
    pub endpoint: endpoint::Type,
    pub state: &'a mut State,
    pub suite: PhantomData<C>,
    pub err: Option<transport::Error>,
    pub send_buffer: &'a mut BytesMut,
    pub emitted_server_name: &'a mut bool,
    pub server_name: &'a Option<ServerName>,
    pub server_params: &'a mut Vec<u8>,
}

impl<'a, T, C> Callback<'a, T, C>
where
    T: 'a + tls::Context<C>,
    C: CryptoSuite<
        HandshakeKey = <Suite as CryptoSuite>::HandshakeKey,
        HandshakeHeaderKey = <Suite as CryptoSuite>::HandshakeHeaderKey,
        InitialKey = <Suite as CryptoSuite>::InitialKey,
        InitialHeaderKey = <Suite as CryptoSuite>::InitialHeaderKey,
        OneRttKey = <Suite as CryptoSuite>::OneRttKey,
        OneRttHeaderKey = <Suite as CryptoSuite>::OneRttHeaderKey,
        ZeroRttKey = <Suite as CryptoSuite>::ZeroRttKey,
        ZeroRttHeaderKey = <Suite as CryptoSuite>::ZeroRttHeaderKey,
        RetryKey = <Suite as CryptoSuite>::RetryKey,
    >,
{
    /// Initializes the s2n-tls connection with all of the contexts and callbacks
    ///
    /// # Safety
    ///
    /// * The Callback struct must live at least as long as the connection
    /// * or the `unset` method should be called if it doesn't
    pub unsafe fn set(&mut self, connection: &mut Connection) {
        let context = self as *mut Self as *mut c_void;

        // We use unwrap here since s2n-tls will just check if connection is not null
        connection
            .set_secret_callback(Some(Self::secret_cb), context)
            .unwrap();
        connection.set_send_callback(Some(Self::send_cb)).unwrap();
        connection.set_send_context(context).unwrap();
        connection
            .set_receive_callback(Some(Self::recv_cb))
            .unwrap();
        connection.set_receive_context(context).unwrap();
        // A Waker is provided for use with the client hello callback.
        connection.set_waker(Some(self.context.waker())).unwrap();
    }

    /// Removes all of the callback and context pointers from the connection
    pub fn unset(mut self, connection: &mut Connection) -> Result<(), transport::Error> {
        unsafe {
            unsafe extern "C" fn secret_cb(
                _context: *mut c_void,
                _conn: *mut s2n_connection,
                _secret_type: s2n_secret_type_t::Type,
                _secret: *mut u8,
                _secret_size: u8,
            ) -> s2n_status_code::Type {
                -1
            }

            unsafe extern "C" fn send_cb(
                _context: *mut c_void,
                _data: *const u8,
                _len: u32,
            ) -> s2n_status_code::Type {
                -1
            }

            unsafe extern "C" fn recv_cb(
                _context: *mut c_void,
                _data: *mut u8,
                _len: u32,
            ) -> s2n_status_code::Type {
                -1
            }

            // We use unwrap here since s2n-tls will just check if connection is not null
            connection
                .set_secret_callback(Some(secret_cb), core::ptr::null_mut())
                .unwrap();
            connection.set_send_callback(Some(send_cb)).unwrap();
            connection.set_send_context(core::ptr::null_mut()).unwrap();
            connection.set_receive_callback(Some(recv_cb)).unwrap();
            connection
                .set_receive_context(core::ptr::null_mut())
                .unwrap();
            connection.set_waker(None).unwrap();

            // Flush the send buffer before returning to the connection
            self.flush();
            // attempt to emit server name after making progress and prior to error handling
            if !*self.emitted_server_name {
                if let Some(server_name) = self.server_name.clone().or_else(|| {
                    connection
                        .server_name()
                        .map(|server_name| server_name.into())
                }) {
                    self.context.on_server_name(server_name)?;
                    *self.emitted_server_name = true;
                }
            }

            if let Some(err) = self.err {
                return Err(err);
            }

            Ok(())
        }
    }

    /// The function s2n-tls calls when it emits secrets
    unsafe extern "C" fn secret_cb(
        context: *mut c_void,
        conn: *mut s2n_connection,
        secret_type: s2n_secret_type_t::Type,
        secret: *mut u8,
        secret_size: u8,
    ) -> s2n_status_code::Type {
        let context = &mut *(context as *mut Self);
        let secret = core::slice::from_raw_parts_mut(secret, secret_size as _);
        match context.on_secret(conn, secret_type, secret) {
            Ok(()) => 0,
            Err(err) => {
                context.err = Some(err);
                -1
            }
        }
    }

    /// Handles secrets from the s2n-tls connection
    fn on_secret(
        &mut self,
        conn: *mut s2n_connection,
        id: s2n_secret_type_t::Type,
        secret: &mut [u8],
    ) -> Result<(), transport::Error> {
        match core::mem::replace(&mut self.state.secrets, Secrets::Waiting) {
            Secrets::Waiting => {
                if id == s2n_secret_type_t::CLIENT_EARLY_TRAFFIC_SECRET {
                    return Ok(());
                }

                let (prk_algo, _aead, cipher_suite) =
                    get_algo_type(conn).ok_or(tls::Error::INTERNAL_ERROR)?;
                let secret = Prk::new_less_safe(prk_algo, secret);
                self.state.secrets = Secrets::Half { secret, id };
                self.state.cipher_suite = cipher_suite;

                Ok(())
            }
            Secrets::Half {
                id: other_id,
                secret: other_secret,
            } => {
                let (prk_algo, aead_algo, cipher_suite) =
                    get_algo_type(conn).ok_or(tls::Error::INTERNAL_ERROR)?;
                let secret = Prk::new_less_safe(prk_algo, secret);
                self.state.cipher_suite = cipher_suite;
                let pair = match (id, other_id) {
                    (
                        s2n_secret_type_t::CLIENT_HANDSHAKE_TRAFFIC_SECRET,
                        s2n_secret_type_t::SERVER_HANDSHAKE_TRAFFIC_SECRET,
                    )
                    | (
                        s2n_secret_type_t::CLIENT_APPLICATION_TRAFFIC_SECRET,
                        s2n_secret_type_t::SERVER_APPLICATION_TRAFFIC_SECRET,
                    ) => SecretPair {
                        client: secret,
                        server: other_secret,
                    },
                    (
                        s2n_secret_type_t::SERVER_HANDSHAKE_TRAFFIC_SECRET,
                        s2n_secret_type_t::CLIENT_HANDSHAKE_TRAFFIC_SECRET,
                    )
                    | (
                        s2n_secret_type_t::SERVER_APPLICATION_TRAFFIC_SECRET,
                        s2n_secret_type_t::CLIENT_APPLICATION_TRAFFIC_SECRET,
                    ) => SecretPair {
                        server: secret,
                        client: other_secret,
                    },
                    _ => {
                        debug_assert!(false, "invalid key phase");
                        return Err(transport::Error::INTERNAL_ERROR);
                    }
                };

                // Flush the send buffer before transitioning to the next phase
                self.flush();

                match self.state.tx_phase {
                    HandshakePhase::Initial => {
                        let (key, header_key) = HandshakeKey::new(self.endpoint, aead_algo, pair)
                            .expect("invalid cipher");

                        if !self.server_params.is_empty() {
                            debug_assert!(self.endpoint.is_server());

                            // Since the client transport parameters are sent in the Initial packet space
                            // we can access them early on in the handshake, allowing the server transport
                            // parameters to be configured based on the client's parameters.
                            unsafe {
                                // Safety: conn needs to outlive params
                                let client_params = get_application_params(conn)?;

                                // Allow for additional server params to be appended that depend
                                // on the given client params
                                self.context.on_client_application_params(
                                    client_params,
                                    self.server_params,
                                )?;

                                // Now set the server transport parameters on the connection for
                                // transmission to the client
                                s2n_connection_set_quic_transport_parameters(
                                    conn,
                                    self.server_params.as_ptr(),
                                    self.server_params.len() as _,
                                );
                            }
                            // We've given the transport parameters to s2n-tls, so clear
                            // `server_params` to ensure they are not set again.
                            self.server_params.clear();
                        }

                        self.context.on_handshake_keys(key, header_key)?;
                        self.state.tx_phase.transition();
                        self.state.rx_phase.transition();
                    }
                    _ => {
                        let (key, header_key) =
                            OneRttKey::new(self.endpoint, aead_algo, pair).expect("invalid cipher");
                        // At this point the server is done writing Handshake messages
                        if self.endpoint.is_server() {
                            self.state.tx_phase.transition();
                        }
                        let params = unsafe {
                            // Safety: conn needs to outlive params
                            //
                            // TODO use interning for these values
                            // issue: https://github.com/aws/s2n-quic/issues/248
                            //
                            // Move this event to where `on_server_name` is emitted once we expose
                            // the functionality in s2n_tls bindings
                            let application_protocol =
                                Bytes::copy_from_slice(get_application_protocol(conn)?);
                            self.context.on_application_protocol(application_protocol)?;
                            get_application_params(conn)?
                        };

                        self.context.on_one_rtt_keys(key, header_key, params)?;
                    }
                }

                Ok(())
            }
        }
    }

    /// The function s2n-tls calls when it wants to send data
    unsafe extern "C" fn send_cb(
        context: *mut c_void,
        data: *const u8,
        len: u32,
    ) -> s2n_status_code::Type {
        let context = &mut *(context as *mut Self);
        let data = core::slice::from_raw_parts(data, len as _);
        context.on_write(data) as _
    }

    /// Called when sending data
    fn on_write(&mut self, data: &[u8]) -> usize {
        // If this write would cause the current send buffer to reallocate,
        // we should flush and create a new send buffer.
        let remaining_capacity = self.send_buffer.capacity() - self.send_buffer.len();

        if remaining_capacity < data.len() {
            // Flush the send buffer before reallocating it
            self.flush();

            // ensure we only do one allocation for this write
            let len = SEND_BUFFER_CAPACITY.max(data.len());

            debug_assert!(
                self.send_buffer.is_empty(),
                "dropping a send buffer with data will result in data loss"
            );
            *self.send_buffer = BytesMut::with_capacity(len);
        }

        // Write the current data to the send buffer
        //
        // NOTE: we don't immediately flush to the packet space since s2n-tls may do
        //       several small writes in a row.
        self.send_buffer.extend_from_slice(data);

        data.len()
    }

    /// Flushes the send buffer into the current TX space
    fn flush(&mut self) {
        if !self.send_buffer.is_empty() {
            let chunk = self.send_buffer.split().freeze();

            match self.state.tx_phase {
                HandshakePhase::Initial => self.context.send_initial(chunk),
                HandshakePhase::Handshake => self.context.send_handshake(chunk),
                HandshakePhase::Application => self.context.send_application(chunk),
            }
        }
    }

    /// The function s2n-tls calls when it wants to receive data
    unsafe extern "C" fn recv_cb(
        context: *mut c_void,
        data: *mut u8,
        len: u32,
    ) -> s2n_status_code::Type {
        let context = &mut *(context as *mut Self);
        let data = core::slice::from_raw_parts_mut(data, len as _);
        match context.on_read(data) {
            0 => {
                // https://github.com/aws/s2n-tls/blob/main/docs/USAGE-GUIDE.md#s2n_connection_set_send_cb
                // s2n-tls wants us to set the global errno to signal blocked
                errno::set_errno(errno::Errno(libc::EWOULDBLOCK));
                -1
            }
            len => len as _,
        }
    }

    /// Called when receiving data
    fn on_read(&mut self, data: &mut [u8]) -> usize {
        let max_len = Some(data.len());

        let chunk = match self.state.rx_phase {
            HandshakePhase::Initial => self.context.receive_initial(max_len),
            HandshakePhase::Handshake => self.context.receive_handshake(max_len),
            HandshakePhase::Application => self.context.receive_application(max_len),
        };

        if let Some(chunk) = chunk {
            let len = chunk.len();
            data[..len].copy_from_slice(&chunk);
            len
        } else {
            0
        }
    }
}

#[derive(Default, Debug)]
pub struct State {
    rx_phase: HandshakePhase,
    tx_phase: HandshakePhase,
    secrets: Secrets,
    cipher_suite: CipherSuite,
}

impl State {
    /// Complete the handshake
    pub fn on_handshake_complete(&mut self) {
        self.tx_phase.transition();
        self.rx_phase.transition();
        debug_assert_eq!(self.tx_phase, HandshakePhase::Application);
        debug_assert_eq!(self.rx_phase, HandshakePhase::Application);
    }

    pub fn cipher_suite(&self) -> CipherSuite {
        self.cipher_suite
    }
}

#[derive(Clone, Copy, Debug, PartialEq, PartialOrd)]
enum HandshakePhase {
    Initial,
    Handshake,
    Application,
}

impl HandshakePhase {
    fn transition(&mut self) {
        *self = match self {
            Self::Initial => Self::Handshake,
            _ => Self::Application,
        };
    }
}

impl Default for HandshakePhase {
    fn default() -> Self {
        Self::Initial
    }
}

#[derive(Debug)]
enum Secrets {
    Waiting,
    Half {
        secret: Prk,
        id: s2n_secret_type_t::Type,
    },
}

impl Default for Secrets {
    fn default() -> Self {
        Self::Waiting
    }
}

fn get_algo_type(
    connection: *mut s2n_connection,
) -> Option<(hkdf::Algorithm, &'static aead::Algorithm, CipherSuite)> {
    let mut cipher = [0, 0];
    unsafe {
        s2n_connection_get_cipher_iana_value(connection, &mut cipher[0], &mut cipher[1])
            .into_result()
            .ok()?;
    }

    //= https://www.rfc-editor.org/rfc/rfc8446#appendix-B.4
    //# This specification defines the following cipher suites for use with
    //# TLS 1.3.
    //#
    //#              +------------------------------+-------------+
    //#              | Description                  | Value       |
    //#              +------------------------------+-------------+
    //#              | TLS_AES_128_GCM_SHA256       | {0x13,0x01} |
    //#              |                              |             |
    //#              | TLS_AES_256_GCM_SHA384       | {0x13,0x02} |
    //#              |                              |             |
    //#              | TLS_CHACHA20_POLY1305_SHA256 | {0x13,0x03} |
    //#              |                              |             |
    //#              | TLS_AES_128_CCM_SHA256       | {0x13,0x04} |
    //#              |                              |             |
    //#              | TLS_AES_128_CCM_8_SHA256     | {0x13,0x05} |
    //#              +------------------------------+-------------+
    const TLS_AES_128_GCM_SHA256: [u8; 2] = [0x13, 0x01];
    const TLS_AES_256_GCM_SHA384: [u8; 2] = [0x13, 0x02];
    const TLS_CHACHA20_POLY1305_SHA256: [u8; 2] = [0x13, 0x03];

    // NOTE: we don't have CCM support implemented currently

    // NOTE: CCM_8 is not allowed by QUIC
    //= https://www.rfc-editor.org/rfc/rfc9001#section-5.3
    //# QUIC can use any of the cipher suites defined in [TLS13] with the
    //# exception of TLS_AES_128_CCM_8_SHA256.

    match cipher {
        TLS_AES_128_GCM_SHA256 => Some((
            hkdf::HKDF_SHA256,
            &aead::AES_128_GCM,
            CipherSuite::TLS_AES_128_GCM_SHA256,
        )),
        TLS_AES_256_GCM_SHA384 => Some((
            hkdf::HKDF_SHA384,
            &aead::AES_256_GCM,
            CipherSuite::TLS_AES_256_GCM_SHA384,
        )),
        TLS_CHACHA20_POLY1305_SHA256 => Some((
            hkdf::HKDF_SHA256,
            &aead::CHACHA20_POLY1305,
            CipherSuite::TLS_CHACHA20_POLY1305_SHA256,
        )),
        _ => None,
    }
}

unsafe fn get_application_params<'a>(
    connection: *mut s2n_connection,
) -> Result<tls::ApplicationParameters<'a>, tls::Error> {
    //= https://www.rfc-editor.org/rfc/rfc9001#section-8.1
    //# When using ALPN, endpoints MUST immediately close a connection (see
    //# Section 10.2 of [QUIC-TRANSPORT]) with a no_application_protocol TLS
    //# alert (QUIC error code 0x178; see Section 4.8) if an application
    //# protocol is not negotiated.

    //= https://www.rfc-editor.org/rfc/rfc9001#section-8.1
    //# While [ALPN] only specifies that servers
    //# use this alert, QUIC clients MUST use error 0x178 to terminate a
    //# connection when ALPN negotiation fails.
    let transport_parameters =
        get_transport_parameters(connection).ok_or(tls::Error::MISSING_EXTENSION)?;

    Ok(tls::ApplicationParameters {
        transport_parameters,
    })
}

//= https://www.rfc-editor.org/rfc/rfc9001#section-8.1
//# Unless
//# another mechanism is used for agreeing on an application protocol,
//# endpoints MUST use ALPN for this purpose.
//
//= https://www.rfc-editor.org/rfc/rfc7301#section-3.1
//# Client                                              Server
//#
//#    ClientHello                     -------->       ServerHello
//#      (ALPN extension &                               (ALPN extension &
//#       list of protocols)                              selected protocol)
//#                                                    [ChangeCipherSpec]
//#                                    <--------       Finished
//#    [ChangeCipherSpec]
//#    Finished                        -------->
//#    Application Data                <------->       Application Data
unsafe fn get_application_protocol<'a>(
    connection: *mut s2n_connection,
) -> Result<&'a [u8], tls::Error> {
    let ptr = s2n_get_application_protocol(connection).into_result().ok();
    ptr.and_then(|ptr| get_cstr_slice(ptr))
        .ok_or(tls::Error::MISSING_EXTENSION)
}

unsafe fn get_transport_parameters<'a>(connection: *mut s2n_connection) -> Option<&'a [u8]> {
    let mut ptr = core::ptr::null();
    let mut len = 0u16;

    s2n_connection_get_quic_transport_parameters(connection, &mut ptr, &mut len)
        .into_result()
        .ok()?;

    get_slice(ptr, len as _)
}

unsafe fn get_cstr_slice<'a>(ptr: *const libc::c_char) -> Option<&'a [u8]> {
    let len = libc::strlen(ptr);
    get_slice(ptr as *const _, len)
}

unsafe fn get_slice<'a>(ptr: *const u8, len: usize) -> Option<&'a [u8]> {
    if ptr.is_null() || len == 0 {
        return None;
    }

    Some(core::slice::from_raw_parts(ptr, len as _))
}