Skip to main content

lexe_api/
server.rs

1// This is the only place where we are allowed to use e.g. `Json` and `Query`.
2#![allow(clippy::disallowed_types)]
3
4//! This module provides various API server utilities.
5//!
6//! # Serving
7//!
8//! Methods to serve a [`Router`] with a fallback handler (for unmatched paths),
9//! tracing / request instrumentation, backpressure, load shedding, concurrency
10//! limits, server-side timeouts, TLS, and graceful shutdown:
11//!
12//! - [`build_server_fut`]
13//! - [`build_server_fut_with_listener`]
14//! - [`spawn_server_task`]
15//! - [`spawn_server_task_with_listener`]
16//!
17//! # Extractors to get data from requests:
18//!
19//! - [`LxJson`] to deserialize from HTTP body JSON
20//! - [`LxQuery`] to deserialize from query strings
21//!
22//! # [`IntoResponse`] types / impls for building Lexe API-conformant responses:
23//!
24//! - [`LxJson`] type for returning success responses as JSON
25//! - All [`ApiError`]s and [`CommonApiError`] impl [`IntoResponse`]
26//! - [`LxRejection`] for notifying clients of bad JSON, query strings, etc.
27//!
28//! [`ApiError`]: lexe_api_core::error::ApiError
29//! [`CommonApiError`]: lexe_api_core::error::CommonApiError
30//! [`Router`]: axum::Router
31//! [`IntoResponse`]: axum::response::IntoResponse
32//! [`LxJson`]: crate::server::LxJson
33//! [`LxQuery`]: crate::server::extract::LxQuery
34//! [`LxRejection`]: crate::server::LxRejection
35//! [`build_server_fut`]: crate::server::build_server_fut
36//! [`build_server_fut_with_listener`]: crate::server::build_server_fut_with_listener
37//! [`spawn_server_task`]: crate::server::spawn_server_task
38//! [`spawn_server_task_with_listener`]: crate::server::spawn_server_task_with_listener
39
40use std::{
41    borrow::Cow,
42    convert::Infallible,
43    fmt::{self, Display},
44    future::Future,
45    net::{SocketAddr, TcpListener},
46    str::FromStr,
47    sync::Arc,
48    time::Duration,
49};
50
51use anyhow::Context;
52use axum::{
53    Router, ServiceExt as AxumServiceExt,
54    error_handling::HandleErrorLayer,
55    extract::{
56        DefaultBodyLimit, FromRequest, OptionalFromRequest,
57        rejection::{
58            BytesRejection, JsonRejection, PathRejection, QueryRejection,
59        },
60    },
61    response::IntoResponse,
62    routing::RouterIntoService,
63};
64use axum_server::tls_rustls::RustlsConfig;
65use bytes::Bytes;
66use http::{HeaderValue, StatusCode, header::CONTENT_TYPE};
67use lexe_api_core::{
68    axum_helpers,
69    error::{CommonApiError, CommonErrorKind},
70};
71use lexe_common::api::auth::{self, Scope};
72use lexe_crypto::ed25519;
73use lexe_tokio::{notify_once::NotifyOnce, task::LxTask};
74use serde::{Serialize, de::DeserializeOwned};
75use tower::{
76    Layer, buffer::BufferLayer, limit::ConcurrencyLimitLayer,
77    load_shed::LoadShedLayer, timeout::TimeoutLayer, util::MapRequestLayer,
78};
79use tracing::{Instrument, debug, error, info, warn};
80
81use crate::{rest, trace};
82
83/// The grace period passed to [`axum_server::Handle::graceful_shutdown`] during
84/// which new connections are refused and we wait for existing connections to
85/// terminate before initiating a hard shutdown.
86const SHUTDOWN_GRACE_PERIOD: Duration = Duration::from_secs(3);
87/// The maximum time we'll wait for a server to complete shutdown.
88pub const SERVER_SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(5);
89lexe_std::const_assert!(
90    SHUTDOWN_GRACE_PERIOD.as_secs() < SERVER_SHUTDOWN_TIMEOUT.as_secs()
91);
92
93/// The default maximum time a server can spend handling a request.
94pub const SERVER_HANDLER_TIMEOUT: Duration = Duration::from_secs(25);
95lexe_std::const_assert!(
96    rest::API_REQUEST_TIMEOUT.as_secs() > SERVER_HANDLER_TIMEOUT.as_secs()
97);
98
99/// A configuration object for Axum / Tower middleware.
100///
101/// Defaults:
102///
103/// ```
104/// # use std::time::Duration;
105/// # use lexe_api::server::LayerConfig;
106/// assert_eq!(
107///     LayerConfig::default(),
108///     LayerConfig {
109///         body_limit: 16384,
110///         buffer_size: 4096,
111///         concurrency: 4096,
112///         handling_timeout: Duration::from_secs(25),
113///         default_fallback: true,
114///     }
115/// );
116/// ```
117#[derive(Clone, Debug, Eq, PartialEq)]
118pub struct LayerConfig {
119    /// The maximum size of the request body in bytes.
120    /// Helps prevent DoS, but may need to be increased for some services.
121    pub body_limit: usize,
122    /// The size of the work buffer for our service.
123    /// Allows the server to immediately work on more queued requests when a
124    /// request completes, and prevents a large backlog from building up.
125    pub buffer_size: usize,
126    /// The maximum # of requests we'll process at once.
127    /// Helps prevent the CPU from maxing out, resulting in thrashing.
128    pub concurrency: usize,
129    /// The maximum time a server can spend handling a request. Helps prevent
130    /// degenerate cases which take abnormally long to process from crowding
131    /// out normal workloads.
132    pub handling_timeout: Duration,
133    /// Whether to add Lexe's default [`Router::fallback`] to the [`Router`].
134    /// The [`Router::fallback`] is called if no routes were matched;
135    /// Lexe's [`default_fallback`] returns a "bad endpoint" rejection along
136    /// with the requested method and path.
137    ///
138    /// If you need to set a custom fallback, set this to [`false`], otherwise
139    /// your custom fallback will be clobbered by Lexe's [`default_fallback`].
140    /// NOTE, however, that the caller is responsible for ensuring that the
141    /// [`Router`] has a fallback configured in this case.
142    pub default_fallback: bool,
143}
144
145impl Default for LayerConfig {
146    fn default() -> Self {
147        Self {
148            // 16KiB is sufficient for most Lexe services.
149            body_limit: 16384,
150            // TODO(max): We are using very high values right now because it
151            // doesn't make sense to constrain anything until we have run some
152            // load tests to profile performance and see what breaks.
153            buffer_size: 4096,
154            concurrency: 4096,
155            handling_timeout: SERVER_HANDLER_TIMEOUT,
156            default_fallback: true,
157        }
158    }
159}
160
161// --- Server helpers --- //
162
163/// Construct a server URL given the [`TcpListener::local_addr`] from by a
164/// server's [`TcpListener`], and its DNS name.
165///
166/// ex: `https://lexe.app` (port=443)
167/// ex: `https://relay.lexe.app:4396`
168/// ex: `http://[::1]:8080`
169//
170// We have a fn to build the url because it's easy to mess up.
171pub fn build_server_url(
172    // The output of `TcpListener::local_addr`
173    listener_addr: SocketAddr,
174    // Primary DNS name
175    maybe_dns: Option<&str>,
176) -> String {
177    match maybe_dns {
178        Some(dns_name) => {
179            let port = listener_addr.port();
180            if port == 443 {
181                format!("https://{dns_name}")
182            } else {
183                format!("https://{dns_name}:{port}")
184            }
185        }
186        None => format!("http://{listener_addr}"),
187    }
188}
189
190/// Constructs an API server future which can be spawned into a task.
191/// Additionally returns the server url.
192///
193/// Use this helper when it is useful to poll multiple futures in a single task
194/// to reduce the amount of task nesting / indirection. If there is only one
195/// future that needs to be driven, use [`spawn_server_task`] instead.
196///
197/// Errors if the [`TcpListener`] failed to bind or return its local address.
198/// Returns the server future along with the bound socket address.
199// Avoids generic parameters to prevent binary bloat.
200// Returns unnamed `impl Future` to avoid Pin<Box<T>> deref cost.
201pub fn build_server_fut(
202    bind_addr: SocketAddr,
203    router: Router<()>,
204    layer_config: LayerConfig,
205    // TLS config + primary DNS name
206    maybe_tls_and_dns: Option<(Arc<rustls::ServerConfig>, &str)>,
207    server_span_name: &str,
208    server_span: tracing::Span,
209    // Send on this channel to begin a graceful shutdown of the server.
210    shutdown: NotifyOnce,
211) -> anyhow::Result<(impl Future<Output = ()>, String)> {
212    let listener =
213        TcpListener::bind(bind_addr).context("Could not bind TCP listener")?;
214    let (server_fut, primary_server_url) = build_server_fut_with_listener(
215        listener,
216        router,
217        layer_config,
218        maybe_tls_and_dns,
219        server_span_name,
220        server_span,
221        shutdown,
222    )
223    .context("Could not build server future")?;
224    Ok((server_fut, primary_server_url))
225}
226
227/// [`build_server_fut`] but takes a [`TcpListener`] instead of [`SocketAddr`].
228// Avoids generic parameters to prevent binary bloat.
229// Returns unnamed `impl Future` to avoid Pin<Box<T>> deref cost.
230pub fn build_server_fut_with_listener(
231    listener: TcpListener,
232    router: Router<()>,
233    layer_config: LayerConfig,
234    // TLS config + primary DNS name
235    maybe_tls_and_dns: Option<(Arc<rustls::ServerConfig>, &str)>,
236    server_span_name: &str,
237    server_span: tracing::Span,
238    // Send on this channel to begin a graceful shutdown of the server.
239    mut shutdown: NotifyOnce,
240) -> anyhow::Result<(impl Future<Output = ()> + use<>, String)> {
241    let (maybe_tls_config, maybe_dns) = maybe_tls_and_dns.unzip();
242    let listener_addr = listener
243        .local_addr()
244        .context("Could not get listener local address")?;
245    let primary_server_url = build_server_url(listener_addr, maybe_dns);
246    info!("Url for {server_span_name}: {primary_server_url}");
247
248    // Add Lexe's default fallback if it is enabled in the LayerConfig.
249    let router = if layer_config.default_fallback {
250        router.fallback(default_fallback)
251    } else {
252        router
253    };
254
255    // Used to annotate the service / request / response types
256    // at each point in the ServiceBuilder chains.
257    type HyperService = RouterIntoService<hyper::body::Incoming, ()>;
258    type AxumService = RouterIntoService<axum::body::Body, ()>;
259    type HyperReq = http::Request<hyper::body::Incoming>;
260    type AxumReq = http::Request<axum::body::Body>;
261    type AxumResp = http::Response<axum::body::Body>;
262    type TraceResp = http::Response<
263        tower_http::trace::ResponseBody<
264            axum::body::Body,
265            tower_http::classify::NeverClassifyEos<anyhow::Error>,
266            (),
267            trace::server::LxOnEos,
268            trace::server::LxOnFailure,
269        >,
270    >;
271
272    // The outer middleware stack which wraps the entire Router.
273    //
274    // Axum docs explain ordering better than tower's ServiceBuilder docs do:
275    // https://docs.rs/axum/latest/axum/middleware/index.html#ordering
276    // Basically, requests go from top to bottom and responses bottom to top.
277    let outer_middleware = tower::ServiceBuilder::new()
278        .check_service::<HyperService, HyperReq, AxumResp, Infallible>()
279        // Log everything on its way in and out, even load-shedded requests.
280        // This layer changes the response type.
281        .layer(trace::server::trace_layer(server_span.clone()))
282        .check_service::<HyperService, HyperReq, TraceResp, Infallible>()
283        // Run our post-processor which can modify responses *after* the Axum
284        // Router has constructed the response.
285        .layer(tower::util::MapResponseLayer::new(
286            middleware::post_process_response,
287        ))
288        .check_service::<HyperService, HyperReq, TraceResp, Infallible>();
289
290    // The inner middleware stack which is cloned to each route in the Router.
291    // We put most of the layers here because it is a lot easier to work with
292    // axum types; moving these outside quickly degenerates into type hell.
293    let inner_middleware = tower::ServiceBuilder::new()
294        .check_service::<AxumService, AxumReq, AxumResp, Infallible>()
295        // Immediately reject anything with a CONTENT_LENGTH over the limit.
296        .layer(axum::middleware::map_request_with_state(
297            layer_config.body_limit,
298            middleware::check_content_length_header,
299        ))
300        .check_service::<AxumService, AxumReq, AxumResp, Infallible>()
301        // Set the default request body limit for all requests. This adds a
302        // `DefaultBodyLimitKind` (private axum type) into the request
303        // extensions so that any inner layers or extractors which call
304        // `axum::RequestExt::[with|into]_limited_body` will pick it up.
305        // NOTE that many of our extractors transitively rely on the Bytes
306        // extractor which will default to a 2MB limit if this is not set.
307        .layer(DefaultBodyLimit::max(layer_config.body_limit))
308        .check_service::<AxumService, AxumReq, AxumResp, Infallible>()
309        // Here, we explicitly apply the body limit from the request extensions,
310        // transforming the request body type into `http_body_util::Limited`.
311        .layer(MapRequestLayer::new(axum::RequestExt::with_limited_body))
312        .check_service::<AxumService, AxumReq, AxumResp, Infallible>()
313        // Handles errors from the load_shed, buffer, and concurrency layers.
314        .layer(HandleErrorLayer::new(|_: tower::BoxError| async move {
315            CommonApiError {
316                kind: CommonErrorKind::AtCapacity,
317                msg: "Service is at capacity; retry later".to_owned(),
318            }
319        }))
320        // Returns an `Err` if the inner service returns `Poll::Pending`.
321        // Helps prevent OOM when combined with the buffer or concurrency layer.
322        .layer(LoadShedLayer::new())
323        .check_service::<AxumService, AxumReq, AxumResp, Infallible>()
324        // Returns Poll::Pending when the buffer is full (backpressure).
325        // Allows the server to immediately work on more queued requests when a
326        // request completes, and prevents a large backlog from building up.
327        // Note that while the layer is often cloned, the buffer itself is not.
328        .layer(BufferLayer::new(layer_config.buffer_size))
329        .check_service::<AxumService, AxumReq, AxumResp, Infallible>()
330        // Returns `Poll::Pending` when the concurrency limit has been reached.
331        // Helps prevent the CPU from maxing out, resulting in thrashing.
332        .layer(ConcurrencyLimitLayer::new(layer_config.concurrency))
333        .check_service::<AxumService, AxumReq, AxumResp, Infallible>()
334        // Handles errors generated by the timeout layer.
335        .layer(HandleErrorLayer::new(|_: tower::BoxError| async move {
336            CommonApiError {
337                kind: CommonErrorKind::Server,
338                msg: "Server timed out handling request".to_owned(),
339            }
340        }))
341        // Returns an error if the inner service takes longer than the timeout
342        // to handle the request. Prevents degenerate cases which take
343        // abnormally long to process from crowding out normal workloads.
344        .layer(TimeoutLayer::new(layer_config.handling_timeout))
345        .check_service::<AxumService, AxumReq, AxumResp, Infallible>();
346
347    // Apply inner middleware
348    let layered_router = router.layer(inner_middleware);
349    // Convert into Service
350    let router_service = layered_router.into_service::<hyper::body::Incoming>();
351    // Apply outer middleware
352    let layered_service = Layer::layer(&outer_middleware, router_service);
353    // Convert into MakeService
354    let make_service = layered_service.into_make_service();
355
356    let handle = axum_server::Handle::new();
357    let handle_clone = handle.clone();
358    let server_fut = async {
359        let serve_result = match maybe_tls_config {
360            Some(tls_config) => {
361                let axum_tls_config = RustlsConfig::from_config(tls_config);
362                axum_server::from_tcp_rustls(listener, axum_tls_config)
363                    .handle(handle_clone)
364                    .serve(make_service)
365                    .await
366            }
367            None =>
368                axum_server::from_tcp(listener)
369                    .handle(handle_clone)
370                    .serve(make_service)
371                    .await,
372        };
373
374        serve_result
375            // See axum_server::Server::serve docs for why this can't error
376            .expect("No binding + axum MakeService::poll_ready never errors");
377    };
378
379    let graceful_shutdown_fut = async move {
380        shutdown.recv().await;
381        info!("Shutting down API server");
382        // The 'grace period' is a period of time during which new connections
383        // are refused and `axum_server::Server::serve` waits for all current
384        // connections to terminate. If `None`, the server waits indefinitely
385        // for current connections to terminate; if `Some`, the server will
386        // initiate a hard shutdown after the grace period has elapsed. We use
387        // Some(_) with a relatively short grace period because (1) our handlers
388        // shouldn't take long to return and (2) we sometimes see connections
389        // failing to terminate for servers which have a /shutdown endpoint.
390        handle.graceful_shutdown(Some(SHUTDOWN_GRACE_PERIOD));
391    };
392
393    let combined_fut = async {
394        tokio::pin!(server_fut);
395        tokio::select! {
396            biased; // Ensure graceful shutdown future finishes first
397            () = graceful_shutdown_fut => (),
398            _ = &mut server_fut => return error!("Server exited early"),
399        }
400        match tokio::time::timeout(SERVER_SHUTDOWN_TIMEOUT, server_fut).await {
401            Ok(()) => info!("API server finished"),
402            Err(_) => warn!("API server timed out during shutdown"),
403        }
404    }
405    .instrument(server_span);
406
407    Ok((combined_fut, primary_server_url))
408}
409
410/// [`build_server_fut`] but additionally spawns the server future into an
411/// instrumented server task and logs the full URL used to access the server.
412/// Returns the server task and server url.
413pub fn spawn_server_task(
414    bind_addr: SocketAddr,
415    router: Router<()>,
416    layer_config: LayerConfig,
417    // TLS config + primary DNS name
418    maybe_tls_and_dns: Option<(Arc<rustls::ServerConfig>, &str)>,
419    server_span_name: Cow<'static, str>,
420    server_span: tracing::Span,
421    // Send on this channel to begin a graceful shutdown of the server.
422    shutdown: NotifyOnce,
423) -> anyhow::Result<(LxTask<()>, String)> {
424    let listener = TcpListener::bind(bind_addr)
425        .context(bind_addr)
426        .context("Failed to bind TcpListener")?;
427
428    let (server_task, primary_server_url) = spawn_server_task_with_listener(
429        listener,
430        router,
431        layer_config,
432        maybe_tls_and_dns,
433        server_span_name,
434        server_span,
435        shutdown,
436    )
437    .context("spawn_server_task_with_listener failed")?;
438
439    Ok((server_task, primary_server_url))
440}
441
442/// [`spawn_server_task`] but takes [`TcpListener`] instead of [`SocketAddr`].
443pub fn spawn_server_task_with_listener(
444    listener: TcpListener,
445    router: Router<()>,
446    layer_config: LayerConfig,
447    // TLS config + primary DNS name
448    maybe_tls_and_dns: Option<(Arc<rustls::ServerConfig>, &str)>,
449    server_span_name: Cow<'static, str>,
450    server_span: tracing::Span,
451    // Send on this channel to begin a graceful shutdown of the server.
452    shutdown: NotifyOnce,
453) -> anyhow::Result<(LxTask<()>, String)> {
454    let (server_fut, primary_server_url) = build_server_fut_with_listener(
455        listener,
456        router,
457        layer_config,
458        maybe_tls_and_dns,
459        &server_span_name,
460        server_span.clone(),
461        shutdown,
462    )
463    .context("Failed to build server future")?;
464
465    let server_task =
466        LxTask::spawn_with_span(server_span_name, server_span, server_fut);
467
468    Ok((server_task, primary_server_url))
469}
470
471// --- LxJson --- //
472
473/// A version of [`axum::Json`] which conforms to Lexe's (JSON) API.
474/// It can be used as either an extractor or a response.
475///
476/// - As an extractor: rejections return [`LxRejection`].
477/// - As a success response:
478///   - Serialization success returns an [`http::Response`] with JSON body.
479///   - Serialization failure returns a [`ErrorResponse`].
480///
481/// [`axum::Json`] is banned because:
482///
483/// - Rejections return [`JsonRejection`] which is just a string HTTP body.
484/// - Response serialization failures likewise return just a string body.
485///
486/// NOTE: This must only be used for forming *success* API responses,
487/// i.e. `T` in `Result<T, E>`, because its [`IntoResponse`] impl uses
488/// [`StatusCode::OK`]. Our API error types, while also serialized as JSON,
489/// have separate [`IntoResponse`] impls which return error statuses.
490///
491/// [`ErrorResponse`]: lexe_api_core::error::ErrorResponse
492pub struct LxJson<T>(pub T);
493
494impl<T: DeserializeOwned, S: Send + Sync> FromRequest<S> for LxJson<T> {
495    type Rejection = LxRejection;
496
497    async fn from_request(
498        req: http::Request<axum::body::Body>,
499        state: &S,
500    ) -> Result<Self, Self::Rejection> {
501        // `axum::Json`'s from_request impl is fine but its rejection is not
502        <axum::Json<T> as FromRequest<S>>::from_request(req, state)
503            .await
504            .map(|axum::Json(t)| Self(t))
505            .map_err(LxRejection::from)
506    }
507}
508
509impl<T: DeserializeOwned, S: Send + Sync> OptionalFromRequest<S> for LxJson<T> {
510    type Rejection = LxRejection;
511
512    async fn from_request(
513        req: http::Request<axum::body::Body>,
514        state: &S,
515    ) -> Result<Option<Self>, Self::Rejection> {
516        <axum::Json<T> as OptionalFromRequest<S>>::from_request(req, state)
517            .await
518            .map(|opt| opt.map(|axum::Json(t)| Self(t)))
519            .map_err(LxRejection::from)
520    }
521}
522
523impl<T: Serialize> IntoResponse for LxJson<T> {
524    fn into_response(self) -> http::Response<axum::body::Body> {
525        axum_helpers::build_json_response(StatusCode::OK, &self.0)
526    }
527}
528
529impl<T: Clone> Clone for LxJson<T> {
530    fn clone(&self) -> Self {
531        Self(self.0.clone())
532    }
533}
534
535impl<T: Copy> Copy for LxJson<T> {}
536
537impl<T: fmt::Debug> fmt::Debug for LxJson<T> {
538    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
539        T::fmt(&self.0, f)
540    }
541}
542
543impl<T: Eq + PartialEq> Eq for LxJson<T> {}
544
545impl<T: PartialEq> PartialEq for LxJson<T> {
546    fn eq(&self, other: &Self) -> bool {
547        self.0.eq(&other.0)
548    }
549}
550
551// --- LxBytes --- //
552
553/// A version of [`Bytes`] which conforms to Lexe's (binary) API.
554/// - [`axum`] has implementations of [`FromRequest`] and [`IntoResponse`] for
555///   [`Bytes`], but these implementations are not Lexe API-conformant.
556/// - This type can be used as either an extractor or a success response, and
557///   should always be used instead of [`Bytes`] in these server contexts.
558/// - It is still fine to use [`Bytes`] on the client side.
559///
560/// - As an extractor: rejections return [`LxRejection`].
561/// - As a success response:
562///   - Returns an [`http::Response`] with a binary body.
563///
564///   - Any failure encountered in extraction or creation should produce an
565///     [`ErrorResponse`].
566///
567/// The regular impls are non-conformant because:
568///
569/// - Rejections return [`BytesRejection`] which is just a string HTTP body.
570///
571/// NOTE: This must only be used for forming *success* API responses,
572/// i.e. `LxBytes` in `Result<LxBytes, E>`, because its [`IntoResponse`] impl
573/// uses [`StatusCode::OK`]. Our API error types are serialized as JSON and
574/// have separate [`IntoResponse`] impls which return error statuses.
575///
576/// [`ErrorResponse`]: lexe_api_core::error::ErrorResponse
577#[derive(Clone, Debug, Default, Eq, PartialEq, Ord, PartialOrd)]
578pub struct LxBytes(pub Bytes);
579
580impl<S: Send + Sync> FromRequest<S> for LxBytes {
581    type Rejection = LxRejection;
582
583    async fn from_request(
584        req: http::Request<axum::body::Body>,
585        state: &S,
586    ) -> Result<Self, Self::Rejection> {
587        // `Bytes`'s from_request impl is fine but its rejection is not
588        Bytes::from_request(req, state)
589            .await
590            .map(Self)
591            .map_err(LxRejection::from)
592    }
593}
594
595/// The [`Bytes`] [`IntoResponse`] impl is almost exactly the same,
596/// except it returns the wrong HTTP version.
597impl IntoResponse for LxBytes {
598    fn into_response(self) -> http::Response<axum::body::Body> {
599        let http_body = http_body_util::Full::new(self.0);
600        let axum_body = axum::body::Body::new(http_body);
601
602        axum_helpers::default_response_builder()
603            .header(
604                CONTENT_TYPE,
605                // Or `HeaderValue::from_static(mime::APPLICATION_OCTET_STREAM)`
606                HeaderValue::from_static("application/octet-stream"),
607            )
608            .status(StatusCode::OK)
609            .body(axum_body)
610            .expect("All operations here should be infallible")
611    }
612}
613
614impl<T: Into<Bytes>> From<T> for LxBytes {
615    fn from(t: T) -> Self {
616        Self(t.into())
617    }
618}
619
620// --- LxRejection --- //
621
622/// Our own [`axum::extract::rejection`] type with an [`IntoResponse`] impl
623/// which conforms to Lexe's API. Contains the source rejection's error text.
624pub struct LxRejection {
625    /// Which [`axum::extract::rejection`] this [`LxRejection`] was built from.
626    kind: LxRejectionKind,
627    /// The error text of the source rejection, or additional context.
628    source_msg: String,
629}
630
631/// The source of this [`LxRejection`].
632enum LxRejectionKind {
633    // -- From `axum::extract::rejection` -- //
634    /// [`BytesRejection`]
635    Bytes,
636    /// [`JsonRejection`]
637    Json,
638    /// [`PathRejection`]
639    Path,
640    /// [`QueryRejection`]
641    Query,
642
643    // -- Other -- //
644    /// Bearer authentication failed
645    Unauthenticated,
646    /// Client is not authorized to access this resource
647    Unauthorized,
648    /// Client request did not match any paths in the [`Router`].
649    BadEndpoint,
650    /// Request body length over limit
651    BodyLengthOverLimit,
652    /// [`ed25519::Error`]
653    Ed25519,
654    /// Gateway proxy
655    Proxy,
656}
657
658// Use explicit `.map_err()`s instead of From impls for non-obvious conversions
659impl LxRejection {
660    pub fn from_ed25519(error: ed25519::Error) -> Self {
661        Self {
662            kind: LxRejectionKind::Ed25519,
663            source_msg: format!("{error:#}"),
664        }
665    }
666
667    pub fn from_bearer_auth(error: auth::Error) -> Self {
668        Self {
669            kind: LxRejectionKind::Unauthenticated,
670            source_msg: format!("{error:#}"),
671        }
672    }
673
674    pub fn scope_unauthorized(
675        granted_scope: &Scope,
676        requested_scope: &Scope,
677    ) -> Self {
678        Self {
679            kind: LxRejectionKind::Unauthorized,
680            source_msg: format!(
681                "granted scope: {granted_scope:?}, requested scope: {requested_scope:?}"
682            ),
683        }
684    }
685
686    pub fn proxy(error: impl Display) -> Self {
687        Self {
688            kind: LxRejectionKind::Proxy,
689            source_msg: format!("{error:#}"),
690        }
691    }
692}
693
694impl From<BytesRejection> for LxRejection {
695    fn from(bytes_rejection: BytesRejection) -> Self {
696        Self {
697            kind: LxRejectionKind::Bytes,
698            source_msg: bytes_rejection.body_text(),
699        }
700    }
701}
702
703impl From<JsonRejection> for LxRejection {
704    fn from(json_rejection: JsonRejection) -> Self {
705        Self {
706            kind: LxRejectionKind::Json,
707            source_msg: json_rejection.body_text(),
708        }
709    }
710}
711
712impl From<PathRejection> for LxRejection {
713    fn from(path_rejection: PathRejection) -> Self {
714        Self {
715            kind: LxRejectionKind::Path,
716            source_msg: path_rejection.body_text(),
717        }
718    }
719}
720
721impl From<QueryRejection> for LxRejection {
722    fn from(query_rejection: QueryRejection) -> Self {
723        Self {
724            kind: LxRejectionKind::Query,
725            source_msg: query_rejection.body_text(),
726        }
727    }
728}
729
730impl IntoResponse for LxRejection {
731    fn into_response(self) -> http::Response<axum::body::Body> {
732        // TODO(phlip9): authn+authz+badendpoint rejections should return
733        // standard status codes, not just 400.
734        let kind = CommonErrorKind::Rejection;
735        // "Bad JSON: Failed to deserialize the JSON body into the target type"
736        let kind_msg = self.kind.to_msg();
737        let source_msg = &self.source_msg;
738        let msg = format!("Rejection: {kind_msg}: {source_msg}");
739        // Log the rejection now since our trace layer can't access this info
740        warn!("{msg}");
741        let common_error = CommonApiError { kind, msg };
742        common_error.into_response()
743    }
744}
745
746impl LxRejectionKind {
747    /// A generic error message for this rejection kind.
748    fn to_msg(&self) -> &'static str {
749        match self {
750            Self::Bytes => "Bad request bytes",
751            Self::Json => "Client provided bad JSON",
752            Self::Path => "Client provided bad path parameter",
753            Self::Query => "Client provided bad query string",
754
755            Self::Unauthenticated => "Invalid bearer auth",
756            Self::Unauthorized => "Not authorized to access this resource",
757            Self::BadEndpoint => "Client requested a non-existent endpoint",
758            Self::BodyLengthOverLimit => "Request body length over limit",
759            Self::Ed25519 => "Ed25519 error",
760            Self::Proxy => "Proxy error",
761        }
762    }
763}
764
765// --- Extractors --- //
766
767pub mod extract {
768    use axum::extract::FromRequestParts;
769
770    use super::*;
771
772    /// Lexe API-compliant version of [`axum::extract::Query`].
773    pub struct LxQuery<T>(pub T);
774
775    impl<T: DeserializeOwned, S: Send + Sync> FromRequestParts<S> for LxQuery<T> {
776        type Rejection = LxRejection;
777
778        async fn from_request_parts(
779            parts: &mut http::request::Parts,
780            state: &S,
781        ) -> Result<Self, Self::Rejection> {
782            axum::extract::Query::from_request_parts(parts, state)
783                .await
784                .map(|axum::extract::Query(t)| Self(t))
785                .map_err(LxRejection::from)
786        }
787    }
788
789    impl<T: Clone> Clone for LxQuery<T> {
790        fn clone(&self) -> Self {
791            Self(self.0.clone())
792        }
793    }
794
795    impl<T: fmt::Debug> fmt::Debug for LxQuery<T> {
796        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
797            T::fmt(&self.0, f)
798        }
799    }
800
801    impl<T: Eq + PartialEq> Eq for LxQuery<T> {}
802
803    impl<T: PartialEq> PartialEq for LxQuery<T> {
804        fn eq(&self, other: &Self) -> bool {
805            self.0.eq(&other.0)
806        }
807    }
808
809    /// Lexe API-compliant version of [`axum::extract::Path`].
810    pub struct LxPath<T>(pub T);
811
812    impl<T: DeserializeOwned + Send, S: Send + Sync> FromRequestParts<S>
813        for LxPath<T>
814    {
815        type Rejection = LxRejection;
816
817        async fn from_request_parts(
818            parts: &mut http::request::Parts,
819            state: &S,
820        ) -> Result<Self, Self::Rejection> {
821            axum::extract::Path::from_request_parts(parts, state)
822                .await
823                .map(|axum::extract::Path(t)| Self(t))
824                .map_err(LxRejection::from)
825        }
826    }
827
828    impl<T: Clone> Clone for LxPath<T> {
829        fn clone(&self) -> Self {
830            Self(self.0.clone())
831        }
832    }
833
834    impl<T: fmt::Debug> fmt::Debug for LxPath<T> {
835        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
836            T::fmt(&self.0, f)
837        }
838    }
839
840    impl<T: Eq + PartialEq> Eq for LxPath<T> {}
841
842    impl<T: PartialEq> PartialEq for LxPath<T> {
843        fn eq(&self, other: &Self) -> bool {
844            self.0.eq(&other.0)
845        }
846    }
847}
848
849// --- Custom middleware --- //
850
851pub mod middleware {
852    use axum::extract::State;
853    use http::HeaderName;
854
855    use super::*;
856
857    /// The header name used for response post-processing signals.
858    pub static POST_PROCESS_HEADER: HeaderName =
859        HeaderName::from_static("lx-post-process");
860
861    /// Checks the `CONTENT_LENGTH` header and returns an early rejection if the
862    /// contained value exceeds our configured body limit. This optimization
863    /// allows us to avoid unnecessary work processing the request further.
864    ///
865    /// NOTE: This does not enforce the body length!! Use [`DefaultBodyLimit`]
866    /// in combination with [`axum::RequestExt::with_limited_body`] to do so.
867    pub async fn check_content_length_header<B>(
868        // `LayerConfig::body_limit`
869        State(config_body_limit): State<usize>,
870        request: http::Request<B>,
871    ) -> Result<http::Request<B>, LxRejection> {
872        let maybe_content_length = request
873            .headers()
874            .get(http::header::CONTENT_LENGTH)
875            .and_then(|value| value.to_str().ok())
876            .and_then(|value_str| usize::from_str(value_str).ok());
877
878        // If a limit is configured and the header value exceeds it, reject.
879        if let Some(content_length) = maybe_content_length
880            && content_length > config_body_limit
881        {
882            return Err(LxRejection {
883                kind: LxRejectionKind::BodyLengthOverLimit,
884                source_msg: "Content length header over limit".to_owned(),
885            });
886        }
887
888        Ok(request)
889    }
890
891    /// A post-processor which can be used to modify the [`http::Response`]s
892    /// returned by an [`axum::Router`]. This is done by signalling the desired
893    /// modification in a fake [`POST_PROCESS_HEADER`] which is also removed
894    /// during post-processing. This can be used to override Axum defaults
895    /// which one does not have access to from within the [`Router`]. Currently,
896    /// this only supports a "remove-content-length" command which removes the
897    /// content-length header set by Axum, but can be easily extended.
898    pub(super) fn post_process_response(
899        mut response: http::Response<axum::body::Body>,
900    ) -> http::Response<axum::body::Body> {
901        let value = match response.headers_mut().remove(&POST_PROCESS_HEADER) {
902            Some(v) => v,
903            None => return response,
904        };
905
906        match value.as_bytes() {
907            b"remove-content-length" => {
908                response.headers_mut().remove(http::header::CONTENT_LENGTH);
909                debug!("Post process: Removed content-length header");
910            }
911            unknown => {
912                let unknown_str = String::from_utf8_lossy(unknown);
913                warn!("Post process: Invalid header value: {unknown_str}");
914            }
915        }
916
917        response
918    }
919}
920
921// --- Helpers --- //
922
923/// Lexe's default fallback [`Handler`](axum::handler::Handler).
924/// Returns a "bad endpoint" rejection along with the requested method and path.
925pub async fn default_fallback(
926    method: http::Method,
927    uri: http::Uri,
928) -> LxRejection {
929    let path = uri.path();
930    LxRejection {
931        kind: LxRejectionKind::BadEndpoint,
932        // e.g. "POST /app/node_info"
933        source_msg: format!("{method} {path}"),
934    }
935}