1use crate::common::decrypted_buffer_info::DecryptedBufferInfo;
2use crate::common::decrypted_read_handler::DecryptedReadHandler;
3use crate::connection::*;
4use crate::key_schedule::KeySchedule;
5use crate::key_schedule::{ReadKeySchedule, SharedState, WriteKeySchedule};
6use crate::read_buffer::ReadBuffer;
7use crate::record::{ClientRecord, ClientRecordHeader};
8use crate::record_reader::RecordReader;
9use crate::split::{SplitState, SplitStateContainer};
10use crate::write_buffer::WriteBuffer;
11use crate::TlsError;
12use embedded_io::Error as _;
13use embedded_io::ErrorType;
14use embedded_io_async::{BufRead, Read as AsyncRead, Write as AsyncWrite};
15use rand_core::{CryptoRng, RngCore};
16
17pub use crate::config::*;
18#[cfg(feature = "std")]
19pub use crate::split::ManagedSplitState;
20pub use crate::split::SplitConnectionState;
21
22pub struct TlsConnection<'a, Socket, CipherSuite>
26where
27 Socket: AsyncRead + AsyncWrite + 'a,
28 CipherSuite: TlsCipherSuite + 'static,
29{
30 delegate: Socket,
31 opened: bool,
32 key_schedule: KeySchedule<CipherSuite>,
33 record_reader: RecordReader<'a, CipherSuite>,
34 record_write_buf: WriteBuffer<'a>,
35 decrypted: DecryptedBufferInfo,
36}
37
38impl<'a, Socket, CipherSuite> TlsConnection<'a, Socket, CipherSuite>
39where
40 Socket: AsyncRead + AsyncWrite + 'a,
41 CipherSuite: TlsCipherSuite + 'static,
42{
43 pub fn new(
55 delegate: Socket,
56 record_read_buf: &'a mut [u8],
57 record_write_buf: &'a mut [u8],
58 ) -> Self {
59 Self {
60 delegate,
61 opened: false,
62 key_schedule: KeySchedule::new(),
63 record_reader: RecordReader::new(record_read_buf),
64 record_write_buf: WriteBuffer::new(record_write_buf),
65 decrypted: DecryptedBufferInfo::default(),
66 }
67 }
68
69 pub async fn open<'v, RNG, Verifier>(
75 &mut self,
76 context: TlsContext<'v, CipherSuite, RNG>,
77 ) -> Result<(), TlsError>
78 where
79 RNG: CryptoRng + RngCore,
80 Verifier: TlsVerifier<'v, CipherSuite>,
81 {
82 let mut handshake: Handshake<CipherSuite, Verifier> =
83 Handshake::new(Verifier::new(context.config.server_name));
84 let mut state = State::ClientHello;
85
86 while state != State::ApplicationData {
87 let next_state = state
88 .process(
89 &mut self.delegate,
90 &mut handshake,
91 &mut self.record_reader,
92 &mut self.record_write_buf,
93 &mut self.key_schedule,
94 context.config,
95 context.rng,
96 )
97 .await?;
98 trace!("State {:?} -> {:?}", state, next_state);
99 state = next_state;
100 }
101 self.opened = true;
102
103 Ok(())
104 }
105
106 pub async fn write(&mut self, buf: &[u8]) -> Result<usize, TlsError> {
115 if self.opened {
116 if !self
117 .record_write_buf
118 .contains(ClientRecordHeader::ApplicationData)
119 {
120 self.flush().await?;
121 self.record_write_buf
122 .start_record(ClientRecordHeader::ApplicationData)?;
123 }
124
125 let buffered = self.record_write_buf.append(buf);
126
127 if self.record_write_buf.is_full() {
128 self.flush().await?;
129 }
130
131 Ok(buffered)
132 } else {
133 Err(TlsError::MissingHandshake)
134 }
135 }
136
137 pub async fn flush(&mut self) -> Result<(), TlsError> {
140 if !self.record_write_buf.is_empty() {
141 let key_schedule = self.key_schedule.write_state();
142 let slice = self.record_write_buf.close_record(key_schedule)?;
143
144 self.delegate
145 .write_all(slice)
146 .await
147 .map_err(|e| TlsError::Io(e.kind()))?;
148
149 key_schedule.increment_counter();
150
151 self.delegate
152 .flush()
153 .await
154 .map_err(|e| TlsError::Io(e.kind()))?;
155 }
156
157 Ok(())
158 }
159
160 fn create_read_buffer(&mut self) -> ReadBuffer {
161 self.decrypted.create_read_buffer(self.record_reader.buf)
162 }
163
164 pub async fn read(&mut self, buf: &mut [u8]) -> Result<usize, TlsError> {
166 if buf.is_empty() {
167 return Ok(0);
168 }
169 let mut buffer = self.read_buffered().await?;
170
171 let len = buffer.pop_into(buf);
172 trace!("Copied {} bytes", len);
173
174 Ok(len)
175 }
176
177 pub async fn read_buffered(&mut self) -> Result<ReadBuffer, TlsError> {
179 if self.opened {
180 while self.decrypted.is_empty() {
181 self.read_application_data().await?;
182 }
183
184 Ok(self.create_read_buffer())
185 } else {
186 Err(TlsError::MissingHandshake)
187 }
188 }
189
190 async fn read_application_data(&mut self) -> Result<(), TlsError> {
191 let buf_ptr_range = self.record_reader.buf.as_ptr_range();
192 let record = self
193 .record_reader
194 .read(&mut self.delegate, self.key_schedule.read_state())
195 .await?;
196
197 let mut handler = DecryptedReadHandler {
198 source_buffer: buf_ptr_range,
199 buffer_info: &mut self.decrypted,
200 is_open: &mut self.opened,
201 };
202 decrypt_record(
203 self.key_schedule.read_state(),
204 record,
205 |_key_schedule, record| handler.handle(record),
206 )?;
207
208 Ok(())
209 }
210
211 async fn close_internal(&mut self) -> Result<(), TlsError> {
213 self.flush().await?;
214
215 let (write_key_schedule, read_key_schedule) = self.key_schedule.as_split();
216 let slice = self.record_write_buf.write_record(
217 &ClientRecord::close_notify(self.opened),
218 write_key_schedule,
219 Some(read_key_schedule),
220 )?;
221
222 self.delegate
223 .write_all(slice)
224 .await
225 .map_err(|e| TlsError::Io(e.kind()))?;
226
227 self.key_schedule.write_state().increment_counter();
228
229 self.flush().await
230 }
231
232 pub async fn close(mut self) -> Result<Socket, (Socket, TlsError)> {
234 match self.close_internal().await {
235 Ok(()) => Ok(self.delegate),
236 Err(e) => Err((self.delegate, e)),
237 }
238 }
239
240 #[cfg(feature = "std")]
241 pub fn split(
242 self,
243 ) -> (
244 TlsReader<'a, Socket, CipherSuite, ManagedSplitState>,
245 TlsWriter<'a, Socket, CipherSuite, ManagedSplitState>,
246 )
247 where
248 Socket: Clone,
249 {
250 self.split_with(ManagedSplitState::new())
251 }
252
253 pub fn split_with<StateContainer>(
254 self,
255 state: StateContainer,
256 ) -> (
257 TlsReader<'a, Socket, CipherSuite, StateContainer::State>,
258 TlsWriter<'a, Socket, CipherSuite, StateContainer::State>,
259 )
260 where
261 Socket: Clone,
262 StateContainer: SplitStateContainer,
263 {
264 let state = state.state();
265 state.set_open(self.opened);
266
267 let (shared, wks, rks) = self.key_schedule.split();
268
269 let reader = TlsReader {
270 state: state.clone(),
271 delegate: self.delegate.clone(),
272 key_schedule: rks,
273 record_reader: self.record_reader,
274 decrypted: self.decrypted,
275 };
276 let writer = TlsWriter {
277 state,
278 delegate: self.delegate,
279 key_schedule_shared: shared,
280 key_schedule: wks,
281 record_write_buf: self.record_write_buf,
282 };
283
284 (reader, writer)
285 }
286
287 pub fn unsplit<State>(
288 reader: TlsReader<'a, Socket, CipherSuite, State>,
289 writer: TlsWriter<'a, Socket, CipherSuite, State>,
290 ) -> Self
291 where
292 Socket: Clone,
293 State: SplitState,
294 {
295 debug_assert!(reader.state.same(&writer.state));
296
297 TlsConnection {
298 delegate: writer.delegate,
299 opened: writer.state.is_open(),
300 key_schedule: KeySchedule::unsplit(
301 writer.key_schedule_shared,
302 writer.key_schedule,
303 reader.key_schedule,
304 ),
305 record_reader: reader.record_reader,
306 record_write_buf: writer.record_write_buf,
307 decrypted: reader.decrypted,
308 }
309 }
310}
311
312impl<'a, Socket, CipherSuite> ErrorType for TlsConnection<'a, Socket, CipherSuite>
313where
314 Socket: AsyncRead + AsyncWrite + 'a,
315 CipherSuite: TlsCipherSuite + 'static,
316{
317 type Error = TlsError;
318}
319
320impl<'a, Socket, CipherSuite> AsyncRead for TlsConnection<'a, Socket, CipherSuite>
321where
322 Socket: AsyncRead + AsyncWrite + 'a,
323 CipherSuite: TlsCipherSuite + 'static,
324{
325 async fn read(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error> {
326 TlsConnection::read(self, buf).await
327 }
328}
329
330impl<'a, Socket, CipherSuite> BufRead for TlsConnection<'a, Socket, CipherSuite>
331where
332 Socket: AsyncRead + AsyncWrite + 'a,
333 CipherSuite: TlsCipherSuite + 'static,
334{
335 async fn fill_buf(&mut self) -> Result<&[u8], Self::Error> {
336 self.read_buffered().await.map(|mut buf| buf.peek_all())
337 }
338
339 fn consume(&mut self, amt: usize) {
340 self.create_read_buffer().pop(amt);
341 }
342}
343
344impl<'a, Socket, CipherSuite> AsyncWrite for TlsConnection<'a, Socket, CipherSuite>
345where
346 Socket: AsyncRead + AsyncWrite + 'a,
347 CipherSuite: TlsCipherSuite + 'static,
348{
349 async fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
350 TlsConnection::write(self, buf).await
351 }
352
353 async fn flush(&mut self) -> Result<(), Self::Error> {
354 TlsConnection::flush(self).await
355 }
356}
357
358pub struct TlsReader<'a, Socket, CipherSuite, State>
359where
360 CipherSuite: TlsCipherSuite + 'static,
361{
362 state: State,
363 delegate: Socket,
364 key_schedule: ReadKeySchedule<CipherSuite>,
365 record_reader: RecordReader<'a, CipherSuite>,
366 decrypted: DecryptedBufferInfo,
367}
368
369impl<'a, Socket, CipherSuite, State> AsRef<Socket> for TlsReader<'a, Socket, CipherSuite, State>
370where
371 CipherSuite: TlsCipherSuite + 'static,
372{
373 fn as_ref(&self) -> &Socket {
374 &self.delegate
375 }
376}
377
378impl<'a, Socket, CipherSuite, State> TlsReader<'a, Socket, CipherSuite, State>
379where
380 Socket: AsyncRead + 'a,
381 CipherSuite: TlsCipherSuite + 'static,
382 State: SplitState,
383{
384 fn create_read_buffer(&mut self) -> ReadBuffer {
385 self.decrypted.create_read_buffer(self.record_reader.buf)
386 }
387
388 pub async fn read_buffered(&mut self) -> Result<ReadBuffer, TlsError> {
390 if self.state.is_open() {
391 while self.decrypted.is_empty() {
392 self.read_application_data().await?;
393 }
394
395 Ok(self.create_read_buffer())
396 } else {
397 Err(TlsError::MissingHandshake)
398 }
399 }
400
401 async fn read_application_data(&mut self) -> Result<(), TlsError> {
402 let buf_ptr_range = self.record_reader.buf.as_ptr_range();
403 let record = self
404 .record_reader
405 .read(&mut self.delegate, &mut self.key_schedule)
406 .await?;
407
408 let mut opened = self.state.is_open();
409 let mut handler = DecryptedReadHandler {
410 source_buffer: buf_ptr_range,
411 buffer_info: &mut self.decrypted,
412 is_open: &mut opened,
413 };
414 let result = decrypt_record(&mut self.key_schedule, record, |_key_schedule, record| {
415 handler.handle(record)
416 });
417
418 if !opened {
419 self.state.set_open(false);
420 }
421 result
422 }
423}
424
425pub struct TlsWriter<'a, Socket, CipherSuite, State>
426where
427 CipherSuite: TlsCipherSuite + 'static,
428{
429 state: State,
430 delegate: Socket,
431 key_schedule_shared: SharedState<CipherSuite>,
432 key_schedule: WriteKeySchedule<CipherSuite>,
433 record_write_buf: WriteBuffer<'a>,
434}
435
436impl<'a, Socket, CipherSuite, State> AsRef<Socket> for TlsWriter<'a, Socket, CipherSuite, State>
437where
438 CipherSuite: TlsCipherSuite + 'static,
439{
440 fn as_ref(&self) -> &Socket {
441 &self.delegate
442 }
443}
444
445impl<'a, Socket, CipherSuite, State> ErrorType for TlsWriter<'a, Socket, CipherSuite, State>
446where
447 CipherSuite: TlsCipherSuite + 'static,
448{
449 type Error = TlsError;
450}
451
452impl<'a, Socket, CipherSuite, State> ErrorType for TlsReader<'a, Socket, CipherSuite, State>
453where
454 CipherSuite: TlsCipherSuite + 'static,
455{
456 type Error = TlsError;
457}
458
459impl<'a, Socket, CipherSuite, State> AsyncRead for TlsReader<'a, Socket, CipherSuite, State>
460where
461 Socket: AsyncRead + 'a,
462 CipherSuite: TlsCipherSuite + 'static,
463 State: SplitState,
464{
465 async fn read(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error> {
466 if buf.is_empty() {
467 return Ok(0);
468 }
469 let mut buffer = self.read_buffered().await?;
470
471 let len = buffer.pop_into(buf);
472 trace!("Copied {} bytes", len);
473
474 Ok(len)
475 }
476}
477
478impl<'a, Socket, CipherSuite, State> BufRead for TlsReader<'a, Socket, CipherSuite, State>
479where
480 Socket: AsyncRead + 'a,
481 CipherSuite: TlsCipherSuite + 'static,
482 State: SplitState,
483{
484 async fn fill_buf(&mut self) -> Result<&[u8], Self::Error> {
485 self.read_buffered().await.map(|mut buf| buf.peek_all())
486 }
487
488 fn consume(&mut self, amt: usize) {
489 self.create_read_buffer().pop(amt);
490 }
491}
492
493impl<'a, Socket, CipherSuite, State> AsyncWrite for TlsWriter<'a, Socket, CipherSuite, State>
494where
495 Socket: AsyncWrite + 'a,
496 CipherSuite: TlsCipherSuite + 'static,
497 State: SplitState,
498{
499 async fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
500 if self.state.is_open() {
501 if !self
502 .record_write_buf
503 .contains(ClientRecordHeader::ApplicationData)
504 {
505 self.flush().await?;
506 self.record_write_buf
507 .start_record(ClientRecordHeader::ApplicationData)?;
508 }
509
510 let buffered = self.record_write_buf.append(buf);
511
512 if self.record_write_buf.is_full() {
513 self.flush().await?;
514 }
515
516 Ok(buffered)
517 } else {
518 Err(TlsError::MissingHandshake)
519 }
520 }
521
522 async fn flush(&mut self) -> Result<(), Self::Error> {
523 if !self.record_write_buf.is_empty() {
524 let slice = self.record_write_buf.close_record(&mut self.key_schedule)?;
525
526 self.delegate
527 .write_all(slice)
528 .await
529 .map_err(|e| TlsError::Io(e.kind()))?;
530
531 self.key_schedule.increment_counter();
532
533 self.delegate
534 .flush()
535 .await
536 .map_err(|e| TlsError::Io(e.kind()))?;
537 }
538
539 Ok(())
540 }
541}