mod filters;
use filters::remove_hop_by_hop_headers;
use flask::httpx::{read_http_request, read_http_response};
use http::{
header::HeaderValue,
Request,
Response,
StatusCode
};
use num_cpus;
use rayon::ThreadPool;
use std::io::Write;
use std::net::SocketAddr;
use std::net::{TcpListener, TcpStream};
use std::str::from_utf8;
use std::string::String;
use std::process::exit;
type RequestTransform = fn(&mut Request<Vec<u8>>) -> Option<Response<Vec<u8>>>;
type ResponseTransform = fn(&mut Response<Vec<u8>>);
fn write_response(resp: Response<Vec<u8>>, mut client: TcpStream) -> bool {
let reason = match resp.status().canonical_reason() {
Some(r) => r,
None => ""
};
let resp_line = format!("{:?} {} {}\r\n", resp.version(), resp.status().as_str(), reason);
if client.write(resp_line.as_bytes()).is_err() {
return false;
}
for (key, value) in resp.headers().iter() {
match key.as_str() {
"content-length" => {
let body_size = resp.body().len();
let h: String = format!("{}: {}\r\n", key, body_size);
if client.write(h.as_bytes()).is_err() {
return false;
}
},
_ => {
let h: String = format!("{}: {}\r\n", key, value.to_str().unwrap());
if client.write(h.as_bytes()).is_err() {
return false;
}
}
};
}
if client.write(b"\r\n").is_err() {
return false;
}
client.write(resp.body()).is_ok()
}
fn write_request(req: Request<Vec<u8>>, mut client: TcpStream) -> bool {
let req_line = format!("{} {} {:?}\r\n", req.method(), req.uri(), req.version());
if client.write(req_line.as_bytes()).is_err() {
return false;
}
for (key, value) in req.headers().iter() {
match key.as_str() {
"content-length" => {
let body_size = req.body().len();
let h: String = format!("{}: {}\r\n", key, body_size);
if client.write(h.as_bytes()).is_err() {
return false;
}
},
"user-agent" => {
}
_ => {
let h: String = format!("{}: {}\r\n", key, value.to_str().unwrap());
if client.write(h.as_bytes()).is_err() {
return false;
}
}
};
}
if client.write(b"\r\n").is_err() {
return false;
}
client.write(req.body()).is_ok()
}
fn handle_request(stream: TcpStream, req: Request<Vec<u8>>,) {
let proxy_addr = req.headers().get(http::header::HOST).unwrap();
let proxy_stream = match TcpStream::connect(from_utf8(proxy_addr.as_bytes()).unwrap()) {
Ok(stream) => stream,
Err(err) => {
eprintln!("Error: could not connect to downstream service ... {}", err);
return;
}
};
if write_request(req, proxy_stream.try_clone().unwrap()) {
match read_http_response(proxy_stream.try_clone().unwrap()) {
Ok(resp) => {
write_response(resp, stream);
},
Err(_) => {
let resp = Response::builder().status(StatusCode::INTERNAL_SERVER_ERROR).body(Vec::new()).unwrap();
write_response(resp, stream);
}
}
} else {
eprintln!("Error: sending request to client")
}
}
fn send_request(req: Request<Vec<u8>>) -> Response<Vec<u8>> {
let proxy_addr = req.headers().get(http::header::HOST).unwrap();
let proxy_stream = match TcpStream::connect(from_utf8(proxy_addr.as_bytes()).unwrap()) {
Ok(stream) => stream,
Err(err) => {
eprintln!("Error: could not connect to downstream service ... {}", err);
let err_resp = Response::builder().status(StatusCode::BAD_GATEWAY).body(Vec::new()).unwrap();
return err_resp;
}
};
let err_resp = Response::builder().status(StatusCode::INTERNAL_SERVER_ERROR).body(Vec::new()).unwrap();
if write_request(req, proxy_stream.try_clone().unwrap()) {
match read_http_response(proxy_stream.try_clone().unwrap()) {
Ok(resp) => {
resp
},
Err(_) => {
err_resp
}
}
} else {
eprintln!("Error: sending request to client");
err_resp
}
}
fn handle_client(stream: TcpStream, req_trans: RequestTransform, resp_trans: ResponseTransform) {
let mut req = read_http_request(stream.try_clone().unwrap()).unwrap();
*req.headers_mut() = remove_hop_by_hop_headers(req.headers());
let short_circuit_response = (req_trans)(&mut req);
match short_circuit_response {
Some(sc_resp) => {
if !write_response(sc_resp, stream) {
eprintln!("Error receiving response from client");
}
},
None => {
let mut resp = send_request(req);
(resp_trans)(&mut resp);
write_response(resp, stream);
}
};
}
pub fn generic_proxy(listen_addr: SocketAddr, req_trans: RequestTransform, resp_trans: ResponseTransform) {
let listener: TcpListener = match TcpListener::bind(listen_addr) {
Ok(_listener) => _listener,
Err(err) => {
eprintln!("{}", err);
exit(1);
}
};
match rayon::ThreadPoolBuilder::new().num_threads(2*num_cpus::get()).build() {
Ok(pool) => {
pool.install( || {
for new_stream in listener.incoming() {
match new_stream {
Ok(stream) => {
pool.spawn( move || {
handle_client(stream, req_trans, resp_trans)
})
}
Err(_) => {
eprintln!("Error accessing TcpStream in generic_proxy");
}
}
}
});
},
Err(err) => {
eprintln!("{}", err);
println!("Running proxy without thread pool");
for new_stream in listener.incoming() {
match new_stream {
Ok(stream) => handle_client(stream, req_trans, resp_trans),
Err(_) => eprintln!("Error accessing TcpStream in generic_proxy")
}
}
}
}
}
pub fn simple_proxy(listen_addr: SocketAddr, proxy_addr: SocketAddr) {
let listener: TcpListener = match TcpListener::bind(listen_addr) {
Ok(_listener) => _listener,
Err(err) => {
eprintln!("{}", err);
exit(1);
}
};
match rayon::ThreadPoolBuilder::new().num_threads(2*num_cpus::get()).build() {
Ok(pool) => proxy_with_thread_pool(pool, listener, proxy_addr),
Err(err) => {
eprintln!("{}", err);
println!("Running proxy without thread pool");
proxy_without_thread_pool(listener, proxy_addr)
}
}
}
fn proxy_tcp_stream(stream: TcpStream, proxy_addr: SocketAddr, ) {
let _proxy_add_str = format!("{}", proxy_addr);
let proxy_addr_hdr = HeaderValue::from_str(&_proxy_add_str).unwrap();
let mut req = read_http_request(stream.try_clone().unwrap()).unwrap();
*req.headers_mut() = remove_hop_by_hop_headers(req.headers());
let req_headers = req.headers_mut();
req_headers.remove(http::header::HOST);
req_headers.insert(http::header::HOST, proxy_addr_hdr);
handle_request(stream, req);
}
fn proxy_without_thread_pool(listener: TcpListener, proxy_addr: SocketAddr) {
for new_stream in listener.incoming() {
match new_stream {
Ok(stream) => {
proxy_tcp_stream(stream, proxy_addr)
}
Err(_) => {
eprintln!("Error accessing TcpStream in generic_proxy");
}
}
}
}
fn proxy_with_thread_pool(pool: ThreadPool, listener: TcpListener, proxy_addr: SocketAddr) {
pool.install( || {
for new_stream in listener.incoming() {
match new_stream {
Ok(stream) => {
pool.spawn( move || {
proxy_tcp_stream(stream, proxy_addr)
})
}
Err(_) => {
eprintln!("Error accessing TcpStream in generic_proxy");
}
}
}
});
}