1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
// Copyright Open Logistics Foundation
//
// Licensed under the Open Logistics Foundation License 1.3.
// For details on the licensing terms, see the LICENSE file.
// SPDX-License-Identifier: OLFL-1.3

//! TLS and DTLS interface module which defines the [`SslConnection`] type which is the main type
//! to interact with this library and the underlying [`SslContext`] type which contains all
//! hardware abstractions

#[cfg(feature = "alloc")]
use alloc::boxed::Box;

use core::ops::DerefMut;
use core::time::Duration;

use cty::c_void;
use embedded_mbedtls_sys::{
    mbedtls_ssl_config, mbedtls_ssl_config_init, mbedtls_ssl_context, mbedtls_ssl_init,
    MBEDTLS_ERR_SSL_WANT_READ, MBEDTLS_ERR_SSL_WANT_WRITE,
};
use embedded_nal::{SocketAddr, UdpClientStack};
use embedded_timers::clock::Clock;
use rand_core::{CryptoRng, RngCore};

use crate::{error::Error, rng::rng_try_fill_bytes_callback_fn, timing, udp};

/// Hardware context of an [`SslConnection`]
///
/// The context contains the underlying network stack, Clock and RNG context. This type needs to be
/// defined separately from the `SslConnection` because it needs a pinned memory location so that
/// raw pointers can be passed to the underlying C functions.
pub struct SslContext<'a, Net, C: Clock, R: RngCore + CryptoRng> {
    config: mbedtls_ssl_config,
    net_context: Net,
    timer_context: timing::MbedtlsTimer<'a, C>,
    csrng: R,
}

impl<'a, U: UdpClientStack, C: Clock, R: RngCore + CryptoRng>
    SslContext<'a, udp::UdpContext<U>, C, R>
{
    /// Create a new `SslContext` for client-side DTLS
    pub fn new_udp_client_side(
        net_stack: U,
        clock: &'a C,
        csrng: R,
        server_addr: SocketAddr,
    ) -> Self {
        let mut config = mbedtls_ssl_config::default();
        unsafe { mbedtls_ssl_config_init(&mut config) };
        let net_context = udp::UdpContext::new(net_stack, server_addr);
        let timer_context = timing::MbedtlsTimer::new(clock);
        SslContext {
            config,
            net_context,
            timer_context,
            csrng,
        }
    }
}

impl<'a, Net, C: Clock, R: RngCore + CryptoRng> Drop for SslContext<'a, Net, C, R> {
    fn drop(&mut self) {
        unsafe {
            embedded_mbedtls_sys::mbedtls_ssl_config_free(&mut self.config);
        }
    }
}

/// An SSL connection, i.e. the main type to interact with this library
///
/// To set up the underlying Mbed TLS C library, we need raw pointers so the [`SslContext`]
/// needs a pinned/stable memory location. Basically, we use two approaches to achieve this:
/// 1. We use a `&'a mut SslContext<...>` for the lifetime of the `SslConnection`, i.e. the
///    underlying `SslContext` can not move for that lifetime. This is implemented in e.g.
///    [`SslConnection::new_dtls_client`]. This approach has the benefit of being heapless (it does
///    not require `alloc`) which is often desirable for embedded devices.
/// 2. We use a `Box<SslConnection<...>>`, i.e. the `SslContext` lies on the heap so it does not
///    move in memory anymore. This is implemented in e.g.
///    `SslConnection::new_dtls_client_heap_context` which is only available when the `alloc`
///    feature is activated. This approach has the benefit that the `SslConnection` may be set up
///    in an initializer function and can be moved around freely afterwards.
pub struct SslConnection<
    'a,
    Net,
    C: Clock + 'a,
    R: RngCore + CryptoRng,
    CTX: DerefMut<Target = SslContext<'a, Net, C, R>>,
> {
    mbedtls_ctx: mbedtls_ssl_context,
    ssl_ctx: CTX,
}

