pub mod config;
mod conn_pool;
pub mod error;
mod handle;
#[macro_use]
pub mod log;
mod cache;
#[cfg(test)]
mod e2e;
#[cfg(feature = "tls")]
pub mod tls;
#[cfg_attr(feature = "logging", macro_use(info, error, debug, trace))]
#[cfg(feature = "logging")]
extern crate tracing;
use std::net::SocketAddr;
use std::sync::Arc;
use cache::Cache;
use config::Upstream;
use conn_pool::ConnPool;
use hyper::body::Incoming;
use hyper::service::service_fn;
use hyper::Request;
use hyper_util::rt::{TokioExecutor, TokioIo};
#[cfg(feature = "tls")]
use rustls::ServerConfig;
#[cfg(feature = "tls")]
use tls::stream::TlsStream;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::TcpListener;
use tokio::sync::{OwnedSemaphorePermit, Semaphore};
pub use config::{CacheSettings, Config, Rule};
pub use error::Error;
type UpstreamAndConnPool = (Arc<Upstream>, ConnPool);
type Upstreams = Vec<UpstreamAndConnPool>;
pub struct Server {
config: Arc<Config>,
cache: Arc<Cache>,
upstreams: Arc<Upstreams>,
listener: TcpListener,
semaphore: Arc<Semaphore>,
#[cfg(feature = "tls")]
tls_config: Option<Arc<ServerConfig>>,
}
impl Server {
fn common_config(mut config: Config) -> (Arc<Config>, Arc<Cache>, Arc<Upstreams>, TcpListener) {
let upstreams = Arc::new(init_upstreams(&mut config));
let cache = Arc::new(Cache::from_config(&mut config));
config.rules.sort_by(|a, b| a.path.cmp(&b.path));
let config = Arc::new(config);
cfg_logging! {debug!("Starting with config: {:#?}", *config);}
let listener = tcp_listener(config.addr).unwrap();
(config, cache, upstreams, listener)
}
pub fn new(config: Config) -> Self {
let (config, cache, conn_pools, listener) = Self::common_config(config);
cfg_logging! {
info!("Motorx proxy listening on http://{}", {
listener.local_addr().unwrap()
});
}
Self {
semaphore: Arc::new(Semaphore::new(config.max_connections)),
cache,
upstreams: conn_pools,
config,
listener,
#[cfg(feature = "tls")]
tls_config: None,
}
}
#[cfg(feature = "tls")]
pub fn new_tls(config: Config) -> Self {
let (config, cache, conn_pools, listener) = Self::common_config(config);
let tls_config = {
let certs = tls::load_certs(
config
.certs
.as_ref()
.expect("Must provide `certs` in config to use tls."),
)
.unwrap();
let key = tls::load_private_key(
config
.private_key
.as_ref()
.expect("Must provide `private_key` in config to use tls."),
)
.unwrap();
let mut cfg = rustls::ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(certs, key)
.unwrap();
cfg.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
Arc::new(cfg)
};
cfg_logging! {
info!("Motorx proxy listening on https://{}", listener.local_addr().unwrap());
}
Self {
semaphore: Arc::new(Semaphore::new(config.max_connections)),
cache,
upstreams: conn_pools,
config,
listener,
tls_config: Some(tls_config),
}
}
pub fn local_addr(&self) -> std::io::Result<SocketAddr> {
self.listener.local_addr()
}
pub async fn run(self) -> Result<(), hyper::Error> {
loop {
println!("Getting semaphore");
if let Ok(permit) = self.semaphore.clone().acquire_owned().await {
println!("Polling listener");
match self.listener.accept().await {
Ok((stream, peer_addr)) => {
cfg_logging! {
trace!("Accepted connection from {}", peer_addr);
}
#[cfg(feature = "tls")]
if let Some(tls_config) = self.tls_config.as_ref() {
let tls_stream = TlsStream::new(stream, Arc::clone(tls_config));
handle_connection(
tls_stream,
peer_addr,
Arc::clone(&self.config),
Arc::clone(&self.cache),
Arc::clone(&self.upstreams),
permit,
)
} else {
handle_connection(
stream,
peer_addr,
Arc::clone(&self.config),
Arc::clone(&self.cache),
Arc::clone(&self.upstreams),
permit,
)
};
#[cfg(not(feature = "tls"))]
handle_connection(
stream,
peer_addr,
Arc::clone(&self.config),
Arc::clone(&self.cache),
Arc::clone(&self.upstreams),
permit,
);
}
Err(e) => {
cfg_logging! {
error!("Error connecting, {:?}", e);
}
}
}
}
}
}
}
#[cfg_attr(
feature = "logging",
tracing::instrument(skip(stream, config, cache, permit))
)]
fn handle_connection<S: AsyncRead + AsyncWrite + Unpin + Send + 'static>(
stream: S,
peer_addr: SocketAddr,
config: Arc<Config>,
cache: Arc<Cache>,
conn_pools: Arc<Upstreams>,
permit: OwnedSemaphorePermit,
) {
let service = service_fn(move |req: Request<Incoming>| {
let config = config.clone();
let cache = cache.clone();
let conn_pools = conn_pools.clone();
async move {
let res = handle::handle_req(
req,
peer_addr,
Arc::clone(&config),
Arc::clone(&cache),
Arc::clone(&conn_pools),
)
.await;
cfg_logging! {
trace!("Responded to req from {}", peer_addr);
}
res
}
});
tokio::spawn(async move {
cfg_logging! {
trace!("Handling connection from {}", peer_addr);
}
let conn_build = hyper_util::server::conn::auto::Builder::new(TokioExecutor::new());
if let Err(err) = conn_build
.serve_connection_with_upgrades(TokioIo::new(stream), service)
.await
{
cfg_logging! {trace!("Error handling connection: {err:?}");}
};
cfg_logging! {
trace!("Closing connection to {}", peer_addr);
}
drop(permit);
});
}
#[inline]
fn tcp_listener(addr: SocketAddr) -> std::io::Result<tokio::net::TcpListener> {
let std_listener = std::net::TcpListener::bind(addr)?;
std_listener.set_nonblocking(true)?;
tokio::net::TcpListener::from_std(std_listener)
}
#[inline]
async fn tcp_connect(
addr: impl tokio::net::ToSocketAddrs,
) -> std::io::Result<tokio::net::TcpStream> {
tokio::net::TcpStream::connect(addr).await
}
fn init_upstreams(config: &mut Config) -> Upstreams {
let mut upstreams = Vec::with_capacity(config.upstreams.len());
let mut upstream_order = Vec::new();
for upstream_name in config.upstreams.keys() {
upstream_order.push(upstream_name.clone());
}
for (key, upstream_name) in upstream_order.iter().enumerate() {
for (_, upstream) in &mut config.upstreams {
if let Some(auth) = Arc::get_mut(upstream).unwrap().authentication.as_mut() {
match &mut auth.source {
config::authentication::AuthenticationSource::Upstream {
name: _,
path: _,
key: upstream_key,
} => *upstream_key = key,
config::authentication::AuthenticationSource::Path(_) => {}
}
}
}
for rule in &mut config.rules {
if rule.upstream == *upstream_name {
rule.upstream_key = key;
}
}
}
for (key, upstream_name) in upstream_order.iter().enumerate() {
let upstream = config.upstreams.get_mut(upstream_name).unwrap();
Arc::get_mut(upstream).unwrap().key = key;
upstreams.push((
Arc::clone(upstream),
ConnPool::new(upstream.addr.clone(), upstream.max_connections),
));
}
upstreams.shrink_to_fit();
upstreams
}