brass_aphid_wire_decryption/decryption/tls_stream.rs
1use crate::decryption::{
2 key_manager::KeyManager, key_space::SecretSpace, stream_decrypter::ConversationState, Mode,
3};
4use brass_aphid_wire_messages::{
5 codec::{DecodeValue, DecodeValueWithContext},
6 protocol::{
7 content_value::{ContentValue, HandshakeMessageValue},
8 Alert, ChangeCipherSpec, ContentType, RecordHeader, ServerHelloConfusionMode,
9 },
10};
11use std::{collections::VecDeque, fmt::Debug, io::ErrorKind};
12
13/// A TlsStream is generally responsible for handling the framing and decrypting
14/// of the TLS Record protocol.
15///
16/// Consider the following scenario
17/// ```text
18/// messages -> |------m1------|--------m2------|--------m3-------|
19/// records -> |----r1----|-------r2---|-----r3------|----r4-----|
20/// packets -> |--p1--|---p2---|---p3--|---p4--|--p5----|---p6---|
21/// ```
22///
23/// Assume (without loss of generality) that each read call returns an individual,
24/// single packet.
25///
26/// ### Packet Buffering
27///
28/// We won't be able to decrypt `r1` until we have received both `p1` and `p2`.
29/// To handle this we buffer all the reads in `byte_buffer` until the full record
30/// is available.
31///
32/// ### Record Buffering
33/// Depending on different key logging implementations, we won't be able to decrypt
34/// the record immediately. We buffer complete records in `record_buffer` until
35/// the decryption keys are available.
36///
37/// Even once we're able to decrypt records, We won't be able to parse `m1` until
38/// we have received both `r1` and `r2`. We buffer the decrypted plaintext in
39/// `plaintext_content_stream`
40///
41/// Note that the content_stream will only ever hold a single content type.
42///
43/// TODO: shenanigans here. We either have to "poll_decrypt" each time we have
44/// gotten a full record, or we need to poll_decrypt when we see a new content
45/// type (before we add it to the stream). I like the first option because it's
46/// less modality.
47///
48/// THOUGHT: Can obfuscated records have multiple inner content types in them? I
49/// think the answer is no. And if the answer is yes then ðŸ˜.
50#[derive(Debug)]
51pub struct TlsStream {
52 /// The identity of this TLS Stream.
53 ///
54 /// E.g. `Mode::Client` means that this is the stream sent by the client.
55 sender: Mode,
56 /// all tx calls are buffered here until there is enough to read a record
57 byte_buffer: Vec<u8>,
58
59 /// records are buffered here until keys are available to decrypt them
60 record_buffer: VecDeque<Vec<u8>>,
61
62 /// text is buffered here until there is enough to read a message
63 plaintext_content_stream: VecDeque<u8>,
64 plaintext_content_type: ContentType,
65 /// the current encryption level of the connection. E.g. "Plaintext" or "Application"
66 key_space: SecretSpace,
67 needs_next_key_space: bool,
68}
69
70impl TlsStream {
71 pub fn new(sender: Mode) -> Self {
72 Self {
73 sender,
74 byte_buffer: Default::default(),
75 record_buffer: Default::default(),
76 plaintext_content_stream: Default::default(),
77 // first data is alert, and also it shouldn't matter
78 plaintext_content_type: ContentType::Handshake,
79 key_space: SecretSpace::Plaintext,
80 needs_next_key_space: false,
81 }
82 }
83
84 /// Set `need_next_key_space` to false
85 ///
86 /// While the "sender" streams are almost entirely independent, that get's broken
87 /// in the event of a hello retry.
88 ///
89 /// Normally, when the Server sends a ServerHello this means that the next message
90 /// will be the EncryptedExtension (which the TLS Stream will need to decrypt).
91 ///
92 /// But if it was a
93 /// We need a way for the server stream to tell
94 /// the client stream that there's actually another client hello on the way.
95 pub fn suppress_next_key_state(&mut self) {
96 debug_assert_eq!(self.sender, Mode::Client);
97 debug_assert!(matches!(self.key_space, SecretSpace::Plaintext));
98 debug_assert!(self.needs_next_key_space);
99 self.needs_next_key_space = false;
100 }
101
102 /// Add bytes to a TLS stream.
103 ///
104 /// In the case of a DecryptingPipe, this is the method called by the Read &
105 /// Write IO methods.
106 ///
107 /// This method will not do any decryption, but will try and assemble existing
108 /// data into complete records.
109 pub fn feed_bytes(&mut self, data: &[u8]) {
110 // first buffer into byte buffer.
111 tracing::info!(
112 "feeding {:?} bytes, record buffer currently {}",
113 data.len(),
114 self.record_buffer.len()
115 );
116 self.byte_buffer.extend_from_slice(data);
117
118 // get all of the records
119 while let Some(record) = self.byte_buffer_has_record() {
120 // TODO: record header size constant
121 let record_and_header_len = record.record_length as usize + 5;
122 // pop the record off the front of the byte buffer
123 let record = self.byte_buffer.drain(..record_and_header_len).collect();
124 // store it in the record buffer
125 self.record_buffer.push_back(record);
126 }
127
128 tracing::info!("record buffer now {}", self.record_buffer.len());
129 }
130
131 /// Attempt to decrypt available bytes.
132 pub fn digest_bytes(
133 &mut self,
134 state: &mut ConversationState,
135 key_manger: &KeyManager,
136 ) -> std::io::Result<Vec<ContentValue>> {
137 tracing::trace!("digesting bytes from {:?}", self.sender);
138 // precondition: any data currently in plaintext_content_stream must have the
139 // same content type as what we are about to decrypt. If that's not true,
140 // then it means that we were unable to "clear out" the data in a previous
141 // content type and the data is malformed
142
143 // check if "needed keyspace" is some, before popping off the record.
144 // e.g. once we have seen the finished message we should
145
146 let mut content = Vec::new();
147
148 loop {
149 // don't try to get the next space if there isn't anything to read
150 // otherwise we'd try and grab the handshake keys for the client before
151 // reading the server hello
152 if self.plaintext_content_stream.is_empty() && self.record_buffer.is_empty() {
153 return Ok(content);
154 }
155
156 if self.needs_next_key_space {
157 let next_space = match &self.key_space {
158 SecretSpace::Plaintext => key_manger
159 .handshake_space(
160 self.sender,
161 state.client_random.as_ref().unwrap(),
162 state.selected_cipher.unwrap(),
163 )
164 .map(SecretSpace::Handshake),
165 SecretSpace::Handshake(_) => key_manger
166 .first_application_space(
167 self.sender,
168 state.client_random.as_ref().unwrap(),
169 state.selected_cipher.unwrap(),
170 )
171 .map(|space| SecretSpace::Application(space, 0)),
172 SecretSpace::Application(key_space, current_key_epoch) => Some(
173 SecretSpace::Application(key_space.key_update(), *current_key_epoch + 1),
174 ),
175 };
176
177 match next_space {
178 Some(space) => {
179 self.key_space = space;
180 self.needs_next_key_space = false;
181 }
182 None => {
183 // give up on decrypting now, and hope that the next time
184 // digest bytes is called we will have the keys that we need
185 tracing::warn!(
186 "Needed next space after {:?}, but it was unavailable",
187 self.key_space
188 );
189 return Ok(content);
190 }
191 }
192 }
193
194 // if there are records, deframe. otherwise return
195 let (content_type, record_payload) = match self.record_buffer.pop_front() {
196 Some(record) => self.key_space.deframe_record(&record)?,
197 None => return Ok(content),
198 };
199
200 // make sure that the record is the right type.
201 //
202 // If we are switching content types, then we should have received a
203 // full message (and been able to decrypt it) so the stream should be
204 // empty.
205 if self.plaintext_content_stream.is_empty() {
206 self.plaintext_content_type = content_type;
207 } else if content_type != self.plaintext_content_type {
208 return Err(std::io::Error::new(
209 ErrorKind::InvalidData,
210 "unable to fully parse plaintext stream, malformed message",
211 ));
212 }
213
214 // add the record into the plaintext stream
215 self.plaintext_content_stream.extend(record_payload);
216
217 tracing::trace!(
218 "plaintext stream length: {:?}",
219 self.plaintext_content_stream.len()
220 );
221
222 loop {
223 let message = Self::plaintext_stream_message(
224 self.plaintext_content_type,
225 &mut self.plaintext_content_stream,
226 state,
227 );
228 let value = match message {
229 // we got a value, yay!
230 Ok(Some(content)) => content,
231 // we didn't have enough data for a value, but maybe if we shove
232 // more records onto the stream then we will.
233 Ok(None) => break,
234 Err(e) if e.kind() == ErrorKind::UnexpectedEof => break,
235 // something went wrong
236 Err(e) => return Err(e),
237 };
238
239 tracing::trace!("from plaintext stream: {value:?}");
240
241 // client hello is end of client plaintext
242 if matches!(
243 value,
244 ContentValue::Handshake(HandshakeMessageValue::ClientHello(_))
245 ) {
246 // hmm, is there some way to check if the next mesage
247 self.needs_next_key_space = true;
248 }
249 // server hello is end of server plaintext
250 // important: HelloRetryRequest is modelled as a totally different
251 // struct enum.
252 if let ContentValue::Handshake(HandshakeMessageValue::ServerHelloConfusion(
253 ServerHelloConfusionMode::ServerHello(_),
254 )) = &value
255 {
256 // important: this is why we represent the Hello Retry as a
257 // different message
258 self.needs_next_key_space = true;
259 }
260 // server finished is end of server handshake space
261 // client finished is end of client handshake space
262 if matches!(
263 value,
264 ContentValue::Handshake(HandshakeMessageValue::Finished(_))
265 ) {
266 self.needs_next_key_space = true;
267 }
268
269 // server/client has updated their keys
270 // If the peer update was requested, then we will update when the
271 // peer sends their own KeyUpdate message.
272 if matches!(
273 value,
274 ContentValue::Handshake(HandshakeMessageValue::KeyUpdate(_))
275 ) {
276 self.needs_next_key_space = true;
277 }
278
279 // update the connection state
280 if let ContentValue::Handshake(HandshakeMessageValue::ClientHello(s)) = &value {
281 state.client_random = Some(s.random.to_vec());
282 }
283 if let ContentValue::Handshake(HandshakeMessageValue::ServerHelloConfusion(
284 ServerHelloConfusionMode::ServerHello(sh),
285 )) = &value
286 {
287 // Even if it was a hello retry, this is fine. Because the HRR still
288 // contains the actual selected parameters.
289 // TODO: I don't like this. I think I should only be doing this
290 // on the actual ServerHello, this branching is confusing.
291 // Actually, maybe I do need to do this because the HRR is a
292 // TLS 1.3 only message?
293 state.selected_cipher = Some(sh.cipher_suite);
294 state.selected_protocol = Some(sh.selected_version().unwrap());
295 tracing::info!("setting cipher and selected version: {state:?}");
296 }
297
298 content.push(value);
299 }
300 }
301 }
302
303 /// attempt to decrypt a record header from `byte_buffer`.
304 ///
305 /// A `Some` return value means that
306 /// - a record header was successfully decrypted
307 /// - the byte_buffer contains the full record
308 fn byte_buffer_has_record(&self) -> Option<RecordHeader> {
309 let (record_header, remaining) =
310 match RecordHeader::decode_from(self.byte_buffer.as_slice()) {
311 Ok(decode) => decode,
312 // TODO: we should only return None if there isn't enough data to
313 // decrypt the RecordHeader. We should bubble up different parsing errors.
314 Err(e) => return None,
315 };
316
317 if remaining.len() >= record_header.record_length as usize {
318 Some(record_header)
319 } else {
320 None
321 }
322 }
323
324 /// parse a message/content value from the plaintext stream.
325 ///
326 /// The
327 fn plaintext_stream_message(
328 content_type: ContentType,
329 stream: &mut VecDeque<u8>,
330 state: &ConversationState,
331 ) -> std::io::Result<Option<ContentValue>> {
332 stream.make_contiguous();
333 let (buffer, empty) = stream.as_slices();
334 assert!(empty.is_empty());
335
336 // TODO: neater handling.
337 // Can't rely on EOF because some things can be zero sized
338 if buffer.is_empty() {
339 return Ok(None);
340 }
341
342 tracing::info!(
343 "plaintext stream length before message pull of {content_type:?}: {:?}",
344 stream.len()
345 );
346 let (value, buffer) = match content_type {
347 ContentType::Invalid => panic!("invalid"),
348 ContentType::ChangeCipherSpec => {
349 let (ccs, record_buffer) = ChangeCipherSpec::decode_from(buffer)?;
350 (ContentValue::ChangeCipherSpec(ccs), record_buffer)
351 }
352 ContentType::Alert => {
353 let (alert, buffer) = Alert::decode_from(buffer)?;
354 (ContentValue::Alert(alert), buffer)
355 }
356 ContentType::Handshake => {
357 let (handshake_message, inner_buffer) =
358 match (state.selected_protocol, state.selected_cipher) {
359 (Some(protocol), Some(cipher)) => {
360 HandshakeMessageValue::decode_from_with_context(
361 buffer,
362 (protocol, cipher),
363 )?
364 }
365 _unknown_state => HandshakeMessageValue::decode_from(buffer)?,
366 };
367 (ContentValue::Handshake(handshake_message), inner_buffer)
368 }
369 ContentType::ApplicationData => (
370 // we consume the entire stream for application data, no message length
371 ContentValue::ApplicationData(buffer.to_vec()),
372 [].as_slice(),
373 ),
374 };
375
376 let consumed = stream.len() - buffer.len();
377 for _ in 0..consumed {
378 // TODO: vectorized instead
379 stream.pop_front();
380 }
381
382 tracing::info!(
383 "plaintext stream length after message pull: {:?}",
384 stream.len()
385 );
386
387 Ok(Some(value))
388 }
389}