psp_net/socket/
tls.rs

1#![allow(clippy::module_name_repetitions)]
2
3use core::fmt::Debug;
4
5use alloc::string::String;
6use embedded_io::{ErrorType, Read, Write};
7use embedded_tls::{blocking::TlsConnection, Aes128GcmSha256, NoVerify, TlsConfig, TlsContext};
8
9use rand::SeedableRng;
10use rand_chacha::ChaCha20Rng;
11use regex::Regex;
12
13use crate::{
14    traits::io::{EasySocket, Open, OptionType},
15    types::TlsSocketOptions,
16};
17
18use super::{
19    state::{Connected, NotReady, Ready, SocketState},
20    tcp::TcpSocket,
21};
22
23lazy_static::lazy_static! {
24    static ref REGEX: Regex = Regex::new("\r|\0").unwrap();
25}
26
27/// TLS maximum fragment length, equivalent to 2^14 bytes (`16_384` bytes)
28pub const MAX_FRAGMENT_LENGTH: u16 = 16_384;
29
30/// A TLS socket.
31/// This is a wrapper around a [`TcpSocket`] that provides a TLS connection.
32///
33/// # Notes
34/// For the Debug trait a dummy implementation is provided.
35pub struct TlsSocket<'a, S: SocketState = NotReady> {
36    /// The TLS connection
37    tls_connection: TlsConnection<'a, TcpSocket<Connected>, Aes128GcmSha256>,
38    /// The TLS config
39    tls_config: TlsConfig<'a, Aes128GcmSha256>,
40    /// marker for the socket state
41    _marker: core::marker::PhantomData<S>,
42}
43
44impl Debug for TlsSocket<'_> {
45    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
46        f.debug_struct("TlsSocket").finish()
47    }
48}
49
50impl<'a> TlsSocket<'_> {
51    /// Create a new TLS socket.
52    /// This will create a new TLS connection using the provided [`TcpSocket`].
53    ///
54    /// # Parameters
55    /// - `socket`: The TCP socket to use for the TLS connection
56    /// - `record_read_buf`: A buffer to use for reading records
57    /// - `record_write_buf`: A buffer to use for writing records
58    ///
59    /// # Returns
60    /// A new TLS socket in the [`NotReady`] state. Use [`TlsSocket::open()`] to get a
61    /// ready socket.
62    ///
63    /// # Example
64    /// ```no_run
65    /// let mut read_buf = TlsSocket::new_buffer();
66    /// let mut write_buf = TlsSocket::new_buffer();
67    /// let tls_socket = TlsSocket::new(tcp_socket, &mut read_buf, &mut write_buf);
68    /// let tls_socket = tls_socket.open(&options)?;
69    /// ```
70    ///
71    /// # Notes
72    /// In most cases you can pass `None` for the `cert` parameter.
73    pub fn new(
74        socket: TcpSocket<Connected>,
75        record_read_buf: &'a mut [u8],
76        record_write_buf: &'a mut [u8],
77    ) -> TlsSocket<'a, NotReady> {
78        let tls_config: TlsConfig<'_, Aes128GcmSha256> = TlsConfig::new();
79
80        let tls_connection: TlsConnection<TcpSocket<Connected>, Aes128GcmSha256> =
81            TlsConnection::new(socket, record_read_buf, record_write_buf);
82        TlsSocket {
83            tls_connection,
84            tls_config,
85            _marker: core::marker::PhantomData,
86        }
87    }
88
89    /// Create a new buffer.
90    /// It is a utility function to create the read/write buffer to pass to [`Self::new()`].
91    ///
92    /// # Returns
93    /// A new buffer of [`MAX_FRAGMENT_LENGTH`] (`16_384`) bytes.
94    ///
95    /// # Example
96    /// ```no_run
97    /// let mut read_buf = TlsSocket::new_buffer();
98    /// let mut write_buf = TlsSocket::new_buffer();
99    /// let tls_socket = TlsSocket::new(tcp_socket, &mut read_buf, &mut write_buf);
100    /// ```
101    #[must_use]
102    pub fn new_buffer() -> [u8; MAX_FRAGMENT_LENGTH as usize] {
103        [0; 16_384]
104    }
105}
106
107impl TlsSocket<'_, Ready> {
108    /// Write all data to the TLS connection.
109    ///
110    /// Writes until all data is written or an error occurs.
111    ///
112    /// # Parameters
113    /// - `buf`: The buffer containing the data to be sent.
114    ///
115    /// # Returns
116    /// - `Ok(())` if the write was successful.
117    /// - `Err(TlsError)` if the write was unsuccessful.
118    ///
119    /// # Errors
120    /// [`embedded_tls::TlsError`] if the write fails.
121    pub fn write_all(&mut self, buf: &[u8]) -> Result<(), embedded_tls::TlsError> {
122        self.tls_connection.write_all(buf)
123    }
124
125    /// Read data from the TLS connection and converts it to a [`String`].
126    ///
127    /// # Returns
128    /// - `Ok(String)` if the read was successful.
129    /// - `Err(TlsError)` if the read was unsuccessful.
130    ///
131    /// # Errors
132    /// [`embedded_tls::TlsError`] if the read fails.
133    pub fn read_string(&mut self) -> Result<String, embedded_tls::TlsError> {
134        let mut buf = TlsSocket::new_buffer();
135        let _ = self.read(&mut buf)?;
136
137        let text = String::from_utf8_lossy(&buf);
138        let text = REGEX.replace_all(&text, "");
139        Ok(text.into_owned())
140    }
141}
142
143impl<S: SocketState> ErrorType for TlsSocket<'_, S> {
144    /// The error type for the TLS socket.
145    type Error = embedded_tls::TlsError;
146}
147
148impl<S: SocketState> OptionType for TlsSocket<'_, S> {
149    /// The options type for the TLS socket.
150    type Options<'b> = TlsSocketOptions<'b>;
151}
152
153impl<'a, 'b> Open<'a, 'b> for TlsSocket<'b, NotReady>
154where
155    'a: 'b,
156{
157    type Return = TlsSocket<'a, Ready>;
158    /// Open the TLS connection.
159    ///
160    /// # Parameters
161    /// - `options`: The TLS options, of type [`TlsSocketOptions`].
162    ///
163    /// # Returns
164    /// A new [`TlsSocket<Ready>`], or an error if opening fails.
165    ///
166    /// # Example
167    /// ```no_run
168    /// let tls_socket = TlsSocket::new(tcp_socket, &mut read_buf, &mut write_buf);
169    /// let tls_socket = tls_socket.open(&options)?;
170    /// ```
171    ///
172    /// # Notes
173    /// The function takes ownership of the socket ([`TcpSocket<NotReady>`]), and returns a new socket of type [`TlsSocket<Ready>`].
174    /// Therefore, you must assign the returned socket to a variable in order to use it.
175    fn open(self, options: &'b Self::Options<'_>) -> Result<Self::Return, embedded_tls::TlsError>
176    where
177        'b: 'a,
178    {
179        let mut rng = ChaCha20Rng::seed_from_u64(options.seed());
180
181        let mut tls_socket: TlsSocket<Ready> = TlsSocket {
182            tls_connection: self.tls_connection,
183            tls_config: self.tls_config,
184            _marker: core::marker::PhantomData,
185        };
186
187        tls_socket.tls_config = tls_socket
188            .tls_config
189            .with_server_name(options.server_name());
190
191        if options.rsa_signatures_enabled() {
192            tls_socket.tls_config = tls_socket.tls_config.enable_rsa_signatures();
193        }
194
195        if options.reset_max_fragment_length() {
196            tls_socket.tls_config = tls_socket.tls_config.reset_max_fragment_length();
197        }
198
199        if let Some(cert) = options.cert() {
200            tls_socket.tls_config = tls_socket.tls_config.with_cert(cert.clone());
201        }
202
203        if let Some(ca) = options.ca() {
204            tls_socket.tls_config = tls_socket.tls_config.with_ca(ca.clone());
205        }
206
207        let tls_context = TlsContext::new(&tls_socket.tls_config, &mut rng);
208        tls_socket
209            .tls_connection
210            .open::<ChaCha20Rng, NoVerify>(tls_context)?;
211
212        Ok(tls_socket)
213    }
214}
215
216impl embedded_io::Read for TlsSocket<'_, Ready> {
217    /// Read data from the TLS connection.
218    ///
219    /// # Parameters
220    /// - `buf`: The buffer where the data will be stored.
221    ///
222    /// # Returns
223    /// - `Ok(usize)` if the read was successful. The number of bytes read
224    /// - `Err(SocketError)` if the read was unsuccessful.
225    fn read(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error> {
226        self.tls_connection.read(buf)
227    }
228}
229
230impl embedded_io::Write for TlsSocket<'_, Ready> {
231    /// Write data to the TLS connection.
232    ///
233    /// # Parameters
234    /// - `buf`: The buffer containing the data to be sent.
235    ///
236    /// # Returns
237    /// - `Ok(usize)` if the write was successful. The number of bytes written
238    /// - `Err(SocketError)` if the write was unsuccessful.
239    fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
240        self.tls_connection.write(buf)
241    }
242
243    /// Flush the TLS connection.
244    fn flush(&mut self) -> Result<(), Self::Error> {
245        self.tls_connection.flush()
246    }
247}
248
249impl EasySocket for TlsSocket<'_, Ready> {}