use std::net::{IpAddr, SocketAddr};
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::Instant;
use anyhow::Result;
use axum::Router;
use axum::body::Body;
use axum::extract::{ConnectInfo, State};
use axum::http::{self, Request, Response};
use axum::routing::any;
use ipnet::IpNet;
use tokio::net::TcpListener;
use tokio::sync::watch;
use tower::ServiceBuilder;
use tower_http::limit::RequestBodyLimitLayer;
use tower_http::timeout::TimeoutLayer;
use tracing::{error, info};
use crate::config::HttpConfig;
use crate::payload::{decode_response, encode_request};
#[derive(Clone)]
struct AppState {
executor: Arc<dyn folk_api::Executor>,
config: Arc<HttpConfig>,
active_connections: Arc<AtomicU64>,
}
pub struct HttpServer {
config: HttpConfig,
executor: Arc<dyn folk_api::Executor>,
active_connections: Arc<AtomicU64>,
}
impl HttpServer {
pub fn new(
config: HttpConfig,
executor: Arc<dyn folk_api::Executor>,
active_connections: Arc<AtomicU64>,
) -> Self {
Self {
config,
executor,
active_connections,
}
}
pub async fn run(self, shutdown: watch::Receiver<bool>) -> Result<()> {
let state = AppState {
executor: self.executor.clone(),
config: Arc::new(self.config.clone()),
active_connections: self.active_connections.clone(),
};
let mut app = Router::new()
.route("/{*path}", any(handle))
.route("/", any(handle))
.with_state(state)
.layer(
ServiceBuilder::new()
.layer(RequestBodyLimitLayer::new(self.config.max_request_size))
.layer(TimeoutLayer::with_status_code(
http::StatusCode::GATEWAY_TIMEOUT,
self.config.write_timeout,
)),
);
if self.config.compression.enabled {
app = app.layer(build_compression_layer(&self.config.compression));
}
#[cfg(feature = "tls")]
if let Some(ref tls) = self.config.tls {
return self.run_tls(app, tls, shutdown).await;
}
#[cfg(feature = "h2c")]
if self.config.h2c {
return self.run_h2c(app, shutdown).await;
}
self.run_plain(app, shutdown).await
}
async fn run_plain(&self, app: Router, shutdown: watch::Receiver<bool>) -> Result<()> {
let listener = TcpListener::bind(self.config.listen).await?;
axum::serve(
listener,
app.into_make_service_with_connect_info::<SocketAddr>(),
)
.with_graceful_shutdown(shutdown_signal(shutdown))
.await?;
Ok(())
}
#[cfg(feature = "tls")]
async fn run_tls(
&self,
app: Router,
tls: &crate::config::TlsConfig,
shutdown: watch::Receiver<bool>,
) -> Result<()> {
use axum_server::Handle;
use axum_server::tls_rustls::RustlsConfig;
let rustls_config = RustlsConfig::from_pem_file(&tls.cert, &tls.key).await?;
info!(cert = %tls.cert.display(), "TLS enabled");
let handle = Handle::new();
let shutdown_handle = handle.clone();
tokio::spawn(async move {
shutdown_signal(shutdown).await;
shutdown_handle.graceful_shutdown(None);
});
axum_server::bind_rustls(self.config.listen, rustls_config)
.handle(handle)
.serve(app.into_make_service_with_connect_info::<SocketAddr>())
.await?;
Ok(())
}
#[cfg(feature = "h2c")]
async fn run_h2c(&self, app: Router, mut shutdown: watch::Receiver<bool>) -> Result<()> {
use hyper_util::rt::{TokioExecutor, TokioIo};
use hyper_util::server::conn::auto::Builder as AutoBuilder;
info!("h2c (HTTP/2 cleartext) enabled");
let listener = TcpListener::bind(self.config.listen).await?;
let builder = Arc::new(AutoBuilder::new(TokioExecutor::new()));
let mut tasks = tokio::task::JoinSet::new();
loop {
tokio::select! {
result = listener.accept() => {
let (stream, remote_addr) = result?;
let app = app.clone();
let builder = builder.clone();
tasks.spawn(async move {
let svc = hyper::service::service_fn(move |mut req: Request<hyper::body::Incoming>| {
req.extensions_mut().insert(ConnectInfo(remote_addr));
let app = app.clone();
async move {
let resp = tower::Service::call(&mut app.clone(), req).await;
resp.map_err(|e| match e {})
}
});
let _ = builder.serve_connection_with_upgrades(TokioIo::new(stream), svc).await;
});
}
_ = async {
loop {
if shutdown.changed().await.is_err() || *shutdown.borrow() {
break;
}
}
} => {
break;
}
}
}
while tasks.join_next().await.is_some() {}
Ok(())
}
}
async fn shutdown_signal(mut shutdown: watch::Receiver<bool>) {
loop {
if shutdown.changed().await.is_err() || *shutdown.borrow() {
break;
}
}
}
struct ConnectionGuard(Arc<AtomicU64>);
impl Drop for ConnectionGuard {
fn drop(&mut self) {
self.0.fetch_sub(1, Ordering::Relaxed);
}
}
async fn handle(
State(state): State<AppState>,
connect_info: ConnectInfo<SocketAddr>,
req: Request<Body>,
) -> Response<Body> {
state.active_connections.fetch_add(1, Ordering::Relaxed);
let _conn_guard = ConnectionGuard(state.active_connections.clone());
let start = Instant::now();
let method = req.method().clone();
let uri = req.uri().clone();
let peer_addr = connect_info.0;
let client_ip = resolve_client_ip(
peer_addr.ip(),
req.headers()
.get("x-forwarded-for")
.and_then(|v| v.to_str().ok()),
&state.config.trusted_proxies,
);
let response = handle_inner(&state, req).await;
if state.config.access_log {
let duration = start.elapsed();
let status = response.status().as_u16();
let response_bytes = response
.headers()
.get(http::header::CONTENT_LENGTH)
.and_then(|v| v.to_str().ok())
.and_then(|v| v.parse::<u64>().ok())
.unwrap_or(0);
info!(
client_ip = %client_ip,
method = %method,
uri = %uri,
status = status,
duration_ms = duration.as_millis() as u64,
response_bytes = response_bytes,
"http request",
);
}
response
}
async fn handle_inner(state: &AppState, req: Request<Body>) -> Response<Body> {
let max_body = state.config.max_request_size;
let read_timeout = state.config.read_timeout;
let payload = match tokio::time::timeout(read_timeout, encode_request(req, max_body)).await {
Ok(Ok(p)) => p,
Ok(Err(e)) => {
error!(error = ?e, "encode request");
return Response::builder()
.status(500)
.body(Body::from("encode error"))
.unwrap();
}
Err(_) => {
return Response::builder()
.status(408)
.body(Body::from("request body read timeout"))
.unwrap();
}
};
let response_value = match state.executor.execute_value("http.handle", payload).await {
Ok(v) => v,
Err(e) => {
error!(error = ?e, "dispatch to worker");
return Response::builder()
.status(502)
.body(Body::from("worker error"))
.unwrap();
}
};
match decode_response(response_value) {
Ok(r) => r,
Err(e) => {
error!(error = ?e, "decode response");
Response::builder()
.status(500)
.body(Body::from("decode error"))
.unwrap()
}
}
}
pub fn resolve_client_ip(peer_ip: IpAddr, xff: Option<&str>, trusted: &[IpNet]) -> IpAddr {
if trusted.is_empty() {
return peer_ip;
}
if !is_trusted(peer_ip, trusted) {
return peer_ip;
}
let Some(xff) = xff else {
return peer_ip;
};
let addrs: Vec<&str> = xff.split(',').map(|s| s.trim()).collect();
for addr_str in addrs.iter().rev() {
if let Ok(ip) = addr_str.parse::<IpAddr>() {
if !is_trusted(ip, trusted) {
return ip;
}
}
}
peer_ip
}
fn is_trusted(ip: IpAddr, trusted: &[IpNet]) -> bool {
trusted.iter().any(|net| net.contains(&ip))
}
fn build_compression_layer(
config: &crate::config::CompressionConfig,
) -> tower_http::compression::CompressionLayer<tower_http::compression::predicate::SizeAbove> {
use crate::config::CompressionAlgorithm;
use tower_http::compression::CompressionLayer;
let mut layer = CompressionLayer::new()
.no_gzip()
.no_br()
.no_zstd()
.no_deflate();
for algo in &config.algorithms {
layer = match algo {
CompressionAlgorithm::Gzip => layer.gzip(true),
CompressionAlgorithm::Br => layer.br(true),
CompressionAlgorithm::Zstd => layer.zstd(true),
CompressionAlgorithm::Deflate => layer.deflate(true),
};
}
#[allow(clippy::cast_possible_truncation)]
let min_size = config.min_size as u16;
layer.compress_when(tower_http::compression::predicate::SizeAbove::new(min_size))
}