lexe_api/server.rs
1// This is the only place where we are allowed to use e.g. `Json` and `Query`.
2#![allow(clippy::disallowed_types)]
3
4//! This module provides various API server utilities.
5//!
6//! # Serving
7//!
8//! Methods to serve a [`Router`] with a fallback handler (for unmatched paths),
9//! tracing / request instrumentation, backpressure, load shedding, concurrency
10//! limits, server-side timeouts, TLS, and graceful shutdown:
11//!
12//! - [`build_server_fut`]
13//! - [`build_server_fut_with_listener`]
14//! - [`spawn_server_task`]
15//! - [`spawn_server_task_with_listener`]
16//!
17//! # Extractors to get data from requests:
18//!
19//! - [`LxJson`] to deserialize from HTTP body JSON
20//! - [`LxQuery`] to deserialize from query strings
21//!
22//! # [`IntoResponse`] types / impls for building Lexe API-conformant responses:
23//!
24//! - [`LxJson`] type for returning success responses as JSON
25//! - All [`ApiError`]s and [`CommonApiError`] impl [`IntoResponse`]
26//! - [`LxRejection`] for notifying clients of bad JSON, query strings, etc.
27//!
28//! [`ApiError`]: lexe_api_core::error::ApiError
29//! [`CommonApiError`]: lexe_api_core::error::CommonApiError
30//! [`Router`]: axum::Router
31//! [`IntoResponse`]: axum::response::IntoResponse
32//! [`LxJson`]: crate::server::LxJson
33//! [`LxQuery`]: crate::server::extract::LxQuery
34//! [`LxRejection`]: crate::server::LxRejection
35//! [`build_server_fut`]: crate::server::build_server_fut
36//! [`build_server_fut_with_listener`]: crate::server::build_server_fut_with_listener
37//! [`spawn_server_task`]: crate::server::spawn_server_task
38//! [`spawn_server_task_with_listener`]: crate::server::spawn_server_task_with_listener
39
40use std::{
41 borrow::Cow,
42 convert::Infallible,
43 fmt::{self, Display},
44 future::Future,
45 net::{SocketAddr, TcpListener},
46 str::FromStr,
47 sync::Arc,
48 time::Duration,
49};
50
51use anyhow::Context;
52use axum::{
53 Router, ServiceExt as AxumServiceExt,
54 error_handling::HandleErrorLayer,
55 extract::{
56 DefaultBodyLimit, FromRequest,
57 rejection::{
58 BytesRejection, JsonRejection, PathRejection, QueryRejection,
59 },
60 },
61 response::IntoResponse,
62 routing::RouterIntoService,
63};
64use axum_server::tls_rustls::RustlsConfig;
65use bytes::Bytes;
66use http::{HeaderValue, StatusCode, header::CONTENT_TYPE};
67use lexe_api_core::{
68 axum_helpers,
69 error::{CommonApiError, CommonErrorKind},
70};
71use lexe_common::api::auth::{self, Scope};
72use lexe_crypto::ed25519;
73use lexe_tokio::{notify_once::NotifyOnce, task::LxTask};
74use serde::{Serialize, de::DeserializeOwned};
75use tower::{
76 Layer, buffer::BufferLayer, limit::ConcurrencyLimitLayer,
77 load_shed::LoadShedLayer, timeout::TimeoutLayer, util::MapRequestLayer,
78};
79use tracing::{Instrument, debug, error, info, warn};
80
81use crate::{rest, trace};
82
83/// The grace period passed to [`axum_server::Handle::graceful_shutdown`] during
84/// which new connections are refused and we wait for existing connections to
85/// terminate before initiating a hard shutdown.
86const SHUTDOWN_GRACE_PERIOD: Duration = Duration::from_secs(3);
87/// The maximum time we'll wait for a server to complete shutdown.
88pub const SERVER_SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(5);
89lexe_std::const_assert!(
90 SHUTDOWN_GRACE_PERIOD.as_secs() < SERVER_SHUTDOWN_TIMEOUT.as_secs()
91);
92
93/// The default maximum time a server can spend handling a request.
94pub const SERVER_HANDLER_TIMEOUT: Duration = Duration::from_secs(25);
95lexe_std::const_assert!(
96 rest::API_REQUEST_TIMEOUT.as_secs() > SERVER_HANDLER_TIMEOUT.as_secs()
97);
98
99/// A configuration object for Axum / Tower middleware.
100///
101/// Defaults:
102///
103/// ```
104/// # use std::time::Duration;
105/// # use lexe_api::server::LayerConfig;
106/// assert_eq!(
107/// LayerConfig::default(),
108/// LayerConfig {
109/// body_limit: 16384,
110/// buffer_size: 4096,
111/// concurrency: 4096,
112/// handling_timeout: Duration::from_secs(25),
113/// default_fallback: true,
114/// }
115/// );
116/// ```
117#[derive(Clone, Debug, Eq, PartialEq)]
118pub struct LayerConfig {
119 /// The maximum size of the request body in bytes.
120 /// Helps prevent DoS, but may need to be increased for some services.
121 pub body_limit: usize,
122 /// The size of the work buffer for our service.
123 /// Allows the server to immediately work on more queued requests when a
124 /// request completes, and prevents a large backlog from building up.
125 pub buffer_size: usize,
126 /// The maximum # of requests we'll process at once.
127 /// Helps prevent the CPU from maxing out, resulting in thrashing.
128 pub concurrency: usize,
129 /// The maximum time a server can spend handling a request. Helps prevent
130 /// degenerate cases which take abnormally long to process from crowding
131 /// out normal workloads.
132 pub handling_timeout: Duration,
133 /// Whether to add Lexe's default [`Router::fallback`] to the [`Router`].
134 /// The [`Router::fallback`] is called if no routes were matched;
135 /// Lexe's [`default_fallback`] returns a "bad endpoint" rejection along
136 /// with the requested method and path.
137 ///
138 /// If you need to set a custom fallback, set this to [`false`], otherwise
139 /// your custom fallback will be clobbered by Lexe's [`default_fallback`].
140 /// NOTE, however, that the caller is responsible for ensuring that the
141 /// [`Router`] has a fallback configured in this case.
142 pub default_fallback: bool,
143}
144
145impl Default for LayerConfig {
146 fn default() -> Self {
147 Self {
148 // 16KiB is sufficient for most Lexe services.
149 body_limit: 16384,
150 // TODO(max): We are using very high values right now because it
151 // doesn't make sense to constrain anything until we have run some
152 // load tests to profile performance and see what breaks.
153 buffer_size: 4096,
154 concurrency: 4096,
155 handling_timeout: SERVER_HANDLER_TIMEOUT,
156 default_fallback: true,
157 }
158 }
159}
160
161// --- Server helpers --- //
162
163/// Construct a server URL given the [`TcpListener::local_addr`] from by a
164/// server's [`TcpListener`], and its DNS name.
165///
166/// ex: `https://lexe.app` (port=443)
167/// ex: `https://relay.lexe.app:4396`
168/// ex: `http://[::1]:8080`
169//
170// We have a fn to build the url because it's easy to mess up.
171pub fn build_server_url(
172 // The output of `TcpListener::local_addr`
173 listener_addr: SocketAddr,
174 // Primary DNS name
175 maybe_dns: Option<&str>,
176) -> String {
177 match maybe_dns {
178 Some(dns_name) => {
179 let port = listener_addr.port();
180 if port == 443 {
181 format!("https://{dns_name}")
182 } else {
183 format!("https://{dns_name}:{port}")
184 }
185 }
186 None => format!("http://{listener_addr}"),
187 }
188}
189
190/// Constructs an API server future which can be spawned into a task.
191/// Additionally returns the server url.
192///
193/// Use this helper when it is useful to poll multiple futures in a single task
194/// to reduce the amount of task nesting / indirection. If there is only one
195/// future that needs to be driven, use [`spawn_server_task`] instead.
196///
197/// Errors if the [`TcpListener`] failed to bind or return its local address.
198/// Returns the server future along with the bound socket address.
199// Avoids generic parameters to prevent binary bloat.
200// Returns unnamed `impl Future` to avoid Pin<Box<T>> deref cost.
201pub fn build_server_fut(
202 bind_addr: SocketAddr,
203 router: Router<()>,
204 layer_config: LayerConfig,
205 // TLS config + primary DNS name
206 maybe_tls_and_dns: Option<(Arc<rustls::ServerConfig>, &str)>,
207 server_span_name: &str,
208 server_span: tracing::Span,
209 // Send on this channel to begin a graceful shutdown of the server.
210 shutdown: NotifyOnce,
211) -> anyhow::Result<(impl Future<Output = ()>, String)> {
212 let listener =
213 TcpListener::bind(bind_addr).context("Could not bind TCP listener")?;
214 let (server_fut, primary_server_url) = build_server_fut_with_listener(
215 listener,
216 router,
217 layer_config,
218 maybe_tls_and_dns,
219 server_span_name,
220 server_span,
221 shutdown,
222 )
223 .context("Could not build server future")?;
224 Ok((server_fut, primary_server_url))
225}
226
227/// [`build_server_fut`] but takes a [`TcpListener`] instead of [`SocketAddr`].
228// Avoids generic parameters to prevent binary bloat.
229// Returns unnamed `impl Future` to avoid Pin<Box<T>> deref cost.
230pub fn build_server_fut_with_listener(
231 listener: TcpListener,
232 router: Router<()>,
233 layer_config: LayerConfig,
234 // TLS config + primary DNS name
235 maybe_tls_and_dns: Option<(Arc<rustls::ServerConfig>, &str)>,
236 server_span_name: &str,
237 server_span: tracing::Span,
238 // Send on this channel to begin a graceful shutdown of the server.
239 mut shutdown: NotifyOnce,
240) -> anyhow::Result<(impl Future<Output = ()> + use<>, String)> {
241 let (maybe_tls_config, maybe_dns) = maybe_tls_and_dns.unzip();
242 let listener_addr = listener
243 .local_addr()
244 .context("Could not get listener local address")?;
245 let primary_server_url = build_server_url(listener_addr, maybe_dns);
246 info!("Url for {server_span_name}: {primary_server_url}");
247
248 // Add Lexe's default fallback if it is enabled in the LayerConfig.
249 let router = if layer_config.default_fallback {
250 router.fallback(default_fallback)
251 } else {
252 router
253 };
254
255 // Used to annotate the service / request / response types
256 // at each point in the ServiceBuilder chains.
257 type HyperService = RouterIntoService<hyper::body::Incoming, ()>;
258 type AxumService = RouterIntoService<axum::body::Body, ()>;
259 type HyperReq = http::Request<hyper::body::Incoming>;
260 type AxumReq = http::Request<axum::body::Body>;
261 type AxumResp = http::Response<axum::body::Body>;
262 type TraceResp = http::Response<
263 tower_http::trace::ResponseBody<
264 axum::body::Body,
265 tower_http::classify::NeverClassifyEos<anyhow::Error>,
266 (),
267 trace::server::LxOnEos,
268 trace::server::LxOnFailure,
269 >,
270 >;
271
272 // The outer middleware stack which wraps the entire Router.
273 //
274 // Axum docs explain ordering better than tower's ServiceBuilder docs do:
275 // https://docs.rs/axum/latest/axum/middleware/index.html#ordering
276 // Basically, requests go from top to bottom and responses bottom to top.
277 let outer_middleware = tower::ServiceBuilder::new()
278 .check_service::<HyperService, HyperReq, AxumResp, Infallible>()
279 // Log everything on its way in and out, even load-shedded requests.
280 // This layer changes the response type.
281 .layer(trace::server::trace_layer(server_span.clone()))
282 .check_service::<HyperService, HyperReq, TraceResp, Infallible>()
283 // Run our post-processor which can modify responses *after* the Axum
284 // Router has constructed the response.
285 .layer(tower::util::MapResponseLayer::new(
286 middleware::post_process_response,
287 ))
288 .check_service::<HyperService, HyperReq, TraceResp, Infallible>();
289
290 // The inner middleware stack which is cloned to each route in the Router.
291 // We put most of the layers here because it is a lot easier to work with
292 // axum types; moving these outside quickly degenerates into type hell.
293 let inner_middleware = tower::ServiceBuilder::new()
294 .check_service::<AxumService, AxumReq, AxumResp, Infallible>()
295 // Immediately reject anything with a CONTENT_LENGTH over the limit.
296 .layer(axum::middleware::map_request_with_state(
297 layer_config.body_limit,
298 middleware::check_content_length_header,
299 ))
300 .check_service::<AxumService, AxumReq, AxumResp, Infallible>()
301 // Set the default request body limit for all requests. This adds a
302 // `DefaultBodyLimitKind` (private axum type) into the request
303 // extensions so that any inner layers or extractors which call
304 // `axum::RequestExt::[with|into]_limited_body` will pick it up.
305 // NOTE that many of our extractors transitively rely on the Bytes
306 // extractor which will default to a 2MB limit if this is not set.
307 .layer(DefaultBodyLimit::max(layer_config.body_limit))
308 .check_service::<AxumService, AxumReq, AxumResp, Infallible>()
309 // Here, we explicitly apply the body limit from the request extensions,
310 // transforming the request body type into `http_body_util::Limited`.
311 .layer(MapRequestLayer::new(axum::RequestExt::with_limited_body))
312 .check_service::<AxumService, AxumReq, AxumResp, Infallible>()
313 // Handles errors from the load_shed, buffer, and concurrency layers.
314 .layer(HandleErrorLayer::new(|_: tower::BoxError| async move {
315 CommonApiError {
316 kind: CommonErrorKind::AtCapacity,
317 msg: "Service is at capacity; retry later".to_owned(),
318 }
319 }))
320 // Returns an `Err` if the inner service returns `Poll::Pending`.
321 // Helps prevent OOM when combined with the buffer or concurrency layer.
322 .layer(LoadShedLayer::new())
323 .check_service::<AxumService, AxumReq, AxumResp, Infallible>()
324 // Returns Poll::Pending when the buffer is full (backpressure).
325 // Allows the server to immediately work on more queued requests when a
326 // request completes, and prevents a large backlog from building up.
327 // Note that while the layer is often cloned, the buffer itself is not.
328 .layer(BufferLayer::new(layer_config.buffer_size))
329 .check_service::<AxumService, AxumReq, AxumResp, Infallible>()
330 // Returns `Poll::Pending` when the concurrency limit has been reached.
331 // Helps prevent the CPU from maxing out, resulting in thrashing.
332 .layer(ConcurrencyLimitLayer::new(layer_config.concurrency))
333 .check_service::<AxumService, AxumReq, AxumResp, Infallible>()
334 // Handles errors generated by the timeout layer.
335 .layer(HandleErrorLayer::new(|_: tower::BoxError| async move {
336 CommonApiError {
337 kind: CommonErrorKind::Server,
338 msg: "Server timed out handling request".to_owned(),
339 }
340 }))
341 // Returns an error if the inner service takes longer than the timeout
342 // to handle the request. Prevents degenerate cases which take
343 // abnormally long to process from crowding out normal workloads.
344 .layer(TimeoutLayer::new(layer_config.handling_timeout))
345 .check_service::<AxumService, AxumReq, AxumResp, Infallible>();
346
347 // Apply inner middleware
348 let layered_router = router.layer(inner_middleware);
349 // Convert into Service
350 let router_service = layered_router.into_service::<hyper::body::Incoming>();
351 // Apply outer middleware
352 let layered_service = Layer::layer(&outer_middleware, router_service);
353 // Convert into MakeService
354 let make_service = layered_service.into_make_service();
355
356 let handle = axum_server::Handle::new();
357 let handle_clone = handle.clone();
358 let server_fut = async {
359 let serve_result = match maybe_tls_config {
360 Some(tls_config) => {
361 let axum_tls_config = RustlsConfig::from_config(tls_config);
362 axum_server::from_tcp_rustls(listener, axum_tls_config)
363 .handle(handle_clone)
364 .serve(make_service)
365 .await
366 }
367 None =>
368 axum_server::from_tcp(listener)
369 .handle(handle_clone)
370 .serve(make_service)
371 .await,
372 };
373
374 serve_result
375 // See axum_server::Server::serve docs for why this can't error
376 .expect("No binding + axum MakeService::poll_ready never errors");
377 };
378
379 let graceful_shutdown_fut = async move {
380 shutdown.recv().await;
381 info!("Shutting down API server");
382 // The 'grace period' is a period of time during which new connections
383 // are refused and `axum_server::Server::serve` waits for all current
384 // connections to terminate. If `None`, the server waits indefinitely
385 // for current connections to terminate; if `Some`, the server will
386 // initiate a hard shutdown after the grace period has elapsed. We use
387 // Some(_) with a relatively short grace period because (1) our handlers
388 // shouldn't take long to return and (2) we sometimes see connections
389 // failing to terminate for servers which have a /shutdown endpoint.
390 handle.graceful_shutdown(Some(SHUTDOWN_GRACE_PERIOD));
391 };
392
393 let combined_fut = async {
394 tokio::pin!(server_fut);
395 tokio::select! {
396 biased; // Ensure graceful shutdown future finishes first
397 () = graceful_shutdown_fut => (),
398 _ = &mut server_fut => return error!("Server exited early"),
399 }
400 match tokio::time::timeout(SERVER_SHUTDOWN_TIMEOUT, server_fut).await {
401 Ok(()) => info!("API server finished"),
402 Err(_) => warn!("API server timed out during shutdown"),
403 }
404 }
405 .instrument(server_span);
406
407 Ok((combined_fut, primary_server_url))
408}
409
410/// [`build_server_fut`] but additionally spawns the server future into an
411/// instrumented server task and logs the full URL used to access the server.
412/// Returns the server task and server url.
413pub fn spawn_server_task(
414 bind_addr: SocketAddr,
415 router: Router<()>,
416 layer_config: LayerConfig,
417 // TLS config + primary DNS name
418 maybe_tls_and_dns: Option<(Arc<rustls::ServerConfig>, &str)>,
419 server_span_name: Cow<'static, str>,
420 server_span: tracing::Span,
421 // Send on this channel to begin a graceful shutdown of the server.
422 shutdown: NotifyOnce,
423) -> anyhow::Result<(LxTask<()>, String)> {
424 let listener = TcpListener::bind(bind_addr)
425 .context(bind_addr)
426 .context("Failed to bind TcpListener")?;
427
428 let (server_task, primary_server_url) = spawn_server_task_with_listener(
429 listener,
430 router,
431 layer_config,
432 maybe_tls_and_dns,
433 server_span_name,
434 server_span,
435 shutdown,
436 )
437 .context("spawn_server_task_with_listener failed")?;
438
439 Ok((server_task, primary_server_url))
440}
441
442/// [`spawn_server_task`] but takes [`TcpListener`] instead of [`SocketAddr`].
443pub fn spawn_server_task_with_listener(
444 listener: TcpListener,
445 router: Router<()>,
446 layer_config: LayerConfig,
447 // TLS config + primary DNS name
448 maybe_tls_and_dns: Option<(Arc<rustls::ServerConfig>, &str)>,
449 server_span_name: Cow<'static, str>,
450 server_span: tracing::Span,
451 // Send on this channel to begin a graceful shutdown of the server.
452 shutdown: NotifyOnce,
453) -> anyhow::Result<(LxTask<()>, String)> {
454 let (server_fut, primary_server_url) = build_server_fut_with_listener(
455 listener,
456 router,
457 layer_config,
458 maybe_tls_and_dns,
459 &server_span_name,
460 server_span.clone(),
461 shutdown,
462 )
463 .context("Failed to build server future")?;
464
465 let server_task =
466 LxTask::spawn_with_span(server_span_name, server_span, server_fut);
467
468 Ok((server_task, primary_server_url))
469}
470
471// --- LxJson --- //
472
473/// A version of [`axum::Json`] which conforms to Lexe's (JSON) API.
474/// It can be used as either an extractor or a response.
475///
476/// - As an extractor: rejections return [`LxRejection`].
477/// - As a success response:
478/// - Serialization success returns an [`http::Response`] with JSON body.
479/// - Serialization failure returns a [`ErrorResponse`].
480///
481/// [`axum::Json`] is banned because:
482///
483/// - Rejections return [`JsonRejection`] which is just a string HTTP body.
484/// - Response serialization failures likewise return just a string body.
485///
486/// NOTE: This must only be used for forming *success* API responses,
487/// i.e. `T` in `Result<T, E>`, because its [`IntoResponse`] impl uses
488/// [`StatusCode::OK`]. Our API error types, while also serialized as JSON,
489/// have separate [`IntoResponse`] impls which return error statuses.
490///
491/// [`ErrorResponse`]: lexe_api_core::error::ErrorResponse
492pub struct LxJson<T>(pub T);
493
494impl<T: DeserializeOwned, S: Send + Sync> FromRequest<S> for LxJson<T> {
495 type Rejection = LxRejection;
496
497 async fn from_request(
498 req: http::Request<axum::body::Body>,
499 state: &S,
500 ) -> Result<Self, Self::Rejection> {
501 // `axum::Json`'s from_request impl is fine but its rejection is not
502 axum::Json::from_request(req, state)
503 .await
504 .map(|axum::Json(t)| Self(t))
505 .map_err(LxRejection::from)
506 }
507}
508
509impl<T: Serialize> IntoResponse for LxJson<T> {
510 fn into_response(self) -> http::Response<axum::body::Body> {
511 axum_helpers::build_json_response(StatusCode::OK, &self.0)
512 }
513}
514
515impl<T: Clone> Clone for LxJson<T> {
516 fn clone(&self) -> Self {
517 Self(self.0.clone())
518 }
519}
520
521impl<T: Copy> Copy for LxJson<T> {}
522
523impl<T: fmt::Debug> fmt::Debug for LxJson<T> {
524 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
525 T::fmt(&self.0, f)
526 }
527}
528
529impl<T: Eq + PartialEq> Eq for LxJson<T> {}
530
531impl<T: PartialEq> PartialEq for LxJson<T> {
532 fn eq(&self, other: &Self) -> bool {
533 self.0.eq(&other.0)
534 }
535}
536
537// --- LxBytes --- //
538
539/// A version of [`Bytes`] which conforms to Lexe's (binary) API.
540/// - [`axum`] has implementations of [`FromRequest`] and [`IntoResponse`] for
541/// [`Bytes`], but these implementations are not Lexe API-conformant.
542/// - This type can be used as either an extractor or a success response, and
543/// should always be used instead of [`Bytes`] in these server contexts.
544/// - It is still fine to use [`Bytes`] on the client side.
545///
546/// - As an extractor: rejections return [`LxRejection`].
547/// - As a success response:
548/// - Returns an [`http::Response`] with a binary body.
549///
550/// - Any failure encountered in extraction or creation should produce an
551/// [`ErrorResponse`].
552///
553/// The regular impls are non-conformant because:
554///
555/// - Rejections return [`BytesRejection`] which is just a string HTTP body.
556///
557/// NOTE: This must only be used for forming *success* API responses,
558/// i.e. `LxBytes` in `Result<LxBytes, E>`, because its [`IntoResponse`] impl
559/// uses [`StatusCode::OK`]. Our API error types are serialized as JSON and
560/// have separate [`IntoResponse`] impls which return error statuses.
561///
562/// [`ErrorResponse`]: lexe_api_core::error::ErrorResponse
563#[derive(Clone, Debug, Default, Eq, PartialEq, Ord, PartialOrd)]
564pub struct LxBytes(pub Bytes);
565
566impl<S: Send + Sync> FromRequest<S> for LxBytes {
567 type Rejection = LxRejection;
568
569 async fn from_request(
570 req: http::Request<axum::body::Body>,
571 state: &S,
572 ) -> Result<Self, Self::Rejection> {
573 // `Bytes`'s from_request impl is fine but its rejection is not
574 Bytes::from_request(req, state)
575 .await
576 .map(Self)
577 .map_err(LxRejection::from)
578 }
579}
580
581/// The [`Bytes`] [`IntoResponse`] impl is almost exactly the same,
582/// except it returns the wrong HTTP version.
583impl IntoResponse for LxBytes {
584 fn into_response(self) -> http::Response<axum::body::Body> {
585 let http_body = http_body_util::Full::new(self.0);
586 let axum_body = axum::body::Body::new(http_body);
587
588 axum_helpers::default_response_builder()
589 .header(
590 CONTENT_TYPE,
591 // Or `HeaderValue::from_static(mime::APPLICATION_OCTET_STREAM)`
592 HeaderValue::from_static("application/octet-stream"),
593 )
594 .status(StatusCode::OK)
595 .body(axum_body)
596 .expect("All operations here should be infallible")
597 }
598}
599
600impl<T: Into<Bytes>> From<T> for LxBytes {
601 fn from(t: T) -> Self {
602 Self(t.into())
603 }
604}
605
606// --- LxRejection --- //
607
608/// Our own [`axum::extract::rejection`] type with an [`IntoResponse`] impl
609/// which conforms to Lexe's API. Contains the source rejection's error text.
610pub struct LxRejection {
611 /// Which [`axum::extract::rejection`] this [`LxRejection`] was built from.
612 kind: LxRejectionKind,
613 /// The error text of the source rejection, or additional context.
614 source_msg: String,
615}
616
617/// The source of this [`LxRejection`].
618enum LxRejectionKind {
619 // -- From `axum::extract::rejection` -- //
620 /// [`BytesRejection`]
621 Bytes,
622 /// [`JsonRejection`]
623 Json,
624 /// [`PathRejection`]
625 Path,
626 /// [`QueryRejection`]
627 Query,
628
629 // -- Other -- //
630 /// Bearer authentication failed
631 Unauthenticated,
632 /// Client is not authorized to access this resource
633 Unauthorized,
634 /// Client request did not match any paths in the [`Router`].
635 BadEndpoint,
636 /// Request body length over limit
637 BodyLengthOverLimit,
638 /// [`ed25519::Error`]
639 Ed25519,
640 /// Gateway proxy
641 Proxy,
642}
643
644// Use explicit `.map_err()`s instead of From impls for non-obvious conversions
645impl LxRejection {
646 pub fn from_ed25519(error: ed25519::Error) -> Self {
647 Self {
648 kind: LxRejectionKind::Ed25519,
649 source_msg: format!("{error:#}"),
650 }
651 }
652
653 pub fn from_bearer_auth(error: auth::Error) -> Self {
654 Self {
655 kind: LxRejectionKind::Unauthenticated,
656 source_msg: format!("{error:#}"),
657 }
658 }
659
660 pub fn scope_unauthorized(
661 granted_scope: &Scope,
662 requested_scope: &Scope,
663 ) -> Self {
664 Self {
665 kind: LxRejectionKind::Unauthorized,
666 source_msg: format!(
667 "granted scope: {granted_scope:?}, requested scope: {requested_scope:?}"
668 ),
669 }
670 }
671
672 pub fn proxy(error: impl Display) -> Self {
673 Self {
674 kind: LxRejectionKind::Proxy,
675 source_msg: format!("{error:#}"),
676 }
677 }
678}
679
680impl From<BytesRejection> for LxRejection {
681 fn from(bytes_rejection: BytesRejection) -> Self {
682 Self {
683 kind: LxRejectionKind::Bytes,
684 source_msg: bytes_rejection.body_text(),
685 }
686 }
687}
688
689impl From<JsonRejection> for LxRejection {
690 fn from(json_rejection: JsonRejection) -> Self {
691 Self {
692 kind: LxRejectionKind::Json,
693 source_msg: json_rejection.body_text(),
694 }
695 }
696}
697
698impl From<PathRejection> for LxRejection {
699 fn from(path_rejection: PathRejection) -> Self {
700 Self {
701 kind: LxRejectionKind::Path,
702 source_msg: path_rejection.body_text(),
703 }
704 }
705}
706
707impl From<QueryRejection> for LxRejection {
708 fn from(query_rejection: QueryRejection) -> Self {
709 Self {
710 kind: LxRejectionKind::Query,
711 source_msg: query_rejection.body_text(),
712 }
713 }
714}
715
716impl IntoResponse for LxRejection {
717 fn into_response(self) -> http::Response<axum::body::Body> {
718 // TODO(phlip9): authn+authz+badendpoint rejections should return
719 // standard status codes, not just 400.
720 let kind = CommonErrorKind::Rejection;
721 // "Bad JSON: Failed to deserialize the JSON body into the target type"
722 let kind_msg = self.kind.to_msg();
723 let source_msg = &self.source_msg;
724 let msg = format!("Rejection: {kind_msg}: {source_msg}");
725 // Log the rejection now since our trace layer can't access this info
726 warn!("{msg}");
727 let common_error = CommonApiError { kind, msg };
728 common_error.into_response()
729 }
730}
731
732impl LxRejectionKind {
733 /// A generic error message for this rejection kind.
734 fn to_msg(&self) -> &'static str {
735 match self {
736 Self::Bytes => "Bad request bytes",
737 Self::Json => "Client provided bad JSON",
738 Self::Path => "Client provided bad path parameter",
739 Self::Query => "Client provided bad query string",
740
741 Self::Unauthenticated => "Invalid bearer auth",
742 Self::Unauthorized => "Not authorized to access this resource",
743 Self::BadEndpoint => "Client requested a non-existent endpoint",
744 Self::BodyLengthOverLimit => "Request body length over limit",
745 Self::Ed25519 => "Ed25519 error",
746 Self::Proxy => "Proxy error",
747 }
748 }
749}
750
751// --- Extractors --- //
752
753pub mod extract {
754 use axum::extract::FromRequestParts;
755
756 use super::*;
757
758 /// Lexe API-compliant version of [`axum::extract::Query`].
759 pub struct LxQuery<T>(pub T);
760
761 impl<T: DeserializeOwned, S: Send + Sync> FromRequestParts<S> for LxQuery<T> {
762 type Rejection = LxRejection;
763
764 async fn from_request_parts(
765 parts: &mut http::request::Parts,
766 state: &S,
767 ) -> Result<Self, Self::Rejection> {
768 axum::extract::Query::from_request_parts(parts, state)
769 .await
770 .map(|axum::extract::Query(t)| Self(t))
771 .map_err(LxRejection::from)
772 }
773 }
774
775 impl<T: Clone> Clone for LxQuery<T> {
776 fn clone(&self) -> Self {
777 Self(self.0.clone())
778 }
779 }
780
781 impl<T: fmt::Debug> fmt::Debug for LxQuery<T> {
782 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
783 T::fmt(&self.0, f)
784 }
785 }
786
787 impl<T: Eq + PartialEq> Eq for LxQuery<T> {}
788
789 impl<T: PartialEq> PartialEq for LxQuery<T> {
790 fn eq(&self, other: &Self) -> bool {
791 self.0.eq(&other.0)
792 }
793 }
794
795 /// Lexe API-compliant version of [`axum::extract::Path`].
796 pub struct LxPath<T>(pub T);
797
798 impl<T: DeserializeOwned + Send, S: Send + Sync> FromRequestParts<S>
799 for LxPath<T>
800 {
801 type Rejection = LxRejection;
802
803 async fn from_request_parts(
804 parts: &mut http::request::Parts,
805 state: &S,
806 ) -> Result<Self, Self::Rejection> {
807 axum::extract::Path::from_request_parts(parts, state)
808 .await
809 .map(|axum::extract::Path(t)| Self(t))
810 .map_err(LxRejection::from)
811 }
812 }
813
814 impl<T: Clone> Clone for LxPath<T> {
815 fn clone(&self) -> Self {
816 Self(self.0.clone())
817 }
818 }
819
820 impl<T: fmt::Debug> fmt::Debug for LxPath<T> {
821 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
822 T::fmt(&self.0, f)
823 }
824 }
825
826 impl<T: Eq + PartialEq> Eq for LxPath<T> {}
827
828 impl<T: PartialEq> PartialEq for LxPath<T> {
829 fn eq(&self, other: &Self) -> bool {
830 self.0.eq(&other.0)
831 }
832 }
833}
834
835// --- Custom middleware --- //
836
837pub mod middleware {
838 use axum::extract::State;
839 use http::HeaderName;
840
841 use super::*;
842
843 /// The header name used for response post-processing signals.
844 pub static POST_PROCESS_HEADER: HeaderName =
845 HeaderName::from_static("lx-post-process");
846
847 /// Checks the `CONTENT_LENGTH` header and returns an early rejection if the
848 /// contained value exceeds our configured body limit. This optimization
849 /// allows us to avoid unnecessary work processing the request further.
850 ///
851 /// NOTE: This does not enforce the body length!! Use [`DefaultBodyLimit`]
852 /// in combination with [`axum::RequestExt::with_limited_body`] to do so.
853 pub async fn check_content_length_header<B>(
854 // `LayerConfig::body_limit`
855 State(config_body_limit): State<usize>,
856 request: http::Request<B>,
857 ) -> Result<http::Request<B>, LxRejection> {
858 let maybe_content_length = request
859 .headers()
860 .get(http::header::CONTENT_LENGTH)
861 .and_then(|value| value.to_str().ok())
862 .and_then(|value_str| usize::from_str(value_str).ok());
863
864 // If a limit is configured and the header value exceeds it, reject.
865 if let Some(content_length) = maybe_content_length
866 && content_length > config_body_limit
867 {
868 return Err(LxRejection {
869 kind: LxRejectionKind::BodyLengthOverLimit,
870 source_msg: "Content length header over limit".to_owned(),
871 });
872 }
873
874 Ok(request)
875 }
876
877 /// A post-processor which can be used to modify the [`http::Response`]s
878 /// returned by an [`axum::Router`]. This is done by signalling the desired
879 /// modification in a fake [`POST_PROCESS_HEADER`] which is also removed
880 /// during post-processing. This can be used to override Axum defaults
881 /// which one does not have access to from within the [`Router`]. Currently,
882 /// this only supports a "remove-content-length" command which removes the
883 /// content-length header set by Axum, but can be easily extended.
884 pub(super) fn post_process_response(
885 mut response: http::Response<axum::body::Body>,
886 ) -> http::Response<axum::body::Body> {
887 let value = match response.headers_mut().remove(&POST_PROCESS_HEADER) {
888 Some(v) => v,
889 None => return response,
890 };
891
892 match value.as_bytes() {
893 b"remove-content-length" => {
894 response.headers_mut().remove(http::header::CONTENT_LENGTH);
895 debug!("Post process: Removed content-length header");
896 }
897 unknown => {
898 let unknown_str = String::from_utf8_lossy(unknown);
899 warn!("Post process: Invalid header value: {unknown_str}");
900 }
901 }
902
903 response
904 }
905}
906
907// --- Helpers --- //
908
909/// Lexe's default fallback [`Handler`](axum::handler::Handler).
910/// Returns a "bad endpoint" rejection along with the requested method and path.
911pub async fn default_fallback(
912 method: http::Method,
913 uri: http::Uri,
914) -> LxRejection {
915 let path = uri.path();
916 LxRejection {
917 kind: LxRejectionKind::BadEndpoint,
918 // e.g. "POST /app/node_info"
919 source_msg: format!("{method} {path}"),
920 }
921}