psp_net/socket/tls.rs
1#![allow(clippy::module_name_repetitions)]
2
3use core::fmt::Debug;
4
5use alloc::string::String;
6use embedded_io::{ErrorType, Read, Write};
7use embedded_tls::{blocking::TlsConnection, Aes128GcmSha256, NoVerify, TlsConfig, TlsContext};
8
9use rand::SeedableRng;
10use rand_chacha::ChaCha20Rng;
11use regex::Regex;
12
13use crate::{
14 traits::io::{EasySocket, Open, OptionType},
15 types::TlsSocketOptions,
16};
17
18use super::{
19 state::{Connected, NotReady, Ready, SocketState},
20 tcp::TcpSocket,
21};
22
23lazy_static::lazy_static! {
24 static ref REGEX: Regex = Regex::new("\r|\0").unwrap();
25}
26
27/// TLS maximum fragment length, equivalent to 2^14 bytes (`16_384` bytes)
28pub const MAX_FRAGMENT_LENGTH: u16 = 16_384;
29
30/// A TLS socket.
31/// This is a wrapper around a [`TcpSocket`] that provides a TLS connection.
32///
33/// # Notes
34/// For the Debug trait a dummy implementation is provided.
35pub struct TlsSocket<'a, S: SocketState = NotReady> {
36 /// The TLS connection
37 tls_connection: TlsConnection<'a, TcpSocket<Connected>, Aes128GcmSha256>,
38 /// The TLS config
39 tls_config: TlsConfig<'a, Aes128GcmSha256>,
40 /// marker for the socket state
41 _marker: core::marker::PhantomData<S>,
42}
43
44impl Debug for TlsSocket<'_> {
45 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
46 f.debug_struct("TlsSocket").finish()
47 }
48}
49
50impl<'a> TlsSocket<'_> {
51 /// Create a new TLS socket.
52 /// This will create a new TLS connection using the provided [`TcpSocket`].
53 ///
54 /// # Parameters
55 /// - `socket`: The TCP socket to use for the TLS connection
56 /// - `record_read_buf`: A buffer to use for reading records
57 /// - `record_write_buf`: A buffer to use for writing records
58 ///
59 /// # Returns
60 /// A new TLS socket in the [`NotReady`] state. Use [`TlsSocket::open()`] to get a
61 /// ready socket.
62 ///
63 /// # Example
64 /// ```no_run
65 /// let mut read_buf = TlsSocket::new_buffer();
66 /// let mut write_buf = TlsSocket::new_buffer();
67 /// let tls_socket = TlsSocket::new(tcp_socket, &mut read_buf, &mut write_buf);
68 /// let tls_socket = tls_socket.open(&options)?;
69 /// ```
70 ///
71 /// # Notes
72 /// In most cases you can pass `None` for the `cert` parameter.
73 pub fn new(
74 socket: TcpSocket<Connected>,
75 record_read_buf: &'a mut [u8],
76 record_write_buf: &'a mut [u8],
77 ) -> TlsSocket<'a, NotReady> {
78 let tls_config: TlsConfig<'_, Aes128GcmSha256> = TlsConfig::new();
79
80 let tls_connection: TlsConnection<TcpSocket<Connected>, Aes128GcmSha256> =
81 TlsConnection::new(socket, record_read_buf, record_write_buf);
82 TlsSocket {
83 tls_connection,
84 tls_config,
85 _marker: core::marker::PhantomData,
86 }
87 }
88
89 /// Create a new buffer.
90 /// It is a utility function to create the read/write buffer to pass to [`Self::new()`].
91 ///
92 /// # Returns
93 /// A new buffer of [`MAX_FRAGMENT_LENGTH`] (`16_384`) bytes.
94 ///
95 /// # Example
96 /// ```no_run
97 /// let mut read_buf = TlsSocket::new_buffer();
98 /// let mut write_buf = TlsSocket::new_buffer();
99 /// let tls_socket = TlsSocket::new(tcp_socket, &mut read_buf, &mut write_buf);
100 /// ```
101 #[must_use]
102 pub fn new_buffer() -> [u8; MAX_FRAGMENT_LENGTH as usize] {
103 [0; 16_384]
104 }
105}
106
107impl TlsSocket<'_, Ready> {
108 /// Write all data to the TLS connection.
109 ///
110 /// Writes until all data is written or an error occurs.
111 ///
112 /// # Parameters
113 /// - `buf`: The buffer containing the data to be sent.
114 ///
115 /// # Returns
116 /// - `Ok(())` if the write was successful.
117 /// - `Err(TlsError)` if the write was unsuccessful.
118 ///
119 /// # Errors
120 /// [`embedded_tls::TlsError`] if the write fails.
121 pub fn write_all(&mut self, buf: &[u8]) -> Result<(), embedded_tls::TlsError> {
122 self.tls_connection.write_all(buf)
123 }
124
125 /// Read data from the TLS connection and converts it to a [`String`].
126 ///
127 /// # Returns
128 /// - `Ok(String)` if the read was successful.
129 /// - `Err(TlsError)` if the read was unsuccessful.
130 ///
131 /// # Errors
132 /// [`embedded_tls::TlsError`] if the read fails.
133 pub fn read_string(&mut self) -> Result<String, embedded_tls::TlsError> {
134 let mut buf = TlsSocket::new_buffer();
135 let _ = self.read(&mut buf)?;
136
137 let text = String::from_utf8_lossy(&buf);
138 let text = REGEX.replace_all(&text, "");
139 Ok(text.into_owned())
140 }
141}
142
143impl<S: SocketState> ErrorType for TlsSocket<'_, S> {
144 /// The error type for the TLS socket.
145 type Error = embedded_tls::TlsError;
146}
147
148impl<S: SocketState> OptionType for TlsSocket<'_, S> {
149 /// The options type for the TLS socket.
150 type Options<'b> = TlsSocketOptions<'b>;
151}
152
153impl<'a, 'b> Open<'a, 'b> for TlsSocket<'b, NotReady>
154where
155 'a: 'b,
156{
157 type Return = TlsSocket<'a, Ready>;
158 /// Open the TLS connection.
159 ///
160 /// # Parameters
161 /// - `options`: The TLS options, of type [`TlsSocketOptions`].
162 ///
163 /// # Returns
164 /// A new [`TlsSocket<Ready>`], or an error if opening fails.
165 ///
166 /// # Example
167 /// ```no_run
168 /// let tls_socket = TlsSocket::new(tcp_socket, &mut read_buf, &mut write_buf);
169 /// let tls_socket = tls_socket.open(&options)?;
170 /// ```
171 ///
172 /// # Notes
173 /// The function takes ownership of the socket ([`TcpSocket<NotReady>`]), and returns a new socket of type [`TlsSocket<Ready>`].
174 /// Therefore, you must assign the returned socket to a variable in order to use it.
175 fn open(self, options: &'b Self::Options<'_>) -> Result<Self::Return, embedded_tls::TlsError>
176 where
177 'b: 'a,
178 {
179 let mut rng = ChaCha20Rng::seed_from_u64(options.seed());
180
181 let mut tls_socket: TlsSocket<Ready> = TlsSocket {
182 tls_connection: self.tls_connection,
183 tls_config: self.tls_config,
184 _marker: core::marker::PhantomData,
185 };
186
187 tls_socket.tls_config = tls_socket
188 .tls_config
189 .with_server_name(options.server_name());
190
191 if options.rsa_signatures_enabled() {
192 tls_socket.tls_config = tls_socket.tls_config.enable_rsa_signatures();
193 }
194
195 if options.reset_max_fragment_length() {
196 tls_socket.tls_config = tls_socket.tls_config.reset_max_fragment_length();
197 }
198
199 if let Some(cert) = options.cert() {
200 tls_socket.tls_config = tls_socket.tls_config.with_cert(cert.clone());
201 }
202
203 if let Some(ca) = options.ca() {
204 tls_socket.tls_config = tls_socket.tls_config.with_ca(ca.clone());
205 }
206
207 let tls_context = TlsContext::new(&tls_socket.tls_config, &mut rng);
208 tls_socket
209 .tls_connection
210 .open::<ChaCha20Rng, NoVerify>(tls_context)?;
211
212 Ok(tls_socket)
213 }
214}
215
216impl embedded_io::Read for TlsSocket<'_, Ready> {
217 /// Read data from the TLS connection.
218 ///
219 /// # Parameters
220 /// - `buf`: The buffer where the data will be stored.
221 ///
222 /// # Returns
223 /// - `Ok(usize)` if the read was successful. The number of bytes read
224 /// - `Err(SocketError)` if the read was unsuccessful.
225 fn read(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error> {
226 self.tls_connection.read(buf)
227 }
228}
229
230impl embedded_io::Write for TlsSocket<'_, Ready> {
231 /// Write data to the TLS connection.
232 ///
233 /// # Parameters
234 /// - `buf`: The buffer containing the data to be sent.
235 ///
236 /// # Returns
237 /// - `Ok(usize)` if the write was successful. The number of bytes written
238 /// - `Err(SocketError)` if the write was unsuccessful.
239 fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
240 self.tls_connection.write(buf)
241 }
242
243 /// Flush the TLS connection.
244 fn flush(&mut self) -> Result<(), Self::Error> {
245 self.tls_connection.flush()
246 }
247}
248
249impl EasySocket for TlsSocket<'_, Ready> {}