embedded_tls/
asynch.rs

1use crate::common::decrypted_buffer_info::DecryptedBufferInfo;
2use crate::common::decrypted_read_handler::DecryptedReadHandler;
3use crate::connection::*;
4use crate::key_schedule::KeySchedule;
5use crate::key_schedule::{ReadKeySchedule, SharedState, WriteKeySchedule};
6use crate::read_buffer::ReadBuffer;
7use crate::record::{ClientRecord, ClientRecordHeader};
8use crate::record_reader::RecordReader;
9use crate::split::{SplitState, SplitStateContainer};
10use crate::write_buffer::WriteBuffer;
11use crate::TlsError;
12use embedded_io::Error as _;
13use embedded_io::ErrorType;
14use embedded_io_async::{BufRead, Read as AsyncRead, Write as AsyncWrite};
15use rand_core::{CryptoRng, RngCore};
16
17pub use crate::config::*;
18#[cfg(feature = "std")]
19pub use crate::split::ManagedSplitState;
20pub use crate::split::SplitConnectionState;
21
22/// Type representing an async TLS connection. An instance of this type can
23/// be used to establish a TLS connection, write and read encrypted data over this connection,
24/// and closing to free up the underlying resources.
25pub struct TlsConnection<'a, Socket, CipherSuite>
26where
27    Socket: AsyncRead + AsyncWrite + 'a,
28    CipherSuite: TlsCipherSuite + 'static,
29{
30    delegate: Socket,
31    opened: bool,
32    key_schedule: KeySchedule<CipherSuite>,
33    record_reader: RecordReader<'a, CipherSuite>,
34    record_write_buf: WriteBuffer<'a>,
35    decrypted: DecryptedBufferInfo,
36}
37
38impl<'a, Socket, CipherSuite> TlsConnection<'a, Socket, CipherSuite>
39where
40    Socket: AsyncRead + AsyncWrite + 'a,
41    CipherSuite: TlsCipherSuite + 'static,
42{
43    /// Create a new TLS connection with the provided context and a async I/O implementation
44    ///
45    /// NOTE: The record read buffer should be sized to fit an encrypted TLS record. The size of this record
46    /// depends on the server configuration, but the maximum allowed value for a TLS record is 16640 bytes,
47    /// which should be a safe value to use.
48    ///
49    /// The write record buffer can be smaller than the read buffer. During writes [`TLS_RECORD_OVERHEAD`] bytes of
50    /// overhead is added per record, so the buffer must at least be this large. Large writes are split into multiple
51    /// records if depending on the size of the write buffer.
52    /// The largest of the two buffers will be used to encode the TLS handshake record, hence either of the
53    /// buffers must at least be large enough to encode a handshake.
54    pub fn new(
55        delegate: Socket,
56        record_read_buf: &'a mut [u8],
57        record_write_buf: &'a mut [u8],
58    ) -> Self {
59        Self {
60            delegate,
61            opened: false,
62            key_schedule: KeySchedule::new(),
63            record_reader: RecordReader::new(record_read_buf),
64            record_write_buf: WriteBuffer::new(record_write_buf),
65            decrypted: DecryptedBufferInfo::default(),
66        }
67    }
68
69    /// Open a TLS connection, performing the handshake with the configuration provided when
70    /// creating the connection instance.
71    ///
72    /// Returns an error if the handshake does not proceed. If an error occurs, the connection
73    /// instance must be recreated.
74    pub async fn open<'v, RNG, Verifier>(
75        &mut self,
76        context: TlsContext<'v, CipherSuite, RNG>,
77    ) -> Result<(), TlsError>
78    where
79        RNG: CryptoRng + RngCore,
80        Verifier: TlsVerifier<'v, CipherSuite>,
81    {
82        let mut handshake: Handshake<CipherSuite, Verifier> =
83            Handshake::new(Verifier::new(context.config.server_name));
84        let mut state = State::ClientHello;
85
86        while state != State::ApplicationData {
87            let next_state = state
88                .process(
89                    &mut self.delegate,
90                    &mut handshake,
91                    &mut self.record_reader,
92                    &mut self.record_write_buf,
93                    &mut self.key_schedule,
94                    context.config,
95                    context.rng,
96                )
97                .await?;
98            trace!("State {:?} -> {:?}", state, next_state);
99            state = next_state;
100        }
101        self.opened = true;
102
103        Ok(())
104    }
105
106    /// Encrypt and send the provided slice over the connection. The connection
107    /// must be opened before writing.
108    ///
109    /// The slice may be buffered internally and not written to the connection immediately.
110    /// In this case [`Self::flush()`] should be called to force the currently buffered writes
111    /// to be written to the connection.
112    ///
113    /// Returns the number of bytes buffered/written.
114    pub async fn write(&mut self, buf: &[u8]) -> Result<usize, TlsError> {
115        if self.opened {
116            if !self
117                .record_write_buf
118                .contains(ClientRecordHeader::ApplicationData)
119            {
120                self.flush().await?;
121                self.record_write_buf
122                    .start_record(ClientRecordHeader::ApplicationData)?;
123            }
124
125            let buffered = self.record_write_buf.append(buf);
126
127            if self.record_write_buf.is_full() {
128                self.flush().await?;
129            }
130
131            Ok(buffered)
132        } else {
133            Err(TlsError::MissingHandshake)
134        }
135    }
136
137    /// Force all previously written, buffered bytes to be encoded into a tls record and written
138    /// to the connection.
139    pub async fn flush(&mut self) -> Result<(), TlsError> {
140        if !self.record_write_buf.is_empty() {
141            let key_schedule = self.key_schedule.write_state();
142            let slice = self.record_write_buf.close_record(key_schedule)?;
143
144            self.delegate
145                .write_all(slice)
146                .await
147                .map_err(|e| TlsError::Io(e.kind()))?;
148
149            key_schedule.increment_counter();
150
151            self.delegate
152                .flush()
153                .await
154                .map_err(|e| TlsError::Io(e.kind()))?;
155        }
156
157        Ok(())
158    }
159
160    fn create_read_buffer(&mut self) -> ReadBuffer {
161        self.decrypted.create_read_buffer(self.record_reader.buf)
162    }
163
164    /// Read and decrypt data filling the provided slice.
165    pub async fn read(&mut self, buf: &mut [u8]) -> Result<usize, TlsError> {
166        if buf.is_empty() {
167            return Ok(0);
168        }
169        let mut buffer = self.read_buffered().await?;
170
171        let len = buffer.pop_into(buf);
172        trace!("Copied {} bytes", len);
173
174        Ok(len)
175    }
176
177    /// Reads buffered data. If nothing is in memory, it'll wait for a TLS record and process it.
178    pub async fn read_buffered(&mut self) -> Result<ReadBuffer, TlsError> {
179        if self.opened {
180            while self.decrypted.is_empty() {
181                self.read_application_data().await?;
182            }
183
184            Ok(self.create_read_buffer())
185        } else {
186            Err(TlsError::MissingHandshake)
187        }
188    }
189
190    async fn read_application_data(&mut self) -> Result<(), TlsError> {
191        let buf_ptr_range = self.record_reader.buf.as_ptr_range();
192        let record = self
193            .record_reader
194            .read(&mut self.delegate, self.key_schedule.read_state())
195            .await?;
196
197        let mut handler = DecryptedReadHandler {
198            source_buffer: buf_ptr_range,
199            buffer_info: &mut self.decrypted,
200            is_open: &mut self.opened,
201        };
202        decrypt_record(
203            self.key_schedule.read_state(),
204            record,
205            |_key_schedule, record| handler.handle(record),
206        )?;
207
208        Ok(())
209    }
210
211    /// Close a connection instance, returning the ownership of the config, random generator and the async I/O provider.
212    async fn close_internal(&mut self) -> Result<(), TlsError> {
213        self.flush().await?;
214
215        let (write_key_schedule, read_key_schedule) = self.key_schedule.as_split();
216        let slice = self.record_write_buf.write_record(
217            &ClientRecord::close_notify(self.opened),
218            write_key_schedule,
219            Some(read_key_schedule),
220        )?;
221
222        self.delegate
223            .write_all(slice)
224            .await
225            .map_err(|e| TlsError::Io(e.kind()))?;
226
227        self.key_schedule.write_state().increment_counter();
228
229        self.flush().await
230    }
231
232    /// Close a connection instance, returning the ownership of the async I/O provider.
233    pub async fn close(mut self) -> Result<Socket, (Socket, TlsError)> {
234        match self.close_internal().await {
235            Ok(()) => Ok(self.delegate),
236            Err(e) => Err((self.delegate, e)),
237        }
238    }
239
240    #[cfg(feature = "std")]
241    pub fn split(
242        self,
243    ) -> (
244        TlsReader<'a, Socket, CipherSuite, ManagedSplitState>,
245        TlsWriter<'a, Socket, CipherSuite, ManagedSplitState>,
246    )
247    where
248        Socket: Clone,
249    {
250        self.split_with(ManagedSplitState::new())
251    }
252
253    pub fn split_with<StateContainer>(
254        self,
255        state: StateContainer,
256    ) -> (
257        TlsReader<'a, Socket, CipherSuite, StateContainer::State>,
258        TlsWriter<'a, Socket, CipherSuite, StateContainer::State>,
259    )
260    where
261        Socket: Clone,
262        StateContainer: SplitStateContainer,
263    {
264        let state = state.state();
265        state.set_open(self.opened);
266
267        let (shared, wks, rks) = self.key_schedule.split();
268
269        let reader = TlsReader {
270            state: state.clone(),
271            delegate: self.delegate.clone(),
272            key_schedule: rks,
273            record_reader: self.record_reader,
274            decrypted: self.decrypted,
275        };
276        let writer = TlsWriter {
277            state,
278            delegate: self.delegate,
279            key_schedule_shared: shared,
280            key_schedule: wks,
281            record_write_buf: self.record_write_buf,
282        };
283
284        (reader, writer)
285    }
286
287    pub fn unsplit<State>(
288        reader: TlsReader<'a, Socket, CipherSuite, State>,
289        writer: TlsWriter<'a, Socket, CipherSuite, State>,
290    ) -> Self
291    where
292        Socket: Clone,
293        State: SplitState,
294    {
295        debug_assert!(reader.state.same(&writer.state));
296
297        TlsConnection {
298            delegate: writer.delegate,
299            opened: writer.state.is_open(),
300            key_schedule: KeySchedule::unsplit(
301                writer.key_schedule_shared,
302                writer.key_schedule,
303                reader.key_schedule,
304            ),
305            record_reader: reader.record_reader,
306            record_write_buf: writer.record_write_buf,
307            decrypted: reader.decrypted,
308        }
309    }
310}
311
312impl<'a, Socket, CipherSuite> ErrorType for TlsConnection<'a, Socket, CipherSuite>
313where
314    Socket: AsyncRead + AsyncWrite + 'a,
315    CipherSuite: TlsCipherSuite + 'static,
316{
317    type Error = TlsError;
318}
319
320impl<'a, Socket, CipherSuite> AsyncRead for TlsConnection<'a, Socket, CipherSuite>
321where
322    Socket: AsyncRead + AsyncWrite + 'a,
323    CipherSuite: TlsCipherSuite + 'static,
324{
325    async fn read(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error> {
326        TlsConnection::read(self, buf).await
327    }
328}
329
330impl<'a, Socket, CipherSuite> BufRead for TlsConnection<'a, Socket, CipherSuite>
331where
332    Socket: AsyncRead + AsyncWrite + 'a,
333    CipherSuite: TlsCipherSuite + 'static,
334{
335    async fn fill_buf(&mut self) -> Result<&[u8], Self::Error> {
336        self.read_buffered().await.map(|mut buf| buf.peek_all())
337    }
338
339    fn consume(&mut self, amt: usize) {
340        self.create_read_buffer().pop(amt);
341    }
342}
343
344impl<'a, Socket, CipherSuite> AsyncWrite for TlsConnection<'a, Socket, CipherSuite>
345where
346    Socket: AsyncRead + AsyncWrite + 'a,
347    CipherSuite: TlsCipherSuite + 'static,
348{
349    async fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
350        TlsConnection::write(self, buf).await
351    }
352
353    async fn flush(&mut self) -> Result<(), Self::Error> {
354        TlsConnection::flush(self).await
355    }
356}
357
358pub struct TlsReader<'a, Socket, CipherSuite, State>
359where
360    CipherSuite: TlsCipherSuite + 'static,
361{
362    state: State,
363    delegate: Socket,
364    key_schedule: ReadKeySchedule<CipherSuite>,
365    record_reader: RecordReader<'a, CipherSuite>,
366    decrypted: DecryptedBufferInfo,
367}
368
369impl<'a, Socket, CipherSuite, State> AsRef<Socket> for TlsReader<'a, Socket, CipherSuite, State>
370where
371    CipherSuite: TlsCipherSuite + 'static,
372{
373    fn as_ref(&self) -> &Socket {
374        &self.delegate
375    }
376}
377
378impl<'a, Socket, CipherSuite, State> TlsReader<'a, Socket, CipherSuite, State>
379where
380    Socket: AsyncRead + 'a,
381    CipherSuite: TlsCipherSuite + 'static,
382    State: SplitState,
383{
384    fn create_read_buffer(&mut self) -> ReadBuffer {
385        self.decrypted.create_read_buffer(self.record_reader.buf)
386    }
387
388    /// Reads buffered data. If nothing is in memory, it'll wait for a TLS record and process it.
389    pub async fn read_buffered(&mut self) -> Result<ReadBuffer, TlsError> {
390        if self.state.is_open() {
391            while self.decrypted.is_empty() {
392                self.read_application_data().await?;
393            }
394
395            Ok(self.create_read_buffer())
396        } else {
397            Err(TlsError::MissingHandshake)
398        }
399    }
400
401    async fn read_application_data(&mut self) -> Result<(), TlsError> {
402        let buf_ptr_range = self.record_reader.buf.as_ptr_range();
403        let record = self
404            .record_reader
405            .read(&mut self.delegate, &mut self.key_schedule)
406            .await?;
407
408        let mut opened = self.state.is_open();
409        let mut handler = DecryptedReadHandler {
410            source_buffer: buf_ptr_range,
411            buffer_info: &mut self.decrypted,
412            is_open: &mut opened,
413        };
414        let result = decrypt_record(&mut self.key_schedule, record, |_key_schedule, record| {
415            handler.handle(record)
416        });
417
418        if !opened {
419            self.state.set_open(false);
420        }
421        result
422    }
423}
424
425pub struct TlsWriter<'a, Socket, CipherSuite, State>
426where
427    CipherSuite: TlsCipherSuite + 'static,
428{
429    state: State,
430    delegate: Socket,
431    key_schedule_shared: SharedState<CipherSuite>,
432    key_schedule: WriteKeySchedule<CipherSuite>,
433    record_write_buf: WriteBuffer<'a>,
434}
435
436impl<'a, Socket, CipherSuite, State> AsRef<Socket> for TlsWriter<'a, Socket, CipherSuite, State>
437where
438    CipherSuite: TlsCipherSuite + 'static,
439{
440    fn as_ref(&self) -> &Socket {
441        &self.delegate
442    }
443}
444
445impl<'a, Socket, CipherSuite, State> ErrorType for TlsWriter<'a, Socket, CipherSuite, State>
446where
447    CipherSuite: TlsCipherSuite + 'static,
448{
449    type Error = TlsError;
450}
451
452impl<'a, Socket, CipherSuite, State> ErrorType for TlsReader<'a, Socket, CipherSuite, State>
453where
454    CipherSuite: TlsCipherSuite + 'static,
455{
456    type Error = TlsError;
457}
458
459impl<'a, Socket, CipherSuite, State> AsyncRead for TlsReader<'a, Socket, CipherSuite, State>
460where
461    Socket: AsyncRead + 'a,
462    CipherSuite: TlsCipherSuite + 'static,
463    State: SplitState,
464{
465    async fn read(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error> {
466        if buf.is_empty() {
467            return Ok(0);
468        }
469        let mut buffer = self.read_buffered().await?;
470
471        let len = buffer.pop_into(buf);
472        trace!("Copied {} bytes", len);
473
474        Ok(len)
475    }
476}
477
478impl<'a, Socket, CipherSuite, State> BufRead for TlsReader<'a, Socket, CipherSuite, State>
479where
480    Socket: AsyncRead + 'a,
481    CipherSuite: TlsCipherSuite + 'static,
482    State: SplitState,
483{
484    async fn fill_buf(&mut self) -> Result<&[u8], Self::Error> {
485        self.read_buffered().await.map(|mut buf| buf.peek_all())
486    }
487
488    fn consume(&mut self, amt: usize) {
489        self.create_read_buffer().pop(amt);
490    }
491}
492
493impl<'a, Socket, CipherSuite, State> AsyncWrite for TlsWriter<'a, Socket, CipherSuite, State>
494where
495    Socket: AsyncWrite + 'a,
496    CipherSuite: TlsCipherSuite + 'static,
497    State: SplitState,
498{
499    async fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
500        if self.state.is_open() {
501            if !self
502                .record_write_buf
503                .contains(ClientRecordHeader::ApplicationData)
504            {
505                self.flush().await?;
506                self.record_write_buf
507                    .start_record(ClientRecordHeader::ApplicationData)?;
508            }
509
510            let buffered = self.record_write_buf.append(buf);
511
512            if self.record_write_buf.is_full() {
513                self.flush().await?;
514            }
515
516            Ok(buffered)
517        } else {
518            Err(TlsError::MissingHandshake)
519        }
520    }
521
522    async fn flush(&mut self) -> Result<(), Self::Error> {
523        if !self.record_write_buf.is_empty() {
524            let slice = self.record_write_buf.close_record(&mut self.key_schedule)?;
525
526            self.delegate
527                .write_all(slice)
528                .await
529                .map_err(|e| TlsError::Io(e.kind()))?;
530
531            self.key_schedule.increment_counter();
532
533            self.delegate
534                .flush()
535                .await
536                .map_err(|e| TlsError::Io(e.kind()))?;
537        }
538
539        Ok(())
540    }
541}