#![deny(missing_docs)]
#![warn(clippy::nursery)]
use crate::mock::Response;
use log::{error, info};
use native_tls::TlsStream;
use openssl::pkey::{PKey, PKeyRef, Private};
use openssl::x509::X509Ref;
use std::io::{Read, Write};
use std::net::{SocketAddr, TcpListener, TcpStream};
use std::sync::mpsc;
use std::thread;
mod identity;
mod mock;
#[cfg(test)]
mod test;
pub use crate::mock::Mock;
const SERVER_ADDRESS_INTERNAL: &str = "127.0.0.1:1234";
pub struct Proxy {
mocks: Vec<Mock>,
listening_addr: Option<SocketAddr>,
started: bool,
identity: PKey<Private>,
cert: openssl::x509::X509,
}
impl Default for Proxy {
fn default() -> Self {
let (cert, identity) = crate::identity::mk_ca_cert().unwrap();
Self {
mocks: Vec::new(),
listening_addr: None,
started: false,
identity,
cert,
}
}
}
struct Pair<'a>(&'a X509Ref, &'a PKeyRef<Private>);
impl Proxy {
pub fn new() -> Self {
Self::default()
}
pub fn register(&mut self, mock: Mock) {
if self.started {
panic!("Cannot add mocks to a started proxy");
}
self.mocks.push(mock);
}
pub fn start(&mut self) {
start_proxy(self);
}
pub fn stop(&mut self) {
todo!();
}
pub fn address(&self) -> SocketAddr {
self.listening_addr.expect("server should be listening")
}
pub fn url(&self) -> String {
format!("http://{}", self.address())
}
pub fn get_certificate(&self) -> Vec<u8> {
self.cert.to_pem().unwrap()
}
}
#[derive(Debug, Clone)]
struct Request {
error: Option<String>,
host: Option<String>,
path: Option<String>,
method: Option<String>,
version: (u8, u8),
}
impl std::fmt::Display for Request {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
f.debug_struct("Request")
.field("method", &self.method)
.field("host", &self.host)
.field("path", &self.path)
.finish()
}
}
impl Request {
fn is_ok(&self) -> bool {
self.error().is_none()
}
fn error(&self) -> Option<&String> {
self.error.as_ref()
}
fn from(stream: &mut dyn Read) -> Self {
let mut request = Self {
error: None,
host: None,
path: None,
method: None,
version: (0, 0),
};
let mut all_buf = Vec::new();
loop {
let mut buf = [0; 1024];
let rlen = match stream.read(&mut buf) {
Err(e) => Err(e.to_string()),
Ok(0) => Err("Nothing to read.".into()),
Ok(i) => Ok(i),
}
.map_err(|e| request.error = Some(e))
.unwrap_or(0);
if request.error().is_some() {
break;
}
all_buf.extend_from_slice(&buf[..rlen]);
if rlen < 1024 {
break;
}
}
let mut headers = [httparse::EMPTY_HEADER; 16];
let mut req = httparse::Request::new(&mut headers);
let _ = req
.parse(&all_buf)
.map_err(|err| {
request.error = Some(err.to_string());
})
.map(|result| match result {
httparse::Status::Complete(_head_length) => {
request.method = req.method.map(|s| s.to_string());
if req.method.as_ref().unwrap().eq(&"CONNECT") {
request.host = req.path.unwrap().split(':').next().map(|f| f.to_string());
} else {
request.path = req.path.map(|f| f.to_string());
}
if let Some(a @ 0..=1) = req.version {
request.version = (1, a);
}
}
httparse::Status::Partial => panic!("Incomplete request"),
});
request
}
}
fn create_identity(cn: &str, pair: Pair) -> native_tls::Identity {
let (cert, key) = crate::identity::mk_ca_signed_cert(cn, pair.0, pair.1).unwrap();
let password = "password";
let encrypted = openssl::pkcs12::Pkcs12::builder()
.build(password, cn, &key, &cert)
.unwrap()
.to_der()
.unwrap();
native_tls::Identity::from_pkcs12(&encrypted, password).expect("Unable to build identity")
}
fn start_proxy(proxy: &mut Proxy) {
if proxy.started {
panic!("Tried to start an already started proxy");
}
proxy.started = true;
let mocks = proxy.mocks.clone();
let cert = proxy.cert.clone();
let pkey = proxy.identity.clone();
let (tx, rx) = mpsc::channel();
thread::spawn(move || {
let res = TcpListener::bind(SERVER_ADDRESS_INTERNAL).or_else(|err| {
error!("TcpListener::bind: {}", err);
TcpListener::bind("127.0.0.1:0")
});
let (listener, addr) = match res {
Ok(listener) => {
let addr = listener.local_addr().unwrap();
tx.send(Some(addr)).unwrap();
(listener, addr)
}
Err(err) => {
error!("alt bind: {}", err);
tx.send(None).unwrap();
return;
}
};
info!("Server is listening at {}", addr);
for stream in listener.incoming() {
info!("Got stream: {:?}", stream);
if let Ok(mut stream) = stream {
let request = Request::from(&mut stream);
info!("Request received: {}", request);
if request.is_ok() {
handle_request(Pair(cert.as_ref(), pkey.as_ref()), &mocks, request, stream)
.unwrap();
} else {
let message = request
.error()
.map_or("Could not parse the request.", |err| err.as_str());
error!("Could not parse request because: {}", message);
respond_with_error(&mut stream as &mut dyn Write, &request, message).unwrap();
}
} else {
error!("Could not read from stream");
}
}
});
proxy.listening_addr = rx.recv().ok().and_then(|addr| addr);
}
fn open_tunnel<'a>(
identity: Pair,
request: &Request,
stream: &'a mut TcpStream,
) -> Result<TlsStream<&'a mut TcpStream>, Box<dyn std::error::Error>> {
let version = request.version;
let status = 200;
let response = Vec::from(format!(
"HTTP/{}.{} {}\r\n\r\n",
version.0, version.1, status
));
stream.write_all(&response)?;
stream.flush()?;
info!("Tunnel open response written");
let identity = create_identity(request.host.as_ref().expect("No host??"), identity);
info!("Wrapping with tls");
let tstream = native_tls::TlsAcceptor::builder(identity)
.build()
.expect("Unable to build acceptor")
.accept(stream)
.expect("Unable to accept connection");
info!("Wrapped: {:?}", tstream);
Ok(tstream)
}
fn handle_request(
identity: Pair,
mocks: &[Mock],
request: Request,
mut stream: TcpStream,
) -> Result<(), Box<dyn std::error::Error>> {
if !request.method.as_ref().unwrap().eq("CONNECT") {
panic!("Not a CONNECT request");
}
let mut tstream = open_tunnel(identity, &request, &mut stream)?;
let mut req = Request::from(&mut tstream);
req.host = request.host;
let mut matched = false;
for m in mocks {
if m.matches(&req) {
write_response(&mut tstream, &req, &m.response)?;
matched = true;
break;
}
}
if !matched {
respond_with_error(&mut tstream, &req, "No matching response")?;
}
Ok(())
}
fn write_response(
tstream: &mut dyn Write,
request: &Request,
response: &Response,
) -> Result<(), Box<dyn std::error::Error>> {
tstream.write_fmt(format_args!(
"HTTP/1.{} {}\r\n",
request.version.1, response.status
))?;
for (header, value) in &response.headers {
tstream.write_fmt(format_args!("{}: {}\r\n", header, value))?;
}
tstream.write_all(b"\r\n")?;
tstream.write_all(&response.body)?;
tstream.write_all(b"\r\n")?;
Ok(())
}
fn respond_with_error(
_stream: &mut dyn Write,
request: &Request,
message: &str,
) -> Result<(), Box<dyn std::error::Error>> {
write_response(
_stream,
request,
&Response {
headers: vec![],
status: http::StatusCode::INTERNAL_SERVER_ERROR,
body: message.as_bytes().to_vec(),
},
)
}