brass_aphid_wire_decryption/decryption/
key_space.rs

1use aws_lc_rs::{aead, hkdf};
2
3use brass_aphid_wire_messages::codec::{DecodeValue, EncodeValue};
4use brass_aphid_wire_messages::{
5    iana,
6    protocol::{ContentType, RecordHeader},
7};
8
9trait DecryptionCipherExtension {
10    fn aead(&self) -> &'static aws_lc_rs::aead::Algorithm;
11
12    fn hkdf(&self) -> aws_lc_rs::hkdf::Algorithm;
13}
14
15impl DecryptionCipherExtension for iana::Cipher {
16    fn aead(&self) -> &'static aws_lc_rs::aead::Algorithm {
17        match self.description {
18            "TLS_AES_128_GCM_SHA256" => &aead::AES_128_GCM,
19            "TLS_AES_256_GCM_SHA384" => &aead::AES_256_GCM,
20            "TLS_CHACHA20_POLY1305_SHA256" => &aead::CHACHA20_POLY1305,
21            _ => panic!("one of us did something stupid. Probably me."),
22        }
23    }
24
25    fn hkdf(&self) -> aws_lc_rs::hkdf::Algorithm {
26        match self.description {
27            "TLS_AES_128_GCM_SHA256" => hkdf::HKDF_SHA256,
28            "TLS_AES_256_GCM_SHA384" => hkdf::HKDF_SHA384,
29            "TLS_CHACHA20_POLY1305_SHA256" => hkdf::HKDF_SHA256,
30            _ => panic!("one of us did something stupid. Probably me."),
31        }
32    }
33}
34
35struct UsizeContainer(usize);
36
37impl UsizeContainer {
38    fn new(num: usize) -> Self {
39        UsizeContainer(num)
40    }
41}
42
43// they have unfortunately made me too angry to put up with their API
44// I am done asking nicely, and will simply transmute it into the shape
45// I wish for, and deal with the consequences later.
46impl hkdf::KeyType for UsizeContainer {
47    fn len(&self) -> usize {
48        self.0
49    }
50}
51
52fn hkdf_expand_label<T: hkdf::KeyType>(
53    secret: &[u8],
54    label: &[u8],
55    context: &[u8],
56    key_type: T,
57    hkdf: hkdf::Algorithm,
58) -> Vec<u8> {
59    let prk = hkdf::Prk::new_less_safe(hkdf, secret);
60
61    let output_length_bytes = (key_type.len() as u16).to_be_bytes();
62    let label = {
63        let mut label_builder = Vec::new();
64        label_builder.extend_from_slice(b"tls13 ");
65        label_builder.extend_from_slice(label);
66        label_builder
67    };
68    let label_bytes = label.len() as u8;
69
70    let context_bytes = context.len() as u8;
71    let label = [
72        output_length_bytes.as_slice(),
73        &[label_bytes],
74        &label,
75        &[context_bytes],
76        context,
77    ];
78
79    let mut key = vec![0; key_type.len()];
80    let out = prk.expand(&label, key_type).unwrap();
81    out.fill(&mut key).unwrap();
82    key
83}
84
85/// KeySpace represents the decryption context of some keys.
86///
87/// E.g. Handshake Space or Traffic Space.
88#[derive(Debug)]
89pub struct KeySpace {
90    pub cipher: iana::Cipher,
91    pub secret: Vec<u8>,
92    pub record_count: u64,
93    /// Defined for application traffic
94    pub key_epoch: Option<usize>,
95}
96
97impl KeySpace {
98    /// Construct a new key space from a handshake secret
99    pub fn handshake_traffic_secret(secret: Vec<u8>, cipher: iana::Cipher) -> Self {
100        // https://www.rfc-editor.org/rfc/rfc8446#section-7.3
101        // [sender]_write_key = HKDF-Expand-Label(Secret, "key", "", key_length)
102        // [sender]_write_iv  = HKDF-Expand-Label(Secret, "iv", "", iv_length)
103
104        Self {
105            cipher,
106            secret,
107            record_count: 0,
108            key_epoch: None,
109        }
110    }
111
112    /// Construct a new key space from the first traffic secret
113    pub fn first_traffic_secret(secret: Vec<u8>, cipher: iana::Cipher) -> Self {
114        // https://www.rfc-editor.org/rfc/rfc8446#section-7.3
115        // [sender]_write_key = HKDF-Expand-Label(Secret, "key", "", key_length)
116        // [sender]_write_iv  = HKDF-Expand-Label(Secret, "iv", "", iv_length)
117
118        Self {
119            cipher,
120            secret,
121            record_count: 0,
122            key_epoch: Some(0),
123        }
124    }
125
126    /// Construct a new key space following a key update
127    ///
128    /// Defined in https://www.rfc-editor.org/rfc/rfc8446#section-7.2
129    pub fn key_update(&self) -> Self {
130        let new_secret = hkdf_expand_label(
131            &self.secret,
132            b"traffic upd",
133            b"",
134            UsizeContainer::new(
135                self.cipher
136                    .hkdf()
137                    .hmac_algorithm()
138                    .digest_algorithm()
139                    .output_len(),
140            ),
141            self.cipher.hkdf(),
142        );
143        Self {
144            cipher: self.cipher,
145            secret: new_secret,
146            record_count: 0,
147            key_epoch: self.key_epoch.map(|epoch| epoch + 1),
148        }
149    }
150
151    /// Return the actual key and IV which will be used the the symmetric cipher
152    pub fn traffic_key(&self) -> std::io::Result<(Vec<u8>, Vec<u8>)> {
153        let secret = &self.secret;
154        // Determine the hash algorithm, key length, and IV length based on the cipher suite
155        let aead = self.cipher.aead();
156
157        let key = hkdf_expand_label(
158            secret,
159            b"key",
160            b"",
161            UsizeContainer::new(aead.key_len()),
162            self.cipher.hkdf(),
163        );
164        let iv = hkdf_expand_label(
165            secret,
166            b"iv",
167            b"",
168            UsizeContainer::new(aead.nonce_len()),
169            self.cipher.hkdf(),
170        );
171
172        Ok((key, iv))
173    }
174
175    /// * `record`: the encrypted record, exclusive of the header
176    /// * `sender`: the party who transmitted the record
177    pub fn decrypt_record(&mut self, header: &RecordHeader, record: &[u8]) -> Vec<u8> {
178        let (key, iv) = self.traffic_key().unwrap();
179
180        let nonce = Self::calculate_nonce(iv, self.record_count);
181        self.record_count += 1;
182
183        let unbound_key = aws_lc_rs::aead::UnboundKey::new(self.cipher.aead(), &key).unwrap();
184        let less_safe_key = aws_lc_rs::aead::LessSafeKey::new(unbound_key);
185
186        // Create a buffer that contains ciphertext + tag for in-place decryption
187        let mut output = record.to_vec();
188
189        // Decrypt the record
190        let nonce_obj = aws_lc_rs::aead::Nonce::try_assume_unique_for_key(&nonce).unwrap();
191
192        let aad = header.encode_to_vec().unwrap();
193
194        let plaintext = less_safe_key
195            .open_in_place(nonce_obj, aws_lc_rs::aead::Aad::from(aad), &mut output)
196            .unwrap();
197        plaintext.to_vec()
198    }
199
200    /// XOR the IV with the record count
201    fn calculate_nonce(iv: Vec<u8>, record_count: u64) -> Vec<u8> {
202        let mut nonce = iv.clone();
203        let record_count = record_count.to_be_bytes();
204        let mut bytes = vec![0; nonce.len() - record_count.len()];
205        bytes.extend_from_slice(&record_count);
206
207        for i in 0..nonce.len() {
208            nonce[i] ^= bytes[i];
209        }
210
211        nonce
212    }
213}
214
215#[derive(Debug)]
216pub enum SecretSpace {
217    Plaintext,
218    Handshake(KeySpace),
219    Application(KeySpace, usize),
220}
221
222impl SecretSpace {
223    /// Deframe (possibly decrypt) a record, returning it's true content type.
224    ///
225    /// E.g. A TLS 1.3 obfuscated record may have an obfuscated content type of "ApplicationData",
226    /// but an internal type of Handshake. This method would return `Handshake`.
227    ///
228    /// This method will also strip off all record padding
229    pub fn deframe_record(&mut self, record: &[u8]) -> std::io::Result<(ContentType, Vec<u8>)> {
230        let remaining = record;
231        let (outer_record_header, remaining) = RecordHeader::decode_from(remaining)?;
232        tracing::debug!("Deframing {outer_record_header:?}");
233        // handle plaintext items which might occur in different places
234        // CCS -> TLS adores complexity, so this is included to make my parsing
235        //        more complicated.
236        // Alert -> we might receive a TLS alert in plaintext even during an
237        //          encrypted space.
238        if matches!(
239            outer_record_header.content_type,
240            ContentType::ChangeCipherSpec | ContentType::Alert
241        ) {
242            return Ok((outer_record_header.content_type, remaining.to_vec()));
243        }
244
245        match self {
246            SecretSpace::Plaintext => Ok((outer_record_header.content_type, remaining.to_vec())),
247            SecretSpace::Handshake(key_space) | SecretSpace::Application(key_space, _) => {
248                let mut plaintext = key_space.decrypt_record(&outer_record_header, remaining);
249
250                // In TLS 1.3, records are "obfuscated". The plaintext record header
251                // contains a fake content type set to "Application Data". To determine
252                // the real content type you must
253                // 1. decrypt the record
254                // 2. remove all padding (0's) from the end of the record
255                // 3. the non-zero byte at the end of the decrypted record content
256                //    is the "real" content type.
257                // "But James!" you say. "That's so complicated. Why wouldn't they
258                // just add another header inside the plaintext? That way you
259                // don't have to try and parse backwards."
260                //
261                // well dear reader, I completely agree, but you are forgetting
262                // the primary point that "TLS Adores Complexity"
263
264                // remove the padding
265                let mut padding = 0;
266                while plaintext.ends_with(&[0]) {
267                    padding += 1;
268                    plaintext.pop();
269                }
270
271                // TODO: is it possible to send a record which is entirely padding?
272
273                // parse the content type from the last byte
274                let content_type =
275                    ContentType::decode_from_exact(&plaintext[plaintext.len() - 1..])?;
276                plaintext.pop();
277
278                tracing::trace!("InnerRecordHeader {{");
279                tracing::trace!("    content_type: {content_type:?}");
280                tracing::trace!("    inner_length: {}", plaintext.len());
281                tracing::trace!("    padding: {padding}");
282                tracing::trace!("}}");
283                Ok((content_type, plaintext))
284            }
285        }
286    }
287}
288
289#[cfg(test)]
290mod tests {
291    use brass_aphid_wire_messages::protocol::HandshakeMessageHeader;
292
293    use super::*;
294
295    #[test]
296    /// Make sure that the correct traffic key is derived
297    fn traffic_key_derivation() {
298        let server_secret =
299            hex::decode("4182e4b0b6565a8f7b8586cc35d2ca23f22fa47764a16eaee9e1b21038efd2a4")
300                .unwrap();
301
302        let space = KeySpace::handshake_traffic_secret(
303            server_secret,
304            iana::Cipher::from_description("TLS_AES_128_GCM_SHA256").unwrap(),
305        );
306
307        let (key, iv) = space.traffic_key().unwrap();
308        assert_eq!(hex::encode(key), "d4af18cdaa11d3943b4d8bb0f9d6c6ca");
309        assert_eq!(hex::encode(iv), "32bd8d44d91fb6e913c3349b");
310    }
311
312    #[test]
313    /// Make sure that a record is successfully decrypted
314    fn handshake_record_decrypt() {
315        let server_secret =
316            hex::decode("64d7b60c7f0d3ca90e47411c575f7eaa8b24d754f3e68ac2d3f060e28395553d")
317                .unwrap();
318
319        let aes_128 = iana::Cipher::from_description("TLS_AES_128_GCM_SHA256").unwrap();
320
321        let mut space = KeySpace::handshake_traffic_secret(server_secret, aes_128);
322
323        let record =
324            hex::decode("1703030017c89a8a469e34ecee23cd8fbe8e978763ac2e498ddebcc5").unwrap();
325        let record_buffer = record.as_slice();
326        let (record_header, record_buffer) = RecordHeader::decode_from(record_buffer).unwrap();
327
328        let decrypted = space.decrypt_record(&record_header, record_buffer);
329        assert_eq!(hex::encode(decrypted), "08000002000016");
330    }
331
332    #[test]
333    fn maybe_app_data() {
334        let data = hex::decode("08000002000016").unwrap();
335
336        let (header, _) = HandshakeMessageHeader::decode_from(data.as_slice()).unwrap();
337        println!("header : {header:#?}");
338    }
339}