1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
// Copyright Open Logistics Foundation
//
// Licensed under the Open Logistics Foundation License 1.3.
// For details on the licensing terms, see the LICENSE file.
// SPDX-License-Identifier: OLFL-1.3

//! UDP module which defines the [`UdpContext`] which is used internally in the
//! [`SslContext`](crate::ssl::SslContext) for DTLS connections

use cty::{c_int, c_uchar, c_void};
use embedded_nal::{nb::Error, SocketAddr, UdpClientStack};

/// Udp receive (non-blocking) C callback
///
/// param ctx: *mut [UdpContext<T>]
///
/// Returns [`MBEDTLS_ERR_SSL_WANT_READ`](embedded_mbedtls_sys::MBEDTLS_ERR_SSL_WANT_READ) if the
/// network stack isn't available.
pub(crate) unsafe extern "C" fn udp_recv<T: UdpClientStack>(
    ctx: *mut c_void,
    buf: *mut c_uchar,
    len: usize,
) -> c_int {
    let buf_slice = unsafe { core::slice::from_raw_parts_mut(buf, len) };
    let context = unsafe { (ctx as *mut UdpContext<T>).as_mut() };
    let Some(context) = context else {
        log::error!("failed to retrieve UdpContext (null pointer exception)!");
        return -1;
    };

    let (socket, stack) = match context.get_socket_and_stack() {
        Ok(ret) => ret,
        Err(e) => {
            log::error!("failed to create udp socket: {e:?}");
            return -1;
        }
    };

    let res = stack.receive(socket, buf_slice);
    match res {
        Ok((n, _)) => n as c_int,
        Err(Error::WouldBlock) => embedded_mbedtls_sys::MBEDTLS_ERR_SSL_WANT_READ,
        Err(Error::Other(e)) => {
            log::warn!("udp receive failed: {e:?}");
            -1
        }
    }
}

/// Udp send C callback (non-blocking)
///
/// param ctx: *mut [UdpContext<T>]
pub(crate) unsafe extern "C" fn udp_send<T: UdpClientStack>(
    ctx: *mut c_void,
    buf: *const c_uchar,
    len: usize,
) -> c_int {
    let buf_slice = unsafe { core::slice::from_raw_parts(buf, len) };
    let context = unsafe { (ctx as *mut UdpContext<T>).as_mut() };
    let Some(context) = context else {
        log::error!("failed to retrieve UdpContext (null pointer exception)!");
        return -1;
    };

    let (socket, stack) = match context.get_socket_and_stack() {
        Ok(ret) => ret,
        Err(e) => {
            log::error!("failed to create udp socket: {e:?}");
            return -1;
        }
    };

    let res = stack.send(socket, buf_slice);

    if let Err(e) = res {
        match e {
            embedded_nal::nb::Error::Other(e) => {
                log::warn!("udp send failed: {:?}", e);
                return -1;
            }
            embedded_nal::nb::Error::WouldBlock => {
                return embedded_mbedtls_sys::MBEDTLS_ERR_SSL_WANT_WRITE
            }
        }
    }

    len as c_int // assume that all bytes have been send
}

/// UDP network context
pub struct UdpContext<U: UdpClientStack> {
    socket: Option<U::UdpSocket>,
    stack: U,
    server_addr: SocketAddr,
}
impl<U: UdpClientStack> UdpContext<U> {
    pub(crate) fn new(net_stack: U, server_addr: SocketAddr) -> Self {
        Self {
            socket: None,
            stack: net_stack,
            server_addr,
        }
    }
    /// Borrow both socket and stack, creating a new socket if necessary
    ///
    /// The socket is connected to the server address.
    pub(crate) fn get_socket_and_stack(&mut self) -> Result<(&mut U::UdpSocket, &mut U), U::Error> {
        let socket = if self.socket.is_none() {
            let s = self.socket.insert(self.stack.socket()?);
            self.stack.connect(s, self.server_addr)?;
            s
        } else {
            self.socket.as_mut().unwrap()
        };
        Ok((socket, &mut self.stack))
    }
}