mock_proxy/
lib.rs

1#![deny(missing_docs)]
2#![warn(clippy::nursery)]
3
4//! This library was built to help test systems that use libraries which don't provide any
5//! testing utilities themselves. It works by overriding the proxy and root ca attributes
6//! and intercepting proxy requests, then returning mock responses defined by the user
7//!
8//! The following shows how to setup reqwest to send requests to a [`Proxy`] instance: [simple_test](https://github.com/Mause/mock_proxy/blob/main/src/test.rs)
9
10use crate::mock::Response;
11use log::{error, info};
12use native_tls::TlsStream;
13use openssl::pkey::{PKey, PKeyRef, Private};
14use openssl::x509::X509Ref;
15use std::io::{Read, Write};
16use std::net::{SocketAddr, TcpListener, TcpStream};
17use std::sync::mpsc;
18use std::thread;
19
20mod identity;
21mod mock;
22#[cfg(test)]
23mod test;
24pub use crate::mock::Mock;
25
26const SERVER_ADDRESS_INTERNAL: &str = "127.0.0.1:1234";
27
28/// Primary interface for the library
29pub struct Proxy {
30    mocks: Vec<Mock>,
31    listening_addr: Option<SocketAddr>,
32    started: bool,
33    identity: PKey<Private>,
34    cert: openssl::x509::X509,
35}
36
37impl Default for Proxy {
38    fn default() -> Self {
39        let (cert, identity) = crate::identity::mk_ca_cert().unwrap();
40        Self {
41            mocks: Vec::new(),
42            listening_addr: None,
43            started: false,
44            identity,
45            cert,
46        }
47    }
48}
49
50struct Pair<'a>(&'a X509Ref, &'a PKeyRef<Private>);
51
52impl Proxy {
53    /// Builds a [`Default`] instance
54    pub fn new() -> Self {
55        Self::default()
56    }
57
58    /// Register a given mock with the proxy
59    ///
60    /// # Panics
61    /// Will panic if proxy has already been started
62    pub fn register(&mut self, mock: Mock) {
63        if self.started {
64            panic!("Cannot add mocks to a started proxy");
65        }
66        self.mocks.push(mock);
67    }
68
69    /// Start the proxy server
70    ///
71    /// # Panics
72    /// Will panic if proxy has already been started
73    pub fn start(&mut self) {
74        start_proxy(self);
75    }
76
77    /// Start the server
78    /// # Panics
79    /// Not supported yet
80    pub fn stop(&mut self) {
81        todo!();
82    }
83
84    /// Address and port of the local server.
85    /// Can be used with `std::net::TcpStream`.
86    ///
87    /// # Panics
88    /// If server is not running
89    pub fn address(&self) -> SocketAddr {
90        self.listening_addr.expect("server should be listening")
91    }
92
93    /// A local `http://…` URL of the server.
94    ///
95    /// # Panics
96    /// If server is not running
97    pub fn url(&self) -> String {
98        format!("http://{}", self.address())
99    }
100
101    /// Returns the root CA certificate of the server
102    ///
103    /// # Panics
104    /// If PEM conversion fails
105    pub fn get_certificate(&self) -> Vec<u8> {
106        self.cert.to_pem().unwrap()
107    }
108}
109
110#[derive(Debug, Clone)]
111struct Request {
112    error: Option<String>,
113    host: Option<String>,
114    path: Option<String>,
115    method: Option<String>,
116    version: (u8, u8),
117}
118
119impl std::fmt::Display for Request {
120    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
121        f.debug_struct("Request")
122            .field("method", &self.method)
123            .field("host", &self.host)
124            .field("path", &self.path)
125            .finish()
126    }
127}
128impl Request {
129    fn is_ok(&self) -> bool {
130        self.error().is_none()
131    }
132    fn error(&self) -> Option<&String> {
133        self.error.as_ref()
134    }
135
136    fn from(stream: &mut dyn Read) -> Self {
137        let mut request = Self {
138            error: None,
139            host: None,
140            path: None,
141            method: None,
142            version: (0, 0),
143        };
144
145        let mut all_buf = Vec::new();
146
147        loop {
148            let mut buf = [0; 1024];
149
150            let rlen = match stream.read(&mut buf) {
151                Err(e) => Err(e.to_string()),
152                Ok(0) => Err("Nothing to read.".into()),
153                Ok(i) => Ok(i),
154            }
155            .map_err(|e| request.error = Some(e))
156            .unwrap_or(0);
157            if request.error().is_some() {
158                break;
159            }
160
161            all_buf.extend_from_slice(&buf[..rlen]);
162
163            if rlen < 1024 {
164                break;
165            }
166        }
167
168        let mut headers = [httparse::EMPTY_HEADER; 16];
169        let mut req = httparse::Request::new(&mut headers);
170
171        let _ = req
172            .parse(&all_buf)
173            .map_err(|err| {
174                request.error = Some(err.to_string());
175            })
176            .map(|result| match result {
177                httparse::Status::Complete(_head_length) => {
178                    request.method = req.method.map(|s| s.to_string());
179
180                    if req.method.as_ref().unwrap().eq(&"CONNECT") {
181                        request.host = req.path.unwrap().split(':').next().map(|f| f.to_string());
182                    } else {
183                        request.path = req.path.map(|f| f.to_string());
184                    }
185
186                    if let Some(a @ 0..=1) = req.version {
187                        request.version = (1, a);
188                    }
189                }
190                httparse::Status::Partial => panic!("Incomplete request"),
191            });
192
193        request
194    }
195}
196
197fn create_identity(cn: &str, pair: Pair) -> native_tls::Identity {
198    let (cert, key) = crate::identity::mk_ca_signed_cert(cn, pair.0, pair.1).unwrap();
199
200    let password = "password";
201    let encrypted = openssl::pkcs12::Pkcs12::builder()
202        .build(password, cn, &key, &cert)
203        .unwrap()
204        .to_der()
205        .unwrap();
206
207    native_tls::Identity::from_pkcs12(&encrypted, password).expect("Unable to build identity")
208}
209
210fn start_proxy(proxy: &mut Proxy) {
211    if proxy.started {
212        panic!("Tried to start an already started proxy");
213    }
214    proxy.started = true;
215    let mocks = proxy.mocks.clone();
216    let cert = proxy.cert.clone();
217    let pkey = proxy.identity.clone();
218
219    // if state.listening_addr.is_some() {
220    //     return;
221    // }
222
223    let (tx, rx) = mpsc::channel();
224
225    thread::spawn(move || {
226        let res = TcpListener::bind(SERVER_ADDRESS_INTERNAL).or_else(|err| {
227            error!("TcpListener::bind: {}", err);
228            TcpListener::bind("127.0.0.1:0")
229        });
230        let (listener, addr) = match res {
231            Ok(listener) => {
232                let addr = listener.local_addr().unwrap();
233                tx.send(Some(addr)).unwrap();
234                (listener, addr)
235            }
236            Err(err) => {
237                error!("alt bind: {}", err);
238                tx.send(None).unwrap();
239                return;
240            }
241        };
242
243        info!("Server is listening at {}", addr);
244        for stream in listener.incoming() {
245            info!("Got stream: {:?}", stream);
246            if let Ok(mut stream) = stream {
247                let request = Request::from(&mut stream);
248                info!("Request received: {}", request);
249                if request.is_ok() {
250                    handle_request(Pair(cert.as_ref(), pkey.as_ref()), &mocks, request, stream)
251                        .unwrap();
252                } else {
253                    let message = request
254                        .error()
255                        .map_or("Could not parse the request.", |err| err.as_str());
256                    error!("Could not parse request because: {}", message);
257                    respond_with_error(&mut stream as &mut dyn Write, &request, message).unwrap();
258                }
259            } else {
260                error!("Could not read from stream");
261            }
262        }
263    });
264
265    proxy.listening_addr = rx.recv().ok().and_then(|addr| addr);
266}
267
268fn open_tunnel<'a>(
269    identity: Pair,
270    request: &Request,
271    stream: &'a mut TcpStream,
272) -> Result<TlsStream<&'a mut TcpStream>, Box<dyn std::error::Error>> {
273    let version = request.version;
274    let status = 200;
275
276    let response = Vec::from(format!(
277        "HTTP/{}.{} {}\r\n\r\n",
278        version.0, version.1, status
279    ));
280
281    stream.write_all(&response)?;
282    stream.flush()?;
283    info!("Tunnel open response written");
284
285    let identity = create_identity(request.host.as_ref().expect("No host??"), identity);
286
287    info!("Wrapping with tls");
288    let tstream = native_tls::TlsAcceptor::builder(identity)
289        .build()
290        .expect("Unable to build acceptor")
291        .accept(stream)
292        .expect("Unable to accept connection");
293    info!("Wrapped: {:?}", tstream);
294
295    Ok(tstream)
296}
297
298fn handle_request(
299    identity: Pair,
300    mocks: &[Mock],
301    request: Request,
302    mut stream: TcpStream,
303) -> Result<(), Box<dyn std::error::Error>> {
304    if !request.method.as_ref().unwrap().eq("CONNECT") {
305        panic!("Not a CONNECT request");
306    }
307
308    let mut tstream = open_tunnel(identity, &request, &mut stream)?;
309
310    let mut req = Request::from(&mut tstream);
311    req.host = request.host;
312
313    let mut matched = false;
314    for m in mocks {
315        if m.matches(&req) {
316            write_response(&mut tstream, &req, &m.response)?;
317            matched = true;
318            break;
319        }
320    }
321
322    if !matched {
323        respond_with_error(&mut tstream, &req, "No matching response")?;
324    }
325
326    Ok(())
327}
328
329fn write_response(
330    tstream: &mut dyn Write,
331    request: &Request,
332    response: &Response,
333) -> Result<(), Box<dyn std::error::Error>> {
334    tstream.write_fmt(format_args!(
335        "HTTP/1.{} {}\r\n",
336        request.version.1, response.status
337    ))?;
338    for (header, value) in &response.headers {
339        tstream.write_fmt(format_args!("{}: {}\r\n", header, value))?;
340    }
341    tstream.write_all(b"\r\n")?;
342    tstream.write_all(&response.body)?;
343    tstream.write_all(b"\r\n")?;
344
345    Ok(())
346}
347
348fn respond_with_error(
349    _stream: &mut dyn Write,
350    request: &Request,
351    message: &str,
352) -> Result<(), Box<dyn std::error::Error>> {
353    write_response(
354        _stream,
355        request,
356        &Response {
357            headers: vec![],
358            status: http::StatusCode::INTERNAL_SERVER_ERROR,
359            body: message.as_bytes().to_vec(),
360        },
361    )
362}