impl<'a, 'b: 'a, U: UdpClientStack, C: Clock, R: RngCore + CryptoRng>
    SslConnection<'b, udp::UdpContext<U>, C, R, &'a mut SslContext<'b, udp::UdpContext<U>, C, R>>
{
    /// Create a new DTLS `SslConnection` from an [`SslContext`]
    ///
    /// When the `alloc` feature is activated, the `new_dtls_client_heap_context` constructor
    /// can be used with a _boxed_ context.
    /// This enables you to freely move the connection together with the context.
    /// (otherwise the context can't be moved after the connection is created,
    /// because its address needs to be pinned in either case)
    ///
    /// ```
    /// # use embedded_nal::{IpAddr, Ipv4Addr, SocketAddr, UdpClientStack};
    /// # use embedded_timers::clock::Clock;
    /// # use rand_core::{CryptoRng, RngCore};
    /// #
    /// #
    /// use embedded_mbedtls::ssl::{SslConnection, SslContext, Preset};
    ///
    /// # fn _setup_ssl_stack<U: UdpClientStack, R: RngCore + CryptoRng>(
    /// #     net_stack: U,
    /// #     clock: &impl Clock,
    /// #     rng: R,
    /// # ) {
    /// #    let server_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 22);
    /// let mut ctx = SslContext::new_udp_client_side(net_stack, clock, rng, server_addr);
    /// let connection = SslConnection::new_dtls_client(&mut ctx, Preset::Default).unwrap();
    ///
    /// // Now the connection is ready to use!
    /// # }
    /// ```
    pub fn new_dtls_client(
        ssl_context: &'a mut SslContext<'b, udp::UdpContext<U>, C, R>,
        preset: Preset,
    ) -> Result<Self, Error> {
        Self::new_generic_dtls_client(ssl_context, preset)
    }
}

#[cfg(feature = "alloc")]
impl<'a, U: UdpClientStack, C: Clock, R: RngCore + CryptoRng>
    SslConnection<'a, udp::UdpContext<U>, C, R, Box<SslContext<'a, udp::UdpContext<U>, C, R>>>
{
    /// Create a new DTLS `SslConnection`, moving the [`SslContext`] into a `Box`
    ///
    /// This allows to move the [`SslConnection`] instance freely. Especially, it allows to return
    /// from an initializer function in which the `SslContext` was set up.
    ///
    /// ```
    /// # type _BoxedDtls<'a, U, C, R> = SslConnection<
    /// #     'a,
    /// #     embedded_mbedtls::udp::UdpContext<U>,
    /// #     C,
    /// #     R,
    /// #     Box<SslContext<'a, embedded_mbedtls::udp::UdpContext<U>, C, R>>,
    /// # >;
    /// # use embedded_nal::{IpAddr, Ipv4Addr, SocketAddr, UdpClientStack};
    /// # use embedded_timers::clock::Clock;
    /// # use rand_core::{CryptoRng, RngCore};
    /// #
    /// #
    /// use embedded_mbedtls::ssl::{SslConnection, SslContext, Preset};
    ///
    /// # fn _setup_ssl_heap<'a, U: UdpClientStack, C: Clock, R: RngCore + CryptoRng>(
    /// #     net_stack: U,
    /// #     clock: &'a C,
    /// #     rng: R,
    /// # ) -> _BoxedDtls<'a, U, C, R> {
    /// #     let server_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 22);
    /// #
    /// let ctx = SslContext::new_udp_client_side(net_stack, clock, rng, server_addr);
    /// let connection = SslConnection::new_dtls_client_heap_context(ctx, Preset::Default).unwrap();
    ///
    /// // now the connection can be moved freely:
    /// return connection
    /// # }
    /// ```
    pub fn new_dtls_client_heap_context(
        ssl_context: SslContext<'a, udp::UdpContext<U>, C, R>,
        preset: Preset,
    ) -> Result<Self, Error> {
        Self::new_generic_dtls_client(Box::new(ssl_context), preset)
    }
}

