1use base64::Engine;
17use bytes::Bytes;
18use futures_util::{SinkExt, StreamExt};
19use http::{Method, Request, Response, StatusCode};
20use http_body_util::{BodyExt, Full};
21use hyper::body::Incoming;
22use hyper::server::conn::http1;
23use hyper::upgrade::Upgraded;
24use hyper_util::rt::TokioIo;
25use ranvier_core::event::{EventSink, EventSource};
26use ranvier_core::prelude::*;
27use ranvier_runtime::Axon;
28use serde::Serialize;
29use serde::de::DeserializeOwned;
30use sha1::{Digest, Sha1};
31use std::collections::HashMap;
32use std::convert::Infallible;
33use std::future::Future;
34use std::net::SocketAddr;
35use std::pin::Pin;
36use std::sync::Arc;
37use std::time::Duration;
38use tokio::net::TcpListener;
39use tokio::sync::Mutex;
40use tokio_tungstenite::WebSocketStream;
41use tokio_tungstenite::tungstenite::{Error as WsWireError, Message as WsWireMessage};
42use tracing::Instrument;
43
44use crate::guard_integration::{
45 GuardExec, GuardIntegration, PreflightConfig, RegisteredGuard, ResponseBodyTransformFn,
46 ResponseExtractorFn,
47};
48use crate::response::{HttpResponse, IntoResponse, json_error_response, outcome_to_response_with_error};
49
50pub struct Ranvier;
55
56impl Ranvier {
57 pub fn http<R>() -> HttpIngress<R>
59 where
60 R: ranvier_core::transition::ResourceRequirement + Clone,
61 {
62 HttpIngress::new()
63 }
64}
65
66type RouteHandler<R> = Arc<
68 dyn Fn(http::request::Parts, &R) -> Pin<Box<dyn Future<Output = HttpResponse> + Send>>
69 + Send
70 + Sync,
71>;
72
73#[derive(Clone)]
75struct BoxService(
76 Arc<
77 dyn Fn(Request<Incoming>) -> Pin<Box<dyn Future<Output = Result<HttpResponse, Infallible>> + Send>>
78 + Send
79 + Sync,
80 >,
81);
82
83impl BoxService {
84 fn new<F, Fut>(f: F) -> Self
85 where
86 F: Fn(Request<Incoming>) -> Fut + Send + Sync + 'static,
87 Fut: Future<Output = Result<HttpResponse, Infallible>> + Send + 'static,
88 {
89 Self(Arc::new(move |req| Box::pin(f(req))))
90 }
91
92 fn call(&self, req: Request<Incoming>) -> Pin<Box<dyn Future<Output = Result<HttpResponse, Infallible>> + Send>> {
93 (self.0)(req)
94 }
95}
96
97impl hyper::service::Service<Request<Incoming>> for BoxService {
98 type Response = HttpResponse;
99 type Error = Infallible;
100 type Future = Pin<Box<dyn Future<Output = Result<HttpResponse, Infallible>> + Send>>;
101
102 fn call(&self, req: Request<Incoming>) -> Self::Future {
103 (self.0)(req)
104 }
105}
106
107type BoxHttpService = BoxService;
108type ServiceLayer = Arc<dyn Fn(BoxHttpService) -> BoxHttpService + Send + Sync>;
109type LifecycleHook = Arc<dyn Fn() + Send + Sync>;
110type BusInjector = Arc<dyn Fn(&http::request::Parts, &mut Bus) + Send + Sync + 'static>;
111type WsSessionFuture = Pin<Box<dyn Future<Output = ()> + Send>>;
112type WsSessionHandler<R> =
113 Arc<dyn Fn(WebSocketConnection, Arc<R>, Bus) -> WsSessionFuture + Send + Sync>;
114type HealthCheckFuture = Pin<Box<dyn Future<Output = Result<(), String>> + Send>>;
115type HealthCheckFn<R> = Arc<dyn Fn(Arc<R>) -> HealthCheckFuture + Send + Sync>;
116const REQUEST_ID_HEADER: &str = "x-request-id";
117const WS_UPGRADE_TOKEN: &str = "websocket";
118const WS_GUID: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
119
120#[derive(Clone)]
121struct NamedHealthCheck<R> {
122 name: String,
123 check: HealthCheckFn<R>,
124}
125
126#[derive(Clone)]
127struct HealthConfig<R> {
128 health_path: Option<String>,
129 readiness_path: Option<String>,
130 liveness_path: Option<String>,
131 checks: Vec<NamedHealthCheck<R>>,
132}
133
134impl<R> Default for HealthConfig<R> {
135 fn default() -> Self {
136 Self {
137 health_path: None,
138 readiness_path: None,
139 liveness_path: None,
140 checks: Vec::new(),
141 }
142 }
143}
144
145#[derive(Clone, Default)]
146struct StaticAssetsConfig {
147 mounts: Vec<StaticMount>,
148 spa_fallback: Option<String>,
149 cache_control: Option<String>,
150 enable_compression: bool,
151 directory_index: Option<String>,
153 immutable_cache: bool,
156}
157
158#[derive(Clone)]
159struct StaticMount {
160 route_prefix: String,
161 directory: String,
162}
163
164#[cfg(feature = "tls")]
166#[derive(Clone)]
167struct TlsAcceptorConfig {
168 cert_path: String,
169 key_path: String,
170}
171
172#[derive(Serialize)]
173struct HealthReport {
174 status: &'static str,
175 probe: &'static str,
176 checks: Vec<HealthCheckReport>,
177}
178
179#[derive(Serialize)]
180struct HealthCheckReport {
181 name: String,
182 status: &'static str,
183 #[serde(skip_serializing_if = "Option::is_none")]
184 error: Option<String>,
185}
186
187fn timeout_middleware(timeout: Duration) -> ServiceLayer {
188 Arc::new(move |inner: BoxHttpService| {
189 BoxService::new(move |req: Request<Incoming>| {
190 let inner = inner.clone();
191 async move {
192 match tokio::time::timeout(timeout, inner.call(req)).await {
193 Ok(response) => response,
194 Err(_) => Ok(Response::builder()
195 .status(StatusCode::REQUEST_TIMEOUT)
196 .body(
197 Full::new(Bytes::from("Request Timeout"))
198 .map_err(|never| match never {})
199 .boxed(),
200 )
201 .expect("valid HTTP response construction")),
202 }
203 }
204 })
205 })
206}
207
208fn request_id_middleware() -> ServiceLayer {
209 Arc::new(move |inner: BoxHttpService| {
210 BoxService::new(move |req: Request<Incoming>| {
211 let inner = inner.clone();
212 async move {
213 let mut req = req;
214 let request_id = req
215 .headers()
216 .get(REQUEST_ID_HEADER)
217 .cloned()
218 .unwrap_or_else(|| {
219 http::HeaderValue::from_str(&uuid::Uuid::new_v4().to_string())
220 .unwrap_or_else(|_| {
221 http::HeaderValue::from_static("request-id-unavailable")
222 })
223 });
224 req.headers_mut()
225 .insert(REQUEST_ID_HEADER, request_id.clone());
226 let mut response = inner.call(req).await?;
227 response
228 .headers_mut()
229 .insert(REQUEST_ID_HEADER, request_id);
230 Ok(response)
231 }
232 })
233 })
234}
235
236#[derive(Clone, Debug, Default, PartialEq, Eq)]
237pub struct PathParams {
238 values: HashMap<String, String>,
239}
240
241#[derive(Clone, Debug, PartialEq, Eq)]
243pub struct HttpRouteDescriptor {
244 method: Method,
245 path_pattern: String,
246}
247
248impl HttpRouteDescriptor {
249 pub fn new(method: Method, path_pattern: impl Into<String>) -> Self {
250 Self {
251 method,
252 path_pattern: path_pattern.into(),
253 }
254 }
255
256 pub fn method(&self) -> &Method {
257 &self.method
258 }
259
260 pub fn path_pattern(&self) -> &str {
261 &self.path_pattern
262 }
263}
264
265#[derive(Clone, Debug, PartialEq, Eq, Serialize)]
267pub struct WebSocketSessionContext {
268 connection_id: uuid::Uuid,
269 path: String,
270 query: Option<String>,
271}
272
273impl WebSocketSessionContext {
274 pub fn connection_id(&self) -> uuid::Uuid {
275 self.connection_id
276 }
277
278 pub fn path(&self) -> &str {
279 &self.path
280 }
281
282 pub fn query(&self) -> Option<&str> {
283 self.query.as_deref()
284 }
285}
286
287#[derive(Clone, Debug, PartialEq, Eq)]
289pub enum WebSocketEvent {
290 Text(String),
291 Binary(Vec<u8>),
292 Ping(Vec<u8>),
293 Pong(Vec<u8>),
294 Close,
295}
296
297impl WebSocketEvent {
298 pub fn text(value: impl Into<String>) -> Self {
299 Self::Text(value.into())
300 }
301
302 pub fn binary(value: impl Into<Vec<u8>>) -> Self {
303 Self::Binary(value.into())
304 }
305
306 pub fn json<T>(value: &T) -> Result<Self, serde_json::Error>
307 where
308 T: Serialize,
309 {
310 let text = serde_json::to_string(value)?;
311 Ok(Self::Text(text))
312 }
313}
314
315#[derive(Debug, thiserror::Error)]
316pub enum WebSocketError {
317 #[error("websocket wire error: {0}")]
318 Wire(#[from] WsWireError),
319 #[error("json serialization failed: {0}")]
320 JsonSerialize(#[source] serde_json::Error),
321 #[error("json deserialization failed: {0}")]
322 JsonDeserialize(#[source] serde_json::Error),
323 #[error("expected text or binary frame for json payload")]
324 NonDataFrame,
325}
326
327type WsServerStream = WebSocketStream<TokioIo<Upgraded>>;
328type WsServerSink = futures_util::stream::SplitSink<WsServerStream, WsWireMessage>;
329type WsServerSource = futures_util::stream::SplitStream<WsServerStream>;
330
331pub struct WebSocketConnection {
333 sink: Mutex<WsServerSink>,
334 source: Mutex<WsServerSource>,
335 session: WebSocketSessionContext,
336}
337
338impl WebSocketConnection {
339 fn new(stream: WsServerStream, session: WebSocketSessionContext) -> Self {
340 let (sink, source) = stream.split();
341 Self {
342 sink: Mutex::new(sink),
343 source: Mutex::new(source),
344 session,
345 }
346 }
347
348 pub fn session(&self) -> &WebSocketSessionContext {
349 &self.session
350 }
351
352 pub async fn send(&self, event: WebSocketEvent) -> Result<(), WebSocketError> {
353 let mut sink = self.sink.lock().await;
354 sink.send(event.into_wire_message()).await?;
355 Ok(())
356 }
357
358 pub async fn send_json<T>(&self, value: &T) -> Result<(), WebSocketError>
359 where
360 T: Serialize,
361 {
362 let event = WebSocketEvent::json(value).map_err(WebSocketError::JsonSerialize)?;
363 self.send(event).await
364 }
365
366 pub async fn next_json<T>(&mut self) -> Result<Option<T>, WebSocketError>
367 where
368 T: DeserializeOwned,
369 {
370 let Some(event) = self.recv_event().await? else {
371 return Ok(None);
372 };
373 match event {
374 WebSocketEvent::Text(text) => serde_json::from_str(&text)
375 .map(Some)
376 .map_err(WebSocketError::JsonDeserialize),
377 WebSocketEvent::Binary(bytes) => serde_json::from_slice(&bytes)
378 .map(Some)
379 .map_err(WebSocketError::JsonDeserialize),
380 _ => Err(WebSocketError::NonDataFrame),
381 }
382 }
383
384 async fn recv_event(&mut self) -> Result<Option<WebSocketEvent>, WsWireError> {
385 let mut source = self.source.lock().await;
386 while let Some(item) = source.next().await {
387 let message = item?;
388 if let Some(event) = WebSocketEvent::from_wire_message(message) {
389 return Ok(Some(event));
390 }
391 }
392 Ok(None)
393 }
394}
395
396impl WebSocketEvent {
397 fn from_wire_message(message: WsWireMessage) -> Option<Self> {
398 match message {
399 WsWireMessage::Text(value) => Some(Self::Text(value.to_string())),
400 WsWireMessage::Binary(value) => Some(Self::Binary(value.to_vec())),
401 WsWireMessage::Ping(value) => Some(Self::Ping(value.to_vec())),
402 WsWireMessage::Pong(value) => Some(Self::Pong(value.to_vec())),
403 WsWireMessage::Close(_) => Some(Self::Close),
404 WsWireMessage::Frame(_) => None,
405 }
406 }
407
408 fn into_wire_message(self) -> WsWireMessage {
409 match self {
410 Self::Text(value) => WsWireMessage::Text(value),
411 Self::Binary(value) => WsWireMessage::Binary(value),
412 Self::Ping(value) => WsWireMessage::Ping(value),
413 Self::Pong(value) => WsWireMessage::Pong(value),
414 Self::Close => WsWireMessage::Close(None),
415 }
416 }
417}
418
419#[async_trait::async_trait]
420impl EventSource<WebSocketEvent> for WebSocketConnection {
421 async fn next_event(&mut self) -> Option<WebSocketEvent> {
422 match self.recv_event().await {
423 Ok(event) => event,
424 Err(error) => {
425 tracing::warn!(ranvier.ws.error = %error, "websocket source read failed");
426 None
427 }
428 }
429 }
430}
431
432#[async_trait::async_trait]
433impl EventSink<WebSocketEvent> for WebSocketConnection {
434 type Error = WebSocketError;
435
436 async fn send_event(&self, event: WebSocketEvent) -> Result<(), Self::Error> {
437 self.send(event).await
438 }
439}
440
441#[async_trait::async_trait]
442impl EventSink<String> for WebSocketConnection {
443 type Error = WebSocketError;
444
445 async fn send_event(&self, event: String) -> Result<(), Self::Error> {
446 self.send(WebSocketEvent::Text(event)).await
447 }
448}
449
450#[async_trait::async_trait]
451impl EventSink<Vec<u8>> for WebSocketConnection {
452 type Error = WebSocketError;
453
454 async fn send_event(&self, event: Vec<u8>) -> Result<(), Self::Error> {
455 self.send(WebSocketEvent::Binary(event)).await
456 }
457}
458
459impl PathParams {
460 pub fn new(values: HashMap<String, String>) -> Self {
461 Self { values }
462 }
463
464 pub fn get(&self, key: &str) -> Option<&str> {
465 self.values.get(key).map(String::as_str)
466 }
467
468 pub fn as_map(&self) -> &HashMap<String, String> {
469 &self.values
470 }
471
472 pub fn into_inner(self) -> HashMap<String, String> {
473 self.values
474 }
475}
476
477#[derive(Clone, Debug, PartialEq, Eq)]
478enum RouteSegment {
479 Static(String),
480 Param(String),
481 Wildcard(String),
482}
483
484#[derive(Clone, Debug, PartialEq, Eq)]
485struct RoutePattern {
486 raw: String,
487 segments: Vec<RouteSegment>,
488}
489
490impl RoutePattern {
491 fn parse(path: &str) -> Self {
492 let segments = path_segments(path)
493 .into_iter()
494 .map(|segment| {
495 if let Some(name) = segment.strip_prefix(':') {
496 if !name.is_empty() {
497 return RouteSegment::Param(name.to_string());
498 }
499 }
500 if let Some(name) = segment.strip_prefix('*') {
501 if !name.is_empty() {
502 return RouteSegment::Wildcard(name.to_string());
503 }
504 }
505 RouteSegment::Static(segment.to_string())
506 })
507 .collect();
508
509 Self {
510 raw: path.to_string(),
511 segments,
512 }
513 }
514
515 fn match_path(&self, path: &str) -> Option<PathParams> {
516 let mut params = HashMap::new();
517 let path_segments = path_segments(path);
518 let mut pattern_index = 0usize;
519 let mut path_index = 0usize;
520
521 while pattern_index < self.segments.len() {
522 match &self.segments[pattern_index] {
523 RouteSegment::Static(expected) => {
524 let actual = path_segments.get(path_index)?;
525 if actual != expected {
526 return None;
527 }
528 pattern_index += 1;
529 path_index += 1;
530 }
531 RouteSegment::Param(name) => {
532 let actual = path_segments.get(path_index)?;
533 params.insert(name.clone(), (*actual).to_string());
534 pattern_index += 1;
535 path_index += 1;
536 }
537 RouteSegment::Wildcard(name) => {
538 let remaining = path_segments[path_index..].join("/");
539 params.insert(name.clone(), remaining);
540 pattern_index += 1;
541 path_index = path_segments.len();
542 break;
543 }
544 }
545 }
546
547 if pattern_index == self.segments.len() && path_index == path_segments.len() {
548 Some(PathParams::new(params))
549 } else {
550 None
551 }
552 }
553}
554
555#[derive(Clone)]
558struct BodyBytes(Bytes);
559
560#[derive(Clone)]
561struct RouteEntry<R> {
562 method: Method,
563 pattern: RoutePattern,
564 handler: RouteHandler<R>,
565 layers: Arc<Vec<ServiceLayer>>,
566 apply_global_layers: bool,
567 needs_body: bool,
570}
571
572fn path_segments(path: &str) -> Vec<&str> {
573 if path == "/" {
574 return Vec::new();
575 }
576
577 path.trim_matches('/')
578 .split('/')
579 .filter(|segment| !segment.is_empty())
580 .collect()
581}
582
583fn normalize_route_path(path: String) -> String {
584 if path.is_empty() {
585 return "/".to_string();
586 }
587 if path.starts_with('/') {
588 path
589 } else {
590 format!("/{path}")
591 }
592}
593
594fn find_matching_route<'a, R>(
595 routes: &'a [RouteEntry<R>],
596 method: &Method,
597 path: &str,
598) -> Option<(&'a RouteEntry<R>, PathParams)> {
599 for entry in routes {
600 if entry.method != *method {
601 continue;
602 }
603 if let Some(params) = entry.pattern.match_path(path) {
604 return Some((entry, params));
605 }
606 }
607 None
608}
609
610fn header_contains_token(
611 headers: &http::HeaderMap,
612 name: http::header::HeaderName,
613 token: &str,
614) -> bool {
615 headers
616 .get(name)
617 .and_then(|value| value.to_str().ok())
618 .map(|value| {
619 value
620 .split(',')
621 .any(|part| part.trim().eq_ignore_ascii_case(token))
622 })
623 .unwrap_or(false)
624}
625
626fn websocket_session_from_request<B>(req: &Request<B>) -> WebSocketSessionContext {
627 WebSocketSessionContext {
628 connection_id: uuid::Uuid::new_v4(),
629 path: req.uri().path().to_string(),
630 query: req.uri().query().map(str::to_string),
631 }
632}
633
634fn websocket_accept_key(client_key: &str) -> String {
635 let mut hasher = Sha1::new();
636 hasher.update(client_key.as_bytes());
637 hasher.update(WS_GUID.as_bytes());
638 let digest = hasher.finalize();
639 base64::engine::general_purpose::STANDARD.encode(digest)
640}
641
642fn websocket_bad_request(message: &'static str) -> HttpResponse {
643 Response::builder()
644 .status(StatusCode::BAD_REQUEST)
645 .body(
646 Full::new(Bytes::from(message))
647 .map_err(|never| match never {})
648 .boxed(),
649 )
650 .unwrap_or_else(|_| {
651 Response::new(
652 Full::new(Bytes::new())
653 .map_err(|never| match never {})
654 .boxed(),
655 )
656 })
657}
658
659fn websocket_upgrade_response<B>(
660 req: &mut Request<B>,
661) -> Result<(HttpResponse, hyper::upgrade::OnUpgrade), HttpResponse> {
662 if req.method() != Method::GET {
663 return Err(websocket_bad_request(
664 "WebSocket upgrade requires GET method",
665 ));
666 }
667
668 if !header_contains_token(req.headers(), http::header::CONNECTION, "upgrade") {
669 return Err(websocket_bad_request(
670 "Missing Connection: upgrade header for WebSocket",
671 ));
672 }
673
674 if !header_contains_token(req.headers(), http::header::UPGRADE, WS_UPGRADE_TOKEN) {
675 return Err(websocket_bad_request("Missing Upgrade: websocket header"));
676 }
677
678 if let Some(version) = req.headers().get("sec-websocket-version") {
679 if version != "13" {
680 return Err(websocket_bad_request(
681 "Unsupported Sec-WebSocket-Version (expected 13)",
682 ));
683 }
684 }
685
686 let Some(client_key) = req
687 .headers()
688 .get("sec-websocket-key")
689 .and_then(|value| value.to_str().ok())
690 else {
691 return Err(websocket_bad_request(
692 "Missing Sec-WebSocket-Key header for WebSocket",
693 ));
694 };
695
696 let accept_key = websocket_accept_key(client_key);
697 let on_upgrade = hyper::upgrade::on(req);
698 let response = Response::builder()
699 .status(StatusCode::SWITCHING_PROTOCOLS)
700 .header(http::header::UPGRADE, WS_UPGRADE_TOKEN)
701 .header(http::header::CONNECTION, "Upgrade")
702 .header("sec-websocket-accept", accept_key)
703 .body(
704 Full::new(Bytes::new())
705 .map_err(|never| match never {})
706 .boxed(),
707 )
708 .unwrap_or_else(|_| {
709 Response::new(
710 Full::new(Bytes::new())
711 .map_err(|never| match never {})
712 .boxed(),
713 )
714 });
715
716 Ok((response, on_upgrade))
717}
718
719pub struct HttpIngress<R = ()> {
725 addr: Option<String>,
727 routes: Vec<RouteEntry<R>>,
729 fallback: Option<RouteHandler<R>>,
731 layers: Vec<ServiceLayer>,
733 on_start: Option<LifecycleHook>,
735 on_shutdown: Option<LifecycleHook>,
737 graceful_shutdown_timeout: Duration,
739 bus_injectors: Vec<BusInjector>,
741 static_assets: StaticAssetsConfig,
743 health: HealthConfig<R>,
745 #[cfg(feature = "http3")]
746 http3_config: Option<crate::http3::Http3Config>,
747 #[cfg(feature = "http3")]
748 alt_svc_h3_port: Option<u16>,
749 #[cfg(feature = "tls")]
751 tls_config: Option<TlsAcceptorConfig>,
752 active_intervention: bool,
754 policy_registry: Option<ranvier_core::policy::PolicyRegistry>,
756 guard_execs: Vec<Arc<dyn GuardExec>>,
758 guard_response_extractors: Vec<ResponseExtractorFn>,
760 guard_body_transforms: Vec<ResponseBodyTransformFn>,
762 preflight_config: Option<PreflightConfig>,
764 _phantom: std::marker::PhantomData<R>,
765}
766
767impl<R> HttpIngress<R>
768where
769 R: ranvier_core::transition::ResourceRequirement + Clone + Send + Sync + 'static,
770{
771 pub fn new() -> Self {
773 Self {
774 addr: None,
775 routes: Vec::new(),
776 fallback: None,
777 layers: Vec::new(),
778 on_start: None,
779 on_shutdown: None,
780 graceful_shutdown_timeout: Duration::from_secs(30),
781 bus_injectors: Vec::new(),
782 static_assets: StaticAssetsConfig::default(),
783 health: HealthConfig::default(),
784 #[cfg(feature = "tls")]
785 tls_config: None,
786 #[cfg(feature = "http3")]
787 http3_config: None,
788 #[cfg(feature = "http3")]
789 alt_svc_h3_port: None,
790 active_intervention: false,
791 policy_registry: None,
792 guard_execs: Vec::new(),
793 guard_response_extractors: Vec::new(),
794 guard_body_transforms: Vec::new(),
795 preflight_config: None,
796 _phantom: std::marker::PhantomData,
797 }
798 }
799
800 pub fn bind(mut self, addr: impl Into<String>) -> Self {
804 self.addr = Some(addr.into());
805 self
806 }
807
808 pub fn active_intervention(mut self) -> Self {
814 self.active_intervention = true;
815 self
816 }
817
818 pub fn policy_registry(mut self, registry: ranvier_core::policy::PolicyRegistry) -> Self {
820 self.policy_registry = Some(registry);
821 self
822 }
823
824 pub fn on_start<F>(mut self, callback: F) -> Self
828 where
829 F: Fn() + Send + Sync + 'static,
830 {
831 self.on_start = Some(Arc::new(callback));
832 self
833 }
834
835 pub fn on_shutdown<F>(mut self, callback: F) -> Self
837 where
838 F: Fn() + Send + Sync + 'static,
839 {
840 self.on_shutdown = Some(Arc::new(callback));
841 self
842 }
843
844 pub fn graceful_shutdown(mut self, timeout: Duration) -> Self {
846 self.graceful_shutdown_timeout = timeout;
847 self
848 }
849
850 pub fn config(mut self, config: &ranvier_core::config::RanvierConfig) -> Self {
856 self.addr = Some(config.bind_addr());
857 self.graceful_shutdown_timeout = config.shutdown_timeout();
858 config.init_telemetry();
859 self
860 }
861
862 #[cfg(feature = "tls")]
864 pub fn tls(mut self, cert_path: impl Into<String>, key_path: impl Into<String>) -> Self {
865 self.tls_config = Some(TlsAcceptorConfig {
866 cert_path: cert_path.into(),
867 key_path: key_path.into(),
868 });
869 self
870 }
871
872 pub fn timeout_layer(mut self, timeout: Duration) -> Self {
877 self.layers.push(timeout_middleware(timeout));
878 self
879 }
880
881 pub fn request_id_layer(mut self) -> Self {
885 self.layers.push(request_id_middleware());
886 self
887 }
888
889 pub fn bus_injector<F>(mut self, injector: F) -> Self
894 where
895 F: Fn(&http::request::Parts, &mut Bus) + Send + Sync + 'static,
896 {
897 self.bus_injectors.push(Arc::new(injector));
898 self
899 }
900
901 pub fn guard(mut self, guard: impl GuardIntegration) -> Self {
921 let registration = guard.register();
922 for injector in registration.bus_injectors {
923 self.bus_injectors.push(injector);
924 }
925 self.guard_execs.push(registration.exec);
926 if let Some(extractor) = registration.response_extractor {
927 self.guard_response_extractors.push(extractor);
928 }
929 if let Some(transform) = registration.response_body_transform {
930 self.guard_body_transforms.push(transform);
931 }
932 if registration.handles_preflight {
933 if let Some(config) = registration.preflight_config {
934 self.preflight_config = Some(config);
935 }
936 }
937 self
938 }
939
940 #[cfg(feature = "http3")]
942 pub fn enable_http3(mut self, config: crate::http3::Http3Config) -> Self {
943 self.http3_config = Some(config);
944 self
945 }
946
947 #[cfg(feature = "http3")]
949 pub fn alt_svc_h3(mut self, port: u16) -> Self {
950 self.alt_svc_h3_port = Some(port);
951 self
952 }
953
954 pub fn route_descriptors(&self) -> Vec<HttpRouteDescriptor> {
958 let mut descriptors = self
959 .routes
960 .iter()
961 .map(|entry| HttpRouteDescriptor::new(entry.method.clone(), entry.pattern.raw.clone()))
962 .collect::<Vec<_>>();
963
964 if let Some(path) = &self.health.health_path {
965 descriptors.push(HttpRouteDescriptor::new(Method::GET, path.clone()));
966 }
967 if let Some(path) = &self.health.readiness_path {
968 descriptors.push(HttpRouteDescriptor::new(Method::GET, path.clone()));
969 }
970 if let Some(path) = &self.health.liveness_path {
971 descriptors.push(HttpRouteDescriptor::new(Method::GET, path.clone()));
972 }
973
974 descriptors
975 }
976
977 pub fn serve_dir(
983 mut self,
984 route_prefix: impl Into<String>,
985 directory: impl Into<String>,
986 ) -> Self {
987 self.static_assets.mounts.push(StaticMount {
988 route_prefix: normalize_route_path(route_prefix.into()),
989 directory: directory.into(),
990 });
991 if self.static_assets.cache_control.is_none() {
992 self.static_assets.cache_control = Some("public, max-age=3600".to_string());
993 }
994 self
995 }
996
997 pub fn spa_fallback(mut self, file_path: impl Into<String>) -> Self {
1001 self.static_assets.spa_fallback = Some(file_path.into());
1002 self
1003 }
1004
1005 pub fn static_cache_control(mut self, cache_control: impl Into<String>) -> Self {
1007 self.static_assets.cache_control = Some(cache_control.into());
1008 self
1009 }
1010
1011 pub fn directory_index(mut self, filename: impl Into<String>) -> Self {
1019 self.static_assets.directory_index = Some(filename.into());
1020 self
1021 }
1022
1023 pub fn immutable_cache(mut self) -> Self {
1028 self.static_assets.immutable_cache = true;
1029 self
1030 }
1031
1032 pub fn compression_layer(mut self) -> Self {
1034 self.static_assets.enable_compression = true;
1035 self
1036 }
1037
1038 pub fn ws<H, Fut>(mut self, path: impl Into<String>, handler: H) -> Self
1047 where
1048 H: Fn(WebSocketConnection, Arc<R>, Bus) -> Fut + Send + Sync + 'static,
1049 Fut: Future<Output = ()> + Send + 'static,
1050 {
1051 let path_str: String = path.into();
1052 let ws_handler: WsSessionHandler<R> = Arc::new(move |connection, resources, bus| {
1053 Box::pin(handler(connection, resources, bus))
1054 });
1055 let bus_injectors = Arc::new(self.bus_injectors.clone());
1056 let ws_guard_execs = Arc::new(self.guard_execs.clone());
1057 let path_for_pattern = path_str.clone();
1058 let path_for_handler = path_str;
1059
1060 let route_handler: RouteHandler<R> =
1061 Arc::new(move |parts: http::request::Parts, res: &R| {
1062 let ws_handler = ws_handler.clone();
1063 let bus_injectors = bus_injectors.clone();
1064 let ws_guard_execs = ws_guard_execs.clone();
1065 let resources = Arc::new(res.clone());
1066 let path = path_for_handler.clone();
1067
1068 Box::pin(async move {
1069 let request_id = uuid::Uuid::new_v4().to_string();
1070 let span = tracing::info_span!(
1071 "WebSocketUpgrade",
1072 ranvier.ws.path = %path,
1073 ranvier.ws.request_id = %request_id
1074 );
1075
1076 async move {
1077 let mut bus = Bus::new();
1078 for injector in bus_injectors.iter() {
1079 injector(&parts, &mut bus);
1080 }
1081 for guard_exec in ws_guard_execs.iter() {
1082 if let Err(rejection) = guard_exec.exec_guard(&mut bus).await {
1083 return json_error_response(rejection.status, &rejection.message);
1084 }
1085 }
1086
1087 let mut req = Request::from_parts(parts, ());
1089 let session = websocket_session_from_request(&req);
1090 bus.insert(session.clone());
1091
1092 let (response, on_upgrade) = match websocket_upgrade_response(&mut req) {
1093 Ok(result) => result,
1094 Err(error_response) => return error_response,
1095 };
1096
1097 tokio::spawn(async move {
1098 match on_upgrade.await {
1099 Ok(upgraded) => {
1100 let stream = WebSocketStream::from_raw_socket(
1101 TokioIo::new(upgraded),
1102 tokio_tungstenite::tungstenite::protocol::Role::Server,
1103 None,
1104 )
1105 .await;
1106 let connection = WebSocketConnection::new(stream, session);
1107 ws_handler(connection, resources, bus).await;
1108 }
1109 Err(error) => {
1110 tracing::warn!(
1111 ranvier.ws.path = %path,
1112 ranvier.ws.error = %error,
1113 "websocket upgrade failed"
1114 );
1115 }
1116 }
1117 });
1118
1119 response
1120 }
1121 .instrument(span)
1122 .await
1123 }) as Pin<Box<dyn Future<Output = HttpResponse> + Send>>
1124 });
1125
1126 self.routes.push(RouteEntry {
1127 method: Method::GET,
1128 pattern: RoutePattern::parse(&path_for_pattern),
1129 handler: route_handler,
1130 layers: Arc::new(Vec::new()),
1131 apply_global_layers: true,
1132 needs_body: false,
1133 });
1134
1135 self
1136 }
1137
1138 pub fn health_endpoint(mut self, path: impl Into<String>) -> Self {
1145 self.health.health_path = Some(normalize_route_path(path.into()));
1146 self
1147 }
1148
1149 pub fn health_check<F, Fut, Err>(mut self, name: impl Into<String>, check: F) -> Self
1153 where
1154 F: Fn(Arc<R>) -> Fut + Send + Sync + 'static,
1155 Fut: Future<Output = Result<(), Err>> + Send + 'static,
1156 Err: ToString + Send + 'static,
1157 {
1158 if self.health.health_path.is_none() {
1159 self.health.health_path = Some("/health".to_string());
1160 }
1161
1162 let check_fn: HealthCheckFn<R> = Arc::new(move |resources: Arc<R>| {
1163 let fut = check(resources);
1164 Box::pin(async move { fut.await.map_err(|error| error.to_string()) })
1165 });
1166
1167 self.health.checks.push(NamedHealthCheck {
1168 name: name.into(),
1169 check: check_fn,
1170 });
1171 self
1172 }
1173
1174 pub fn readiness_liveness(
1176 mut self,
1177 readiness_path: impl Into<String>,
1178 liveness_path: impl Into<String>,
1179 ) -> Self {
1180 self.health.readiness_path = Some(normalize_route_path(readiness_path.into()));
1181 self.health.liveness_path = Some(normalize_route_path(liveness_path.into()));
1182 self
1183 }
1184
1185 pub fn readiness_liveness_default(self) -> Self {
1187 self.readiness_liveness("/ready", "/live")
1188 }
1189
1190 pub fn route<Out, E>(self, path: impl Into<String>, circuit: Axon<(), Out, E, R>) -> Self
1194 where
1195 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1196 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1197 {
1198 self.route_method(Method::GET, path, circuit)
1199 }
1200 pub fn route_method<Out, E>(
1209 self,
1210 method: Method,
1211 path: impl Into<String>,
1212 circuit: Axon<(), Out, E, R>,
1213 ) -> Self
1214 where
1215 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1216 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1217 {
1218 self.route_method_with_error(method, path, circuit, |error| {
1219 (
1220 StatusCode::INTERNAL_SERVER_ERROR,
1221 format!("Error: {:?}", error),
1222 )
1223 .into_response()
1224 })
1225 }
1226
1227 pub fn route_method_with_error<Out, E, H>(
1228 self,
1229 method: Method,
1230 path: impl Into<String>,
1231 circuit: Axon<(), Out, E, R>,
1232 error_handler: H,
1233 ) -> Self
1234 where
1235 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1236 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1237 H: Fn(&E) -> HttpResponse + Send + Sync + 'static,
1238 {
1239 self.route_method_with_error_and_layers(
1240 method,
1241 path,
1242 circuit,
1243 error_handler,
1244 Arc::new(Vec::new()),
1245 true,
1246 )
1247 }
1248
1249
1250
1251 fn route_method_with_error_and_layers<Out, E, H>(
1252 mut self,
1253 method: Method,
1254 path: impl Into<String>,
1255 circuit: Axon<(), Out, E, R>,
1256 error_handler: H,
1257 route_layers: Arc<Vec<ServiceLayer>>,
1258 apply_global_layers: bool,
1259 ) -> Self
1260 where
1261 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1262 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1263 H: Fn(&E) -> HttpResponse + Send + Sync + 'static,
1264 {
1265 let path_str: String = path.into();
1266 let circuit = Arc::new(circuit);
1267 let error_handler = Arc::new(error_handler);
1268 let route_bus_injectors = Arc::new(self.bus_injectors.clone());
1269 let route_guard_execs = Arc::new(self.guard_execs.clone());
1270 let route_response_extractors = Arc::new(self.guard_response_extractors.clone());
1271 let route_body_transforms = Arc::new(self.guard_body_transforms.clone());
1272 let path_for_pattern = path_str.clone();
1273 let path_for_handler = path_str;
1274 let method_for_pattern = method.clone();
1275 let method_for_handler = method;
1276
1277 let handler: RouteHandler<R> = Arc::new(move |parts: http::request::Parts, res: &R| {
1278 let circuit = circuit.clone();
1279 let error_handler = error_handler.clone();
1280 let route_bus_injectors = route_bus_injectors.clone();
1281 let route_guard_execs = route_guard_execs.clone();
1282 let route_response_extractors = route_response_extractors.clone();
1283 let route_body_transforms = route_body_transforms.clone();
1284 let res = res.clone();
1285 let path = path_for_handler.clone();
1286 let method = method_for_handler.clone();
1287
1288 Box::pin(async move {
1289 let request_id = uuid::Uuid::new_v4().to_string();
1290 let span = tracing::info_span!(
1291 "HTTPRequest",
1292 ranvier.http.method = %method,
1293 ranvier.http.path = %path,
1294 ranvier.http.request_id = %request_id
1295 );
1296
1297 async move {
1298 let mut bus = Bus::new();
1299 for injector in route_bus_injectors.iter() {
1300 injector(&parts, &mut bus);
1301 }
1302 for guard_exec in route_guard_execs.iter() {
1303 if let Err(rejection) = guard_exec.exec_guard(&mut bus).await {
1304 let mut response = json_error_response(rejection.status, &rejection.message);
1305 for extractor in route_response_extractors.iter() {
1306 extractor(&bus, response.headers_mut());
1307 }
1308 return response;
1309 }
1310 }
1311 if let Some(cached) = bus.read::<ranvier_guard::IdempotencyCachedResponse>() {
1313 let body = Bytes::from(cached.body.clone());
1314 let mut response = Response::builder()
1315 .status(StatusCode::OK)
1316 .header("content-type", "application/json")
1317 .body(Full::new(body).map_err(|n: Infallible| match n {}).boxed())
1318 .unwrap();
1319 for extractor in route_response_extractors.iter() {
1320 extractor(&bus, response.headers_mut());
1321 }
1322 return response;
1323 }
1324 let result = if let Some(td) = bus.read::<ranvier_guard::TimeoutDeadline>() {
1326 let remaining = td.remaining();
1327 if remaining.is_zero() {
1328 let mut response = json_error_response(
1329 StatusCode::REQUEST_TIMEOUT,
1330 "Request timeout: pipeline deadline exceeded",
1331 );
1332 for extractor in route_response_extractors.iter() {
1333 extractor(&bus, response.headers_mut());
1334 }
1335 return response;
1336 }
1337 match tokio::time::timeout(remaining, circuit.execute((), &res, &mut bus)).await {
1338 Ok(result) => result,
1339 Err(_) => {
1340 let mut response = json_error_response(
1341 StatusCode::REQUEST_TIMEOUT,
1342 "Request timeout: pipeline deadline exceeded",
1343 );
1344 for extractor in route_response_extractors.iter() {
1345 extractor(&bus, response.headers_mut());
1346 }
1347 return response;
1348 }
1349 }
1350 } else {
1351 circuit.execute((), &res, &mut bus).await
1352 };
1353 let mut response = outcome_to_response_with_error(result, |error| error_handler(error));
1354 for extractor in route_response_extractors.iter() {
1355 extractor(&bus, response.headers_mut());
1356 }
1357 if !route_body_transforms.is_empty() {
1358 response = apply_body_transforms(response, &bus, &route_body_transforms).await;
1359 }
1360 response
1361 }
1362 .instrument(span)
1363 .await
1364 }) as Pin<Box<dyn Future<Output = HttpResponse> + Send>>
1365 });
1366
1367 self.routes.push(RouteEntry {
1368 method: method_for_pattern,
1369 pattern: RoutePattern::parse(&path_for_pattern),
1370 handler,
1371 layers: route_layers,
1372 apply_global_layers,
1373 needs_body: false,
1374 });
1375 self
1376 }
1377
1378 fn route_method_typed<T, Out, E>(
1382 mut self,
1383 method: Method,
1384 path: impl Into<String>,
1385 circuit: Axon<T, Out, E, R>,
1386 ) -> Self
1387 where
1388 T: serde::de::DeserializeOwned + Send + Sync + serde::Serialize + 'static,
1389 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1390 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1391 {
1392 let path_str: String = path.into();
1393 let circuit = Arc::new(circuit);
1394 let route_bus_injectors = Arc::new(self.bus_injectors.clone());
1395 let route_guard_execs = Arc::new(self.guard_execs.clone());
1396 let route_response_extractors = Arc::new(self.guard_response_extractors.clone());
1397 let route_body_transforms = Arc::new(self.guard_body_transforms.clone());
1398 let path_for_pattern = path_str.clone();
1399 let path_for_handler = path_str;
1400 let method_for_pattern = method.clone();
1401 let method_for_handler = method;
1402
1403 let handler: RouteHandler<R> = Arc::new(move |parts: http::request::Parts, res: &R| {
1404 let circuit = circuit.clone();
1405 let route_bus_injectors = route_bus_injectors.clone();
1406 let route_guard_execs = route_guard_execs.clone();
1407 let route_response_extractors = route_response_extractors.clone();
1408 let route_body_transforms = route_body_transforms.clone();
1409 let res = res.clone();
1410 let path = path_for_handler.clone();
1411 let method = method_for_handler.clone();
1412
1413 Box::pin(async move {
1414 let request_id = uuid::Uuid::new_v4().to_string();
1415 let span = tracing::info_span!(
1416 "HTTPRequest",
1417 ranvier.http.method = %method,
1418 ranvier.http.path = %path,
1419 ranvier.http.request_id = %request_id
1420 );
1421
1422 async move {
1423 let body_bytes = parts
1425 .extensions
1426 .get::<BodyBytes>()
1427 .map(|b| b.0.clone())
1428 .unwrap_or_default();
1429
1430 let input: T = match serde_json::from_slice(&body_bytes) {
1432 Ok(v) => v,
1433 Err(e) => {
1434 return json_error_response(
1435 StatusCode::BAD_REQUEST,
1436 &format!("Invalid request body: {}", e),
1437 );
1438 }
1439 };
1440
1441 let mut bus = Bus::new();
1442 for injector in route_bus_injectors.iter() {
1443 injector(&parts, &mut bus);
1444 }
1445 for guard_exec in route_guard_execs.iter() {
1446 if let Err(rejection) = guard_exec.exec_guard(&mut bus).await {
1447 let mut response = json_error_response(rejection.status, &rejection.message);
1448 for extractor in route_response_extractors.iter() {
1449 extractor(&bus, response.headers_mut());
1450 }
1451 return response;
1452 }
1453 }
1454 if let Some(cached) = bus.read::<ranvier_guard::IdempotencyCachedResponse>() {
1456 let body = Bytes::from(cached.body.clone());
1457 let mut response = Response::builder()
1458 .status(StatusCode::OK)
1459 .header("content-type", "application/json")
1460 .body(Full::new(body).map_err(|n: Infallible| match n {}).boxed())
1461 .unwrap();
1462 for extractor in route_response_extractors.iter() {
1463 extractor(&bus, response.headers_mut());
1464 }
1465 return response;
1466 }
1467 let result = if let Some(td) = bus.read::<ranvier_guard::TimeoutDeadline>() {
1469 let remaining = td.remaining();
1470 if remaining.is_zero() {
1471 let mut response = json_error_response(
1472 StatusCode::REQUEST_TIMEOUT,
1473 "Request timeout: pipeline deadline exceeded",
1474 );
1475 for extractor in route_response_extractors.iter() {
1476 extractor(&bus, response.headers_mut());
1477 }
1478 return response;
1479 }
1480 match tokio::time::timeout(remaining, circuit.execute(input, &res, &mut bus)).await {
1481 Ok(result) => result,
1482 Err(_) => {
1483 let mut response = json_error_response(
1484 StatusCode::REQUEST_TIMEOUT,
1485 "Request timeout: pipeline deadline exceeded",
1486 );
1487 for extractor in route_response_extractors.iter() {
1488 extractor(&bus, response.headers_mut());
1489 }
1490 return response;
1491 }
1492 }
1493 } else {
1494 circuit.execute(input, &res, &mut bus).await
1495 };
1496 let mut response = outcome_to_response_with_error(result, |error| {
1497 if cfg!(debug_assertions) {
1498 (
1499 StatusCode::INTERNAL_SERVER_ERROR,
1500 format!("Error: {:?}", error),
1501 )
1502 .into_response()
1503 } else {
1504 json_error_response(
1505 StatusCode::INTERNAL_SERVER_ERROR,
1506 "Internal server error",
1507 )
1508 }
1509 });
1510 for extractor in route_response_extractors.iter() {
1511 extractor(&bus, response.headers_mut());
1512 }
1513 if !route_body_transforms.is_empty() {
1514 response = apply_body_transforms(response, &bus, &route_body_transforms).await;
1515 }
1516 response
1517 }
1518 .instrument(span)
1519 .await
1520 }) as Pin<Box<dyn Future<Output = HttpResponse> + Send>>
1521 });
1522
1523 self.routes.push(RouteEntry {
1524 method: method_for_pattern,
1525 pattern: RoutePattern::parse(&path_for_pattern),
1526 handler,
1527 layers: Arc::new(Vec::new()),
1528 apply_global_layers: true,
1529 needs_body: true,
1530 });
1531 self
1532 }
1533
1534 pub fn get<Out, E>(self, path: impl Into<String>, circuit: Axon<(), Out, E, R>) -> Self
1535 where
1536 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1537 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1538 {
1539 self.route_method(Method::GET, path, circuit)
1540 }
1541
1542 pub fn get_with_error<Out, E, H>(
1543 self,
1544 path: impl Into<String>,
1545 circuit: Axon<(), Out, E, R>,
1546 error_handler: H,
1547 ) -> Self
1548 where
1549 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1550 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1551 H: Fn(&E) -> HttpResponse + Send + Sync + 'static,
1552 {
1553 self.route_method_with_error(Method::GET, path, circuit, error_handler)
1554 }
1555
1556 pub fn post<Out, E>(self, path: impl Into<String>, circuit: Axon<(), Out, E, R>) -> Self
1557 where
1558 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1559 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1560 {
1561 self.route_method(Method::POST, path, circuit)
1562 }
1563
1564 pub fn post_typed<T, Out, E>(
1580 self,
1581 path: impl Into<String>,
1582 circuit: Axon<T, Out, E, R>,
1583 ) -> Self
1584 where
1585 T: serde::de::DeserializeOwned + Send + Sync + serde::Serialize + 'static,
1586 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1587 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1588 {
1589 self.route_method_typed::<T, Out, E>(Method::POST, path, circuit)
1590 }
1591
1592 pub fn put_typed<T, Out, E>(
1596 self,
1597 path: impl Into<String>,
1598 circuit: Axon<T, Out, E, R>,
1599 ) -> Self
1600 where
1601 T: serde::de::DeserializeOwned + Send + Sync + serde::Serialize + 'static,
1602 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1603 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1604 {
1605 self.route_method_typed::<T, Out, E>(Method::PUT, path, circuit)
1606 }
1607
1608 pub fn patch_typed<T, Out, E>(
1612 self,
1613 path: impl Into<String>,
1614 circuit: Axon<T, Out, E, R>,
1615 ) -> Self
1616 where
1617 T: serde::de::DeserializeOwned + Send + Sync + serde::Serialize + 'static,
1618 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1619 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1620 {
1621 self.route_method_typed::<T, Out, E>(Method::PATCH, path, circuit)
1622 }
1623
1624 pub fn put<Out, E>(self, path: impl Into<String>, circuit: Axon<(), Out, E, R>) -> Self
1625 where
1626 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1627 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1628 {
1629 self.route_method(Method::PUT, path, circuit)
1630 }
1631
1632 pub fn delete<Out, E>(self, path: impl Into<String>, circuit: Axon<(), Out, E, R>) -> Self
1633 where
1634 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1635 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1636 {
1637 self.route_method(Method::DELETE, path, circuit)
1638 }
1639
1640 pub fn patch<Out, E>(self, path: impl Into<String>, circuit: Axon<(), Out, E, R>) -> Self
1641 where
1642 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1643 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1644 {
1645 self.route_method(Method::PATCH, path, circuit)
1646 }
1647
1648 pub fn post_with_error<Out, E, H>(
1649 self,
1650 path: impl Into<String>,
1651 circuit: Axon<(), Out, E, R>,
1652 error_handler: H,
1653 ) -> Self
1654 where
1655 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1656 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1657 H: Fn(&E) -> HttpResponse + Send + Sync + 'static,
1658 {
1659 self.route_method_with_error(Method::POST, path, circuit, error_handler)
1660 }
1661
1662 pub fn put_with_error<Out, E, H>(
1663 self,
1664 path: impl Into<String>,
1665 circuit: Axon<(), Out, E, R>,
1666 error_handler: H,
1667 ) -> Self
1668 where
1669 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1670 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1671 H: Fn(&E) -> HttpResponse + Send + Sync + 'static,
1672 {
1673 self.route_method_with_error(Method::PUT, path, circuit, error_handler)
1674 }
1675
1676 pub fn delete_with_error<Out, E, H>(
1677 self,
1678 path: impl Into<String>,
1679 circuit: Axon<(), Out, E, R>,
1680 error_handler: H,
1681 ) -> Self
1682 where
1683 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1684 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1685 H: Fn(&E) -> HttpResponse + Send + Sync + 'static,
1686 {
1687 self.route_method_with_error(Method::DELETE, path, circuit, error_handler)
1688 }
1689
1690 pub fn patch_with_error<Out, E, H>(
1691 self,
1692 path: impl Into<String>,
1693 circuit: Axon<(), Out, E, R>,
1694 error_handler: H,
1695 ) -> Self
1696 where
1697 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1698 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1699 H: Fn(&E) -> HttpResponse + Send + Sync + 'static,
1700 {
1701 self.route_method_with_error(Method::PATCH, path, circuit, error_handler)
1702 }
1703
1704 fn route_method_with_extra_guards<Out, E>(
1710 mut self,
1711 method: Method,
1712 path: impl Into<String>,
1713 circuit: Axon<(), Out, E, R>,
1714 extra_guards: Vec<RegisteredGuard>,
1715 ) -> Self
1716 where
1717 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1718 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1719 {
1720 let saved_injectors = self.bus_injectors.len();
1722 let saved_execs = self.guard_execs.len();
1723 let saved_extractors = self.guard_response_extractors.len();
1724 let saved_transforms = self.guard_body_transforms.len();
1725
1726 for registration in extra_guards {
1728 for injector in registration.bus_injectors {
1729 self.bus_injectors.push(injector);
1730 }
1731 self.guard_execs.push(registration.exec);
1732 if let Some(extractor) = registration.response_extractor {
1733 self.guard_response_extractors.push(extractor);
1734 }
1735 if let Some(transform) = registration.response_body_transform {
1736 self.guard_body_transforms.push(transform);
1737 }
1738 }
1739
1740 self = self.route_method(method, path, circuit);
1742
1743 self.bus_injectors.truncate(saved_injectors);
1745 self.guard_execs.truncate(saved_execs);
1746 self.guard_response_extractors.truncate(saved_extractors);
1747 self.guard_body_transforms.truncate(saved_transforms);
1748
1749 self
1750 }
1751
1752 pub fn get_with_guards<Out, E>(
1770 self,
1771 path: impl Into<String>,
1772 circuit: Axon<(), Out, E, R>,
1773 extra_guards: Vec<RegisteredGuard>,
1774 ) -> Self
1775 where
1776 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1777 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1778 {
1779 self.route_method_with_extra_guards(Method::GET, path, circuit, extra_guards)
1780 }
1781
1782 pub fn post_with_guards<Out, E>(
1803 self,
1804 path: impl Into<String>,
1805 circuit: Axon<(), Out, E, R>,
1806 extra_guards: Vec<RegisteredGuard>,
1807 ) -> Self
1808 where
1809 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1810 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1811 {
1812 self.route_method_with_extra_guards(Method::POST, path, circuit, extra_guards)
1813 }
1814
1815 pub fn put_with_guards<Out, E>(
1817 self,
1818 path: impl Into<String>,
1819 circuit: Axon<(), Out, E, R>,
1820 extra_guards: Vec<RegisteredGuard>,
1821 ) -> Self
1822 where
1823 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1824 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1825 {
1826 self.route_method_with_extra_guards(Method::PUT, path, circuit, extra_guards)
1827 }
1828
1829 pub fn delete_with_guards<Out, E>(
1831 self,
1832 path: impl Into<String>,
1833 circuit: Axon<(), Out, E, R>,
1834 extra_guards: Vec<RegisteredGuard>,
1835 ) -> Self
1836 where
1837 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1838 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1839 {
1840 self.route_method_with_extra_guards(Method::DELETE, path, circuit, extra_guards)
1841 }
1842
1843 pub fn patch_with_guards<Out, E>(
1845 self,
1846 path: impl Into<String>,
1847 circuit: Axon<(), Out, E, R>,
1848 extra_guards: Vec<RegisteredGuard>,
1849 ) -> Self
1850 where
1851 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1852 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1853 {
1854 self.route_method_with_extra_guards(Method::PATCH, path, circuit, extra_guards)
1855 }
1856
1857 pub fn fallback<Out, E>(mut self, circuit: Axon<(), Out, E, R>) -> Self
1868 where
1869 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1870 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1871 {
1872 let circuit = Arc::new(circuit);
1873 let fallback_bus_injectors = Arc::new(self.bus_injectors.clone());
1874 let fallback_guard_execs = Arc::new(self.guard_execs.clone());
1875 let fallback_response_extractors = Arc::new(self.guard_response_extractors.clone());
1876 let fallback_body_transforms = Arc::new(self.guard_body_transforms.clone());
1877
1878 let handler: RouteHandler<R> = Arc::new(move |parts: http::request::Parts, res: &R| {
1879 let circuit = circuit.clone();
1880 let fallback_bus_injectors = fallback_bus_injectors.clone();
1881 let fallback_guard_execs = fallback_guard_execs.clone();
1882 let fallback_response_extractors = fallback_response_extractors.clone();
1883 let fallback_body_transforms = fallback_body_transforms.clone();
1884 let res = res.clone();
1885 Box::pin(async move {
1886 let request_id = uuid::Uuid::new_v4().to_string();
1887 let span = tracing::info_span!(
1888 "HTTPRequest",
1889 ranvier.http.method = "FALLBACK",
1890 ranvier.http.request_id = %request_id
1891 );
1892
1893 async move {
1894 let mut bus = Bus::new();
1895 for injector in fallback_bus_injectors.iter() {
1896 injector(&parts, &mut bus);
1897 }
1898 for guard_exec in fallback_guard_execs.iter() {
1899 if let Err(rejection) = guard_exec.exec_guard(&mut bus).await {
1900 let mut response = json_error_response(rejection.status, &rejection.message);
1901 for extractor in fallback_response_extractors.iter() {
1902 extractor(&bus, response.headers_mut());
1903 }
1904 return response;
1905 }
1906 }
1907 let result: ranvier_core::Outcome<Out, E> =
1908 circuit.execute((), &res, &mut bus).await;
1909
1910 let mut response = match result {
1911 Outcome::Next(output) => {
1912 let mut response = output.into_response();
1913 *response.status_mut() = StatusCode::NOT_FOUND;
1914 response
1915 }
1916 _ => Response::builder()
1917 .status(StatusCode::NOT_FOUND)
1918 .body(
1919 Full::new(Bytes::from("Not Found"))
1920 .map_err(|never| match never {})
1921 .boxed(),
1922 )
1923 .expect("valid HTTP response construction"),
1924 };
1925 for extractor in fallback_response_extractors.iter() {
1926 extractor(&bus, response.headers_mut());
1927 }
1928 if !fallback_body_transforms.is_empty() {
1929 response = apply_body_transforms(response, &bus, &fallback_body_transforms).await;
1930 }
1931 response
1932 }
1933 .instrument(span)
1934 .await
1935 }) as Pin<Box<dyn Future<Output = HttpResponse> + Send>>
1936 });
1937
1938 self.fallback = Some(handler);
1939 self
1940 }
1941
1942 pub async fn run(self, resources: R) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
1946 self.run_with_shutdown_signal(resources, shutdown_signal())
1947 .await
1948 }
1949
1950 async fn run_with_shutdown_signal<S>(
1951 self,
1952 resources: R,
1953 shutdown_signal: S,
1954 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>>
1955 where
1956 S: Future<Output = ()> + Send,
1957 {
1958 let addr_str = self.addr.as_deref().unwrap_or("127.0.0.1:3000");
1959 let addr: SocketAddr = addr_str.parse()?;
1960
1961 let mut raw_routes = self.routes;
1962 if self.active_intervention {
1963 let handler: RouteHandler<R> = Arc::new(|_parts, _res| {
1964 Box::pin(async move {
1965 Response::builder()
1966 .status(StatusCode::OK)
1967 .body(
1968 Full::new(Bytes::from("Intervention accepted"))
1969 .map_err(|never| match never {} as Infallible)
1970 .boxed(),
1971 )
1972 .expect("valid HTTP response construction")
1973 }) as Pin<Box<dyn Future<Output = HttpResponse> + Send>>
1974 });
1975
1976 raw_routes.push(RouteEntry {
1977 method: Method::POST,
1978 pattern: RoutePattern::parse("/_system/intervene/force_resume"),
1979 handler,
1980 layers: Arc::new(Vec::new()),
1981 apply_global_layers: true,
1982 needs_body: false,
1983 });
1984 }
1985
1986 if let Some(registry) = self.policy_registry.clone() {
1987 let handler: RouteHandler<R> = Arc::new(move |_parts, _res| {
1988 let _registry = registry.clone();
1989 Box::pin(async move {
1990 Response::builder()
1994 .status(StatusCode::OK)
1995 .body(
1996 Full::new(Bytes::from("Policy registry active"))
1997 .map_err(|never| match never {} as Infallible)
1998 .boxed(),
1999 )
2000 .expect("valid HTTP response construction")
2001 }) as Pin<Box<dyn Future<Output = HttpResponse> + Send>>
2002 });
2003
2004 raw_routes.push(RouteEntry {
2005 method: Method::POST,
2006 pattern: RoutePattern::parse("/_system/policy/reload"),
2007 handler,
2008 layers: Arc::new(Vec::new()),
2009 apply_global_layers: true,
2010 needs_body: false,
2011 });
2012 }
2013 let routes = Arc::new(raw_routes);
2014 let fallback = self.fallback;
2015 let layers = Arc::new(self.layers);
2016 let health = Arc::new(self.health);
2017 let static_assets = Arc::new(self.static_assets);
2018 let preflight_config = Arc::new(self.preflight_config);
2019 let on_start = self.on_start;
2020 let on_shutdown = self.on_shutdown;
2021 let graceful_shutdown_timeout = self.graceful_shutdown_timeout;
2022 let resources = Arc::new(resources);
2023
2024 let listener = TcpListener::bind(addr).await?;
2025
2026 #[cfg(feature = "tls")]
2028 let tls_acceptor = if let Some(ref tls_cfg) = self.tls_config {
2029 let acceptor = build_tls_acceptor(&tls_cfg.cert_path, &tls_cfg.key_path)?;
2030 tracing::info!("Ranvier HTTP Ingress listening on https://{}", addr);
2031 Some(acceptor)
2032 } else {
2033 tracing::info!("Ranvier HTTP Ingress listening on http://{}", addr);
2034 None
2035 };
2036 #[cfg(not(feature = "tls"))]
2037 tracing::info!("Ranvier HTTP Ingress listening on http://{}", addr);
2038
2039 if let Some(callback) = on_start.as_ref() {
2040 callback();
2041 }
2042
2043 tokio::pin!(shutdown_signal);
2044 let mut connections = tokio::task::JoinSet::new();
2045
2046 loop {
2047 tokio::select! {
2048 _ = &mut shutdown_signal => {
2049 tracing::info!("Shutdown signal received. Draining in-flight connections.");
2050 break;
2051 }
2052 accept_result = listener.accept() => {
2053 let (stream, _) = accept_result?;
2054
2055 let routes = routes.clone();
2056 let fallback = fallback.clone();
2057 let resources = resources.clone();
2058 let layers = layers.clone();
2059 let health = health.clone();
2060 let static_assets = static_assets.clone();
2061 let preflight_config = preflight_config.clone();
2062 #[cfg(feature = "http3")]
2063 let alt_svc_h3_port = self.alt_svc_h3_port;
2064
2065 #[cfg(feature = "tls")]
2066 let tls_acceptor = tls_acceptor.clone();
2067
2068 connections.spawn(async move {
2069 let service = build_http_service(
2070 routes,
2071 fallback,
2072 resources,
2073 layers,
2074 health,
2075 static_assets,
2076 preflight_config,
2077 #[cfg(feature = "http3")] alt_svc_h3_port,
2078 );
2079
2080 #[cfg(feature = "tls")]
2081 if let Some(acceptor) = tls_acceptor {
2082 match acceptor.accept(stream).await {
2083 Ok(tls_stream) => {
2084 let io = TokioIo::new(tls_stream);
2085 if let Err(err) = http1::Builder::new()
2086 .serve_connection(io, service)
2087 .with_upgrades()
2088 .await
2089 {
2090 tracing::error!("Error serving TLS connection: {:?}", err);
2091 }
2092 }
2093 Err(err) => {
2094 tracing::warn!("TLS handshake failed: {:?}", err);
2095 }
2096 }
2097 return;
2098 }
2099
2100 let io = TokioIo::new(stream);
2101 if let Err(err) = http1::Builder::new()
2102 .serve_connection(io, service)
2103 .with_upgrades()
2104 .await
2105 {
2106 tracing::error!("Error serving connection: {:?}", err);
2107 }
2108 });
2109 }
2110 Some(join_result) = connections.join_next(), if !connections.is_empty() => {
2111 if let Err(err) = join_result {
2112 tracing::warn!("Connection task join error: {:?}", err);
2113 }
2114 }
2115 }
2116 }
2117
2118 let _timed_out = drain_connections(&mut connections, graceful_shutdown_timeout).await;
2119
2120 drop(resources);
2121 if let Some(callback) = on_shutdown.as_ref() {
2122 callback();
2123 }
2124
2125 Ok(())
2126 }
2127
2128 pub fn into_raw_service(self, resources: R) -> RawIngressService<R> {
2144 let routes = Arc::new(self.routes);
2145 let fallback = self.fallback;
2146 let layers = Arc::new(self.layers);
2147 let health = Arc::new(self.health);
2148 let static_assets = Arc::new(self.static_assets);
2149 let preflight_config = Arc::new(self.preflight_config);
2150 let resources = Arc::new(resources);
2151
2152 RawIngressService {
2153 routes,
2154 fallback,
2155 layers,
2156 health,
2157 static_assets,
2158 preflight_config,
2159 resources,
2160 }
2161 }
2162}
2163
2164async fn apply_body_transforms(
2169 response: HttpResponse,
2170 bus: &Bus,
2171 transforms: &[ResponseBodyTransformFn],
2172) -> HttpResponse {
2173 use http_body_util::BodyExt;
2174
2175 let (parts, body) = response.into_parts();
2176
2177 let collected = match body.collect().await {
2179 Ok(c) => c.to_bytes(),
2180 Err(_) => {
2181 return Response::builder()
2183 .status(StatusCode::INTERNAL_SERVER_ERROR)
2184 .body(
2185 Full::new(Bytes::from("body collection failed"))
2186 .map_err(|never| match never {})
2187 .boxed(),
2188 )
2189 .expect("valid response");
2190 }
2191 };
2192
2193 let mut transformed = collected;
2194 for transform in transforms {
2195 transformed = transform(bus, transformed);
2196 }
2197
2198 Response::from_parts(
2199 parts,
2200 Full::new(transformed)
2201 .map_err(|never| match never {})
2202 .boxed(),
2203 )
2204}
2205
2206fn build_http_service<R>(
2207 routes: Arc<Vec<RouteEntry<R>>>,
2208 fallback: Option<RouteHandler<R>>,
2209 resources: Arc<R>,
2210 layers: Arc<Vec<ServiceLayer>>,
2211 health: Arc<HealthConfig<R>>,
2212 static_assets: Arc<StaticAssetsConfig>,
2213 preflight_config: Arc<Option<PreflightConfig>>,
2214 #[cfg(feature = "http3")] alt_svc_port: Option<u16>,
2215) -> BoxHttpService
2216where
2217 R: ranvier_core::transition::ResourceRequirement + Clone + Send + Sync + 'static,
2218{
2219 BoxService::new(move |req: Request<Incoming>| {
2220 let routes = routes.clone();
2221 let fallback = fallback.clone();
2222 let resources = resources.clone();
2223 let layers = layers.clone();
2224 let health = health.clone();
2225 let static_assets = static_assets.clone();
2226 let preflight_config = preflight_config.clone();
2227
2228 async move {
2229 let mut req = req;
2230 let method = req.method().clone();
2231 let path = req.uri().path().to_string();
2232
2233 if let Some(response) =
2234 maybe_handle_health_request(&method, &path, &health, resources.clone()).await
2235 {
2236 return Ok::<_, Infallible>(response.into_response());
2237 }
2238
2239 if method == Method::OPTIONS {
2241 if let Some(ref config) = *preflight_config {
2242 let origin = req
2243 .headers()
2244 .get("origin")
2245 .and_then(|v| v.to_str().ok())
2246 .unwrap_or("");
2247 let is_wildcard = config.allowed_origins.iter().any(|o| o == "*");
2248 let is_allowed = is_wildcard
2249 || config.allowed_origins.iter().any(|o| o == origin);
2250
2251 if is_allowed || origin.is_empty() {
2252 let allow_origin = if is_wildcard {
2253 "*".to_string()
2254 } else {
2255 origin.to_string()
2256 };
2257 let mut response = Response::builder()
2258 .status(StatusCode::NO_CONTENT)
2259 .body(
2260 Full::new(Bytes::new())
2261 .map_err(|never| match never {})
2262 .boxed(),
2263 )
2264 .expect("valid preflight response");
2265 let headers = response.headers_mut();
2266 if let Ok(v) = allow_origin.parse() {
2267 headers.insert("access-control-allow-origin", v);
2268 }
2269 if let Ok(v) = config.allowed_methods.parse() {
2270 headers.insert("access-control-allow-methods", v);
2271 }
2272 if let Ok(v) = config.allowed_headers.parse() {
2273 headers.insert("access-control-allow-headers", v);
2274 }
2275 if let Ok(v) = config.max_age.parse() {
2276 headers.insert("access-control-max-age", v);
2277 }
2278 if config.allow_credentials {
2279 headers.insert(
2280 "access-control-allow-credentials",
2281 "true".parse().expect("valid header value"),
2282 );
2283 }
2284 return Ok(response);
2285 }
2286 }
2287 }
2288
2289 if let Some((entry, params)) = find_matching_route(routes.as_slice(), &method, &path) {
2290 req.extensions_mut().insert(params);
2291 let effective_layers = if entry.apply_global_layers {
2292 merge_layers(&layers, &entry.layers)
2293 } else {
2294 entry.layers.clone()
2295 };
2296
2297 if effective_layers.is_empty() {
2298 let (mut parts, body) = req.into_parts();
2299 if entry.needs_body {
2300 match BodyExt::collect(body).await {
2301 Ok(collected) => { parts.extensions.insert(BodyBytes(collected.to_bytes())); }
2302 Err(_) => {
2303 return Ok(json_error_response(
2304 StatusCode::BAD_REQUEST,
2305 "Failed to read request body",
2306 ));
2307 }
2308 }
2309 }
2310 #[allow(unused_mut)]
2311 let mut res = (entry.handler)(parts, &resources).await;
2312 #[cfg(feature = "http3")]
2313 if let Some(port) = alt_svc_port {
2314 if let Ok(val) =
2315 http::HeaderValue::from_str(&format!("h3=\":{}\"; ma=86400", port))
2316 {
2317 res.headers_mut().insert(http::header::ALT_SVC, val);
2318 }
2319 }
2320 Ok::<_, Infallible>(res)
2321 } else {
2322 let route_service = build_route_service(
2323 entry.handler.clone(),
2324 resources.clone(),
2325 effective_layers,
2326 entry.needs_body,
2327 );
2328 #[allow(unused_mut)]
2329 let mut res = route_service.call(req).await;
2330 #[cfg(feature = "http3")]
2331 #[allow(irrefutable_let_patterns)]
2332 if let Ok(ref mut r) = res {
2333 if let Some(port) = alt_svc_port {
2334 if let Ok(val) =
2335 http::HeaderValue::from_str(&format!("h3=\":{}\"; ma=86400", port))
2336 {
2337 r.headers_mut().insert(http::header::ALT_SVC, val);
2338 }
2339 }
2340 }
2341 res
2342 }
2343 } else {
2344 let req =
2345 match maybe_handle_static_request(req, &method, &path, static_assets.as_ref())
2346 .await
2347 {
2348 Ok(req) => req,
2349 Err(response) => return Ok(response),
2350 };
2351
2352 #[allow(unused_mut)]
2353 let mut fallback_res = if let Some(ref fb) = fallback {
2354 if layers.is_empty() {
2355 let (parts, _) = req.into_parts();
2356 Ok(fb(parts, &resources).await)
2357 } else {
2358 let fallback_service =
2359 build_route_service(fb.clone(), resources.clone(), layers.clone(), false);
2360 fallback_service.call(req).await
2361 }
2362 } else {
2363 Ok(Response::builder()
2364 .status(StatusCode::NOT_FOUND)
2365 .body(
2366 Full::new(Bytes::from("Not Found"))
2367 .map_err(|never| match never {})
2368 .boxed(),
2369 )
2370 .expect("valid HTTP response construction"))
2371 };
2372
2373 #[cfg(feature = "http3")]
2374 if let Ok(r) = fallback_res.as_mut() {
2375 if let Some(port) = alt_svc_port {
2376 if let Ok(val) =
2377 http::HeaderValue::from_str(&format!("h3=\":{}\"; ma=86400", port))
2378 {
2379 r.headers_mut().insert(http::header::ALT_SVC, val);
2380 }
2381 }
2382 }
2383
2384 fallback_res
2385 }
2386 }
2387 })
2388}
2389
2390fn build_route_service<R>(
2391 handler: RouteHandler<R>,
2392 resources: Arc<R>,
2393 layers: Arc<Vec<ServiceLayer>>,
2394 needs_body: bool,
2395) -> BoxHttpService
2396where
2397 R: ranvier_core::transition::ResourceRequirement + Clone + Send + Sync + 'static,
2398{
2399 let mut service = BoxService::new(move |req: Request<Incoming>| {
2400 let handler = handler.clone();
2401 let resources = resources.clone();
2402 async move {
2403 let (mut parts, body) = req.into_parts();
2404 if needs_body {
2405 match BodyExt::collect(body).await {
2406 Ok(collected) => { parts.extensions.insert(BodyBytes(collected.to_bytes())); }
2407 Err(_) => {
2408 return Ok(json_error_response(
2409 StatusCode::BAD_REQUEST,
2410 "Failed to read request body",
2411 ));
2412 }
2413 }
2414 }
2415 Ok::<_, Infallible>(handler(parts, &resources).await)
2416 }
2417 });
2418
2419 for layer in layers.iter() {
2420 service = layer(service);
2421 }
2422 service
2423}
2424
2425fn merge_layers(
2426 global_layers: &Arc<Vec<ServiceLayer>>,
2427 route_layers: &Arc<Vec<ServiceLayer>>,
2428) -> Arc<Vec<ServiceLayer>> {
2429 if global_layers.is_empty() {
2430 return route_layers.clone();
2431 }
2432 if route_layers.is_empty() {
2433 return global_layers.clone();
2434 }
2435
2436 let mut combined = Vec::with_capacity(global_layers.len() + route_layers.len());
2437 combined.extend(global_layers.iter().cloned());
2438 combined.extend(route_layers.iter().cloned());
2439 Arc::new(combined)
2440}
2441
2442async fn maybe_handle_health_request<R>(
2443 method: &Method,
2444 path: &str,
2445 health: &HealthConfig<R>,
2446 resources: Arc<R>,
2447) -> Option<HttpResponse>
2448where
2449 R: ranvier_core::transition::ResourceRequirement + Clone + Send + Sync + 'static,
2450{
2451 if method != Method::GET {
2452 return None;
2453 }
2454
2455 if let Some(liveness_path) = health.liveness_path.as_ref() {
2456 if path == liveness_path {
2457 return Some(health_json_response("liveness", true, Vec::new()));
2458 }
2459 }
2460
2461 if let Some(readiness_path) = health.readiness_path.as_ref() {
2462 if path == readiness_path {
2463 let (healthy, checks) = run_named_health_checks(&health.checks, resources).await;
2464 return Some(health_json_response("readiness", healthy, checks));
2465 }
2466 }
2467
2468 if let Some(health_path) = health.health_path.as_ref() {
2469 if path == health_path {
2470 let (healthy, checks) = run_named_health_checks(&health.checks, resources).await;
2471 return Some(health_json_response("health", healthy, checks));
2472 }
2473 }
2474
2475 None
2476}
2477
2478async fn serve_single_file(file_path: &str) -> Result<Response<Full<Bytes>>, std::io::Error> {
2480 let path = std::path::Path::new(file_path);
2481 let content = tokio::fs::read(path).await?;
2482 let mime = guess_mime(file_path);
2483 let mut response = Response::new(Full::new(Bytes::from(content)));
2484 if let Ok(value) = http::HeaderValue::from_str(mime) {
2485 response
2486 .headers_mut()
2487 .insert(http::header::CONTENT_TYPE, value);
2488 }
2489 if let Ok(metadata) = tokio::fs::metadata(path).await {
2490 if let Ok(modified) = metadata.modified() {
2491 if let Ok(duration) = modified.duration_since(std::time::UNIX_EPOCH) {
2492 let etag = format!("\"{}\"", duration.as_secs());
2493 if let Ok(value) = http::HeaderValue::from_str(&etag) {
2494 response.headers_mut().insert(http::header::ETAG, value);
2495 }
2496 }
2497 }
2498 }
2499 Ok(response)
2500}
2501
2502async fn serve_static_file(
2504 directory: &str,
2505 file_subpath: &str,
2506 config: &StaticAssetsConfig,
2507 if_none_match: Option<&http::HeaderValue>,
2508) -> Result<Response<Full<Bytes>>, std::io::Error> {
2509 let subpath = file_subpath.trim_start_matches('/');
2510
2511 let resolved_subpath;
2513 if subpath.is_empty() || subpath.ends_with('/') {
2514 if let Some(ref index) = config.directory_index {
2515 resolved_subpath = if subpath.is_empty() {
2516 index.clone()
2517 } else {
2518 format!("{}{}", subpath, index)
2519 };
2520 } else {
2521 return Err(std::io::Error::new(
2522 std::io::ErrorKind::NotFound,
2523 "empty path",
2524 ));
2525 }
2526 } else {
2527 resolved_subpath = subpath.to_string();
2528 }
2529
2530 let full_path = std::path::Path::new(directory).join(&resolved_subpath);
2531 let canonical = tokio::fs::canonicalize(&full_path).await?;
2533 let dir_canonical = tokio::fs::canonicalize(directory).await?;
2534 if !canonical.starts_with(&dir_canonical) {
2535 return Err(std::io::Error::new(
2536 std::io::ErrorKind::PermissionDenied,
2537 "path traversal detected",
2538 ));
2539 }
2540
2541 let etag = if let Ok(metadata) = tokio::fs::metadata(&canonical).await {
2543 metadata
2544 .modified()
2545 .ok()
2546 .and_then(|m| m.duration_since(std::time::UNIX_EPOCH).ok())
2547 .map(|d| format!("\"{}\"", d.as_secs()))
2548 } else {
2549 None
2550 };
2551
2552 if let (Some(client_etag), Some(server_etag)) = (if_none_match, &etag) {
2554 if client_etag.as_bytes() == server_etag.as_bytes() {
2555 let mut response = Response::new(Full::new(Bytes::new()));
2556 *response.status_mut() = StatusCode::NOT_MODIFIED;
2557 if let Ok(value) = http::HeaderValue::from_str(server_etag) {
2558 response.headers_mut().insert(http::header::ETAG, value);
2559 }
2560 return Ok(response);
2561 }
2562 }
2563
2564 let content = tokio::fs::read(&canonical).await?;
2565 let mime = guess_mime(canonical.to_str().unwrap_or(""));
2566 let mut response = Response::new(Full::new(Bytes::from(content)));
2567 if let Ok(value) = http::HeaderValue::from_str(mime) {
2568 response
2569 .headers_mut()
2570 .insert(http::header::CONTENT_TYPE, value);
2571 }
2572 if let Some(ref etag_val) = etag {
2573 if let Ok(value) = http::HeaderValue::from_str(etag_val) {
2574 response.headers_mut().insert(http::header::ETAG, value);
2575 }
2576 }
2577
2578 if config.immutable_cache {
2580 if let Some(filename) = canonical.file_name().and_then(|n| n.to_str()) {
2581 if is_hashed_filename(filename) {
2582 if let Ok(value) = http::HeaderValue::from_str(
2583 "public, max-age=31536000, immutable",
2584 ) {
2585 response
2586 .headers_mut()
2587 .insert(http::header::CACHE_CONTROL, value);
2588 }
2589 }
2590 }
2591 }
2592
2593 Ok(response)
2594}
2595
2596fn is_hashed_filename(filename: &str) -> bool {
2599 let parts: Vec<&str> = filename.rsplitn(3, '.').collect();
2600 if parts.len() < 3 {
2601 return false;
2602 }
2603 let hash_part = parts[1];
2605 hash_part.len() >= 6 && hash_part.chars().all(|c| c.is_ascii_hexdigit())
2606}
2607
2608fn guess_mime(path: &str) -> &'static str {
2609 match path.rsplit('.').next().unwrap_or("") {
2610 "html" | "htm" => "text/html; charset=utf-8",
2611 "css" => "text/css; charset=utf-8",
2612 "js" | "mjs" | "ts" | "tsx" => "application/javascript; charset=utf-8",
2613 "json" => "application/json; charset=utf-8",
2614 "png" => "image/png",
2615 "jpg" | "jpeg" => "image/jpeg",
2616 "gif" => "image/gif",
2617 "svg" => "image/svg+xml",
2618 "ico" => "image/x-icon",
2619 "avif" => "image/avif",
2620 "webp" => "image/webp",
2621 "webm" => "video/webm",
2622 "mp4" => "video/mp4",
2623 "woff" => "font/woff",
2624 "woff2" => "font/woff2",
2625 "ttf" => "font/ttf",
2626 "txt" => "text/plain; charset=utf-8",
2627 "xml" => "application/xml; charset=utf-8",
2628 "yaml" | "yml" => "application/yaml",
2629 "wasm" => "application/wasm",
2630 "pdf" => "application/pdf",
2631 "map" => "application/json",
2632 _ => "application/octet-stream",
2633 }
2634}
2635
2636fn apply_cache_control(
2637 mut response: Response<Full<Bytes>>,
2638 cache_control: Option<&str>,
2639) -> Response<Full<Bytes>> {
2640 if response.status() == StatusCode::OK {
2641 if let Some(value) = cache_control {
2642 if !response.headers().contains_key(http::header::CACHE_CONTROL) {
2643 if let Ok(header_value) = http::HeaderValue::from_str(value) {
2644 response
2645 .headers_mut()
2646 .insert(http::header::CACHE_CONTROL, header_value);
2647 }
2648 }
2649 }
2650 }
2651 response
2652}
2653
2654async fn maybe_handle_static_request(
2655 req: Request<Incoming>,
2656 method: &Method,
2657 path: &str,
2658 static_assets: &StaticAssetsConfig,
2659) -> Result<Request<Incoming>, HttpResponse> {
2660 if method != Method::GET && method != Method::HEAD {
2661 return Ok(req);
2662 }
2663
2664 if let Some(mount) = static_assets
2665 .mounts
2666 .iter()
2667 .find(|mount| strip_mount_prefix(path, &mount.route_prefix).is_some())
2668 {
2669 let accept_encoding = req.headers().get(http::header::ACCEPT_ENCODING).cloned();
2670 let if_none_match = req.headers().get(http::header::IF_NONE_MATCH).cloned();
2671 let Some(stripped_path) = strip_mount_prefix(path, &mount.route_prefix) else {
2672 return Ok(req);
2673 };
2674 let response = match serve_static_file(
2675 &mount.directory,
2676 &stripped_path,
2677 static_assets,
2678 if_none_match.as_ref(),
2679 )
2680 .await
2681 {
2682 Ok(response) => response,
2683 Err(_) => {
2684 return Err(Response::builder()
2685 .status(StatusCode::INTERNAL_SERVER_ERROR)
2686 .body(
2687 Full::new(Bytes::from("Failed to serve static asset"))
2688 .map_err(|never| match never {})
2689 .boxed(),
2690 )
2691 .unwrap_or_else(|_| {
2692 Response::new(
2693 Full::new(Bytes::new())
2694 .map_err(|never| match never {})
2695 .boxed(),
2696 )
2697 }));
2698 }
2699 };
2700 let mut response = apply_cache_control(response, static_assets.cache_control.as_deref());
2701 response = maybe_compress_static_response(
2702 response,
2703 accept_encoding,
2704 static_assets.enable_compression,
2705 );
2706 let (parts, body) = response.into_parts();
2707 return Err(Response::from_parts(
2708 parts,
2709 body.map_err(|never| match never {}).boxed(),
2710 ));
2711 }
2712
2713 if let Some(spa_file) = static_assets.spa_fallback.as_ref() {
2714 if looks_like_spa_request(path) {
2715 let accept_encoding = req.headers().get(http::header::ACCEPT_ENCODING).cloned();
2716 let response = match serve_single_file(spa_file).await {
2717 Ok(response) => response,
2718 Err(_) => {
2719 return Err(Response::builder()
2720 .status(StatusCode::INTERNAL_SERVER_ERROR)
2721 .body(
2722 Full::new(Bytes::from("Failed to serve SPA fallback"))
2723 .map_err(|never| match never {})
2724 .boxed(),
2725 )
2726 .unwrap_or_else(|_| {
2727 Response::new(
2728 Full::new(Bytes::new())
2729 .map_err(|never| match never {})
2730 .boxed(),
2731 )
2732 }));
2733 }
2734 };
2735 let mut response =
2736 apply_cache_control(response, static_assets.cache_control.as_deref());
2737 response = maybe_compress_static_response(
2738 response,
2739 accept_encoding,
2740 static_assets.enable_compression,
2741 );
2742 let (parts, body) = response.into_parts();
2743 return Err(Response::from_parts(
2744 parts,
2745 body.map_err(|never| match never {}).boxed(),
2746 ));
2747 }
2748 }
2749
2750 Ok(req)
2751}
2752
2753fn strip_mount_prefix(path: &str, prefix: &str) -> Option<String> {
2754 let normalized_prefix = if prefix == "/" {
2755 "/"
2756 } else {
2757 prefix.trim_end_matches('/')
2758 };
2759
2760 if normalized_prefix == "/" {
2761 return Some(path.to_string());
2762 }
2763
2764 if path == normalized_prefix {
2765 return Some("/".to_string());
2766 }
2767
2768 let with_slash = format!("{normalized_prefix}/");
2769 path.strip_prefix(&with_slash)
2770 .map(|stripped| format!("/{}", stripped))
2771}
2772
2773fn looks_like_spa_request(path: &str) -> bool {
2774 let tail = path.rsplit('/').next().unwrap_or_default();
2775 !tail.contains('.')
2776}
2777
2778fn maybe_compress_static_response(
2779 response: Response<Full<Bytes>>,
2780 accept_encoding: Option<http::HeaderValue>,
2781 enable_compression: bool,
2782) -> Response<Full<Bytes>> {
2783 if !enable_compression {
2784 return response;
2785 }
2786
2787 let Some(accept_encoding) = accept_encoding else {
2788 return response;
2789 };
2790
2791 let accept_str = accept_encoding.to_str().unwrap_or("");
2792 if !accept_str.contains("gzip") {
2793 return response;
2794 }
2795
2796 let status = response.status();
2797 let headers = response.headers().clone();
2798 let body = response.into_body();
2799
2800 let data = futures_util::FutureExt::now_or_never(BodyExt::collect(body))
2802 .and_then(|r| r.ok())
2803 .map(|collected| collected.to_bytes())
2804 .unwrap_or_default();
2805
2806 let compressed = {
2808 use flate2::write::GzEncoder;
2809 use flate2::Compression;
2810 use std::io::Write;
2811 let mut encoder = GzEncoder::new(Vec::new(), Compression::default());
2812 let _ = encoder.write_all(&data);
2813 encoder.finish().unwrap_or_default()
2814 };
2815
2816 let mut builder = Response::builder().status(status);
2817 for (name, value) in headers.iter() {
2818 if name != http::header::CONTENT_LENGTH && name != http::header::CONTENT_ENCODING {
2819 builder = builder.header(name, value);
2820 }
2821 }
2822 builder
2823 .header(http::header::CONTENT_ENCODING, "gzip")
2824 .body(Full::new(Bytes::from(compressed)))
2825 .unwrap_or_else(|_| Response::new(Full::new(Bytes::new())))
2826}
2827
2828async fn run_named_health_checks<R>(
2829 checks: &[NamedHealthCheck<R>],
2830 resources: Arc<R>,
2831) -> (bool, Vec<HealthCheckReport>)
2832where
2833 R: ranvier_core::transition::ResourceRequirement + Clone + Send + Sync + 'static,
2834{
2835 let mut reports = Vec::with_capacity(checks.len());
2836 let mut healthy = true;
2837
2838 for check in checks {
2839 match (check.check)(resources.clone()).await {
2840 Ok(()) => reports.push(HealthCheckReport {
2841 name: check.name.clone(),
2842 status: "ok",
2843 error: None,
2844 }),
2845 Err(error) => {
2846 healthy = false;
2847 reports.push(HealthCheckReport {
2848 name: check.name.clone(),
2849 status: "error",
2850 error: Some(error),
2851 });
2852 }
2853 }
2854 }
2855
2856 (healthy, reports)
2857}
2858
2859fn health_json_response(
2860 probe: &'static str,
2861 healthy: bool,
2862 checks: Vec<HealthCheckReport>,
2863) -> HttpResponse {
2864 let status_code = if healthy {
2865 StatusCode::OK
2866 } else {
2867 StatusCode::SERVICE_UNAVAILABLE
2868 };
2869 let status = if healthy { "ok" } else { "degraded" };
2870 let payload = HealthReport {
2871 status,
2872 probe,
2873 checks,
2874 };
2875
2876 let body = serde_json::to_vec(&payload)
2877 .unwrap_or_else(|_| br#"{"status":"error","probe":"health"}"#.to_vec());
2878
2879 Response::builder()
2880 .status(status_code)
2881 .header(http::header::CONTENT_TYPE, "application/json")
2882 .body(
2883 Full::new(Bytes::from(body))
2884 .map_err(|never| match never {})
2885 .boxed(),
2886 )
2887 .expect("valid HTTP response construction")
2888}
2889
2890async fn shutdown_signal() {
2891 #[cfg(unix)]
2892 {
2893 use tokio::signal::unix::{SignalKind, signal};
2894
2895 match signal(SignalKind::terminate()) {
2896 Ok(mut terminate) => {
2897 tokio::select! {
2898 _ = tokio::signal::ctrl_c() => {}
2899 _ = terminate.recv() => {}
2900 }
2901 }
2902 Err(err) => {
2903 tracing::warn!("Failed to install SIGTERM handler: {:?}", err);
2904 if let Err(ctrl_c_err) = tokio::signal::ctrl_c().await {
2905 tracing::warn!("Failed to listen for Ctrl+C: {:?}", ctrl_c_err);
2906 }
2907 }
2908 }
2909 }
2910
2911 #[cfg(not(unix))]
2912 {
2913 if let Err(err) = tokio::signal::ctrl_c().await {
2914 tracing::warn!("Failed to listen for Ctrl+C: {:?}", err);
2915 }
2916 }
2917}
2918
2919async fn drain_connections(
2920 connections: &mut tokio::task::JoinSet<()>,
2921 graceful_shutdown_timeout: Duration,
2922) -> bool {
2923 if connections.is_empty() {
2924 return false;
2925 }
2926
2927 let drain_result = tokio::time::timeout(graceful_shutdown_timeout, async {
2928 while let Some(join_result) = connections.join_next().await {
2929 if let Err(err) = join_result {
2930 tracing::warn!("Connection task join error during shutdown: {:?}", err);
2931 }
2932 }
2933 })
2934 .await;
2935
2936 if drain_result.is_err() {
2937 tracing::warn!(
2938 "Graceful shutdown timeout reached ({:?}). Aborting remaining connections.",
2939 graceful_shutdown_timeout
2940 );
2941 connections.abort_all();
2942 while let Some(join_result) = connections.join_next().await {
2943 if let Err(err) = join_result {
2944 tracing::warn!("Connection task abort join error: {:?}", err);
2945 }
2946 }
2947 true
2948 } else {
2949 false
2950 }
2951}
2952
2953#[cfg(feature = "tls")]
2955fn build_tls_acceptor(
2956 cert_path: &str,
2957 key_path: &str,
2958) -> Result<tokio_rustls::TlsAcceptor, Box<dyn std::error::Error + Send + Sync>> {
2959 use rustls::ServerConfig;
2960 use rustls_pemfile::{certs, private_key};
2961 use std::io::BufReader;
2962 use tokio_rustls::TlsAcceptor;
2963
2964 let cert_file = std::fs::File::open(cert_path)
2965 .map_err(|e| format!("Failed to open certificate file '{}': {}", cert_path, e))?;
2966 let key_file = std::fs::File::open(key_path)
2967 .map_err(|e| format!("Failed to open key file '{}': {}", key_path, e))?;
2968
2969 let cert_chain: Vec<_> = certs(&mut BufReader::new(cert_file))
2970 .collect::<Result<Vec<_>, _>>()
2971 .map_err(|e| format!("Failed to parse certificate PEM: {}", e))?;
2972
2973 let key = private_key(&mut BufReader::new(key_file))
2974 .map_err(|e| format!("Failed to parse private key PEM: {}", e))?
2975 .ok_or("No private key found in key file")?;
2976
2977 let config = ServerConfig::builder()
2978 .with_no_client_auth()
2979 .with_single_cert(cert_chain, key)
2980 .map_err(|e| format!("TLS configuration error: {}", e))?;
2981
2982 Ok(TlsAcceptor::from(Arc::new(config)))
2983}
2984
2985impl<R> Default for HttpIngress<R>
2986where
2987 R: ranvier_core::transition::ResourceRequirement + Clone + Send + Sync + 'static,
2988{
2989 fn default() -> Self {
2990 Self::new()
2991 }
2992}
2993
2994#[derive(Clone)]
2996pub struct RawIngressService<R> {
2997 routes: Arc<Vec<RouteEntry<R>>>,
2998 fallback: Option<RouteHandler<R>>,
2999 layers: Arc<Vec<ServiceLayer>>,
3000 health: Arc<HealthConfig<R>>,
3001 static_assets: Arc<StaticAssetsConfig>,
3002 preflight_config: Arc<Option<PreflightConfig>>,
3003 resources: Arc<R>,
3004}
3005
3006impl<R> hyper::service::Service<Request<Incoming>> for RawIngressService<R>
3007where
3008 R: ranvier_core::transition::ResourceRequirement + Clone + Send + Sync + 'static,
3009{
3010 type Response = HttpResponse;
3011 type Error = Infallible;
3012 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
3013
3014 fn call(&self, req: Request<Incoming>) -> Self::Future {
3015 let routes = self.routes.clone();
3016 let fallback = self.fallback.clone();
3017 let layers = self.layers.clone();
3018 let health = self.health.clone();
3019 let static_assets = self.static_assets.clone();
3020 let preflight_config = self.preflight_config.clone();
3021 let resources = self.resources.clone();
3022
3023 Box::pin(async move {
3024 let service = build_http_service(
3025 routes,
3026 fallback,
3027 resources,
3028 layers,
3029 health,
3030 static_assets,
3031 preflight_config,
3032 #[cfg(feature = "http3")]
3033 None,
3034 );
3035 service.call(req).await
3036 })
3037 }
3038}
3039
3040#[cfg(test)]
3041mod tests {
3042 use super::*;
3043 use async_trait::async_trait;
3044 use futures_util::{SinkExt, StreamExt};
3045 use serde::Deserialize;
3046 use std::fs;
3047 use std::sync::atomic::{AtomicBool, Ordering};
3048 use tempfile::tempdir;
3049 use tokio::io::{AsyncReadExt, AsyncWriteExt};
3050 use tokio_tungstenite::tungstenite::Message as WsClientMessage;
3051 use tokio_tungstenite::tungstenite::client::IntoClientRequest;
3052
3053 async fn connect_with_retry(addr: std::net::SocketAddr) -> tokio::net::TcpStream {
3054 let deadline = tokio::time::Instant::now() + Duration::from_secs(2);
3055
3056 loop {
3057 match tokio::net::TcpStream::connect(addr).await {
3058 Ok(stream) => return stream,
3059 Err(error) => {
3060 if tokio::time::Instant::now() >= deadline {
3061 panic!("connect server: {error}");
3062 }
3063 tokio::time::sleep(Duration::from_millis(25)).await;
3064 }
3065 }
3066 }
3067 }
3068
3069 #[test]
3070 fn route_pattern_matches_static_path() {
3071 let pattern = RoutePattern::parse("/orders/list");
3072 let params = pattern.match_path("/orders/list").expect("should match");
3073 assert!(params.into_inner().is_empty());
3074 }
3075
3076 #[test]
3077 fn route_pattern_matches_param_segments() {
3078 let pattern = RoutePattern::parse("/orders/:id/items/:item_id");
3079 let params = pattern
3080 .match_path("/orders/42/items/sku-123")
3081 .expect("should match");
3082 assert_eq!(params.get("id"), Some("42"));
3083 assert_eq!(params.get("item_id"), Some("sku-123"));
3084 }
3085
3086 #[test]
3087 fn route_pattern_matches_wildcard_segment() {
3088 let pattern = RoutePattern::parse("/assets/*path");
3089 let params = pattern
3090 .match_path("/assets/css/theme/light.css")
3091 .expect("should match");
3092 assert_eq!(params.get("path"), Some("css/theme/light.css"));
3093 }
3094
3095 #[test]
3096 fn route_pattern_rejects_non_matching_path() {
3097 let pattern = RoutePattern::parse("/orders/:id");
3098 assert!(pattern.match_path("/users/42").is_none());
3099 }
3100
3101 #[test]
3102 fn graceful_shutdown_timeout_defaults_to_30_seconds() {
3103 let ingress = HttpIngress::<()>::new();
3104 assert_eq!(ingress.graceful_shutdown_timeout, Duration::from_secs(30));
3105 assert!(ingress.layers.is_empty());
3106 assert!(ingress.bus_injectors.is_empty());
3107 assert!(ingress.static_assets.mounts.is_empty());
3108 assert!(ingress.on_start.is_none());
3109 assert!(ingress.on_shutdown.is_none());
3110 }
3111
3112 #[test]
3113 fn route_without_layer_keeps_empty_route_middleware_stack() {
3114 let ingress =
3115 HttpIngress::<()>::new().get("/ping", Axon::<(), (), String, ()>::new("Ping"));
3116 assert_eq!(ingress.routes.len(), 1);
3117 assert!(ingress.routes[0].layers.is_empty());
3118 assert!(ingress.routes[0].apply_global_layers);
3119 }
3120
3121 #[test]
3122 fn timeout_layer_registers_builtin_middleware() {
3123 let ingress = HttpIngress::<()>::new().timeout_layer(Duration::from_secs(1));
3124 assert_eq!(ingress.layers.len(), 1);
3125 }
3126
3127 #[test]
3128 fn request_id_layer_registers_builtin_middleware() {
3129 let ingress = HttpIngress::<()>::new().request_id_layer();
3130 assert_eq!(ingress.layers.len(), 1);
3131 }
3132
3133 #[test]
3134 fn compression_layer_registers_builtin_middleware() {
3135 let ingress = HttpIngress::<()>::new().compression_layer();
3136 assert!(ingress.static_assets.enable_compression);
3137 }
3138
3139 #[test]
3140 fn bus_injector_registration_adds_hook() {
3141 let ingress = HttpIngress::<()>::new().bus_injector(|_req, bus| {
3142 bus.insert("ok".to_string());
3143 });
3144 assert_eq!(ingress.bus_injectors.len(), 1);
3145 }
3146
3147 #[test]
3148 fn ws_route_registers_get_route_pattern() {
3149 let ingress =
3150 HttpIngress::<()>::new().ws("/ws/events", |_socket, _resources, _bus| async {});
3151 assert_eq!(ingress.routes.len(), 1);
3152 assert_eq!(ingress.routes[0].method, Method::GET);
3153 assert_eq!(ingress.routes[0].pattern.raw, "/ws/events");
3154 }
3155
3156 #[derive(Debug, Deserialize)]
3157 struct WsWelcomeFrame {
3158 connection_id: String,
3159 path: String,
3160 tenant: String,
3161 }
3162
3163 #[tokio::test]
3164 async fn ws_route_upgrades_and_bridges_event_source_sink_with_connection_bus() {
3165 let probe = std::net::TcpListener::bind("127.0.0.1:0").expect("bind probe");
3166 let addr = probe.local_addr().expect("local addr");
3167 drop(probe);
3168
3169 let ingress = HttpIngress::<()>::new()
3170 .bind(addr.to_string())
3171 .bus_injector(|req, bus| {
3172 if let Some(value) = req.headers.get("x-tenant-id").and_then(|v| v.to_str().ok()) {
3173 bus.insert(value.to_string());
3174 }
3175 })
3176 .ws("/ws/echo", |mut socket, _resources, bus| async move {
3177 let tenant = bus
3178 .read::<String>()
3179 .cloned()
3180 .unwrap_or_else(|| "unknown".to_string());
3181 if let Some(session) = bus.read::<WebSocketSessionContext>() {
3182 let welcome = serde_json::json!({
3183 "connection_id": session.connection_id().to_string(),
3184 "path": session.path(),
3185 "tenant": tenant,
3186 });
3187 let _ = socket.send_json(&welcome).await;
3188 }
3189
3190 while let Some(event) = socket.next_event().await {
3191 match event {
3192 WebSocketEvent::Text(text) => {
3193 let _ = socket.send_event(format!("echo:{text}")).await;
3194 }
3195 WebSocketEvent::Binary(bytes) => {
3196 let _ = socket.send_event(bytes).await;
3197 }
3198 WebSocketEvent::Close => break,
3199 WebSocketEvent::Ping(_) | WebSocketEvent::Pong(_) => {}
3200 }
3201 }
3202 });
3203
3204 let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel::<()>();
3205 let server = tokio::spawn(async move {
3206 ingress
3207 .run_with_shutdown_signal((), async move {
3208 let _ = shutdown_rx.await;
3209 })
3210 .await
3211 });
3212
3213 let ws_uri = format!("ws://{addr}/ws/echo?room=alpha");
3214 let mut ws_request = ws_uri
3215 .as_str()
3216 .into_client_request()
3217 .expect("ws client request");
3218 ws_request
3219 .headers_mut()
3220 .insert("x-tenant-id", http::HeaderValue::from_static("acme"));
3221 let (mut client, _response) = tokio_tungstenite::connect_async(ws_request)
3222 .await
3223 .expect("websocket connect");
3224
3225 let welcome = client
3226 .next()
3227 .await
3228 .expect("welcome frame")
3229 .expect("welcome frame ok");
3230 let welcome_text = match welcome {
3231 WsClientMessage::Text(text) => text.to_string(),
3232 other => panic!("expected text welcome frame, got {other:?}"),
3233 };
3234 let welcome_payload: WsWelcomeFrame =
3235 serde_json::from_str(&welcome_text).expect("welcome json");
3236 assert_eq!(welcome_payload.path, "/ws/echo");
3237 assert_eq!(welcome_payload.tenant, "acme");
3238 assert!(!welcome_payload.connection_id.is_empty());
3239
3240 client
3241 .send(WsClientMessage::Text("hello".into()))
3242 .await
3243 .expect("send text");
3244 let echo_text = client
3245 .next()
3246 .await
3247 .expect("echo text frame")
3248 .expect("echo text frame ok");
3249 assert_eq!(echo_text, WsClientMessage::Text("echo:hello".into()));
3250
3251 client
3252 .send(WsClientMessage::Binary(vec![1, 2, 3, 4].into()))
3253 .await
3254 .expect("send binary");
3255 let echo_binary = client
3256 .next()
3257 .await
3258 .expect("echo binary frame")
3259 .expect("echo binary frame ok");
3260 assert_eq!(
3261 echo_binary,
3262 WsClientMessage::Binary(vec![1, 2, 3, 4].into())
3263 );
3264
3265 client.close(None).await.expect("close websocket");
3266
3267 let _ = shutdown_tx.send(());
3268 server
3269 .await
3270 .expect("server join")
3271 .expect("server shutdown should succeed");
3272 }
3273
3274 #[test]
3275 fn route_descriptors_export_http_and_health_paths() {
3276 let ingress = HttpIngress::<()>::new()
3277 .get(
3278 "/orders/:id",
3279 Axon::<(), (), String, ()>::new("OrderById"),
3280 )
3281 .health_endpoint("/healthz")
3282 .readiness_liveness("/readyz", "/livez");
3283
3284 let descriptors = ingress.route_descriptors();
3285
3286 assert!(
3287 descriptors
3288 .iter()
3289 .any(|descriptor| descriptor.method() == Method::GET
3290 && descriptor.path_pattern() == "/orders/:id")
3291 );
3292 assert!(
3293 descriptors
3294 .iter()
3295 .any(|descriptor| descriptor.method() == Method::GET
3296 && descriptor.path_pattern() == "/healthz")
3297 );
3298 assert!(
3299 descriptors
3300 .iter()
3301 .any(|descriptor| descriptor.method() == Method::GET
3302 && descriptor.path_pattern() == "/readyz")
3303 );
3304 assert!(
3305 descriptors
3306 .iter()
3307 .any(|descriptor| descriptor.method() == Method::GET
3308 && descriptor.path_pattern() == "/livez")
3309 );
3310 }
3311
3312 #[tokio::test]
3313 async fn lifecycle_hooks_fire_on_start_and_shutdown() {
3314 let started = Arc::new(AtomicBool::new(false));
3315 let shutdown = Arc::new(AtomicBool::new(false));
3316
3317 let started_flag = started.clone();
3318 let shutdown_flag = shutdown.clone();
3319
3320 let ingress = HttpIngress::<()>::new()
3321 .bind("127.0.0.1:0")
3322 .on_start(move || {
3323 started_flag.store(true, Ordering::SeqCst);
3324 })
3325 .on_shutdown(move || {
3326 shutdown_flag.store(true, Ordering::SeqCst);
3327 })
3328 .graceful_shutdown(Duration::from_millis(50));
3329
3330 ingress
3331 .run_with_shutdown_signal((), async {
3332 tokio::time::sleep(Duration::from_millis(20)).await;
3333 })
3334 .await
3335 .expect("server should exit gracefully");
3336
3337 assert!(started.load(Ordering::SeqCst));
3338 assert!(shutdown.load(Ordering::SeqCst));
3339 }
3340
3341 #[tokio::test]
3342 async fn graceful_shutdown_drains_in_flight_requests_before_exit() {
3343 #[derive(Clone)]
3344 struct SlowDrainRoute;
3345
3346 #[async_trait]
3347 impl Transition<(), String> for SlowDrainRoute {
3348 type Error = String;
3349 type Resources = ();
3350
3351 async fn run(
3352 &self,
3353 _state: (),
3354 _resources: &Self::Resources,
3355 _bus: &mut Bus,
3356 ) -> Outcome<String, Self::Error> {
3357 tokio::time::sleep(Duration::from_millis(120)).await;
3358 Outcome::next("drained-ok".to_string())
3359 }
3360 }
3361
3362 let probe = std::net::TcpListener::bind("127.0.0.1:0").expect("bind probe");
3363 let addr = probe.local_addr().expect("local addr");
3364 drop(probe);
3365
3366 let ingress = HttpIngress::<()>::new()
3367 .bind(addr.to_string())
3368 .graceful_shutdown(Duration::from_millis(500))
3369 .get(
3370 "/drain",
3371 Axon::<(), (), String, ()>::new("SlowDrain").then(SlowDrainRoute),
3372 );
3373
3374 let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel::<()>();
3375 let server = tokio::spawn(async move {
3376 ingress
3377 .run_with_shutdown_signal((), async move {
3378 let _ = shutdown_rx.await;
3379 })
3380 .await
3381 });
3382
3383 let mut stream = connect_with_retry(addr).await;
3384 stream
3385 .write_all(b"GET /drain HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n")
3386 .await
3387 .expect("write request");
3388
3389 tokio::time::sleep(Duration::from_millis(20)).await;
3390 let _ = shutdown_tx.send(());
3391
3392 let mut buf = Vec::new();
3393 stream.read_to_end(&mut buf).await.expect("read response");
3394 let response = String::from_utf8_lossy(&buf);
3395 assert!(response.starts_with("HTTP/1.1 200"), "{response}");
3396 assert!(response.contains("drained-ok"), "{response}");
3397
3398 server
3399 .await
3400 .expect("server join")
3401 .expect("server shutdown should succeed");
3402 }
3403
3404 #[tokio::test]
3405 async fn serve_dir_serves_static_file_with_cache_and_metadata_headers() {
3406 let temp = tempdir().expect("tempdir");
3407 let root = temp.path().join("public");
3408 fs::create_dir_all(&root).expect("create dir");
3409 let file = root.join("hello.txt");
3410 fs::write(&file, "hello static").expect("write file");
3411
3412 let ingress =
3413 Ranvier::http::<()>().serve_dir("/static", root.to_string_lossy().to_string());
3414 let app = crate::test_harness::TestApp::new(ingress, ());
3415 let response = app
3416 .send(crate::test_harness::TestRequest::get("/static/hello.txt"))
3417 .await
3418 .expect("request should succeed");
3419
3420 assert_eq!(response.status(), StatusCode::OK);
3421 assert_eq!(response.text().expect("utf8"), "hello static");
3422 assert!(response.header("cache-control").is_some());
3423 let has_metadata_header =
3424 response.header("etag").is_some() || response.header("last-modified").is_some();
3425 assert!(has_metadata_header);
3426 }
3427
3428 #[tokio::test]
3429 async fn spa_fallback_returns_index_for_unmatched_path() {
3430 let temp = tempdir().expect("tempdir");
3431 let index = temp.path().join("index.html");
3432 fs::write(&index, "<html><body>spa</body></html>").expect("write index");
3433
3434 let ingress = Ranvier::http::<()>().spa_fallback(index.to_string_lossy().to_string());
3435 let app = crate::test_harness::TestApp::new(ingress, ());
3436 let response = app
3437 .send(crate::test_harness::TestRequest::get("/dashboard/settings"))
3438 .await
3439 .expect("request should succeed");
3440
3441 assert_eq!(response.status(), StatusCode::OK);
3442 assert!(response.text().expect("utf8").contains("spa"));
3443 }
3444
3445 #[tokio::test]
3446 async fn static_compression_layer_sets_content_encoding_for_gzip_client() {
3447 let temp = tempdir().expect("tempdir");
3448 let root = temp.path().join("public");
3449 fs::create_dir_all(&root).expect("create dir");
3450 let file = root.join("compressed.txt");
3451 fs::write(&file, "compress me ".repeat(400)).expect("write file");
3452
3453 let ingress = Ranvier::http::<()>()
3454 .serve_dir("/static", root.to_string_lossy().to_string())
3455 .compression_layer();
3456 let app = crate::test_harness::TestApp::new(ingress, ());
3457 let response = app
3458 .send(
3459 crate::test_harness::TestRequest::get("/static/compressed.txt")
3460 .header("accept-encoding", "gzip"),
3461 )
3462 .await
3463 .expect("request should succeed");
3464
3465 assert_eq!(response.status(), StatusCode::OK);
3466 assert_eq!(
3467 response
3468 .header("content-encoding")
3469 .and_then(|value| value.to_str().ok()),
3470 Some("gzip")
3471 );
3472 }
3473
3474 #[tokio::test]
3475 async fn drain_connections_completes_before_timeout() {
3476 let mut connections = tokio::task::JoinSet::new();
3477 connections.spawn(async {
3478 tokio::time::sleep(Duration::from_millis(20)).await;
3479 });
3480
3481 let timed_out = drain_connections(&mut connections, Duration::from_millis(200)).await;
3482 assert!(!timed_out);
3483 assert!(connections.is_empty());
3484 }
3485
3486 #[tokio::test]
3487 async fn drain_connections_times_out_and_aborts() {
3488 let mut connections = tokio::task::JoinSet::new();
3489 connections.spawn(async {
3490 tokio::time::sleep(Duration::from_secs(10)).await;
3491 });
3492
3493 let timed_out = drain_connections(&mut connections, Duration::from_millis(10)).await;
3494 assert!(timed_out);
3495 assert!(connections.is_empty());
3496 }
3497
3498 #[tokio::test]
3499 async fn timeout_layer_returns_408_for_slow_route() {
3500 #[derive(Clone)]
3501 struct SlowRoute;
3502
3503 #[async_trait]
3504 impl Transition<(), String> for SlowRoute {
3505 type Error = String;
3506 type Resources = ();
3507
3508 async fn run(
3509 &self,
3510 _state: (),
3511 _resources: &Self::Resources,
3512 _bus: &mut Bus,
3513 ) -> Outcome<String, Self::Error> {
3514 tokio::time::sleep(Duration::from_millis(80)).await;
3515 Outcome::next("slow-ok".to_string())
3516 }
3517 }
3518
3519 let probe = std::net::TcpListener::bind("127.0.0.1:0").expect("bind probe");
3520 let addr = probe.local_addr().expect("local addr");
3521 drop(probe);
3522
3523 let ingress = HttpIngress::<()>::new()
3524 .bind(addr.to_string())
3525 .timeout_layer(Duration::from_millis(10))
3526 .get(
3527 "/slow",
3528 Axon::<(), (), String, ()>::new("Slow").then(SlowRoute),
3529 );
3530
3531 let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel::<()>();
3532 let server = tokio::spawn(async move {
3533 ingress
3534 .run_with_shutdown_signal((), async move {
3535 let _ = shutdown_rx.await;
3536 })
3537 .await
3538 });
3539
3540 let mut stream = connect_with_retry(addr).await;
3541 stream
3542 .write_all(b"GET /slow HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n")
3543 .await
3544 .expect("write request");
3545
3546 let mut buf = Vec::new();
3547 stream.read_to_end(&mut buf).await.expect("read response");
3548 let response = String::from_utf8_lossy(&buf);
3549 assert!(response.starts_with("HTTP/1.1 408"), "{response}");
3550
3551 let _ = shutdown_tx.send(());
3552 server
3553 .await
3554 .expect("server join")
3555 .expect("server shutdown should succeed");
3556 }
3557
3558}