brass_aphid_wire_decryption/decryption/
tls_stream.rs

1use crate::decryption::{
2    key_manager::KeyManager, key_space::SecretSpace, stream_decrypter::ConversationState, Mode,
3};
4use brass_aphid_wire_messages::{
5    codec::{DecodeValue, DecodeValueWithContext},
6    protocol::{
7        content_value::{ContentValue, HandshakeMessageValue},
8        Alert, ChangeCipherSpec, ContentType, RecordHeader, ServerHelloConfusionMode,
9    },
10};
11use std::{collections::VecDeque, fmt::Debug, io::ErrorKind};
12
13/// A TlsStream is generally responsible for handling the framing and decrypting
14/// of the TLS Record protocol.
15///
16/// Consider the following scenario
17/// ```text
18/// messages ->   |------m1------|--------m2------|--------m3-------|
19/// records  ->   |----r1----|-------r2---|-----r3------|----r4-----|
20/// packets  ->   |--p1--|---p2---|---p3--|---p4--|--p5----|---p6---|
21/// ```
22///
23/// Assume (without loss of generality) that each read call returns an individual,
24/// single packet.
25///
26/// ### Packet Buffering
27///
28/// We won't be able to decrypt `r1` until we have received both `p1` and `p2`.
29/// To handle this we buffer all the reads in `byte_buffer` until the full record
30/// is available.
31///
32/// ### Record Buffering
33/// Depending on different key logging implementations, we won't be able to decrypt
34/// the record immediately. We buffer complete records in `record_buffer` until
35/// the decryption keys are available.
36///
37/// Even once we're able to decrypt records, We won't be able to parse `m1` until
38/// we have received both `r1` and `r2`. We buffer the decrypted plaintext in
39/// `plaintext_content_stream`
40///
41/// Note that the content_stream will only ever hold a single content type.
42///
43/// TODO: shenanigans here. We either have to "poll_decrypt" each time we have
44/// gotten a full record, or we need to poll_decrypt when we see a new content
45/// type (before we add it to the stream). I like the first option because it's
46/// less modality.
47///
48/// THOUGHT: Can obfuscated records have multiple inner content types in them? I
49/// think the answer is no. And if the answer is yes then 😭.
50#[derive(Debug)]
51pub struct TlsStream {
52    /// The identity of this TLS Stream.
53    ///
54    /// E.g. `Mode::Client` means that this is the stream sent by the client.
55    sender: Mode,
56    /// all tx calls are buffered here until there is enough to read a record
57    byte_buffer: Vec<u8>,
58
59    /// records are buffered here until keys are available to decrypt them
60    record_buffer: VecDeque<Vec<u8>>,
61
62    /// text is buffered here until there is enough to read a message
63    plaintext_content_stream: VecDeque<u8>,
64    plaintext_content_type: ContentType,
65    /// the current encryption level of the connection. E.g. "Plaintext" or "Application"
66    key_space: SecretSpace,
67    needs_next_key_space: bool,
68}
69
70impl TlsStream {
71    pub fn new(sender: Mode) -> Self {
72        Self {
73            sender,
74            byte_buffer: Default::default(),
75            record_buffer: Default::default(),
76            plaintext_content_stream: Default::default(),
77            // first data is alert, and also it shouldn't matter
78            plaintext_content_type: ContentType::Handshake,
79            key_space: SecretSpace::Plaintext,
80            needs_next_key_space: false,
81        }
82    }
83
84    /// Set `need_next_key_space` to false
85    ///
86    /// While the "sender" streams are almost entirely independent, that get's broken
87    /// in the event of a hello retry.
88    ///
89    /// Normally, when the Server sends a ServerHello this means that the next message
90    /// will be the EncryptedExtension (which the TLS Stream will need to decrypt).
91    ///
92    /// But if it was a
93    /// We need a way for the server stream to tell
94    /// the client stream that there's actually another client hello on the way.
95    pub fn suppress_next_key_state(&mut self) {
96        debug_assert_eq!(self.sender, Mode::Client);
97        debug_assert!(matches!(self.key_space, SecretSpace::Plaintext));
98        debug_assert!(self.needs_next_key_space);
99        self.needs_next_key_space = false;
100    }
101
102    /// Add bytes to a TLS stream.
103    ///
104    /// In the case of a DecryptingPipe, this is the method called by the Read &
105    /// Write IO methods.
106    ///
107    /// This method will not do any decryption, but will try and assemble existing
108    /// data into complete records.
109    pub fn feed_bytes(&mut self, data: &[u8]) {
110        // first buffer into byte buffer.
111        tracing::info!(
112            "feeding {:?} bytes, record buffer currently {}",
113            data.len(),
114            self.record_buffer.len()
115        );
116        self.byte_buffer.extend_from_slice(data);
117
118        // get all of the records
119        while let Some(record) = self.byte_buffer_has_record() {
120            // TODO: record header size constant
121            let record_and_header_len = record.record_length as usize + 5;
122            // pop the record off the front of the byte buffer
123            let record = self.byte_buffer.drain(..record_and_header_len).collect();
124            // store it in the record buffer
125            self.record_buffer.push_back(record);
126        }
127
128        tracing::info!("record buffer now {}", self.record_buffer.len());
129    }
130
131    /// Attempt to decrypt available bytes.
132    pub fn digest_bytes(
133        &mut self,
134        state: &mut ConversationState,
135        key_manger: &KeyManager,
136    ) -> std::io::Result<Vec<ContentValue>> {
137        tracing::trace!("digesting bytes from {:?}", self.sender);
138        // precondition: any data currently in plaintext_content_stream must have the
139        // same content type as what we are about to decrypt. If that's not true,
140        // then it means that we were unable to "clear out" the data in a previous
141        // content type and the data is malformed
142
143        // check if "needed keyspace" is some, before popping off the record.
144        // e.g. once we have seen the finished message we should
145
146        let mut content = Vec::new();
147
148        loop {
149            // don't try to get the next space if there isn't anything to read
150            // otherwise we'd try and grab the handshake keys for the client before
151            // reading the server hello
152            if self.plaintext_content_stream.is_empty() && self.record_buffer.is_empty() {
153                return Ok(content);
154            }
155
156            if self.needs_next_key_space {
157                let next_space = match &self.key_space {
158                    SecretSpace::Plaintext => key_manger
159                        .handshake_space(
160                            self.sender,
161                            state.client_random.as_ref().unwrap(),
162                            state.selected_cipher.unwrap(),
163                        )
164                        .map(SecretSpace::Handshake),
165                    SecretSpace::Handshake(_) => key_manger
166                        .first_application_space(
167                            self.sender,
168                            state.client_random.as_ref().unwrap(),
169                            state.selected_cipher.unwrap(),
170                        )
171                        .map(|space| SecretSpace::Application(space, 0)),
172                    SecretSpace::Application(key_space, current_key_epoch) => Some(
173                        SecretSpace::Application(key_space.key_update(), *current_key_epoch + 1),
174                    ),
175                };
176
177                match next_space {
178                    Some(space) => {
179                        self.key_space = space;
180                        self.needs_next_key_space = false;
181                    }
182                    None => {
183                        // give up on decrypting now, and hope that the next time
184                        // digest bytes is called we will have the keys that we need
185                        tracing::warn!(
186                            "Needed next space after {:?}, but it was unavailable",
187                            self.key_space
188                        );
189                        return Ok(content);
190                    }
191                }
192            }
193
194            // if there are records, deframe. otherwise return
195            let (content_type, record_payload) = match self.record_buffer.pop_front() {
196                Some(record) => self.key_space.deframe_record(&record)?,
197                None => return Ok(content),
198            };
199
200            // make sure that the record is the right type.
201            //
202            // If we are switching content types, then we should have received a
203            // full message (and been able to decrypt it) so the stream should be
204            // empty.
205            if self.plaintext_content_stream.is_empty() {
206                self.plaintext_content_type = content_type;
207            } else if content_type != self.plaintext_content_type {
208                return Err(std::io::Error::new(
209                    ErrorKind::InvalidData,
210                    "unable to fully parse plaintext stream, malformed message",
211                ));
212            }
213
214            // add the record into the plaintext stream
215            self.plaintext_content_stream.extend(record_payload);
216
217            tracing::trace!(
218                "plaintext stream length: {:?}",
219                self.plaintext_content_stream.len()
220            );
221
222            loop {
223                let message = Self::plaintext_stream_message(
224                    self.plaintext_content_type,
225                    &mut self.plaintext_content_stream,
226                    state,
227                );
228                let value = match message {
229                    // we got a value, yay!
230                    Ok(Some(content)) => content,
231                    // we didn't have enough data for a value, but maybe if we shove
232                    // more records onto the stream then we will.
233                    Ok(None) => break,
234                    Err(e) if e.kind() == ErrorKind::UnexpectedEof => break,
235                    // something went wrong
236                    Err(e) => return Err(e),
237                };
238
239                tracing::trace!("from plaintext stream: {value:?}");
240
241                // client hello is end of client plaintext
242                if matches!(
243                    value,
244                    ContentValue::Handshake(HandshakeMessageValue::ClientHello(_))
245                ) {
246                    // hmm, is there some way to check if the next mesage
247                    self.needs_next_key_space = true;
248                }
249                // server hello is end of server plaintext
250                // important: HelloRetryRequest is modelled as a totally different
251                // struct enum.
252                if let ContentValue::Handshake(HandshakeMessageValue::ServerHelloConfusion(
253                    ServerHelloConfusionMode::ServerHello(_),
254                )) = &value
255                {
256                    // important: this is why we represent the Hello Retry as a
257                    // different message
258                    self.needs_next_key_space = true;
259                }
260                // server finished is end of server handshake space
261                // client finished is end of client handshake space
262                if matches!(
263                    value,
264                    ContentValue::Handshake(HandshakeMessageValue::Finished(_))
265                ) {
266                    self.needs_next_key_space = true;
267                }
268
269                // server/client has updated their keys
270                // If the peer update was requested, then we will update when the
271                // peer sends their own KeyUpdate message.
272                if matches!(
273                    value,
274                    ContentValue::Handshake(HandshakeMessageValue::KeyUpdate(_))
275                ) {
276                    self.needs_next_key_space = true;
277                }
278
279                // update the connection state
280                if let ContentValue::Handshake(HandshakeMessageValue::ClientHello(s)) = &value {
281                    state.client_random = Some(s.random.to_vec());
282                }
283                if let ContentValue::Handshake(HandshakeMessageValue::ServerHelloConfusion(
284                    ServerHelloConfusionMode::ServerHello(sh),
285                )) = &value
286                {
287                    // Even if it was a hello retry, this is fine. Because the HRR still
288                    // contains the actual selected parameters.
289                    // TODO: I don't like this. I think I should only be doing this
290                    // on the actual ServerHello, this branching is confusing.
291                    // Actually, maybe I do need to do this because the HRR is a
292                    // TLS 1.3 only message?
293                    state.selected_cipher = Some(sh.cipher_suite);
294                    state.selected_protocol = Some(sh.selected_version().unwrap());
295                    tracing::info!("setting cipher and selected version: {state:?}");
296                }
297
298                content.push(value);
299            }
300        }
301    }
302
303    /// attempt to decrypt a record header from `byte_buffer`.
304    ///
305    /// A `Some` return value means that
306    /// - a record header was successfully decrypted
307    /// - the byte_buffer contains the full record
308    fn byte_buffer_has_record(&self) -> Option<RecordHeader> {
309        let (record_header, remaining) =
310            match RecordHeader::decode_from(self.byte_buffer.as_slice()) {
311                Ok(decode) => decode,
312                // TODO: we should only return None if there isn't enough data to
313                // decrypt the RecordHeader. We should bubble up different parsing errors.
314                Err(e) => return None,
315            };
316
317        if remaining.len() >= record_header.record_length as usize {
318            Some(record_header)
319        } else {
320            None
321        }
322    }
323
324    /// parse a message/content value from the plaintext stream.
325    ///
326    /// The
327    fn plaintext_stream_message(
328        content_type: ContentType,
329        stream: &mut VecDeque<u8>,
330        state: &ConversationState,
331    ) -> std::io::Result<Option<ContentValue>> {
332        stream.make_contiguous();
333        let (buffer, empty) = stream.as_slices();
334        assert!(empty.is_empty());
335
336        // TODO: neater handling.
337        // Can't rely on EOF because some things can be zero sized
338        if buffer.is_empty() {
339            return Ok(None);
340        }
341
342        tracing::info!(
343            "plaintext stream length before message pull of {content_type:?}: {:?}",
344            stream.len()
345        );
346        let (value, buffer) = match content_type {
347            ContentType::Invalid => panic!("invalid"),
348            ContentType::ChangeCipherSpec => {
349                let (ccs, record_buffer) = ChangeCipherSpec::decode_from(buffer)?;
350                (ContentValue::ChangeCipherSpec(ccs), record_buffer)
351            }
352            ContentType::Alert => {
353                let (alert, buffer) = Alert::decode_from(buffer)?;
354                (ContentValue::Alert(alert), buffer)
355            }
356            ContentType::Handshake => {
357                let (handshake_message, inner_buffer) =
358                    match (state.selected_protocol, state.selected_cipher) {
359                        (Some(protocol), Some(cipher)) => {
360                            HandshakeMessageValue::decode_from_with_context(
361                                buffer,
362                                (protocol, cipher),
363                            )?
364                        }
365                        _unknown_state => HandshakeMessageValue::decode_from(buffer)?,
366                    };
367                (ContentValue::Handshake(handshake_message), inner_buffer)
368            }
369            ContentType::ApplicationData => (
370                // we consume the entire stream for application data, no message length
371                ContentValue::ApplicationData(buffer.to_vec()),
372                [].as_slice(),
373            ),
374        };
375
376        let consumed = stream.len() - buffer.len();
377        for _ in 0..consumed {
378            // TODO: vectorized instead
379            stream.pop_front();
380        }
381
382        tracing::info!(
383            "plaintext stream length after message pull: {:?}",
384            stream.len()
385        );
386
387        Ok(Some(value))
388    }
389}