Skip to main content

nest_rs_http/
transport.rs

1use anyhow::Result;
2use async_trait::async_trait;
3use nest_rs_core::{Container, DiscoveryService, Transport};
4use poem::endpoint::BoxEndpoint;
5use poem::http::header::{HeaderValue, SERVER};
6use poem::listener::{Listener, TcpListener};
7use poem::middleware::{Cors, SetHeader};
8use poem::{EndpointExt, IntoEndpoint, Response, Route, Server};
9use tokio_util::sync::CancellationToken;
10
11use crate::boot_check::{GlobalGuardsActive, HttpBootCheck};
12use crate::controller::HttpControllerMeta;
13use crate::endpoint::{EdgePosture, HttpEndpointMeta, SelfMountGuardWrap};
14use crate::interceptor::HttpEndpointWrap;
15use crate::raw_body::RawBodyLimit;
16use crate::tls::TlsConfig;
17
18type MountFn = Box<dyn Fn(&Container, Route) -> Route + Send + Sync>;
19/// Imperative mount paired with its path — kept so the fail-secure boot
20/// check can name the endpoints that bypass the layer pool.
21type NamedMount = (String, MountFn);
22
23/// Join a controller prefix with a route path the way poem's nesting does:
24/// `("/health", "/live") -> "/health/live"`. Public so `nestrs-openapi`
25/// composes paths identically to how this transport mounts them — the served
26/// path and the documented path must not drift.
27pub fn join_path(prefix: &str, rest: &str) -> String {
28    let p = prefix.trim_end_matches('/');
29    let r = rest.trim_start_matches('/');
30    match (p.is_empty(), r.is_empty()) {
31        (true, true) => "/".to_string(),
32        (false, true) => p.to_string(),
33        (true, false) => format!("/{r}"),
34        (false, false) => format!("{p}/{r}"),
35    }
36}
37
38/// Apply URI API versioning: `Some("1"), "/users"` → `"/v1/users"`. The single
39/// place the URI strategy lives — `#[routes]`, the boot route log, and the
40/// OpenAPI document all route through it so the served/logged/documented paths
41/// can never drift.
42pub fn version_path(version: Option<&str>, path: &str) -> String {
43    match version {
44        Some(v) => join_path(&format!("/v{v}"), path),
45        None => path.to_string(),
46    }
47}
48
49/// HTTP [`Transport`] backed by poem. At [`Transport::configure`] time, runs
50/// every discovered [`HttpBootCheck`], mounts every
51/// `#[module(providers = [...])]`-declared [`HttpControllerMeta`] and
52/// [`HttpEndpointMeta`], then any imperative [`HttpTransport::mount`], then
53/// folds every discovered [`HttpEndpointWrap`] wrap around the assembled
54/// endpoint. Transport-edge wraps (the global interceptor / filter pools,
55/// infra `#[interceptor]`s like `DbContext`) attach themselves through
56/// [`HttpEndpointWrap`] from their own crates — this transport stays free
57/// of the cross-transport trait crates and only knows about poem. Guards
58/// and pipes never wrap here: they execute in the per-route shaper
59/// (post-routing) or at a `Guarded` self-mount's edge.
60pub struct HttpTransport {
61    bind: String,
62    mounts: Vec<NamedMount>,
63    cors: Option<Cors>,
64    tls: Option<TlsConfig>,
65    server_header: Option<&'static str>,
66    global_prefix: Option<String>,
67    max_body_bytes: Option<usize>,
68    request_timeout: Option<std::time::Duration>,
69    fail_secure_strict: bool,
70    endpoint: Option<BoxEndpoint<'static, Response>>,
71}
72
73/// Normalize a global prefix: trim whitespace, drop empty/`"/"` to `None`,
74/// prepend a leading `/`, strip a trailing one. `Some("/api/v1")` is the
75/// canonical form.
76fn normalize_global_prefix(raw: &str) -> Option<String> {
77    let trimmed = raw.trim().trim_matches('/');
78    if trimmed.is_empty() {
79        return None;
80    }
81    Some(format!("/{trimmed}"))
82}
83
84impl Default for HttpTransport {
85    fn default() -> Self {
86        Self::new()
87    }
88}
89
90impl HttpTransport {
91    pub fn new() -> Self {
92        Self {
93            bind: "0.0.0.0:3000".into(),
94            mounts: Vec::new(),
95            cors: None,
96            tls: None,
97            server_header: None,
98            global_prefix: None,
99            max_body_bytes: None,
100            request_timeout: None,
101            // Fail-secure by default: when global guards are active, an
102            // endpoint the transport cannot shape fails boot instead of
103            // mounting unguarded. Opt out via `fail_secure_strict(false)` /
104            // `NESTRS_HTTP__FAIL_SECURE_STRICT=false`.
105            fail_secure_strict: true,
106            endpoint: None,
107        }
108    }
109
110    /// `true` (the default) makes `configure` **fail** when global guards are
111    /// registered and an imperative [`mount`](Self::mount) endpoint would
112    /// bypass the guard pool; `false` downgrades the violation to a `warn`.
113    pub fn fail_secure_strict(mut self, strict: bool) -> Self {
114        self.fail_secure_strict = strict;
115        self
116    }
117
118    /// Mount every controller under a shared prefix (e.g. `/api`). Useful
119    /// behind a reverse proxy that hands off a sub-path. Empty / `"/"`
120    /// collapse to no-op; a missing leading `/` is added; a trailing `/` is
121    /// stripped.
122    pub fn global_prefix(mut self, prefix: impl Into<String>) -> Self {
123        self.global_prefix = normalize_global_prefix(&prefix.into());
124        self
125    }
126
127    /// Emit `Server: <value>` on every response — off by default
128    /// (production-safe). [`HttpModule`](crate::HttpModule) sets this when
129    /// `HttpConfig.server_header` is `true`, using `nestrs/<crate version>`.
130    pub fn server_header(mut self, value: &'static str) -> Self {
131        self.server_header = Some(value);
132        self
133    }
134
135    pub fn bind(mut self, addr: impl Into<String>) -> Self {
136        self.bind = addr.into();
137        self
138    }
139
140    /// Cap each request's raw body to `limit` bytes. Read back by the
141    /// [`RawBody`](crate::RawBody) extractor via the
142    /// [`RawBodyLimit`](crate::RawBodyLimit) request extension.
143    pub fn max_body_bytes(mut self, limit: usize) -> Self {
144        self.max_body_bytes = Some(limit);
145        self
146    }
147
148    /// Abort any request that runs longer than `timeout`, answering the client
149    /// with `504 Gateway Timeout`. Bounds connection hold time against slow or
150    /// stuck handlers. Without this call no timeout is enforced.
151    pub fn request_timeout(mut self, timeout: std::time::Duration) -> Self {
152        self.request_timeout = Some(timeout);
153        self
154    }
155
156    /// Enable CORS with a configured poem [`Cors`] middleware. Wraps the route
157    /// tree outermost so a preflight (`OPTIONS`) is answered before any guard
158    /// or interceptor runs.
159    pub fn cors(mut self, cors: Cors) -> Self {
160        self.cors = Some(cors);
161        self
162    }
163
164    /// Serve HTTPS directly from [`TlsConfig`] (poem's `rustls` listener)
165    /// instead of plain HTTP. Without this call the transport stays plaintext.
166    pub fn tls(mut self, tls: TlsConfig) -> Self {
167        self.tls = Some(tls);
168        self
169    }
170
171    /// Mount an extra endpoint at `path`. The builder closure runs at
172    /// [`Transport::configure`] time with the live container, so it can
173    /// resolve services to construct framework-specific endpoints.
174    pub fn mount<F, E>(mut self, path: impl Into<String>, build: F) -> Self
175    where
176        F: Fn(&Container) -> E + Send + Sync + 'static,
177        E: IntoEndpoint,
178        E::Endpoint: 'static,
179        <E::Endpoint as poem::Endpoint>::Output: poem::IntoResponse,
180    {
181        let path = path.into();
182        let mount_path = path.clone();
183        self.mounts.push((
184            path,
185            Box::new(move |container, route| {
186                let endpoint = build(container).into_endpoint().map_to_response().boxed();
187                route.nest(mount_path.clone(), endpoint)
188            }),
189        ));
190        self
191    }
192
193    /// Take the assembled endpoint for in-process testing (drive with poem's
194    /// `TestClient`). Returns `None` before `configure` has run, and leaves
195    /// the transport without an endpoint (so it must not also be `serve`d).
196    pub fn take_endpoint(&mut self) -> Option<BoxEndpoint<'static, Response>> {
197        self.endpoint.take()
198    }
199}
200
201#[async_trait]
202impl Transport for HttpTransport {
203    async fn configure(&mut self, container: &Container) -> Result<()> {
204        let discovery = DiscoveryService::new(container);
205        // Boot checks first — a misconfigured global layer pool (a spec whose
206        // provider was never registered) must fail boot before anything
207        // mounts; resolved-at-configure means dropped-silently otherwise.
208        for d in discovery.meta::<HttpBootCheck>() {
209            d.meta.run(container).map_err(|msg| anyhow::anyhow!(msg))?;
210        }
211        let mut route = Route::new();
212
213        for d in discovery.meta::<HttpControllerMeta>() {
214            let prefix = d.meta.effective_prefix();
215            for r in &d.meta.routes {
216                tracing::info!(
217                    target: "nest_rs::routes",
218                    "{:<6} {}  ({})",
219                    r.verb.as_str(),
220                    join_path(&prefix, r.path),
221                    r.handler,
222                );
223            }
224            route = d.meta.mount(container, route);
225        }
226        // Provided by `use_guards_global` (which can see the `Guard` trait);
227        // absent when no global guard is registered. Applied below to every
228        // `Guarded` self-mount — they have no per-route shaper to carry the
229        // global guard pool, so the transport runs it at their edge.
230        let self_mount_guard = discovery
231            .meta::<SelfMountGuardWrap>()
232            .into_iter()
233            .next()
234            .map(|d| d.meta);
235        for d in discovery.meta::<HttpEndpointMeta>() {
236            tracing::info!(
237                target: "nest_rs::routes",
238                "{:<6} {}  ({})",
239                "*",
240                d.meta.path(),
241                d.meta.label(),
242            );
243            match (d.meta.posture(), &self_mount_guard) {
244                (EdgePosture::Guarded, Some(wrap)) => {
245                    // Isolate this self-mount into a fresh sub-route, wrap it
246                    // with the global guard chain, and nest it back without
247                    // stripping its own path (so the inner route still matches).
248                    let isolated: BoxEndpoint<'static, Response> =
249                        d.meta.mount(container, Route::new()).boxed();
250                    let wrapped = wrap.apply(container, isolated);
251                    route = route.nest_no_strip(d.meta.path(), wrapped);
252                }
253                _ => {
254                    // `Exempt` surfaces gate in-band (GraphQL operation guard,
255                    // MCP per-request guard) or are deliberately public
256                    // (OpenAPI docs) — no edge wrap.
257                    route = d.meta.mount(container, route);
258                }
259            }
260        }
261        // Fail-secure completeness check: every controller route is shaped
262        // (its `RouteShaper` runs the global guard pool) and every self-mount
263        // declares an `EdgePosture`, but an imperative `mount(...)` is an
264        // opaque poem endpoint the transport can neither shape nor introspect.
265        // When global guards are active, those endpoints bypass the pool —
266        // strict mode (the default) fails boot, the same posture as the
267        // access graph; opting out downgrades to a warn.
268        if !self.mounts.is_empty() && container.get::<GlobalGuardsActive>().is_some() {
269            let paths: Vec<&str> = self.mounts.iter().map(|(p, _)| p.as_str()).collect();
270            if self.fail_secure_strict {
271                anyhow::bail!(
272                    "fail-secure: imperative mount(...) endpoints bypass the global guard pool: \
273                     {} — route them through a #[controller], guard them explicitly, or opt out \
274                     with HttpTransport::fail_secure_strict(false) / \
275                     NESTRS_HTTP__FAIL_SECURE_STRICT=false",
276                    paths.join(", "),
277                );
278            }
279            tracing::warn!(
280                target: "nest_rs::http",
281                paths = paths.join(", ").as_str(),
282                "imperative mount(...) endpoints bypass the global guard pool — route them through a #[controller] or guard them explicitly",
283            );
284        }
285        for (_, mount) in self.mounts.drain(..) {
286            route = mount(container, route);
287        }
288
289        // Apply the global prefix once around the fully-assembled tree so
290        // every controller, every self-mounting endpoint, and every imperative
291        // `mount(...)` lands under it.
292        if let Some(prefix) = self.global_prefix.take() {
293            route = Route::new().nest(prefix, route);
294        }
295
296        let mut endpoint: BoxEndpoint<'static, Response> = route.map_to_response().boxed();
297        // Layer-System globals (guards / interceptors / filters / pipes /
298        // exception filters) attach a `HttpEndpointWrap` from their own
299        // crate. The transport sorts by priority ascending so the
300        // documented HTTP order is enforced regardless of AppBuilder call
301        // sequence: Guards (innermost) → Filters → Interceptors
302        // (outermost). Insertion order is the tiebreaker within a band.
303        let mut metas: Vec<std::sync::Arc<HttpEndpointWrap>> = discovery
304            .meta::<HttpEndpointWrap>()
305            .into_iter()
306            .map(|d| d.meta)
307            .collect();
308        metas.sort_by_key(|m| m.priority());
309        for meta in metas {
310            endpoint = meta.wrap(container, endpoint);
311        }
312        // Wrap the whole Layer System in a wall-clock budget: a handler that
313        // overruns is aborted and the client gets `504`. Outside the globals
314        // so guards/interceptors are themselves bounded; inside body-limit /
315        // header / CORS so a preflight is still answered without the timer.
316        if let Some(timeout) = self.request_timeout.take() {
317            endpoint = endpoint
318                .around(move |ep, req| async move {
319                    match tokio::time::timeout(timeout, ep.call(req)).await {
320                        Ok(res) => res,
321                        Err(_) => {
322                            tracing::warn!(target: "nest_rs::http", ?timeout, "request timed out");
323                            Ok(Response::builder()
324                                .status(poem::http::StatusCode::GATEWAY_TIMEOUT)
325                                .finish())
326                        }
327                    }
328                })
329                .map_to_response()
330                .boxed();
331        }
332        // Apply the body-byte cap, if any, as a request-data entry the
333        // `RawBody` extractor reads back. Installed OUTSIDE the Layer
334        // System globals so every interceptor / filter / guard that
335        // inspects `req.extensions().get::<RawBodyLimit>()` before calling
336        // `next` sees the configured value — pre-v5 behavior, preserved.
337        // No `Interceptor` trait needed — `EndpointExt::data` is enough.
338        if let Some(limit) = self.max_body_bytes.take() {
339            endpoint = endpoint.data(RawBodyLimit(limit)).map_to_response().boxed();
340        }
341        // Server header is purely cosmetic — apply before CORS so the
342        // preflight short-circuit (no body) still carries it for observability.
343        if let Some(value) = self.server_header.take() {
344            let header_value = HeaderValue::from_static(value);
345            let set = SetHeader::new().overriding(SERVER, header_value);
346            endpoint = endpoint.with(set).map_to_response().boxed();
347        }
348        // CORS wraps outermost, so a preflight is handled before guards run.
349        if let Some(cors) = self.cors.take() {
350            endpoint = endpoint.with(cors).map_to_response().boxed();
351        }
352        // Request scope installs before anything else so guards/handlers can
353        // resolve `#[injectable(scope = request)]` providers via `Scoped<T>`.
354        endpoint = crate::RequestScopeEndpoint::new(endpoint, container.clone())
355            .map_to_response()
356            .boxed();
357
358        self.endpoint = Some(endpoint);
359        Ok(())
360    }
361
362    async fn serve(self: Box<Self>, cancel: CancellationToken) -> Result<()> {
363        let endpoint = self
364            .endpoint
365            .expect("HttpTransport::configure must run before serve");
366        let bind = self.bind;
367        let listener = match self.tls {
368            Some(tls) => {
369                tracing::debug!(addr = %bind, "https transport listening (TLS)");
370                TcpListener::bind(bind).rustls(tls.into_rustls()).boxed()
371            }
372            None => {
373                tracing::debug!(addr = %bind, "http transport listening");
374                TcpListener::bind(bind).boxed()
375            }
376        };
377        Server::new(listener)
378            .run_with_graceful_shutdown(endpoint, async move { cancel.cancelled().await }, None)
379            .await?;
380        Ok(())
381    }
382}
383
384#[cfg(test)]
385mod tests {
386    use super::*;
387
388    // `join_path` is the single source of truth shared with `nestrs-openapi`
389    // and the boot route log — a drift here means the served path and the
390    // documented path disagree, so the cases are exhaustive on purpose.
391    #[test]
392    fn join_path_concatenates_clean_segments() {
393        assert_eq!(join_path("/health", "/live"), "/health/live");
394        assert_eq!(join_path("/users", "/:id"), "/users/:id");
395    }
396
397    #[test]
398    fn join_path_strips_redundant_slashes_on_either_side() {
399        assert_eq!(join_path("/health/", "/live"), "/health/live");
400        assert_eq!(join_path("/health", "live"), "/health/live");
401        assert_eq!(join_path("/health/", "live"), "/health/live");
402    }
403
404    #[test]
405    fn join_path_handles_empty_or_root_segments() {
406        assert_eq!(join_path("", ""), "/");
407        assert_eq!(join_path("/", ""), "/");
408        assert_eq!(join_path("/", "/"), "/");
409        assert_eq!(join_path("", "/users"), "/users");
410        assert_eq!(join_path("/users", ""), "/users");
411    }
412
413    #[test]
414    fn version_path_prefixes_when_a_version_is_supplied() {
415        assert_eq!(version_path(Some("1"), "/users"), "/v1/users");
416        assert_eq!(version_path(Some("2"), "/users/:id"), "/v2/users/:id");
417        // Version + root.
418        assert_eq!(version_path(Some("1"), "/"), "/v1");
419    }
420
421    #[test]
422    fn version_path_leaves_an_unversioned_path_alone() {
423        assert_eq!(version_path(None, "/users"), "/users");
424        assert_eq!(version_path(None, "/"), "/");
425    }
426
427    #[test]
428    fn http_transport_defaults_match_an_empty_new() {
429        let d = HttpTransport::default();
430        let n = HttpTransport::new();
431        assert_eq!(d.bind, n.bind);
432        assert_eq!(d.bind, "0.0.0.0:3000");
433        assert!(d.mounts.is_empty());
434        assert!(d.cors.is_none());
435        assert!(d.tls.is_none());
436        assert!(d.server_header.is_none());
437        assert!(d.endpoint.is_none());
438    }
439
440    #[test]
441    fn bind_overrides_the_default_address() {
442        let t = HttpTransport::new().bind("127.0.0.1:9000");
443        assert_eq!(t.bind, "127.0.0.1:9000");
444    }
445
446    #[test]
447    fn tls_pins_the_supplied_config() {
448        // TlsConfig is opaque, so just check the option flips on.
449        let t = HttpTransport::new().tls(TlsConfig::new(b"cert".to_vec(), b"key".to_vec()));
450        assert!(t.tls.is_some());
451    }
452
453    #[test]
454    fn server_header_pins_the_supplied_static_str() {
455        let t = HttpTransport::new().server_header("nestrs/0.1.0");
456        assert_eq!(t.server_header, Some("nestrs/0.1.0"));
457    }
458
459    #[test]
460    fn take_endpoint_returns_none_before_configure_has_run() {
461        let mut t = HttpTransport::new();
462        assert!(t.take_endpoint().is_none(), "no endpoint before configure");
463    }
464}