1#![allow(clippy::disallowed_types)]
3
4use std::{
41 borrow::Cow,
42 convert::Infallible,
43 fmt::{self, Display},
44 future::Future,
45 net::{SocketAddr, TcpListener},
46 str::FromStr,
47 sync::Arc,
48 time::Duration,
49};
50
51use anyhow::Context;
52use axum::{
53 Router, ServiceExt as AxumServiceExt,
54 error_handling::HandleErrorLayer,
55 extract::{
56 DefaultBodyLimit, FromRequest, OptionalFromRequest,
57 rejection::{
58 BytesRejection, JsonRejection, PathRejection, QueryRejection,
59 },
60 },
61 response::IntoResponse,
62 routing::RouterIntoService,
63};
64use axum_server::tls_rustls::RustlsConfig;
65use bytes::Bytes;
66use http::{HeaderValue, StatusCode, header::CONTENT_TYPE};
67use lexe_api_core::{
68 axum_helpers,
69 error::{CommonApiError, CommonErrorKind},
70};
71use lexe_common::api::auth::{self, Scope};
72use lexe_crypto::ed25519;
73use lexe_tokio::{notify_once::NotifyOnce, task::LxTask};
74use serde::{Serialize, de::DeserializeOwned};
75use tower::{
76 Layer, buffer::BufferLayer, limit::ConcurrencyLimitLayer,
77 load_shed::LoadShedLayer, timeout::TimeoutLayer, util::MapRequestLayer,
78};
79use tracing::{Instrument, debug, error, info, warn};
80
81use crate::{rest, trace};
82
83const SHUTDOWN_GRACE_PERIOD: Duration = Duration::from_secs(3);
87pub const SERVER_SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(5);
89lexe_std::const_assert!(
90 SHUTDOWN_GRACE_PERIOD.as_secs() < SERVER_SHUTDOWN_TIMEOUT.as_secs()
91);
92
93pub const SERVER_HANDLER_TIMEOUT: Duration = Duration::from_secs(25);
95lexe_std::const_assert!(
96 rest::API_REQUEST_TIMEOUT.as_secs() > SERVER_HANDLER_TIMEOUT.as_secs()
97);
98
99#[derive(Clone, Debug, Eq, PartialEq)]
118pub struct LayerConfig {
119 pub body_limit: usize,
122 pub buffer_size: usize,
126 pub concurrency: usize,
129 pub handling_timeout: Duration,
133 pub default_fallback: bool,
143}
144
145impl Default for LayerConfig {
146 fn default() -> Self {
147 Self {
148 body_limit: 16384,
150 buffer_size: 4096,
154 concurrency: 4096,
155 handling_timeout: SERVER_HANDLER_TIMEOUT,
156 default_fallback: true,
157 }
158 }
159}
160
161pub fn build_server_url(
172 listener_addr: SocketAddr,
174 maybe_dns: Option<&str>,
176) -> String {
177 match maybe_dns {
178 Some(dns_name) => {
179 let port = listener_addr.port();
180 if port == 443 {
181 format!("https://{dns_name}")
182 } else {
183 format!("https://{dns_name}:{port}")
184 }
185 }
186 None => format!("http://{listener_addr}"),
187 }
188}
189
190pub fn build_server_fut(
202 bind_addr: SocketAddr,
203 router: Router<()>,
204 layer_config: LayerConfig,
205 maybe_tls_and_dns: Option<(Arc<rustls::ServerConfig>, &str)>,
207 server_span_name: &str,
208 server_span: tracing::Span,
209 shutdown: NotifyOnce,
211) -> anyhow::Result<(impl Future<Output = ()>, String)> {
212 let listener =
213 TcpListener::bind(bind_addr).context("Could not bind TCP listener")?;
214 let (server_fut, primary_server_url) = build_server_fut_with_listener(
215 listener,
216 router,
217 layer_config,
218 maybe_tls_and_dns,
219 server_span_name,
220 server_span,
221 shutdown,
222 )
223 .context("Could not build server future")?;
224 Ok((server_fut, primary_server_url))
225}
226
227pub fn build_server_fut_with_listener(
231 listener: TcpListener,
232 router: Router<()>,
233 layer_config: LayerConfig,
234 maybe_tls_and_dns: Option<(Arc<rustls::ServerConfig>, &str)>,
236 server_span_name: &str,
237 server_span: tracing::Span,
238 mut shutdown: NotifyOnce,
240) -> anyhow::Result<(impl Future<Output = ()> + use<>, String)> {
241 let (maybe_tls_config, maybe_dns) = maybe_tls_and_dns.unzip();
242 let listener_addr = listener
243 .local_addr()
244 .context("Could not get listener local address")?;
245 let primary_server_url = build_server_url(listener_addr, maybe_dns);
246 info!("Url for {server_span_name}: {primary_server_url}");
247
248 let router = if layer_config.default_fallback {
250 router.fallback(default_fallback)
251 } else {
252 router
253 };
254
255 type HyperService = RouterIntoService<hyper::body::Incoming, ()>;
258 type AxumService = RouterIntoService<axum::body::Body, ()>;
259 type HyperReq = http::Request<hyper::body::Incoming>;
260 type AxumReq = http::Request<axum::body::Body>;
261 type AxumResp = http::Response<axum::body::Body>;
262 type TraceResp = http::Response<
263 tower_http::trace::ResponseBody<
264 axum::body::Body,
265 tower_http::classify::NeverClassifyEos<anyhow::Error>,
266 (),
267 trace::server::LxOnEos,
268 trace::server::LxOnFailure,
269 >,
270 >;
271
272 let outer_middleware = tower::ServiceBuilder::new()
278 .check_service::<HyperService, HyperReq, AxumResp, Infallible>()
279 .layer(trace::server::trace_layer(server_span.clone()))
282 .check_service::<HyperService, HyperReq, TraceResp, Infallible>()
283 .layer(tower::util::MapResponseLayer::new(
286 middleware::post_process_response,
287 ))
288 .check_service::<HyperService, HyperReq, TraceResp, Infallible>();
289
290 let inner_middleware = tower::ServiceBuilder::new()
294 .check_service::<AxumService, AxumReq, AxumResp, Infallible>()
295 .layer(axum::middleware::map_request_with_state(
297 layer_config.body_limit,
298 middleware::check_content_length_header,
299 ))
300 .check_service::<AxumService, AxumReq, AxumResp, Infallible>()
301 .layer(DefaultBodyLimit::max(layer_config.body_limit))
308 .check_service::<AxumService, AxumReq, AxumResp, Infallible>()
309 .layer(MapRequestLayer::new(axum::RequestExt::with_limited_body))
312 .check_service::<AxumService, AxumReq, AxumResp, Infallible>()
313 .layer(HandleErrorLayer::new(|_: tower::BoxError| async move {
315 CommonApiError {
316 kind: CommonErrorKind::AtCapacity,
317 msg: "Service is at capacity; retry later".to_owned(),
318 }
319 }))
320 .layer(LoadShedLayer::new())
323 .check_service::<AxumService, AxumReq, AxumResp, Infallible>()
324 .layer(BufferLayer::new(layer_config.buffer_size))
329 .check_service::<AxumService, AxumReq, AxumResp, Infallible>()
330 .layer(ConcurrencyLimitLayer::new(layer_config.concurrency))
333 .check_service::<AxumService, AxumReq, AxumResp, Infallible>()
334 .layer(HandleErrorLayer::new(|_: tower::BoxError| async move {
336 CommonApiError {
337 kind: CommonErrorKind::Server,
338 msg: "Server timed out handling request".to_owned(),
339 }
340 }))
341 .layer(TimeoutLayer::new(layer_config.handling_timeout))
345 .check_service::<AxumService, AxumReq, AxumResp, Infallible>();
346
347 let layered_router = router.layer(inner_middleware);
349 let router_service = layered_router.into_service::<hyper::body::Incoming>();
351 let layered_service = Layer::layer(&outer_middleware, router_service);
353 let make_service = layered_service.into_make_service();
355
356 let handle = axum_server::Handle::new();
357 let handle_clone = handle.clone();
358 let server_fut = async {
359 let serve_result = match maybe_tls_config {
360 Some(tls_config) => {
361 let axum_tls_config = RustlsConfig::from_config(tls_config);
362 axum_server::from_tcp_rustls(listener, axum_tls_config)
363 .handle(handle_clone)
364 .serve(make_service)
365 .await
366 }
367 None =>
368 axum_server::from_tcp(listener)
369 .handle(handle_clone)
370 .serve(make_service)
371 .await,
372 };
373
374 serve_result
375 .expect("No binding + axum MakeService::poll_ready never errors");
377 };
378
379 let graceful_shutdown_fut = async move {
380 shutdown.recv().await;
381 info!("Shutting down API server");
382 handle.graceful_shutdown(Some(SHUTDOWN_GRACE_PERIOD));
391 };
392
393 let combined_fut = async {
394 tokio::pin!(server_fut);
395 tokio::select! {
396 biased; () = graceful_shutdown_fut => (),
398 _ = &mut server_fut => return error!("Server exited early"),
399 }
400 match tokio::time::timeout(SERVER_SHUTDOWN_TIMEOUT, server_fut).await {
401 Ok(()) => info!("API server finished"),
402 Err(_) => warn!("API server timed out during shutdown"),
403 }
404 }
405 .instrument(server_span);
406
407 Ok((combined_fut, primary_server_url))
408}
409
410pub fn spawn_server_task(
414 bind_addr: SocketAddr,
415 router: Router<()>,
416 layer_config: LayerConfig,
417 maybe_tls_and_dns: Option<(Arc<rustls::ServerConfig>, &str)>,
419 server_span_name: Cow<'static, str>,
420 server_span: tracing::Span,
421 shutdown: NotifyOnce,
423) -> anyhow::Result<(LxTask<()>, String)> {
424 let listener = TcpListener::bind(bind_addr)
425 .context(bind_addr)
426 .context("Failed to bind TcpListener")?;
427
428 let (server_task, primary_server_url) = spawn_server_task_with_listener(
429 listener,
430 router,
431 layer_config,
432 maybe_tls_and_dns,
433 server_span_name,
434 server_span,
435 shutdown,
436 )
437 .context("spawn_server_task_with_listener failed")?;
438
439 Ok((server_task, primary_server_url))
440}
441
442pub fn spawn_server_task_with_listener(
444 listener: TcpListener,
445 router: Router<()>,
446 layer_config: LayerConfig,
447 maybe_tls_and_dns: Option<(Arc<rustls::ServerConfig>, &str)>,
449 server_span_name: Cow<'static, str>,
450 server_span: tracing::Span,
451 shutdown: NotifyOnce,
453) -> anyhow::Result<(LxTask<()>, String)> {
454 let (server_fut, primary_server_url) = build_server_fut_with_listener(
455 listener,
456 router,
457 layer_config,
458 maybe_tls_and_dns,
459 &server_span_name,
460 server_span.clone(),
461 shutdown,
462 )
463 .context("Failed to build server future")?;
464
465 let server_task =
466 LxTask::spawn_with_span(server_span_name, server_span, server_fut);
467
468 Ok((server_task, primary_server_url))
469}
470
471pub struct LxJson<T>(pub T);
493
494impl<T: DeserializeOwned, S: Send + Sync> FromRequest<S> for LxJson<T> {
495 type Rejection = LxRejection;
496
497 async fn from_request(
498 req: http::Request<axum::body::Body>,
499 state: &S,
500 ) -> Result<Self, Self::Rejection> {
501 <axum::Json<T> as FromRequest<S>>::from_request(req, state)
503 .await
504 .map(|axum::Json(t)| Self(t))
505 .map_err(LxRejection::from)
506 }
507}
508
509impl<T: DeserializeOwned, S: Send + Sync> OptionalFromRequest<S> for LxJson<T> {
510 type Rejection = LxRejection;
511
512 async fn from_request(
513 req: http::Request<axum::body::Body>,
514 state: &S,
515 ) -> Result<Option<Self>, Self::Rejection> {
516 <axum::Json<T> as OptionalFromRequest<S>>::from_request(req, state)
517 .await
518 .map(|opt| opt.map(|axum::Json(t)| Self(t)))
519 .map_err(LxRejection::from)
520 }
521}
522
523impl<T: Serialize> IntoResponse for LxJson<T> {
524 fn into_response(self) -> http::Response<axum::body::Body> {
525 axum_helpers::build_json_response(StatusCode::OK, &self.0)
526 }
527}
528
529impl<T: Clone> Clone for LxJson<T> {
530 fn clone(&self) -> Self {
531 Self(self.0.clone())
532 }
533}
534
535impl<T: Copy> Copy for LxJson<T> {}
536
537impl<T: fmt::Debug> fmt::Debug for LxJson<T> {
538 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
539 T::fmt(&self.0, f)
540 }
541}
542
543impl<T: Eq + PartialEq> Eq for LxJson<T> {}
544
545impl<T: PartialEq> PartialEq for LxJson<T> {
546 fn eq(&self, other: &Self) -> bool {
547 self.0.eq(&other.0)
548 }
549}
550
551#[derive(Clone, Debug, Default, Eq, PartialEq, Ord, PartialOrd)]
578pub struct LxBytes(pub Bytes);
579
580impl<S: Send + Sync> FromRequest<S> for LxBytes {
581 type Rejection = LxRejection;
582
583 async fn from_request(
584 req: http::Request<axum::body::Body>,
585 state: &S,
586 ) -> Result<Self, Self::Rejection> {
587 Bytes::from_request(req, state)
589 .await
590 .map(Self)
591 .map_err(LxRejection::from)
592 }
593}
594
595impl IntoResponse for LxBytes {
598 fn into_response(self) -> http::Response<axum::body::Body> {
599 let http_body = http_body_util::Full::new(self.0);
600 let axum_body = axum::body::Body::new(http_body);
601
602 axum_helpers::default_response_builder()
603 .header(
604 CONTENT_TYPE,
605 HeaderValue::from_static("application/octet-stream"),
607 )
608 .status(StatusCode::OK)
609 .body(axum_body)
610 .expect("All operations here should be infallible")
611 }
612}
613
614impl<T: Into<Bytes>> From<T> for LxBytes {
615 fn from(t: T) -> Self {
616 Self(t.into())
617 }
618}
619
620pub struct LxRejection {
625 kind: LxRejectionKind,
627 source_msg: String,
629}
630
631enum LxRejectionKind {
633 Bytes,
636 Json,
638 Path,
640 Query,
642
643 Unauthenticated,
646 Unauthorized,
648 BadEndpoint,
650 BodyLengthOverLimit,
652 Ed25519,
654 Proxy,
656}
657
658impl LxRejection {
660 pub fn from_ed25519(error: ed25519::Error) -> Self {
661 Self {
662 kind: LxRejectionKind::Ed25519,
663 source_msg: format!("{error:#}"),
664 }
665 }
666
667 pub fn from_bearer_auth(error: auth::Error) -> Self {
668 Self {
669 kind: LxRejectionKind::Unauthenticated,
670 source_msg: format!("{error:#}"),
671 }
672 }
673
674 pub fn scope_unauthorized(
675 granted_scope: &Scope,
676 requested_scope: &Scope,
677 ) -> Self {
678 Self {
679 kind: LxRejectionKind::Unauthorized,
680 source_msg: format!(
681 "granted scope: {granted_scope:?}, requested scope: {requested_scope:?}"
682 ),
683 }
684 }
685
686 pub fn proxy(error: impl Display) -> Self {
687 Self {
688 kind: LxRejectionKind::Proxy,
689 source_msg: format!("{error:#}"),
690 }
691 }
692}
693
694impl From<BytesRejection> for LxRejection {
695 fn from(bytes_rejection: BytesRejection) -> Self {
696 Self {
697 kind: LxRejectionKind::Bytes,
698 source_msg: bytes_rejection.body_text(),
699 }
700 }
701}
702
703impl From<JsonRejection> for LxRejection {
704 fn from(json_rejection: JsonRejection) -> Self {
705 Self {
706 kind: LxRejectionKind::Json,
707 source_msg: json_rejection.body_text(),
708 }
709 }
710}
711
712impl From<PathRejection> for LxRejection {
713 fn from(path_rejection: PathRejection) -> Self {
714 Self {
715 kind: LxRejectionKind::Path,
716 source_msg: path_rejection.body_text(),
717 }
718 }
719}
720
721impl From<QueryRejection> for LxRejection {
722 fn from(query_rejection: QueryRejection) -> Self {
723 Self {
724 kind: LxRejectionKind::Query,
725 source_msg: query_rejection.body_text(),
726 }
727 }
728}
729
730impl IntoResponse for LxRejection {
731 fn into_response(self) -> http::Response<axum::body::Body> {
732 let kind = CommonErrorKind::Rejection;
735 let kind_msg = self.kind.to_msg();
737 let source_msg = &self.source_msg;
738 let msg = format!("Rejection: {kind_msg}: {source_msg}");
739 warn!("{msg}");
741 let common_error = CommonApiError { kind, msg };
742 common_error.into_response()
743 }
744}
745
746impl LxRejectionKind {
747 fn to_msg(&self) -> &'static str {
749 match self {
750 Self::Bytes => "Bad request bytes",
751 Self::Json => "Client provided bad JSON",
752 Self::Path => "Client provided bad path parameter",
753 Self::Query => "Client provided bad query string",
754
755 Self::Unauthenticated => "Invalid bearer auth",
756 Self::Unauthorized => "Not authorized to access this resource",
757 Self::BadEndpoint => "Client requested a non-existent endpoint",
758 Self::BodyLengthOverLimit => "Request body length over limit",
759 Self::Ed25519 => "Ed25519 error",
760 Self::Proxy => "Proxy error",
761 }
762 }
763}
764
765pub mod extract {
768 use axum::extract::FromRequestParts;
769
770 use super::*;
771
772 pub struct LxQuery<T>(pub T);
774
775 impl<T: DeserializeOwned, S: Send + Sync> FromRequestParts<S> for LxQuery<T> {
776 type Rejection = LxRejection;
777
778 async fn from_request_parts(
779 parts: &mut http::request::Parts,
780 state: &S,
781 ) -> Result<Self, Self::Rejection> {
782 axum::extract::Query::from_request_parts(parts, state)
783 .await
784 .map(|axum::extract::Query(t)| Self(t))
785 .map_err(LxRejection::from)
786 }
787 }
788
789 impl<T: Clone> Clone for LxQuery<T> {
790 fn clone(&self) -> Self {
791 Self(self.0.clone())
792 }
793 }
794
795 impl<T: fmt::Debug> fmt::Debug for LxQuery<T> {
796 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
797 T::fmt(&self.0, f)
798 }
799 }
800
801 impl<T: Eq + PartialEq> Eq for LxQuery<T> {}
802
803 impl<T: PartialEq> PartialEq for LxQuery<T> {
804 fn eq(&self, other: &Self) -> bool {
805 self.0.eq(&other.0)
806 }
807 }
808
809 pub struct LxPath<T>(pub T);
811
812 impl<T: DeserializeOwned + Send, S: Send + Sync> FromRequestParts<S>
813 for LxPath<T>
814 {
815 type Rejection = LxRejection;
816
817 async fn from_request_parts(
818 parts: &mut http::request::Parts,
819 state: &S,
820 ) -> Result<Self, Self::Rejection> {
821 axum::extract::Path::from_request_parts(parts, state)
822 .await
823 .map(|axum::extract::Path(t)| Self(t))
824 .map_err(LxRejection::from)
825 }
826 }
827
828 impl<T: Clone> Clone for LxPath<T> {
829 fn clone(&self) -> Self {
830 Self(self.0.clone())
831 }
832 }
833
834 impl<T: fmt::Debug> fmt::Debug for LxPath<T> {
835 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
836 T::fmt(&self.0, f)
837 }
838 }
839
840 impl<T: Eq + PartialEq> Eq for LxPath<T> {}
841
842 impl<T: PartialEq> PartialEq for LxPath<T> {
843 fn eq(&self, other: &Self) -> bool {
844 self.0.eq(&other.0)
845 }
846 }
847}
848
849pub mod middleware {
852 use axum::extract::State;
853 use http::HeaderName;
854
855 use super::*;
856
857 pub static POST_PROCESS_HEADER: HeaderName =
859 HeaderName::from_static("lx-post-process");
860
861 pub async fn check_content_length_header<B>(
868 State(config_body_limit): State<usize>,
870 request: http::Request<B>,
871 ) -> Result<http::Request<B>, LxRejection> {
872 let maybe_content_length = request
873 .headers()
874 .get(http::header::CONTENT_LENGTH)
875 .and_then(|value| value.to_str().ok())
876 .and_then(|value_str| usize::from_str(value_str).ok());
877
878 if let Some(content_length) = maybe_content_length
880 && content_length > config_body_limit
881 {
882 return Err(LxRejection {
883 kind: LxRejectionKind::BodyLengthOverLimit,
884 source_msg: "Content length header over limit".to_owned(),
885 });
886 }
887
888 Ok(request)
889 }
890
891 pub(super) fn post_process_response(
899 mut response: http::Response<axum::body::Body>,
900 ) -> http::Response<axum::body::Body> {
901 let value = match response.headers_mut().remove(&POST_PROCESS_HEADER) {
902 Some(v) => v,
903 None => return response,
904 };
905
906 match value.as_bytes() {
907 b"remove-content-length" => {
908 response.headers_mut().remove(http::header::CONTENT_LENGTH);
909 debug!("Post process: Removed content-length header");
910 }
911 unknown => {
912 let unknown_str = String::from_utf8_lossy(unknown);
913 warn!("Post process: Invalid header value: {unknown_str}");
914 }
915 }
916
917 response
918 }
919}
920
921pub async fn default_fallback(
926 method: http::Method,
927 uri: http::Uri,
928) -> LxRejection {
929 let path = uri.path();
930 LxRejection {
931 kind: LxRejectionKind::BadEndpoint,
932 source_msg: format!("{method} {path}"),
934 }
935}