use std::path::PathBuf;
use std::{
error::Error as StdError,
fs::File,
io::{BufReader, Error as IoError, ErrorKind as IoErrorKind},
net::SocketAddr,
path::Path,
str,
sync::Arc,
};
use super::{
chain::{Handler, Middleware},
svc::Svc,
};
use crate::{
GATEWAY_LOGGER_DOMAIN, debug_log, error_log,
gateway::{
context::Context,
response::{BoxBodyType, ResponseBuilder},
},
info_log, warn_log,
};
use hyper::{Response, StatusCode, body::Incoming, server::conn::http1};
use hyper_util::{
rt::{TokioExecutor, TokioIo},
server::conn::auto as hyper_conn_auto,
};
use rustls::{ServerConfig, crypto::aws_lc_rs};
use tokio::net::TcpListener;
use tokio_rustls::TlsAcceptor;
pub struct Gateway {
addr: String,
handler: Option<Handler>,
middlewares: Vec<Box<dyn Middleware>>,
cert_path: Option<String>,
key_path: Option<String>,
}
impl Gateway {
pub fn new(addr: &str) -> Self {
Self {
addr: addr.to_string(),
handler: None,
middlewares: Vec::new(),
cert_path: None,
key_path: None,
}
}
pub fn with_tls(
mut self,
cert_path: Option<PathBuf>,
key_path: Option<PathBuf>,
) -> Self {
if let (Some(cert), Some(key)) = (cert_path, key_path) {
if cert.exists() && key.exists() {
debug_log!(
GATEWAY_LOGGER_DOMAIN,
"SSL certificate exist, start loading cert_path={:?}, key_path={:?}",
cert,
key
);
self.cert_path = Some(cert.to_string_lossy().into_owned());
self.key_path = Some(key.to_string_lossy().into_owned());
} else {
warn_log!(
GATEWAY_LOGGER_DOMAIN,
"SSL certificate does not exist: cert_path={:?}, key_path={:?}",
cert,
key
);
}
}
self
}
pub fn add_middleware(mut self, middleware: Box<dyn Middleware>) -> Self {
self.middlewares.push(middleware);
self
}
pub fn set_handler(&mut self, handler: Handler) {
self.handler = Some(handler);
}
pub fn setup_crypto_provider() -> Result<(), Box<dyn StdError + Send + Sync>>
{
aws_lc_rs::default_provider()
.install_default()
.map_err(|e| {
format!("Failed to install rustls crypto provider: {e:?}")
})?;
Ok(())
}
pub async fn listen(
&mut self,
) -> Result<(), Box<dyn StdError + Send + Sync>> {
let addr: SocketAddr = self.addr.parse()?;
let listener = TcpListener::bind(&addr).await?;
let handler =
self.handler.clone().unwrap_or_else(Self::default_handler);
let middlewares = Arc::new(std::mem::take(&mut self.middlewares));
self.run_server(listener, handler, middlewares).await
}
async fn run_server(
&self,
listener: TcpListener,
handler: Handler,
middlewares: Arc<Vec<Box<dyn Middleware>>>,
) -> Result<(), Box<dyn StdError + Send + Sync>> {
let addr = listener.local_addr()?;
if let (Some(cert_path), Some(key_path)) =
(&self.cert_path, &self.key_path)
{
match self
.load_tls_config(Path::new(cert_path), Path::new(key_path))
{
Ok(tls_config) => {
let tls_acceptor = TlsAcceptor::from(Arc::new(tls_config));
self.run_https_server(
&addr,
listener,
handler,
middlewares,
tls_acceptor,
)
.await
}
Err(e) => {
warn_log!(
GATEWAY_LOGGER_DOMAIN,
"Failed to load TLS config: {}. Falling back to plain HTTP/1.1.",
e
);
self.run_http_server(&addr, listener, handler, middlewares)
.await
}
}
} else {
self.run_http_server(&addr, listener, handler, middlewares)
.await
}
}
async fn run_http_server(
&self,
addr: &SocketAddr,
listener: TcpListener,
handler: Handler,
middlewares: Arc<Vec<Box<dyn Middleware>>>,
) -> Result<(), Box<dyn StdError + Send + Sync>> {
info_log!(
GATEWAY_LOGGER_DOMAIN,
"Gateway listening with HTTP/1.1 on addr {}",
addr
);
loop {
let (stream, peer_addr) = listener.accept().await?;
let service = Svc::new(handler.clone(), middlewares.clone());
tokio::spawn(async move {
let io = TokioIo::new(stream);
if let Err(err) =
http1::Builder::new().serve_connection(io, service).await
{
if !Self::is_ignorable_connection_error(&err) {
error_log!(
GATEWAY_LOGGER_DOMAIN,
"Error serving HTTP connection from {}: {:?}",
peer_addr,
err
);
}
}
});
}
}
async fn run_https_server(
&self,
addr: &SocketAddr,
listener: TcpListener,
handler: Handler,
middlewares: Arc<Vec<Box<dyn Middleware>>>,
tls_acceptor: TlsAcceptor,
) -> Result<(), Box<dyn StdError + Send + Sync>> {
info_log!(
GATEWAY_LOGGER_DOMAIN,
"Gateway listening with TLS (H2/H1) on addr {}",
addr
);
loop {
let (stream, peer_addr) = listener.accept().await?;
debug_log!(
GATEWAY_LOGGER_DOMAIN,
"Incoming TCP connection from {}",
peer_addr
);
let tls_acceptor = tls_acceptor.clone();
let service = Svc::new(handler.clone(), middlewares.clone());
tokio::spawn(async move {
debug_log!(
GATEWAY_LOGGER_DOMAIN,
"Shake hands for {}",
peer_addr
);
match tls_acceptor.accept(stream).await {
Ok(tls_stream) => {
let (_, conn) = tls_stream.get_ref();
let alpn_protocol =
conn.alpn_protocol().map_or("None", |p| {
str::from_utf8(p).unwrap_or("Invalid UTF-8")
});
debug_log!(
GATEWAY_LOGGER_DOMAIN,
"TLS shake hands for {} success。ALPN protocol: {}",
peer_addr,
alpn_protocol
);
let io = TokioIo::new(tls_stream);
debug_log!(
GATEWAY_LOGGER_DOMAIN,
"Handling the connection from {} over to Hyper",
peer_addr
);
if let Err(err) =
hyper_conn_auto::Builder::new(TokioExecutor::new())
.serve_connection(io, service)
.await
{
if !Self::is_ignorable_connection_error(
err.as_ref(),
) {
error_log!(
GATEWAY_LOGGER_DOMAIN,
"Error occurred while processing HTTPS connection from {}: {:?}",
peer_addr,
err
);
}
}
}
Err(e) => {
error_log!(
GATEWAY_LOGGER_DOMAIN,
"TLS shake hands error for {}: {:?}",
peer_addr,
e
);
}
}
});
}
}
fn load_tls_config(
&self,
cert_path: &Path,
key_path: &Path,
) -> Result<ServerConfig, Box<dyn StdError + Send + Sync>> {
let cert_file = File::open(cert_path).map_err(|e| {
format!("failed to open cert file {cert_path:?}: {e}")
})?;
let certs = rustls_pemfile::certs(&mut BufReader::new(cert_file))
.collect::<Result<Vec<_>, _>>()?;
let key_file = File::open(key_path).map_err(|e| {
format!("failed to open key file {key_path:?}: {e}")
})?;
let key = rustls_pemfile::private_key(&mut BufReader::new(key_file))?
.ok_or("no private key found in file")?;
let mut config = ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(certs, key)?;
config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
Ok(config)
}
fn default_handler() -> Handler {
Arc::new(
|_ctx: Context, _body: Option<Incoming>| -> Response<BoxBodyType> {
ResponseBuilder::with_status_code(
StatusCode::INTERNAL_SERVER_ERROR,
)
},
)
}
fn is_ignorable_connection_error(err: &(dyn StdError + 'static)) -> bool {
let mut source = Some(err);
while let Some(current_err) = source {
if let Some(io_err) = current_err.downcast_ref::<IoError>() {
if matches!(
io_err.kind(),
IoErrorKind::ConnectionReset | IoErrorKind::BrokenPipe
) {
return true;
}
}
if let Some(hyper_err) = current_err.downcast_ref::<hyper::Error>()
{
if hyper_err.is_canceled() || hyper_err.is_closed() {
return true;
}
}
source = current_err.source();
}
false
}
}