embedded_tls/
asynch.rs

1use core::sync::atomic::{AtomicBool, Ordering};
2
3use crate::TlsError;
4use crate::common::decrypted_buffer_info::DecryptedBufferInfo;
5use crate::common::decrypted_read_handler::DecryptedReadHandler;
6use crate::connection::{Handshake, State, decrypt_record};
7use crate::flush_policy::FlushPolicy;
8use crate::key_schedule::KeySchedule;
9use crate::key_schedule::{ReadKeySchedule, WriteKeySchedule};
10use crate::read_buffer::ReadBuffer;
11use crate::record::{ClientRecord, ClientRecordHeader};
12use crate::record_reader::{RecordReader, RecordReaderBorrowMut};
13use crate::write_buffer::{WriteBuffer, WriteBufferBorrowMut};
14use embedded_io::Error as _;
15use embedded_io::ErrorType;
16use embedded_io_async::{BufRead, Read as AsyncRead, Write as AsyncWrite};
17
18pub use crate::config::*;
19
20/// Type representing an async TLS connection. An instance of this type can
21/// be used to establish a TLS connection, write and read encrypted data over this connection,
22/// and closing to free up the underlying resources.
23pub struct TlsConnection<'a, Socket, CipherSuite>
24where
25    Socket: AsyncRead + AsyncWrite + 'a,
26    CipherSuite: TlsCipherSuite + 'static,
27{
28    delegate: Socket,
29    opened: AtomicBool,
30    key_schedule: KeySchedule<CipherSuite>,
31    record_reader: RecordReader<'a>,
32    record_write_buf: WriteBuffer<'a>,
33    decrypted: DecryptedBufferInfo,
34    flush_policy: FlushPolicy,
35}
36
37impl<'a, Socket, CipherSuite> TlsConnection<'a, Socket, CipherSuite>
38where
39    Socket: AsyncRead + AsyncWrite + 'a,
40    CipherSuite: TlsCipherSuite + 'static,
41{
42    pub fn is_opened(&mut self) -> bool {
43        *self.opened.get_mut()
44    }
45    /// Create a new TLS connection with the provided context and a async I/O implementation
46    ///
47    /// NOTE: The record read buffer should be sized to fit an encrypted TLS record. The size of this record
48    /// depends on the server configuration, but the maximum allowed value for a TLS record is 16640 bytes,
49    /// which should be a safe value to use.
50    ///
51    /// The write record buffer can be smaller than the read buffer. During writes [`TLS_RECORD_OVERHEAD`] bytes of
52    /// overhead is added per record, so the buffer must at least be this large. Large writes are split into multiple
53    /// records if depending on the size of the write buffer.
54    /// The largest of the two buffers will be used to encode the TLS handshake record, hence either of the
55    /// buffers must at least be large enough to encode a handshake.
56    pub fn new(
57        delegate: Socket,
58        record_read_buf: &'a mut [u8],
59        record_write_buf: &'a mut [u8],
60    ) -> Self {
61        Self {
62            delegate,
63            opened: AtomicBool::new(false),
64            key_schedule: KeySchedule::new(),
65            record_reader: RecordReader::new(record_read_buf),
66            record_write_buf: WriteBuffer::new(record_write_buf),
67            decrypted: DecryptedBufferInfo::default(),
68            flush_policy: FlushPolicy::default(),
69        }
70    }
71
72    /// Returns a reference to the current flush policy.
73    ///
74    /// The flush policy controls whether the underlying transport is flushed
75    /// (via its `flush()` method) after writing a TLS record.
76    #[inline]
77    pub fn flush_policy(&self) -> FlushPolicy {
78        self.flush_policy
79    }
80
81    /// Replace the current flush policy with the provided one.
82    ///
83    /// This sets how and when the connection will call `flush()` on the
84    /// underlying transport after writing records.
85    #[inline]
86    pub fn set_flush_policy(&mut self, policy: FlushPolicy) {
87        self.flush_policy = policy;
88    }
89
90    /// Open a TLS connection, performing the handshake with the configuration provided when
91    /// creating the connection instance.
92    ///
93    /// Returns an error if the handshake does not proceed. If an error occurs, the connection
94    /// instance must be recreated.
95    pub async fn open<Provider>(
96        &mut self,
97        mut context: TlsContext<'_, Provider>,
98    ) -> Result<(), TlsError>
99    where
100        Provider: CryptoProvider<CipherSuite = CipherSuite>,
101    {
102        let mut handshake: Handshake<CipherSuite> = Handshake::new();
103        if let (Ok(verifier), Some(server_name)) = (
104            context.crypto_provider.verifier(),
105            context.config.server_name,
106        ) {
107            verifier.set_hostname_verification(server_name)?;
108        }
109        let mut state = State::ClientHello;
110
111        while state != State::ApplicationData {
112            let next_state = state
113                .process(
114                    &mut self.delegate,
115                    &mut handshake,
116                    &mut self.record_reader,
117                    &mut self.record_write_buf,
118                    &mut self.key_schedule,
119                    context.config,
120                    &mut context.crypto_provider,
121                )
122                .await?;
123            trace!("State {:?} -> {:?}", state, next_state);
124            state = next_state;
125        }
126        *self.opened.get_mut() = true;
127
128        Ok(())
129    }
130
131    /// Encrypt and send the provided slice over the connection. The connection
132    /// must be opened before writing.
133    ///
134    /// The slice may be buffered internally and not written to the connection immediately.
135    /// In this case [`Self::flush()`] should be called to force the currently buffered writes
136    /// to be written to the connection.
137    ///
138    /// Returns the number of bytes buffered/written.
139    pub async fn write(&mut self, buf: &[u8]) -> Result<usize, TlsError> {
140        if self.is_opened() {
141            if !self
142                .record_write_buf
143                .contains(ClientRecordHeader::ApplicationData)
144            {
145                self.flush().await?;
146                self.record_write_buf
147                    .start_record(ClientRecordHeader::ApplicationData)?;
148            }
149
150            let buffered = self.record_write_buf.append(buf);
151
152            if self.record_write_buf.is_full() {
153                self.flush().await?;
154            }
155
156            Ok(buffered)
157        } else {
158            Err(TlsError::MissingHandshake)
159        }
160    }
161
162    /// Force all previously written, buffered bytes to be encoded into a tls record and written
163    /// to the connection.
164    pub async fn flush(&mut self) -> Result<(), TlsError> {
165        if !self.record_write_buf.is_empty() {
166            let key_schedule = self.key_schedule.write_state();
167            let slice = self.record_write_buf.close_record(key_schedule)?;
168
169            self.delegate
170                .write_all(slice)
171                .await
172                .map_err(|e| TlsError::Io(e.kind()))?;
173
174            key_schedule.increment_counter();
175
176            if self.flush_policy.flush_transport() {
177                self.flush_transport().await?;
178            }
179        }
180
181        Ok(())
182    }
183
184    #[inline]
185    async fn flush_transport(&mut self) -> Result<(), TlsError> {
186        self.delegate
187            .flush()
188            .await
189            .map_err(|e| TlsError::Io(e.kind()))
190    }
191
192    fn create_read_buffer(&mut self) -> ReadBuffer<'_> {
193        self.decrypted.create_read_buffer(self.record_reader.buf)
194    }
195
196    /// Read and decrypt data filling the provided slice.
197    pub async fn read(&mut self, buf: &mut [u8]) -> Result<usize, TlsError> {
198        if buf.is_empty() {
199            return Ok(0);
200        }
201        let mut buffer = self.read_buffered().await?;
202
203        let len = buffer.pop_into(buf);
204        trace!("Copied {} bytes", len);
205
206        Ok(len)
207    }
208
209    /// Reads buffered data. If nothing is in memory, it'll wait for a TLS record and process it.
210    pub async fn read_buffered(&mut self) -> Result<ReadBuffer<'_>, TlsError> {
211        if self.is_opened() {
212            while self.decrypted.is_empty() {
213                self.read_application_data().await?;
214            }
215
216            Ok(self.create_read_buffer())
217        } else {
218            Err(TlsError::MissingHandshake)
219        }
220    }
221
222    async fn read_application_data(&mut self) -> Result<(), TlsError> {
223        let buf_ptr_range = self.record_reader.buf.as_ptr_range();
224        let record = self
225            .record_reader
226            .read(&mut self.delegate, self.key_schedule.read_state())
227            .await?;
228
229        let mut handler = DecryptedReadHandler {
230            source_buffer: buf_ptr_range,
231            buffer_info: &mut self.decrypted,
232            is_open: self.opened.get_mut(),
233        };
234        decrypt_record(
235            self.key_schedule.read_state(),
236            record,
237            |_key_schedule, record| handler.handle(record),
238        )?;
239
240        Ok(())
241    }
242
243    /// Close a connection instance, returning the ownership of the config, random generator and the async I/O provider.
244    async fn close_internal(&mut self) -> Result<(), TlsError> {
245        self.flush().await?;
246
247        let is_opened = self.is_opened();
248        let (write_key_schedule, read_key_schedule) = self.key_schedule.as_split();
249        let slice = self.record_write_buf.write_record(
250            &ClientRecord::close_notify(is_opened),
251            write_key_schedule,
252            Some(read_key_schedule),
253        )?;
254
255        self.delegate
256            .write_all(slice)
257            .await
258            .map_err(|e| TlsError::Io(e.kind()))?;
259
260        self.key_schedule.write_state().increment_counter();
261
262        self.flush_transport().await
263    }
264
265    /// Close a connection instance, returning the ownership of the async I/O provider.
266    pub async fn close(mut self) -> Result<Socket, (Socket, TlsError)> {
267        match self.close_internal().await {
268            Ok(()) => Ok(self.delegate),
269            Err(e) => Err((self.delegate, e)),
270        }
271    }
272
273    pub fn split(
274        &mut self,
275    ) -> (
276        TlsReader<'_, Socket, CipherSuite>,
277        TlsWriter<'_, Socket, CipherSuite>,
278    )
279    where
280        Socket: Clone,
281    {
282        let (wks, rks) = self.key_schedule.as_split();
283
284        let reader = TlsReader {
285            opened: &self.opened,
286            delegate: self.delegate.clone(),
287            key_schedule: rks,
288            record_reader: self.record_reader.reborrow_mut(),
289            decrypted: &mut self.decrypted,
290        };
291        let writer = TlsWriter {
292            opened: &self.opened,
293            delegate: self.delegate.clone(),
294            key_schedule: wks,
295            record_write_buf: self.record_write_buf.reborrow_mut(),
296            flush_policy: self.flush_policy,
297        };
298
299        (reader, writer)
300    }
301}
302
303impl<'a, Socket, CipherSuite> ErrorType for TlsConnection<'a, Socket, CipherSuite>
304where
305    Socket: AsyncRead + AsyncWrite + 'a,
306    CipherSuite: TlsCipherSuite + 'static,
307{
308    type Error = TlsError;
309}
310
311impl<'a, Socket, CipherSuite> AsyncRead for TlsConnection<'a, Socket, CipherSuite>
312where
313    Socket: AsyncRead + AsyncWrite + 'a,
314    CipherSuite: TlsCipherSuite + 'static,
315{
316    async fn read(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error> {
317        TlsConnection::read(self, buf).await
318    }
319}
320
321impl<'a, Socket, CipherSuite> BufRead for TlsConnection<'a, Socket, CipherSuite>
322where
323    Socket: AsyncRead + AsyncWrite + 'a,
324    CipherSuite: TlsCipherSuite + 'static,
325{
326    async fn fill_buf(&mut self) -> Result<&[u8], Self::Error> {
327        self.read_buffered().await.map(|mut buf| buf.peek_all())
328    }
329
330    fn consume(&mut self, amt: usize) {
331        self.create_read_buffer().pop(amt);
332    }
333}
334
335impl<'a, Socket, CipherSuite> AsyncWrite for TlsConnection<'a, Socket, CipherSuite>
336where
337    Socket: AsyncRead + AsyncWrite + 'a,
338    CipherSuite: TlsCipherSuite + 'static,
339{
340    async fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
341        TlsConnection::write(self, buf).await
342    }
343
344    async fn flush(&mut self) -> Result<(), Self::Error> {
345        TlsConnection::flush(self).await
346    }
347}
348
349pub struct TlsReader<'a, Socket, CipherSuite>
350where
351    CipherSuite: TlsCipherSuite + 'static,
352{
353    opened: &'a AtomicBool,
354    delegate: Socket,
355    key_schedule: &'a mut ReadKeySchedule<CipherSuite>,
356    record_reader: RecordReaderBorrowMut<'a>,
357    decrypted: &'a mut DecryptedBufferInfo,
358}
359
360impl<Socket, CipherSuite> AsRef<Socket> for TlsReader<'_, Socket, CipherSuite>
361where
362    CipherSuite: TlsCipherSuite + 'static,
363{
364    fn as_ref(&self) -> &Socket {
365        &self.delegate
366    }
367}
368
369impl<'a, Socket, CipherSuite> TlsReader<'a, Socket, CipherSuite>
370where
371    Socket: AsyncRead + 'a,
372    CipherSuite: TlsCipherSuite + 'static,
373{
374    fn create_read_buffer(&mut self) -> ReadBuffer<'_> {
375        self.decrypted.create_read_buffer(self.record_reader.buf)
376    }
377
378    /// Reads buffered data. If nothing is in memory, it'll wait for a TLS record and process it.
379    pub async fn read_buffered(&mut self) -> Result<ReadBuffer<'_>, TlsError> {
380        if self.opened.load(Ordering::Acquire) {
381            while self.decrypted.is_empty() {
382                self.read_application_data().await?;
383            }
384
385            Ok(self.create_read_buffer())
386        } else {
387            Err(TlsError::MissingHandshake)
388        }
389    }
390
391    async fn read_application_data(&mut self) -> Result<(), TlsError> {
392        let buf_ptr_range = self.record_reader.buf.as_ptr_range();
393        let record = self
394            .record_reader
395            .read(&mut self.delegate, self.key_schedule)
396            .await?;
397
398        let mut opened = self.opened.load(Ordering::Acquire);
399        let mut handler = DecryptedReadHandler {
400            source_buffer: buf_ptr_range,
401            buffer_info: self.decrypted,
402            is_open: &mut opened,
403        };
404        let result = decrypt_record(self.key_schedule, record, |_key_schedule, record| {
405            handler.handle(record)
406        });
407
408        if !opened {
409            self.opened.store(false, Ordering::Release);
410        }
411        result
412    }
413}
414
415pub struct TlsWriter<'a, Socket, CipherSuite>
416where
417    CipherSuite: TlsCipherSuite + 'static,
418{
419    opened: &'a AtomicBool,
420    delegate: Socket,
421    key_schedule: &'a mut WriteKeySchedule<CipherSuite>,
422    record_write_buf: WriteBufferBorrowMut<'a>,
423    flush_policy: FlushPolicy,
424}
425
426impl<'a, Socket, CipherSuite> TlsWriter<'a, Socket, CipherSuite>
427where
428    Socket: AsyncWrite + 'a,
429    CipherSuite: TlsCipherSuite + 'static,
430{
431    #[inline]
432    async fn flush_transport(&mut self) -> Result<(), TlsError> {
433        self.delegate
434            .flush()
435            .await
436            .map_err(|e| TlsError::Io(e.kind()))
437    }
438}
439
440impl<Socket, CipherSuite> AsRef<Socket> for TlsWriter<'_, Socket, CipherSuite>
441where
442    CipherSuite: TlsCipherSuite + 'static,
443{
444    fn as_ref(&self) -> &Socket {
445        &self.delegate
446    }
447}
448
449impl<Socket, CipherSuite> ErrorType for TlsWriter<'_, Socket, CipherSuite>
450where
451    CipherSuite: TlsCipherSuite + 'static,
452{
453    type Error = TlsError;
454}
455
456impl<Socket, CipherSuite> ErrorType for TlsReader<'_, Socket, CipherSuite>
457where
458    CipherSuite: TlsCipherSuite + 'static,
459{
460    type Error = TlsError;
461}
462
463impl<'a, Socket, CipherSuite> AsyncRead for TlsReader<'a, Socket, CipherSuite>
464where
465    Socket: AsyncRead + 'a,
466    CipherSuite: TlsCipherSuite + 'static,
467{
468    async fn read(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error> {
469        if buf.is_empty() {
470            return Ok(0);
471        }
472        let mut buffer = self.read_buffered().await?;
473
474        let len = buffer.pop_into(buf);
475        trace!("Copied {} bytes", len);
476
477        Ok(len)
478    }
479}
480
481impl<'a, Socket, CipherSuite> BufRead for TlsReader<'a, Socket, CipherSuite>
482where
483    Socket: AsyncRead + 'a,
484    CipherSuite: TlsCipherSuite + 'static,
485{
486    async fn fill_buf(&mut self) -> Result<&[u8], Self::Error> {
487        self.read_buffered().await.map(|mut buf| buf.peek_all())
488    }
489
490    fn consume(&mut self, amt: usize) {
491        self.create_read_buffer().pop(amt);
492    }
493}
494
495impl<'a, Socket, CipherSuite> AsyncWrite for TlsWriter<'a, Socket, CipherSuite>
496where
497    Socket: AsyncWrite + 'a,
498    CipherSuite: TlsCipherSuite + 'static,
499{
500    async fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
501        if self.opened.load(Ordering::Acquire) {
502            if !self
503                .record_write_buf
504                .contains(ClientRecordHeader::ApplicationData)
505            {
506                self.flush().await?;
507                self.record_write_buf
508                    .start_record(ClientRecordHeader::ApplicationData)?;
509            }
510
511            let buffered = self.record_write_buf.append(buf);
512
513            if self.record_write_buf.is_full() {
514                self.flush().await?;
515            }
516
517            Ok(buffered)
518        } else {
519            Err(TlsError::MissingHandshake)
520        }
521    }
522
523    async fn flush(&mut self) -> Result<(), Self::Error> {
524        if !self.record_write_buf.is_empty() {
525            let slice = self.record_write_buf.close_record(self.key_schedule)?;
526
527            self.delegate
528                .write_all(slice)
529                .await
530                .map_err(|e| TlsError::Io(e.kind()))?;
531
532            self.key_schedule.increment_counter();
533
534            if self.flush_policy.flush_transport() {
535                self.flush_transport().await?;
536            }
537        }
538
539        Ok(())
540    }
541}