brass_aphid_wire_decryption/decryption/
transcript.rs1use crate::decryption::{
2 s2n_tls_intercept::{self, PeerIntoS2ntlsInsides},
3 Mode,
4};
5use brass_aphid_wire_messages::protocol::{
6 content_value::{ContentValue, HandshakeMessageValue},
7 ClientHello, HelloRetryRequest, ServerHello, ServerHelloConfusionMode,
8};
9use s2n_tls::testing::TestPair;
10use std::{
11 cell::RefCell,
12 ffi::c_void,
13 io::Write,
14 pin::Pin,
15 sync::{Arc, Mutex},
16};
17
18#[derive(Debug)]
19pub struct Transcript {
20 pub record_transcript: Mutex<Vec<(Mode, usize)>>,
22
23 pub content_transcript: Mutex<Vec<(Mode, ContentValue)>>,
26}
27
28impl Transcript {
29 pub fn record_record(&self, sender: Mode, size: usize) {
31 self.record_transcript.lock().unwrap().push((sender, size));
32 }
33
34 pub fn record_content(&self, sender: Mode, content: ContentValue) {
35 self.content_transcript
36 .lock()
37 .unwrap()
38 .push((sender, content));
39 }
40
41 pub fn records(&self) -> Vec<(Mode, usize)> {
42 self.record_transcript.lock().unwrap().clone()
43 }
44
45 pub fn content(&self) -> Vec<(Mode, usize)> {
46 self.record_transcript.lock().unwrap().clone()
47 }
48
49 pub fn client_hellos(&self) -> Vec<ClientHello> {
50 let content = self.content_transcript.lock().unwrap();
51 content
52 .iter()
53 .filter_map(|(_, message)| {
54 if let ContentValue::Handshake(HandshakeMessageValue::ClientHello(ch)) = message {
55 Some(ch.clone())
56 } else {
57 None
58 }
59 })
60 .collect()
61 }
62
63 pub fn client_hello(&self) -> ClientHello {
65 let client_hellos = self.client_hellos();
66 assert_eq!(client_hellos.len(), 1);
67 client_hellos.first().unwrap().clone()
68 }
69
70 pub fn server_hello(&self) -> ServerHello {
71 let content = self.content_transcript.lock().unwrap();
72 for (_, content) in content.iter() {
73 if let ContentValue::Handshake(HandshakeMessageValue::ServerHelloConfusion(
74 ServerHelloConfusionMode::ServerHello(sh),
75 )) = content
76 {
77 return sh.clone();
78 }
79 }
80 panic!("no server hello. smh, people have no manners");
81 }
82
83 pub fn hello_retry_request(&self) -> Option<HelloRetryRequest> {
84 let content = self.content_transcript.lock().unwrap();
85 for (_, content) in content.iter() {
86 if let ContentValue::Handshake(HandshakeMessageValue::ServerHelloConfusion(
87 ServerHelloConfusionMode::HelloRetryRequest(hrr),
88 )) = content
89 {
90 return Some(hrr.clone());
91 }
92 }
93 None
94 }
95}
96
97use crate::decryption::s2n_tls_intercept::InterceptedSendCallback;
98pub trait TestPairExtension {
99 fn enable_transcript(&mut self) -> TestPairTranscript;
103}
104
105impl TestPairExtension for TestPair {
106 fn enable_transcript(&mut self) -> TestPairTranscript {
107 TestPairTranscript::new(self)
108 }
109}
110
111pub struct TestPairTranscript {
113 records: Pin<Arc<RefCell<Vec<(Mode, Vec<u8>)>>>>,
114 client_handle: Box<RecordingSendHandle>,
115 server_handle: Box<RecordingSendHandle>,
116}
117
118impl TestPairTranscript {
119 fn new(pair: &mut TestPair) -> Self {
121 let records = Arc::pin(RefCell::new(Vec::new()));
122
123 let client_send = pair.client.steal_send_cb();
125 let client_record_handle =
126 RecordingSendHandle::new(Mode::Client, records.clone(), client_send);
127 let client_boxed = Box::new(client_record_handle);
128
129 pair.client
130 .set_send_callback(Some(
131 s2n_tls_intercept::generic_send_cb::<RecordingSendHandle>,
132 ))
133 .unwrap();
134 unsafe { pair.client.set_send_context(client_boxed.as_ref() as *const RecordingSendHandle as *mut c_void) }.unwrap();
135
136 let server_send = pair.server.steal_send_cb();
137 let server_record_handle =
138 RecordingSendHandle::new(Mode::Server, records.clone(), server_send);
139 let server_boxed = Box::new(server_record_handle);
140
141 pair.server
142 .set_send_callback(Some(
143 s2n_tls_intercept::generic_send_cb::<RecordingSendHandle>,
144 ))
145 .unwrap();
146 unsafe { pair.server.set_send_context(server_boxed.as_ref() as *const RecordingSendHandle as *mut c_void) }.unwrap();
147
148 Self {
149 records,
150 client_handle: client_boxed,
151 server_handle: server_boxed,
152 }
153 }
154
155 pub fn get_all_records(&self) -> Vec<(Mode, Vec<u8>)> {
157 self.records.borrow().clone()
158 }
159}
160
161pub struct RecordingSendHandle {
163 identity: Mode,
165 records: Pin<Arc<RefCell<Vec<(Mode, Vec<u8>)>>>>,
166 io_stream: InterceptedSendCallback,
168}
169
170impl RecordingSendHandle {
171 pub fn new(
172 identity: Mode,
173 records: Pin<Arc<RefCell<Vec<(Mode, Vec<u8>)>>>>,
174 io_stream: InterceptedSendCallback,
175 ) -> Self {
176 Self {
177 identity,
178 records,
179 io_stream,
180 }
181 }
182}
183
184impl Write for RecordingSendHandle {
185 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
186 self.records
188 .borrow_mut()
189 .push((self.identity, buf.to_vec()));
190
191 let bytes_written = self.io_stream.write(buf).unwrap();
192 assert_eq!(bytes_written, buf.len());
194
195 Ok(buf.len())
196 }
197
198 fn flush(&mut self) -> std::io::Result<()> {
199 Ok(())
201 }
202}