1#![deny(missing_docs)]
2#![warn(clippy::nursery)]
3
4use crate::mock::Response;
11use log::{error, info};
12use native_tls::TlsStream;
13use openssl::pkey::{PKey, PKeyRef, Private};
14use openssl::x509::X509Ref;
15use std::io::{Read, Write};
16use std::net::{SocketAddr, TcpListener, TcpStream};
17use std::sync::mpsc;
18use std::thread;
19
20mod identity;
21mod mock;
22#[cfg(test)]
23mod test;
24pub use crate::mock::Mock;
25
26const SERVER_ADDRESS_INTERNAL: &str = "127.0.0.1:1234";
27
28pub struct Proxy {
30 mocks: Vec<Mock>,
31 listening_addr: Option<SocketAddr>,
32 started: bool,
33 identity: PKey<Private>,
34 cert: openssl::x509::X509,
35}
36
37impl Default for Proxy {
38 fn default() -> Self {
39 let (cert, identity) = crate::identity::mk_ca_cert().unwrap();
40 Self {
41 mocks: Vec::new(),
42 listening_addr: None,
43 started: false,
44 identity,
45 cert,
46 }
47 }
48}
49
50struct Pair<'a>(&'a X509Ref, &'a PKeyRef<Private>);
51
52impl Proxy {
53 pub fn new() -> Self {
55 Self::default()
56 }
57
58 pub fn register(&mut self, mock: Mock) {
63 if self.started {
64 panic!("Cannot add mocks to a started proxy");
65 }
66 self.mocks.push(mock);
67 }
68
69 pub fn start(&mut self) {
74 start_proxy(self);
75 }
76
77 pub fn stop(&mut self) {
81 todo!();
82 }
83
84 pub fn address(&self) -> SocketAddr {
90 self.listening_addr.expect("server should be listening")
91 }
92
93 pub fn url(&self) -> String {
98 format!("http://{}", self.address())
99 }
100
101 pub fn get_certificate(&self) -> Vec<u8> {
106 self.cert.to_pem().unwrap()
107 }
108}
109
110#[derive(Debug, Clone)]
111struct Request {
112 error: Option<String>,
113 host: Option<String>,
114 path: Option<String>,
115 method: Option<String>,
116 version: (u8, u8),
117}
118
119impl std::fmt::Display for Request {
120 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
121 f.debug_struct("Request")
122 .field("method", &self.method)
123 .field("host", &self.host)
124 .field("path", &self.path)
125 .finish()
126 }
127}
128impl Request {
129 fn is_ok(&self) -> bool {
130 self.error().is_none()
131 }
132 fn error(&self) -> Option<&String> {
133 self.error.as_ref()
134 }
135
136 fn from(stream: &mut dyn Read) -> Self {
137 let mut request = Self {
138 error: None,
139 host: None,
140 path: None,
141 method: None,
142 version: (0, 0),
143 };
144
145 let mut all_buf = Vec::new();
146
147 loop {
148 let mut buf = [0; 1024];
149
150 let rlen = match stream.read(&mut buf) {
151 Err(e) => Err(e.to_string()),
152 Ok(0) => Err("Nothing to read.".into()),
153 Ok(i) => Ok(i),
154 }
155 .map_err(|e| request.error = Some(e))
156 .unwrap_or(0);
157 if request.error().is_some() {
158 break;
159 }
160
161 all_buf.extend_from_slice(&buf[..rlen]);
162
163 if rlen < 1024 {
164 break;
165 }
166 }
167
168 let mut headers = [httparse::EMPTY_HEADER; 16];
169 let mut req = httparse::Request::new(&mut headers);
170
171 let _ = req
172 .parse(&all_buf)
173 .map_err(|err| {
174 request.error = Some(err.to_string());
175 })
176 .map(|result| match result {
177 httparse::Status::Complete(_head_length) => {
178 request.method = req.method.map(|s| s.to_string());
179
180 if req.method.as_ref().unwrap().eq(&"CONNECT") {
181 request.host = req.path.unwrap().split(':').next().map(|f| f.to_string());
182 } else {
183 request.path = req.path.map(|f| f.to_string());
184 }
185
186 if let Some(a @ 0..=1) = req.version {
187 request.version = (1, a);
188 }
189 }
190 httparse::Status::Partial => panic!("Incomplete request"),
191 });
192
193 request
194 }
195}
196
197fn create_identity(cn: &str, pair: Pair) -> native_tls::Identity {
198 let (cert, key) = crate::identity::mk_ca_signed_cert(cn, pair.0, pair.1).unwrap();
199
200 let password = "password";
201 let encrypted = openssl::pkcs12::Pkcs12::builder()
202 .build(password, cn, &key, &cert)
203 .unwrap()
204 .to_der()
205 .unwrap();
206
207 native_tls::Identity::from_pkcs12(&encrypted, password).expect("Unable to build identity")
208}
209
210fn start_proxy(proxy: &mut Proxy) {
211 if proxy.started {
212 panic!("Tried to start an already started proxy");
213 }
214 proxy.started = true;
215 let mocks = proxy.mocks.clone();
216 let cert = proxy.cert.clone();
217 let pkey = proxy.identity.clone();
218
219 let (tx, rx) = mpsc::channel();
224
225 thread::spawn(move || {
226 let res = TcpListener::bind(SERVER_ADDRESS_INTERNAL).or_else(|err| {
227 error!("TcpListener::bind: {}", err);
228 TcpListener::bind("127.0.0.1:0")
229 });
230 let (listener, addr) = match res {
231 Ok(listener) => {
232 let addr = listener.local_addr().unwrap();
233 tx.send(Some(addr)).unwrap();
234 (listener, addr)
235 }
236 Err(err) => {
237 error!("alt bind: {}", err);
238 tx.send(None).unwrap();
239 return;
240 }
241 };
242
243 info!("Server is listening at {}", addr);
244 for stream in listener.incoming() {
245 info!("Got stream: {:?}", stream);
246 if let Ok(mut stream) = stream {
247 let request = Request::from(&mut stream);
248 info!("Request received: {}", request);
249 if request.is_ok() {
250 handle_request(Pair(cert.as_ref(), pkey.as_ref()), &mocks, request, stream)
251 .unwrap();
252 } else {
253 let message = request
254 .error()
255 .map_or("Could not parse the request.", |err| err.as_str());
256 error!("Could not parse request because: {}", message);
257 respond_with_error(&mut stream as &mut dyn Write, &request, message).unwrap();
258 }
259 } else {
260 error!("Could not read from stream");
261 }
262 }
263 });
264
265 proxy.listening_addr = rx.recv().ok().and_then(|addr| addr);
266}
267
268fn open_tunnel<'a>(
269 identity: Pair,
270 request: &Request,
271 stream: &'a mut TcpStream,
272) -> Result<TlsStream<&'a mut TcpStream>, Box<dyn std::error::Error>> {
273 let version = request.version;
274 let status = 200;
275
276 let response = Vec::from(format!(
277 "HTTP/{}.{} {}\r\n\r\n",
278 version.0, version.1, status
279 ));
280
281 stream.write_all(&response)?;
282 stream.flush()?;
283 info!("Tunnel open response written");
284
285 let identity = create_identity(request.host.as_ref().expect("No host??"), identity);
286
287 info!("Wrapping with tls");
288 let tstream = native_tls::TlsAcceptor::builder(identity)
289 .build()
290 .expect("Unable to build acceptor")
291 .accept(stream)
292 .expect("Unable to accept connection");
293 info!("Wrapped: {:?}", tstream);
294
295 Ok(tstream)
296}
297
298fn handle_request(
299 identity: Pair,
300 mocks: &[Mock],
301 request: Request,
302 mut stream: TcpStream,
303) -> Result<(), Box<dyn std::error::Error>> {
304 if !request.method.as_ref().unwrap().eq("CONNECT") {
305 panic!("Not a CONNECT request");
306 }
307
308 let mut tstream = open_tunnel(identity, &request, &mut stream)?;
309
310 let mut req = Request::from(&mut tstream);
311 req.host = request.host;
312
313 let mut matched = false;
314 for m in mocks {
315 if m.matches(&req) {
316 write_response(&mut tstream, &req, &m.response)?;
317 matched = true;
318 break;
319 }
320 }
321
322 if !matched {
323 respond_with_error(&mut tstream, &req, "No matching response")?;
324 }
325
326 Ok(())
327}
328
329fn write_response(
330 tstream: &mut dyn Write,
331 request: &Request,
332 response: &Response,
333) -> Result<(), Box<dyn std::error::Error>> {
334 tstream.write_fmt(format_args!(
335 "HTTP/1.{} {}\r\n",
336 request.version.1, response.status
337 ))?;
338 for (header, value) in &response.headers {
339 tstream.write_fmt(format_args!("{}: {}\r\n", header, value))?;
340 }
341 tstream.write_all(b"\r\n")?;
342 tstream.write_all(&response.body)?;
343 tstream.write_all(b"\r\n")?;
344
345 Ok(())
346}
347
348fn respond_with_error(
349 _stream: &mut dyn Write,
350 request: &Request,
351 message: &str,
352) -> Result<(), Box<dyn std::error::Error>> {
353 write_response(
354 _stream,
355 request,
356 &Response {
357 headers: vec![],
358 status: http::StatusCode::INTERNAL_SERVER_ERROR,
359 body: message.as_bytes().to_vec(),
360 },
361 )
362}