pcapsql_core/tls/
decrypt.rs

1//! TLS record decryption engine.
2//!
3//! Implements AEAD decryption for TLS 1.2 and TLS 1.3 records using:
4//! - AES-128-GCM
5//! - AES-256-GCM
6//! - ChaCha20-Poly1305
7
8use ring::aead::{
9    Aad, LessSafeKey, Nonce, UnboundKey, AES_128_GCM, AES_256_GCM, CHACHA20_POLY1305,
10};
11use thiserror::Error;
12
13use super::kdf::{AeadAlgorithm, Tls12KeyMaterial, Tls13KeyMaterial};
14
15/// Errors during TLS record decryption.
16#[derive(Debug, Error)]
17pub enum DecryptionError {
18    #[error("Invalid key length: expected {expected}, got {actual}")]
19    InvalidKeyLength { expected: usize, actual: usize },
20
21    #[error("Invalid IV length: expected {expected}, got {actual}")]
22    InvalidIvLength { expected: usize, actual: usize },
23
24    #[error("Invalid nonce length: expected 12, got {0}")]
25    InvalidNonceLength(usize),
26
27    #[error("Decryption failed: authentication tag mismatch")]
28    AuthenticationFailed,
29
30    #[error("Unsupported algorithm: {0:?}")]
31    UnsupportedAlgorithm(AeadAlgorithm),
32
33    #[error("Ciphertext too short: minimum {min_len} bytes, got {actual}")]
34    CiphertextTooShort { min_len: usize, actual: usize },
35}
36
37/// Direction of TLS traffic.
38#[derive(Debug, Clone, Copy, PartialEq, Eq)]
39pub enum Direction {
40    ClientToServer,
41    ServerToClient,
42}
43
44/// TLS record decryption context.
45///
46/// Holds the AEAD key material and provides methods to decrypt TLS records.
47pub struct DecryptionContext {
48    /// The AEAD algorithm
49    algorithm: AeadAlgorithm,
50    /// Decryption key
51    key: LessSafeKey,
52    /// Implicit IV (TLS 1.2) or base IV (TLS 1.3)
53    iv: Vec<u8>,
54    /// Current sequence number for nonce construction
55    sequence_number: u64,
56}
57
58impl DecryptionContext {
59    /// Create a new decryption context for TLS 1.2.
60    ///
61    /// For TLS 1.2 with AEAD:
62    /// - The nonce is: implicit_iv (4 bytes) || explicit_nonce (8 bytes from record)
63    /// - The explicit_nonce is typically the sequence number
64    pub fn new_tls12(
65        keys: &Tls12KeyMaterial,
66        algorithm: AeadAlgorithm,
67        direction: Direction,
68    ) -> Result<Self, DecryptionError> {
69        let (key_bytes, iv_bytes) = match direction {
70            Direction::ClientToServer => (&keys.client_write_key, &keys.client_write_iv),
71            Direction::ServerToClient => (&keys.server_write_key, &keys.server_write_iv),
72        };
73
74        Self::new(algorithm, key_bytes, iv_bytes)
75    }
76
77    /// Create a new decryption context for TLS 1.3.
78    ///
79    /// For TLS 1.3:
80    /// - The nonce is: iv XOR padded_sequence_number
81    /// - The IV is derived from the traffic secret via HKDF-Expand-Label
82    pub fn new_tls13(
83        keys: &Tls13KeyMaterial,
84        algorithm: AeadAlgorithm,
85    ) -> Result<Self, DecryptionError> {
86        Self::new(algorithm, &keys.key, &keys.iv)
87    }
88
89    /// Create a new decryption context from raw key material.
90    pub fn new(algorithm: AeadAlgorithm, key: &[u8], iv: &[u8]) -> Result<Self, DecryptionError> {
91        let ring_algo = match algorithm {
92            AeadAlgorithm::Aes128Gcm => &AES_128_GCM,
93            AeadAlgorithm::Aes256Gcm => &AES_256_GCM,
94            AeadAlgorithm::Chacha20Poly1305 => &CHACHA20_POLY1305,
95        };
96
97        let expected_key_len = algorithm.key_len();
98        if key.len() != expected_key_len {
99            return Err(DecryptionError::InvalidKeyLength {
100                expected: expected_key_len,
101                actual: key.len(),
102            });
103        }
104
105        let unbound_key =
106            UnboundKey::new(ring_algo, key).map_err(|_| DecryptionError::InvalidKeyLength {
107                expected: expected_key_len,
108                actual: key.len(),
109            })?;
110
111        Ok(Self {
112            algorithm,
113            key: LessSafeKey::new(unbound_key),
114            iv: iv.to_vec(),
115            sequence_number: 0,
116        })
117    }
118
119    /// Get the current sequence number.
120    pub fn sequence_number(&self) -> u64 {
121        self.sequence_number
122    }
123
124    /// Set the sequence number (useful for resuming mid-stream).
125    pub fn set_sequence_number(&mut self, seq: u64) {
126        self.sequence_number = seq;
127    }
128
129    /// Decrypt a TLS 1.2 AEAD record in place.
130    ///
131    /// For TLS 1.2 AEAD ciphers:
132    /// - Record format: explicit_nonce (8 bytes) || ciphertext || tag (16 bytes)
133    /// - Nonce = implicit_iv (4 bytes) || explicit_nonce (8 bytes)
134    /// - AAD = seq_num (8 bytes) || type (1) || version (2) || length (2)
135    ///
136    /// Returns the decrypted plaintext.
137    pub fn decrypt_tls12_record(
138        &mut self,
139        record_type: u8,
140        version: u16,
141        ciphertext: &[u8],
142    ) -> Result<Vec<u8>, DecryptionError> {
143        // TLS 1.2 AEAD record structure:
144        // explicit_nonce (8 bytes) || encrypted_data || auth_tag (16 bytes)
145        let explicit_nonce_len = 8;
146        let tag_len = self.algorithm.tag_len();
147        let min_len = explicit_nonce_len + tag_len;
148
149        if ciphertext.len() < min_len {
150            return Err(DecryptionError::CiphertextTooShort {
151                min_len,
152                actual: ciphertext.len(),
153            });
154        }
155
156        let explicit_nonce = &ciphertext[..explicit_nonce_len];
157        let encrypted_with_tag = &ciphertext[explicit_nonce_len..];
158
159        // Construct the 12-byte nonce: implicit_iv (4) || explicit_nonce (8)
160        let mut nonce_bytes = [0u8; 12];
161        nonce_bytes[..4].copy_from_slice(&self.iv[..4.min(self.iv.len())]);
162        nonce_bytes[4..].copy_from_slice(explicit_nonce);
163
164        let nonce = Nonce::try_assume_unique_for_key(&nonce_bytes)
165            .map_err(|_| DecryptionError::InvalidNonceLength(nonce_bytes.len()))?;
166
167        // Construct AAD: seq_num (8) || type (1) || version (2) || length (2)
168        let plaintext_len = encrypted_with_tag.len() - tag_len;
169        let mut aad_bytes = [0u8; 13];
170        aad_bytes[..8].copy_from_slice(&self.sequence_number.to_be_bytes());
171        aad_bytes[8] = record_type;
172        aad_bytes[9..11].copy_from_slice(&version.to_be_bytes());
173        aad_bytes[11..13].copy_from_slice(&(plaintext_len as u16).to_be_bytes());
174
175        let aad = Aad::from(&aad_bytes);
176
177        // Decrypt in place
178        let mut buffer = encrypted_with_tag.to_vec();
179        let plaintext = self
180            .key
181            .open_in_place(nonce, aad, &mut buffer)
182            .map_err(|_| DecryptionError::AuthenticationFailed)?;
183
184        self.sequence_number += 1;
185
186        Ok(plaintext.to_vec())
187    }
188
189    /// Decrypt a TLS 1.3 AEAD record in place.
190    ///
191    /// For TLS 1.3:
192    /// - Record format: ciphertext || tag (16 bytes)
193    /// - Nonce = iv XOR padded_sequence_number
194    /// - AAD = record_header (type || legacy_version || length)
195    /// - Inner plaintext ends with content_type byte
196    ///
197    /// Returns the decrypted plaintext (including inner content type).
198    pub fn decrypt_tls13_record(
199        &mut self,
200        ciphertext: &[u8],
201        record_header: &[u8; 5],
202    ) -> Result<Vec<u8>, DecryptionError> {
203        let tag_len = self.algorithm.tag_len();
204
205        if ciphertext.len() < tag_len {
206            return Err(DecryptionError::CiphertextTooShort {
207                min_len: tag_len,
208                actual: ciphertext.len(),
209            });
210        }
211
212        // Construct the 12-byte nonce: iv XOR padded_seq_num
213        let mut nonce_bytes = [0u8; 12];
214        nonce_bytes.copy_from_slice(&self.iv[..12.min(self.iv.len())]);
215
216        // XOR with padded sequence number (right-aligned)
217        let seq_bytes = self.sequence_number.to_be_bytes();
218        for i in 0..8 {
219            nonce_bytes[4 + i] ^= seq_bytes[i];
220        }
221
222        let nonce = Nonce::try_assume_unique_for_key(&nonce_bytes)
223            .map_err(|_| DecryptionError::InvalidNonceLength(nonce_bytes.len()))?;
224
225        // AAD is the TLS record header (type || version || length)
226        let aad = Aad::from(record_header);
227
228        // Decrypt in place
229        let mut buffer = ciphertext.to_vec();
230        let plaintext = self
231            .key
232            .open_in_place(nonce, aad, &mut buffer)
233            .map_err(|_| DecryptionError::AuthenticationFailed)?;
234
235        self.sequence_number += 1;
236
237        Ok(plaintext.to_vec())
238    }
239
240    /// Decrypt a TLS record, auto-detecting the version from context.
241    ///
242    /// This is a convenience wrapper that routes to the appropriate decryption
243    /// method based on TLS version.
244    pub fn decrypt_record(
245        &mut self,
246        tls_version: TlsVersion,
247        record_type: u8,
248        protocol_version: u16,
249        ciphertext: &[u8],
250    ) -> Result<Vec<u8>, DecryptionError> {
251        match tls_version {
252            TlsVersion::Tls12 | TlsVersion::Tls11 | TlsVersion::Tls10 => {
253                self.decrypt_tls12_record(record_type, protocol_version, ciphertext)
254            }
255            TlsVersion::Tls13 => {
256                // Reconstruct record header for TLS 1.3 AAD
257                let mut header = [0u8; 5];
258                header[0] = record_type;
259                header[1..3].copy_from_slice(&protocol_version.to_be_bytes());
260                header[3..5].copy_from_slice(&(ciphertext.len() as u16).to_be_bytes());
261                self.decrypt_tls13_record(ciphertext, &header)
262            }
263        }
264    }
265}
266
267/// TLS protocol version.
268#[derive(Debug, Clone, Copy, PartialEq, Eq)]
269pub enum TlsVersion {
270    Tls10,
271    Tls11,
272    Tls12,
273    Tls13,
274}
275
276impl TlsVersion {
277    /// Create from wire protocol version value.
278    pub fn from_wire(version: u16) -> Option<Self> {
279        match version {
280            0x0301 => Some(TlsVersion::Tls10),
281            0x0302 => Some(TlsVersion::Tls11),
282            0x0303 => Some(TlsVersion::Tls12), // Note: TLS 1.3 also uses 0x0303 in record layer
283            0x0304 => Some(TlsVersion::Tls13), // Supported versions extension
284            _ => None,
285        }
286    }
287
288    /// Get the wire protocol version value.
289    pub fn to_wire(&self) -> u16 {
290        match self {
291            TlsVersion::Tls10 => 0x0301,
292            TlsVersion::Tls11 => 0x0302,
293            TlsVersion::Tls12 | TlsVersion::Tls13 => 0x0303, // TLS 1.3 uses 0x0303 in record layer
294        }
295    }
296}
297
298/// Extract the inner content type from a TLS 1.3 decrypted record.
299///
300/// TLS 1.3 inner plaintext format: content || zeros || content_type
301/// The content_type is the last non-zero byte.
302pub fn extract_tls13_inner_content_type(plaintext: &[u8]) -> Option<(u8, &[u8])> {
303    // Find the last non-zero byte (content type)
304    let mut i = plaintext.len();
305    while i > 0 && plaintext[i - 1] == 0 {
306        i -= 1;
307    }
308
309    if i == 0 {
310        return None;
311    }
312
313    let content_type = plaintext[i - 1];
314    let content = &plaintext[..i - 1];
315
316    Some((content_type, content))
317}
318
319#[cfg(test)]
320mod tests {
321    use super::*;
322    use crate::tls::kdf::{derive_tls12_keys, derive_tls13_keys};
323
324    #[test]
325    fn test_decryption_context_creation() {
326        let key = [0x42u8; 16];
327        let iv = [0x01u8; 12];
328
329        let ctx = DecryptionContext::new(AeadAlgorithm::Aes128Gcm, &key, &iv);
330        assert!(ctx.is_ok());
331
332        let ctx = ctx.unwrap();
333        assert_eq!(ctx.sequence_number(), 0);
334    }
335
336    #[test]
337    fn test_decryption_context_wrong_key_length() {
338        let key = [0x42u8; 15]; // Wrong length
339        let iv = [0x01u8; 12];
340
341        let result = DecryptionContext::new(AeadAlgorithm::Aes128Gcm, &key, &iv);
342        assert!(matches!(
343            result,
344            Err(DecryptionError::InvalidKeyLength { .. })
345        ));
346    }
347
348    #[test]
349    fn test_tls12_context_from_keys() {
350        let master_secret = [0x42u8; 48];
351        let client_random = [0x01u8; 32];
352        let server_random = [0x02u8; 32];
353
354        let keys =
355            derive_tls12_keys(&master_secret, &client_random, &server_random, 0xC02F).unwrap();
356
357        let ctx = DecryptionContext::new_tls12(
358            &keys,
359            AeadAlgorithm::Aes128Gcm,
360            Direction::ClientToServer,
361        );
362        assert!(ctx.is_ok());
363
364        let ctx = DecryptionContext::new_tls12(
365            &keys,
366            AeadAlgorithm::Aes128Gcm,
367            Direction::ServerToClient,
368        );
369        assert!(ctx.is_ok());
370    }
371
372    #[test]
373    fn test_tls13_context_from_keys() {
374        let traffic_secret = [0x42u8; 32];
375        let keys = derive_tls13_keys(&traffic_secret, 0x1301).unwrap();
376
377        let ctx = DecryptionContext::new_tls13(&keys, AeadAlgorithm::Aes128Gcm);
378        assert!(ctx.is_ok());
379    }
380
381    #[test]
382    fn test_sequence_number() {
383        let key = [0x42u8; 16];
384        let iv = [0x01u8; 12];
385
386        let mut ctx = DecryptionContext::new(AeadAlgorithm::Aes128Gcm, &key, &iv).unwrap();
387
388        assert_eq!(ctx.sequence_number(), 0);
389        ctx.set_sequence_number(100);
390        assert_eq!(ctx.sequence_number(), 100);
391    }
392
393    #[test]
394    fn test_tls_version_from_wire() {
395        assert_eq!(TlsVersion::from_wire(0x0301), Some(TlsVersion::Tls10));
396        assert_eq!(TlsVersion::from_wire(0x0302), Some(TlsVersion::Tls11));
397        assert_eq!(TlsVersion::from_wire(0x0303), Some(TlsVersion::Tls12));
398        assert_eq!(TlsVersion::from_wire(0x0304), Some(TlsVersion::Tls13));
399        assert_eq!(TlsVersion::from_wire(0x0300), None);
400    }
401
402    #[test]
403    fn test_tls_version_to_wire() {
404        assert_eq!(TlsVersion::Tls10.to_wire(), 0x0301);
405        assert_eq!(TlsVersion::Tls11.to_wire(), 0x0302);
406        assert_eq!(TlsVersion::Tls12.to_wire(), 0x0303);
407        assert_eq!(TlsVersion::Tls13.to_wire(), 0x0303); // TLS 1.3 uses 0x0303 in record layer
408    }
409
410    #[test]
411    fn test_extract_tls13_inner_content_type() {
412        // Normal case: content + content_type
413        let plaintext = [0x48, 0x54, 0x54, 0x50, 0x17]; // "HTTP" + application_data(23)
414        let result = extract_tls13_inner_content_type(&plaintext);
415        assert!(result.is_some());
416        let (content_type, content) = result.unwrap();
417        assert_eq!(content_type, 0x17);
418        assert_eq!(content, &[0x48, 0x54, 0x54, 0x50]);
419
420        // With padding zeros
421        let plaintext = [0x48, 0x54, 0x17, 0x00, 0x00];
422        let result = extract_tls13_inner_content_type(&plaintext);
423        assert!(result.is_some());
424        let (content_type, content) = result.unwrap();
425        assert_eq!(content_type, 0x17);
426        assert_eq!(content, &[0x48, 0x54]);
427
428        // Empty content
429        let plaintext = [0x17];
430        let result = extract_tls13_inner_content_type(&plaintext);
431        assert!(result.is_some());
432        let (content_type, content) = result.unwrap();
433        assert_eq!(content_type, 0x17);
434        assert!(content.is_empty());
435
436        // All zeros (invalid)
437        let plaintext = [0x00, 0x00, 0x00];
438        let result = extract_tls13_inner_content_type(&plaintext);
439        assert!(result.is_none());
440
441        // Empty (invalid)
442        let plaintext: [u8; 0] = [];
443        let result = extract_tls13_inner_content_type(&plaintext);
444        assert!(result.is_none());
445    }
446
447    #[test]
448    fn test_decrypt_tls12_record_too_short() {
449        let key = [0x42u8; 16];
450        let iv = [0x01u8; 4]; // TLS 1.2 implicit IV is 4 bytes
451
452        let mut ctx = DecryptionContext::new(AeadAlgorithm::Aes128Gcm, &key, &iv).unwrap();
453
454        // Too short: need at least 8 (explicit nonce) + 16 (tag) = 24 bytes
455        let ciphertext = [0u8; 20];
456        let result = ctx.decrypt_tls12_record(23, 0x0303, &ciphertext);
457        assert!(matches!(
458            result,
459            Err(DecryptionError::CiphertextTooShort { .. })
460        ));
461    }
462
463    #[test]
464    fn test_decrypt_tls13_record_too_short() {
465        let key = [0x42u8; 16];
466        let iv = [0x01u8; 12];
467
468        let mut ctx = DecryptionContext::new(AeadAlgorithm::Aes128Gcm, &key, &iv).unwrap();
469
470        // Too short: need at least 16 bytes for tag
471        let ciphertext = [0u8; 10];
472        let header = [0x17, 0x03, 0x03, 0x00, 0x0A];
473        let result = ctx.decrypt_tls13_record(&ciphertext, &header);
474        assert!(matches!(
475            result,
476            Err(DecryptionError::CiphertextTooShort { .. })
477        ));
478    }
479
480    // Note: We can't easily test successful decryption without a known test vector
481    // or generating actual encrypted data. Integration tests will cover this.
482}