1use base64::Engine;
17use bytes::Bytes;
18use futures_util::{SinkExt, StreamExt};
19use http::{Method, Request, Response, StatusCode, Uri};
20use http_body::Body;
21use http_body_util::{BodyExt, Full};
22use hyper::body::Incoming;
23use hyper::server::conn::http1;
24use hyper::upgrade::Upgraded;
25use hyper_util::rt::TokioIo;
26use hyper_util::service::TowerToHyperService;
27use ranvier_core::event::{EventSink, EventSource};
28use ranvier_core::prelude::*;
29use ranvier_runtime::Axon;
30use serde::Serialize;
31use serde::de::DeserializeOwned;
32use sha1::{Digest, Sha1};
33use std::collections::HashMap;
34use std::convert::Infallible;
35use std::future::Future;
36use std::net::SocketAddr;
37use std::pin::Pin;
38use std::sync::Arc;
39use std::time::Duration;
40use tokio::net::TcpListener;
41use tokio::sync::Mutex;
42use tokio_tungstenite::WebSocketStream;
43use tokio_tungstenite::tungstenite::{Error as WsWireError, Message as WsWireMessage};
44use tower::util::BoxCloneService;
45use tower::{Layer, Service, ServiceExt, service_fn};
46use tower_http::compression::CompressionLayer;
47use tower_http::services::{ServeDir, ServeFile};
48use tracing::Instrument;
49
50use crate::response::{HttpResponse, IntoResponse, outcome_to_response_with_error};
51
52pub struct Ranvier;
57
58impl Ranvier {
59 pub fn http<R>() -> HttpIngress<R>
61 where
62 R: ranvier_core::transition::ResourceRequirement + Clone,
63 {
64 HttpIngress::new()
65 }
66}
67
68type RouteHandler<R> = Arc<
70 dyn Fn(http::request::Parts, &R) -> Pin<Box<dyn Future<Output = HttpResponse> + Send>>
71 + Send
72 + Sync,
73>;
74
75type BoxHttpService = BoxCloneService<Request<Incoming>, HttpResponse, Infallible>;
76type ServiceLayer = Arc<dyn Fn(BoxHttpService) -> BoxHttpService + Send + Sync>;
77type LifecycleHook = Arc<dyn Fn() + Send + Sync>;
78type BusInjector = Arc<dyn Fn(&http::request::Parts, &mut Bus) + Send + Sync + 'static>;
79type WsSessionFuture = Pin<Box<dyn Future<Output = ()> + Send>>;
80type WsSessionHandler<R> =
81 Arc<dyn Fn(WebSocketConnection, Arc<R>, Bus) -> WsSessionFuture + Send + Sync>;
82type HealthCheckFuture = Pin<Box<dyn Future<Output = Result<(), String>> + Send>>;
83type HealthCheckFn<R> = Arc<dyn Fn(Arc<R>) -> HealthCheckFuture + Send + Sync>;
84const REQUEST_ID_HEADER: &str = "x-request-id";
85const WS_UPGRADE_TOKEN: &str = "websocket";
86const WS_GUID: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
87
88#[derive(Clone)]
89struct NamedHealthCheck<R> {
90 name: String,
91 check: HealthCheckFn<R>,
92}
93
94#[derive(Clone)]
95struct HealthConfig<R> {
96 health_path: Option<String>,
97 readiness_path: Option<String>,
98 liveness_path: Option<String>,
99 checks: Vec<NamedHealthCheck<R>>,
100}
101
102impl<R> Default for HealthConfig<R> {
103 fn default() -> Self {
104 Self {
105 health_path: None,
106 readiness_path: None,
107 liveness_path: None,
108 checks: Vec::new(),
109 }
110 }
111}
112
113#[derive(Clone, Default)]
114struct StaticAssetsConfig {
115 mounts: Vec<StaticMount>,
116 spa_fallback: Option<String>,
117 cache_control: Option<String>,
118 enable_compression: bool,
119}
120
121#[derive(Clone)]
122struct StaticMount {
123 route_prefix: String,
124 directory: String,
125}
126
127#[derive(Serialize)]
128struct HealthReport {
129 status: &'static str,
130 probe: &'static str,
131 checks: Vec<HealthCheckReport>,
132}
133
134#[derive(Serialize)]
135struct HealthCheckReport {
136 name: String,
137 status: &'static str,
138 #[serde(skip_serializing_if = "Option::is_none")]
139 error: Option<String>,
140}
141
142#[derive(Clone)]
143struct TimeoutService {
144 inner: BoxHttpService,
145 timeout: Duration,
146}
147
148impl Service<Request<Incoming>> for TimeoutService {
149 type Response = HttpResponse;
150 type Error = Infallible;
151 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
152
153 fn poll_ready(
154 &mut self,
155 cx: &mut std::task::Context<'_>,
156 ) -> std::task::Poll<Result<(), Self::Error>> {
157 self.inner.poll_ready(cx)
158 }
159
160 fn call(&mut self, req: Request<Incoming>) -> Self::Future {
161 let timeout = self.timeout;
162 let fut = self.inner.call(req);
163 Box::pin(async move {
164 match tokio::time::timeout(timeout, fut).await {
165 Ok(response) => response,
166 Err(_) => Ok(Response::builder()
167 .status(StatusCode::REQUEST_TIMEOUT)
168 .body(
169 Full::new(Bytes::from("Request Timeout"))
170 .map_err(|never| match never {})
171 .boxed(),
172 )
173 .unwrap()),
174 }
175 })
176 }
177}
178
179#[derive(Clone)]
180struct RequestIdService {
181 inner: BoxHttpService,
182}
183
184impl Service<Request<Incoming>> for RequestIdService {
185 type Response = HttpResponse;
186 type Error = Infallible;
187 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
188
189 fn poll_ready(
190 &mut self,
191 cx: &mut std::task::Context<'_>,
192 ) -> std::task::Poll<Result<(), Self::Error>> {
193 self.inner.poll_ready(cx)
194 }
195
196 fn call(&mut self, req: Request<Incoming>) -> Self::Future {
197 let mut req = req;
198 let request_id = req
199 .headers()
200 .get(REQUEST_ID_HEADER)
201 .cloned()
202 .unwrap_or_else(|| {
203 http::HeaderValue::from_str(&uuid::Uuid::new_v4().to_string())
204 .unwrap_or_else(|_| http::HeaderValue::from_static("request-id-unavailable"))
205 });
206
207 req.headers_mut()
208 .insert(REQUEST_ID_HEADER, request_id.clone());
209
210 let fut = self.inner.call(req);
211 Box::pin(async move {
212 let mut response = fut.await?;
213 response.headers_mut().insert(REQUEST_ID_HEADER, request_id);
214 Ok(response)
215 })
216 }
217}
218
219fn to_service_layer<L>(layer: L) -> ServiceLayer
220where
221 L: Layer<BoxHttpService> + Clone + Send + Sync + 'static,
222 L::Service: Service<Request<Incoming>, Response = HttpResponse, Error = Infallible>
223 + Clone
224 + Send
225 + 'static,
226 <L::Service as Service<Request<Incoming>>>::Future: Send + 'static,
227{
228 Arc::new(move |service: BoxHttpService| BoxCloneService::new(layer.clone().layer(service)))
229}
230
231#[derive(Clone, Debug, Default, PartialEq, Eq)]
232pub struct PathParams {
233 values: HashMap<String, String>,
234}
235
236#[derive(Clone, Debug, PartialEq, Eq)]
238pub struct HttpRouteDescriptor {
239 method: Method,
240 path_pattern: String,
241}
242
243impl HttpRouteDescriptor {
244 pub fn new(method: Method, path_pattern: impl Into<String>) -> Self {
245 Self {
246 method,
247 path_pattern: path_pattern.into(),
248 }
249 }
250
251 pub fn method(&self) -> &Method {
252 &self.method
253 }
254
255 pub fn path_pattern(&self) -> &str {
256 &self.path_pattern
257 }
258}
259
260#[derive(Clone, Debug, PartialEq, Eq, Serialize)]
262pub struct WebSocketSessionContext {
263 connection_id: uuid::Uuid,
264 path: String,
265 query: Option<String>,
266}
267
268impl WebSocketSessionContext {
269 pub fn connection_id(&self) -> uuid::Uuid {
270 self.connection_id
271 }
272
273 pub fn path(&self) -> &str {
274 &self.path
275 }
276
277 pub fn query(&self) -> Option<&str> {
278 self.query.as_deref()
279 }
280}
281
282#[derive(Clone, Debug, PartialEq, Eq)]
284pub enum WebSocketEvent {
285 Text(String),
286 Binary(Vec<u8>),
287 Ping(Vec<u8>),
288 Pong(Vec<u8>),
289 Close,
290}
291
292impl WebSocketEvent {
293 pub fn text(value: impl Into<String>) -> Self {
294 Self::Text(value.into())
295 }
296
297 pub fn binary(value: impl Into<Vec<u8>>) -> Self {
298 Self::Binary(value.into())
299 }
300
301 pub fn json<T>(value: &T) -> Result<Self, serde_json::Error>
302 where
303 T: Serialize,
304 {
305 let text = serde_json::to_string(value)?;
306 Ok(Self::Text(text))
307 }
308}
309
310#[derive(Debug, thiserror::Error)]
311pub enum WebSocketError {
312 #[error("websocket wire error: {0}")]
313 Wire(#[from] WsWireError),
314 #[error("json serialization failed: {0}")]
315 JsonSerialize(#[source] serde_json::Error),
316 #[error("json deserialization failed: {0}")]
317 JsonDeserialize(#[source] serde_json::Error),
318 #[error("expected text or binary frame for json payload")]
319 NonDataFrame,
320}
321
322type WsServerStream = WebSocketStream<TokioIo<Upgraded>>;
323type WsServerSink = futures_util::stream::SplitSink<WsServerStream, WsWireMessage>;
324type WsServerSource = futures_util::stream::SplitStream<WsServerStream>;
325
326pub struct WebSocketConnection {
328 sink: Mutex<WsServerSink>,
329 source: Mutex<WsServerSource>,
330 session: WebSocketSessionContext,
331}
332
333impl WebSocketConnection {
334 fn new(stream: WsServerStream, session: WebSocketSessionContext) -> Self {
335 let (sink, source) = stream.split();
336 Self {
337 sink: Mutex::new(sink),
338 source: Mutex::new(source),
339 session,
340 }
341 }
342
343 pub fn session(&self) -> &WebSocketSessionContext {
344 &self.session
345 }
346
347 pub async fn send(&self, event: WebSocketEvent) -> Result<(), WebSocketError> {
348 let mut sink = self.sink.lock().await;
349 sink.send(event.into_wire_message()).await?;
350 Ok(())
351 }
352
353 pub async fn send_json<T>(&self, value: &T) -> Result<(), WebSocketError>
354 where
355 T: Serialize,
356 {
357 let event = WebSocketEvent::json(value).map_err(WebSocketError::JsonSerialize)?;
358 self.send(event).await
359 }
360
361 pub async fn next_json<T>(&mut self) -> Result<Option<T>, WebSocketError>
362 where
363 T: DeserializeOwned,
364 {
365 let Some(event) = self.recv_event().await? else {
366 return Ok(None);
367 };
368 match event {
369 WebSocketEvent::Text(text) => serde_json::from_str(&text)
370 .map(Some)
371 .map_err(WebSocketError::JsonDeserialize),
372 WebSocketEvent::Binary(bytes) => serde_json::from_slice(&bytes)
373 .map(Some)
374 .map_err(WebSocketError::JsonDeserialize),
375 _ => Err(WebSocketError::NonDataFrame),
376 }
377 }
378
379 async fn recv_event(&mut self) -> Result<Option<WebSocketEvent>, WsWireError> {
380 let mut source = self.source.lock().await;
381 while let Some(item) = source.next().await {
382 let message = item?;
383 if let Some(event) = WebSocketEvent::from_wire_message(message) {
384 return Ok(Some(event));
385 }
386 }
387 Ok(None)
388 }
389}
390
391impl WebSocketEvent {
392 fn from_wire_message(message: WsWireMessage) -> Option<Self> {
393 match message {
394 WsWireMessage::Text(value) => Some(Self::Text(value.to_string())),
395 WsWireMessage::Binary(value) => Some(Self::Binary(value.to_vec())),
396 WsWireMessage::Ping(value) => Some(Self::Ping(value.to_vec())),
397 WsWireMessage::Pong(value) => Some(Self::Pong(value.to_vec())),
398 WsWireMessage::Close(_) => Some(Self::Close),
399 WsWireMessage::Frame(_) => None,
400 }
401 }
402
403 fn into_wire_message(self) -> WsWireMessage {
404 match self {
405 Self::Text(value) => WsWireMessage::Text(value),
406 Self::Binary(value) => WsWireMessage::Binary(value),
407 Self::Ping(value) => WsWireMessage::Ping(value),
408 Self::Pong(value) => WsWireMessage::Pong(value),
409 Self::Close => WsWireMessage::Close(None),
410 }
411 }
412}
413
414#[async_trait::async_trait]
415impl EventSource<WebSocketEvent> for WebSocketConnection {
416 async fn next_event(&mut self) -> Option<WebSocketEvent> {
417 match self.recv_event().await {
418 Ok(event) => event,
419 Err(error) => {
420 tracing::warn!(ranvier.ws.error = %error, "websocket source read failed");
421 None
422 }
423 }
424 }
425}
426
427#[async_trait::async_trait]
428impl EventSink<WebSocketEvent> for WebSocketConnection {
429 type Error = WebSocketError;
430
431 async fn send_event(&self, event: WebSocketEvent) -> Result<(), Self::Error> {
432 self.send(event).await
433 }
434}
435
436#[async_trait::async_trait]
437impl EventSink<String> for WebSocketConnection {
438 type Error = WebSocketError;
439
440 async fn send_event(&self, event: String) -> Result<(), Self::Error> {
441 self.send(WebSocketEvent::Text(event)).await
442 }
443}
444
445#[async_trait::async_trait]
446impl EventSink<Vec<u8>> for WebSocketConnection {
447 type Error = WebSocketError;
448
449 async fn send_event(&self, event: Vec<u8>) -> Result<(), Self::Error> {
450 self.send(WebSocketEvent::Binary(event)).await
451 }
452}
453
454impl PathParams {
455 pub fn new(values: HashMap<String, String>) -> Self {
456 Self { values }
457 }
458
459 pub fn get(&self, key: &str) -> Option<&str> {
460 self.values.get(key).map(String::as_str)
461 }
462
463 pub fn as_map(&self) -> &HashMap<String, String> {
464 &self.values
465 }
466
467 pub fn into_inner(self) -> HashMap<String, String> {
468 self.values
469 }
470}
471
472#[derive(Clone, Debug, PartialEq, Eq)]
473enum RouteSegment {
474 Static(String),
475 Param(String),
476 Wildcard(String),
477}
478
479#[derive(Clone, Debug, PartialEq, Eq)]
480struct RoutePattern {
481 raw: String,
482 segments: Vec<RouteSegment>,
483}
484
485impl RoutePattern {
486 fn parse(path: &str) -> Self {
487 let segments = path_segments(path)
488 .into_iter()
489 .map(|segment| {
490 if let Some(name) = segment.strip_prefix(':') {
491 if !name.is_empty() {
492 return RouteSegment::Param(name.to_string());
493 }
494 }
495 if let Some(name) = segment.strip_prefix('*') {
496 if !name.is_empty() {
497 return RouteSegment::Wildcard(name.to_string());
498 }
499 }
500 RouteSegment::Static(segment.to_string())
501 })
502 .collect();
503
504 Self {
505 raw: path.to_string(),
506 segments,
507 }
508 }
509
510 fn match_path(&self, path: &str) -> Option<PathParams> {
511 let mut params = HashMap::new();
512 let path_segments = path_segments(path);
513 let mut pattern_index = 0usize;
514 let mut path_index = 0usize;
515
516 while pattern_index < self.segments.len() {
517 match &self.segments[pattern_index] {
518 RouteSegment::Static(expected) => {
519 let actual = path_segments.get(path_index)?;
520 if actual != expected {
521 return None;
522 }
523 pattern_index += 1;
524 path_index += 1;
525 }
526 RouteSegment::Param(name) => {
527 let actual = path_segments.get(path_index)?;
528 params.insert(name.clone(), (*actual).to_string());
529 pattern_index += 1;
530 path_index += 1;
531 }
532 RouteSegment::Wildcard(name) => {
533 let remaining = path_segments[path_index..].join("/");
534 params.insert(name.clone(), remaining);
535 pattern_index += 1;
536 path_index = path_segments.len();
537 break;
538 }
539 }
540 }
541
542 if pattern_index == self.segments.len() && path_index == path_segments.len() {
543 Some(PathParams::new(params))
544 } else {
545 None
546 }
547 }
548}
549
550#[derive(Clone)]
551struct RouteEntry<R> {
552 method: Method,
553 pattern: RoutePattern,
554 handler: RouteHandler<R>,
555 layers: Arc<Vec<ServiceLayer>>,
556 apply_global_layers: bool,
557}
558
559fn path_segments(path: &str) -> Vec<&str> {
560 if path == "/" {
561 return Vec::new();
562 }
563
564 path.trim_matches('/')
565 .split('/')
566 .filter(|segment| !segment.is_empty())
567 .collect()
568}
569
570fn normalize_route_path(path: String) -> String {
571 if path.is_empty() {
572 return "/".to_string();
573 }
574 if path.starts_with('/') {
575 path
576 } else {
577 format!("/{path}")
578 }
579}
580
581fn find_matching_route<'a, R>(
582 routes: &'a [RouteEntry<R>],
583 method: &Method,
584 path: &str,
585) -> Option<(&'a RouteEntry<R>, PathParams)> {
586 for entry in routes {
587 if entry.method != *method {
588 continue;
589 }
590 if let Some(params) = entry.pattern.match_path(path) {
591 return Some((entry, params));
592 }
593 }
594 None
595}
596
597fn header_contains_token(
598 headers: &http::HeaderMap,
599 name: http::header::HeaderName,
600 token: &str,
601) -> bool {
602 headers
603 .get(name)
604 .and_then(|value| value.to_str().ok())
605 .map(|value| {
606 value
607 .split(',')
608 .any(|part| part.trim().eq_ignore_ascii_case(token))
609 })
610 .unwrap_or(false)
611}
612
613fn websocket_session_from_request<B>(req: &Request<B>) -> WebSocketSessionContext {
614 WebSocketSessionContext {
615 connection_id: uuid::Uuid::new_v4(),
616 path: req.uri().path().to_string(),
617 query: req.uri().query().map(str::to_string),
618 }
619}
620
621fn websocket_accept_key(client_key: &str) -> String {
622 let mut hasher = Sha1::new();
623 hasher.update(client_key.as_bytes());
624 hasher.update(WS_GUID.as_bytes());
625 let digest = hasher.finalize();
626 base64::engine::general_purpose::STANDARD.encode(digest)
627}
628
629fn websocket_bad_request(message: &'static str) -> HttpResponse {
630 Response::builder()
631 .status(StatusCode::BAD_REQUEST)
632 .body(
633 Full::new(Bytes::from(message))
634 .map_err(|never| match never {})
635 .boxed(),
636 )
637 .unwrap_or_else(|_| {
638 Response::new(
639 Full::new(Bytes::new())
640 .map_err(|never| match never {})
641 .boxed(),
642 )
643 })
644}
645
646fn websocket_upgrade_response<B>(
647 req: &mut Request<B>,
648) -> Result<(HttpResponse, hyper::upgrade::OnUpgrade), HttpResponse> {
649 if req.method() != Method::GET {
650 return Err(websocket_bad_request(
651 "WebSocket upgrade requires GET method",
652 ));
653 }
654
655 if !header_contains_token(req.headers(), http::header::CONNECTION, "upgrade") {
656 return Err(websocket_bad_request(
657 "Missing Connection: upgrade header for WebSocket",
658 ));
659 }
660
661 if !header_contains_token(req.headers(), http::header::UPGRADE, WS_UPGRADE_TOKEN) {
662 return Err(websocket_bad_request("Missing Upgrade: websocket header"));
663 }
664
665 if let Some(version) = req.headers().get("sec-websocket-version") {
666 if version != "13" {
667 return Err(websocket_bad_request(
668 "Unsupported Sec-WebSocket-Version (expected 13)",
669 ));
670 }
671 }
672
673 let Some(client_key) = req
674 .headers()
675 .get("sec-websocket-key")
676 .and_then(|value| value.to_str().ok())
677 else {
678 return Err(websocket_bad_request(
679 "Missing Sec-WebSocket-Key header for WebSocket",
680 ));
681 };
682
683 let accept_key = websocket_accept_key(client_key);
684 let on_upgrade = hyper::upgrade::on(req);
685 let response = Response::builder()
686 .status(StatusCode::SWITCHING_PROTOCOLS)
687 .header(http::header::UPGRADE, WS_UPGRADE_TOKEN)
688 .header(http::header::CONNECTION, "Upgrade")
689 .header("sec-websocket-accept", accept_key)
690 .body(
691 Full::new(Bytes::new())
692 .map_err(|never| match never {})
693 .boxed(),
694 )
695 .unwrap_or_else(|_| {
696 Response::new(
697 Full::new(Bytes::new())
698 .map_err(|never| match never {})
699 .boxed(),
700 )
701 });
702
703 Ok((response, on_upgrade))
704}
705
706pub struct HttpIngress<R = ()> {
712 addr: Option<String>,
714 routes: Vec<RouteEntry<R>>,
716 fallback: Option<RouteHandler<R>>,
718 layers: Vec<ServiceLayer>,
720 on_start: Option<LifecycleHook>,
722 on_shutdown: Option<LifecycleHook>,
724 graceful_shutdown_timeout: Duration,
726 bus_injectors: Vec<BusInjector>,
728 static_assets: StaticAssetsConfig,
730 health: HealthConfig<R>,
732 #[cfg(feature = "http3")]
733 http3_config: Option<crate::http3::Http3Config>,
734 #[cfg(feature = "http3")]
735 alt_svc_h3_port: Option<u16>,
736 active_intervention: bool,
738 policy_registry: Option<ranvier_core::policy::PolicyRegistry>,
740 _phantom: std::marker::PhantomData<R>,
741}
742
743impl<R> HttpIngress<R>
744where
745 R: ranvier_core::transition::ResourceRequirement + Clone + Send + Sync + 'static,
746{
747 pub fn new() -> Self {
749 Self {
750 addr: None,
751 routes: Vec::new(),
752 fallback: None,
753 layers: Vec::new(),
754 on_start: None,
755 on_shutdown: None,
756 graceful_shutdown_timeout: Duration::from_secs(30),
757 bus_injectors: Vec::new(),
758 static_assets: StaticAssetsConfig::default(),
759 health: HealthConfig::default(),
760 #[cfg(feature = "http3")]
761 http3_config: None,
762 #[cfg(feature = "http3")]
763 alt_svc_h3_port: None,
764 active_intervention: false,
765 policy_registry: None,
766 _phantom: std::marker::PhantomData,
767 }
768 }
769
770 pub fn bind(mut self, addr: impl Into<String>) -> Self {
772 self.addr = Some(addr.into());
773 self
774 }
775
776 pub fn active_intervention(mut self) -> Self {
780 self.active_intervention = true;
781 self
782 }
783
784 pub fn policy_registry(mut self, registry: ranvier_core::policy::PolicyRegistry) -> Self {
786 self.policy_registry = Some(registry);
787 self
788 }
789
790 pub fn on_start<F>(mut self, callback: F) -> Self
792 where
793 F: Fn() + Send + Sync + 'static,
794 {
795 self.on_start = Some(Arc::new(callback));
796 self
797 }
798
799 pub fn on_shutdown<F>(mut self, callback: F) -> Self
801 where
802 F: Fn() + Send + Sync + 'static,
803 {
804 self.on_shutdown = Some(Arc::new(callback));
805 self
806 }
807
808 pub fn graceful_shutdown(mut self, timeout: Duration) -> Self {
810 self.graceful_shutdown_timeout = timeout;
811 self
812 }
813
814 pub fn layer<L>(mut self, layer: L) -> Self
819 where
820 L: Layer<BoxHttpService> + Clone + Send + Sync + 'static,
821 L::Service: Service<Request<Incoming>, Response = HttpResponse, Error = Infallible>
822 + Clone
823 + Send
824 + 'static,
825 <L::Service as Service<Request<Incoming>>>::Future: Send + 'static,
826 {
827 self.layers.push(to_service_layer(layer));
828 self
829 }
830
831 pub fn timeout_layer(mut self, timeout: Duration) -> Self {
834 self.layers.push(Arc::new(move |service: BoxHttpService| {
835 BoxCloneService::new(TimeoutService {
836 inner: service,
837 timeout,
838 })
839 }));
840 self
841 }
842
843 pub fn request_id_layer(mut self) -> Self {
847 self.layers.push(Arc::new(move |service: BoxHttpService| {
848 BoxCloneService::new(RequestIdService { inner: service })
849 }));
850 self
851 }
852
853 pub fn bus_injector<F>(mut self, injector: F) -> Self
858 where
859 F: Fn(&http::request::Parts, &mut Bus) + Send + Sync + 'static,
860 {
861 self.bus_injectors.push(Arc::new(injector));
862 self
863 }
864
865 #[cfg(feature = "http3")]
867 pub fn enable_http3(mut self, config: crate::http3::Http3Config) -> Self {
868 self.http3_config = Some(config);
869 self
870 }
871
872 #[cfg(feature = "http3")]
874 pub fn alt_svc_h3(mut self, port: u16) -> Self {
875 self.alt_svc_h3_port = Some(port);
876 self
877 }
878
879 pub fn route_descriptors(&self) -> Vec<HttpRouteDescriptor> {
881 let mut descriptors = self
882 .routes
883 .iter()
884 .map(|entry| HttpRouteDescriptor::new(entry.method.clone(), entry.pattern.raw.clone()))
885 .collect::<Vec<_>>();
886
887 if let Some(path) = &self.health.health_path {
888 descriptors.push(HttpRouteDescriptor::new(Method::GET, path.clone()));
889 }
890 if let Some(path) = &self.health.readiness_path {
891 descriptors.push(HttpRouteDescriptor::new(Method::GET, path.clone()));
892 }
893 if let Some(path) = &self.health.liveness_path {
894 descriptors.push(HttpRouteDescriptor::new(Method::GET, path.clone()));
895 }
896
897 descriptors
898 }
899
900 pub fn serve_dir(
904 mut self,
905 route_prefix: impl Into<String>,
906 directory: impl Into<String>,
907 ) -> Self {
908 self.static_assets.mounts.push(StaticMount {
909 route_prefix: normalize_route_path(route_prefix.into()),
910 directory: directory.into(),
911 });
912 if self.static_assets.cache_control.is_none() {
913 self.static_assets.cache_control = Some("public, max-age=3600".to_string());
914 }
915 self
916 }
917
918 pub fn spa_fallback(mut self, file_path: impl Into<String>) -> Self {
922 self.static_assets.spa_fallback = Some(file_path.into());
923 self
924 }
925
926 pub fn static_cache_control(mut self, cache_control: impl Into<String>) -> Self {
928 self.static_assets.cache_control = Some(cache_control.into());
929 self
930 }
931
932 pub fn compression_layer(mut self) -> Self {
934 self.static_assets.enable_compression = true;
935 self
936 }
937
938 pub fn ws<H, Fut>(mut self, path: impl Into<String>, handler: H) -> Self
945 where
946 H: Fn(WebSocketConnection, Arc<R>, Bus) -> Fut + Send + Sync + 'static,
947 Fut: Future<Output = ()> + Send + 'static,
948 {
949 let path_str: String = path.into();
950 let ws_handler: WsSessionHandler<R> = Arc::new(move |connection, resources, bus| {
951 Box::pin(handler(connection, resources, bus))
952 });
953 let bus_injectors = Arc::new(self.bus_injectors.clone());
954 let path_for_pattern = path_str.clone();
955 let path_for_handler = path_str;
956
957 let route_handler: RouteHandler<R> =
958 Arc::new(move |parts: http::request::Parts, res: &R| {
959 let ws_handler = ws_handler.clone();
960 let bus_injectors = bus_injectors.clone();
961 let resources = Arc::new(res.clone());
962 let path = path_for_handler.clone();
963
964 Box::pin(async move {
965 let request_id = uuid::Uuid::new_v4().to_string();
966 let span = tracing::info_span!(
967 "WebSocketUpgrade",
968 ranvier.ws.path = %path,
969 ranvier.ws.request_id = %request_id
970 );
971
972 async move {
973 let mut bus = Bus::new();
974 for injector in bus_injectors.iter() {
975 injector(&parts, &mut bus);
976 }
977
978 let mut req = Request::from_parts(parts, ());
980 let session = websocket_session_from_request(&req);
981 bus.insert(session.clone());
982
983 let (response, on_upgrade) = match websocket_upgrade_response(&mut req) {
984 Ok(result) => result,
985 Err(error_response) => return error_response,
986 };
987
988 tokio::spawn(async move {
989 match on_upgrade.await {
990 Ok(upgraded) => {
991 let stream = WebSocketStream::from_raw_socket(
992 TokioIo::new(upgraded),
993 tokio_tungstenite::tungstenite::protocol::Role::Server,
994 None,
995 )
996 .await;
997 let connection = WebSocketConnection::new(stream, session);
998 ws_handler(connection, resources, bus).await;
999 }
1000 Err(error) => {
1001 tracing::warn!(
1002 ranvier.ws.path = %path,
1003 ranvier.ws.error = %error,
1004 "websocket upgrade failed"
1005 );
1006 }
1007 }
1008 });
1009
1010 response
1011 }
1012 .instrument(span)
1013 .await
1014 }) as Pin<Box<dyn Future<Output = HttpResponse> + Send>>
1015 });
1016
1017 self.routes.push(RouteEntry {
1018 method: Method::GET,
1019 pattern: RoutePattern::parse(&path_for_pattern),
1020 handler: route_handler,
1021 layers: Arc::new(Vec::new()),
1022 apply_global_layers: true,
1023 });
1024
1025 self
1026 }
1027
1028 pub fn health_endpoint(mut self, path: impl Into<String>) -> Self {
1033 self.health.health_path = Some(normalize_route_path(path.into()));
1034 self
1035 }
1036
1037 pub fn health_check<F, Fut, Err>(mut self, name: impl Into<String>, check: F) -> Self
1041 where
1042 F: Fn(Arc<R>) -> Fut + Send + Sync + 'static,
1043 Fut: Future<Output = Result<(), Err>> + Send + 'static,
1044 Err: ToString + Send + 'static,
1045 {
1046 if self.health.health_path.is_none() {
1047 self.health.health_path = Some("/health".to_string());
1048 }
1049
1050 let check_fn: HealthCheckFn<R> = Arc::new(move |resources: Arc<R>| {
1051 let fut = check(resources);
1052 Box::pin(async move { fut.await.map_err(|error| error.to_string()) })
1053 });
1054
1055 self.health.checks.push(NamedHealthCheck {
1056 name: name.into(),
1057 check: check_fn,
1058 });
1059 self
1060 }
1061
1062 pub fn readiness_liveness(
1064 mut self,
1065 readiness_path: impl Into<String>,
1066 liveness_path: impl Into<String>,
1067 ) -> Self {
1068 self.health.readiness_path = Some(normalize_route_path(readiness_path.into()));
1069 self.health.liveness_path = Some(normalize_route_path(liveness_path.into()));
1070 self
1071 }
1072
1073 pub fn readiness_liveness_default(self) -> Self {
1075 self.readiness_liveness("/ready", "/live")
1076 }
1077
1078 pub fn route<Out, E>(self, path: impl Into<String>, circuit: Axon<(), Out, E, R>) -> Self
1080 where
1081 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1082 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1083 {
1084 self.route_method(Method::GET, path, circuit)
1085 }
1086 pub fn route_method<Out, E>(
1095 self,
1096 method: Method,
1097 path: impl Into<String>,
1098 circuit: Axon<(), Out, E, R>,
1099 ) -> Self
1100 where
1101 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1102 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1103 {
1104 self.route_method_with_error(method, path, circuit, |error| {
1105 (
1106 StatusCode::INTERNAL_SERVER_ERROR,
1107 format!("Error: {:?}", error),
1108 )
1109 .into_response()
1110 })
1111 }
1112
1113 pub fn route_method_with_error<Out, E, H>(
1114 self,
1115 method: Method,
1116 path: impl Into<String>,
1117 circuit: Axon<(), Out, E, R>,
1118 error_handler: H,
1119 ) -> Self
1120 where
1121 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1122 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1123 H: Fn(&E) -> HttpResponse + Send + Sync + 'static,
1124 {
1125 self.route_method_with_error_and_layers(
1126 method,
1127 path,
1128 circuit,
1129 error_handler,
1130 Arc::new(Vec::new()),
1131 true,
1132 )
1133 }
1134
1135 pub fn route_method_with_layer<Out, E, L>(
1136 self,
1137 method: Method,
1138 path: impl Into<String>,
1139 circuit: Axon<(), Out, E, R>,
1140 layer: L,
1141 ) -> Self
1142 where
1143 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1144 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1145 L: Layer<BoxHttpService> + Clone + Send + Sync + 'static,
1146 L::Service: Service<Request<Incoming>, Response = HttpResponse, Error = Infallible>
1147 + Clone
1148 + Send
1149 + 'static,
1150 <L::Service as Service<Request<Incoming>>>::Future: Send + 'static,
1151 {
1152 self.route_method_with_error_and_layers(
1153 method,
1154 path,
1155 circuit,
1156 |error| {
1157 (
1158 StatusCode::INTERNAL_SERVER_ERROR,
1159 format!("Error: {:?}", error),
1160 )
1161 .into_response()
1162 },
1163 Arc::new(vec![to_service_layer(layer)]),
1164 true,
1165 )
1166 }
1167
1168 pub fn route_method_with_layer_override<Out, E, L>(
1169 self,
1170 method: Method,
1171 path: impl Into<String>,
1172 circuit: Axon<(), Out, E, R>,
1173 layer: L,
1174 ) -> Self
1175 where
1176 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1177 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1178 L: Layer<BoxHttpService> + Clone + Send + Sync + 'static,
1179 L::Service: Service<Request<Incoming>, Response = HttpResponse, Error = Infallible>
1180 + Clone
1181 + Send
1182 + 'static,
1183 <L::Service as Service<Request<Incoming>>>::Future: Send + 'static,
1184 {
1185 self.route_method_with_error_and_layers(
1186 method,
1187 path,
1188 circuit,
1189 |error| {
1190 (
1191 StatusCode::INTERNAL_SERVER_ERROR,
1192 format!("Error: {:?}", error),
1193 )
1194 .into_response()
1195 },
1196 Arc::new(vec![to_service_layer(layer)]),
1197 false,
1198 )
1199 }
1200
1201 fn route_method_with_error_and_layers<Out, E, H>(
1202 mut self,
1203 method: Method,
1204 path: impl Into<String>,
1205 circuit: Axon<(), Out, E, R>,
1206 error_handler: H,
1207 route_layers: Arc<Vec<ServiceLayer>>,
1208 apply_global_layers: bool,
1209 ) -> Self
1210 where
1211 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1212 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1213 H: Fn(&E) -> HttpResponse + Send + Sync + 'static,
1214 {
1215 let path_str: String = path.into();
1216 let circuit = Arc::new(circuit);
1217 let error_handler = Arc::new(error_handler);
1218 let route_bus_injectors = Arc::new(self.bus_injectors.clone());
1219 let path_for_pattern = path_str.clone();
1220 let path_for_handler = path_str;
1221 let method_for_pattern = method.clone();
1222 let method_for_handler = method;
1223
1224 let handler: RouteHandler<R> = Arc::new(move |parts: http::request::Parts, res: &R| {
1225 let circuit = circuit.clone();
1226 let error_handler = error_handler.clone();
1227 let route_bus_injectors = route_bus_injectors.clone();
1228 let res = res.clone();
1229 let path = path_for_handler.clone();
1230 let method = method_for_handler.clone();
1231
1232 Box::pin(async move {
1233 let request_id = uuid::Uuid::new_v4().to_string();
1234 let span = tracing::info_span!(
1235 "HTTPRequest",
1236 ranvier.http.method = %method,
1237 ranvier.http.path = %path,
1238 ranvier.http.request_id = %request_id
1239 );
1240
1241 async move {
1242 let mut bus = Bus::new();
1243 for injector in route_bus_injectors.iter() {
1244 injector(&parts, &mut bus);
1245 }
1246 let result = circuit.execute((), &res, &mut bus).await;
1247 outcome_to_response_with_error(result, |error| error_handler(error))
1248 }
1249 .instrument(span)
1250 .await
1251 }) as Pin<Box<dyn Future<Output = HttpResponse> + Send>>
1252 });
1253
1254 self.routes.push(RouteEntry {
1255 method: method_for_pattern,
1256 pattern: RoutePattern::parse(&path_for_pattern),
1257 handler,
1258 layers: route_layers,
1259 apply_global_layers,
1260 });
1261 self
1262 }
1263
1264 pub fn get<Out, E>(self, path: impl Into<String>, circuit: Axon<(), Out, E, R>) -> Self
1265 where
1266 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1267 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1268 {
1269 self.route_method(Method::GET, path, circuit)
1270 }
1271
1272 pub fn get_with_error<Out, E, H>(
1273 self,
1274 path: impl Into<String>,
1275 circuit: Axon<(), Out, E, R>,
1276 error_handler: H,
1277 ) -> Self
1278 where
1279 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1280 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1281 H: Fn(&E) -> HttpResponse + Send + Sync + 'static,
1282 {
1283 self.route_method_with_error(Method::GET, path, circuit, error_handler)
1284 }
1285
1286 pub fn get_with_layer<Out, E, L>(
1287 self,
1288 path: impl Into<String>,
1289 circuit: Axon<(), Out, E, R>,
1290 layer: L,
1291 ) -> Self
1292 where
1293 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1294 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1295 L: Layer<BoxHttpService> + Clone + Send + Sync + 'static,
1296 L::Service: Service<Request<Incoming>, Response = HttpResponse, Error = Infallible>
1297 + Clone
1298 + Send
1299 + 'static,
1300 <L::Service as Service<Request<Incoming>>>::Future: Send + 'static,
1301 {
1302 self.route_method_with_layer(Method::GET, path, circuit, layer)
1303 }
1304
1305 pub fn get_with_layer_override<Out, E, L>(
1306 self,
1307 path: impl Into<String>,
1308 circuit: Axon<(), Out, E, R>,
1309 layer: L,
1310 ) -> Self
1311 where
1312 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1313 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1314 L: Layer<BoxHttpService> + Clone + Send + Sync + 'static,
1315 L::Service: Service<Request<Incoming>, Response = HttpResponse, Error = Infallible>
1316 + Clone
1317 + Send
1318 + 'static,
1319 <L::Service as Service<Request<Incoming>>>::Future: Send + 'static,
1320 {
1321 self.route_method_with_layer_override(Method::GET, path, circuit, layer)
1322 }
1323
1324 pub fn post<Out, E>(self, path: impl Into<String>, circuit: Axon<(), Out, E, R>) -> Self
1325 where
1326 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1327 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1328 {
1329 self.route_method(Method::POST, path, circuit)
1330 }
1331
1332 pub fn put<Out, E>(self, path: impl Into<String>, circuit: Axon<(), Out, E, R>) -> Self
1333 where
1334 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1335 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1336 {
1337 self.route_method(Method::PUT, path, circuit)
1338 }
1339
1340 pub fn delete<Out, E>(self, path: impl Into<String>, circuit: Axon<(), Out, E, R>) -> Self
1341 where
1342 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1343 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1344 {
1345 self.route_method(Method::DELETE, path, circuit)
1346 }
1347
1348 pub fn patch<Out, E>(self, path: impl Into<String>, circuit: Axon<(), Out, E, R>) -> Self
1349 where
1350 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1351 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1352 {
1353 self.route_method(Method::PATCH, path, circuit)
1354 }
1355
1356 pub fn post_with_error<Out, E, H>(
1357 self,
1358 path: impl Into<String>,
1359 circuit: Axon<(), Out, E, R>,
1360 error_handler: H,
1361 ) -> Self
1362 where
1363 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1364 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1365 H: Fn(&E) -> HttpResponse + Send + Sync + 'static,
1366 {
1367 self.route_method_with_error(Method::POST, path, circuit, error_handler)
1368 }
1369
1370 pub fn put_with_error<Out, E, H>(
1371 self,
1372 path: impl Into<String>,
1373 circuit: Axon<(), Out, E, R>,
1374 error_handler: H,
1375 ) -> Self
1376 where
1377 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1378 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1379 H: Fn(&E) -> HttpResponse + Send + Sync + 'static,
1380 {
1381 self.route_method_with_error(Method::PUT, path, circuit, error_handler)
1382 }
1383
1384 pub fn delete_with_error<Out, E, H>(
1385 self,
1386 path: impl Into<String>,
1387 circuit: Axon<(), Out, E, R>,
1388 error_handler: H,
1389 ) -> Self
1390 where
1391 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1392 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1393 H: Fn(&E) -> HttpResponse + Send + Sync + 'static,
1394 {
1395 self.route_method_with_error(Method::DELETE, path, circuit, error_handler)
1396 }
1397
1398 pub fn patch_with_error<Out, E, H>(
1399 self,
1400 path: impl Into<String>,
1401 circuit: Axon<(), Out, E, R>,
1402 error_handler: H,
1403 ) -> Self
1404 where
1405 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1406 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1407 H: Fn(&E) -> HttpResponse + Send + Sync + 'static,
1408 {
1409 self.route_method_with_error(Method::PATCH, path, circuit, error_handler)
1410 }
1411
1412 pub fn fallback<Out, E>(mut self, circuit: Axon<(), Out, E, R>) -> Self
1423 where
1424 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1425 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1426 {
1427 let circuit = Arc::new(circuit);
1428 let fallback_bus_injectors = Arc::new(self.bus_injectors.clone());
1429
1430 let handler: RouteHandler<R> = Arc::new(move |parts: http::request::Parts, res: &R| {
1431 let circuit = circuit.clone();
1432 let fallback_bus_injectors = fallback_bus_injectors.clone();
1433 let res = res.clone();
1434 Box::pin(async move {
1435 let request_id = uuid::Uuid::new_v4().to_string();
1436 let span = tracing::info_span!(
1437 "HTTPRequest",
1438 ranvier.http.method = "FALLBACK",
1439 ranvier.http.request_id = %request_id
1440 );
1441
1442 async move {
1443 let mut bus = Bus::new();
1444 for injector in fallback_bus_injectors.iter() {
1445 injector(&parts, &mut bus);
1446 }
1447 let result: ranvier_core::Outcome<Out, E> =
1448 circuit.execute((), &res, &mut bus).await;
1449
1450 match result {
1451 Outcome::Next(output) => {
1452 let mut response = output.into_response();
1453 *response.status_mut() = StatusCode::NOT_FOUND;
1454 response
1455 }
1456 _ => Response::builder()
1457 .status(StatusCode::NOT_FOUND)
1458 .body(
1459 Full::new(Bytes::from("Not Found"))
1460 .map_err(|never| match never {})
1461 .boxed(),
1462 )
1463 .unwrap(),
1464 }
1465 }
1466 .instrument(span)
1467 .await
1468 }) as Pin<Box<dyn Future<Output = HttpResponse> + Send>>
1469 });
1470
1471 self.fallback = Some(handler);
1472 self
1473 }
1474
1475 pub async fn run(self, resources: R) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
1477 self.run_with_shutdown_signal(resources, shutdown_signal())
1478 .await
1479 }
1480
1481 async fn run_with_shutdown_signal<S>(
1482 self,
1483 resources: R,
1484 shutdown_signal: S,
1485 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>>
1486 where
1487 S: Future<Output = ()> + Send,
1488 {
1489 let addr_str = self.addr.as_deref().unwrap_or("127.0.0.1:3000");
1490 let addr: SocketAddr = addr_str.parse()?;
1491
1492 let mut raw_routes = self.routes;
1493 if self.active_intervention {
1494 let handler: RouteHandler<R> = Arc::new(|_parts, _res| {
1495 Box::pin(async move {
1496 Response::builder()
1497 .status(StatusCode::OK)
1498 .body(
1499 Full::new(Bytes::from("Intervention accepted"))
1500 .map_err(|never| match never {} as Infallible)
1501 .boxed(),
1502 )
1503 .unwrap()
1504 }) as Pin<Box<dyn Future<Output = HttpResponse> + Send>>
1505 });
1506
1507 raw_routes.push(RouteEntry {
1508 method: Method::POST,
1509 pattern: RoutePattern::parse("/_system/intervene/force_resume"),
1510 handler,
1511 layers: Arc::new(Vec::new()),
1512 apply_global_layers: true,
1513 });
1514 }
1515
1516 if let Some(registry) = self.policy_registry.clone() {
1517 let handler: RouteHandler<R> = Arc::new(move |_parts, _res| {
1518 let _registry = registry.clone();
1519 Box::pin(async move {
1520 Response::builder()
1524 .status(StatusCode::OK)
1525 .body(
1526 Full::new(Bytes::from("Policy registry active"))
1527 .map_err(|never| match never {} as Infallible)
1528 .boxed(),
1529 )
1530 .unwrap()
1531 }) as Pin<Box<dyn Future<Output = HttpResponse> + Send>>
1532 });
1533
1534 raw_routes.push(RouteEntry {
1535 method: Method::POST,
1536 pattern: RoutePattern::parse("/_system/policy/reload"),
1537 handler,
1538 layers: Arc::new(Vec::new()),
1539 apply_global_layers: true,
1540 });
1541 }
1542 let routes = Arc::new(raw_routes);
1543 let fallback = self.fallback;
1544 let layers = Arc::new(self.layers);
1545 let health = Arc::new(self.health);
1546 let static_assets = Arc::new(self.static_assets);
1547 let on_start = self.on_start;
1548 let on_shutdown = self.on_shutdown;
1549 let graceful_shutdown_timeout = self.graceful_shutdown_timeout;
1550 let resources = Arc::new(resources);
1551
1552 let listener = TcpListener::bind(addr).await?;
1553 tracing::info!("Ranvier HTTP Ingress listening on http://{}", addr);
1554 if let Some(callback) = on_start.as_ref() {
1555 callback();
1556 }
1557
1558 tokio::pin!(shutdown_signal);
1559 let mut connections = tokio::task::JoinSet::new();
1560
1561 loop {
1562 tokio::select! {
1563 _ = &mut shutdown_signal => {
1564 tracing::info!("Shutdown signal received. Draining in-flight connections.");
1565 break;
1566 }
1567 accept_result = listener.accept() => {
1568 let (stream, _) = accept_result?;
1569 let io = TokioIo::new(stream);
1570
1571 let routes = routes.clone();
1572 let fallback = fallback.clone();
1573 let resources = resources.clone();
1574 let layers = layers.clone();
1575 let health = health.clone();
1576 let static_assets = static_assets.clone();
1577 #[cfg(feature = "http3")]
1578 let alt_svc_h3_port = self.alt_svc_h3_port;
1579
1580 connections.spawn(async move {
1581 let service = build_http_service(
1582 routes,
1583 fallback,
1584 resources,
1585 layers,
1586 health,
1587 static_assets,
1588 #[cfg(feature = "http3")] alt_svc_h3_port,
1589 );
1590 let hyper_service = TowerToHyperService::new(service);
1591 if let Err(err) = http1::Builder::new()
1592 .serve_connection(io, hyper_service)
1593 .with_upgrades()
1594 .await
1595 {
1596 tracing::error!("Error serving connection: {:?}", err);
1597 }
1598 });
1599 }
1600 Some(join_result) = connections.join_next(), if !connections.is_empty() => {
1601 if let Err(err) = join_result {
1602 tracing::warn!("Connection task join error: {:?}", err);
1603 }
1604 }
1605 }
1606 }
1607
1608 let _timed_out = drain_connections(&mut connections, graceful_shutdown_timeout).await;
1609
1610 drop(resources);
1611 if let Some(callback) = on_shutdown.as_ref() {
1612 callback();
1613 }
1614
1615 Ok(())
1616 }
1617
1618 pub fn into_raw_service(self, resources: R) -> RawIngressService<R> {
1634 let routes = Arc::new(self.routes);
1635 let fallback = self.fallback;
1636 let layers = Arc::new(self.layers);
1637 let health = Arc::new(self.health);
1638 let static_assets = Arc::new(self.static_assets);
1639 let resources = Arc::new(resources);
1640
1641 RawIngressService {
1642 routes,
1643 fallback,
1644 layers,
1645 health,
1646 static_assets,
1647 resources,
1648 }
1649 }
1650}
1651
1652fn build_http_service<R>(
1653 routes: Arc<Vec<RouteEntry<R>>>,
1654 fallback: Option<RouteHandler<R>>,
1655 resources: Arc<R>,
1656 layers: Arc<Vec<ServiceLayer>>,
1657 health: Arc<HealthConfig<R>>,
1658 static_assets: Arc<StaticAssetsConfig>,
1659 #[cfg(feature = "http3")] alt_svc_port: Option<u16>,
1660) -> BoxHttpService
1661where
1662 R: ranvier_core::transition::ResourceRequirement + Clone + Send + Sync + 'static,
1663{
1664 let base_service = service_fn(move |req: Request<Incoming>| {
1665 let routes = routes.clone();
1666 let fallback = fallback.clone();
1667 let resources = resources.clone();
1668 let layers = layers.clone();
1669 let health = health.clone();
1670 let static_assets = static_assets.clone();
1671
1672 async move {
1673 let mut req = req;
1674 let method = req.method().clone();
1675 let path = req.uri().path().to_string();
1676
1677 if let Some(response) =
1678 maybe_handle_health_request(&method, &path, &health, resources.clone()).await
1679 {
1680 return Ok::<_, Infallible>(response.into_response());
1681 }
1682
1683 if let Some((entry, params)) = find_matching_route(routes.as_slice(), &method, &path) {
1684 req.extensions_mut().insert(params);
1685 let effective_layers = if entry.apply_global_layers {
1686 merge_layers(&layers, &entry.layers)
1687 } else {
1688 entry.layers.clone()
1689 };
1690
1691 if effective_layers.is_empty() {
1692 let (parts, _) = req.into_parts();
1693 #[allow(unused_mut)]
1694 let mut res = (entry.handler)(parts, &resources).await;
1695 #[cfg(feature = "http3")]
1696 if let Some(port) = alt_svc_port {
1697 if let Ok(val) =
1698 http::HeaderValue::from_str(&format!("h3=\":{}\"; ma=86400", port))
1699 {
1700 res.headers_mut().insert(http::header::ALT_SVC, val);
1701 }
1702 }
1703 Ok::<_, Infallible>(res)
1704 } else {
1705 let route_service = build_route_service(
1706 entry.handler.clone(),
1707 resources.clone(),
1708 effective_layers,
1709 );
1710 #[allow(unused_mut)]
1711 let mut res = route_service.oneshot(req).await;
1712 #[cfg(feature = "http3")]
1713 #[allow(irrefutable_let_patterns)]
1714 if let Ok(ref mut r) = res {
1715 if let Some(port) = alt_svc_port {
1716 if let Ok(val) =
1717 http::HeaderValue::from_str(&format!("h3=\":{}\"; ma=86400", port))
1718 {
1719 r.headers_mut().insert(http::header::ALT_SVC, val);
1720 }
1721 }
1722 }
1723 res
1724 }
1725 } else {
1726 let req =
1727 match maybe_handle_static_request(req, &method, &path, static_assets.as_ref())
1728 .await
1729 {
1730 Ok(req) => req,
1731 Err(response) => return Ok(response),
1732 };
1733
1734 #[allow(unused_mut)]
1735 let mut fallback_res = if let Some(ref fb) = fallback {
1736 if layers.is_empty() {
1737 let (parts, _) = req.into_parts();
1738 Ok(fb(parts, &resources).await)
1739 } else {
1740 let fallback_service =
1741 build_route_service(fb.clone(), resources.clone(), layers.clone());
1742 fallback_service.oneshot(req).await
1743 }
1744 } else {
1745 Ok(Response::builder()
1746 .status(StatusCode::NOT_FOUND)
1747 .body(
1748 Full::new(Bytes::from("Not Found"))
1749 .map_err(|never| match never {})
1750 .boxed(),
1751 )
1752 .unwrap())
1753 };
1754
1755 #[cfg(feature = "http3")]
1756 if let Ok(r) = fallback_res.as_mut() {
1757 if let Some(port) = alt_svc_port {
1758 if let Ok(val) =
1759 http::HeaderValue::from_str(&format!("h3=\":{}\"; ma=86400", port))
1760 {
1761 r.headers_mut().insert(http::header::ALT_SVC, val);
1762 }
1763 }
1764 }
1765
1766 fallback_res
1767 }
1768 }
1769 });
1770
1771 BoxCloneService::new(base_service)
1772}
1773
1774fn build_route_service<R>(
1775 handler: RouteHandler<R>,
1776 resources: Arc<R>,
1777 layers: Arc<Vec<ServiceLayer>>,
1778) -> BoxHttpService
1779where
1780 R: ranvier_core::transition::ResourceRequirement + Clone + Send + Sync + 'static,
1781{
1782 let base_service = service_fn(move |req: Request<Incoming>| {
1783 let handler = handler.clone();
1784 let resources = resources.clone();
1785 let (parts, _) = req.into_parts();
1786 async move { Ok::<_, Infallible>(handler(parts, &resources).await) }
1787 });
1788
1789 let mut service = BoxCloneService::new(base_service);
1790 for layer in layers.iter() {
1791 service = layer(service);
1792 }
1793 service
1794}
1795
1796fn merge_layers(
1797 global_layers: &Arc<Vec<ServiceLayer>>,
1798 route_layers: &Arc<Vec<ServiceLayer>>,
1799) -> Arc<Vec<ServiceLayer>> {
1800 if global_layers.is_empty() {
1801 return route_layers.clone();
1802 }
1803 if route_layers.is_empty() {
1804 return global_layers.clone();
1805 }
1806
1807 let mut combined = Vec::with_capacity(global_layers.len() + route_layers.len());
1808 combined.extend(global_layers.iter().cloned());
1809 combined.extend(route_layers.iter().cloned());
1810 Arc::new(combined)
1811}
1812
1813async fn maybe_handle_health_request<R>(
1814 method: &Method,
1815 path: &str,
1816 health: &HealthConfig<R>,
1817 resources: Arc<R>,
1818) -> Option<HttpResponse>
1819where
1820 R: ranvier_core::transition::ResourceRequirement + Clone + Send + Sync + 'static,
1821{
1822 if method != Method::GET {
1823 return None;
1824 }
1825
1826 if let Some(liveness_path) = health.liveness_path.as_ref() {
1827 if path == liveness_path {
1828 return Some(health_json_response("liveness", true, Vec::new()));
1829 }
1830 }
1831
1832 if let Some(readiness_path) = health.readiness_path.as_ref() {
1833 if path == readiness_path {
1834 let (healthy, checks) = run_named_health_checks(&health.checks, resources).await;
1835 return Some(health_json_response("readiness", healthy, checks));
1836 }
1837 }
1838
1839 if let Some(health_path) = health.health_path.as_ref() {
1840 if path == health_path {
1841 let (healthy, checks) = run_named_health_checks(&health.checks, resources).await;
1842 return Some(health_json_response("health", healthy, checks));
1843 }
1844 }
1845
1846 None
1847}
1848
1849async fn maybe_handle_static_request(
1850 req: Request<Incoming>,
1851 method: &Method,
1852 path: &str,
1853 static_assets: &StaticAssetsConfig,
1854) -> Result<Request<Incoming>, HttpResponse> {
1855 if method != Method::GET && method != Method::HEAD {
1856 return Ok(req);
1857 }
1858
1859 if let Some(mount) = static_assets
1860 .mounts
1861 .iter()
1862 .find(|mount| strip_mount_prefix(path, &mount.route_prefix).is_some())
1863 {
1864 let accept_encoding = req.headers().get(http::header::ACCEPT_ENCODING).cloned();
1865 let Some(stripped_path) = strip_mount_prefix(path, &mount.route_prefix) else {
1866 return Ok(req);
1867 };
1868 let rewritten = rewrite_request_path(req, &stripped_path);
1869 let service = ServeDir::new(&mount.directory);
1870 let response = match service.oneshot(rewritten).await {
1871 Ok(response) => response,
1872 Err(_) => {
1873 return Err(Response::builder()
1874 .status(StatusCode::INTERNAL_SERVER_ERROR)
1875 .body(
1876 Full::new(Bytes::from("Failed to serve static asset"))
1877 .map_err(|never| match never {})
1878 .boxed(),
1879 )
1880 .unwrap_or_else(|_| {
1881 Response::new(
1882 Full::new(Bytes::new())
1883 .map_err(|never| match never {})
1884 .boxed(),
1885 )
1886 }));
1887 }
1888 };
1889 let response =
1890 collect_static_response(response, static_assets.cache_control.as_deref()).await;
1891 let response = maybe_compress_static_response(
1892 response,
1893 accept_encoding,
1894 static_assets.enable_compression,
1895 )
1896 .await;
1897 let (parts, body) = response.into_parts();
1898 return Err(Response::from_parts(
1899 parts,
1900 body.map_err(|never| match never {}).boxed(),
1901 ));
1902 }
1903
1904 if let Some(spa_file) = static_assets.spa_fallback.as_ref() {
1905 if looks_like_spa_request(path) {
1906 let accept_encoding = req.headers().get(http::header::ACCEPT_ENCODING).cloned();
1907 let service = ServeFile::new(spa_file);
1908 let response = match service.oneshot(req).await {
1909 Ok(response) => response,
1910 Err(_) => {
1911 return Err(Response::builder()
1912 .status(StatusCode::INTERNAL_SERVER_ERROR)
1913 .body(
1914 Full::new(Bytes::from("Failed to serve SPA fallback"))
1915 .map_err(|never| match never {})
1916 .boxed(),
1917 )
1918 .unwrap_or_else(|_| {
1919 Response::new(
1920 Full::new(Bytes::new())
1921 .map_err(|never| match never {})
1922 .boxed(),
1923 )
1924 }));
1925 }
1926 };
1927 let response =
1928 collect_static_response(response, static_assets.cache_control.as_deref()).await;
1929 let response = maybe_compress_static_response(
1930 response,
1931 accept_encoding,
1932 static_assets.enable_compression,
1933 )
1934 .await;
1935 let (parts, body) = response.into_parts();
1936 return Err(Response::from_parts(
1937 parts,
1938 body.map_err(|never| match never {}).boxed(),
1939 ));
1940 }
1941 }
1942
1943 Ok(req)
1944}
1945
1946fn strip_mount_prefix(path: &str, prefix: &str) -> Option<String> {
1947 let normalized_prefix = if prefix == "/" {
1948 "/"
1949 } else {
1950 prefix.trim_end_matches('/')
1951 };
1952
1953 if normalized_prefix == "/" {
1954 return Some(path.to_string());
1955 }
1956
1957 if path == normalized_prefix {
1958 return Some("/".to_string());
1959 }
1960
1961 let with_slash = format!("{normalized_prefix}/");
1962 path.strip_prefix(&with_slash)
1963 .map(|stripped| format!("/{}", stripped))
1964}
1965
1966fn rewrite_request_path(mut req: Request<Incoming>, new_path: &str) -> Request<Incoming> {
1967 let query = req.uri().query().map(str::to_string);
1968 let path_and_query = match query {
1969 Some(query) => format!("{new_path}?{query}"),
1970 None => new_path.to_string(),
1971 };
1972
1973 let mut parts = req.uri().clone().into_parts();
1974 if let Ok(parsed_path_and_query) = path_and_query.parse() {
1975 parts.path_and_query = Some(parsed_path_and_query);
1976 if let Ok(uri) = Uri::from_parts(parts) {
1977 *req.uri_mut() = uri;
1978 }
1979 }
1980
1981 req
1982}
1983
1984async fn collect_static_response<B>(
1985 response: Response<B>,
1986 cache_control: Option<&str>,
1987) -> Response<Full<Bytes>>
1988where
1989 B: Body<Data = Bytes> + Send + 'static,
1990 B::Error: std::fmt::Display,
1991{
1992 let status = response.status();
1993 let headers = response.headers().clone();
1994 let body = response.into_body();
1995 let collected = body.collect().await;
1996
1997 let bytes = match collected {
1998 Ok(value) => value.to_bytes(),
1999 Err(error) => Bytes::from(error.to_string()),
2000 };
2001
2002 let mut builder = Response::builder().status(status);
2003 for (name, value) in headers.iter() {
2004 builder = builder.header(name, value);
2005 }
2006
2007 let mut response = builder
2008 .body(Full::new(bytes))
2009 .unwrap_or_else(|_| Response::new(Full::new(Bytes::new())));
2010
2011 if status == StatusCode::OK {
2012 if let Some(value) = cache_control {
2013 if !response.headers().contains_key(http::header::CACHE_CONTROL) {
2014 if let Ok(header_value) = http::HeaderValue::from_str(value) {
2015 response
2016 .headers_mut()
2017 .insert(http::header::CACHE_CONTROL, header_value);
2018 }
2019 }
2020 }
2021 }
2022
2023 response
2024}
2025
2026fn looks_like_spa_request(path: &str) -> bool {
2027 let tail = path.rsplit('/').next().unwrap_or_default();
2028 !tail.contains('.')
2029}
2030
2031async fn maybe_compress_static_response(
2032 response: Response<Full<Bytes>>,
2033 accept_encoding: Option<http::HeaderValue>,
2034 enable_compression: bool,
2035) -> Response<Full<Bytes>> {
2036 if !enable_compression {
2037 return response;
2038 }
2039
2040 let Some(accept_encoding) = accept_encoding else {
2041 return response;
2042 };
2043
2044 let mut request = Request::builder()
2045 .uri("/")
2046 .body(Full::new(Bytes::new()))
2047 .unwrap_or_else(|_| Request::new(Full::new(Bytes::new())));
2048 request
2049 .headers_mut()
2050 .insert(http::header::ACCEPT_ENCODING, accept_encoding);
2051
2052 let service = CompressionLayer::new().layer(service_fn({
2053 let response = response.clone();
2054 move |_req: Request<Full<Bytes>>| {
2055 let response = response.clone();
2056 async move { Ok::<_, Infallible>(response) }
2057 }
2058 }));
2059
2060 match service.oneshot(request).await {
2061 Ok(compressed) => collect_static_response(compressed, None).await,
2062 Err(_) => response,
2063 }
2064}
2065
2066async fn run_named_health_checks<R>(
2067 checks: &[NamedHealthCheck<R>],
2068 resources: Arc<R>,
2069) -> (bool, Vec<HealthCheckReport>)
2070where
2071 R: ranvier_core::transition::ResourceRequirement + Clone + Send + Sync + 'static,
2072{
2073 let mut reports = Vec::with_capacity(checks.len());
2074 let mut healthy = true;
2075
2076 for check in checks {
2077 match (check.check)(resources.clone()).await {
2078 Ok(()) => reports.push(HealthCheckReport {
2079 name: check.name.clone(),
2080 status: "ok",
2081 error: None,
2082 }),
2083 Err(error) => {
2084 healthy = false;
2085 reports.push(HealthCheckReport {
2086 name: check.name.clone(),
2087 status: "error",
2088 error: Some(error),
2089 });
2090 }
2091 }
2092 }
2093
2094 (healthy, reports)
2095}
2096
2097fn health_json_response(
2098 probe: &'static str,
2099 healthy: bool,
2100 checks: Vec<HealthCheckReport>,
2101) -> HttpResponse {
2102 let status_code = if healthy {
2103 StatusCode::OK
2104 } else {
2105 StatusCode::SERVICE_UNAVAILABLE
2106 };
2107 let status = if healthy { "ok" } else { "degraded" };
2108 let payload = HealthReport {
2109 status,
2110 probe,
2111 checks,
2112 };
2113
2114 let body = serde_json::to_vec(&payload)
2115 .unwrap_or_else(|_| br#"{"status":"error","probe":"health"}"#.to_vec());
2116
2117 Response::builder()
2118 .status(status_code)
2119 .header(http::header::CONTENT_TYPE, "application/json")
2120 .body(
2121 Full::new(Bytes::from(body))
2122 .map_err(|never| match never {})
2123 .boxed(),
2124 )
2125 .unwrap()
2126}
2127
2128async fn shutdown_signal() {
2129 #[cfg(unix)]
2130 {
2131 use tokio::signal::unix::{SignalKind, signal};
2132
2133 match signal(SignalKind::terminate()) {
2134 Ok(mut terminate) => {
2135 tokio::select! {
2136 _ = tokio::signal::ctrl_c() => {}
2137 _ = terminate.recv() => {}
2138 }
2139 }
2140 Err(err) => {
2141 tracing::warn!("Failed to install SIGTERM handler: {:?}", err);
2142 if let Err(ctrl_c_err) = tokio::signal::ctrl_c().await {
2143 tracing::warn!("Failed to listen for Ctrl+C: {:?}", ctrl_c_err);
2144 }
2145 }
2146 }
2147 }
2148
2149 #[cfg(not(unix))]
2150 {
2151 if let Err(err) = tokio::signal::ctrl_c().await {
2152 tracing::warn!("Failed to listen for Ctrl+C: {:?}", err);
2153 }
2154 }
2155}
2156
2157async fn drain_connections(
2158 connections: &mut tokio::task::JoinSet<()>,
2159 graceful_shutdown_timeout: Duration,
2160) -> bool {
2161 if connections.is_empty() {
2162 return false;
2163 }
2164
2165 let drain_result = tokio::time::timeout(graceful_shutdown_timeout, async {
2166 while let Some(join_result) = connections.join_next().await {
2167 if let Err(err) = join_result {
2168 tracing::warn!("Connection task join error during shutdown: {:?}", err);
2169 }
2170 }
2171 })
2172 .await;
2173
2174 if drain_result.is_err() {
2175 tracing::warn!(
2176 "Graceful shutdown timeout reached ({:?}). Aborting remaining connections.",
2177 graceful_shutdown_timeout
2178 );
2179 connections.abort_all();
2180 while let Some(join_result) = connections.join_next().await {
2181 if let Err(err) = join_result {
2182 tracing::warn!("Connection task abort join error: {:?}", err);
2183 }
2184 }
2185 true
2186 } else {
2187 false
2188 }
2189}
2190
2191impl<R> Default for HttpIngress<R>
2192where
2193 R: ranvier_core::transition::ResourceRequirement + Clone + Send + Sync + 'static,
2194{
2195 fn default() -> Self {
2196 Self::new()
2197 }
2198}
2199
2200#[derive(Clone)]
2202pub struct RawIngressService<R> {
2203 routes: Arc<Vec<RouteEntry<R>>>,
2204 fallback: Option<RouteHandler<R>>,
2205 layers: Arc<Vec<ServiceLayer>>,
2206 health: Arc<HealthConfig<R>>,
2207 static_assets: Arc<StaticAssetsConfig>,
2208 resources: Arc<R>,
2209}
2210
2211impl<R> Service<Request<Incoming>> for RawIngressService<R>
2212where
2213 R: ranvier_core::transition::ResourceRequirement + Clone + Send + Sync + 'static,
2214{
2215 type Response = HttpResponse;
2216 type Error = Infallible;
2217 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
2218
2219 fn poll_ready(
2220 &mut self,
2221 _cx: &mut std::task::Context<'_>,
2222 ) -> std::task::Poll<Result<(), Self::Error>> {
2223 std::task::Poll::Ready(Ok(()))
2224 }
2225
2226 fn call(&mut self, req: Request<Incoming>) -> Self::Future {
2227 let routes = self.routes.clone();
2228 let fallback = self.fallback.clone();
2229 let layers = self.layers.clone();
2230 let health = self.health.clone();
2231 let static_assets = self.static_assets.clone();
2232 let resources = self.resources.clone();
2233
2234 Box::pin(async move {
2235 let service = build_http_service(
2236 routes,
2237 fallback,
2238 resources,
2239 layers,
2240 health,
2241 static_assets,
2242 #[cfg(feature = "http3")]
2243 None,
2244 );
2245 service.oneshot(req).await
2246 })
2247 }
2248}
2249
2250#[cfg(test)]
2251mod tests {
2252 use super::*;
2253 use async_trait::async_trait;
2254 use futures_util::{SinkExt, StreamExt};
2255 use ranvier_observe::{HttpMetrics, HttpMetricsLayer, IncomingTraceContext, TraceContextLayer};
2256 use serde::Deserialize;
2257 use std::fs;
2258 use std::sync::atomic::{AtomicBool, Ordering};
2259 use tempfile::tempdir;
2260 use tokio::io::{AsyncReadExt, AsyncWriteExt};
2261 use tokio_tungstenite::tungstenite::Message as WsClientMessage;
2262 use tokio_tungstenite::tungstenite::client::IntoClientRequest;
2263
2264 async fn connect_with_retry(addr: std::net::SocketAddr) -> tokio::net::TcpStream {
2265 let deadline = tokio::time::Instant::now() + Duration::from_secs(2);
2266
2267 loop {
2268 match tokio::net::TcpStream::connect(addr).await {
2269 Ok(stream) => return stream,
2270 Err(error) => {
2271 if tokio::time::Instant::now() >= deadline {
2272 panic!("connect server: {error}");
2273 }
2274 tokio::time::sleep(Duration::from_millis(25)).await;
2275 }
2276 }
2277 }
2278 }
2279
2280 #[test]
2281 fn route_pattern_matches_static_path() {
2282 let pattern = RoutePattern::parse("/orders/list");
2283 let params = pattern.match_path("/orders/list").expect("should match");
2284 assert!(params.into_inner().is_empty());
2285 }
2286
2287 #[test]
2288 fn route_pattern_matches_param_segments() {
2289 let pattern = RoutePattern::parse("/orders/:id/items/:item_id");
2290 let params = pattern
2291 .match_path("/orders/42/items/sku-123")
2292 .expect("should match");
2293 assert_eq!(params.get("id"), Some("42"));
2294 assert_eq!(params.get("item_id"), Some("sku-123"));
2295 }
2296
2297 #[test]
2298 fn route_pattern_matches_wildcard_segment() {
2299 let pattern = RoutePattern::parse("/assets/*path");
2300 let params = pattern
2301 .match_path("/assets/css/theme/light.css")
2302 .expect("should match");
2303 assert_eq!(params.get("path"), Some("css/theme/light.css"));
2304 }
2305
2306 #[test]
2307 fn route_pattern_rejects_non_matching_path() {
2308 let pattern = RoutePattern::parse("/orders/:id");
2309 assert!(pattern.match_path("/users/42").is_none());
2310 }
2311
2312 #[test]
2313 fn graceful_shutdown_timeout_defaults_to_30_seconds() {
2314 let ingress = HttpIngress::<()>::new();
2315 assert_eq!(ingress.graceful_shutdown_timeout, Duration::from_secs(30));
2316 assert!(ingress.layers.is_empty());
2317 assert!(ingress.bus_injectors.is_empty());
2318 assert!(ingress.static_assets.mounts.is_empty());
2319 assert!(ingress.on_start.is_none());
2320 assert!(ingress.on_shutdown.is_none());
2321 }
2322
2323 #[test]
2324 fn layer_registration_stacks_globally() {
2325 let ingress = HttpIngress::<()>::new()
2326 .layer(tower::layer::util::Identity::new())
2327 .layer(tower::layer::util::Identity::new());
2328 assert_eq!(ingress.layers.len(), 2);
2329 }
2330
2331 #[test]
2332 fn layer_accepts_tower_http_cors_layer() {
2333 let ingress = HttpIngress::<()>::new().layer(tower_http::cors::CorsLayer::permissive());
2334 assert_eq!(ingress.layers.len(), 1);
2335 }
2336
2337 #[test]
2338 fn route_without_layer_keeps_empty_route_middleware_stack() {
2339 let ingress =
2340 HttpIngress::<()>::new().get("/ping", Axon::<(), (), String, ()>::new("Ping"));
2341 assert_eq!(ingress.routes.len(), 1);
2342 assert!(ingress.routes[0].layers.is_empty());
2343 assert!(ingress.routes[0].apply_global_layers);
2344 }
2345
2346 #[test]
2347 fn route_with_layer_registers_route_middleware_stack() {
2348 let ingress = HttpIngress::<()>::new().get_with_layer(
2349 "/ping",
2350 Axon::<(), (), String, ()>::new("Ping"),
2351 tower::layer::util::Identity::new(),
2352 );
2353 assert_eq!(ingress.routes.len(), 1);
2354 assert_eq!(ingress.routes[0].layers.len(), 1);
2355 assert!(ingress.routes[0].apply_global_layers);
2356 }
2357
2358 #[test]
2359 fn route_with_layer_override_disables_global_layers() {
2360 let ingress = HttpIngress::<()>::new().get_with_layer_override(
2361 "/ping",
2362 Axon::<(), (), String, ()>::new("Ping"),
2363 tower::layer::util::Identity::new(),
2364 );
2365 assert_eq!(ingress.routes.len(), 1);
2366 assert_eq!(ingress.routes[0].layers.len(), 1);
2367 assert!(!ingress.routes[0].apply_global_layers);
2368 }
2369
2370 #[test]
2371 fn timeout_layer_registers_builtin_middleware() {
2372 let ingress = HttpIngress::<()>::new().timeout_layer(Duration::from_secs(1));
2373 assert_eq!(ingress.layers.len(), 1);
2374 }
2375
2376 #[test]
2377 fn request_id_layer_registers_builtin_middleware() {
2378 let ingress = HttpIngress::<()>::new().request_id_layer();
2379 assert_eq!(ingress.layers.len(), 1);
2380 }
2381
2382 #[test]
2383 fn compression_layer_registers_builtin_middleware() {
2384 let ingress = HttpIngress::<()>::new().compression_layer();
2385 assert!(ingress.static_assets.enable_compression);
2386 }
2387
2388 #[test]
2389 fn bus_injector_registration_adds_hook() {
2390 let ingress = HttpIngress::<()>::new().bus_injector(|_req, bus| {
2391 bus.insert("ok".to_string());
2392 });
2393 assert_eq!(ingress.bus_injectors.len(), 1);
2394 }
2395
2396 #[test]
2397 fn ws_route_registers_get_route_pattern() {
2398 let ingress =
2399 HttpIngress::<()>::new().ws("/ws/events", |_socket, _resources, _bus| async {});
2400 assert_eq!(ingress.routes.len(), 1);
2401 assert_eq!(ingress.routes[0].method, Method::GET);
2402 assert_eq!(ingress.routes[0].pattern.raw, "/ws/events");
2403 }
2404
2405 #[derive(Debug, Deserialize)]
2406 struct WsWelcomeFrame {
2407 connection_id: String,
2408 path: String,
2409 tenant: String,
2410 }
2411
2412 #[tokio::test]
2413 async fn ws_route_upgrades_and_bridges_event_source_sink_with_connection_bus() {
2414 let probe = std::net::TcpListener::bind("127.0.0.1:0").expect("bind probe");
2415 let addr = probe.local_addr().expect("local addr");
2416 drop(probe);
2417
2418 let ingress = HttpIngress::<()>::new()
2419 .bind(addr.to_string())
2420 .bus_injector(|req, bus| {
2421 if let Some(value) = req.headers.get("x-tenant-id").and_then(|v| v.to_str().ok()) {
2422 bus.insert(value.to_string());
2423 }
2424 })
2425 .ws("/ws/echo", |mut socket, _resources, bus| async move {
2426 let tenant = bus
2427 .read::<String>()
2428 .cloned()
2429 .unwrap_or_else(|| "unknown".to_string());
2430 if let Some(session) = bus.read::<WebSocketSessionContext>() {
2431 let welcome = serde_json::json!({
2432 "connection_id": session.connection_id().to_string(),
2433 "path": session.path(),
2434 "tenant": tenant,
2435 });
2436 let _ = socket.send_json(&welcome).await;
2437 }
2438
2439 while let Some(event) = socket.next_event().await {
2440 match event {
2441 WebSocketEvent::Text(text) => {
2442 let _ = socket.send_event(format!("echo:{text}")).await;
2443 }
2444 WebSocketEvent::Binary(bytes) => {
2445 let _ = socket.send_event(bytes).await;
2446 }
2447 WebSocketEvent::Close => break,
2448 WebSocketEvent::Ping(_) | WebSocketEvent::Pong(_) => {}
2449 }
2450 }
2451 });
2452
2453 let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel::<()>();
2454 let server = tokio::spawn(async move {
2455 ingress
2456 .run_with_shutdown_signal((), async move {
2457 let _ = shutdown_rx.await;
2458 })
2459 .await
2460 });
2461
2462 let ws_uri = format!("ws://{addr}/ws/echo?room=alpha");
2463 let mut ws_request = ws_uri
2464 .as_str()
2465 .into_client_request()
2466 .expect("ws client request");
2467 ws_request
2468 .headers_mut()
2469 .insert("x-tenant-id", http::HeaderValue::from_static("acme"));
2470 let (mut client, _response) = tokio_tungstenite::connect_async(ws_request)
2471 .await
2472 .expect("websocket connect");
2473
2474 let welcome = client
2475 .next()
2476 .await
2477 .expect("welcome frame")
2478 .expect("welcome frame ok");
2479 let welcome_text = match welcome {
2480 WsClientMessage::Text(text) => text.to_string(),
2481 other => panic!("expected text welcome frame, got {other:?}"),
2482 };
2483 let welcome_payload: WsWelcomeFrame =
2484 serde_json::from_str(&welcome_text).expect("welcome json");
2485 assert_eq!(welcome_payload.path, "/ws/echo");
2486 assert_eq!(welcome_payload.tenant, "acme");
2487 assert!(!welcome_payload.connection_id.is_empty());
2488
2489 client
2490 .send(WsClientMessage::Text("hello".into()))
2491 .await
2492 .expect("send text");
2493 let echo_text = client
2494 .next()
2495 .await
2496 .expect("echo text frame")
2497 .expect("echo text frame ok");
2498 assert_eq!(echo_text, WsClientMessage::Text("echo:hello".into()));
2499
2500 client
2501 .send(WsClientMessage::Binary(vec![1, 2, 3, 4].into()))
2502 .await
2503 .expect("send binary");
2504 let echo_binary = client
2505 .next()
2506 .await
2507 .expect("echo binary frame")
2508 .expect("echo binary frame ok");
2509 assert_eq!(
2510 echo_binary,
2511 WsClientMessage::Binary(vec![1, 2, 3, 4].into())
2512 );
2513
2514 client.close(None).await.expect("close websocket");
2515
2516 let _ = shutdown_tx.send(());
2517 server
2518 .await
2519 .expect("server join")
2520 .expect("server shutdown should succeed");
2521 }
2522
2523 #[derive(Clone)]
2524 struct EchoTrace;
2525
2526 #[async_trait]
2527 impl Transition<(), String> for EchoTrace {
2528 type Error = String;
2529 type Resources = ();
2530
2531 async fn run(
2532 &self,
2533 _state: (),
2534 _resources: &Self::Resources,
2535 bus: &mut Bus,
2536 ) -> Outcome<String, Self::Error> {
2537 let trace_id = bus
2538 .read::<String>()
2539 .cloned()
2540 .unwrap_or_else(|| "missing-trace".to_string());
2541 Outcome::next(trace_id)
2542 }
2543 }
2544
2545 #[tokio::test]
2546 async fn observe_trace_context_and_metrics_layers_work_with_ingress() {
2547 let metrics = HttpMetrics::default();
2548 let ingress = HttpIngress::<()>::new()
2549 .layer(TraceContextLayer::new())
2550 .layer(HttpMetricsLayer::new(metrics.clone()))
2551 .bus_injector(|req, bus| {
2552 if let Some(trace) = req.extensions.get::<IncomingTraceContext>() {
2553 bus.insert(trace.trace_id().to_string());
2554 }
2555 })
2556 .get(
2557 "/trace",
2558 Axon::<(), (), String, ()>::new("EchoTrace").then(EchoTrace),
2559 );
2560
2561 let app = crate::test_harness::TestApp::new(ingress, ());
2562 let response = app
2563 .send(crate::test_harness::TestRequest::get("/trace").header(
2564 "traceparent",
2565 "00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01",
2566 ))
2567 .await
2568 .expect("request should succeed");
2569
2570 assert_eq!(response.status(), StatusCode::OK);
2571 assert_eq!(
2572 response.text().expect("utf8 response"),
2573 "4bf92f3577b34da6a3ce929d0e0e4736"
2574 );
2575
2576 let snapshot = metrics.snapshot();
2577 assert_eq!(snapshot.requests_total, 1);
2578 assert_eq!(snapshot.requests_error, 0);
2579 }
2580
2581 #[test]
2582 fn route_descriptors_export_http_and_health_paths() {
2583 let ingress = HttpIngress::<()>::new()
2584 .get(
2585 "/orders/:id",
2586 Axon::<(), (), String, ()>::new("OrderById"),
2587 )
2588 .health_endpoint("/healthz")
2589 .readiness_liveness("/readyz", "/livez");
2590
2591 let descriptors = ingress.route_descriptors();
2592
2593 assert!(
2594 descriptors
2595 .iter()
2596 .any(|descriptor| descriptor.method() == Method::GET
2597 && descriptor.path_pattern() == "/orders/:id")
2598 );
2599 assert!(
2600 descriptors
2601 .iter()
2602 .any(|descriptor| descriptor.method() == Method::GET
2603 && descriptor.path_pattern() == "/healthz")
2604 );
2605 assert!(
2606 descriptors
2607 .iter()
2608 .any(|descriptor| descriptor.method() == Method::GET
2609 && descriptor.path_pattern() == "/readyz")
2610 );
2611 assert!(
2612 descriptors
2613 .iter()
2614 .any(|descriptor| descriptor.method() == Method::GET
2615 && descriptor.path_pattern() == "/livez")
2616 );
2617 }
2618
2619 #[tokio::test]
2620 async fn lifecycle_hooks_fire_on_start_and_shutdown() {
2621 let started = Arc::new(AtomicBool::new(false));
2622 let shutdown = Arc::new(AtomicBool::new(false));
2623
2624 let started_flag = started.clone();
2625 let shutdown_flag = shutdown.clone();
2626
2627 let ingress = HttpIngress::<()>::new()
2628 .bind("127.0.0.1:0")
2629 .on_start(move || {
2630 started_flag.store(true, Ordering::SeqCst);
2631 })
2632 .on_shutdown(move || {
2633 shutdown_flag.store(true, Ordering::SeqCst);
2634 })
2635 .graceful_shutdown(Duration::from_millis(50));
2636
2637 ingress
2638 .run_with_shutdown_signal((), async {
2639 tokio::time::sleep(Duration::from_millis(20)).await;
2640 })
2641 .await
2642 .expect("server should exit gracefully");
2643
2644 assert!(started.load(Ordering::SeqCst));
2645 assert!(shutdown.load(Ordering::SeqCst));
2646 }
2647
2648 #[tokio::test]
2649 async fn graceful_shutdown_drains_in_flight_requests_before_exit() {
2650 #[derive(Clone)]
2651 struct SlowDrainRoute;
2652
2653 #[async_trait]
2654 impl Transition<(), String> for SlowDrainRoute {
2655 type Error = String;
2656 type Resources = ();
2657
2658 async fn run(
2659 &self,
2660 _state: (),
2661 _resources: &Self::Resources,
2662 _bus: &mut Bus,
2663 ) -> Outcome<String, Self::Error> {
2664 tokio::time::sleep(Duration::from_millis(120)).await;
2665 Outcome::next("drained-ok".to_string())
2666 }
2667 }
2668
2669 let probe = std::net::TcpListener::bind("127.0.0.1:0").expect("bind probe");
2670 let addr = probe.local_addr().expect("local addr");
2671 drop(probe);
2672
2673 let ingress = HttpIngress::<()>::new()
2674 .bind(addr.to_string())
2675 .graceful_shutdown(Duration::from_millis(500))
2676 .get(
2677 "/drain",
2678 Axon::<(), (), String, ()>::new("SlowDrain").then(SlowDrainRoute),
2679 );
2680
2681 let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel::<()>();
2682 let server = tokio::spawn(async move {
2683 ingress
2684 .run_with_shutdown_signal((), async move {
2685 let _ = shutdown_rx.await;
2686 })
2687 .await
2688 });
2689
2690 let mut stream = connect_with_retry(addr).await;
2691 stream
2692 .write_all(b"GET /drain HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n")
2693 .await
2694 .expect("write request");
2695
2696 tokio::time::sleep(Duration::from_millis(20)).await;
2697 let _ = shutdown_tx.send(());
2698
2699 let mut buf = Vec::new();
2700 stream.read_to_end(&mut buf).await.expect("read response");
2701 let response = String::from_utf8_lossy(&buf);
2702 assert!(response.starts_with("HTTP/1.1 200"), "{response}");
2703 assert!(response.contains("drained-ok"), "{response}");
2704
2705 server
2706 .await
2707 .expect("server join")
2708 .expect("server shutdown should succeed");
2709 }
2710
2711 #[tokio::test]
2712 async fn serve_dir_serves_static_file_with_cache_and_metadata_headers() {
2713 let temp = tempdir().expect("tempdir");
2714 let root = temp.path().join("public");
2715 fs::create_dir_all(&root).expect("create dir");
2716 let file = root.join("hello.txt");
2717 fs::write(&file, "hello static").expect("write file");
2718
2719 let ingress =
2720 Ranvier::http::<()>().serve_dir("/static", root.to_string_lossy().to_string());
2721 let app = crate::test_harness::TestApp::new(ingress, ());
2722 let response = app
2723 .send(crate::test_harness::TestRequest::get("/static/hello.txt"))
2724 .await
2725 .expect("request should succeed");
2726
2727 assert_eq!(response.status(), StatusCode::OK);
2728 assert_eq!(response.text().expect("utf8"), "hello static");
2729 assert!(response.header("cache-control").is_some());
2730 let has_metadata_header =
2731 response.header("etag").is_some() || response.header("last-modified").is_some();
2732 assert!(has_metadata_header);
2733 }
2734
2735 #[tokio::test]
2736 async fn spa_fallback_returns_index_for_unmatched_path() {
2737 let temp = tempdir().expect("tempdir");
2738 let index = temp.path().join("index.html");
2739 fs::write(&index, "<html><body>spa</body></html>").expect("write index");
2740
2741 let ingress = Ranvier::http::<()>().spa_fallback(index.to_string_lossy().to_string());
2742 let app = crate::test_harness::TestApp::new(ingress, ());
2743 let response = app
2744 .send(crate::test_harness::TestRequest::get("/dashboard/settings"))
2745 .await
2746 .expect("request should succeed");
2747
2748 assert_eq!(response.status(), StatusCode::OK);
2749 assert!(response.text().expect("utf8").contains("spa"));
2750 }
2751
2752 #[tokio::test]
2753 async fn static_compression_layer_sets_content_encoding_for_gzip_client() {
2754 let temp = tempdir().expect("tempdir");
2755 let root = temp.path().join("public");
2756 fs::create_dir_all(&root).expect("create dir");
2757 let file = root.join("compressed.txt");
2758 fs::write(&file, "compress me ".repeat(400)).expect("write file");
2759
2760 let ingress = Ranvier::http::<()>()
2761 .serve_dir("/static", root.to_string_lossy().to_string())
2762 .compression_layer();
2763 let app = crate::test_harness::TestApp::new(ingress, ());
2764 let response = app
2765 .send(
2766 crate::test_harness::TestRequest::get("/static/compressed.txt")
2767 .header("accept-encoding", "gzip"),
2768 )
2769 .await
2770 .expect("request should succeed");
2771
2772 assert_eq!(response.status(), StatusCode::OK);
2773 assert_eq!(
2774 response
2775 .header("content-encoding")
2776 .and_then(|value| value.to_str().ok()),
2777 Some("gzip")
2778 );
2779 }
2780
2781 #[tokio::test]
2782 async fn drain_connections_completes_before_timeout() {
2783 let mut connections = tokio::task::JoinSet::new();
2784 connections.spawn(async {
2785 tokio::time::sleep(Duration::from_millis(20)).await;
2786 });
2787
2788 let timed_out = drain_connections(&mut connections, Duration::from_millis(200)).await;
2789 assert!(!timed_out);
2790 assert!(connections.is_empty());
2791 }
2792
2793 #[tokio::test]
2794 async fn drain_connections_times_out_and_aborts() {
2795 let mut connections = tokio::task::JoinSet::new();
2796 connections.spawn(async {
2797 tokio::time::sleep(Duration::from_secs(10)).await;
2798 });
2799
2800 let timed_out = drain_connections(&mut connections, Duration::from_millis(10)).await;
2801 assert!(timed_out);
2802 assert!(connections.is_empty());
2803 }
2804
2805 #[tokio::test]
2806 async fn timeout_layer_returns_408_for_slow_route() {
2807 #[derive(Clone)]
2808 struct SlowRoute;
2809
2810 #[async_trait]
2811 impl Transition<(), String> for SlowRoute {
2812 type Error = String;
2813 type Resources = ();
2814
2815 async fn run(
2816 &self,
2817 _state: (),
2818 _resources: &Self::Resources,
2819 _bus: &mut Bus,
2820 ) -> Outcome<String, Self::Error> {
2821 tokio::time::sleep(Duration::from_millis(80)).await;
2822 Outcome::next("slow-ok".to_string())
2823 }
2824 }
2825
2826 let probe = std::net::TcpListener::bind("127.0.0.1:0").expect("bind probe");
2827 let addr = probe.local_addr().expect("local addr");
2828 drop(probe);
2829
2830 let ingress = HttpIngress::<()>::new()
2831 .bind(addr.to_string())
2832 .timeout_layer(Duration::from_millis(10))
2833 .get(
2834 "/slow",
2835 Axon::<(), (), String, ()>::new("Slow").then(SlowRoute),
2836 );
2837
2838 let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel::<()>();
2839 let server = tokio::spawn(async move {
2840 ingress
2841 .run_with_shutdown_signal((), async move {
2842 let _ = shutdown_rx.await;
2843 })
2844 .await
2845 });
2846
2847 let mut stream = connect_with_retry(addr).await;
2848 stream
2849 .write_all(b"GET /slow HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n")
2850 .await
2851 .expect("write request");
2852
2853 let mut buf = Vec::new();
2854 stream.read_to_end(&mut buf).await.expect("read response");
2855 let response = String::from_utf8_lossy(&buf);
2856 assert!(response.starts_with("HTTP/1.1 408"), "{response}");
2857
2858 let _ = shutdown_tx.send(());
2859 server
2860 .await
2861 .expect("server join")
2862 .expect("server shutdown should succeed");
2863 }
2864
2865 #[tokio::test]
2866 async fn route_layer_override_bypasses_global_timeout() {
2867 #[derive(Clone)]
2868 struct SlowRoute;
2869
2870 #[async_trait]
2871 impl Transition<(), String> for SlowRoute {
2872 type Error = String;
2873 type Resources = ();
2874
2875 async fn run(
2876 &self,
2877 _state: (),
2878 _resources: &Self::Resources,
2879 _bus: &mut Bus,
2880 ) -> Outcome<String, Self::Error> {
2881 tokio::time::sleep(Duration::from_millis(60)).await;
2882 Outcome::next("override-ok".to_string())
2883 }
2884 }
2885
2886 let probe = std::net::TcpListener::bind("127.0.0.1:0").expect("bind probe");
2887 let addr = probe.local_addr().expect("local addr");
2888 drop(probe);
2889
2890 let ingress = HttpIngress::<()>::new()
2891 .bind(addr.to_string())
2892 .timeout_layer(Duration::from_millis(10))
2893 .get_with_layer_override(
2894 "/slow",
2895 Axon::<(), (), String, ()>::new("SlowOverride").then(SlowRoute),
2896 tower::layer::util::Identity::new(),
2897 );
2898
2899 let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel::<()>();
2900 let server = tokio::spawn(async move {
2901 ingress
2902 .run_with_shutdown_signal((), async move {
2903 let _ = shutdown_rx.await;
2904 })
2905 .await
2906 });
2907
2908 let mut stream = connect_with_retry(addr).await;
2909 stream
2910 .write_all(b"GET /slow HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n")
2911 .await
2912 .expect("write request");
2913
2914 let mut buf = Vec::new();
2915 stream.read_to_end(&mut buf).await.expect("read response");
2916 let response = String::from_utf8_lossy(&buf);
2917 assert!(response.starts_with("HTTP/1.1 200"), "{response}");
2918 assert!(response.contains("override-ok"), "{response}");
2919
2920 let _ = shutdown_tx.send(());
2921 server
2922 .await
2923 .expect("server join")
2924 .expect("server shutdown should succeed");
2925 }
2926}