impl<
        'a,
        U: UdpClientStack,
        C: Clock,
        R: RngCore + CryptoRng,
        CTX: DerefMut<Target = SslContext<'a, udp::UdpContext<U>, C, R>>,
    > SslConnection<'a, udp::UdpContext<U>, C, R, CTX>
{
    /// Create a new DTLS `SslConnection` from an [`SslContext`]
    ///
    /// _Note_: Internal function which should __not__ be made public.
    /// With this function the `ssl_context` is only constrained with by `DerefMut<Target = SslContext<...>>`,
    /// which makes it possible to pass a type which implements DerefMut itself and won't guarantee
    /// that the `ssl_context` is never moved. This is crucial since the C parts hold raw pointers
    /// to members of [`SslContext`].
    fn new_generic_dtls_client(ssl_context: CTX, preset: Preset) -> Result<Self, Error> {
        let mut context = mbedtls_ssl_context::default();
        unsafe { mbedtls_ssl_init(&mut context) };

        let mut this = SslConnection {
            mbedtls_ctx: context,
            ssl_ctx: ssl_context,
        };

        use embedded_mbedtls_sys::MBEDTLS_SSL_IS_CLIENT;
        use embedded_mbedtls_sys::MBEDTLS_SSL_TRANSPORT_DATAGRAM;
        let ret = unsafe {
            embedded_mbedtls_sys::mbedtls_ssl_config_defaults(
                &mut this.ssl_ctx.config as *mut mbedtls_ssl_config,
                MBEDTLS_SSL_IS_CLIENT as i32,
                MBEDTLS_SSL_TRANSPORT_DATAGRAM as i32,
                preset.into(),
            )
        };
        if ret < 0 {
            return Err(ret.into());
        }

        unsafe {
            embedded_mbedtls_sys::mbedtls_ssl_conf_rng(
                &mut this.ssl_ctx.config,
                Some(rng_try_fill_bytes_callback_fn::<R>),
                &mut this.ssl_ctx.csrng as *mut R as *mut c_void,
            );
        }

        let ret = unsafe {
            embedded_mbedtls_sys::mbedtls_ssl_setup(&mut this.mbedtls_ctx, &this.ssl_ctx.config)
        };
        if ret < 0 {
            return Err(ret.into());
        }

        unsafe {
            embedded_mbedtls_sys::mbedtls_ssl_set_bio(
                &mut this.mbedtls_ctx,
                &mut this.ssl_ctx.net_context as *mut udp::UdpContext<U> as *mut c_void,
                Some(crate::udp::udp_send::<U>),
                Some(crate::udp::udp_recv::<U>),
                None, // Some(crate::net::net_recv_timeout::<U, C>),
            );
            embedded_mbedtls_sys::mbedtls_ssl_set_timer_cb(
                &mut this.mbedtls_ctx,
                &mut this.ssl_ctx.timer_context as *mut timing::MbedtlsTimer<C> as *mut c_void,
                Some(timing::set_timer::<C>),
                Some(timing::get_timer::<C>),
            );
        }

        Ok(this)
    }

    /// Set retransmit timeout values for the DTLS handshake
    ///
    /// This method is located in the `impl SslConnection` block which uses the
    /// [`UdpContext`](udp::UdpContext) as net context. This `impl` block is DTLS-specific, i.e.
    /// the method will only be available on DTLS connections because it will have no effect on a
    /// TLS connection.
    ///
    /// The Mbed TLS documentation says the following about choosing timeout values:
    ///
    /// > Default values are from RFC 6347 section 4.2.4.1.
    ///
    /// > The ‘min’ value should typically be slightly above the expected round-trip time to your peer,
    /// plus whatever time it takes for the peer to process the message.
    /// For example, if your RTT is about 600ms and you peer needs up to 1s
    /// to do the cryptographic operations in the handshake,
    /// then you should set `min` slightly above 1600.
    /// Lower values of `min` might cause spurious resends which waste network resources,
    /// while larger value of `min` will increase overall latency on unreliable network links.
    ///
    /// > Messages are retransmitted up to `log2(ceil(max/min))` times.
    /// For example, if `min = 1s` and `max = 5s`, the retransmit plan goes:
    /// send … 1s -> resend … 2s -> resend … 4s -> resend … 5s -> give up and return a timeout error.
    ///
    /// If the chosen `Duration` is out-of-bounds regarding the underlying `u32`, we choose huge
    /// fallbacks to avoid panics.
    pub fn conf_handshake_timeout(&mut self, min: Duration, max: Duration) {
        unsafe {
            embedded_mbedtls_sys::mbedtls_ssl_conf_handshake_timeout(
                &mut self.ssl_ctx.config,
                min.as_millis().try_into().unwrap_or(u32::MAX / 4),
                max.as_millis().try_into().unwrap_or(u32::MAX),
            );
        }
    }
}

