Skip to main content

lexe_api/
server.rs

1// This is the only place where we are allowed to use e.g. `Json` and `Query`.
2#![allow(clippy::disallowed_types)]
3
4//! This module provides various API server utilities.
5//!
6//! # Serving
7//!
8//! Methods to serve a [`Router`] with a fallback handler (for unmatched paths),
9//! tracing / request instrumentation, backpressure, load shedding, concurrency
10//! limits, server-side timeouts, TLS, and graceful shutdown:
11//!
12//! - [`build_server_fut`]
13//! - [`build_server_fut_with_listener`]
14//! - [`spawn_server_task`]
15//! - [`spawn_server_task_with_listener`]
16//!
17//! # Extractors to get data from requests:
18//!
19//! - [`LxJson`] to deserialize from HTTP body JSON
20//! - [`LxQuery`] to deserialize from query strings
21//!
22//! # [`IntoResponse`] types / impls for building Lexe API-conformant responses:
23//!
24//! - [`LxJson`] type for returning success responses as JSON
25//! - All [`ApiError`]s and [`CommonApiError`] impl [`IntoResponse`]
26//! - [`LxRejection`] for notifying clients of bad JSON, query strings, etc.
27//!
28//! [`ApiError`]: lexe_api_core::error::ApiError
29//! [`CommonApiError`]: lexe_api_core::error::CommonApiError
30//! [`Router`]: axum::Router
31//! [`IntoResponse`]: axum::response::IntoResponse
32//! [`LxJson`]: crate::server::LxJson
33//! [`LxQuery`]: crate::server::extract::LxQuery
34//! [`LxRejection`]: crate::server::LxRejection
35//! [`build_server_fut`]: crate::server::build_server_fut
36//! [`build_server_fut_with_listener`]: crate::server::build_server_fut_with_listener
37//! [`spawn_server_task`]: crate::server::spawn_server_task
38//! [`spawn_server_task_with_listener`]: crate::server::spawn_server_task_with_listener
39
40use std::{
41    borrow::Cow,
42    convert::Infallible,
43    fmt::{self, Display},
44    future::Future,
45    net::{SocketAddr, TcpListener},
46    str::FromStr,
47    sync::Arc,
48    time::Duration,
49};
50
51use anyhow::Context;
52use axum::{
53    Router, ServiceExt as AxumServiceExt,
54    error_handling::HandleErrorLayer,
55    extract::{
56        DefaultBodyLimit, FromRequest,
57        rejection::{
58            BytesRejection, JsonRejection, PathRejection, QueryRejection,
59        },
60    },
61    response::IntoResponse,
62    routing::RouterIntoService,
63};
64use axum_server::tls_rustls::RustlsConfig;
65use bytes::Bytes;
66use http::{HeaderValue, StatusCode, header::CONTENT_TYPE};
67use lexe_api_core::{
68    axum_helpers,
69    error::{CommonApiError, CommonErrorKind},
70};
71use lexe_common::api::auth::{self, Scope};
72use lexe_crypto::ed25519;
73use lexe_tokio::{notify_once::NotifyOnce, task::LxTask};
74use serde::{Serialize, de::DeserializeOwned};
75use tower::{
76    Layer, buffer::BufferLayer, limit::ConcurrencyLimitLayer,
77    load_shed::LoadShedLayer, timeout::TimeoutLayer, util::MapRequestLayer,
78};
79use tracing::{Instrument, debug, error, info, warn};
80
81use crate::{rest, trace};
82
83/// The grace period passed to [`axum_server::Handle::graceful_shutdown`] during
84/// which new connections are refused and we wait for existing connections to
85/// terminate before initiating a hard shutdown.
86const SHUTDOWN_GRACE_PERIOD: Duration = Duration::from_secs(3);
87/// The maximum time we'll wait for a server to complete shutdown.
88pub const SERVER_SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(5);
89lexe_std::const_assert!(
90    SHUTDOWN_GRACE_PERIOD.as_secs() < SERVER_SHUTDOWN_TIMEOUT.as_secs()
91);
92
93/// The default maximum time a server can spend handling a request.
94pub const SERVER_HANDLER_TIMEOUT: Duration = Duration::from_secs(25);
95lexe_std::const_assert!(
96    rest::API_REQUEST_TIMEOUT.as_secs() > SERVER_HANDLER_TIMEOUT.as_secs()
97);
98
99/// A configuration object for Axum / Tower middleware.
100///
101/// Defaults:
102///
103/// ```
104/// # use std::time::Duration;
105/// # use lexe_api::server::LayerConfig;
106/// assert_eq!(
107///     LayerConfig::default(),
108///     LayerConfig {
109///         body_limit: 16384,
110///         buffer_size: 4096,
111///         concurrency: 4096,
112///         handling_timeout: Duration::from_secs(25),
113///         default_fallback: true,
114///     }
115/// );
116/// ```
117#[derive(Clone, Debug, Eq, PartialEq)]
118pub struct LayerConfig {
119    /// The maximum size of the request body in bytes.
120    /// Helps prevent DoS, but may need to be increased for some services.
121    pub body_limit: usize,
122    /// The size of the work buffer for our service.
123    /// Allows the server to immediately work on more queued requests when a
124    /// request completes, and prevents a large backlog from building up.
125    pub buffer_size: usize,
126    /// The maximum # of requests we'll process at once.
127    /// Helps prevent the CPU from maxing out, resulting in thrashing.
128    pub concurrency: usize,
129    /// The maximum time a server can spend handling a request. Helps prevent
130    /// degenerate cases which take abnormally long to process from crowding
131    /// out normal workloads.
132    pub handling_timeout: Duration,
133    /// Whether to add Lexe's default [`Router::fallback`] to the [`Router`].
134    /// The [`Router::fallback`] is called if no routes were matched;
135    /// Lexe's [`default_fallback`] returns a "bad endpoint" rejection along
136    /// with the requested method and path.
137    ///
138    /// If you need to set a custom fallback, set this to [`false`], otherwise
139    /// your custom fallback will be clobbered by Lexe's [`default_fallback`].
140    /// NOTE, however, that the caller is responsible for ensuring that the
141    /// [`Router`] has a fallback configured in this case.
142    pub default_fallback: bool,
143}
144
145impl Default for LayerConfig {
146    fn default() -> Self {
147        Self {
148            // 16KiB is sufficient for most Lexe services.
149            body_limit: 16384,
150            // TODO(max): We are using very high values right now because it
151            // doesn't make sense to constrain anything until we have run some
152            // load tests to profile performance and see what breaks.
153            buffer_size: 4096,
154            concurrency: 4096,
155            handling_timeout: SERVER_HANDLER_TIMEOUT,
156            default_fallback: true,
157        }
158    }
159}
160
161// --- Server helpers --- //
162
163/// Construct a server URL given the [`TcpListener::local_addr`] from by a
164/// server's [`TcpListener`], and its DNS name.
165///
166/// ex: `https://lexe.app` (port=443)
167/// ex: `https://relay.lexe.app:4396`
168/// ex: `http://[::1]:8080`
169//
170// We have a fn to build the url because it's easy to mess up.
171pub fn build_server_url(
172    // The output of `TcpListener::local_addr`
173    listener_addr: SocketAddr,
174    // Primary DNS name
175    maybe_dns: Option<&str>,
176) -> String {
177    match maybe_dns {
178        Some(dns_name) => {
179            let port = listener_addr.port();
180            if port == 443 {
181                format!("https://{dns_name}")
182            } else {
183                format!("https://{dns_name}:{port}")
184            }
185        }
186        None => format!("http://{listener_addr}"),
187    }
188}
189
190/// Constructs an API server future which can be spawned into a task.
191/// Additionally returns the server url.
192///
193/// Use this helper when it is useful to poll multiple futures in a single task
194/// to reduce the amount of task nesting / indirection. If there is only one
195/// future that needs to be driven, use [`spawn_server_task`] instead.
196///
197/// Errors if the [`TcpListener`] failed to bind or return its local address.
198/// Returns the server future along with the bound socket address.
199// Avoids generic parameters to prevent binary bloat.
200// Returns unnamed `impl Future` to avoid Pin<Box<T>> deref cost.
201pub fn build_server_fut(
202    bind_addr: SocketAddr,
203    router: Router<()>,
204    layer_config: LayerConfig,
205    // TLS config + primary DNS name
206    maybe_tls_and_dns: Option<(Arc<rustls::ServerConfig>, &str)>,
207    server_span_name: &str,
208    server_span: tracing::Span,
209    // Send on this channel to begin a graceful shutdown of the server.
210    shutdown: NotifyOnce,
211) -> anyhow::Result<(impl Future<Output = ()>, String)> {
212    let listener =
213        TcpListener::bind(bind_addr).context("Could not bind TCP listener")?;
214    let (server_fut, primary_server_url) = build_server_fut_with_listener(
215        listener,
216        router,
217        layer_config,
218        maybe_tls_and_dns,
219        server_span_name,
220        server_span,
221        shutdown,
222    )
223    .context("Could not build server future")?;
224    Ok((server_fut, primary_server_url))
225}
226
227/// [`build_server_fut`] but takes a [`TcpListener`] instead of [`SocketAddr`].
228// Avoids generic parameters to prevent binary bloat.
229// Returns unnamed `impl Future` to avoid Pin<Box<T>> deref cost.
230pub fn build_server_fut_with_listener(
231    listener: TcpListener,
232    router: Router<()>,
233    layer_config: LayerConfig,
234    // TLS config + primary DNS name
235    maybe_tls_and_dns: Option<(Arc<rustls::ServerConfig>, &str)>,
236    server_span_name: &str,
237    server_span: tracing::Span,
238    // Send on this channel to begin a graceful shutdown of the server.
239    mut shutdown: NotifyOnce,
240) -> anyhow::Result<(impl Future<Output = ()> + use<>, String)> {
241    let (maybe_tls_config, maybe_dns) = maybe_tls_and_dns.unzip();
242    let listener_addr = listener
243        .local_addr()
244        .context("Could not get listener local address")?;
245    let primary_server_url = build_server_url(listener_addr, maybe_dns);
246    info!("Url for {server_span_name}: {primary_server_url}");
247
248    // Add Lexe's default fallback if it is enabled in the LayerConfig.
249    let router = if layer_config.default_fallback {
250        router.fallback(default_fallback)
251    } else {
252        router
253    };
254
255    // Used to annotate the service / request / response types
256    // at each point in the ServiceBuilder chains.
257    type HyperService = RouterIntoService<hyper::body::Incoming, ()>;
258    type AxumService = RouterIntoService<axum::body::Body, ()>;
259    type HyperReq = http::Request<hyper::body::Incoming>;
260    type AxumReq = http::Request<axum::body::Body>;
261    type AxumResp = http::Response<axum::body::Body>;
262    type TraceResp = http::Response<
263        tower_http::trace::ResponseBody<
264            axum::body::Body,
265            tower_http::classify::NeverClassifyEos<anyhow::Error>,
266            (),
267            trace::server::LxOnEos,
268            trace::server::LxOnFailure,
269        >,
270    >;
271
272    // The outer middleware stack which wraps the entire Router.
273    //
274    // Axum docs explain ordering better than tower's ServiceBuilder docs do:
275    // https://docs.rs/axum/latest/axum/middleware/index.html#ordering
276    // Basically, requests go from top to bottom and responses bottom to top.
277    let outer_middleware = tower::ServiceBuilder::new()
278        .check_service::<HyperService, HyperReq, AxumResp, Infallible>()
279        // Log everything on its way in and out, even load-shedded requests.
280        // This layer changes the response type.
281        .layer(trace::server::trace_layer(server_span.clone()))
282        .check_service::<HyperService, HyperReq, TraceResp, Infallible>()
283        // Run our post-processor which can modify responses *after* the Axum
284        // Router has constructed the response.
285        .layer(tower::util::MapResponseLayer::new(
286            middleware::post_process_response,
287        ))
288        .check_service::<HyperService, HyperReq, TraceResp, Infallible>();
289
290    // The inner middleware stack which is cloned to each route in the Router.
291    // We put most of the layers here because it is a lot easier to work with
292    // axum types; moving these outside quickly degenerates into type hell.
293    let inner_middleware = tower::ServiceBuilder::new()
294        .check_service::<AxumService, AxumReq, AxumResp, Infallible>()
295        // Immediately reject anything with a CONTENT_LENGTH over the limit.
296        .layer(axum::middleware::map_request_with_state(
297            layer_config.body_limit,
298            middleware::check_content_length_header,
299        ))
300        .check_service::<AxumService, AxumReq, AxumResp, Infallible>()
301        // Set the default request body limit for all requests. This adds a
302        // `DefaultBodyLimitKind` (private axum type) into the request
303        // extensions so that any inner layers or extractors which call
304        // `axum::RequestExt::[with|into]_limited_body` will pick it up.
305        // NOTE that many of our extractors transitively rely on the Bytes
306        // extractor which will default to a 2MB limit if this is not set.
307        .layer(DefaultBodyLimit::max(layer_config.body_limit))
308        .check_service::<AxumService, AxumReq, AxumResp, Infallible>()
309        // Here, we explicitly apply the body limit from the request extensions,
310        // transforming the request body type into `http_body_util::Limited`.
311        .layer(MapRequestLayer::new(axum::RequestExt::with_limited_body))
312        .check_service::<AxumService, AxumReq, AxumResp, Infallible>()
313        // Handles errors from the load_shed, buffer, and concurrency layers.
314        .layer(HandleErrorLayer::new(|_: tower::BoxError| async move {
315            CommonApiError {
316                kind: CommonErrorKind::AtCapacity,
317                msg: "Service is at capacity; retry later".to_owned(),
318            }
319        }))
320        // Returns an `Err` if the inner service returns `Poll::Pending`.
321        // Helps prevent OOM when combined with the buffer or concurrency layer.
322        .layer(LoadShedLayer::new())
323        .check_service::<AxumService, AxumReq, AxumResp, Infallible>()
324        // Returns Poll::Pending when the buffer is full (backpressure).
325        // Allows the server to immediately work on more queued requests when a
326        // request completes, and prevents a large backlog from building up.
327        // Note that while the layer is often cloned, the buffer itself is not.
328        .layer(BufferLayer::new(layer_config.buffer_size))
329        .check_service::<AxumService, AxumReq, AxumResp, Infallible>()
330        // Returns `Poll::Pending` when the concurrency limit has been reached.
331        // Helps prevent the CPU from maxing out, resulting in thrashing.
332        .layer(ConcurrencyLimitLayer::new(layer_config.concurrency))
333        .check_service::<AxumService, AxumReq, AxumResp, Infallible>()
334        // Handles errors generated by the timeout layer.
335        .layer(HandleErrorLayer::new(|_: tower::BoxError| async move {
336            CommonApiError {
337                kind: CommonErrorKind::Server,
338                msg: "Server timed out handling request".to_owned(),
339            }
340        }))
341        // Returns an error if the inner service takes longer than the timeout
342        // to handle the request. Prevents degenerate cases which take
343        // abnormally long to process from crowding out normal workloads.
344        .layer(TimeoutLayer::new(layer_config.handling_timeout))
345        .check_service::<AxumService, AxumReq, AxumResp, Infallible>();
346
347    // Apply inner middleware
348    let layered_router = router.layer(inner_middleware);
349    // Convert into Service
350    let router_service = layered_router.into_service::<hyper::body::Incoming>();
351    // Apply outer middleware
352    let layered_service = Layer::layer(&outer_middleware, router_service);
353    // Convert into MakeService
354    let make_service = layered_service.into_make_service();
355
356    let handle = axum_server::Handle::new();
357    let handle_clone = handle.clone();
358    let server_fut = async {
359        let serve_result = match maybe_tls_config {
360            Some(tls_config) => {
361                let axum_tls_config = RustlsConfig::from_config(tls_config);
362                axum_server::from_tcp_rustls(listener, axum_tls_config)
363                    .handle(handle_clone)
364                    .serve(make_service)
365                    .await
366            }
367            None =>
368                axum_server::from_tcp(listener)
369                    .handle(handle_clone)
370                    .serve(make_service)
371                    .await,
372        };
373
374        serve_result
375            // See axum_server::Server::serve docs for why this can't error
376            .expect("No binding + axum MakeService::poll_ready never errors");
377    };
378
379    let graceful_shutdown_fut = async move {
380        shutdown.recv().await;
381        info!("Shutting down API server");
382        // The 'grace period' is a period of time during which new connections
383        // are refused and `axum_server::Server::serve` waits for all current
384        // connections to terminate. If `None`, the server waits indefinitely
385        // for current connections to terminate; if `Some`, the server will
386        // initiate a hard shutdown after the grace period has elapsed. We use
387        // Some(_) with a relatively short grace period because (1) our handlers
388        // shouldn't take long to return and (2) we sometimes see connections
389        // failing to terminate for servers which have a /shutdown endpoint.
390        handle.graceful_shutdown(Some(SHUTDOWN_GRACE_PERIOD));
391    };
392
393    let combined_fut = async {
394        tokio::pin!(server_fut);
395        tokio::select! {
396            biased; // Ensure graceful shutdown future finishes first
397            () = graceful_shutdown_fut => (),
398            _ = &mut server_fut => return error!("Server exited early"),
399        }
400        match tokio::time::timeout(SERVER_SHUTDOWN_TIMEOUT, server_fut).await {
401            Ok(()) => info!("API server finished"),
402            Err(_) => warn!("API server timed out during shutdown"),
403        }
404    }
405    .instrument(server_span);
406
407    Ok((combined_fut, primary_server_url))
408}
409
410/// [`build_server_fut`] but additionally spawns the server future into an
411/// instrumented server task and logs the full URL used to access the server.
412/// Returns the server task and server url.
413pub fn spawn_server_task(
414    bind_addr: SocketAddr,
415    router: Router<()>,
416    layer_config: LayerConfig,
417    // TLS config + primary DNS name
418    maybe_tls_and_dns: Option<(Arc<rustls::ServerConfig>, &str)>,
419    server_span_name: Cow<'static, str>,
420    server_span: tracing::Span,
421    // Send on this channel to begin a graceful shutdown of the server.
422    shutdown: NotifyOnce,
423) -> anyhow::Result<(LxTask<()>, String)> {
424    let listener = TcpListener::bind(bind_addr)
425        .context(bind_addr)
426        .context("Failed to bind TcpListener")?;
427
428    let (server_task, primary_server_url) = spawn_server_task_with_listener(
429        listener,
430        router,
431        layer_config,
432        maybe_tls_and_dns,
433        server_span_name,
434        server_span,
435        shutdown,
436    )
437    .context("spawn_server_task_with_listener failed")?;
438
439    Ok((server_task, primary_server_url))
440}
441
442/// [`spawn_server_task`] but takes [`TcpListener`] instead of [`SocketAddr`].
443pub fn spawn_server_task_with_listener(
444    listener: TcpListener,
445    router: Router<()>,
446    layer_config: LayerConfig,
447    // TLS config + primary DNS name
448    maybe_tls_and_dns: Option<(Arc<rustls::ServerConfig>, &str)>,
449    server_span_name: Cow<'static, str>,
450    server_span: tracing::Span,
451    // Send on this channel to begin a graceful shutdown of the server.
452    shutdown: NotifyOnce,
453) -> anyhow::Result<(LxTask<()>, String)> {
454    let (server_fut, primary_server_url) = build_server_fut_with_listener(
455        listener,
456        router,
457        layer_config,
458        maybe_tls_and_dns,
459        &server_span_name,
460        server_span.clone(),
461        shutdown,
462    )
463    .context("Failed to build server future")?;
464
465    let server_task =
466        LxTask::spawn_with_span(server_span_name, server_span, server_fut);
467
468    Ok((server_task, primary_server_url))
469}
470
471// --- LxJson --- //
472
473/// A version of [`axum::Json`] which conforms to Lexe's (JSON) API.
474/// It can be used as either an extractor or a response.
475///
476/// - As an extractor: rejections return [`LxRejection`].
477/// - As a success response:
478///   - Serialization success returns an [`http::Response`] with JSON body.
479///   - Serialization failure returns a [`ErrorResponse`].
480///
481/// [`axum::Json`] is banned because:
482///
483/// - Rejections return [`JsonRejection`] which is just a string HTTP body.
484/// - Response serialization failures likewise return just a string body.
485///
486/// NOTE: This must only be used for forming *success* API responses,
487/// i.e. `T` in `Result<T, E>`, because its [`IntoResponse`] impl uses
488/// [`StatusCode::OK`]. Our API error types, while also serialized as JSON,
489/// have separate [`IntoResponse`] impls which return error statuses.
490///
491/// [`ErrorResponse`]: lexe_api_core::error::ErrorResponse
492pub struct LxJson<T>(pub T);
493
494impl<T: DeserializeOwned, S: Send + Sync> FromRequest<S> for LxJson<T> {
495    type Rejection = LxRejection;
496
497    async fn from_request(
498        req: http::Request<axum::body::Body>,
499        state: &S,
500    ) -> Result<Self, Self::Rejection> {
501        // `axum::Json`'s from_request impl is fine but its rejection is not
502        axum::Json::from_request(req, state)
503            .await
504            .map(|axum::Json(t)| Self(t))
505            .map_err(LxRejection::from)
506    }
507}
508
509impl<T: Serialize> IntoResponse for LxJson<T> {
510    fn into_response(self) -> http::Response<axum::body::Body> {
511        axum_helpers::build_json_response(StatusCode::OK, &self.0)
512    }
513}
514
515impl<T: Clone> Clone for LxJson<T> {
516    fn clone(&self) -> Self {
517        Self(self.0.clone())
518    }
519}
520
521impl<T: Copy> Copy for LxJson<T> {}
522
523impl<T: fmt::Debug> fmt::Debug for LxJson<T> {
524    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
525        T::fmt(&self.0, f)
526    }
527}
528
529impl<T: Eq + PartialEq> Eq for LxJson<T> {}
530
531impl<T: PartialEq> PartialEq for LxJson<T> {
532    fn eq(&self, other: &Self) -> bool {
533        self.0.eq(&other.0)
534    }
535}
536
537// --- LxBytes --- //
538
539/// A version of [`Bytes`] which conforms to Lexe's (binary) API.
540/// - [`axum`] has implementations of [`FromRequest`] and [`IntoResponse`] for
541///   [`Bytes`], but these implementations are not Lexe API-conformant.
542/// - This type can be used as either an extractor or a success response, and
543///   should always be used instead of [`Bytes`] in these server contexts.
544/// - It is still fine to use [`Bytes`] on the client side.
545///
546/// - As an extractor: rejections return [`LxRejection`].
547/// - As a success response:
548///   - Returns an [`http::Response`] with a binary body.
549///
550///   - Any failure encountered in extraction or creation should produce an
551///     [`ErrorResponse`].
552///
553/// The regular impls are non-conformant because:
554///
555/// - Rejections return [`BytesRejection`] which is just a string HTTP body.
556///
557/// NOTE: This must only be used for forming *success* API responses,
558/// i.e. `LxBytes` in `Result<LxBytes, E>`, because its [`IntoResponse`] impl
559/// uses [`StatusCode::OK`]. Our API error types are serialized as JSON and
560/// have separate [`IntoResponse`] impls which return error statuses.
561///
562/// [`ErrorResponse`]: lexe_api_core::error::ErrorResponse
563#[derive(Clone, Debug, Default, Eq, PartialEq, Ord, PartialOrd)]
564pub struct LxBytes(pub Bytes);
565
566impl<S: Send + Sync> FromRequest<S> for LxBytes {
567    type Rejection = LxRejection;
568
569    async fn from_request(
570        req: http::Request<axum::body::Body>,
571        state: &S,
572    ) -> Result<Self, Self::Rejection> {
573        // `Bytes`'s from_request impl is fine but its rejection is not
574        Bytes::from_request(req, state)
575            .await
576            .map(Self)
577            .map_err(LxRejection::from)
578    }
579}
580
581/// The [`Bytes`] [`IntoResponse`] impl is almost exactly the same,
582/// except it returns the wrong HTTP version.
583impl IntoResponse for LxBytes {
584    fn into_response(self) -> http::Response<axum::body::Body> {
585        let http_body = http_body_util::Full::new(self.0);
586        let axum_body = axum::body::Body::new(http_body);
587
588        axum_helpers::default_response_builder()
589            .header(
590                CONTENT_TYPE,
591                // Or `HeaderValue::from_static(mime::APPLICATION_OCTET_STREAM)`
592                HeaderValue::from_static("application/octet-stream"),
593            )
594            .status(StatusCode::OK)
595            .body(axum_body)
596            .expect("All operations here should be infallible")
597    }
598}
599
600impl<T: Into<Bytes>> From<T> for LxBytes {
601    fn from(t: T) -> Self {
602        Self(t.into())
603    }
604}
605
606// --- LxRejection --- //
607
608/// Our own [`axum::extract::rejection`] type with an [`IntoResponse`] impl
609/// which conforms to Lexe's API. Contains the source rejection's error text.
610pub struct LxRejection {
611    /// Which [`axum::extract::rejection`] this [`LxRejection`] was built from.
612    kind: LxRejectionKind,
613    /// The error text of the source rejection, or additional context.
614    source_msg: String,
615}
616
617/// The source of this [`LxRejection`].
618enum LxRejectionKind {
619    // -- From `axum::extract::rejection` -- //
620    /// [`BytesRejection`]
621    Bytes,
622    /// [`JsonRejection`]
623    Json,
624    /// [`PathRejection`]
625    Path,
626    /// [`QueryRejection`]
627    Query,
628
629    // -- Other -- //
630    /// Bearer authentication failed
631    Unauthenticated,
632    /// Client is not authorized to access this resource
633    Unauthorized,
634    /// Client request did not match any paths in the [`Router`].
635    BadEndpoint,
636    /// Request body length over limit
637    BodyLengthOverLimit,
638    /// [`ed25519::Error`]
639    Ed25519,
640    /// Gateway proxy
641    Proxy,
642}
643
644// Use explicit `.map_err()`s instead of From impls for non-obvious conversions
645impl LxRejection {
646    pub fn from_ed25519(error: ed25519::Error) -> Self {
647        Self {
648            kind: LxRejectionKind::Ed25519,
649            source_msg: format!("{error:#}"),
650        }
651    }
652
653    pub fn from_bearer_auth(error: auth::Error) -> Self {
654        Self {
655            kind: LxRejectionKind::Unauthenticated,
656            source_msg: format!("{error:#}"),
657        }
658    }
659
660    pub fn scope_unauthorized(
661        granted_scope: &Scope,
662        requested_scope: &Scope,
663    ) -> Self {
664        Self {
665            kind: LxRejectionKind::Unauthorized,
666            source_msg: format!(
667                "granted scope: {granted_scope:?}, requested scope: {requested_scope:?}"
668            ),
669        }
670    }
671
672    pub fn proxy(error: impl Display) -> Self {
673        Self {
674            kind: LxRejectionKind::Proxy,
675            source_msg: format!("{error:#}"),
676        }
677    }
678}
679
680impl From<BytesRejection> for LxRejection {
681    fn from(bytes_rejection: BytesRejection) -> Self {
682        Self {
683            kind: LxRejectionKind::Bytes,
684            source_msg: bytes_rejection.body_text(),
685        }
686    }
687}
688
689impl From<JsonRejection> for LxRejection {
690    fn from(json_rejection: JsonRejection) -> Self {
691        Self {
692            kind: LxRejectionKind::Json,
693            source_msg: json_rejection.body_text(),
694        }
695    }
696}
697
698impl From<PathRejection> for LxRejection {
699    fn from(path_rejection: PathRejection) -> Self {
700        Self {
701            kind: LxRejectionKind::Path,
702            source_msg: path_rejection.body_text(),
703        }
704    }
705}
706
707impl From<QueryRejection> for LxRejection {
708    fn from(query_rejection: QueryRejection) -> Self {
709        Self {
710            kind: LxRejectionKind::Query,
711            source_msg: query_rejection.body_text(),
712        }
713    }
714}
715
716impl IntoResponse for LxRejection {
717    fn into_response(self) -> http::Response<axum::body::Body> {
718        // TODO(phlip9): authn+authz+badendpoint rejections should return
719        // standard status codes, not just 400.
720        let kind = CommonErrorKind::Rejection;
721        // "Bad JSON: Failed to deserialize the JSON body into the target type"
722        let kind_msg = self.kind.to_msg();
723        let source_msg = &self.source_msg;
724        let msg = format!("Rejection: {kind_msg}: {source_msg}");
725        // Log the rejection now since our trace layer can't access this info
726        warn!("{msg}");
727        let common_error = CommonApiError { kind, msg };
728        common_error.into_response()
729    }
730}
731
732impl LxRejectionKind {
733    /// A generic error message for this rejection kind.
734    fn to_msg(&self) -> &'static str {
735        match self {
736            Self::Bytes => "Bad request bytes",
737            Self::Json => "Client provided bad JSON",
738            Self::Path => "Client provided bad path parameter",
739            Self::Query => "Client provided bad query string",
740
741            Self::Unauthenticated => "Invalid bearer auth",
742            Self::Unauthorized => "Not authorized to access this resource",
743            Self::BadEndpoint => "Client requested a non-existent endpoint",
744            Self::BodyLengthOverLimit => "Request body length over limit",
745            Self::Ed25519 => "Ed25519 error",
746            Self::Proxy => "Proxy error",
747        }
748    }
749}
750
751// --- Extractors --- //
752
753pub mod extract {
754    use axum::extract::FromRequestParts;
755
756    use super::*;
757
758    /// Lexe API-compliant version of [`axum::extract::Query`].
759    pub struct LxQuery<T>(pub T);
760
761    impl<T: DeserializeOwned, S: Send + Sync> FromRequestParts<S> for LxQuery<T> {
762        type Rejection = LxRejection;
763
764        async fn from_request_parts(
765            parts: &mut http::request::Parts,
766            state: &S,
767        ) -> Result<Self, Self::Rejection> {
768            axum::extract::Query::from_request_parts(parts, state)
769                .await
770                .map(|axum::extract::Query(t)| Self(t))
771                .map_err(LxRejection::from)
772        }
773    }
774
775    impl<T: Clone> Clone for LxQuery<T> {
776        fn clone(&self) -> Self {
777            Self(self.0.clone())
778        }
779    }
780
781    impl<T: fmt::Debug> fmt::Debug for LxQuery<T> {
782        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
783            T::fmt(&self.0, f)
784        }
785    }
786
787    impl<T: Eq + PartialEq> Eq for LxQuery<T> {}
788
789    impl<T: PartialEq> PartialEq for LxQuery<T> {
790        fn eq(&self, other: &Self) -> bool {
791            self.0.eq(&other.0)
792        }
793    }
794
795    /// Lexe API-compliant version of [`axum::extract::Path`].
796    pub struct LxPath<T>(pub T);
797
798    impl<T: DeserializeOwned + Send, S: Send + Sync> FromRequestParts<S>
799        for LxPath<T>
800    {
801        type Rejection = LxRejection;
802
803        async fn from_request_parts(
804            parts: &mut http::request::Parts,
805            state: &S,
806        ) -> Result<Self, Self::Rejection> {
807            axum::extract::Path::from_request_parts(parts, state)
808                .await
809                .map(|axum::extract::Path(t)| Self(t))
810                .map_err(LxRejection::from)
811        }
812    }
813
814    impl<T: Clone> Clone for LxPath<T> {
815        fn clone(&self) -> Self {
816            Self(self.0.clone())
817        }
818    }
819
820    impl<T: fmt::Debug> fmt::Debug for LxPath<T> {
821        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
822            T::fmt(&self.0, f)
823        }
824    }
825
826    impl<T: Eq + PartialEq> Eq for LxPath<T> {}
827
828    impl<T: PartialEq> PartialEq for LxPath<T> {
829        fn eq(&self, other: &Self) -> bool {
830            self.0.eq(&other.0)
831        }
832    }
833}
834
835// --- Custom middleware --- //
836
837pub mod middleware {
838    use axum::extract::State;
839    use http::HeaderName;
840
841    use super::*;
842
843    /// The header name used for response post-processing signals.
844    pub static POST_PROCESS_HEADER: HeaderName =
845        HeaderName::from_static("lx-post-process");
846
847    /// Checks the `CONTENT_LENGTH` header and returns an early rejection if the
848    /// contained value exceeds our configured body limit. This optimization
849    /// allows us to avoid unnecessary work processing the request further.
850    ///
851    /// NOTE: This does not enforce the body length!! Use [`DefaultBodyLimit`]
852    /// in combination with [`axum::RequestExt::with_limited_body`] to do so.
853    pub async fn check_content_length_header<B>(
854        // `LayerConfig::body_limit`
855        State(config_body_limit): State<usize>,
856        request: http::Request<B>,
857    ) -> Result<http::Request<B>, LxRejection> {
858        let maybe_content_length = request
859            .headers()
860            .get(http::header::CONTENT_LENGTH)
861            .and_then(|value| value.to_str().ok())
862            .and_then(|value_str| usize::from_str(value_str).ok());
863
864        // If a limit is configured and the header value exceeds it, reject.
865        if let Some(content_length) = maybe_content_length
866            && content_length > config_body_limit
867        {
868            return Err(LxRejection {
869                kind: LxRejectionKind::BodyLengthOverLimit,
870                source_msg: "Content length header over limit".to_owned(),
871            });
872        }
873
874        Ok(request)
875    }
876
877    /// A post-processor which can be used to modify the [`http::Response`]s
878    /// returned by an [`axum::Router`]. This is done by signalling the desired
879    /// modification in a fake [`POST_PROCESS_HEADER`] which is also removed
880    /// during post-processing. This can be used to override Axum defaults
881    /// which one does not have access to from within the [`Router`]. Currently,
882    /// this only supports a "remove-content-length" command which removes the
883    /// content-length header set by Axum, but can be easily extended.
884    pub(super) fn post_process_response(
885        mut response: http::Response<axum::body::Body>,
886    ) -> http::Response<axum::body::Body> {
887        let value = match response.headers_mut().remove(&POST_PROCESS_HEADER) {
888            Some(v) => v,
889            None => return response,
890        };
891
892        match value.as_bytes() {
893            b"remove-content-length" => {
894                response.headers_mut().remove(http::header::CONTENT_LENGTH);
895                debug!("Post process: Removed content-length header");
896            }
897            unknown => {
898                let unknown_str = String::from_utf8_lossy(unknown);
899                warn!("Post process: Invalid header value: {unknown_str}");
900            }
901        }
902
903        response
904    }
905}
906
907// --- Helpers --- //
908
909/// Lexe's default fallback [`Handler`](axum::handler::Handler).
910/// Returns a "bad endpoint" rejection along with the requested method and path.
911pub async fn default_fallback(
912    method: http::Method,
913    uri: http::Uri,
914) -> LxRejection {
915    let path = uri.path();
916    LxRejection {
917        kind: LxRejectionKind::BadEndpoint,
918        // e.g. "POST /app/node_info"
919        source_msg: format!("{method} {path}"),
920    }
921}