use std::{io, net::SocketAddr, str::FromStr, sync::Arc};
use hyper::{
header::{GetAll, HeaderValue},
http::uri::{Authority, Scheme},
upgrade,
Body,
HeaderMap,
Method,
Request,
Response,
StatusCode,
Uri,
Version,
};
use log::{debug, error, trace};
use shadowsocks::relay::socks5::Address;
use crate::local::{
context::ServiceContext,
loadbalancing::PingBalancer,
net::{AutoProxyClientStream, AutoProxyIo},
utils::{establish_tcp_tunnel, establish_tcp_tunnel_bypassed},
};
use super::{
client_cache::ProxyClientCache,
http_client::{BypassHttpClient, HttpClientEnum},
utils::{authority_addr, host_addr},
};
pub struct HttpDispatcher {
context: Arc<ServiceContext>,
req: Request<Body>,
balancer: PingBalancer,
client_addr: SocketAddr,
bypass_client: BypassHttpClient,
proxy_client_cache: Arc<ProxyClientCache>,
}
impl HttpDispatcher {
pub fn new(
context: Arc<ServiceContext>,
req: Request<Body>,
balancer: PingBalancer,
client_addr: SocketAddr,
bypass_client: BypassHttpClient,
proxy_client_cache: Arc<ProxyClientCache>,
) -> HttpDispatcher {
HttpDispatcher {
context,
req,
balancer,
client_addr,
bypass_client,
proxy_client_cache,
}
}
pub async fn dispatch(mut self) -> io::Result<Response<Body>> {
trace!("request {} {:?}", self.client_addr, self.req);
let host = match host_addr(self.req.uri()) {
None => {
if self.req.uri().authority().is_some() {
error!(
"HTTP {} URI {} doesn't have a valid host",
self.req.method(),
self.req.uri()
);
return make_bad_request();
} else {
trace!(
"HTTP {} URI {} doesn't have a valid host",
self.req.method(),
self.req.uri()
);
}
match get_addr_from_header(&mut self.req) {
Ok(h) => h,
Err(()) => return make_bad_request(),
}
}
Some(h) => h,
};
if Method::CONNECT == self.req.method() {
debug!("HTTP CONNECT {}", host);
let mut server_opt = None;
let mut stream = if self.balancer.is_empty() {
AutoProxyClientStream::connect_bypassed(self.context, &host).await?
} else {
let server = self.balancer.best_tcp_server();
let stream = AutoProxyClientStream::connect(self.context, server.as_ref(), &host).await?;
server_opt = Some(server);
stream
};
debug!(
"CONNECT relay connected {} <-> {} ({})",
self.client_addr,
host,
if stream.is_bypassed() { "bypassed" } else { "proxied" }
);
let req = self.req;
let client_addr = self.client_addr;
tokio::spawn(async move {
match upgrade::on(req).await {
Ok(mut upgraded) => {
trace!("CONNECT tunnel upgrade success, {} <-> {}", client_addr, host);
let _ = match server_opt {
Some(server) => {
establish_tcp_tunnel(
server.server_config(),
&mut upgraded,
&mut stream,
client_addr,
&host,
)
.await
}
None => establish_tcp_tunnel_bypassed(&mut upgraded, &mut stream, client_addr, &host).await,
};
}
Err(e) => {
error!(
"failed to upgrade TCP tunnel {} <-> {}, error: {}",
client_addr, host, e
);
}
}
});
let resp = Response::builder().body(Body::empty()).unwrap();
Ok(resp)
} else {
let method = self.req.method().clone();
let version = self.req.version();
debug!("HTTP {} {} {:?}", method, host, version);
let conn_keep_alive = check_keep_alive(version, self.req.headers(), true);
clear_hop_headers(self.req.headers_mut());
set_conn_keep_alive(version, self.req.headers_mut(), conn_keep_alive);
let client = if self.balancer.is_empty() || self.context.check_target_bypassed(&host).await {
trace!("bypassed {} -> {} {:?}", self.client_addr, host, self.req);
HttpClientEnum::Bypass(self.bypass_client)
} else {
trace!("proxied {} -> {} {:?}", self.client_addr, host, self.req);
let server = self.balancer.best_tcp_server();
HttpClientEnum::Proxy(self.proxy_client_cache.get_connected(&server).await)
};
let mut res = match client.send(self.req).await {
Ok(res) => res,
Err(err) => {
error!(
"HTTP {} {} <-> {} relay failed, error: {}",
method, self.client_addr, host, err
);
let mut resp = Response::new(Body::from(format!("relay failed to {host}")));
*resp.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
return Ok(resp);
}
};
trace!("received {} <- {} {:?}", self.client_addr, host, res);
let res_keep_alive = conn_keep_alive && check_keep_alive(res.version(), res.headers(), false);
clear_hop_headers(res.headers_mut());
if res.version() != version {
trace!("response version {:?} => {:?}", res.version(), version);
*res.version_mut() = version;
}
set_conn_keep_alive(res.version(), res.headers_mut(), res_keep_alive);
trace!("response {} <- {} {:?}", self.client_addr, host, res);
debug!("HTTP {} relay {} <-> {} finished", method, self.client_addr, host);
Ok(res)
}
}
}
fn make_bad_request() -> io::Result<Response<Body>> {
let mut resp = Response::new(Body::empty());
*resp.status_mut() = StatusCode::BAD_REQUEST;
Ok(resp)
}
fn get_keep_alive_val(values: GetAll<HeaderValue>) -> Option<bool> {
let mut conn_keep_alive = None;
for value in values {
if let Ok(value) = value.to_str() {
if value.eq_ignore_ascii_case("close") {
conn_keep_alive = Some(false);
} else {
for part in value.split(',') {
let part = part.trim();
if part.eq_ignore_ascii_case("keep-alive") {
conn_keep_alive = Some(true);
break;
}
}
}
}
}
conn_keep_alive
}
fn check_keep_alive(version: Version, headers: &HeaderMap<HeaderValue>, check_proxy: bool) -> bool {
let mut conn_keep_alive = !matches!(version, Version::HTTP_09 | Version::HTTP_10);
if check_proxy {
if let Some(b) = get_keep_alive_val(headers.get_all("Proxy-Connection")) {
conn_keep_alive = b
}
}
if let Some(b) = get_keep_alive_val(headers.get_all("Connection")) {
conn_keep_alive = b
}
conn_keep_alive
}
fn get_extra_headers(headers: GetAll<HeaderValue>) -> Vec<String> {
let mut extra_headers = Vec::new();
for connection in headers {
if let Ok(conn) = connection.to_str() {
if conn.eq_ignore_ascii_case("close") {
continue;
}
for header in conn.split(',') {
let header = header.trim();
extra_headers.push(header.to_owned());
}
}
}
extra_headers
}
fn clear_hop_headers(headers: &mut HeaderMap<HeaderValue>) {
let mut extra_headers = get_extra_headers(headers.get_all("Connection"));
extra_headers.extend(get_extra_headers(headers.get_all("Proxy-Connection")));
for header in extra_headers {
while let Some(..) = headers.remove(&header) {}
}
const HOP_BY_HOP_HEADERS: [&str; 9] = [
"Keep-Alive",
"Transfer-Encoding",
"TE",
"Connection",
"Trailer",
"Upgrade",
"Proxy-Authorization",
"Proxy-Authenticate",
"Proxy-Connection", ];
for header in &HOP_BY_HOP_HEADERS {
while let Some(..) = headers.remove(*header) {}
}
}
fn set_conn_keep_alive(version: Version, headers: &mut HeaderMap<HeaderValue>, keep_alive: bool) {
match version {
Version::HTTP_09 | Version::HTTP_10 => {
if keep_alive {
headers.insert("Connection", HeaderValue::from_static("keep-alive"));
}
}
_ => {
if !keep_alive {
headers.insert("Connection", HeaderValue::from_static("close"));
}
}
}
}
fn get_addr_from_header(req: &mut Request<Body>) -> Result<Address, ()> {
match req.headers().get("Host") {
Some(hhost) => match hhost.to_str() {
Ok(shost) => {
match Authority::from_str(shost) {
Ok(authority) => match authority_addr(req.uri().scheme_str(), &authority) {
Some(host) => {
trace!("HTTP {} URI {} got host from header: {}", req.method(), req.uri(), host);
let mut parts = req.uri().clone().into_parts();
if parts.scheme.is_none() {
parts.scheme = Some(Scheme::HTTP);
}
parts.authority = Some(authority);
*req.uri_mut() = Uri::from_parts(parts).expect("Reassemble URI failed");
debug!("reassembled URI from \"Host\", {}", req.uri());
Ok(host)
}
None => {
error!(
"HTTP {} URI {} \"Host\" header invalid, value: {}",
req.method(),
req.uri(),
shost
);
Err(())
}
},
Err(..) => {
error!(
"HTTP {} URI {} \"Host\" header is not an Authority, value: {:?}",
req.method(),
req.uri(),
hhost
);
Err(())
}
}
}
Err(..) => {
error!(
"HTTP {} URI {} \"Host\" header invalid encoding, value: {:?}",
req.method(),
req.uri(),
hhost
);
Err(())
}
},
None => {
error!(
"HTTP {} URI doesn't have valid host and missing the \"Host\" header, URI: {}",
req.method(),
req.uri()
);
Err(())
}
}
}