clia_rustls_mod/
record_layer.rs

1use alloc::boxed::Box;
2use core::num::NonZeroU64;
3
4use crate::crypto::cipher::{InboundOpaqueMessage, MessageDecrypter, MessageEncrypter};
5use crate::error::Error;
6#[cfg(feature = "logging")]
7use crate::log::trace;
8use crate::msgs::message::{InboundPlainMessage, OutboundOpaqueMessage, OutboundPlainMessage};
9
10static SEQ_SOFT_LIMIT: u64 = 0xffff_ffff_ffff_0000u64;
11static SEQ_HARD_LIMIT: u64 = 0xffff_ffff_ffff_fffeu64;
12
13#[derive(PartialEq)]
14enum DirectionState {
15    /// No keying material.
16    Invalid,
17
18    /// Keying material present, but not yet in use.
19    Prepared,
20
21    /// Keying material in use.
22    Active,
23}
24
25/// Record layer that tracks decryption and encryption keys.
26pub struct RecordLayer {
27    message_encrypter: Box<dyn MessageEncrypter>,
28    message_decrypter: Box<dyn MessageDecrypter>,
29    write_seq: u64,
30    read_seq: u64,
31    has_decrypted: bool,
32    encrypt_state: DirectionState,
33    decrypt_state: DirectionState,
34
35    // Message encrypted with other keys may be encountered, so failures
36    // should be swallowed by the caller.  This struct tracks the amount
37    // of message size this is allowed for.
38    trial_decryption_len: Option<usize>,
39}
40
41impl RecordLayer {
42    /// Create new record layer with no keys.
43    pub fn new() -> Self {
44        Self {
45            message_encrypter: <dyn MessageEncrypter>::invalid(),
46            message_decrypter: <dyn MessageDecrypter>::invalid(),
47            write_seq: 0,
48            read_seq: 0,
49            has_decrypted: false,
50            encrypt_state: DirectionState::Invalid,
51            decrypt_state: DirectionState::Invalid,
52            trial_decryption_len: None,
53        }
54    }
55
56    /// Decrypt a TLS message.
57    ///
58    /// `encr` is a decoded message allegedly received from the peer.
59    /// If it can be decrypted, its decryption is returned.  Otherwise,
60    /// an error is returned.
61    pub(crate) fn decrypt_incoming<'a>(
62        &mut self,
63        encr: InboundOpaqueMessage<'a>,
64    ) -> Result<Option<Decrypted<'a>>, Error> {
65        if self.decrypt_state != DirectionState::Active {
66            return Ok(Some(Decrypted {
67                want_close_before_decrypt: false,
68                plaintext: encr.into_plain_message(),
69            }));
70        }
71
72        // Set to `true` if the peer appears to getting close to encrypting
73        // too many messages with this key.
74        //
75        // Perhaps if we send an alert well before their counter wraps, a
76        // buggy peer won't make a terrible mistake here?
77        //
78        // Note that there's no reason to refuse to decrypt: the security
79        // failure has already happened.
80        let want_close_before_decrypt = self.read_seq == SEQ_SOFT_LIMIT;
81
82        let encrypted_len = encr.payload.len();
83        match self
84            .message_decrypter
85            .decrypt(encr, self.read_seq)
86        {
87            Ok(plaintext) => {
88                self.read_seq += 1;
89                if !self.has_decrypted {
90                    self.has_decrypted = true;
91                }
92                Ok(Some(Decrypted {
93                    want_close_before_decrypt,
94                    plaintext,
95                }))
96            }
97            Err(Error::DecryptError) if self.doing_trial_decryption(encrypted_len) => {
98                trace!("Dropping undecryptable message after aborted early_data");
99                Ok(None)
100            }
101            Err(err) => Err(err),
102        }
103    }
104
105    /// Encrypt a TLS message.
106    ///
107    /// `plain` is a TLS message we'd like to send.  This function
108    /// panics if the requisite keying material hasn't been established yet.
109    pub(crate) fn encrypt_outgoing(
110        &mut self,
111        plain: OutboundPlainMessage,
112    ) -> OutboundOpaqueMessage {
113        debug_assert!(self.encrypt_state == DirectionState::Active);
114        assert!(!self.encrypt_exhausted());
115        let seq = self.write_seq;
116        self.write_seq += 1;
117        self.message_encrypter
118            .encrypt(plain, seq)
119            .unwrap()
120    }
121
122    /// Prepare to use the given `MessageEncrypter` for future message encryption.
123    /// It is not used until you call `start_encrypting`.
124    pub(crate) fn prepare_message_encrypter(&mut self, cipher: Box<dyn MessageEncrypter>) {
125        self.message_encrypter = cipher;
126        self.write_seq = 0;
127        self.encrypt_state = DirectionState::Prepared;
128    }
129
130    /// Prepare to use the given `MessageDecrypter` for future message decryption.
131    /// It is not used until you call `start_decrypting`.
132    pub(crate) fn prepare_message_decrypter(&mut self, cipher: Box<dyn MessageDecrypter>) {
133        self.message_decrypter = cipher;
134        self.read_seq = 0;
135        self.decrypt_state = DirectionState::Prepared;
136    }
137
138    /// Start using the `MessageEncrypter` previously provided to the previous
139    /// call to `prepare_message_encrypter`.
140    pub(crate) fn start_encrypting(&mut self) {
141        debug_assert!(self.encrypt_state == DirectionState::Prepared);
142        self.encrypt_state = DirectionState::Active;
143    }
144
145    /// Start using the `MessageDecrypter` previously provided to the previous
146    /// call to `prepare_message_decrypter`.
147    pub(crate) fn start_decrypting(&mut self) {
148        debug_assert!(self.decrypt_state == DirectionState::Prepared);
149        self.decrypt_state = DirectionState::Active;
150    }
151
152    /// Set and start using the given `MessageEncrypter` for future outgoing
153    /// message encryption.
154    pub(crate) fn set_message_encrypter(&mut self, cipher: Box<dyn MessageEncrypter>) {
155        self.prepare_message_encrypter(cipher);
156        self.start_encrypting();
157    }
158
159    /// Set and start using the given `MessageDecrypter` for future incoming
160    /// message decryption.
161    pub(crate) fn set_message_decrypter(&mut self, cipher: Box<dyn MessageDecrypter>) {
162        self.prepare_message_decrypter(cipher);
163        self.start_decrypting();
164        self.trial_decryption_len = None;
165    }
166
167    /// Set and start using the given `MessageDecrypter` for future incoming
168    /// message decryption, and enable "trial decryption" mode for when TLS1.3
169    /// 0-RTT is attempted but rejected by the server.
170    pub(crate) fn set_message_decrypter_with_trial_decryption(
171        &mut self,
172        cipher: Box<dyn MessageDecrypter>,
173        max_length: usize,
174    ) {
175        self.prepare_message_decrypter(cipher);
176        self.start_decrypting();
177        self.trial_decryption_len = Some(max_length);
178    }
179
180    pub(crate) fn finish_trial_decryption(&mut self) {
181        self.trial_decryption_len = None;
182    }
183
184    /// Return true if we are getting close to encrypting too many
185    /// messages with our encryption key.
186    pub(crate) fn wants_close_before_encrypt(&self) -> bool {
187        self.write_seq == SEQ_SOFT_LIMIT
188    }
189
190    /// Return true if we outright refuse to do anything with the
191    /// encryption key.
192    pub(crate) fn encrypt_exhausted(&self) -> bool {
193        self.write_seq >= SEQ_HARD_LIMIT
194    }
195
196    pub(crate) fn is_encrypting(&self) -> bool {
197        self.encrypt_state == DirectionState::Active
198    }
199
200    /// Return true if we have ever decrypted a message. This is used in place
201    /// of checking the read_seq since that will be reset on key updates.
202    pub(crate) fn has_decrypted(&self) -> bool {
203        self.has_decrypted
204    }
205
206    pub(crate) fn write_seq(&self) -> u64 {
207        self.write_seq
208    }
209
210    /// Returns the number of remaining write sequences
211    pub(crate) fn remaining_write_seq(&self) -> Option<NonZeroU64> {
212        SEQ_SOFT_LIMIT
213            .checked_sub(self.write_seq)
214            .and_then(NonZeroU64::new)
215    }
216
217    pub(crate) fn read_seq(&self) -> u64 {
218        self.read_seq
219    }
220
221    pub(crate) fn encrypted_len(&self, payload_len: usize) -> usize {
222        self.message_encrypter
223            .encrypted_payload_len(payload_len)
224    }
225
226    fn doing_trial_decryption(&mut self, requested: usize) -> bool {
227        match self
228            .trial_decryption_len
229            .and_then(|value| value.checked_sub(requested))
230        {
231            Some(remaining) => {
232                self.trial_decryption_len = Some(remaining);
233                true
234            }
235            _ => false,
236        }
237    }
238}
239
240/// Result of decryption.
241#[derive(Debug)]
242pub(crate) struct Decrypted<'a> {
243    /// Whether the peer appears to be getting close to encrypting too many messages with this key.
244    pub(crate) want_close_before_decrypt: bool,
245    /// The decrypted message.
246    pub(crate) plaintext: InboundPlainMessage<'a>,
247}
248
249#[cfg(test)]
250mod tests {
251    use super::*;
252
253    #[test]
254    fn test_has_decrypted() {
255        use crate::{ContentType, ProtocolVersion};
256
257        struct PassThroughDecrypter;
258        impl MessageDecrypter for PassThroughDecrypter {
259            fn decrypt<'a>(
260                &mut self,
261                m: InboundOpaqueMessage<'a>,
262                _: u64,
263            ) -> Result<InboundPlainMessage<'a>, Error> {
264                Ok(m.into_plain_message())
265            }
266        }
267
268        // A record layer starts out invalid, having never decrypted.
269        let mut record_layer = RecordLayer::new();
270        assert!(matches!(
271            record_layer.decrypt_state,
272            DirectionState::Invalid
273        ));
274        assert_eq!(record_layer.read_seq, 0);
275        assert!(!record_layer.has_decrypted());
276
277        // Preparing the record layer should update the decrypt state, but shouldn't affect whether it
278        // has decrypted.
279        record_layer.prepare_message_decrypter(Box::new(PassThroughDecrypter));
280        assert!(matches!(
281            record_layer.decrypt_state,
282            DirectionState::Prepared
283        ));
284        assert_eq!(record_layer.read_seq, 0);
285        assert!(!record_layer.has_decrypted());
286
287        // Starting decryption should update the decrypt state, but not affect whether it has decrypted.
288        record_layer.start_decrypting();
289        assert!(matches!(record_layer.decrypt_state, DirectionState::Active));
290        assert_eq!(record_layer.read_seq, 0);
291        assert!(!record_layer.has_decrypted());
292
293        // Decrypting a message should update the read_seq and track that we have now performed
294        // a decryption.
295        record_layer
296            .decrypt_incoming(InboundOpaqueMessage::new(
297                ContentType::Handshake,
298                ProtocolVersion::TLSv1_2,
299                &mut [0xC0, 0xFF, 0xEE],
300            ))
301            .unwrap();
302        assert!(matches!(record_layer.decrypt_state, DirectionState::Active));
303        assert_eq!(record_layer.read_seq, 1);
304        assert!(record_layer.has_decrypted());
305
306        // Resetting the record layer message decrypter (as if a key update occurred) should reset
307        // the read_seq number, but not our knowledge of whether we have decrypted previously.
308        record_layer.set_message_decrypter(Box::new(PassThroughDecrypter));
309        assert!(matches!(record_layer.decrypt_state, DirectionState::Active));
310        assert_eq!(record_layer.read_seq, 0);
311        assert!(record_layer.has_decrypted());
312    }
313}