1use anyhow::Result;
2use async_trait::async_trait;
3use nest_rs_core::{Container, DiscoveryService, Transport};
4use poem::endpoint::BoxEndpoint;
5use poem::http::header::{HeaderValue, SERVER};
6use poem::listener::{Listener, TcpListener};
7use poem::middleware::{Cors, SetHeader};
8use poem::{EndpointExt, IntoEndpoint, Response, Route, Server};
9use tokio_util::sync::CancellationToken;
10
11use crate::boot_check::{GlobalGuardsActive, HttpBootCheck};
12use crate::controller::HttpControllerMeta;
13use crate::endpoint::{EdgePosture, HttpEndpointMeta, SelfMountGuardWrap};
14use crate::interceptor::HttpEndpointWrap;
15use crate::raw_body::RawBodyLimit;
16use crate::tls::TlsConfig;
17
18type MountFn = Box<dyn Fn(&Container, Route) -> Route + Send + Sync>;
19type NamedMount = (String, MountFn);
22
23pub fn join_path(prefix: &str, rest: &str) -> String {
28 let p = prefix.trim_end_matches('/');
29 let r = rest.trim_start_matches('/');
30 match (p.is_empty(), r.is_empty()) {
31 (true, true) => "/".to_string(),
32 (false, true) => p.to_string(),
33 (true, false) => format!("/{r}"),
34 (false, false) => format!("{p}/{r}"),
35 }
36}
37
38pub fn version_path(version: Option<&str>, path: &str) -> String {
43 match version {
44 Some(v) => join_path(&format!("/v{v}"), path),
45 None => path.to_string(),
46 }
47}
48
49pub struct HttpTransport {
61 bind: String,
62 mounts: Vec<NamedMount>,
63 cors: Option<Cors>,
64 tls: Option<TlsConfig>,
65 server_header: Option<&'static str>,
66 global_prefix: Option<String>,
67 max_body_bytes: Option<usize>,
68 request_timeout: Option<std::time::Duration>,
69 fail_secure_strict: bool,
70 endpoint: Option<BoxEndpoint<'static, Response>>,
71}
72
73fn normalize_global_prefix(raw: &str) -> Option<String> {
77 let trimmed = raw.trim().trim_matches('/');
78 if trimmed.is_empty() {
79 return None;
80 }
81 Some(format!("/{trimmed}"))
82}
83
84impl Default for HttpTransport {
85 fn default() -> Self {
86 Self::new()
87 }
88}
89
90impl HttpTransport {
91 pub fn new() -> Self {
92 Self {
93 bind: "0.0.0.0:3000".into(),
94 mounts: Vec::new(),
95 cors: None,
96 tls: None,
97 server_header: None,
98 global_prefix: None,
99 max_body_bytes: None,
100 request_timeout: None,
101 fail_secure_strict: true,
106 endpoint: None,
107 }
108 }
109
110 pub fn fail_secure_strict(mut self, strict: bool) -> Self {
114 self.fail_secure_strict = strict;
115 self
116 }
117
118 pub fn global_prefix(mut self, prefix: impl Into<String>) -> Self {
123 self.global_prefix = normalize_global_prefix(&prefix.into());
124 self
125 }
126
127 pub fn server_header(mut self, value: &'static str) -> Self {
131 self.server_header = Some(value);
132 self
133 }
134
135 pub fn bind(mut self, addr: impl Into<String>) -> Self {
136 self.bind = addr.into();
137 self
138 }
139
140 pub fn max_body_bytes(mut self, limit: usize) -> Self {
144 self.max_body_bytes = Some(limit);
145 self
146 }
147
148 pub fn request_timeout(mut self, timeout: std::time::Duration) -> Self {
152 self.request_timeout = Some(timeout);
153 self
154 }
155
156 pub fn cors(mut self, cors: Cors) -> Self {
160 self.cors = Some(cors);
161 self
162 }
163
164 pub fn tls(mut self, tls: TlsConfig) -> Self {
167 self.tls = Some(tls);
168 self
169 }
170
171 pub fn mount<F, E>(mut self, path: impl Into<String>, build: F) -> Self
175 where
176 F: Fn(&Container) -> E + Send + Sync + 'static,
177 E: IntoEndpoint,
178 E::Endpoint: 'static,
179 <E::Endpoint as poem::Endpoint>::Output: poem::IntoResponse,
180 {
181 let path = path.into();
182 let mount_path = path.clone();
183 self.mounts.push((
184 path,
185 Box::new(move |container, route| {
186 let endpoint = build(container).into_endpoint().map_to_response().boxed();
187 route.nest(mount_path.clone(), endpoint)
188 }),
189 ));
190 self
191 }
192
193 pub fn take_endpoint(&mut self) -> Option<BoxEndpoint<'static, Response>> {
197 self.endpoint.take()
198 }
199}
200
201#[async_trait]
202impl Transport for HttpTransport {
203 async fn configure(&mut self, container: &Container) -> Result<()> {
204 let discovery = DiscoveryService::new(container);
205 for d in discovery.meta::<HttpBootCheck>() {
209 d.meta.run(container).map_err(|msg| anyhow::anyhow!(msg))?;
210 }
211 let mut route = Route::new();
212
213 for d in discovery.meta::<HttpControllerMeta>() {
214 let prefix = d.meta.effective_prefix();
215 for r in &d.meta.routes {
216 tracing::info!(
217 target: "nest_rs::routes",
218 "{:<6} {} ({})",
219 r.verb.as_str(),
220 join_path(&prefix, r.path),
221 r.handler,
222 );
223 }
224 route = d.meta.mount(container, route);
225 }
226 let self_mount_guard = discovery
231 .meta::<SelfMountGuardWrap>()
232 .into_iter()
233 .next()
234 .map(|d| d.meta);
235 for d in discovery.meta::<HttpEndpointMeta>() {
236 tracing::info!(
237 target: "nest_rs::routes",
238 "{:<6} {} ({})",
239 "*",
240 d.meta.path(),
241 d.meta.label(),
242 );
243 match (d.meta.posture(), &self_mount_guard) {
244 (EdgePosture::Guarded, Some(wrap)) => {
245 let isolated: BoxEndpoint<'static, Response> =
249 d.meta.mount(container, Route::new()).boxed();
250 let wrapped = wrap.apply(container, isolated);
251 route = route.nest_no_strip(d.meta.path(), wrapped);
252 }
253 _ => {
254 route = d.meta.mount(container, route);
258 }
259 }
260 }
261 if !self.mounts.is_empty() && container.get::<GlobalGuardsActive>().is_some() {
269 let paths: Vec<&str> = self.mounts.iter().map(|(p, _)| p.as_str()).collect();
270 if self.fail_secure_strict {
271 anyhow::bail!(
272 "fail-secure: imperative mount(...) endpoints bypass the global guard pool: \
273 {} — route them through a #[controller], guard them explicitly, or opt out \
274 with HttpTransport::fail_secure_strict(false) / \
275 NESTRS_HTTP__FAIL_SECURE_STRICT=false",
276 paths.join(", "),
277 );
278 }
279 tracing::warn!(
280 target: "nest_rs::http",
281 paths = paths.join(", ").as_str(),
282 "imperative mount(...) endpoints bypass the global guard pool — route them through a #[controller] or guard them explicitly",
283 );
284 }
285 for (_, mount) in self.mounts.drain(..) {
286 route = mount(container, route);
287 }
288
289 if let Some(prefix) = self.global_prefix.take() {
293 route = Route::new().nest(prefix, route);
294 }
295
296 let mut endpoint: BoxEndpoint<'static, Response> = route.map_to_response().boxed();
297 let mut metas: Vec<std::sync::Arc<HttpEndpointWrap>> = discovery
304 .meta::<HttpEndpointWrap>()
305 .into_iter()
306 .map(|d| d.meta)
307 .collect();
308 metas.sort_by_key(|m| m.priority());
309 for meta in metas {
310 endpoint = meta.wrap(container, endpoint);
311 }
312 if let Some(timeout) = self.request_timeout.take() {
317 endpoint = endpoint
318 .around(move |ep, req| async move {
319 match tokio::time::timeout(timeout, ep.call(req)).await {
320 Ok(res) => res,
321 Err(_) => {
322 tracing::warn!(target: "nest_rs::http", ?timeout, "request timed out");
323 Ok(Response::builder()
324 .status(poem::http::StatusCode::GATEWAY_TIMEOUT)
325 .finish())
326 }
327 }
328 })
329 .map_to_response()
330 .boxed();
331 }
332 if let Some(limit) = self.max_body_bytes.take() {
339 endpoint = endpoint.data(RawBodyLimit(limit)).map_to_response().boxed();
340 }
341 if let Some(value) = self.server_header.take() {
344 let header_value = HeaderValue::from_static(value);
345 let set = SetHeader::new().overriding(SERVER, header_value);
346 endpoint = endpoint.with(set).map_to_response().boxed();
347 }
348 if let Some(cors) = self.cors.take() {
350 endpoint = endpoint.with(cors).map_to_response().boxed();
351 }
352 endpoint = crate::RequestScopeEndpoint::new(endpoint, container.clone())
355 .map_to_response()
356 .boxed();
357
358 self.endpoint = Some(endpoint);
359 Ok(())
360 }
361
362 async fn serve(self: Box<Self>, cancel: CancellationToken) -> Result<()> {
363 let endpoint = self
364 .endpoint
365 .expect("HttpTransport::configure must run before serve");
366 let bind = self.bind;
367 let listener = match self.tls {
368 Some(tls) => {
369 tracing::debug!(addr = %bind, "https transport listening (TLS)");
370 TcpListener::bind(bind).rustls(tls.into_rustls()).boxed()
371 }
372 None => {
373 tracing::debug!(addr = %bind, "http transport listening");
374 TcpListener::bind(bind).boxed()
375 }
376 };
377 Server::new(listener)
378 .run_with_graceful_shutdown(endpoint, async move { cancel.cancelled().await }, None)
379 .await?;
380 Ok(())
381 }
382}
383
384#[cfg(test)]
385mod tests {
386 use super::*;
387
388 #[test]
392 fn join_path_concatenates_clean_segments() {
393 assert_eq!(join_path("/health", "/live"), "/health/live");
394 assert_eq!(join_path("/users", "/:id"), "/users/:id");
395 }
396
397 #[test]
398 fn join_path_strips_redundant_slashes_on_either_side() {
399 assert_eq!(join_path("/health/", "/live"), "/health/live");
400 assert_eq!(join_path("/health", "live"), "/health/live");
401 assert_eq!(join_path("/health/", "live"), "/health/live");
402 }
403
404 #[test]
405 fn join_path_handles_empty_or_root_segments() {
406 assert_eq!(join_path("", ""), "/");
407 assert_eq!(join_path("/", ""), "/");
408 assert_eq!(join_path("/", "/"), "/");
409 assert_eq!(join_path("", "/users"), "/users");
410 assert_eq!(join_path("/users", ""), "/users");
411 }
412
413 #[test]
414 fn version_path_prefixes_when_a_version_is_supplied() {
415 assert_eq!(version_path(Some("1"), "/users"), "/v1/users");
416 assert_eq!(version_path(Some("2"), "/users/:id"), "/v2/users/:id");
417 assert_eq!(version_path(Some("1"), "/"), "/v1");
419 }
420
421 #[test]
422 fn version_path_leaves_an_unversioned_path_alone() {
423 assert_eq!(version_path(None, "/users"), "/users");
424 assert_eq!(version_path(None, "/"), "/");
425 }
426
427 #[test]
428 fn http_transport_defaults_match_an_empty_new() {
429 let d = HttpTransport::default();
430 let n = HttpTransport::new();
431 assert_eq!(d.bind, n.bind);
432 assert_eq!(d.bind, "0.0.0.0:3000");
433 assert!(d.mounts.is_empty());
434 assert!(d.cors.is_none());
435 assert!(d.tls.is_none());
436 assert!(d.server_header.is_none());
437 assert!(d.endpoint.is_none());
438 }
439
440 #[test]
441 fn bind_overrides_the_default_address() {
442 let t = HttpTransport::new().bind("127.0.0.1:9000");
443 assert_eq!(t.bind, "127.0.0.1:9000");
444 }
445
446 #[test]
447 fn tls_pins_the_supplied_config() {
448 let t = HttpTransport::new().tls(TlsConfig::new(b"cert".to_vec(), b"key".to_vec()));
450 assert!(t.tls.is_some());
451 }
452
453 #[test]
454 fn server_header_pins_the_supplied_static_str() {
455 let t = HttpTransport::new().server_header("nestrs/0.1.0");
456 assert_eq!(t.server_header, Some("nestrs/0.1.0"));
457 }
458
459 #[test]
460 fn take_endpoint_returns_none_before_configure_has_run() {
461 let mut t = HttpTransport::new();
462 assert!(t.take_endpoint().is_none(), "no endpoint before configure");
463 }
464}