1use alloc::vec::Vec;
4use core::num::NonZeroUsize;
5use core::{fmt, mem};
6#[cfg(feature = "std")]
7use std::error::Error as StdError;
8
9use super::UnbufferedConnectionCommon;
10use crate::client::ClientConnectionData;
11use crate::msgs::deframer::DeframerSliceBuffer;
12use crate::server::ServerConnectionData;
13use crate::Error;
14
15impl UnbufferedConnectionCommon<ClientConnectionData> {
16    pub fn process_tls_records<'c, 'i>(
19        &'c mut self,
20        incoming_tls: &'i mut [u8],
21    ) -> UnbufferedStatus<'c, 'i, ClientConnectionData> {
22        self.process_tls_records_common(incoming_tls, |_| None, |_, _, ()| unreachable!())
23    }
24}
25
26impl UnbufferedConnectionCommon<ServerConnectionData> {
27    pub fn process_tls_records<'c, 'i>(
30        &'c mut self,
31        incoming_tls: &'i mut [u8],
32    ) -> UnbufferedStatus<'c, 'i, ServerConnectionData> {
33        self.process_tls_records_common(
34            incoming_tls,
35            |conn| conn.pop_early_data(),
36            |conn, incoming_tls, chunk| ReadEarlyData::new(conn, incoming_tls, chunk).into(),
37        )
38    }
39}
40
41impl<Data> UnbufferedConnectionCommon<Data> {
42    fn process_tls_records_common<'c, 'i, T>(
43        &'c mut self,
44        incoming_tls: &'i mut [u8],
45        mut check: impl FnMut(&mut Self) -> Option<T>,
46        execute: impl FnOnce(&'c mut Self, &'i mut [u8], T) -> ConnectionState<'c, 'i, Data>,
47    ) -> UnbufferedStatus<'c, 'i, Data> {
48        let mut buffer = DeframerSliceBuffer::new(incoming_tls);
49
50        let (discard, state) = loop {
51            if let Some(value) = check(self) {
52                break (buffer.pending_discard(), execute(self, incoming_tls, value));
53            }
54
55            if let Some(chunk) = self
56                .core
57                .common_state
58                .received_plaintext
59                .pop()
60            {
61                break (
62                    buffer.pending_discard(),
63                    ReadTraffic::new(self, incoming_tls, chunk).into(),
64                );
65            }
66
67            if let Some(chunk) = self
68                .core
69                .common_state
70                .sendable_tls
71                .pop()
72            {
73                break (
74                    buffer.pending_discard(),
75                    EncodeTlsData::new(self, chunk).into(),
76                );
77            }
78
79            let deframer_output = match self.core.deframe(None, &mut buffer) {
80                Err(err) => {
81                    return UnbufferedStatus {
82                        discard: buffer.pending_discard(),
83                        state: Err(err),
84                    };
85                }
86                Ok(r) => r,
87            };
88
89            if let Some(msg) = deframer_output {
90                let mut state =
91                    match mem::replace(&mut self.core.state, Err(Error::HandshakeNotComplete)) {
92                        Ok(state) => state,
93                        Err(e) => {
94                            self.core.state = Err(e.clone());
95                            return UnbufferedStatus {
96                                discard: buffer.pending_discard(),
97                                state: Err(e),
98                            };
99                        }
100                    };
101
102                match self.core.process_msg(msg, state, None) {
103                    Ok(new) => state = new,
104
105                    Err(e) => {
106                        self.core.state = Err(e.clone());
107                        return UnbufferedStatus {
108                            discard: buffer.pending_discard(),
109                            state: Err(e),
110                        };
111                    }
112                }
113
114                self.core.state = Ok(state);
115            } else if self.wants_write {
116                break (
117                    buffer.pending_discard(),
118                    TransmitTlsData { conn: self }.into(),
119                );
120            } else if self
121                .core
122                .common_state
123                .has_received_close_notify
124            {
125                break (buffer.pending_discard(), ConnectionState::Closed);
126            } else if self
127                .core
128                .common_state
129                .may_send_application_data
130            {
131                break (
132                    buffer.pending_discard(),
133                    ConnectionState::WriteTraffic(WriteTraffic { conn: self }),
134                );
135            } else {
136                break (buffer.pending_discard(), ConnectionState::BlockedHandshake);
137            }
138        };
139
140        UnbufferedStatus {
141            discard,
142            state: Ok(state),
143        }
144    }
145}
146
147#[must_use]
149#[derive(Debug)]
150pub struct UnbufferedStatus<'c, 'i, Data> {
151    pub discard: usize,
160
161    pub state: Result<ConnectionState<'c, 'i, Data>, Error>,
167}
168
169#[non_exhaustive] pub enum ConnectionState<'c, 'i, Data> {
172    ReadTraffic(ReadTraffic<'c, 'i, Data>),
177
178    Closed,
180
181    ReadEarlyData(ReadEarlyData<'c, 'i, Data>),
183
184    EncodeTlsData(EncodeTlsData<'c, Data>),
189
190    TransmitTlsData(TransmitTlsData<'c, Data>),
203
204    BlockedHandshake,
209
210    WriteTraffic(WriteTraffic<'c, Data>),
224}
225
226impl<'c, 'i, Data> From<ReadTraffic<'c, 'i, Data>> for ConnectionState<'c, 'i, Data> {
227    fn from(v: ReadTraffic<'c, 'i, Data>) -> Self {
228        Self::ReadTraffic(v)
229    }
230}
231
232impl<'c, 'i, Data> From<ReadEarlyData<'c, 'i, Data>> for ConnectionState<'c, 'i, Data> {
233    fn from(v: ReadEarlyData<'c, 'i, Data>) -> Self {
234        Self::ReadEarlyData(v)
235    }
236}
237
238impl<'c, 'i, Data> From<EncodeTlsData<'c, Data>> for ConnectionState<'c, 'i, Data> {
239    fn from(v: EncodeTlsData<'c, Data>) -> Self {
240        Self::EncodeTlsData(v)
241    }
242}
243
244impl<'c, 'i, Data> From<TransmitTlsData<'c, Data>> for ConnectionState<'c, 'i, Data> {
245    fn from(v: TransmitTlsData<'c, Data>) -> Self {
246        Self::TransmitTlsData(v)
247    }
248}
249
250impl<Data> fmt::Debug for ConnectionState<'_, '_, Data> {
251    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
252        match self {
253            Self::ReadTraffic(..) => f.debug_tuple("ReadTraffic").finish(),
254
255            Self::Closed => write!(f, "Closed"),
256
257            Self::ReadEarlyData(..) => f.debug_tuple("ReadEarlyData").finish(),
258
259            Self::EncodeTlsData(..) => f.debug_tuple("EncodeTlsData").finish(),
260
261            Self::TransmitTlsData(..) => f
262                .debug_tuple("TransmitTlsData")
263                .finish(),
264
265            Self::BlockedHandshake => f
266                .debug_tuple("BlockedHandshake")
267                .finish(),
268
269            Self::WriteTraffic(..) => f.debug_tuple("WriteTraffic").finish(),
270        }
271    }
272}
273
274pub struct ReadTraffic<'c, 'i, Data> {
276    _conn: &'c mut UnbufferedConnectionCommon<Data>,
277    _incoming_tls: &'i mut [u8],
279    chunk: Vec<u8>,
280    taken: bool,
281}
282
283impl<'c, 'i, Data> ReadTraffic<'c, 'i, Data> {
284    fn new(
285        _conn: &'c mut UnbufferedConnectionCommon<Data>,
286        _incoming_tls: &'i mut [u8],
287        chunk: Vec<u8>,
288    ) -> Self {
289        Self {
290            _conn,
291            _incoming_tls,
292            chunk,
293            taken: false,
294        }
295    }
296
297    pub fn next_record(&mut self) -> Option<Result<AppDataRecord, Error>> {
300        if self.taken {
301            None
302        } else {
303            self.taken = true;
304            Some(Ok(AppDataRecord {
305                discard: 0,
306                payload: &self.chunk,
307            }))
308        }
309    }
310
311    pub fn peek_len(&self) -> Option<NonZeroUsize> {
315        if self.taken {
316            None
317        } else {
318            NonZeroUsize::new(self.chunk.len())
319        }
320    }
321}
322
323pub struct ReadEarlyData<'c, 'i, Data> {
325    _conn: &'c mut UnbufferedConnectionCommon<Data>,
326    _incoming_tls: &'i mut [u8],
328    chunk: Vec<u8>,
329    taken: bool,
330}
331
332impl<'c, 'i, Data> ReadEarlyData<'c, 'i, Data> {
333    fn new(
334        _conn: &'c mut UnbufferedConnectionCommon<Data>,
335        _incoming_tls: &'i mut [u8],
336        chunk: Vec<u8>,
337    ) -> Self {
338        Self {
339            _conn,
340            _incoming_tls,
341            chunk,
342            taken: false,
343        }
344    }
345}
346
347impl<'c, 'i> ReadEarlyData<'c, 'i, ServerConnectionData> {
348    pub fn next_record(&mut self) -> Option<Result<AppDataRecord, Error>> {
351        if self.taken {
352            None
353        } else {
354            self.taken = true;
355            Some(Ok(AppDataRecord {
356                discard: 0,
357                payload: &self.chunk,
358            }))
359        }
360    }
361
362    pub fn peek_len(&self) -> Option<NonZeroUsize> {
366        if self.taken {
367            None
368        } else {
369            NonZeroUsize::new(self.chunk.len())
370        }
371    }
372}
373
374pub struct AppDataRecord<'i> {
376    pub discard: usize,
381
382    pub payload: &'i [u8],
384}
385
386pub struct WriteTraffic<'c, Data> {
388    conn: &'c mut UnbufferedConnectionCommon<Data>,
389}
390
391impl<Data> WriteTraffic<'_, Data> {
392    pub fn encrypt(
397        &mut self,
398        application_data: &[u8],
399        outgoing_tls: &mut [u8],
400    ) -> Result<usize, EncryptError> {
401        self.conn
402            .core
403            .common_state
404            .write_plaintext(application_data.into(), outgoing_tls)
405    }
406
407    pub fn queue_close_notify(&mut self, outgoing_tls: &mut [u8]) -> Result<usize, EncryptError> {
412        self.conn
413            .core
414            .common_state
415            .eager_send_close_notify(outgoing_tls)
416    }
417}
418
419pub struct EncodeTlsData<'c, Data> {
421    conn: &'c mut UnbufferedConnectionCommon<Data>,
422    chunk: Option<Vec<u8>>,
423}
424
425impl<'c, Data> EncodeTlsData<'c, Data> {
426    fn new(conn: &'c mut UnbufferedConnectionCommon<Data>, chunk: Vec<u8>) -> Self {
427        Self {
428            conn,
429            chunk: Some(chunk),
430        }
431    }
432
433    pub fn encode(&mut self, outgoing_tls: &mut [u8]) -> Result<usize, EncodeError> {
438        let chunk = match self.chunk.take() {
439            Some(chunk) => chunk,
440            None => return Err(EncodeError::AlreadyEncoded),
441        };
442
443        let required_size = chunk.len();
444
445        if required_size > outgoing_tls.len() {
446            self.chunk = Some(chunk);
447            Err(InsufficientSizeError { required_size }.into())
448        } else {
449            let written = chunk.len();
450            outgoing_tls[..written].copy_from_slice(&chunk);
451
452            self.conn.wants_write = true;
453
454            Ok(written)
455        }
456    }
457}
458
459pub struct TransmitTlsData<'c, Data> {
461    pub(crate) conn: &'c mut UnbufferedConnectionCommon<Data>,
462}
463
464impl<Data> TransmitTlsData<'_, Data> {
465    pub fn done(self) {
467        self.conn.wants_write = false;
468    }
469
470    pub fn may_encrypt_app_data(&mut self) -> Option<WriteTraffic<Data>> {
474        if self
475            .conn
476            .core
477            .common_state
478            .may_send_application_data
479        {
480            Some(WriteTraffic { conn: self.conn })
481        } else {
482            None
483        }
484    }
485}
486
487#[derive(Debug)]
489pub enum EncodeError {
490    InsufficientSize(InsufficientSizeError),
492
493    AlreadyEncoded,
495}
496
497impl From<InsufficientSizeError> for EncodeError {
498    fn from(v: InsufficientSizeError) -> Self {
499        Self::InsufficientSize(v)
500    }
501}
502
503impl fmt::Display for EncodeError {
504    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
505        match self {
506            Self::InsufficientSize(InsufficientSizeError { required_size }) => write!(
507                f,
508                "cannot encode due to insufficient size, {} bytes are required",
509                required_size
510            ),
511            Self::AlreadyEncoded => "cannot encode, data has already been encoded".fmt(f),
512        }
513    }
514}
515
516#[cfg(feature = "std")]
517impl StdError for EncodeError {}
518
519#[derive(Debug)]
521pub enum EncryptError {
522    InsufficientSize(InsufficientSizeError),
524
525    EncryptExhausted,
527}
528
529impl From<InsufficientSizeError> for EncryptError {
530    fn from(v: InsufficientSizeError) -> Self {
531        Self::InsufficientSize(v)
532    }
533}
534
535impl fmt::Display for EncryptError {
536    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
537        match self {
538            Self::InsufficientSize(InsufficientSizeError { required_size }) => write!(
539                f,
540                "cannot encrypt due to insufficient size, {required_size} bytes are required"
541            ),
542            Self::EncryptExhausted => f.write_str("encrypter has been exhausted"),
543        }
544    }
545}
546
547#[cfg(feature = "std")]
548impl StdError for EncryptError {}
549
550#[derive(Clone, Copy, Debug)]
552pub struct InsufficientSizeError {
553    pub required_size: usize,
555}