#![allow(clippy::disallowed_types)]
use std::{
borrow::Cow,
convert::Infallible,
fmt::{self, Display},
future::Future,
net::{SocketAddr, TcpListener},
str::FromStr,
sync::Arc,
time::Duration,
};
use anyhow::Context;
use axum::{
Router, ServiceExt as AxumServiceExt,
error_handling::HandleErrorLayer,
extract::{
DefaultBodyLimit, FromRequest,
rejection::{
BytesRejection, JsonRejection, PathRejection, QueryRejection,
},
},
response::IntoResponse,
routing::RouterIntoService,
};
use axum_server::tls_rustls::RustlsConfig;
use bytes::Bytes;
use http::{HeaderValue, StatusCode, header::CONTENT_TYPE};
use lexe_api_core::{
axum_helpers,
error::{CommonApiError, CommonErrorKind},
};
use lexe_common::api::auth::{self, Scope};
use lexe_crypto::ed25519;
use lexe_tokio::{notify_once::NotifyOnce, task::LxTask};
use serde::{Serialize, de::DeserializeOwned};
use tower::{
Layer, buffer::BufferLayer, limit::ConcurrencyLimitLayer,
load_shed::LoadShedLayer, timeout::TimeoutLayer, util::MapRequestLayer,
};
use tracing::{Instrument, debug, error, info, warn};
use crate::{rest, trace};
const SHUTDOWN_GRACE_PERIOD: Duration = Duration::from_secs(3);
pub const SERVER_SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(5);
lexe_std::const_assert!(
SHUTDOWN_GRACE_PERIOD.as_secs() < SERVER_SHUTDOWN_TIMEOUT.as_secs()
);
pub const SERVER_HANDLER_TIMEOUT: Duration = Duration::from_secs(25);
lexe_std::const_assert!(
rest::API_REQUEST_TIMEOUT.as_secs() > SERVER_HANDLER_TIMEOUT.as_secs()
);
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct LayerConfig {
pub body_limit: usize,
pub buffer_size: usize,
pub concurrency: usize,
pub handling_timeout: Duration,
pub default_fallback: bool,
}
impl Default for LayerConfig {
fn default() -> Self {
Self {
body_limit: 16384,
buffer_size: 4096,
concurrency: 4096,
handling_timeout: SERVER_HANDLER_TIMEOUT,
default_fallback: true,
}
}
}
pub fn build_server_url(
listener_addr: SocketAddr,
maybe_dns: Option<&str>,
) -> String {
match maybe_dns {
Some(dns_name) => {
let port = listener_addr.port();
if port == 443 {
format!("https://{dns_name}")
} else {
format!("https://{dns_name}:{port}")
}
}
None => format!("http://{listener_addr}"),
}
}
pub fn build_server_fut(
bind_addr: SocketAddr,
router: Router<()>,
layer_config: LayerConfig,
maybe_tls_and_dns: Option<(Arc<rustls::ServerConfig>, &str)>,
server_span_name: &str,
server_span: tracing::Span,
shutdown: NotifyOnce,
) -> anyhow::Result<(impl Future<Output = ()>, String)> {
let listener =
TcpListener::bind(bind_addr).context("Could not bind TCP listener")?;
let (server_fut, primary_server_url) = build_server_fut_with_listener(
listener,
router,
layer_config,
maybe_tls_and_dns,
server_span_name,
server_span,
shutdown,
)
.context("Could not build server future")?;
Ok((server_fut, primary_server_url))
}
pub fn build_server_fut_with_listener(
listener: TcpListener,
router: Router<()>,
layer_config: LayerConfig,
maybe_tls_and_dns: Option<(Arc<rustls::ServerConfig>, &str)>,
server_span_name: &str,
server_span: tracing::Span,
mut shutdown: NotifyOnce,
) -> anyhow::Result<(impl Future<Output = ()> + use<>, String)> {
let (maybe_tls_config, maybe_dns) = maybe_tls_and_dns.unzip();
let listener_addr = listener
.local_addr()
.context("Could not get listener local address")?;
let primary_server_url = build_server_url(listener_addr, maybe_dns);
info!("Url for {server_span_name}: {primary_server_url}");
let router = if layer_config.default_fallback {
router.fallback(default_fallback)
} else {
router
};
type HyperService = RouterIntoService<hyper::body::Incoming, ()>;
type AxumService = RouterIntoService<axum::body::Body, ()>;
type HyperReq = http::Request<hyper::body::Incoming>;
type AxumReq = http::Request<axum::body::Body>;
type AxumResp = http::Response<axum::body::Body>;
type TraceResp = http::Response<
tower_http::trace::ResponseBody<
axum::body::Body,
tower_http::classify::NeverClassifyEos<anyhow::Error>,
(),
trace::server::LxOnEos,
trace::server::LxOnFailure,
>,
>;
let outer_middleware = tower::ServiceBuilder::new()
.check_service::<HyperService, HyperReq, AxumResp, Infallible>()
.layer(trace::server::trace_layer(server_span.clone()))
.check_service::<HyperService, HyperReq, TraceResp, Infallible>()
.layer(tower::util::MapResponseLayer::new(
middleware::post_process_response,
))
.check_service::<HyperService, HyperReq, TraceResp, Infallible>();
let inner_middleware = tower::ServiceBuilder::new()
.check_service::<AxumService, AxumReq, AxumResp, Infallible>()
.layer(axum::middleware::map_request_with_state(
layer_config.body_limit,
middleware::check_content_length_header,
))
.check_service::<AxumService, AxumReq, AxumResp, Infallible>()
.layer(DefaultBodyLimit::max(layer_config.body_limit))
.check_service::<AxumService, AxumReq, AxumResp, Infallible>()
.layer(MapRequestLayer::new(axum::RequestExt::with_limited_body))
.check_service::<AxumService, AxumReq, AxumResp, Infallible>()
.layer(HandleErrorLayer::new(|_: tower::BoxError| async move {
CommonApiError {
kind: CommonErrorKind::AtCapacity,
msg: "Service is at capacity; retry later".to_owned(),
}
}))
.layer(LoadShedLayer::new())
.check_service::<AxumService, AxumReq, AxumResp, Infallible>()
.layer(BufferLayer::new(layer_config.buffer_size))
.check_service::<AxumService, AxumReq, AxumResp, Infallible>()
.layer(ConcurrencyLimitLayer::new(layer_config.concurrency))
.check_service::<AxumService, AxumReq, AxumResp, Infallible>()
.layer(HandleErrorLayer::new(|_: tower::BoxError| async move {
CommonApiError {
kind: CommonErrorKind::Server,
msg: "Server timed out handling request".to_owned(),
}
}))
.layer(TimeoutLayer::new(layer_config.handling_timeout))
.check_service::<AxumService, AxumReq, AxumResp, Infallible>();
let layered_router = router.layer(inner_middleware);
let router_service = layered_router.into_service::<hyper::body::Incoming>();
let layered_service = Layer::layer(&outer_middleware, router_service);
let make_service = layered_service.into_make_service();
let handle = axum_server::Handle::new();
let handle_clone = handle.clone();
let server_fut = async {
let serve_result = match maybe_tls_config {
Some(tls_config) => {
let axum_tls_config = RustlsConfig::from_config(tls_config);
axum_server::from_tcp_rustls(listener, axum_tls_config)
.handle(handle_clone)
.serve(make_service)
.await
}
None =>
axum_server::from_tcp(listener)
.handle(handle_clone)
.serve(make_service)
.await,
};
serve_result
.expect("No binding + axum MakeService::poll_ready never errors");
};
let graceful_shutdown_fut = async move {
shutdown.recv().await;
info!("Shutting down API server");
handle.graceful_shutdown(Some(SHUTDOWN_GRACE_PERIOD));
};
let combined_fut = async {
tokio::pin!(server_fut);
tokio::select! {
biased; () = graceful_shutdown_fut => (),
_ = &mut server_fut => return error!("Server exited early"),
}
match tokio::time::timeout(SERVER_SHUTDOWN_TIMEOUT, server_fut).await {
Ok(()) => info!("API server finished"),
Err(_) => warn!("API server timed out during shutdown"),
}
}
.instrument(server_span);
Ok((combined_fut, primary_server_url))
}
pub fn spawn_server_task(
bind_addr: SocketAddr,
router: Router<()>,
layer_config: LayerConfig,
maybe_tls_and_dns: Option<(Arc<rustls::ServerConfig>, &str)>,
server_span_name: Cow<'static, str>,
server_span: tracing::Span,
shutdown: NotifyOnce,
) -> anyhow::Result<(LxTask<()>, String)> {
let listener = TcpListener::bind(bind_addr)
.context(bind_addr)
.context("Failed to bind TcpListener")?;
let (server_task, primary_server_url) = spawn_server_task_with_listener(
listener,
router,
layer_config,
maybe_tls_and_dns,
server_span_name,
server_span,
shutdown,
)
.context("spawn_server_task_with_listener failed")?;
Ok((server_task, primary_server_url))
}
pub fn spawn_server_task_with_listener(
listener: TcpListener,
router: Router<()>,
layer_config: LayerConfig,
maybe_tls_and_dns: Option<(Arc<rustls::ServerConfig>, &str)>,
server_span_name: Cow<'static, str>,
server_span: tracing::Span,
shutdown: NotifyOnce,
) -> anyhow::Result<(LxTask<()>, String)> {
let (server_fut, primary_server_url) = build_server_fut_with_listener(
listener,
router,
layer_config,
maybe_tls_and_dns,
&server_span_name,
server_span.clone(),
shutdown,
)
.context("Failed to build server future")?;
let server_task =
LxTask::spawn_with_span(server_span_name, server_span, server_fut);
Ok((server_task, primary_server_url))
}
pub struct LxJson<T>(pub T);
impl<T: DeserializeOwned, S: Send + Sync> FromRequest<S> for LxJson<T> {
type Rejection = LxRejection;
async fn from_request(
req: http::Request<axum::body::Body>,
state: &S,
) -> Result<Self, Self::Rejection> {
axum::Json::from_request(req, state)
.await
.map(|axum::Json(t)| Self(t))
.map_err(LxRejection::from)
}
}
impl<T: Serialize> IntoResponse for LxJson<T> {
fn into_response(self) -> http::Response<axum::body::Body> {
axum_helpers::build_json_response(StatusCode::OK, &self.0)
}
}
impl<T: Clone> Clone for LxJson<T> {
fn clone(&self) -> Self {
Self(self.0.clone())
}
}
impl<T: Copy> Copy for LxJson<T> {}
impl<T: fmt::Debug> fmt::Debug for LxJson<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
T::fmt(&self.0, f)
}
}
impl<T: Eq + PartialEq> Eq for LxJson<T> {}
impl<T: PartialEq> PartialEq for LxJson<T> {
fn eq(&self, other: &Self) -> bool {
self.0.eq(&other.0)
}
}
#[derive(Clone, Debug, Default, Eq, PartialEq, Ord, PartialOrd)]
pub struct LxBytes(pub Bytes);
impl<S: Send + Sync> FromRequest<S> for LxBytes {
type Rejection = LxRejection;
async fn from_request(
req: http::Request<axum::body::Body>,
state: &S,
) -> Result<Self, Self::Rejection> {
Bytes::from_request(req, state)
.await
.map(Self)
.map_err(LxRejection::from)
}
}
impl IntoResponse for LxBytes {
fn into_response(self) -> http::Response<axum::body::Body> {
let http_body = http_body_util::Full::new(self.0);
let axum_body = axum::body::Body::new(http_body);
axum_helpers::default_response_builder()
.header(
CONTENT_TYPE,
HeaderValue::from_static("application/octet-stream"),
)
.status(StatusCode::OK)
.body(axum_body)
.expect("All operations here should be infallible")
}
}
impl<T: Into<Bytes>> From<T> for LxBytes {
fn from(t: T) -> Self {
Self(t.into())
}
}
pub struct LxRejection {
kind: LxRejectionKind,
source_msg: String,
}
enum LxRejectionKind {
Bytes,
Json,
Path,
Query,
Unauthenticated,
Unauthorized,
BadEndpoint,
BodyLengthOverLimit,
Ed25519,
Proxy,
}
impl LxRejection {
pub fn from_ed25519(error: ed25519::Error) -> Self {
Self {
kind: LxRejectionKind::Ed25519,
source_msg: format!("{error:#}"),
}
}
pub fn from_bearer_auth(error: auth::Error) -> Self {
Self {
kind: LxRejectionKind::Unauthenticated,
source_msg: format!("{error:#}"),
}
}
pub fn scope_unauthorized(
granted_scope: &Scope,
requested_scope: &Scope,
) -> Self {
Self {
kind: LxRejectionKind::Unauthorized,
source_msg: format!(
"granted scope: {granted_scope:?}, requested scope: {requested_scope:?}"
),
}
}
pub fn proxy(error: impl Display) -> Self {
Self {
kind: LxRejectionKind::Proxy,
source_msg: format!("{error:#}"),
}
}
}
impl From<BytesRejection> for LxRejection {
fn from(bytes_rejection: BytesRejection) -> Self {
Self {
kind: LxRejectionKind::Bytes,
source_msg: bytes_rejection.body_text(),
}
}
}
impl From<JsonRejection> for LxRejection {
fn from(json_rejection: JsonRejection) -> Self {
Self {
kind: LxRejectionKind::Json,
source_msg: json_rejection.body_text(),
}
}
}
impl From<PathRejection> for LxRejection {
fn from(path_rejection: PathRejection) -> Self {
Self {
kind: LxRejectionKind::Path,
source_msg: path_rejection.body_text(),
}
}
}
impl From<QueryRejection> for LxRejection {
fn from(query_rejection: QueryRejection) -> Self {
Self {
kind: LxRejectionKind::Query,
source_msg: query_rejection.body_text(),
}
}
}
impl IntoResponse for LxRejection {
fn into_response(self) -> http::Response<axum::body::Body> {
let kind = CommonErrorKind::Rejection;
let kind_msg = self.kind.to_msg();
let source_msg = &self.source_msg;
let msg = format!("Rejection: {kind_msg}: {source_msg}");
warn!("{msg}");
let common_error = CommonApiError { kind, msg };
common_error.into_response()
}
}
impl LxRejectionKind {
fn to_msg(&self) -> &'static str {
match self {
Self::Bytes => "Bad request bytes",
Self::Json => "Client provided bad JSON",
Self::Path => "Client provided bad path parameter",
Self::Query => "Client provided bad query string",
Self::Unauthenticated => "Invalid bearer auth",
Self::Unauthorized => "Not authorized to access this resource",
Self::BadEndpoint => "Client requested a non-existent endpoint",
Self::BodyLengthOverLimit => "Request body length over limit",
Self::Ed25519 => "Ed25519 error",
Self::Proxy => "Proxy error",
}
}
}
pub mod extract {
use axum::extract::FromRequestParts;
use super::*;
pub struct LxQuery<T>(pub T);
impl<T: DeserializeOwned, S: Send + Sync> FromRequestParts<S> for LxQuery<T> {
type Rejection = LxRejection;
async fn from_request_parts(
parts: &mut http::request::Parts,
state: &S,
) -> Result<Self, Self::Rejection> {
axum::extract::Query::from_request_parts(parts, state)
.await
.map(|axum::extract::Query(t)| Self(t))
.map_err(LxRejection::from)
}
}
impl<T: Clone> Clone for LxQuery<T> {
fn clone(&self) -> Self {
Self(self.0.clone())
}
}
impl<T: fmt::Debug> fmt::Debug for LxQuery<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
T::fmt(&self.0, f)
}
}
impl<T: Eq + PartialEq> Eq for LxQuery<T> {}
impl<T: PartialEq> PartialEq for LxQuery<T> {
fn eq(&self, other: &Self) -> bool {
self.0.eq(&other.0)
}
}
pub struct LxPath<T>(pub T);
impl<T: DeserializeOwned + Send, S: Send + Sync> FromRequestParts<S>
for LxPath<T>
{
type Rejection = LxRejection;
async fn from_request_parts(
parts: &mut http::request::Parts,
state: &S,
) -> Result<Self, Self::Rejection> {
axum::extract::Path::from_request_parts(parts, state)
.await
.map(|axum::extract::Path(t)| Self(t))
.map_err(LxRejection::from)
}
}
impl<T: Clone> Clone for LxPath<T> {
fn clone(&self) -> Self {
Self(self.0.clone())
}
}
impl<T: fmt::Debug> fmt::Debug for LxPath<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
T::fmt(&self.0, f)
}
}
impl<T: Eq + PartialEq> Eq for LxPath<T> {}
impl<T: PartialEq> PartialEq for LxPath<T> {
fn eq(&self, other: &Self) -> bool {
self.0.eq(&other.0)
}
}
}
pub mod middleware {
use axum::extract::State;
use http::HeaderName;
use super::*;
pub static POST_PROCESS_HEADER: HeaderName =
HeaderName::from_static("lx-post-process");
pub async fn check_content_length_header<B>(
State(config_body_limit): State<usize>,
request: http::Request<B>,
) -> Result<http::Request<B>, LxRejection> {
let maybe_content_length = request
.headers()
.get(http::header::CONTENT_LENGTH)
.and_then(|value| value.to_str().ok())
.and_then(|value_str| usize::from_str(value_str).ok());
if let Some(content_length) = maybe_content_length
&& content_length > config_body_limit
{
return Err(LxRejection {
kind: LxRejectionKind::BodyLengthOverLimit,
source_msg: "Content length header over limit".to_owned(),
});
}
Ok(request)
}
pub(super) fn post_process_response(
mut response: http::Response<axum::body::Body>,
) -> http::Response<axum::body::Body> {
let value = match response.headers_mut().remove(&POST_PROCESS_HEADER) {
Some(v) => v,
None => return response,
};
match value.as_bytes() {
b"remove-content-length" => {
response.headers_mut().remove(http::header::CONTENT_LENGTH);
debug!("Post process: Removed content-length header");
}
unknown => {
let unknown_str = String::from_utf8_lossy(unknown);
warn!("Post process: Invalid header value: {unknown_str}");
}
}
response
}
}
pub async fn default_fallback(
method: http::Method,
uri: http::Uri,
) -> LxRejection {
let path = uri.path();
LxRejection {
kind: LxRejectionKind::BadEndpoint,
source_msg: format!("{method} {path}"),
}
}