1use crate::common::decrypted_buffer_info::DecryptedBufferInfo;
2use crate::common::decrypted_read_handler::DecryptedReadHandler;
3use crate::connection::*;
4use crate::key_schedule::KeySchedule;
5use crate::key_schedule::{ReadKeySchedule, SharedState, WriteKeySchedule};
6use crate::read_buffer::ReadBuffer;
7use crate::record::{ClientRecord, ClientRecordHeader};
8use crate::record_reader::RecordReader;
9use crate::split::{SplitState, SplitStateContainer};
10use crate::write_buffer::WriteBuffer;
11use embedded_io::Error as _;
12use embedded_io::{BufRead, ErrorType, Read, Write};
13use rand_core::{CryptoRng, RngCore};
14
15pub use crate::config::*;
16#[cfg(feature = "std")]
17pub use crate::split::ManagedSplitState;
18pub use crate::split::SplitConnectionState;
19pub use crate::TlsError;
20
21pub struct TlsConnection<'a, Socket, CipherSuite>
25where
26 Socket: Read + Write + 'a,
27 CipherSuite: TlsCipherSuite + 'static,
28{
29 delegate: Socket,
30 opened: bool,
31 key_schedule: KeySchedule<CipherSuite>,
32 record_reader: RecordReader<'a, CipherSuite>,
33 record_write_buf: WriteBuffer<'a>,
34 decrypted: DecryptedBufferInfo,
35}
36
37impl<'a, Socket, CipherSuite> TlsConnection<'a, Socket, CipherSuite>
38where
39 Socket: Read + Write + 'a,
40 CipherSuite: TlsCipherSuite + 'static,
41{
42 pub fn new(
54 delegate: Socket,
55 record_read_buf: &'a mut [u8],
56 record_write_buf: &'a mut [u8],
57 ) -> Self {
58 Self {
59 delegate,
60 opened: false,
61 key_schedule: KeySchedule::new(),
62 record_reader: RecordReader::new(record_read_buf),
63 record_write_buf: WriteBuffer::new(record_write_buf),
64 decrypted: DecryptedBufferInfo::default(),
65 }
66 }
67
68 pub fn open<'v, RNG, Verifier>(
74 &mut self,
75 context: TlsContext<'v, CipherSuite, RNG>,
76 ) -> Result<(), TlsError>
77 where
78 RNG: CryptoRng + RngCore,
79 Verifier: TlsVerifier<'v, CipherSuite>,
80 {
81 let mut handshake: Handshake<CipherSuite, Verifier> =
82 Handshake::new(Verifier::new(context.config.server_name));
83 let mut state = State::ClientHello;
84
85 while state != State::ApplicationData {
86 let next_state = state.process_blocking(
87 &mut self.delegate,
88 &mut handshake,
89 &mut self.record_reader,
90 &mut self.record_write_buf,
91 &mut self.key_schedule,
92 context.config,
93 context.rng,
94 )?;
95 trace!("State {:?} -> {:?}", state, next_state);
96 state = next_state;
97 }
98 self.opened = true;
99
100 Ok(())
101 }
102
103 pub fn write(&mut self, buf: &[u8]) -> Result<usize, TlsError> {
112 if self.opened {
113 if !self
114 .record_write_buf
115 .contains(ClientRecordHeader::ApplicationData)
116 {
117 self.flush()?;
118 self.record_write_buf
119 .start_record(ClientRecordHeader::ApplicationData)?;
120 }
121
122 let buffered = self.record_write_buf.append(buf);
123
124 if self.record_write_buf.is_full() {
125 self.flush()?;
126 }
127
128 Ok(buffered)
129 } else {
130 Err(TlsError::MissingHandshake)
131 }
132 }
133
134 pub fn flush(&mut self) -> Result<(), TlsError> {
137 if !self.record_write_buf.is_empty() {
138 let key_schedule = self.key_schedule.write_state();
139 let slice = self.record_write_buf.close_record(key_schedule)?;
140
141 self.delegate
142 .write_all(slice)
143 .map_err(|e| TlsError::Io(e.kind()))?;
144
145 key_schedule.increment_counter();
146
147 self.delegate.flush().map_err(|e| TlsError::Io(e.kind()))?;
148 }
149
150 Ok(())
151 }
152
153 fn create_read_buffer(&mut self) -> ReadBuffer {
154 self.decrypted.create_read_buffer(self.record_reader.buf)
155 }
156
157 pub fn read(&mut self, buf: &mut [u8]) -> Result<usize, TlsError> {
159 if buf.is_empty() {
160 return Ok(0);
161 }
162 let mut buffer = self.read_buffered()?;
163
164 let len = buffer.pop_into(buf);
165 trace!("Copied {} bytes", len);
166
167 Ok(len)
168 }
169
170 pub fn read_buffered(&mut self) -> Result<ReadBuffer, TlsError> {
172 if self.opened {
173 while self.decrypted.is_empty() {
174 self.read_application_data()?;
175 }
176
177 Ok(self.create_read_buffer())
178 } else {
179 Err(TlsError::MissingHandshake)
180 }
181 }
182
183 fn read_application_data(&mut self) -> Result<(), TlsError> {
184 let buf_ptr_range = self.record_reader.buf.as_ptr_range();
185 let key_schedule = self.key_schedule.read_state();
186 let record = self
187 .record_reader
188 .read_blocking(&mut self.delegate, key_schedule)?;
189
190 let mut handler = DecryptedReadHandler {
191 source_buffer: buf_ptr_range,
192 buffer_info: &mut self.decrypted,
193 is_open: &mut self.opened,
194 };
195 decrypt_record(key_schedule, record, |_key_schedule, record| {
196 handler.handle(record)
197 })?;
198
199 Ok(())
200 }
201
202 fn close_internal(&mut self) -> Result<(), TlsError> {
203 self.flush()?;
204
205 let (write_key_schedule, read_key_schedule) = self.key_schedule.as_split();
206 let slice = self.record_write_buf.write_record(
207 &ClientRecord::close_notify(self.opened),
208 write_key_schedule,
209 Some(read_key_schedule),
210 )?;
211
212 self.delegate
213 .write_all(slice)
214 .map_err(|e| TlsError::Io(e.kind()))?;
215
216 self.key_schedule.write_state().increment_counter();
217
218 self.flush()?;
219
220 Ok(())
221 }
222
223 pub fn close(mut self) -> Result<Socket, (Socket, TlsError)> {
225 match self.close_internal() {
226 Ok(()) => Ok(self.delegate),
227 Err(e) => Err((self.delegate, e)),
228 }
229 }
230
231 #[cfg(feature = "std")]
232 pub fn split(
233 self,
234 ) -> (
235 TlsReader<'a, Socket, CipherSuite, ManagedSplitState>,
236 TlsWriter<'a, Socket, CipherSuite, ManagedSplitState>,
237 )
238 where
239 Socket: Clone,
240 {
241 self.split_with(ManagedSplitState::new())
242 }
243
244 pub fn split_with<StateContainer>(
245 self,
246 state: StateContainer,
247 ) -> (
248 TlsReader<'a, Socket, CipherSuite, StateContainer::State>,
249 TlsWriter<'a, Socket, CipherSuite, StateContainer::State>,
250 )
251 where
252 Socket: Clone,
253 StateContainer: SplitStateContainer,
254 {
255 let state = state.state();
256 state.set_open(self.opened);
257
258 let (shared, wks, rks) = self.key_schedule.split();
259
260 let reader = TlsReader {
261 state: state.clone(),
262 delegate: self.delegate.clone(),
263 key_schedule: rks,
264 record_reader: self.record_reader,
265 decrypted: self.decrypted,
266 };
267 let writer = TlsWriter {
268 state,
269 delegate: self.delegate,
270 key_schedule_shared: shared,
271 key_schedule: wks,
272 record_write_buf: self.record_write_buf,
273 };
274
275 (reader, writer)
276 }
277
278 pub fn unsplit<State>(
279 reader: TlsReader<'a, Socket, CipherSuite, State>,
280 writer: TlsWriter<'a, Socket, CipherSuite, State>,
281 ) -> Self
282 where
283 Socket: Clone,
284 State: SplitState,
285 {
286 debug_assert!(reader.state.same(&writer.state));
287
288 TlsConnection {
289 delegate: writer.delegate,
290 opened: writer.state.is_open(),
291 key_schedule: KeySchedule::unsplit(
292 writer.key_schedule_shared,
293 writer.key_schedule,
294 reader.key_schedule,
295 ),
296 record_reader: reader.record_reader,
297 record_write_buf: writer.record_write_buf,
298 decrypted: reader.decrypted,
299 }
300 }
301}
302
303impl<'a, Socket, CipherSuite> ErrorType for TlsConnection<'a, Socket, CipherSuite>
304where
305 Socket: Read + Write + 'a,
306 CipherSuite: TlsCipherSuite + 'static,
307{
308 type Error = TlsError;
309}
310
311impl<'a, Socket, CipherSuite> Read for TlsConnection<'a, Socket, CipherSuite>
312where
313 Socket: Read + Write + 'a,
314 CipherSuite: TlsCipherSuite + 'static,
315{
316 fn read(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error> {
317 TlsConnection::read(self, buf)
318 }
319}
320
321impl<'a, Socket, CipherSuite> BufRead for TlsConnection<'a, Socket, CipherSuite>
322where
323 Socket: Read + Write + 'a,
324 CipherSuite: TlsCipherSuite + 'static,
325{
326 fn fill_buf(&mut self) -> Result<&[u8], Self::Error> {
327 self.read_buffered().map(|mut buf| buf.peek_all())
328 }
329
330 fn consume(&mut self, amt: usize) {
331 self.create_read_buffer().pop(amt);
332 }
333}
334
335impl<'a, Socket, CipherSuite> Write for TlsConnection<'a, Socket, CipherSuite>
336where
337 Socket: Read + Write + 'a,
338 CipherSuite: TlsCipherSuite + 'static,
339{
340 fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
341 TlsConnection::write(self, buf)
342 }
343
344 fn flush(&mut self) -> Result<(), Self::Error> {
345 TlsConnection::flush(self)
346 }
347}
348
349pub struct TlsReader<'a, Socket, CipherSuite, State>
350where
351 CipherSuite: TlsCipherSuite + 'static,
352{
353 state: State,
354 delegate: Socket,
355 key_schedule: ReadKeySchedule<CipherSuite>,
356 record_reader: RecordReader<'a, CipherSuite>,
357 decrypted: DecryptedBufferInfo,
358}
359
360impl<'a, Socket, CipherSuite, State> AsRef<Socket> for TlsReader<'a, Socket, CipherSuite, State>
361where
362 CipherSuite: TlsCipherSuite + 'static,
363{
364 fn as_ref(&self) -> &Socket {
365 &self.delegate
366 }
367}
368
369impl<'a, Socket, CipherSuite, State> TlsReader<'a, Socket, CipherSuite, State>
370where
371 Socket: Read + 'a,
372 CipherSuite: TlsCipherSuite + 'static,
373 State: SplitState,
374{
375 fn create_read_buffer(&mut self) -> ReadBuffer {
376 self.decrypted.create_read_buffer(self.record_reader.buf)
377 }
378
379 pub fn read_buffered(&mut self) -> Result<ReadBuffer, TlsError> {
381 if self.state.is_open() {
382 while self.decrypted.is_empty() {
383 self.read_application_data()?;
384 }
385
386 Ok(self.create_read_buffer())
387 } else {
388 Err(TlsError::MissingHandshake)
389 }
390 }
391
392 fn read_application_data(&mut self) -> Result<(), TlsError> {
393 let buf_ptr_range = self.record_reader.buf.as_ptr_range();
394 let record = self
395 .record_reader
396 .read_blocking(&mut self.delegate, &mut self.key_schedule)?;
397
398 let mut opened = self.state.is_open();
399 let mut handler = DecryptedReadHandler {
400 source_buffer: buf_ptr_range,
401 buffer_info: &mut self.decrypted,
402 is_open: &mut opened,
403 };
404 let result = decrypt_record(&mut self.key_schedule, record, |_key_schedule, record| {
405 handler.handle(record)
406 });
407
408 if !opened {
409 self.state.set_open(false);
410 }
411 result
412 }
413}
414
415pub struct TlsWriter<'a, Socket, CipherSuite, State>
416where
417 CipherSuite: TlsCipherSuite + 'static,
418{
419 state: State,
420 delegate: Socket,
421 key_schedule_shared: SharedState<CipherSuite>,
422 key_schedule: WriteKeySchedule<CipherSuite>,
423 record_write_buf: WriteBuffer<'a>,
424}
425
426impl<'a, Socket, CipherSuite, State> AsRef<Socket> for TlsWriter<'a, Socket, CipherSuite, State>
427where
428 CipherSuite: TlsCipherSuite + 'static,
429{
430 fn as_ref(&self) -> &Socket {
431 &self.delegate
432 }
433}
434
435impl<'a, Socket, CipherSuite, State> ErrorType for TlsWriter<'a, Socket, CipherSuite, State>
436where
437 CipherSuite: TlsCipherSuite + 'static,
438{
439 type Error = TlsError;
440}
441
442impl<'a, Socket, CipherSuite, State> ErrorType for TlsReader<'a, Socket, CipherSuite, State>
443where
444 CipherSuite: TlsCipherSuite + 'static,
445{
446 type Error = TlsError;
447}
448
449impl<'a, Socket, CipherSuite, State> Read for TlsReader<'a, Socket, CipherSuite, State>
450where
451 Socket: Read + 'a,
452 CipherSuite: TlsCipherSuite + 'static,
453 State: SplitState,
454{
455 fn read(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error> {
456 if buf.is_empty() {
457 return Ok(0);
458 }
459 let mut buffer = self.read_buffered()?;
460
461 let len = buffer.pop_into(buf);
462 trace!("Copied {} bytes", len);
463
464 Ok(len)
465 }
466}
467
468impl<'a, Socket, CipherSuite, State> BufRead for TlsReader<'a, Socket, CipherSuite, State>
469where
470 Socket: Read + 'a,
471 CipherSuite: TlsCipherSuite + 'static,
472 State: SplitState,
473{
474 fn fill_buf(&mut self) -> Result<&[u8], Self::Error> {
475 self.read_buffered().map(|mut buf| buf.peek_all())
476 }
477
478 fn consume(&mut self, amt: usize) {
479 self.create_read_buffer().pop(amt);
480 }
481}
482
483impl<'a, Socket, CipherSuite, State> Write for TlsWriter<'a, Socket, CipherSuite, State>
484where
485 Socket: Write + 'a,
486 CipherSuite: TlsCipherSuite + 'static,
487 State: SplitState,
488{
489 fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
490 if self.state.is_open() {
491 if !self
492 .record_write_buf
493 .contains(ClientRecordHeader::ApplicationData)
494 {
495 self.flush()?;
496 self.record_write_buf
497 .start_record(ClientRecordHeader::ApplicationData)?;
498 }
499
500 let buffered = self.record_write_buf.append(buf);
501
502 if self.record_write_buf.is_full() {
503 self.flush()?;
504 }
505
506 Ok(buffered)
507 } else {
508 Err(TlsError::MissingHandshake)
509 }
510 }
511
512 fn flush(&mut self) -> Result<(), Self::Error> {
513 if !self.record_write_buf.is_empty() {
514 let slice = self.record_write_buf.close_record(&mut self.key_schedule)?;
515
516 self.delegate
517 .write_all(slice)
518 .map_err(|e| TlsError::Io(e.kind()))?;
519
520 self.key_schedule.increment_counter();
521
522 self.delegate.flush().map_err(|e| TlsError::Io(e.kind()))?;
523 }
524
525 Ok(())
526 }
527}