embedded_tls/
blocking.rs

1use crate::common::decrypted_buffer_info::DecryptedBufferInfo;
2use crate::common::decrypted_read_handler::DecryptedReadHandler;
3use crate::connection::*;
4use crate::key_schedule::KeySchedule;
5use crate::key_schedule::{ReadKeySchedule, SharedState, WriteKeySchedule};
6use crate::read_buffer::ReadBuffer;
7use crate::record::{ClientRecord, ClientRecordHeader};
8use crate::record_reader::RecordReader;
9use crate::split::{SplitState, SplitStateContainer};
10use crate::write_buffer::WriteBuffer;
11use embedded_io::Error as _;
12use embedded_io::{BufRead, ErrorType, Read, Write};
13use rand_core::{CryptoRng, RngCore};
14
15pub use crate::config::*;
16#[cfg(feature = "std")]
17pub use crate::split::ManagedSplitState;
18pub use crate::split::SplitConnectionState;
19pub use crate::TlsError;
20
21/// Type representing a TLS connection. An instance of this type can
22/// be used to establish a TLS connection, write and read encrypted data over this connection,
23/// and closing to free up the underlying resources.
24pub struct TlsConnection<'a, Socket, CipherSuite>
25where
26    Socket: Read + Write + 'a,
27    CipherSuite: TlsCipherSuite + 'static,
28{
29    delegate: Socket,
30    opened: bool,
31    key_schedule: KeySchedule<CipherSuite>,
32    record_reader: RecordReader<'a, CipherSuite>,
33    record_write_buf: WriteBuffer<'a>,
34    decrypted: DecryptedBufferInfo,
35}
36
37impl<'a, Socket, CipherSuite> TlsConnection<'a, Socket, CipherSuite>
38where
39    Socket: Read + Write + 'a,
40    CipherSuite: TlsCipherSuite + 'static,
41{
42    /// Create a new TLS connection with the provided context and a blocking I/O implementation
43    ///
44    /// NOTE: The record read buffer should be sized to fit an encrypted TLS record. The size of this record
45    /// depends on the server configuration, but the maximum allowed value for a TLS record is 16640 bytes,
46    /// which should be a safe value to use.
47    ///
48    /// The write record buffer can be smaller than the read buffer. During writes [`TLS_RECORD_OVERHEAD`] bytes of
49    /// overhead is added per record, so the buffer must at least be this large. Large writes are split into multiple
50    /// records if depending on the size of the write buffer.
51    /// The largest of the two buffers will be used to encode the TLS handshake record, hence either of the
52    /// buffers must at least be large enough to encode a handshake.
53    pub fn new(
54        delegate: Socket,
55        record_read_buf: &'a mut [u8],
56        record_write_buf: &'a mut [u8],
57    ) -> Self {
58        Self {
59            delegate,
60            opened: false,
61            key_schedule: KeySchedule::new(),
62            record_reader: RecordReader::new(record_read_buf),
63            record_write_buf: WriteBuffer::new(record_write_buf),
64            decrypted: DecryptedBufferInfo::default(),
65        }
66    }
67
68    /// Open a TLS connection, performing the handshake with the configuration provided when
69    /// creating the connection instance.
70    ///
71    /// Returns an error if the handshake does not proceed. If an error occurs, the connection
72    /// instance must be recreated.
73    pub fn open<'v, RNG, Verifier>(
74        &mut self,
75        context: TlsContext<'v, CipherSuite, RNG>,
76    ) -> Result<(), TlsError>
77    where
78        RNG: CryptoRng + RngCore,
79        Verifier: TlsVerifier<'v, CipherSuite>,
80    {
81        let mut handshake: Handshake<CipherSuite, Verifier> =
82            Handshake::new(Verifier::new(context.config.server_name));
83        let mut state = State::ClientHello;
84
85        while state != State::ApplicationData {
86            let next_state = state.process_blocking(
87                &mut self.delegate,
88                &mut handshake,
89                &mut self.record_reader,
90                &mut self.record_write_buf,
91                &mut self.key_schedule,
92                context.config,
93                context.rng,
94            )?;
95            trace!("State {:?} -> {:?}", state, next_state);
96            state = next_state;
97        }
98        self.opened = true;
99
100        Ok(())
101    }
102
103    /// Encrypt and send the provided slice over the connection. The connection
104    /// must be opened before writing.
105    ///
106    /// The slice may be buffered internally and not written to the connection immediately.
107    /// In this case [`Self::flush()`] should be called to force the currently buffered writes
108    /// to be written to the connection.
109    ///
110    /// Returns the number of bytes buffered/written.
111    pub fn write(&mut self, buf: &[u8]) -> Result<usize, TlsError> {
112        if self.opened {
113            if !self
114                .record_write_buf
115                .contains(ClientRecordHeader::ApplicationData)
116            {
117                self.flush()?;
118                self.record_write_buf
119                    .start_record(ClientRecordHeader::ApplicationData)?;
120            }
121
122            let buffered = self.record_write_buf.append(buf);
123
124            if self.record_write_buf.is_full() {
125                self.flush()?;
126            }
127
128            Ok(buffered)
129        } else {
130            Err(TlsError::MissingHandshake)
131        }
132    }
133
134    /// Force all previously written, buffered bytes to be encoded into a tls record and written
135    /// to the connection.
136    pub fn flush(&mut self) -> Result<(), TlsError> {
137        if !self.record_write_buf.is_empty() {
138            let key_schedule = self.key_schedule.write_state();
139            let slice = self.record_write_buf.close_record(key_schedule)?;
140
141            self.delegate
142                .write_all(slice)
143                .map_err(|e| TlsError::Io(e.kind()))?;
144
145            key_schedule.increment_counter();
146
147            self.delegate.flush().map_err(|e| TlsError::Io(e.kind()))?;
148        }
149
150        Ok(())
151    }
152
153    fn create_read_buffer(&mut self) -> ReadBuffer {
154        self.decrypted.create_read_buffer(self.record_reader.buf)
155    }
156
157    /// Read and decrypt data filling the provided slice.
158    pub fn read(&mut self, buf: &mut [u8]) -> Result<usize, TlsError> {
159        if buf.is_empty() {
160            return Ok(0);
161        }
162        let mut buffer = self.read_buffered()?;
163
164        let len = buffer.pop_into(buf);
165        trace!("Copied {} bytes", len);
166
167        Ok(len)
168    }
169
170    /// Reads buffered data. If nothing is in memory, it'll wait for a TLS record and process it.
171    pub fn read_buffered(&mut self) -> Result<ReadBuffer, TlsError> {
172        if self.opened {
173            while self.decrypted.is_empty() {
174                self.read_application_data()?;
175            }
176
177            Ok(self.create_read_buffer())
178        } else {
179            Err(TlsError::MissingHandshake)
180        }
181    }
182
183    fn read_application_data(&mut self) -> Result<(), TlsError> {
184        let buf_ptr_range = self.record_reader.buf.as_ptr_range();
185        let key_schedule = self.key_schedule.read_state();
186        let record = self
187            .record_reader
188            .read_blocking(&mut self.delegate, key_schedule)?;
189
190        let mut handler = DecryptedReadHandler {
191            source_buffer: buf_ptr_range,
192            buffer_info: &mut self.decrypted,
193            is_open: &mut self.opened,
194        };
195        decrypt_record(key_schedule, record, |_key_schedule, record| {
196            handler.handle(record)
197        })?;
198
199        Ok(())
200    }
201
202    fn close_internal(&mut self) -> Result<(), TlsError> {
203        self.flush()?;
204
205        let (write_key_schedule, read_key_schedule) = self.key_schedule.as_split();
206        let slice = self.record_write_buf.write_record(
207            &ClientRecord::close_notify(self.opened),
208            write_key_schedule,
209            Some(read_key_schedule),
210        )?;
211
212        self.delegate
213            .write_all(slice)
214            .map_err(|e| TlsError::Io(e.kind()))?;
215
216        self.key_schedule.write_state().increment_counter();
217
218        self.flush()?;
219
220        Ok(())
221    }
222
223    /// Close a connection instance, returning the ownership of the I/O provider.
224    pub fn close(mut self) -> Result<Socket, (Socket, TlsError)> {
225        match self.close_internal() {
226            Ok(()) => Ok(self.delegate),
227            Err(e) => Err((self.delegate, e)),
228        }
229    }
230
231    #[cfg(feature = "std")]
232    pub fn split(
233        self,
234    ) -> (
235        TlsReader<'a, Socket, CipherSuite, ManagedSplitState>,
236        TlsWriter<'a, Socket, CipherSuite, ManagedSplitState>,
237    )
238    where
239        Socket: Clone,
240    {
241        self.split_with(ManagedSplitState::new())
242    }
243
244    pub fn split_with<StateContainer>(
245        self,
246        state: StateContainer,
247    ) -> (
248        TlsReader<'a, Socket, CipherSuite, StateContainer::State>,
249        TlsWriter<'a, Socket, CipherSuite, StateContainer::State>,
250    )
251    where
252        Socket: Clone,
253        StateContainer: SplitStateContainer,
254    {
255        let state = state.state();
256        state.set_open(self.opened);
257
258        let (shared, wks, rks) = self.key_schedule.split();
259
260        let reader = TlsReader {
261            state: state.clone(),
262            delegate: self.delegate.clone(),
263            key_schedule: rks,
264            record_reader: self.record_reader,
265            decrypted: self.decrypted,
266        };
267        let writer = TlsWriter {
268            state,
269            delegate: self.delegate,
270            key_schedule_shared: shared,
271            key_schedule: wks,
272            record_write_buf: self.record_write_buf,
273        };
274
275        (reader, writer)
276    }
277
278    pub fn unsplit<State>(
279        reader: TlsReader<'a, Socket, CipherSuite, State>,
280        writer: TlsWriter<'a, Socket, CipherSuite, State>,
281    ) -> Self
282    where
283        Socket: Clone,
284        State: SplitState,
285    {
286        debug_assert!(reader.state.same(&writer.state));
287
288        TlsConnection {
289            delegate: writer.delegate,
290            opened: writer.state.is_open(),
291            key_schedule: KeySchedule::unsplit(
292                writer.key_schedule_shared,
293                writer.key_schedule,
294                reader.key_schedule,
295            ),
296            record_reader: reader.record_reader,
297            record_write_buf: writer.record_write_buf,
298            decrypted: reader.decrypted,
299        }
300    }
301}
302
303impl<'a, Socket, CipherSuite> ErrorType for TlsConnection<'a, Socket, CipherSuite>
304where
305    Socket: Read + Write + 'a,
306    CipherSuite: TlsCipherSuite + 'static,
307{
308    type Error = TlsError;
309}
310
311impl<'a, Socket, CipherSuite> Read for TlsConnection<'a, Socket, CipherSuite>
312where
313    Socket: Read + Write + 'a,
314    CipherSuite: TlsCipherSuite + 'static,
315{
316    fn read(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error> {
317        TlsConnection::read(self, buf)
318    }
319}
320
321impl<'a, Socket, CipherSuite> BufRead for TlsConnection<'a, Socket, CipherSuite>
322where
323    Socket: Read + Write + 'a,
324    CipherSuite: TlsCipherSuite + 'static,
325{
326    fn fill_buf(&mut self) -> Result<&[u8], Self::Error> {
327        self.read_buffered().map(|mut buf| buf.peek_all())
328    }
329
330    fn consume(&mut self, amt: usize) {
331        self.create_read_buffer().pop(amt);
332    }
333}
334
335impl<'a, Socket, CipherSuite> Write for TlsConnection<'a, Socket, CipherSuite>
336where
337    Socket: Read + Write + 'a,
338    CipherSuite: TlsCipherSuite + 'static,
339{
340    fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
341        TlsConnection::write(self, buf)
342    }
343
344    fn flush(&mut self) -> Result<(), Self::Error> {
345        TlsConnection::flush(self)
346    }
347}
348
349pub struct TlsReader<'a, Socket, CipherSuite, State>
350where
351    CipherSuite: TlsCipherSuite + 'static,
352{
353    state: State,
354    delegate: Socket,
355    key_schedule: ReadKeySchedule<CipherSuite>,
356    record_reader: RecordReader<'a, CipherSuite>,
357    decrypted: DecryptedBufferInfo,
358}
359
360impl<'a, Socket, CipherSuite, State> AsRef<Socket> for TlsReader<'a, Socket, CipherSuite, State>
361where
362    CipherSuite: TlsCipherSuite + 'static,
363{
364    fn as_ref(&self) -> &Socket {
365        &self.delegate
366    }
367}
368
369impl<'a, Socket, CipherSuite, State> TlsReader<'a, Socket, CipherSuite, State>
370where
371    Socket: Read + 'a,
372    CipherSuite: TlsCipherSuite + 'static,
373    State: SplitState,
374{
375    fn create_read_buffer(&mut self) -> ReadBuffer {
376        self.decrypted.create_read_buffer(self.record_reader.buf)
377    }
378
379    /// Reads buffered data. If nothing is in memory, it'll wait for a TLS record and process it.
380    pub fn read_buffered(&mut self) -> Result<ReadBuffer, TlsError> {
381        if self.state.is_open() {
382            while self.decrypted.is_empty() {
383                self.read_application_data()?;
384            }
385
386            Ok(self.create_read_buffer())
387        } else {
388            Err(TlsError::MissingHandshake)
389        }
390    }
391
392    fn read_application_data(&mut self) -> Result<(), TlsError> {
393        let buf_ptr_range = self.record_reader.buf.as_ptr_range();
394        let record = self
395            .record_reader
396            .read_blocking(&mut self.delegate, &mut self.key_schedule)?;
397
398        let mut opened = self.state.is_open();
399        let mut handler = DecryptedReadHandler {
400            source_buffer: buf_ptr_range,
401            buffer_info: &mut self.decrypted,
402            is_open: &mut opened,
403        };
404        let result = decrypt_record(&mut self.key_schedule, record, |_key_schedule, record| {
405            handler.handle(record)
406        });
407
408        if !opened {
409            self.state.set_open(false);
410        }
411        result
412    }
413}
414
415pub struct TlsWriter<'a, Socket, CipherSuite, State>
416where
417    CipherSuite: TlsCipherSuite + 'static,
418{
419    state: State,
420    delegate: Socket,
421    key_schedule_shared: SharedState<CipherSuite>,
422    key_schedule: WriteKeySchedule<CipherSuite>,
423    record_write_buf: WriteBuffer<'a>,
424}
425
426impl<'a, Socket, CipherSuite, State> AsRef<Socket> for TlsWriter<'a, Socket, CipherSuite, State>
427where
428    CipherSuite: TlsCipherSuite + 'static,
429{
430    fn as_ref(&self) -> &Socket {
431        &self.delegate
432    }
433}
434
435impl<'a, Socket, CipherSuite, State> ErrorType for TlsWriter<'a, Socket, CipherSuite, State>
436where
437    CipherSuite: TlsCipherSuite + 'static,
438{
439    type Error = TlsError;
440}
441
442impl<'a, Socket, CipherSuite, State> ErrorType for TlsReader<'a, Socket, CipherSuite, State>
443where
444    CipherSuite: TlsCipherSuite + 'static,
445{
446    type Error = TlsError;
447}
448
449impl<'a, Socket, CipherSuite, State> Read for TlsReader<'a, Socket, CipherSuite, State>
450where
451    Socket: Read + 'a,
452    CipherSuite: TlsCipherSuite + 'static,
453    State: SplitState,
454{
455    fn read(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error> {
456        if buf.is_empty() {
457            return Ok(0);
458        }
459        let mut buffer = self.read_buffered()?;
460
461        let len = buffer.pop_into(buf);
462        trace!("Copied {} bytes", len);
463
464        Ok(len)
465    }
466}
467
468impl<'a, Socket, CipherSuite, State> BufRead for TlsReader<'a, Socket, CipherSuite, State>
469where
470    Socket: Read + 'a,
471    CipherSuite: TlsCipherSuite + 'static,
472    State: SplitState,
473{
474    fn fill_buf(&mut self) -> Result<&[u8], Self::Error> {
475        self.read_buffered().map(|mut buf| buf.peek_all())
476    }
477
478    fn consume(&mut self, amt: usize) {
479        self.create_read_buffer().pop(amt);
480    }
481}
482
483impl<'a, Socket, CipherSuite, State> Write for TlsWriter<'a, Socket, CipherSuite, State>
484where
485    Socket: Write + 'a,
486    CipherSuite: TlsCipherSuite + 'static,
487    State: SplitState,
488{
489    fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
490        if self.state.is_open() {
491            if !self
492                .record_write_buf
493                .contains(ClientRecordHeader::ApplicationData)
494            {
495                self.flush()?;
496                self.record_write_buf
497                    .start_record(ClientRecordHeader::ApplicationData)?;
498            }
499
500            let buffered = self.record_write_buf.append(buf);
501
502            if self.record_write_buf.is_full() {
503                self.flush()?;
504            }
505
506            Ok(buffered)
507        } else {
508            Err(TlsError::MissingHandshake)
509        }
510    }
511
512    fn flush(&mut self) -> Result<(), Self::Error> {
513        if !self.record_write_buf.is_empty() {
514            let slice = self.record_write_buf.close_record(&mut self.key_schedule)?;
515
516            self.delegate
517                .write_all(slice)
518                .map_err(|e| TlsError::Io(e.kind()))?;
519
520            self.key_schedule.increment_counter();
521
522            self.delegate.flush().map_err(|e| TlsError::Io(e.kind()))?;
523        }
524
525        Ok(())
526    }
527}