pub mod config;
mod conn_pool;
pub mod error;
mod handle;
#[macro_use]
pub mod log;
mod cache;
#[cfg(feature = "tls")]
pub mod tls;
#[cfg_attr(feature = "logging", macro_use(info, error, debug, trace))]
#[cfg(feature = "logging")]
extern crate tracing;
use std::net::SocketAddr;
use std::sync::Arc;
use std::task::ready;
use cache::init_caches;
use conn_pool::init_conn_pools;
use hyper::body::Incoming;
use hyper::server;
use hyper::service::service_fn;
use hyper::Request;
#[cfg(feature = "tls")]
use rustls::ServerConfig;
#[cfg(feature = "tls")]
use tls::stream::TlsStream;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::TcpListener;
use tokio::sync::{OwnedSemaphorePermit, Semaphore};
pub use config::{CacheSettings, Config, Rule};
pub use error::Error;
#[must_use = "Server does nothing unless it is `.await`ed"]
pub struct Server {
config: Arc<Config>,
listener: TcpListener,
semaphore: Arc<Semaphore>,
#[cfg(feature = "tls")]
tls_config: Option<Arc<ServerConfig>>,
}
impl Server {
fn common_config(mut config: Config) -> (Arc<Config>, TcpListener) {
init_conn_pools(&config);
init_caches(&config);
config.rules.sort_by(|a, b| a.path.cmp(&b.path));
let config = Arc::new(config);
cfg_logging! {debug!("Starting with config: {:#?}", *config);}
let listener = tcp_listener(config.addr).unwrap();
(config, listener)
}
pub fn new(config: Config) -> Self {
let (config, listener) = Self::common_config(config);
cfg_logging! {
info!("Motorx proxy listening on http://{}", {
#[cfg(target_os = "wasi")]
{config.addr}
#[cfg(not(target_os = "wasi"))]
listener.local_addr().unwrap()
});
}
Self {
semaphore: Arc::new(Semaphore::new(config.max_connections)),
config,
listener,
#[cfg(feature = "tls")]
tls_config: None,
}
}
#[cfg(feature = "tls")]
pub fn new_tls(config: Config) -> Self {
let (config, listener) = Self::common_config(config);
let tls_config = {
let certs = tls::load_certs(
config
.certs
.as_ref()
.expect("Must provide `certs` in config to use tls."),
)
.unwrap();
let key = tls::load_private_key(
config
.private_key
.as_ref()
.expect("Must provide `private_key` in config to use tls."),
)
.unwrap();
let mut cfg = rustls::ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_single_cert(certs, key)
.unwrap();
cfg.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
Arc::new(cfg)
};
cfg_logging! {
info!("Motorx proxy listening on https://{}", listener.local_addr().unwrap());
}
Self {
semaphore: Arc::new(Semaphore::new(config.max_connections)),
config,
listener,
tls_config: Some(tls_config),
}
}
pub fn local_addr(&self) -> std::io::Result<SocketAddr> {
self.listener.local_addr()
}
}
impl std::future::Future for Server {
type Output = Result<(), hyper::Error>;
fn poll(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Self::Output> {
loop {
if let Ok(permit) = Arc::clone(&self.semaphore).try_acquire_owned() {
match ready!(self.listener.poll_accept(cx)) {
Ok((stream, peer_addr)) => {
cfg_logging! {
trace!("Accepted connection from {}", peer_addr);
}
#[cfg(feature = "tls")]
if let Some(tls_config) = self.tls_config.as_ref() {
let tls_stream = TlsStream::new(stream, Arc::clone(tls_config));
handle_connection(
tls_stream,
peer_addr,
Arc::clone(&self.config),
permit,
)
} else {
handle_connection(stream, peer_addr, Arc::clone(&self.config), permit)
};
#[cfg(not(feature = "tls"))]
handle_connection(stream, peer_addr, Arc::clone(&self.config), permit);
}
Err(e) => {
cfg_logging! {
error!("Error connecting, {:?}", e);
}
}
}
}
}
}
}
#[cfg_attr(feature = "logging", tracing::instrument(skip(stream, config)))]
fn handle_connection<S: AsyncRead + AsyncWrite + Unpin + Send + 'static>(
stream: S,
peer_addr: SocketAddr,
config: Arc<Config>,
permit: OwnedSemaphorePermit,
) {
let service = service_fn(move |req: Request<Incoming>| {
handle::handle_req(req, peer_addr, Arc::clone(&config))
});
tokio::spawn(async move {
if let Err(err) = server::conn::http1::Builder::new()
.http1_preserve_header_case(true)
.http1_title_case_headers(true)
.http1_keep_alive(true)
.serve_connection(stream, service)
.with_upgrades()
.await
{
cfg_logging! {trace!("Error handling connection: {err:?}");}
};
cfg_logging! {
trace!("Closing connection to {}", peer_addr);
}
drop(permit);
});
}
#[cfg(not(target_os = "wasi"))]
#[inline]
fn tcp_listener(addr: SocketAddr) -> std::io::Result<tokio::net::TcpListener> {
tokio::net::TcpListener::from_std(std::net::TcpListener::bind(addr)?)
}
#[cfg(target_os = "wasi")]
#[inline]
fn tcp_listener(addr: SocketAddr) -> std::io::Result<tokio::net::TcpListener> {
tokio::net::TcpListener::from_std(wasmedge_wasi_socket::TcpListener::bind(addr, true)?)
}
#[cfg(not(target_os = "wasi"))]
#[inline(always)]
async fn tcp_connect(addr: impl ToString) -> std::io::Result<tokio::net::TcpStream> {
tokio::net::TcpStream::connect(addr.to_string()).await
}
#[cfg(target_os = "wasi")]
#[inline(always)]
async fn tcp_connect(addr: impl ToString) -> std::io::Result<tokio::net::TcpStream> {
tokio::net::TcpStream::connect(addr.to_string()).await
}