1use core::sync::atomic::{AtomicBool, Ordering};
2
3use crate::TlsError;
4use crate::common::decrypted_buffer_info::DecryptedBufferInfo;
5use crate::common::decrypted_read_handler::DecryptedReadHandler;
6use crate::connection::{Handshake, State, decrypt_record};
7use crate::flush_policy::FlushPolicy;
8use crate::key_schedule::KeySchedule;
9use crate::key_schedule::{ReadKeySchedule, WriteKeySchedule};
10use crate::read_buffer::ReadBuffer;
11use crate::record::{ClientRecord, ClientRecordHeader};
12use crate::record_reader::{RecordReader, RecordReaderBorrowMut};
13use crate::write_buffer::{WriteBuffer, WriteBufferBorrowMut};
14use embedded_io::Error as _;
15use embedded_io::ErrorType;
16use embedded_io_async::{BufRead, Read as AsyncRead, Write as AsyncWrite};
17
18pub use crate::config::*;
19
20pub struct TlsConnection<'a, Socket, CipherSuite>
24where
25 Socket: AsyncRead + AsyncWrite + 'a,
26 CipherSuite: TlsCipherSuite + 'static,
27{
28 delegate: Socket,
29 opened: AtomicBool,
30 key_schedule: KeySchedule<CipherSuite>,
31 record_reader: RecordReader<'a>,
32 record_write_buf: WriteBuffer<'a>,
33 decrypted: DecryptedBufferInfo,
34 flush_policy: FlushPolicy,
35}
36
37impl<'a, Socket, CipherSuite> TlsConnection<'a, Socket, CipherSuite>
38where
39 Socket: AsyncRead + AsyncWrite + 'a,
40 CipherSuite: TlsCipherSuite + 'static,
41{
42 pub fn is_opened(&mut self) -> bool {
43 *self.opened.get_mut()
44 }
45 pub fn new(
57 delegate: Socket,
58 record_read_buf: &'a mut [u8],
59 record_write_buf: &'a mut [u8],
60 ) -> Self {
61 Self {
62 delegate,
63 opened: AtomicBool::new(false),
64 key_schedule: KeySchedule::new(),
65 record_reader: RecordReader::new(record_read_buf),
66 record_write_buf: WriteBuffer::new(record_write_buf),
67 decrypted: DecryptedBufferInfo::default(),
68 flush_policy: FlushPolicy::default(),
69 }
70 }
71
72 #[inline]
77 pub fn flush_policy(&self) -> FlushPolicy {
78 self.flush_policy
79 }
80
81 #[inline]
86 pub fn set_flush_policy(&mut self, policy: FlushPolicy) {
87 self.flush_policy = policy;
88 }
89
90 pub async fn open<Provider>(
96 &mut self,
97 mut context: TlsContext<'_, Provider>,
98 ) -> Result<(), TlsError>
99 where
100 Provider: CryptoProvider<CipherSuite = CipherSuite>,
101 {
102 let mut handshake: Handshake<CipherSuite> = Handshake::new();
103 if let (Ok(verifier), Some(server_name)) = (
104 context.crypto_provider.verifier(),
105 context.config.server_name,
106 ) {
107 verifier.set_hostname_verification(server_name)?;
108 }
109 let mut state = State::ClientHello;
110
111 while state != State::ApplicationData {
112 let next_state = state
113 .process(
114 &mut self.delegate,
115 &mut handshake,
116 &mut self.record_reader,
117 &mut self.record_write_buf,
118 &mut self.key_schedule,
119 context.config,
120 &mut context.crypto_provider,
121 )
122 .await?;
123 trace!("State {:?} -> {:?}", state, next_state);
124 state = next_state;
125 }
126 *self.opened.get_mut() = true;
127
128 Ok(())
129 }
130
131 pub async fn write(&mut self, buf: &[u8]) -> Result<usize, TlsError> {
140 if self.is_opened() {
141 if !self
142 .record_write_buf
143 .contains(ClientRecordHeader::ApplicationData)
144 {
145 self.flush().await?;
146 self.record_write_buf
147 .start_record(ClientRecordHeader::ApplicationData)?;
148 }
149
150 let buffered = self.record_write_buf.append(buf);
151
152 if self.record_write_buf.is_full() {
153 self.flush().await?;
154 }
155
156 Ok(buffered)
157 } else {
158 Err(TlsError::MissingHandshake)
159 }
160 }
161
162 pub async fn flush(&mut self) -> Result<(), TlsError> {
165 if !self.record_write_buf.is_empty() {
166 let key_schedule = self.key_schedule.write_state();
167 let slice = self.record_write_buf.close_record(key_schedule)?;
168
169 self.delegate
170 .write_all(slice)
171 .await
172 .map_err(|e| TlsError::Io(e.kind()))?;
173
174 key_schedule.increment_counter();
175
176 if self.flush_policy.flush_transport() {
177 self.flush_transport().await?;
178 }
179 }
180
181 Ok(())
182 }
183
184 #[inline]
185 async fn flush_transport(&mut self) -> Result<(), TlsError> {
186 self.delegate
187 .flush()
188 .await
189 .map_err(|e| TlsError::Io(e.kind()))
190 }
191
192 fn create_read_buffer(&mut self) -> ReadBuffer<'_> {
193 self.decrypted.create_read_buffer(self.record_reader.buf)
194 }
195
196 pub async fn read(&mut self, buf: &mut [u8]) -> Result<usize, TlsError> {
198 if buf.is_empty() {
199 return Ok(0);
200 }
201 let mut buffer = self.read_buffered().await?;
202
203 let len = buffer.pop_into(buf);
204 trace!("Copied {} bytes", len);
205
206 Ok(len)
207 }
208
209 pub async fn read_buffered(&mut self) -> Result<ReadBuffer<'_>, TlsError> {
211 if self.is_opened() {
212 while self.decrypted.is_empty() {
213 self.read_application_data().await?;
214 }
215
216 Ok(self.create_read_buffer())
217 } else {
218 Err(TlsError::MissingHandshake)
219 }
220 }
221
222 async fn read_application_data(&mut self) -> Result<(), TlsError> {
223 let buf_ptr_range = self.record_reader.buf.as_ptr_range();
224 let record = self
225 .record_reader
226 .read(&mut self.delegate, self.key_schedule.read_state())
227 .await?;
228
229 let mut handler = DecryptedReadHandler {
230 source_buffer: buf_ptr_range,
231 buffer_info: &mut self.decrypted,
232 is_open: self.opened.get_mut(),
233 };
234 decrypt_record(
235 self.key_schedule.read_state(),
236 record,
237 |_key_schedule, record| handler.handle(record),
238 )?;
239
240 Ok(())
241 }
242
243 async fn close_internal(&mut self) -> Result<(), TlsError> {
245 self.flush().await?;
246
247 let is_opened = self.is_opened();
248 let (write_key_schedule, read_key_schedule) = self.key_schedule.as_split();
249 let slice = self.record_write_buf.write_record(
250 &ClientRecord::close_notify(is_opened),
251 write_key_schedule,
252 Some(read_key_schedule),
253 )?;
254
255 self.delegate
256 .write_all(slice)
257 .await
258 .map_err(|e| TlsError::Io(e.kind()))?;
259
260 self.key_schedule.write_state().increment_counter();
261
262 self.flush_transport().await
263 }
264
265 pub async fn close(mut self) -> Result<Socket, (Socket, TlsError)> {
267 match self.close_internal().await {
268 Ok(()) => Ok(self.delegate),
269 Err(e) => Err((self.delegate, e)),
270 }
271 }
272
273 pub fn split(
274 &mut self,
275 ) -> (
276 TlsReader<'_, Socket, CipherSuite>,
277 TlsWriter<'_, Socket, CipherSuite>,
278 )
279 where
280 Socket: Clone,
281 {
282 let (wks, rks) = self.key_schedule.as_split();
283
284 let reader = TlsReader {
285 opened: &self.opened,
286 delegate: self.delegate.clone(),
287 key_schedule: rks,
288 record_reader: self.record_reader.reborrow_mut(),
289 decrypted: &mut self.decrypted,
290 };
291 let writer = TlsWriter {
292 opened: &self.opened,
293 delegate: self.delegate.clone(),
294 key_schedule: wks,
295 record_write_buf: self.record_write_buf.reborrow_mut(),
296 flush_policy: self.flush_policy,
297 };
298
299 (reader, writer)
300 }
301}
302
303impl<'a, Socket, CipherSuite> ErrorType for TlsConnection<'a, Socket, CipherSuite>
304where
305 Socket: AsyncRead + AsyncWrite + 'a,
306 CipherSuite: TlsCipherSuite + 'static,
307{
308 type Error = TlsError;
309}
310
311impl<'a, Socket, CipherSuite> AsyncRead for TlsConnection<'a, Socket, CipherSuite>
312where
313 Socket: AsyncRead + AsyncWrite + 'a,
314 CipherSuite: TlsCipherSuite + 'static,
315{
316 async fn read(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error> {
317 TlsConnection::read(self, buf).await
318 }
319}
320
321impl<'a, Socket, CipherSuite> BufRead for TlsConnection<'a, Socket, CipherSuite>
322where
323 Socket: AsyncRead + AsyncWrite + 'a,
324 CipherSuite: TlsCipherSuite + 'static,
325{
326 async fn fill_buf(&mut self) -> Result<&[u8], Self::Error> {
327 self.read_buffered().await.map(|mut buf| buf.peek_all())
328 }
329
330 fn consume(&mut self, amt: usize) {
331 self.create_read_buffer().pop(amt);
332 }
333}
334
335impl<'a, Socket, CipherSuite> AsyncWrite for TlsConnection<'a, Socket, CipherSuite>
336where
337 Socket: AsyncRead + AsyncWrite + 'a,
338 CipherSuite: TlsCipherSuite + 'static,
339{
340 async fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
341 TlsConnection::write(self, buf).await
342 }
343
344 async fn flush(&mut self) -> Result<(), Self::Error> {
345 TlsConnection::flush(self).await
346 }
347}
348
349pub struct TlsReader<'a, Socket, CipherSuite>
350where
351 CipherSuite: TlsCipherSuite + 'static,
352{
353 opened: &'a AtomicBool,
354 delegate: Socket,
355 key_schedule: &'a mut ReadKeySchedule<CipherSuite>,
356 record_reader: RecordReaderBorrowMut<'a>,
357 decrypted: &'a mut DecryptedBufferInfo,
358}
359
360impl<Socket, CipherSuite> AsRef<Socket> for TlsReader<'_, Socket, CipherSuite>
361where
362 CipherSuite: TlsCipherSuite + 'static,
363{
364 fn as_ref(&self) -> &Socket {
365 &self.delegate
366 }
367}
368
369impl<'a, Socket, CipherSuite> TlsReader<'a, Socket, CipherSuite>
370where
371 Socket: AsyncRead + 'a,
372 CipherSuite: TlsCipherSuite + 'static,
373{
374 fn create_read_buffer(&mut self) -> ReadBuffer<'_> {
375 self.decrypted.create_read_buffer(self.record_reader.buf)
376 }
377
378 pub async fn read_buffered(&mut self) -> Result<ReadBuffer<'_>, TlsError> {
380 if self.opened.load(Ordering::Acquire) {
381 while self.decrypted.is_empty() {
382 self.read_application_data().await?;
383 }
384
385 Ok(self.create_read_buffer())
386 } else {
387 Err(TlsError::MissingHandshake)
388 }
389 }
390
391 async fn read_application_data(&mut self) -> Result<(), TlsError> {
392 let buf_ptr_range = self.record_reader.buf.as_ptr_range();
393 let record = self
394 .record_reader
395 .read(&mut self.delegate, self.key_schedule)
396 .await?;
397
398 let mut opened = self.opened.load(Ordering::Acquire);
399 let mut handler = DecryptedReadHandler {
400 source_buffer: buf_ptr_range,
401 buffer_info: self.decrypted,
402 is_open: &mut opened,
403 };
404 let result = decrypt_record(self.key_schedule, record, |_key_schedule, record| {
405 handler.handle(record)
406 });
407
408 if !opened {
409 self.opened.store(false, Ordering::Release);
410 }
411 result
412 }
413}
414
415pub struct TlsWriter<'a, Socket, CipherSuite>
416where
417 CipherSuite: TlsCipherSuite + 'static,
418{
419 opened: &'a AtomicBool,
420 delegate: Socket,
421 key_schedule: &'a mut WriteKeySchedule<CipherSuite>,
422 record_write_buf: WriteBufferBorrowMut<'a>,
423 flush_policy: FlushPolicy,
424}
425
426impl<'a, Socket, CipherSuite> TlsWriter<'a, Socket, CipherSuite>
427where
428 Socket: AsyncWrite + 'a,
429 CipherSuite: TlsCipherSuite + 'static,
430{
431 #[inline]
432 async fn flush_transport(&mut self) -> Result<(), TlsError> {
433 self.delegate
434 .flush()
435 .await
436 .map_err(|e| TlsError::Io(e.kind()))
437 }
438}
439
440impl<Socket, CipherSuite> AsRef<Socket> for TlsWriter<'_, Socket, CipherSuite>
441where
442 CipherSuite: TlsCipherSuite + 'static,
443{
444 fn as_ref(&self) -> &Socket {
445 &self.delegate
446 }
447}
448
449impl<Socket, CipherSuite> ErrorType for TlsWriter<'_, Socket, CipherSuite>
450where
451 CipherSuite: TlsCipherSuite + 'static,
452{
453 type Error = TlsError;
454}
455
456impl<Socket, CipherSuite> ErrorType for TlsReader<'_, Socket, CipherSuite>
457where
458 CipherSuite: TlsCipherSuite + 'static,
459{
460 type Error = TlsError;
461}
462
463impl<'a, Socket, CipherSuite> AsyncRead for TlsReader<'a, Socket, CipherSuite>
464where
465 Socket: AsyncRead + 'a,
466 CipherSuite: TlsCipherSuite + 'static,
467{
468 async fn read(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error> {
469 if buf.is_empty() {
470 return Ok(0);
471 }
472 let mut buffer = self.read_buffered().await?;
473
474 let len = buffer.pop_into(buf);
475 trace!("Copied {} bytes", len);
476
477 Ok(len)
478 }
479}
480
481impl<'a, Socket, CipherSuite> BufRead for TlsReader<'a, Socket, CipherSuite>
482where
483 Socket: AsyncRead + 'a,
484 CipherSuite: TlsCipherSuite + 'static,
485{
486 async fn fill_buf(&mut self) -> Result<&[u8], Self::Error> {
487 self.read_buffered().await.map(|mut buf| buf.peek_all())
488 }
489
490 fn consume(&mut self, amt: usize) {
491 self.create_read_buffer().pop(amt);
492 }
493}
494
495impl<'a, Socket, CipherSuite> AsyncWrite for TlsWriter<'a, Socket, CipherSuite>
496where
497 Socket: AsyncWrite + 'a,
498 CipherSuite: TlsCipherSuite + 'static,
499{
500 async fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
501 if self.opened.load(Ordering::Acquire) {
502 if !self
503 .record_write_buf
504 .contains(ClientRecordHeader::ApplicationData)
505 {
506 self.flush().await?;
507 self.record_write_buf
508 .start_record(ClientRecordHeader::ApplicationData)?;
509 }
510
511 let buffered = self.record_write_buf.append(buf);
512
513 if self.record_write_buf.is_full() {
514 self.flush().await?;
515 }
516
517 Ok(buffered)
518 } else {
519 Err(TlsError::MissingHandshake)
520 }
521 }
522
523 async fn flush(&mut self) -> Result<(), Self::Error> {
524 if !self.record_write_buf.is_empty() {
525 let slice = self.record_write_buf.close_record(self.key_schedule)?;
526
527 self.delegate
528 .write_all(slice)
529 .await
530 .map_err(|e| TlsError::Io(e.kind()))?;
531
532 self.key_schedule.increment_counter();
533
534 if self.flush_policy.flush_transport() {
535 self.flush_transport().await?;
536 }
537 }
538
539 Ok(())
540 }
541}