1use core::sync::atomic::Ordering;
2
3use crate::common::decrypted_buffer_info::DecryptedBufferInfo;
4use crate::common::decrypted_read_handler::DecryptedReadHandler;
5use crate::connection::{Handshake, State, decrypt_record};
6use crate::flush_policy::FlushPolicy;
7use crate::key_schedule::KeySchedule;
8use crate::key_schedule::{ReadKeySchedule, WriteKeySchedule};
9use crate::read_buffer::ReadBuffer;
10use crate::record::{ClientRecord, ClientRecordHeader};
11use crate::record_reader::{RecordReader, RecordReaderBorrowMut};
12use crate::write_buffer::{WriteBuffer, WriteBufferBorrowMut};
13use embedded_io::Error as _;
14use embedded_io::{BufRead, ErrorType, Read, Write};
15use portable_atomic::AtomicBool;
16
17pub use crate::TlsError;
18pub use crate::config::*;
19
20pub struct TlsConnection<'a, Socket, CipherSuite>
24where
25 Socket: Read + Write + 'a,
26 CipherSuite: TlsCipherSuite + 'static,
27{
28 delegate: Socket,
29 opened: AtomicBool,
30 key_schedule: KeySchedule<CipherSuite>,
31 record_reader: RecordReader<'a>,
32 record_write_buf: WriteBuffer<'a>,
33 decrypted: DecryptedBufferInfo,
34 flush_policy: FlushPolicy,
35}
36
37impl<'a, Socket, CipherSuite> TlsConnection<'a, Socket, CipherSuite>
38where
39 Socket: Read + Write + 'a,
40 CipherSuite: TlsCipherSuite + 'static,
41{
42 fn is_opened(&mut self) -> bool {
43 *self.opened.get_mut()
44 }
45
46 pub fn new(
58 delegate: Socket,
59 record_read_buf: &'a mut [u8],
60 record_write_buf: &'a mut [u8],
61 ) -> Self {
62 Self {
63 delegate,
64 opened: AtomicBool::new(false),
65 key_schedule: KeySchedule::new(),
66 record_reader: RecordReader::new(record_read_buf),
67 record_write_buf: WriteBuffer::new(record_write_buf),
68 decrypted: DecryptedBufferInfo::default(),
69 flush_policy: FlushPolicy::default(),
70 }
71 }
72
73 #[inline]
78 pub fn flush_policy(&self) -> FlushPolicy {
79 self.flush_policy
80 }
81
82 #[inline]
87 pub fn set_flush_policy(&mut self, policy: FlushPolicy) {
88 self.flush_policy = policy;
89 }
90
91 pub fn open<Provider>(&mut self, mut context: TlsContext<Provider>) -> Result<(), TlsError>
97 where
98 Provider: CryptoProvider<CipherSuite = CipherSuite>,
99 {
100 let mut handshake: Handshake<CipherSuite> = Handshake::new();
101 if let (Ok(verifier), Some(server_name)) = (
102 context.crypto_provider.verifier(),
103 context.config.server_name,
104 ) {
105 verifier.set_hostname_verification(server_name)?;
106 }
107 let mut state = State::ClientHello;
108
109 while state != State::ApplicationData {
110 let next_state = state.process_blocking(
111 &mut self.delegate,
112 &mut handshake,
113 &mut self.record_reader,
114 &mut self.record_write_buf,
115 &mut self.key_schedule,
116 context.config,
117 &mut context.crypto_provider,
118 )?;
119 trace!("State {:?} -> {:?}", state, next_state);
120 state = next_state;
121 }
122 *self.opened.get_mut() = true;
123
124 Ok(())
125 }
126
127 pub fn write(&mut self, buf: &[u8]) -> Result<usize, TlsError> {
136 if self.is_opened() {
137 if !self
138 .record_write_buf
139 .contains(ClientRecordHeader::ApplicationData)
140 {
141 self.flush()?;
142 self.record_write_buf
143 .start_record(ClientRecordHeader::ApplicationData)?;
144 }
145
146 let buffered = self.record_write_buf.append(buf);
147
148 if self.record_write_buf.is_full() {
149 self.flush()?;
150 }
151
152 Ok(buffered)
153 } else {
154 Err(TlsError::MissingHandshake)
155 }
156 }
157
158 pub fn flush(&mut self) -> Result<(), TlsError> {
161 if !self.record_write_buf.is_empty() {
162 let key_schedule = self.key_schedule.write_state();
163 let slice = self.record_write_buf.close_record(key_schedule)?;
164
165 self.delegate
166 .write_all(slice)
167 .map_err(|e| TlsError::Io(e.kind()))?;
168
169 key_schedule.increment_counter();
170
171 if self.flush_policy.flush_transport() {
172 self.flush_transport()?;
173 }
174 }
175
176 Ok(())
177 }
178
179 #[inline]
180 fn flush_transport(&mut self) -> Result<(), TlsError> {
181 self.delegate.flush().map_err(|e| TlsError::Io(e.kind()))
182 }
183
184 fn create_read_buffer(&mut self) -> ReadBuffer<'_> {
185 self.decrypted.create_read_buffer(self.record_reader.buf)
186 }
187
188 pub fn read(&mut self, buf: &mut [u8]) -> Result<usize, TlsError> {
190 if buf.is_empty() {
191 return Ok(0);
192 }
193 let mut buffer = self.read_buffered()?;
194
195 let len = buffer.pop_into(buf);
196 trace!("Copied {} bytes", len);
197
198 Ok(len)
199 }
200
201 pub fn read_buffered(&mut self) -> Result<ReadBuffer<'_>, TlsError> {
203 if self.is_opened() {
204 while self.decrypted.is_empty() {
205 self.read_application_data()?;
206 }
207
208 Ok(self.create_read_buffer())
209 } else {
210 Err(TlsError::MissingHandshake)
211 }
212 }
213
214 fn read_application_data(&mut self) -> Result<(), TlsError> {
215 let buf_ptr_range = self.record_reader.buf.as_ptr_range();
216 let key_schedule = self.key_schedule.read_state();
217 let record = self
218 .record_reader
219 .read_blocking(&mut self.delegate, key_schedule)?;
220
221 let mut handler = DecryptedReadHandler {
222 source_buffer: buf_ptr_range,
223 buffer_info: &mut self.decrypted,
224 is_open: self.opened.get_mut(),
225 };
226 decrypt_record(key_schedule, record, |_key_schedule, record| {
227 handler.handle(record)
228 })?;
229
230 Ok(())
231 }
232
233 fn close_internal(&mut self) -> Result<(), TlsError> {
234 self.flush()?;
235
236 let is_opened = self.is_opened();
237 let (write_key_schedule, read_key_schedule) = self.key_schedule.as_split();
238 let slice = self.record_write_buf.write_record(
239 &ClientRecord::close_notify(is_opened),
240 write_key_schedule,
241 Some(read_key_schedule),
242 )?;
243
244 self.delegate
245 .write_all(slice)
246 .map_err(|e| TlsError::Io(e.kind()))?;
247
248 self.key_schedule.write_state().increment_counter();
249
250 self.flush_transport()?;
251
252 Ok(())
253 }
254
255 pub fn close(mut self) -> Result<Socket, (Socket, TlsError)> {
257 match self.close_internal() {
258 Ok(()) => Ok(self.delegate),
259 Err(e) => Err((self.delegate, e)),
260 }
261 }
262
263 pub fn split(
264 &mut self,
265 ) -> (
266 TlsReader<'_, Socket, CipherSuite>,
267 TlsWriter<'_, Socket, CipherSuite>,
268 )
269 where
270 Socket: Clone,
271 {
272 let (wks, rks) = self.key_schedule.as_split();
273
274 let reader = TlsReader {
275 opened: &self.opened,
276 delegate: self.delegate.clone(),
277 key_schedule: rks,
278 record_reader: self.record_reader.reborrow_mut(),
279 decrypted: &mut self.decrypted,
280 };
281 let writer = TlsWriter {
282 opened: &self.opened,
283 delegate: self.delegate.clone(),
284 key_schedule: wks,
285 record_write_buf: self.record_write_buf.reborrow_mut(),
286 flush_policy: self.flush_policy,
287 };
288
289 (reader, writer)
290 }
291}
292
293impl<'a, Socket, CipherSuite> ErrorType for TlsConnection<'a, Socket, CipherSuite>
294where
295 Socket: Read + Write + 'a,
296 CipherSuite: TlsCipherSuite + 'static,
297{
298 type Error = TlsError;
299}
300
301impl<'a, Socket, CipherSuite> Read for TlsConnection<'a, Socket, CipherSuite>
302where
303 Socket: Read + Write + 'a,
304 CipherSuite: TlsCipherSuite + 'static,
305{
306 fn read(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error> {
307 TlsConnection::read(self, buf)
308 }
309}
310
311impl<'a, Socket, CipherSuite> BufRead for TlsConnection<'a, Socket, CipherSuite>
312where
313 Socket: Read + Write + 'a,
314 CipherSuite: TlsCipherSuite + 'static,
315{
316 fn fill_buf(&mut self) -> Result<&[u8], Self::Error> {
317 self.read_buffered().map(|mut buf| buf.peek_all())
318 }
319
320 fn consume(&mut self, amt: usize) {
321 self.create_read_buffer().pop(amt);
322 }
323}
324
325impl<'a, Socket, CipherSuite> Write for TlsConnection<'a, Socket, CipherSuite>
326where
327 Socket: Read + Write + 'a,
328 CipherSuite: TlsCipherSuite + 'static,
329{
330 fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
331 TlsConnection::write(self, buf)
332 }
333
334 fn flush(&mut self) -> Result<(), Self::Error> {
335 TlsConnection::flush(self)
336 }
337}
338
339pub struct TlsReader<'a, Socket, CipherSuite>
340where
341 CipherSuite: TlsCipherSuite + 'static,
342{
343 opened: &'a AtomicBool,
344 delegate: Socket,
345 key_schedule: &'a mut ReadKeySchedule<CipherSuite>,
346 record_reader: RecordReaderBorrowMut<'a>,
347 decrypted: &'a mut DecryptedBufferInfo,
348}
349
350impl<Socket, CipherSuite> AsRef<Socket> for TlsReader<'_, Socket, CipherSuite>
351where
352 CipherSuite: TlsCipherSuite + 'static,
353{
354 fn as_ref(&self) -> &Socket {
355 &self.delegate
356 }
357}
358
359impl<'a, Socket, CipherSuite> TlsReader<'a, Socket, CipherSuite>
360where
361 Socket: Read + 'a,
362 CipherSuite: TlsCipherSuite + 'static,
363{
364 fn create_read_buffer(&mut self) -> ReadBuffer<'_> {
365 self.decrypted.create_read_buffer(self.record_reader.buf)
366 }
367
368 pub fn read_buffered(&mut self) -> Result<ReadBuffer<'_>, TlsError> {
370 if self.opened.load(Ordering::Acquire) {
371 while self.decrypted.is_empty() {
372 self.read_application_data()?;
373 }
374
375 Ok(self.create_read_buffer())
376 } else {
377 Err(TlsError::MissingHandshake)
378 }
379 }
380
381 fn read_application_data(&mut self) -> Result<(), TlsError> {
382 let buf_ptr_range = self.record_reader.buf.as_ptr_range();
383 let record = self
384 .record_reader
385 .read_blocking(&mut self.delegate, self.key_schedule)?;
386
387 let mut opened = self.opened.load(Ordering::Acquire);
388 let mut handler = DecryptedReadHandler {
389 source_buffer: buf_ptr_range,
390 buffer_info: self.decrypted,
391 is_open: &mut opened,
392 };
393 let result = decrypt_record(self.key_schedule, record, |_key_schedule, record| {
394 handler.handle(record)
395 });
396
397 if !opened {
398 self.opened.store(false, Ordering::Release);
399 }
400 result
401 }
402}
403
404pub struct TlsWriter<'a, Socket, CipherSuite>
405where
406 CipherSuite: TlsCipherSuite + 'static,
407{
408 opened: &'a AtomicBool,
409 delegate: Socket,
410 key_schedule: &'a mut WriteKeySchedule<CipherSuite>,
411 record_write_buf: WriteBufferBorrowMut<'a>,
412 flush_policy: FlushPolicy,
413}
414
415impl<'a, Socket, CipherSuite> TlsWriter<'a, Socket, CipherSuite>
416where
417 Socket: Write + 'a,
418 CipherSuite: TlsCipherSuite + 'static,
419{
420 fn flush_transport(&mut self) -> Result<(), TlsError> {
421 self.delegate.flush().map_err(|e| TlsError::Io(e.kind()))
422 }
423}
424
425impl<Socket, CipherSuite> AsRef<Socket> for TlsWriter<'_, Socket, CipherSuite>
426where
427 CipherSuite: TlsCipherSuite + 'static,
428{
429 fn as_ref(&self) -> &Socket {
430 &self.delegate
431 }
432}
433
434impl<Socket, CipherSuite> ErrorType for TlsWriter<'_, Socket, CipherSuite>
435where
436 CipherSuite: TlsCipherSuite + 'static,
437{
438 type Error = TlsError;
439}
440
441impl<Socket, CipherSuite> ErrorType for TlsReader<'_, Socket, CipherSuite>
442where
443 CipherSuite: TlsCipherSuite + 'static,
444{
445 type Error = TlsError;
446}
447
448impl<'a, Socket, CipherSuite> Read for TlsReader<'a, Socket, CipherSuite>
449where
450 Socket: Read + 'a,
451 CipherSuite: TlsCipherSuite + 'static,
452{
453 fn read(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error> {
454 if buf.is_empty() {
455 return Ok(0);
456 }
457 let mut buffer = self.read_buffered()?;
458
459 let len = buffer.pop_into(buf);
460 trace!("Copied {} bytes", len);
461
462 Ok(len)
463 }
464}
465
466impl<'a, Socket, CipherSuite> BufRead for TlsReader<'a, Socket, CipherSuite>
467where
468 Socket: Read + 'a,
469 CipherSuite: TlsCipherSuite + 'static,
470{
471 fn fill_buf(&mut self) -> Result<&[u8], Self::Error> {
472 self.read_buffered().map(|mut buf| buf.peek_all())
473 }
474
475 fn consume(&mut self, amt: usize) {
476 self.create_read_buffer().pop(amt);
477 }
478}
479
480impl<'a, Socket, CipherSuite> Write for TlsWriter<'a, Socket, CipherSuite>
481where
482 Socket: Write + 'a,
483 CipherSuite: TlsCipherSuite + 'static,
484{
485 fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
486 if self.opened.load(Ordering::Acquire) {
487 if !self
488 .record_write_buf
489 .contains(ClientRecordHeader::ApplicationData)
490 {
491 self.flush()?;
492 self.record_write_buf
493 .start_record(ClientRecordHeader::ApplicationData)?;
494 }
495
496 let buffered = self.record_write_buf.append(buf);
497
498 if self.record_write_buf.is_full() {
499 self.flush()?;
500 }
501
502 Ok(buffered)
503 } else {
504 Err(TlsError::MissingHandshake)
505 }
506 }
507
508 fn flush(&mut self) -> Result<(), Self::Error> {
509 if !self.record_write_buf.is_empty() {
510 let slice = self.record_write_buf.close_record(self.key_schedule)?;
511
512 self.delegate
513 .write_all(slice)
514 .map_err(|e| TlsError::Io(e.kind()))?;
515
516 self.key_schedule.increment_counter();
517
518 if self.flush_policy.flush_transport() {
519 self.flush_transport()?;
520 }
521 }
522
523 Ok(())
524 }
525}