impl<'a, Net, C, R, CTX> SslConnection<'a, Net, C, R, CTX>
where
    C: Clock,
    R: RngCore + CryptoRng,
    CTX: DerefMut<Target = SslContext<'a, Net, C, R>>,
{
    /// Perform the ssl handshake (non-blocking)
    pub fn handshake(&mut self) -> nb::Result<(), Error> {
        unsafe {
            use embedded_mbedtls_sys::mbedtls_ssl_handshake;
            let ret = mbedtls_ssl_handshake(&mut self.mbedtls_ctx);
            if matches!(ret, MBEDTLS_ERR_SSL_WANT_READ | MBEDTLS_ERR_SSL_WANT_WRITE) {
                return Err(nb::Error::WouldBlock);
            }
            if ret < 0 {
                return Err(nb::Error::Other(ret.into()));
            }
        }

        Ok(())
    }

    /// Configure pre-shared keys (PSKs) and their identities to be used in PSK-based cipher suites.
    ///
    /// Only one PSK can be registered. If no more PSKs can be configured,
    /// [`SslFeatureUnavailable`](crate::error::MbedtlsError::SslFeatureUnavailable) is returned.
    pub fn configure_psk(&mut self, psk: &[u8], psk_identity: &[u8]) -> Result<(), Error> {
        unsafe {
            let ret = embedded_mbedtls_sys::mbedtls_ssl_conf_psk(
                &mut self.ssl_ctx.config,
                psk.as_ptr(),
                psk.len(),
                psk_identity.as_ptr(),
                psk_identity.len(),
            );

            if ret < 0 {
                Err(ret.into())
            } else {
                Ok(())
            }
        }
    }

    /// Try to write application data bytes (non-blocking)
    ///
    /// Returns the number of bytes actually written (may be less than data.len()).
    ///
    /// ## Note:
    ///
    /// If the requested length is greater than the maximum fragment length (either the built-in
    /// limit or the one set or negotiated with the peer), then:
    /// - with TLS, less bytes than requested are written.
    /// - with DTLS, a [`SslBadInputData`](crate::error::MbedtlsError::SslBadInputData) Error is
    /// returned.
    ///
    /// Attempting to write 0 bytes will result in an empty TLS application record being sent.
    pub fn write(&mut self, data: &[u8]) -> nb::Result<usize, Error> {
        use embedded_mbedtls_sys::{
            MBEDTLS_ERR_SSL_ASYNC_IN_PROGRESS, MBEDTLS_ERR_SSL_CRYPTO_IN_PROGRESS,
        };
        unsafe {
            use embedded_mbedtls_sys::mbedtls_ssl_write;
            let ret = mbedtls_ssl_write(&mut self.mbedtls_ctx, data.as_ptr(), data.len());
            if matches!(
                ret,
                MBEDTLS_ERR_SSL_WANT_READ
                    | MBEDTLS_ERR_SSL_WANT_WRITE
                    | MBEDTLS_ERR_SSL_CRYPTO_IN_PROGRESS
                    | MBEDTLS_ERR_SSL_ASYNC_IN_PROGRESS
            ) {
                return Err(nb::Error::WouldBlock);
            }
            if ret < 0 {
                return Err(nb::Error::Other(ret.into()));
            }
            Ok(ret as usize)
        }
    }

    /// Read at most `buf.len()` application data bytes (non-blocking)
    ///
    /// Returns the number of bytes actually read.
    pub fn read(&mut self, buf: &mut [u8]) -> nb::Result<usize, Error> {
        unsafe {
            use embedded_mbedtls_sys::mbedtls_ssl_read;
            use embedded_mbedtls_sys::{
                MBEDTLS_ERR_SSL_ASYNC_IN_PROGRESS, MBEDTLS_ERR_SSL_CRYPTO_IN_PROGRESS,
            };

            let ret = mbedtls_ssl_read(&mut self.mbedtls_ctx, buf.as_mut_ptr(), buf.len());
            if matches!(
                ret,
                MBEDTLS_ERR_SSL_WANT_READ
                    | MBEDTLS_ERR_SSL_WANT_WRITE
                    | MBEDTLS_ERR_SSL_ASYNC_IN_PROGRESS
                    | MBEDTLS_ERR_SSL_CRYPTO_IN_PROGRESS
            ) {
                return Err(nb::Error::WouldBlock);
            }
            let len = match ret {
                x if x >= 0 => x,
                e => {
                    return Err(nb::Error::Other(e.into()));
                }
            };

            Ok(len as usize)
        }
    }

    /// Notify the peer that the connection is being closed (non-blocking)
    ///
    /// On error, the connection has to be reset to be used again.
    pub fn close_notify(&mut self) -> nb::Result<(), Error> {
        let ret = unsafe { embedded_mbedtls_sys::mbedtls_ssl_close_notify(&mut self.mbedtls_ctx) };

        if ret == 0 {
            return Ok(());
        }
        if matches!(ret, MBEDTLS_ERR_SSL_WANT_READ | MBEDTLS_ERR_SSL_WANT_WRITE) {
            return Err(nb::Error::WouldBlock);
        }
        Err(nb::Error::Other(ret.into()))
    }

    /// Reset an already initialized SSL connection for re-use
    pub fn session_reset(&mut self) -> Result<(), Error> {
        let ret = unsafe { embedded_mbedtls_sys::mbedtls_ssl_session_reset(&mut self.mbedtls_ctx) };
        if ret < 0 {
            Err(ret.into())
        } else {
            Ok(())
        }
    }
}

