brass_aphid_wire_decryption/decryption/
transcript.rs

1use crate::decryption::{
2    s2n_tls_intercept::{self, PeerIntoS2ntlsInsides},
3    Mode,
4};
5use brass_aphid_wire_messages::protocol::{
6    content_value::{ContentValue, HandshakeMessageValue},
7    ClientHello, HelloRetryRequest, ServerHello, ServerHelloConfusionMode,
8};
9use s2n_tls::testing::TestPair;
10use std::{
11    cell::RefCell,
12    ffi::c_void,
13    io::Write,
14    pin::Pin,
15    sync::{Arc, Mutex},
16};
17
18#[derive(Debug)]
19pub struct Transcript {
20    /// a list of the record sizes sent by each peer
21    pub record_transcript: Mutex<Vec<(Mode, usize)>>,
22
23    /// a list of the content sent by each peer
24    /// TODO: why are these mutexes? I think the vast majority of the time
25    pub content_transcript: Mutex<Vec<(Mode, ContentValue)>>,
26}
27
28impl Transcript {
29    // record_record hurts my brain
30    pub fn record_record(&self, sender: Mode, size: usize) {
31        self.record_transcript.lock().unwrap().push((sender, size));
32    }
33
34    pub fn record_content(&self, sender: Mode, content: ContentValue) {
35        self.content_transcript
36            .lock()
37            .unwrap()
38            .push((sender, content));
39    }
40
41    pub fn records(&self) -> Vec<(Mode, usize)> {
42        self.record_transcript.lock().unwrap().clone()
43    }
44
45    pub fn content(&self) -> Vec<(Mode, usize)> {
46        self.record_transcript.lock().unwrap().clone()
47    }
48
49    pub fn client_hellos(&self) -> Vec<ClientHello> {
50        let content = self.content_transcript.lock().unwrap();
51        content
52            .iter()
53            .filter_map(|(_, message)| {
54                if let ContentValue::Handshake(HandshakeMessageValue::ClientHello(ch)) = message {
55                    Some(ch.clone())
56                } else {
57                    None
58                }
59            })
60            .collect()
61    }
62
63    /// panics if there is more than one (TLS 1.3 HRR)
64    pub fn client_hello(&self) -> ClientHello {
65        let client_hellos = self.client_hellos();
66        assert_eq!(client_hellos.len(), 1);
67        client_hellos.first().unwrap().clone()
68    }
69
70    pub fn server_hello(&self) -> ServerHello {
71        let content = self.content_transcript.lock().unwrap();
72        for (_, content) in content.iter() {
73            if let ContentValue::Handshake(HandshakeMessageValue::ServerHelloConfusion(
74                ServerHelloConfusionMode::ServerHello(sh),
75            )) = content
76            {
77                return sh.clone();
78            }
79        }
80        panic!("no server hello. smh, people have no manners");
81    }
82
83    pub fn hello_retry_request(&self) -> Option<HelloRetryRequest> {
84        let content = self.content_transcript.lock().unwrap();
85        for (_, content) in content.iter() {
86            if let ContentValue::Handshake(HandshakeMessageValue::ServerHelloConfusion(
87                ServerHelloConfusionMode::HelloRetryRequest(hrr),
88            )) = content
89            {
90                return Some(hrr.clone());
91            }
92        }
93        None
94    }
95}
96
97use crate::decryption::s2n_tls_intercept::InterceptedSendCallback;
98pub trait TestPairExtension {
99    /// Record all bytes sent by the connections.
100    ///
101    /// This does not decrypt the bytes
102    fn enable_transcript(&mut self) -> TestPairTranscript;
103}
104
105impl TestPairExtension for TestPair {
106    fn enable_transcript(&mut self) -> TestPairTranscript {
107        TestPairTranscript::new(self)
108    }
109}
110
111/// Holds all of the writes that occurred during a TLS handshake
112pub struct TestPairTranscript {
113    records: Pin<Arc<RefCell<Vec<(Mode, Vec<u8>)>>>>,
114    client_handle: Box<RecordingSendHandle>,
115    server_handle: Box<RecordingSendHandle>,
116}
117
118impl TestPairTranscript {
119    /// Create a new empty transcript
120    fn new(pair: &mut TestPair) -> Self {
121        let records = Arc::pin(RefCell::new(Vec::new()));
122
123        // configure client
124        let client_send = pair.client.steal_send_cb();
125        let client_record_handle =
126            RecordingSendHandle::new(Mode::Client, records.clone(), client_send);
127        let client_boxed = Box::new(client_record_handle);
128
129        pair.client
130            .set_send_callback(Some(
131                s2n_tls_intercept::generic_send_cb::<RecordingSendHandle>,
132            ))
133            .unwrap();
134        unsafe { pair.client.set_send_context(client_boxed.as_ref() as *const RecordingSendHandle as *mut c_void) }.unwrap();
135
136        let server_send = pair.server.steal_send_cb();
137        let server_record_handle =
138            RecordingSendHandle::new(Mode::Server, records.clone(), server_send);
139        let server_boxed = Box::new(server_record_handle);
140
141        pair.server
142            .set_send_callback(Some(
143                s2n_tls_intercept::generic_send_cb::<RecordingSendHandle>,
144            ))
145            .unwrap();
146        unsafe { pair.server.set_send_context(server_boxed.as_ref() as *const RecordingSendHandle as *mut c_void) }.unwrap();
147
148        Self {
149            records,
150            client_handle: client_boxed,
151            server_handle: server_boxed,
152        }
153    }
154
155    /// Get all records in order of transmission
156    pub fn get_all_records(&self) -> Vec<(Mode, Vec<u8>)> {
157        self.records.borrow().clone()
158    }
159}
160
161/// A handle that records data sent through it and forwards it to the original IO stream
162pub struct RecordingSendHandle {
163    // client or server
164    identity: Mode,
165    records: Pin<Arc<RefCell<Vec<(Mode, Vec<u8>)>>>>,
166    // Reference to the TestPair's IO stream to forward data to
167    io_stream: InterceptedSendCallback,
168}
169
170impl RecordingSendHandle {
171    pub fn new(
172        identity: Mode,
173        records: Pin<Arc<RefCell<Vec<(Mode, Vec<u8>)>>>>,
174        io_stream: InterceptedSendCallback,
175    ) -> Self {
176        Self {
177            identity,
178            records,
179            io_stream,
180        }
181    }
182}
183
184impl Write for RecordingSendHandle {
185    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
186        // Record the data
187        self.records
188            .borrow_mut()
189            .push((self.identity, buf.to_vec()));
190
191        let bytes_written = self.io_stream.write(buf).unwrap();
192        // should be local io
193        assert_eq!(bytes_written, buf.len());
194
195        Ok(buf.len())
196    }
197
198    fn flush(&mut self) -> std::io::Result<()> {
199        /* no op */
200        Ok(())
201    }
202}