embedded_mbedtls/
udp.rs

1// Copyright Open Logistics Foundation
2//
3// Licensed under the Open Logistics Foundation License 1.3.
4// For details on the licensing terms, see the LICENSE file.
5// SPDX-License-Identifier: OLFL-1.3
6
7//! UDP module which defines the [`UdpContext`] which is used internally in the
8//! [`SslContext`](crate::ssl::SslContext) for DTLS connections
9
10use core::ffi::{c_int, c_uchar, c_void};
11use core::net::SocketAddr;
12use embedded_nal::{nb::Error, UdpClientStack};
13
14/// Udp receive (non-blocking) C callback
15///
16/// param ctx: *mut [UdpContext<T>]
17///
18/// Returns [`MBEDTLS_ERR_SSL_WANT_READ`](embedded_mbedtls_sys::MBEDTLS_ERR_SSL_WANT_READ) if the
19/// network stack isn't available.
20pub(crate) unsafe extern "C" fn udp_recv<T: UdpClientStack>(
21    ctx: *mut c_void,
22    buf: *mut c_uchar,
23    len: usize,
24) -> c_int {
25    let buf_slice = unsafe { core::slice::from_raw_parts_mut(buf, len) };
26    let context = unsafe { (ctx as *mut UdpContext<T>).as_mut() };
27    let Some(context) = context else {
28        log::error!("failed to retrieve UdpContext (null pointer exception)!");
29        return -1;
30    };
31
32    let (socket, stack) = match context.get_socket_and_stack() {
33        Ok(ret) => ret,
34        Err(e) => {
35            log::error!("failed to create udp socket: {e:?}");
36            return -1;
37        }
38    };
39
40    let res = stack.receive(socket, buf_slice);
41    match res {
42        Ok((n, _)) => n as c_int,
43        Err(Error::WouldBlock) => embedded_mbedtls_sys::MBEDTLS_ERR_SSL_WANT_READ,
44        Err(Error::Other(e)) => {
45            log::warn!("udp receive failed: {e:?}");
46            -1
47        }
48    }
49}
50
51/// Udp send C callback (non-blocking)
52///
53/// param ctx: *mut [UdpContext<T>]
54pub(crate) unsafe extern "C" fn udp_send<T: UdpClientStack>(
55    ctx: *mut c_void,
56    buf: *const c_uchar,
57    len: usize,
58) -> c_int {
59    let buf_slice = unsafe { core::slice::from_raw_parts(buf, len) };
60    let context = unsafe { (ctx as *mut UdpContext<T>).as_mut() };
61    let Some(context) = context else {
62        log::error!("failed to retrieve UdpContext (null pointer exception)!");
63        return -1;
64    };
65
66    let (socket, stack) = match context.get_socket_and_stack() {
67        Ok(ret) => ret,
68        Err(e) => {
69            log::error!("failed to create udp socket: {e:?}");
70            return -1;
71        }
72    };
73
74    let res = stack.send(socket, buf_slice);
75
76    if let Err(e) = res {
77        match e {
78            embedded_nal::nb::Error::Other(e) => {
79                log::warn!("udp send failed: {:?}", e);
80                return -1;
81            }
82            embedded_nal::nb::Error::WouldBlock => {
83                return embedded_mbedtls_sys::MBEDTLS_ERR_SSL_WANT_WRITE
84            }
85        }
86    }
87
88    len as c_int // assume that all bytes have been send
89}
90
91/// UDP network context
92pub struct UdpContext<U: UdpClientStack> {
93    socket: Option<U::UdpSocket>,
94    stack: U,
95    server_addr: SocketAddr,
96}
97impl<U: UdpClientStack> UdpContext<U> {
98    pub(crate) fn new(net_stack: U, server_addr: SocketAddr) -> Self {
99        Self {
100            socket: None,
101            stack: net_stack,
102            server_addr,
103        }
104    }
105    /// Borrow both socket and stack, creating a new socket if necessary
106    ///
107    /// The socket is connected to the server address.
108    pub(crate) fn get_socket_and_stack(&mut self) -> Result<(&mut U::UdpSocket, &mut U), U::Error> {
109        let socket = if self.socket.is_none() {
110            let s = self.socket.insert(self.stack.socket()?);
111            self.stack.connect(s, self.server_addr)?;
112            s
113        } else {
114            self.socket.as_mut().unwrap()
115        };
116        Ok((socket, &mut self.stack))
117    }
118}