clia_rustls_mod/
record_layer.rs1use alloc::boxed::Box;
2use core::num::NonZeroU64;
3
4use crate::crypto::cipher::{InboundOpaqueMessage, MessageDecrypter, MessageEncrypter};
5use crate::error::Error;
6#[cfg(feature = "logging")]
7use crate::log::trace;
8use crate::msgs::message::{InboundPlainMessage, OutboundOpaqueMessage, OutboundPlainMessage};
9
10static SEQ_SOFT_LIMIT: u64 = 0xffff_ffff_ffff_0000u64;
11static SEQ_HARD_LIMIT: u64 = 0xffff_ffff_ffff_fffeu64;
12
13#[derive(PartialEq)]
14enum DirectionState {
15 Invalid,
17
18 Prepared,
20
21 Active,
23}
24
25pub struct RecordLayer {
27 message_encrypter: Box<dyn MessageEncrypter>,
28 message_decrypter: Box<dyn MessageDecrypter>,
29 write_seq: u64,
30 read_seq: u64,
31 has_decrypted: bool,
32 encrypt_state: DirectionState,
33 decrypt_state: DirectionState,
34
35 trial_decryption_len: Option<usize>,
39}
40
41impl RecordLayer {
42 pub fn new() -> Self {
44 Self {
45 message_encrypter: <dyn MessageEncrypter>::invalid(),
46 message_decrypter: <dyn MessageDecrypter>::invalid(),
47 write_seq: 0,
48 read_seq: 0,
49 has_decrypted: false,
50 encrypt_state: DirectionState::Invalid,
51 decrypt_state: DirectionState::Invalid,
52 trial_decryption_len: None,
53 }
54 }
55
56 pub(crate) fn decrypt_incoming<'a>(
62 &mut self,
63 encr: InboundOpaqueMessage<'a>,
64 ) -> Result<Option<Decrypted<'a>>, Error> {
65 if self.decrypt_state != DirectionState::Active {
66 return Ok(Some(Decrypted {
67 want_close_before_decrypt: false,
68 plaintext: encr.into_plain_message(),
69 }));
70 }
71
72 let want_close_before_decrypt = self.read_seq == SEQ_SOFT_LIMIT;
81
82 let encrypted_len = encr.payload.len();
83 match self
84 .message_decrypter
85 .decrypt(encr, self.read_seq)
86 {
87 Ok(plaintext) => {
88 self.read_seq += 1;
89 if !self.has_decrypted {
90 self.has_decrypted = true;
91 }
92 Ok(Some(Decrypted {
93 want_close_before_decrypt,
94 plaintext,
95 }))
96 }
97 Err(Error::DecryptError) if self.doing_trial_decryption(encrypted_len) => {
98 trace!("Dropping undecryptable message after aborted early_data");
99 Ok(None)
100 }
101 Err(err) => Err(err),
102 }
103 }
104
105 pub(crate) fn encrypt_outgoing(
110 &mut self,
111 plain: OutboundPlainMessage,
112 ) -> OutboundOpaqueMessage {
113 debug_assert!(self.encrypt_state == DirectionState::Active);
114 assert!(!self.encrypt_exhausted());
115 let seq = self.write_seq;
116 self.write_seq += 1;
117 self.message_encrypter
118 .encrypt(plain, seq)
119 .unwrap()
120 }
121
122 pub(crate) fn prepare_message_encrypter(&mut self, cipher: Box<dyn MessageEncrypter>) {
125 self.message_encrypter = cipher;
126 self.write_seq = 0;
127 self.encrypt_state = DirectionState::Prepared;
128 }
129
130 pub(crate) fn prepare_message_decrypter(&mut self, cipher: Box<dyn MessageDecrypter>) {
133 self.message_decrypter = cipher;
134 self.read_seq = 0;
135 self.decrypt_state = DirectionState::Prepared;
136 }
137
138 pub(crate) fn start_encrypting(&mut self) {
141 debug_assert!(self.encrypt_state == DirectionState::Prepared);
142 self.encrypt_state = DirectionState::Active;
143 }
144
145 pub(crate) fn start_decrypting(&mut self) {
148 debug_assert!(self.decrypt_state == DirectionState::Prepared);
149 self.decrypt_state = DirectionState::Active;
150 }
151
152 pub(crate) fn set_message_encrypter(&mut self, cipher: Box<dyn MessageEncrypter>) {
155 self.prepare_message_encrypter(cipher);
156 self.start_encrypting();
157 }
158
159 pub(crate) fn set_message_decrypter(&mut self, cipher: Box<dyn MessageDecrypter>) {
162 self.prepare_message_decrypter(cipher);
163 self.start_decrypting();
164 self.trial_decryption_len = None;
165 }
166
167 pub(crate) fn set_message_decrypter_with_trial_decryption(
171 &mut self,
172 cipher: Box<dyn MessageDecrypter>,
173 max_length: usize,
174 ) {
175 self.prepare_message_decrypter(cipher);
176 self.start_decrypting();
177 self.trial_decryption_len = Some(max_length);
178 }
179
180 pub(crate) fn finish_trial_decryption(&mut self) {
181 self.trial_decryption_len = None;
182 }
183
184 pub(crate) fn wants_close_before_encrypt(&self) -> bool {
187 self.write_seq == SEQ_SOFT_LIMIT
188 }
189
190 pub(crate) fn encrypt_exhausted(&self) -> bool {
193 self.write_seq >= SEQ_HARD_LIMIT
194 }
195
196 pub(crate) fn is_encrypting(&self) -> bool {
197 self.encrypt_state == DirectionState::Active
198 }
199
200 pub(crate) fn has_decrypted(&self) -> bool {
203 self.has_decrypted
204 }
205
206 pub(crate) fn write_seq(&self) -> u64 {
207 self.write_seq
208 }
209
210 pub(crate) fn remaining_write_seq(&self) -> Option<NonZeroU64> {
212 SEQ_SOFT_LIMIT
213 .checked_sub(self.write_seq)
214 .and_then(NonZeroU64::new)
215 }
216
217 pub(crate) fn read_seq(&self) -> u64 {
218 self.read_seq
219 }
220
221 pub(crate) fn encrypted_len(&self, payload_len: usize) -> usize {
222 self.message_encrypter
223 .encrypted_payload_len(payload_len)
224 }
225
226 fn doing_trial_decryption(&mut self, requested: usize) -> bool {
227 match self
228 .trial_decryption_len
229 .and_then(|value| value.checked_sub(requested))
230 {
231 Some(remaining) => {
232 self.trial_decryption_len = Some(remaining);
233 true
234 }
235 _ => false,
236 }
237 }
238}
239
240#[derive(Debug)]
242pub(crate) struct Decrypted<'a> {
243 pub(crate) want_close_before_decrypt: bool,
245 pub(crate) plaintext: InboundPlainMessage<'a>,
247}
248
249#[cfg(test)]
250mod tests {
251 use super::*;
252
253 #[test]
254 fn test_has_decrypted() {
255 use crate::{ContentType, ProtocolVersion};
256
257 struct PassThroughDecrypter;
258 impl MessageDecrypter for PassThroughDecrypter {
259 fn decrypt<'a>(
260 &mut self,
261 m: InboundOpaqueMessage<'a>,
262 _: u64,
263 ) -> Result<InboundPlainMessage<'a>, Error> {
264 Ok(m.into_plain_message())
265 }
266 }
267
268 let mut record_layer = RecordLayer::new();
270 assert!(matches!(
271 record_layer.decrypt_state,
272 DirectionState::Invalid
273 ));
274 assert_eq!(record_layer.read_seq, 0);
275 assert!(!record_layer.has_decrypted());
276
277 record_layer.prepare_message_decrypter(Box::new(PassThroughDecrypter));
280 assert!(matches!(
281 record_layer.decrypt_state,
282 DirectionState::Prepared
283 ));
284 assert_eq!(record_layer.read_seq, 0);
285 assert!(!record_layer.has_decrypted());
286
287 record_layer.start_decrypting();
289 assert!(matches!(record_layer.decrypt_state, DirectionState::Active));
290 assert_eq!(record_layer.read_seq, 0);
291 assert!(!record_layer.has_decrypted());
292
293 record_layer
296 .decrypt_incoming(InboundOpaqueMessage::new(
297 ContentType::Handshake,
298 ProtocolVersion::TLSv1_2,
299 &mut [0xC0, 0xFF, 0xEE],
300 ))
301 .unwrap();
302 assert!(matches!(record_layer.decrypt_state, DirectionState::Active));
303 assert_eq!(record_layer.read_seq, 1);
304 assert!(record_layer.has_decrypted());
305
306 record_layer.set_message_decrypter(Box::new(PassThroughDecrypter));
309 assert!(matches!(record_layer.decrypt_state, DirectionState::Active));
310 assert_eq!(record_layer.read_seq, 0);
311 assert!(record_layer.has_decrypted());
312 }
313}