#[cfg(test)]
mod tests;
use std::io::{Read, Write};
use std::net::{TcpStream, ToSocketAddrs};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::time::Duration;
use crate::application::Application;
use crate::core::New;
use crate::middleware::Middleware;
use crate::mime_type::MimeType;
use crate::range::Range;
use crate::request::Request;
use crate::response::{Response, STATUS_CODE_REASON_PHRASE};
use crate::server::ConnectionInfo;
const HOP_BY_HOP: &[&str] = &[
"connection",
"keep-alive",
"proxy-authenticate",
"proxy-authorization",
"te",
"trailers",
"transfer-encoding",
"upgrade",
];
pub enum LoadBalancing {
RoundRobin,
}
pub struct ReverseProxy {
backends: Vec<Backend>,
path_prefix: Option<String>,
connect_timeout: Duration,
read_timeout: Duration,
counter: AtomicUsize,
}
impl ReverseProxy {
pub fn new<I, S>(backends: I) -> Self
where
I: IntoIterator<Item = S>,
S: AsRef<str>,
{
Self {
backends: backends
.into_iter()
.filter_map(|u| Backend::parse(u.as_ref()))
.collect(),
path_prefix: None,
connect_timeout: Duration::from_secs(5),
read_timeout: Duration::from_secs(30),
counter: AtomicUsize::new(0),
}
}
pub fn path_prefix(mut self, prefix: impl Into<String>) -> Self {
self.path_prefix = Some(prefix.into());
self
}
pub fn strategy(self, _strategy: LoadBalancing) -> Self {
self
}
pub fn connect_timeout_ms(mut self, ms: u64) -> Self {
self.connect_timeout = Duration::from_millis(ms);
self
}
pub fn read_timeout_ms(mut self, ms: u64) -> Self {
self.read_timeout = Duration::from_millis(ms);
self
}
fn proxy(&self, request: &Request, connection: &ConnectionInfo) -> Result<Response, String> {
if self.backends.is_empty() {
return Err("no backends configured".to_string());
}
let n = self.backends.len();
let start = self.counter.fetch_add(1, Ordering::Relaxed);
for attempt in 0..n {
let idx = (start + attempt) % n;
match self.try_backend(request, connection, &self.backends[idx]) {
Ok(resp) => return Ok(resp),
Err(_) if attempt + 1 < n => continue,
Err(e) => return Err(e),
}
}
Err("all backends failed".to_string())
}
fn try_backend(
&self,
request: &Request,
connection: &ConnectionInfo,
backend: &Backend,
) -> Result<Response, String> {
let addr_str = format!("{}:{}", backend.host, backend.port);
let sock_addr = addr_str
.to_socket_addrs()
.map_err(|e| format!("DNS lookup for {} failed: {}", addr_str, e))?
.next()
.ok_or_else(|| format!("no address resolved for {}", addr_str))?;
let stream = TcpStream::connect_timeout(&sock_addr, self.connect_timeout)
.map_err(|e| format!("connect to {} failed: {}", addr_str, e))?;
stream
.set_read_timeout(Some(self.read_timeout))
.map_err(|e| e.to_string())?;
stream
.set_write_timeout(Some(Duration::from_secs(10)))
.map_err(|e| e.to_string())?;
let req_bytes = build_request(request, &backend.host, &connection.client.ip);
let mut stream = stream;
stream
.write_all(&req_bytes)
.map_err(|e| format!("write to backend failed: {}", e))?;
let resp_bytes = read_response(&mut stream)?;
Response::parse(&resp_bytes)
}
}
impl Middleware for ReverseProxy {
fn handle(
&self,
request: &Request,
connection: &ConnectionInfo,
next: &dyn Application,
) -> Result<Response, String> {
if let Some(prefix) = &self.path_prefix {
if !request.request_uri.starts_with(prefix.as_str()) {
return next.execute(request, connection);
}
}
match self.proxy(request, connection) {
Ok(resp) => Ok(resp),
Err(_) => Ok(bad_gateway()),
}
}
}
fn build_request(request: &Request, backend_host: &str, client_ip: &str) -> Vec<u8> {
let mut out: Vec<u8> = Vec::new();
let _ = write!(
out,
"{} {} HTTP/1.1\r\nHost: {}\r\n",
request.method, request.request_uri, backend_host
);
for h in &request.headers {
let lower = h.name.to_lowercase();
if HOP_BY_HOP.contains(&lower.as_str()) || lower == "host" {
continue;
}
let _ = write!(out, "{}: {}\r\n", h.name, h.value);
}
let _ = write!(out, "X-Forwarded-For: {}\r\n", client_ip);
let _ = write!(out, "Via: 1.1 rws\r\n");
let _ = write!(out, "Connection: close\r\n");
if !request.body.is_empty() {
let _ = write!(out, "Content-Length: {}\r\n", request.body.len());
}
let _ = write!(out, "\r\n");
out.extend_from_slice(&request.body);
out
}
fn read_response(stream: &mut TcpStream) -> Result<Vec<u8>, String> {
let mut buf: Vec<u8> = Vec::with_capacity(8192);
let mut tmp = [0u8; 4096];
let header_end = loop {
let n = stream.read(&mut tmp).map_err(|e| e.to_string())?;
if n == 0 {
return if buf.is_empty() {
Err("backend closed connection without sending a response".to_string())
} else {
Ok(buf)
};
}
buf.extend_from_slice(&tmp[..n]);
if let Some(pos) = buf.windows(4).position(|w| w == b"\r\n\r\n") {
break pos + 4;
}
};
let content_length = std::str::from_utf8(&buf[..header_end])
.unwrap_or("")
.lines()
.find_map(|line| {
line.to_lowercase()
.starts_with("content-length:")
.then(|| line.splitn(2, ':').nth(1)?.trim().parse::<usize>().ok())
.flatten()
});
match content_length {
Some(len) => {
while buf.len() < header_end + len {
let n = stream.read(&mut tmp).map_err(|e| e.to_string())?;
if n == 0 {
break;
}
buf.extend_from_slice(&tmp[..n]);
}
}
None => loop {
let n = stream.read(&mut tmp).map_err(|e| e.to_string())?;
if n == 0 {
break;
}
buf.extend_from_slice(&tmp[..n]);
},
}
Ok(buf)
}
fn bad_gateway() -> Response {
let cr = Range::get_content_range(
b"502 Bad Gateway".to_vec(),
MimeType::TEXT_PLAIN.to_string(),
);
let mut r = Response::new();
r.status_code = *STATUS_CODE_REASON_PHRASE.n502_bad_gateway.status_code;
r.reason_phrase = STATUS_CODE_REASON_PHRASE
.n502_bad_gateway
.reason_phrase
.to_string();
r.content_range_list = vec![cr];
r
}
struct Backend {
host: String,
port: u16,
}
impl Backend {
fn parse(url: &str) -> Option<Self> {
let rest = url
.strip_prefix("https://")
.or_else(|| url.strip_prefix("http://"))
.unwrap_or(url);
let host_port = rest.split('/').next().unwrap_or(rest);
let (host, port) = if let Some(colon) = host_port.rfind(':') {
let port_str = &host_port[colon + 1..];
if let Ok(p) = port_str.parse::<u16>() {
(host_port[..colon].to_string(), p)
} else {
(host_port.to_string(), 80)
}
} else {
(host_port.to_string(), 80)
};
if host.is_empty() {
return None;
}
Some(Backend { host, port })
}
}