1use ring::aead::{
9 Aad, LessSafeKey, Nonce, UnboundKey, AES_128_GCM, AES_256_GCM, CHACHA20_POLY1305,
10};
11use thiserror::Error;
12
13use super::kdf::{AeadAlgorithm, Tls12KeyMaterial, Tls13KeyMaterial};
14
15#[derive(Debug, Error)]
17pub enum DecryptionError {
18 #[error("Invalid key length: expected {expected}, got {actual}")]
19 InvalidKeyLength { expected: usize, actual: usize },
20
21 #[error("Invalid IV length: expected {expected}, got {actual}")]
22 InvalidIvLength { expected: usize, actual: usize },
23
24 #[error("Invalid nonce length: expected 12, got {0}")]
25 InvalidNonceLength(usize),
26
27 #[error("Decryption failed: authentication tag mismatch")]
28 AuthenticationFailed,
29
30 #[error("Unsupported algorithm: {0:?}")]
31 UnsupportedAlgorithm(AeadAlgorithm),
32
33 #[error("Ciphertext too short: minimum {min_len} bytes, got {actual}")]
34 CiphertextTooShort { min_len: usize, actual: usize },
35}
36
37#[derive(Debug, Clone, Copy, PartialEq, Eq)]
39pub enum Direction {
40 ClientToServer,
41 ServerToClient,
42}
43
44pub struct DecryptionContext {
48 algorithm: AeadAlgorithm,
50 key: LessSafeKey,
52 iv: Vec<u8>,
54 sequence_number: u64,
56}
57
58impl DecryptionContext {
59 pub fn new_tls12(
65 keys: &Tls12KeyMaterial,
66 algorithm: AeadAlgorithm,
67 direction: Direction,
68 ) -> Result<Self, DecryptionError> {
69 let (key_bytes, iv_bytes) = match direction {
70 Direction::ClientToServer => (&keys.client_write_key, &keys.client_write_iv),
71 Direction::ServerToClient => (&keys.server_write_key, &keys.server_write_iv),
72 };
73
74 Self::new(algorithm, key_bytes, iv_bytes)
75 }
76
77 pub fn new_tls13(
83 keys: &Tls13KeyMaterial,
84 algorithm: AeadAlgorithm,
85 ) -> Result<Self, DecryptionError> {
86 Self::new(algorithm, &keys.key, &keys.iv)
87 }
88
89 pub fn new(algorithm: AeadAlgorithm, key: &[u8], iv: &[u8]) -> Result<Self, DecryptionError> {
91 let ring_algo = match algorithm {
92 AeadAlgorithm::Aes128Gcm => &AES_128_GCM,
93 AeadAlgorithm::Aes256Gcm => &AES_256_GCM,
94 AeadAlgorithm::Chacha20Poly1305 => &CHACHA20_POLY1305,
95 };
96
97 let expected_key_len = algorithm.key_len();
98 if key.len() != expected_key_len {
99 return Err(DecryptionError::InvalidKeyLength {
100 expected: expected_key_len,
101 actual: key.len(),
102 });
103 }
104
105 let unbound_key =
106 UnboundKey::new(ring_algo, key).map_err(|_| DecryptionError::InvalidKeyLength {
107 expected: expected_key_len,
108 actual: key.len(),
109 })?;
110
111 Ok(Self {
112 algorithm,
113 key: LessSafeKey::new(unbound_key),
114 iv: iv.to_vec(),
115 sequence_number: 0,
116 })
117 }
118
119 pub fn sequence_number(&self) -> u64 {
121 self.sequence_number
122 }
123
124 pub fn set_sequence_number(&mut self, seq: u64) {
126 self.sequence_number = seq;
127 }
128
129 pub fn decrypt_tls12_record(
138 &mut self,
139 record_type: u8,
140 version: u16,
141 ciphertext: &[u8],
142 ) -> Result<Vec<u8>, DecryptionError> {
143 let explicit_nonce_len = 8;
146 let tag_len = self.algorithm.tag_len();
147 let min_len = explicit_nonce_len + tag_len;
148
149 if ciphertext.len() < min_len {
150 return Err(DecryptionError::CiphertextTooShort {
151 min_len,
152 actual: ciphertext.len(),
153 });
154 }
155
156 let explicit_nonce = &ciphertext[..explicit_nonce_len];
157 let encrypted_with_tag = &ciphertext[explicit_nonce_len..];
158
159 let mut nonce_bytes = [0u8; 12];
161 nonce_bytes[..4].copy_from_slice(&self.iv[..4.min(self.iv.len())]);
162 nonce_bytes[4..].copy_from_slice(explicit_nonce);
163
164 let nonce = Nonce::try_assume_unique_for_key(&nonce_bytes)
165 .map_err(|_| DecryptionError::InvalidNonceLength(nonce_bytes.len()))?;
166
167 let plaintext_len = encrypted_with_tag.len() - tag_len;
169 let mut aad_bytes = [0u8; 13];
170 aad_bytes[..8].copy_from_slice(&self.sequence_number.to_be_bytes());
171 aad_bytes[8] = record_type;
172 aad_bytes[9..11].copy_from_slice(&version.to_be_bytes());
173 aad_bytes[11..13].copy_from_slice(&(plaintext_len as u16).to_be_bytes());
174
175 let aad = Aad::from(&aad_bytes);
176
177 let mut buffer = encrypted_with_tag.to_vec();
179 let plaintext = self
180 .key
181 .open_in_place(nonce, aad, &mut buffer)
182 .map_err(|_| DecryptionError::AuthenticationFailed)?;
183
184 self.sequence_number += 1;
185
186 Ok(plaintext.to_vec())
187 }
188
189 pub fn decrypt_tls13_record(
199 &mut self,
200 ciphertext: &[u8],
201 record_header: &[u8; 5],
202 ) -> Result<Vec<u8>, DecryptionError> {
203 let tag_len = self.algorithm.tag_len();
204
205 if ciphertext.len() < tag_len {
206 return Err(DecryptionError::CiphertextTooShort {
207 min_len: tag_len,
208 actual: ciphertext.len(),
209 });
210 }
211
212 let mut nonce_bytes = [0u8; 12];
214 nonce_bytes.copy_from_slice(&self.iv[..12.min(self.iv.len())]);
215
216 let seq_bytes = self.sequence_number.to_be_bytes();
218 for i in 0..8 {
219 nonce_bytes[4 + i] ^= seq_bytes[i];
220 }
221
222 let nonce = Nonce::try_assume_unique_for_key(&nonce_bytes)
223 .map_err(|_| DecryptionError::InvalidNonceLength(nonce_bytes.len()))?;
224
225 let aad = Aad::from(record_header);
227
228 let mut buffer = ciphertext.to_vec();
230 let plaintext = self
231 .key
232 .open_in_place(nonce, aad, &mut buffer)
233 .map_err(|_| DecryptionError::AuthenticationFailed)?;
234
235 self.sequence_number += 1;
236
237 Ok(plaintext.to_vec())
238 }
239
240 pub fn decrypt_record(
245 &mut self,
246 tls_version: TlsVersion,
247 record_type: u8,
248 protocol_version: u16,
249 ciphertext: &[u8],
250 ) -> Result<Vec<u8>, DecryptionError> {
251 match tls_version {
252 TlsVersion::Tls12 | TlsVersion::Tls11 | TlsVersion::Tls10 => {
253 self.decrypt_tls12_record(record_type, protocol_version, ciphertext)
254 }
255 TlsVersion::Tls13 => {
256 let mut header = [0u8; 5];
258 header[0] = record_type;
259 header[1..3].copy_from_slice(&protocol_version.to_be_bytes());
260 header[3..5].copy_from_slice(&(ciphertext.len() as u16).to_be_bytes());
261 self.decrypt_tls13_record(ciphertext, &header)
262 }
263 }
264 }
265}
266
267#[derive(Debug, Clone, Copy, PartialEq, Eq)]
269pub enum TlsVersion {
270 Tls10,
271 Tls11,
272 Tls12,
273 Tls13,
274}
275
276impl TlsVersion {
277 pub fn from_wire(version: u16) -> Option<Self> {
279 match version {
280 0x0301 => Some(TlsVersion::Tls10),
281 0x0302 => Some(TlsVersion::Tls11),
282 0x0303 => Some(TlsVersion::Tls12), 0x0304 => Some(TlsVersion::Tls13), _ => None,
285 }
286 }
287
288 pub fn to_wire(&self) -> u16 {
290 match self {
291 TlsVersion::Tls10 => 0x0301,
292 TlsVersion::Tls11 => 0x0302,
293 TlsVersion::Tls12 | TlsVersion::Tls13 => 0x0303, }
295 }
296}
297
298pub fn extract_tls13_inner_content_type(plaintext: &[u8]) -> Option<(u8, &[u8])> {
303 let mut i = plaintext.len();
305 while i > 0 && plaintext[i - 1] == 0 {
306 i -= 1;
307 }
308
309 if i == 0 {
310 return None;
311 }
312
313 let content_type = plaintext[i - 1];
314 let content = &plaintext[..i - 1];
315
316 Some((content_type, content))
317}
318
319#[cfg(test)]
320mod tests {
321 use super::*;
322 use crate::tls::kdf::{derive_tls12_keys, derive_tls13_keys};
323
324 #[test]
325 fn test_decryption_context_creation() {
326 let key = [0x42u8; 16];
327 let iv = [0x01u8; 12];
328
329 let ctx = DecryptionContext::new(AeadAlgorithm::Aes128Gcm, &key, &iv);
330 assert!(ctx.is_ok());
331
332 let ctx = ctx.unwrap();
333 assert_eq!(ctx.sequence_number(), 0);
334 }
335
336 #[test]
337 fn test_decryption_context_wrong_key_length() {
338 let key = [0x42u8; 15]; let iv = [0x01u8; 12];
340
341 let result = DecryptionContext::new(AeadAlgorithm::Aes128Gcm, &key, &iv);
342 assert!(matches!(
343 result,
344 Err(DecryptionError::InvalidKeyLength { .. })
345 ));
346 }
347
348 #[test]
349 fn test_tls12_context_from_keys() {
350 let master_secret = [0x42u8; 48];
351 let client_random = [0x01u8; 32];
352 let server_random = [0x02u8; 32];
353
354 let keys =
355 derive_tls12_keys(&master_secret, &client_random, &server_random, 0xC02F).unwrap();
356
357 let ctx = DecryptionContext::new_tls12(
358 &keys,
359 AeadAlgorithm::Aes128Gcm,
360 Direction::ClientToServer,
361 );
362 assert!(ctx.is_ok());
363
364 let ctx = DecryptionContext::new_tls12(
365 &keys,
366 AeadAlgorithm::Aes128Gcm,
367 Direction::ServerToClient,
368 );
369 assert!(ctx.is_ok());
370 }
371
372 #[test]
373 fn test_tls13_context_from_keys() {
374 let traffic_secret = [0x42u8; 32];
375 let keys = derive_tls13_keys(&traffic_secret, 0x1301).unwrap();
376
377 let ctx = DecryptionContext::new_tls13(&keys, AeadAlgorithm::Aes128Gcm);
378 assert!(ctx.is_ok());
379 }
380
381 #[test]
382 fn test_sequence_number() {
383 let key = [0x42u8; 16];
384 let iv = [0x01u8; 12];
385
386 let mut ctx = DecryptionContext::new(AeadAlgorithm::Aes128Gcm, &key, &iv).unwrap();
387
388 assert_eq!(ctx.sequence_number(), 0);
389 ctx.set_sequence_number(100);
390 assert_eq!(ctx.sequence_number(), 100);
391 }
392
393 #[test]
394 fn test_tls_version_from_wire() {
395 assert_eq!(TlsVersion::from_wire(0x0301), Some(TlsVersion::Tls10));
396 assert_eq!(TlsVersion::from_wire(0x0302), Some(TlsVersion::Tls11));
397 assert_eq!(TlsVersion::from_wire(0x0303), Some(TlsVersion::Tls12));
398 assert_eq!(TlsVersion::from_wire(0x0304), Some(TlsVersion::Tls13));
399 assert_eq!(TlsVersion::from_wire(0x0300), None);
400 }
401
402 #[test]
403 fn test_tls_version_to_wire() {
404 assert_eq!(TlsVersion::Tls10.to_wire(), 0x0301);
405 assert_eq!(TlsVersion::Tls11.to_wire(), 0x0302);
406 assert_eq!(TlsVersion::Tls12.to_wire(), 0x0303);
407 assert_eq!(TlsVersion::Tls13.to_wire(), 0x0303); }
409
410 #[test]
411 fn test_extract_tls13_inner_content_type() {
412 let plaintext = [0x48, 0x54, 0x54, 0x50, 0x17]; let result = extract_tls13_inner_content_type(&plaintext);
415 assert!(result.is_some());
416 let (content_type, content) = result.unwrap();
417 assert_eq!(content_type, 0x17);
418 assert_eq!(content, &[0x48, 0x54, 0x54, 0x50]);
419
420 let plaintext = [0x48, 0x54, 0x17, 0x00, 0x00];
422 let result = extract_tls13_inner_content_type(&plaintext);
423 assert!(result.is_some());
424 let (content_type, content) = result.unwrap();
425 assert_eq!(content_type, 0x17);
426 assert_eq!(content, &[0x48, 0x54]);
427
428 let plaintext = [0x17];
430 let result = extract_tls13_inner_content_type(&plaintext);
431 assert!(result.is_some());
432 let (content_type, content) = result.unwrap();
433 assert_eq!(content_type, 0x17);
434 assert!(content.is_empty());
435
436 let plaintext = [0x00, 0x00, 0x00];
438 let result = extract_tls13_inner_content_type(&plaintext);
439 assert!(result.is_none());
440
441 let plaintext: [u8; 0] = [];
443 let result = extract_tls13_inner_content_type(&plaintext);
444 assert!(result.is_none());
445 }
446
447 #[test]
448 fn test_decrypt_tls12_record_too_short() {
449 let key = [0x42u8; 16];
450 let iv = [0x01u8; 4]; let mut ctx = DecryptionContext::new(AeadAlgorithm::Aes128Gcm, &key, &iv).unwrap();
453
454 let ciphertext = [0u8; 20];
456 let result = ctx.decrypt_tls12_record(23, 0x0303, &ciphertext);
457 assert!(matches!(
458 result,
459 Err(DecryptionError::CiphertextTooShort { .. })
460 ));
461 }
462
463 #[test]
464 fn test_decrypt_tls13_record_too_short() {
465 let key = [0x42u8; 16];
466 let iv = [0x01u8; 12];
467
468 let mut ctx = DecryptionContext::new(AeadAlgorithm::Aes128Gcm, &key, &iv).unwrap();
469
470 let ciphertext = [0u8; 10];
472 let header = [0x17, 0x03, 0x03, 0x00, 0x0A];
473 let result = ctx.decrypt_tls13_record(&ciphertext, &header);
474 assert!(matches!(
475 result,
476 Err(DecryptionError::CiphertextTooShort { .. })
477 ));
478 }
479
480 }