impl<'a, Net, C, R, CTX> Drop for SslConnection<'a, Net, C, R, CTX>
where
    C: Clock,
    R: RngCore + CryptoRng,
    CTX: DerefMut<Target = SslContext<'a, Net, C, R>>,
{
    fn drop(&mut self) {
        unsafe {
            embedded_mbedtls_sys::mbedtls_ssl_free(&mut self.mbedtls_ctx);
        }
    }
}

/// Cryptography profile preset which is used to configure an [`SslConnection`]
#[derive(Debug, Clone, Copy)]
pub enum Preset {
    /// Default cryptography profile
    Default,
    /// "NSA Suite B Cryptography" profile
    ///
    /// See [RFC 6460](https://datatracker.ietf.org/doc/html/rfc6460) for more information.
    SuiteB,
}

impl From<Preset> for cty::c_int {
    fn from(value: Preset) -> Self {
        match value {
            Preset::Default => embedded_mbedtls_sys::MBEDTLS_SSL_PRESET_DEFAULT as cty::c_int,
            Preset::SuiteB => embedded_mbedtls_sys::MBEDTLS_SSL_PRESET_SUITEB as cty::c_int,
        }
    }
}

#[cfg(test)]
mod test {
    use embedded_nal::{IpAddr, Ipv4Addr, SocketAddr, UdpClientStack};
    use embedded_timers::clock::Clock;
    use rand_core::{CryptoRng, RngCore};

    use crate::udp;

    use super::{SslConnection, SslContext};

    fn _setup_ssl_stack<U: UdpClientStack, R: RngCore + CryptoRng>(
        net_stack: U,
        clock: &impl Clock,
        rng: R,
    ) {
        let server_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 22);

        let mut ctx = SslContext::new_udp_client_side(net_stack, clock, rng, server_addr);
        let _connection = SslConnection::new_dtls_client(&mut ctx, super::Preset::Default).unwrap();

        // Now the connection is ready to use!
    }

    type _BoxedDtls<'a, U, C, R> =
        SslConnection<'a, udp::UdpContext<U>, C, R, Box<SslContext<'a, udp::UdpContext<U>, C, R>>>;

    #[cfg(feature = "alloc")]
    fn _setup_ssl_heap<'a, U: UdpClientStack, C: Clock, R: RngCore + CryptoRng>(
        net_stack: U,
        clock: &'a C,
        rng: R,
    ) -> _BoxedDtls<'a, U, C, R> {
        let server_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 22);

        let ctx = SslContext::new_udp_client_side(net_stack, clock, rng, server_addr);
        let connection =
            SslConnection::new_dtls_client_heap_context(ctx, super::Preset::Default).unwrap();

        connection
    }
}