embedded_tls/
blocking.rs

1use core::sync::atomic::Ordering;
2
3use crate::common::decrypted_buffer_info::DecryptedBufferInfo;
4use crate::common::decrypted_read_handler::DecryptedReadHandler;
5use crate::connection::{Handshake, State, decrypt_record};
6use crate::flush_policy::FlushPolicy;
7use crate::key_schedule::KeySchedule;
8use crate::key_schedule::{ReadKeySchedule, WriteKeySchedule};
9use crate::read_buffer::ReadBuffer;
10use crate::record::{ClientRecord, ClientRecordHeader};
11use crate::record_reader::{RecordReader, RecordReaderBorrowMut};
12use crate::write_buffer::{WriteBuffer, WriteBufferBorrowMut};
13use embedded_io::Error as _;
14use embedded_io::{BufRead, ErrorType, Read, Write};
15use portable_atomic::AtomicBool;
16
17pub use crate::TlsError;
18pub use crate::config::*;
19
20/// Type representing a TLS connection. An instance of this type can
21/// be used to establish a TLS connection, write and read encrypted data over this connection,
22/// and closing to free up the underlying resources.
23pub struct TlsConnection<'a, Socket, CipherSuite>
24where
25    Socket: Read + Write + 'a,
26    CipherSuite: TlsCipherSuite + 'static,
27{
28    delegate: Socket,
29    opened: AtomicBool,
30    key_schedule: KeySchedule<CipherSuite>,
31    record_reader: RecordReader<'a>,
32    record_write_buf: WriteBuffer<'a>,
33    decrypted: DecryptedBufferInfo,
34    flush_policy: FlushPolicy,
35}
36
37impl<'a, Socket, CipherSuite> TlsConnection<'a, Socket, CipherSuite>
38where
39    Socket: Read + Write + 'a,
40    CipherSuite: TlsCipherSuite + 'static,
41{
42    fn is_opened(&mut self) -> bool {
43        *self.opened.get_mut()
44    }
45
46    /// Create a new TLS connection with the provided context and a blocking I/O implementation
47    ///
48    /// NOTE: The record read buffer should be sized to fit an encrypted TLS record. The size of this record
49    /// depends on the server configuration, but the maximum allowed value for a TLS record is 16640 bytes,
50    /// which should be a safe value to use.
51    ///
52    /// The write record buffer can be smaller than the read buffer. During writes [`TLS_RECORD_OVERHEAD`] bytes of
53    /// overhead is added per record, so the buffer must at least be this large. Large writes are split into multiple
54    /// records if depending on the size of the write buffer.
55    /// The largest of the two buffers will be used to encode the TLS handshake record, hence either of the
56    /// buffers must at least be large enough to encode a handshake.
57    pub fn new(
58        delegate: Socket,
59        record_read_buf: &'a mut [u8],
60        record_write_buf: &'a mut [u8],
61    ) -> Self {
62        Self {
63            delegate,
64            opened: AtomicBool::new(false),
65            key_schedule: KeySchedule::new(),
66            record_reader: RecordReader::new(record_read_buf),
67            record_write_buf: WriteBuffer::new(record_write_buf),
68            decrypted: DecryptedBufferInfo::default(),
69            flush_policy: FlushPolicy::default(),
70        }
71    }
72
73    /// Returns a reference to the current flush policy.
74    ///
75    /// The flush policy controls whether the underlying transport is flushed
76    /// (via its `flush()` method) after writing a TLS record.
77    #[inline]
78    pub fn flush_policy(&self) -> FlushPolicy {
79        self.flush_policy
80    }
81
82    /// Replace the current flush policy with the provided one.
83    ///
84    /// This sets how and when the connection will call `flush()` on the
85    /// underlying transport after writing records.
86    #[inline]
87    pub fn set_flush_policy(&mut self, policy: FlushPolicy) {
88        self.flush_policy = policy;
89    }
90
91    /// Open a TLS connection, performing the handshake with the configuration provided when
92    /// creating the connection instance.
93    ///
94    /// Returns an error if the handshake does not proceed. If an error occurs, the connection
95    /// instance must be recreated.
96    pub fn open<Provider>(&mut self, mut context: TlsContext<Provider>) -> Result<(), TlsError>
97    where
98        Provider: CryptoProvider<CipherSuite = CipherSuite>,
99    {
100        let mut handshake: Handshake<CipherSuite> = Handshake::new();
101        if let (Ok(verifier), Some(server_name)) = (
102            context.crypto_provider.verifier(),
103            context.config.server_name,
104        ) {
105            verifier.set_hostname_verification(server_name)?;
106        }
107        let mut state = State::ClientHello;
108
109        while state != State::ApplicationData {
110            let next_state = state.process_blocking(
111                &mut self.delegate,
112                &mut handshake,
113                &mut self.record_reader,
114                &mut self.record_write_buf,
115                &mut self.key_schedule,
116                context.config,
117                &mut context.crypto_provider,
118            )?;
119            trace!("State {:?} -> {:?}", state, next_state);
120            state = next_state;
121        }
122        *self.opened.get_mut() = true;
123
124        Ok(())
125    }
126
127    /// Encrypt and send the provided slice over the connection. The connection
128    /// must be opened before writing.
129    ///
130    /// The slice may be buffered internally and not written to the connection immediately.
131    /// In this case [`Self::flush()`] should be called to force the currently buffered writes
132    /// to be written to the connection.
133    ///
134    /// Returns the number of bytes buffered/written.
135    pub fn write(&mut self, buf: &[u8]) -> Result<usize, TlsError> {
136        if self.is_opened() {
137            if !self
138                .record_write_buf
139                .contains(ClientRecordHeader::ApplicationData)
140            {
141                self.flush()?;
142                self.record_write_buf
143                    .start_record(ClientRecordHeader::ApplicationData)?;
144            }
145
146            let buffered = self.record_write_buf.append(buf);
147
148            if self.record_write_buf.is_full() {
149                self.flush()?;
150            }
151
152            Ok(buffered)
153        } else {
154            Err(TlsError::MissingHandshake)
155        }
156    }
157
158    /// Force all previously written, buffered bytes to be encoded into a tls record and written
159    /// to the connection.
160    pub fn flush(&mut self) -> Result<(), TlsError> {
161        if !self.record_write_buf.is_empty() {
162            let key_schedule = self.key_schedule.write_state();
163            let slice = self.record_write_buf.close_record(key_schedule)?;
164
165            self.delegate
166                .write_all(slice)
167                .map_err(|e| TlsError::Io(e.kind()))?;
168
169            key_schedule.increment_counter();
170
171            if self.flush_policy.flush_transport() {
172                self.flush_transport()?;
173            }
174        }
175
176        Ok(())
177    }
178
179    #[inline]
180    fn flush_transport(&mut self) -> Result<(), TlsError> {
181        self.delegate.flush().map_err(|e| TlsError::Io(e.kind()))
182    }
183
184    fn create_read_buffer(&mut self) -> ReadBuffer<'_> {
185        self.decrypted.create_read_buffer(self.record_reader.buf)
186    }
187
188    /// Read and decrypt data filling the provided slice.
189    pub fn read(&mut self, buf: &mut [u8]) -> Result<usize, TlsError> {
190        if buf.is_empty() {
191            return Ok(0);
192        }
193        let mut buffer = self.read_buffered()?;
194
195        let len = buffer.pop_into(buf);
196        trace!("Copied {} bytes", len);
197
198        Ok(len)
199    }
200
201    /// Reads buffered data. If nothing is in memory, it'll wait for a TLS record and process it.
202    pub fn read_buffered(&mut self) -> Result<ReadBuffer<'_>, TlsError> {
203        if self.is_opened() {
204            while self.decrypted.is_empty() {
205                self.read_application_data()?;
206            }
207
208            Ok(self.create_read_buffer())
209        } else {
210            Err(TlsError::MissingHandshake)
211        }
212    }
213
214    fn read_application_data(&mut self) -> Result<(), TlsError> {
215        let buf_ptr_range = self.record_reader.buf.as_ptr_range();
216        let key_schedule = self.key_schedule.read_state();
217        let record = self
218            .record_reader
219            .read_blocking(&mut self.delegate, key_schedule)?;
220
221        let mut handler = DecryptedReadHandler {
222            source_buffer: buf_ptr_range,
223            buffer_info: &mut self.decrypted,
224            is_open: self.opened.get_mut(),
225        };
226        decrypt_record(key_schedule, record, |_key_schedule, record| {
227            handler.handle(record)
228        })?;
229
230        Ok(())
231    }
232
233    fn close_internal(&mut self) -> Result<(), TlsError> {
234        self.flush()?;
235
236        let is_opened = self.is_opened();
237        let (write_key_schedule, read_key_schedule) = self.key_schedule.as_split();
238        let slice = self.record_write_buf.write_record(
239            &ClientRecord::close_notify(is_opened),
240            write_key_schedule,
241            Some(read_key_schedule),
242        )?;
243
244        self.delegate
245            .write_all(slice)
246            .map_err(|e| TlsError::Io(e.kind()))?;
247
248        self.key_schedule.write_state().increment_counter();
249
250        self.flush_transport()?;
251
252        Ok(())
253    }
254
255    /// Close a connection instance, returning the ownership of the I/O provider.
256    pub fn close(mut self) -> Result<Socket, (Socket, TlsError)> {
257        match self.close_internal() {
258            Ok(()) => Ok(self.delegate),
259            Err(e) => Err((self.delegate, e)),
260        }
261    }
262
263    pub fn split(
264        &mut self,
265    ) -> (
266        TlsReader<'_, Socket, CipherSuite>,
267        TlsWriter<'_, Socket, CipherSuite>,
268    )
269    where
270        Socket: Clone,
271    {
272        let (wks, rks) = self.key_schedule.as_split();
273
274        let reader = TlsReader {
275            opened: &self.opened,
276            delegate: self.delegate.clone(),
277            key_schedule: rks,
278            record_reader: self.record_reader.reborrow_mut(),
279            decrypted: &mut self.decrypted,
280        };
281        let writer = TlsWriter {
282            opened: &self.opened,
283            delegate: self.delegate.clone(),
284            key_schedule: wks,
285            record_write_buf: self.record_write_buf.reborrow_mut(),
286            flush_policy: self.flush_policy,
287        };
288
289        (reader, writer)
290    }
291}
292
293impl<'a, Socket, CipherSuite> ErrorType for TlsConnection<'a, Socket, CipherSuite>
294where
295    Socket: Read + Write + 'a,
296    CipherSuite: TlsCipherSuite + 'static,
297{
298    type Error = TlsError;
299}
300
301impl<'a, Socket, CipherSuite> Read for TlsConnection<'a, Socket, CipherSuite>
302where
303    Socket: Read + Write + 'a,
304    CipherSuite: TlsCipherSuite + 'static,
305{
306    fn read(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error> {
307        TlsConnection::read(self, buf)
308    }
309}
310
311impl<'a, Socket, CipherSuite> BufRead for TlsConnection<'a, Socket, CipherSuite>
312where
313    Socket: Read + Write + 'a,
314    CipherSuite: TlsCipherSuite + 'static,
315{
316    fn fill_buf(&mut self) -> Result<&[u8], Self::Error> {
317        self.read_buffered().map(|mut buf| buf.peek_all())
318    }
319
320    fn consume(&mut self, amt: usize) {
321        self.create_read_buffer().pop(amt);
322    }
323}
324
325impl<'a, Socket, CipherSuite> Write for TlsConnection<'a, Socket, CipherSuite>
326where
327    Socket: Read + Write + 'a,
328    CipherSuite: TlsCipherSuite + 'static,
329{
330    fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
331        TlsConnection::write(self, buf)
332    }
333
334    fn flush(&mut self) -> Result<(), Self::Error> {
335        TlsConnection::flush(self)
336    }
337}
338
339pub struct TlsReader<'a, Socket, CipherSuite>
340where
341    CipherSuite: TlsCipherSuite + 'static,
342{
343    opened: &'a AtomicBool,
344    delegate: Socket,
345    key_schedule: &'a mut ReadKeySchedule<CipherSuite>,
346    record_reader: RecordReaderBorrowMut<'a>,
347    decrypted: &'a mut DecryptedBufferInfo,
348}
349
350impl<Socket, CipherSuite> AsRef<Socket> for TlsReader<'_, Socket, CipherSuite>
351where
352    CipherSuite: TlsCipherSuite + 'static,
353{
354    fn as_ref(&self) -> &Socket {
355        &self.delegate
356    }
357}
358
359impl<'a, Socket, CipherSuite> TlsReader<'a, Socket, CipherSuite>
360where
361    Socket: Read + 'a,
362    CipherSuite: TlsCipherSuite + 'static,
363{
364    fn create_read_buffer(&mut self) -> ReadBuffer<'_> {
365        self.decrypted.create_read_buffer(self.record_reader.buf)
366    }
367
368    /// Reads buffered data. If nothing is in memory, it'll wait for a TLS record and process it.
369    pub fn read_buffered(&mut self) -> Result<ReadBuffer<'_>, TlsError> {
370        if self.opened.load(Ordering::Acquire) {
371            while self.decrypted.is_empty() {
372                self.read_application_data()?;
373            }
374
375            Ok(self.create_read_buffer())
376        } else {
377            Err(TlsError::MissingHandshake)
378        }
379    }
380
381    fn read_application_data(&mut self) -> Result<(), TlsError> {
382        let buf_ptr_range = self.record_reader.buf.as_ptr_range();
383        let record = self
384            .record_reader
385            .read_blocking(&mut self.delegate, self.key_schedule)?;
386
387        let mut opened = self.opened.load(Ordering::Acquire);
388        let mut handler = DecryptedReadHandler {
389            source_buffer: buf_ptr_range,
390            buffer_info: self.decrypted,
391            is_open: &mut opened,
392        };
393        let result = decrypt_record(self.key_schedule, record, |_key_schedule, record| {
394            handler.handle(record)
395        });
396
397        if !opened {
398            self.opened.store(false, Ordering::Release);
399        }
400        result
401    }
402}
403
404pub struct TlsWriter<'a, Socket, CipherSuite>
405where
406    CipherSuite: TlsCipherSuite + 'static,
407{
408    opened: &'a AtomicBool,
409    delegate: Socket,
410    key_schedule: &'a mut WriteKeySchedule<CipherSuite>,
411    record_write_buf: WriteBufferBorrowMut<'a>,
412    flush_policy: FlushPolicy,
413}
414
415impl<'a, Socket, CipherSuite> TlsWriter<'a, Socket, CipherSuite>
416where
417    Socket: Write + 'a,
418    CipherSuite: TlsCipherSuite + 'static,
419{
420    fn flush_transport(&mut self) -> Result<(), TlsError> {
421        self.delegate.flush().map_err(|e| TlsError::Io(e.kind()))
422    }
423}
424
425impl<Socket, CipherSuite> AsRef<Socket> for TlsWriter<'_, Socket, CipherSuite>
426where
427    CipherSuite: TlsCipherSuite + 'static,
428{
429    fn as_ref(&self) -> &Socket {
430        &self.delegate
431    }
432}
433
434impl<Socket, CipherSuite> ErrorType for TlsWriter<'_, Socket, CipherSuite>
435where
436    CipherSuite: TlsCipherSuite + 'static,
437{
438    type Error = TlsError;
439}
440
441impl<Socket, CipherSuite> ErrorType for TlsReader<'_, Socket, CipherSuite>
442where
443    CipherSuite: TlsCipherSuite + 'static,
444{
445    type Error = TlsError;
446}
447
448impl<'a, Socket, CipherSuite> Read for TlsReader<'a, Socket, CipherSuite>
449where
450    Socket: Read + 'a,
451    CipherSuite: TlsCipherSuite + 'static,
452{
453    fn read(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error> {
454        if buf.is_empty() {
455            return Ok(0);
456        }
457        let mut buffer = self.read_buffered()?;
458
459        let len = buffer.pop_into(buf);
460        trace!("Copied {} bytes", len);
461
462        Ok(len)
463    }
464}
465
466impl<'a, Socket, CipherSuite> BufRead for TlsReader<'a, Socket, CipherSuite>
467where
468    Socket: Read + 'a,
469    CipherSuite: TlsCipherSuite + 'static,
470{
471    fn fill_buf(&mut self) -> Result<&[u8], Self::Error> {
472        self.read_buffered().map(|mut buf| buf.peek_all())
473    }
474
475    fn consume(&mut self, amt: usize) {
476        self.create_read_buffer().pop(amt);
477    }
478}
479
480impl<'a, Socket, CipherSuite> Write for TlsWriter<'a, Socket, CipherSuite>
481where
482    Socket: Write + 'a,
483    CipherSuite: TlsCipherSuite + 'static,
484{
485    fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
486        if self.opened.load(Ordering::Acquire) {
487            if !self
488                .record_write_buf
489                .contains(ClientRecordHeader::ApplicationData)
490            {
491                self.flush()?;
492                self.record_write_buf
493                    .start_record(ClientRecordHeader::ApplicationData)?;
494            }
495
496            let buffered = self.record_write_buf.append(buf);
497
498            if self.record_write_buf.is_full() {
499                self.flush()?;
500            }
501
502            Ok(buffered)
503        } else {
504            Err(TlsError::MissingHandshake)
505        }
506    }
507
508    fn flush(&mut self) -> Result<(), Self::Error> {
509        if !self.record_write_buf.is_empty() {
510            let slice = self.record_write_buf.close_record(self.key_schedule)?;
511
512            self.delegate
513                .write_all(slice)
514                .map_err(|e| TlsError::Io(e.kind()))?;
515
516            self.key_schedule.increment_counter();
517
518            if self.flush_policy.flush_transport() {
519                self.flush_transport()?;
520            }
521        }
522
523        Ok(())
524    }
525}