1use base64::Engine;
17use bytes::Bytes;
18use futures_util::{SinkExt, StreamExt};
19use http::{Method, Request, Response, StatusCode, Uri};
20use http_body::Body;
21use http_body_util::{BodyExt, Full};
22use hyper::body::Incoming;
23use hyper::server::conn::http1;
24use hyper::upgrade::Upgraded;
25use hyper_util::rt::TokioIo;
26use hyper_util::service::TowerToHyperService;
27use ranvier_core::event::{EventSink, EventSource};
28use ranvier_core::prelude::*;
29use ranvier_runtime::Axon;
30use serde::Serialize;
31use serde::de::DeserializeOwned;
32use sha1::{Digest, Sha1};
33use std::collections::HashMap;
34use std::convert::Infallible;
35use std::future::Future;
36use std::net::SocketAddr;
37use std::pin::Pin;
38use std::sync::Arc;
39use std::time::Duration;
40use tokio::net::TcpListener;
41use tokio::sync::Mutex;
42use tokio_tungstenite::WebSocketStream;
43use tokio_tungstenite::tungstenite::{Error as WsWireError, Message as WsWireMessage};
44use tower::util::BoxCloneService;
45use tower::{Layer, Service, ServiceExt, service_fn};
46use tower_http::compression::CompressionLayer;
47use tower_http::services::{ServeDir, ServeFile};
48use tracing::Instrument;
49
50use crate::response::{HttpResponse, IntoResponse, outcome_to_response_with_error};
51
52pub struct Ranvier;
57
58impl Ranvier {
59 pub fn http<R>() -> HttpIngress<R>
61 where
62 R: ranvier_core::transition::ResourceRequirement + Clone,
63 {
64 HttpIngress::new()
65 }
66}
67
68type RouteHandler<R> = Arc<
70 dyn Fn(http::request::Parts, &R) -> Pin<Box<dyn Future<Output = HttpResponse> + Send>>
71 + Send
72 + Sync,
73>;
74
75type BoxHttpService = BoxCloneService<Request<Incoming>, HttpResponse, Infallible>;
76type ServiceLayer = Arc<dyn Fn(BoxHttpService) -> BoxHttpService + Send + Sync>;
77type LifecycleHook = Arc<dyn Fn() + Send + Sync>;
78type BusInjector = Arc<dyn Fn(&http::request::Parts, &mut Bus) + Send + Sync + 'static>;
79type WsSessionFuture = Pin<Box<dyn Future<Output = ()> + Send>>;
80type WsSessionHandler<R> =
81 Arc<dyn Fn(WebSocketConnection, Arc<R>, Bus) -> WsSessionFuture + Send + Sync>;
82type HealthCheckFuture = Pin<Box<dyn Future<Output = Result<(), String>> + Send>>;
83type HealthCheckFn<R> = Arc<dyn Fn(Arc<R>) -> HealthCheckFuture + Send + Sync>;
84const REQUEST_ID_HEADER: &str = "x-request-id";
85const WS_UPGRADE_TOKEN: &str = "websocket";
86const WS_GUID: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
87
88#[derive(Clone)]
89struct NamedHealthCheck<R> {
90 name: String,
91 check: HealthCheckFn<R>,
92}
93
94#[derive(Clone)]
95struct HealthConfig<R> {
96 health_path: Option<String>,
97 readiness_path: Option<String>,
98 liveness_path: Option<String>,
99 checks: Vec<NamedHealthCheck<R>>,
100}
101
102impl<R> Default for HealthConfig<R> {
103 fn default() -> Self {
104 Self {
105 health_path: None,
106 readiness_path: None,
107 liveness_path: None,
108 checks: Vec::new(),
109 }
110 }
111}
112
113#[derive(Clone, Default)]
114struct StaticAssetsConfig {
115 mounts: Vec<StaticMount>,
116 spa_fallback: Option<String>,
117 cache_control: Option<String>,
118 enable_compression: bool,
119}
120
121#[derive(Clone)]
122struct StaticMount {
123 route_prefix: String,
124 directory: String,
125}
126
127#[derive(Serialize)]
128struct HealthReport {
129 status: &'static str,
130 probe: &'static str,
131 checks: Vec<HealthCheckReport>,
132}
133
134#[derive(Serialize)]
135struct HealthCheckReport {
136 name: String,
137 status: &'static str,
138 #[serde(skip_serializing_if = "Option::is_none")]
139 error: Option<String>,
140}
141
142#[derive(Clone)]
143struct TimeoutService {
144 inner: BoxHttpService,
145 timeout: Duration,
146}
147
148impl Service<Request<Incoming>> for TimeoutService {
149 type Response = HttpResponse;
150 type Error = Infallible;
151 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
152
153 fn poll_ready(
154 &mut self,
155 cx: &mut std::task::Context<'_>,
156 ) -> std::task::Poll<Result<(), Self::Error>> {
157 self.inner.poll_ready(cx)
158 }
159
160 fn call(&mut self, req: Request<Incoming>) -> Self::Future {
161 let timeout = self.timeout;
162 let fut = self.inner.call(req);
163 Box::pin(async move {
164 match tokio::time::timeout(timeout, fut).await {
165 Ok(response) => response,
166 Err(_) => Ok(Response::builder()
167 .status(StatusCode::REQUEST_TIMEOUT)
168 .body(Full::new(Bytes::from("Request Timeout")).map_err(|never| match never {}).boxed())
169 .unwrap()),
170 }
171 })
172 }
173}
174
175#[derive(Clone)]
176struct RequestIdService {
177 inner: BoxHttpService,
178}
179
180impl Service<Request<Incoming>> for RequestIdService {
181 type Response = HttpResponse;
182 type Error = Infallible;
183 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
184
185 fn poll_ready(
186 &mut self,
187 cx: &mut std::task::Context<'_>,
188 ) -> std::task::Poll<Result<(), Self::Error>> {
189 self.inner.poll_ready(cx)
190 }
191
192 fn call(&mut self, req: Request<Incoming>) -> Self::Future {
193 let mut req = req;
194 let request_id = req
195 .headers()
196 .get(REQUEST_ID_HEADER)
197 .cloned()
198 .unwrap_or_else(|| {
199 http::HeaderValue::from_str(&uuid::Uuid::new_v4().to_string())
200 .unwrap_or_else(|_| http::HeaderValue::from_static("request-id-unavailable"))
201 });
202
203 req.headers_mut()
204 .insert(REQUEST_ID_HEADER, request_id.clone());
205
206 let fut = self.inner.call(req);
207 Box::pin(async move {
208 let mut response = fut.await?;
209 response.headers_mut().insert(REQUEST_ID_HEADER, request_id);
210 Ok(response)
211 })
212 }
213}
214
215fn to_service_layer<L>(layer: L) -> ServiceLayer
216where
217 L: Layer<BoxHttpService> + Clone + Send + Sync + 'static,
218 L::Service: Service<Request<Incoming>, Response = HttpResponse, Error = Infallible>
219 + Clone
220 + Send
221 + 'static,
222 <L::Service as Service<Request<Incoming>>>::Future: Send + 'static,
223{
224 Arc::new(move |service: BoxHttpService| BoxCloneService::new(layer.clone().layer(service)))
225}
226
227#[derive(Clone, Debug, Default, PartialEq, Eq)]
228pub struct PathParams {
229 values: HashMap<String, String>,
230}
231
232#[derive(Clone, Debug, PartialEq, Eq)]
234pub struct HttpRouteDescriptor {
235 method: Method,
236 path_pattern: String,
237}
238
239impl HttpRouteDescriptor {
240 pub fn new(method: Method, path_pattern: impl Into<String>) -> Self {
241 Self {
242 method,
243 path_pattern: path_pattern.into(),
244 }
245 }
246
247 pub fn method(&self) -> &Method {
248 &self.method
249 }
250
251 pub fn path_pattern(&self) -> &str {
252 &self.path_pattern
253 }
254}
255
256#[derive(Clone, Debug, PartialEq, Eq, Serialize)]
258pub struct WebSocketSessionContext {
259 connection_id: uuid::Uuid,
260 path: String,
261 query: Option<String>,
262}
263
264impl WebSocketSessionContext {
265 pub fn connection_id(&self) -> uuid::Uuid {
266 self.connection_id
267 }
268
269 pub fn path(&self) -> &str {
270 &self.path
271 }
272
273 pub fn query(&self) -> Option<&str> {
274 self.query.as_deref()
275 }
276}
277
278#[derive(Clone, Debug, PartialEq, Eq)]
280pub enum WebSocketEvent {
281 Text(String),
282 Binary(Vec<u8>),
283 Ping(Vec<u8>),
284 Pong(Vec<u8>),
285 Close,
286}
287
288impl WebSocketEvent {
289 pub fn text(value: impl Into<String>) -> Self {
290 Self::Text(value.into())
291 }
292
293 pub fn binary(value: impl Into<Vec<u8>>) -> Self {
294 Self::Binary(value.into())
295 }
296
297 pub fn json<T>(value: &T) -> Result<Self, serde_json::Error>
298 where
299 T: Serialize,
300 {
301 let text = serde_json::to_string(value)?;
302 Ok(Self::Text(text))
303 }
304}
305
306#[derive(Debug, thiserror::Error)]
307pub enum WebSocketError {
308 #[error("websocket wire error: {0}")]
309 Wire(#[from] WsWireError),
310 #[error("json serialization failed: {0}")]
311 JsonSerialize(#[source] serde_json::Error),
312 #[error("json deserialization failed: {0}")]
313 JsonDeserialize(#[source] serde_json::Error),
314 #[error("expected text or binary frame for json payload")]
315 NonDataFrame,
316}
317
318type WsServerStream = WebSocketStream<TokioIo<Upgraded>>;
319type WsServerSink = futures_util::stream::SplitSink<WsServerStream, WsWireMessage>;
320type WsServerSource = futures_util::stream::SplitStream<WsServerStream>;
321
322pub struct WebSocketConnection {
324 sink: Mutex<WsServerSink>,
325 source: Mutex<WsServerSource>,
326 session: WebSocketSessionContext,
327}
328
329impl WebSocketConnection {
330 fn new(stream: WsServerStream, session: WebSocketSessionContext) -> Self {
331 let (sink, source) = stream.split();
332 Self {
333 sink: Mutex::new(sink),
334 source: Mutex::new(source),
335 session,
336 }
337 }
338
339 pub fn session(&self) -> &WebSocketSessionContext {
340 &self.session
341 }
342
343 pub async fn send(&self, event: WebSocketEvent) -> Result<(), WebSocketError> {
344 let mut sink = self.sink.lock().await;
345 sink.send(event.into_wire_message()).await?;
346 Ok(())
347 }
348
349 pub async fn send_json<T>(&self, value: &T) -> Result<(), WebSocketError>
350 where
351 T: Serialize,
352 {
353 let event = WebSocketEvent::json(value).map_err(WebSocketError::JsonSerialize)?;
354 self.send(event).await
355 }
356
357 pub async fn next_json<T>(&mut self) -> Result<Option<T>, WebSocketError>
358 where
359 T: DeserializeOwned,
360 {
361 let Some(event) = self.recv_event().await? else {
362 return Ok(None);
363 };
364 match event {
365 WebSocketEvent::Text(text) => serde_json::from_str(&text)
366 .map(Some)
367 .map_err(WebSocketError::JsonDeserialize),
368 WebSocketEvent::Binary(bytes) => serde_json::from_slice(&bytes)
369 .map(Some)
370 .map_err(WebSocketError::JsonDeserialize),
371 _ => Err(WebSocketError::NonDataFrame),
372 }
373 }
374
375 async fn recv_event(&mut self) -> Result<Option<WebSocketEvent>, WsWireError> {
376 let mut source = self.source.lock().await;
377 while let Some(item) = source.next().await {
378 let message = item?;
379 if let Some(event) = WebSocketEvent::from_wire_message(message) {
380 return Ok(Some(event));
381 }
382 }
383 Ok(None)
384 }
385}
386
387impl WebSocketEvent {
388 fn from_wire_message(message: WsWireMessage) -> Option<Self> {
389 match message {
390 WsWireMessage::Text(value) => Some(Self::Text(value.to_string())),
391 WsWireMessage::Binary(value) => Some(Self::Binary(value.to_vec())),
392 WsWireMessage::Ping(value) => Some(Self::Ping(value.to_vec())),
393 WsWireMessage::Pong(value) => Some(Self::Pong(value.to_vec())),
394 WsWireMessage::Close(_) => Some(Self::Close),
395 WsWireMessage::Frame(_) => None,
396 }
397 }
398
399 fn into_wire_message(self) -> WsWireMessage {
400 match self {
401 Self::Text(value) => WsWireMessage::Text(value.into()),
402 Self::Binary(value) => WsWireMessage::Binary(value.into()),
403 Self::Ping(value) => WsWireMessage::Ping(value.into()),
404 Self::Pong(value) => WsWireMessage::Pong(value.into()),
405 Self::Close => WsWireMessage::Close(None),
406 }
407 }
408}
409
410#[async_trait::async_trait]
411impl EventSource<WebSocketEvent> for WebSocketConnection {
412 async fn next_event(&mut self) -> Option<WebSocketEvent> {
413 match self.recv_event().await {
414 Ok(event) => event,
415 Err(error) => {
416 tracing::warn!(ranvier.ws.error = %error, "websocket source read failed");
417 None
418 }
419 }
420 }
421}
422
423#[async_trait::async_trait]
424impl EventSink<WebSocketEvent> for WebSocketConnection {
425 type Error = WebSocketError;
426
427 async fn send_event(&self, event: WebSocketEvent) -> Result<(), Self::Error> {
428 self.send(event).await
429 }
430}
431
432#[async_trait::async_trait]
433impl EventSink<String> for WebSocketConnection {
434 type Error = WebSocketError;
435
436 async fn send_event(&self, event: String) -> Result<(), Self::Error> {
437 self.send(WebSocketEvent::Text(event)).await
438 }
439}
440
441#[async_trait::async_trait]
442impl EventSink<Vec<u8>> for WebSocketConnection {
443 type Error = WebSocketError;
444
445 async fn send_event(&self, event: Vec<u8>) -> Result<(), Self::Error> {
446 self.send(WebSocketEvent::Binary(event)).await
447 }
448}
449
450impl PathParams {
451 pub fn new(values: HashMap<String, String>) -> Self {
452 Self { values }
453 }
454
455 pub fn get(&self, key: &str) -> Option<&str> {
456 self.values.get(key).map(String::as_str)
457 }
458
459 pub fn as_map(&self) -> &HashMap<String, String> {
460 &self.values
461 }
462
463 pub fn into_inner(self) -> HashMap<String, String> {
464 self.values
465 }
466}
467
468#[derive(Clone, Debug, PartialEq, Eq)]
469enum RouteSegment {
470 Static(String),
471 Param(String),
472 Wildcard(String),
473}
474
475#[derive(Clone, Debug, PartialEq, Eq)]
476struct RoutePattern {
477 raw: String,
478 segments: Vec<RouteSegment>,
479}
480
481impl RoutePattern {
482 fn parse(path: &str) -> Self {
483 let segments = path_segments(path)
484 .into_iter()
485 .map(|segment| {
486 if let Some(name) = segment.strip_prefix(':') {
487 if !name.is_empty() {
488 return RouteSegment::Param(name.to_string());
489 }
490 }
491 if let Some(name) = segment.strip_prefix('*') {
492 if !name.is_empty() {
493 return RouteSegment::Wildcard(name.to_string());
494 }
495 }
496 RouteSegment::Static(segment.to_string())
497 })
498 .collect();
499
500 Self {
501 raw: path.to_string(),
502 segments,
503 }
504 }
505
506 fn match_path(&self, path: &str) -> Option<PathParams> {
507 let mut params = HashMap::new();
508 let path_segments = path_segments(path);
509 let mut pattern_index = 0usize;
510 let mut path_index = 0usize;
511
512 while pattern_index < self.segments.len() {
513 match &self.segments[pattern_index] {
514 RouteSegment::Static(expected) => {
515 let actual = path_segments.get(path_index)?;
516 if actual != expected {
517 return None;
518 }
519 pattern_index += 1;
520 path_index += 1;
521 }
522 RouteSegment::Param(name) => {
523 let actual = path_segments.get(path_index)?;
524 params.insert(name.clone(), (*actual).to_string());
525 pattern_index += 1;
526 path_index += 1;
527 }
528 RouteSegment::Wildcard(name) => {
529 let remaining = path_segments[path_index..].join("/");
530 params.insert(name.clone(), remaining);
531 pattern_index += 1;
532 path_index = path_segments.len();
533 break;
534 }
535 }
536 }
537
538 if pattern_index == self.segments.len() && path_index == path_segments.len() {
539 Some(PathParams::new(params))
540 } else {
541 None
542 }
543 }
544}
545
546#[derive(Clone)]
547struct RouteEntry<R> {
548 method: Method,
549 pattern: RoutePattern,
550 handler: RouteHandler<R>,
551 layers: Arc<Vec<ServiceLayer>>,
552 apply_global_layers: bool,
553}
554
555fn path_segments(path: &str) -> Vec<&str> {
556 if path == "/" {
557 return Vec::new();
558 }
559
560 path.trim_matches('/')
561 .split('/')
562 .filter(|segment| !segment.is_empty())
563 .collect()
564}
565
566fn normalize_route_path(path: String) -> String {
567 if path.is_empty() {
568 return "/".to_string();
569 }
570 if path.starts_with('/') {
571 path
572 } else {
573 format!("/{path}")
574 }
575}
576
577fn find_matching_route<'a, R>(
578 routes: &'a [RouteEntry<R>],
579 method: &Method,
580 path: &str,
581) -> Option<(&'a RouteEntry<R>, PathParams)> {
582 for entry in routes {
583 if &entry.method != method {
584 continue;
585 }
586 if let Some(params) = entry.pattern.match_path(path) {
587 return Some((entry, params));
588 }
589 }
590 None
591}
592
593fn header_contains_token(
594 headers: &http::HeaderMap,
595 name: http::header::HeaderName,
596 token: &str,
597) -> bool {
598 headers
599 .get(name)
600 .and_then(|value| value.to_str().ok())
601 .map(|value| {
602 value
603 .split(',')
604 .any(|part| part.trim().eq_ignore_ascii_case(token))
605 })
606 .unwrap_or(false)
607}
608
609fn websocket_session_from_request<B>(req: &Request<B>) -> WebSocketSessionContext {
610 WebSocketSessionContext {
611 connection_id: uuid::Uuid::new_v4(),
612 path: req.uri().path().to_string(),
613 query: req.uri().query().map(str::to_string),
614 }
615}
616
617fn websocket_accept_key(client_key: &str) -> String {
618 let mut hasher = Sha1::new();
619 hasher.update(client_key.as_bytes());
620 hasher.update(WS_GUID.as_bytes());
621 let digest = hasher.finalize();
622 base64::engine::general_purpose::STANDARD.encode(digest)
623}
624
625fn websocket_bad_request(message: &'static str) -> HttpResponse {
626 Response::builder()
627 .status(StatusCode::BAD_REQUEST)
628 .body(Full::new(Bytes::from(message)).map_err(|never| match never {}).boxed())
629 .unwrap_or_else(|_| Response::new(Full::new(Bytes::new()).map_err(|never| match never {}).boxed()))
630}
631
632fn websocket_upgrade_response<B>(
633 req: &mut Request<B>,
634) -> Result<(HttpResponse, hyper::upgrade::OnUpgrade), HttpResponse> {
635 if req.method() != Method::GET {
636 return Err(websocket_bad_request(
637 "WebSocket upgrade requires GET method",
638 ));
639 }
640
641 if !header_contains_token(req.headers(), http::header::CONNECTION, "upgrade") {
642 return Err(websocket_bad_request(
643 "Missing Connection: upgrade header for WebSocket",
644 ));
645 }
646
647 if !header_contains_token(req.headers(), http::header::UPGRADE, WS_UPGRADE_TOKEN) {
648 return Err(websocket_bad_request("Missing Upgrade: websocket header"));
649 }
650
651 if let Some(version) = req.headers().get("sec-websocket-version") {
652 if version != "13" {
653 return Err(websocket_bad_request(
654 "Unsupported Sec-WebSocket-Version (expected 13)",
655 ));
656 }
657 }
658
659 let Some(client_key) = req
660 .headers()
661 .get("sec-websocket-key")
662 .and_then(|value| value.to_str().ok())
663 else {
664 return Err(websocket_bad_request(
665 "Missing Sec-WebSocket-Key header for WebSocket",
666 ));
667 };
668
669 let accept_key = websocket_accept_key(client_key);
670 let on_upgrade = hyper::upgrade::on(req);
671 let response = Response::builder()
672 .status(StatusCode::SWITCHING_PROTOCOLS)
673 .header(http::header::UPGRADE, WS_UPGRADE_TOKEN)
674 .header(http::header::CONNECTION, "Upgrade")
675 .header("sec-websocket-accept", accept_key)
676 .body(Full::new(Bytes::new()).map_err(|never| match never {}).boxed())
677 .unwrap_or_else(|_| Response::new(Full::new(Bytes::new()).map_err(|never| match never {}).boxed()));
678
679 Ok((response, on_upgrade))
680}
681
682pub struct HttpIngress<R = ()> {
688 addr: Option<String>,
690 routes: Vec<RouteEntry<R>>,
692 fallback: Option<RouteHandler<R>>,
694 layers: Vec<ServiceLayer>,
696 on_start: Option<LifecycleHook>,
698 on_shutdown: Option<LifecycleHook>,
700 graceful_shutdown_timeout: Duration,
702 bus_injectors: Vec<BusInjector>,
704 static_assets: StaticAssetsConfig,
706 health: HealthConfig<R>,
708 #[cfg(feature = "http3")]
709 http3_config: Option<crate::http3::Http3Config>,
710 #[cfg(feature = "http3")]
711 alt_svc_h3_port: Option<u16>,
712 _phantom: std::marker::PhantomData<R>,
713}
714
715impl<R> HttpIngress<R>
716where
717 R: ranvier_core::transition::ResourceRequirement + Clone + Send + Sync + 'static,
718{
719 pub fn new() -> Self {
721 Self {
722 addr: None,
723 routes: Vec::new(),
724 fallback: None,
725 layers: Vec::new(),
726 on_start: None,
727 on_shutdown: None,
728 graceful_shutdown_timeout: Duration::from_secs(30),
729 bus_injectors: Vec::new(),
730 static_assets: StaticAssetsConfig::default(),
731 health: HealthConfig::default(),
732 #[cfg(feature = "http3")]
733 http3_config: None,
734 #[cfg(feature = "http3")]
735 alt_svc_h3_port: None,
736 _phantom: std::marker::PhantomData,
737 }
738 }
739
740 pub fn bind(mut self, addr: impl Into<String>) -> Self {
742 self.addr = Some(addr.into());
743 self
744 }
745
746 pub fn on_start<F>(mut self, callback: F) -> Self
748 where
749 F: Fn() + Send + Sync + 'static,
750 {
751 self.on_start = Some(Arc::new(callback));
752 self
753 }
754
755 pub fn on_shutdown<F>(mut self, callback: F) -> Self
757 where
758 F: Fn() + Send + Sync + 'static,
759 {
760 self.on_shutdown = Some(Arc::new(callback));
761 self
762 }
763
764 pub fn graceful_shutdown(mut self, timeout: Duration) -> Self {
766 self.graceful_shutdown_timeout = timeout;
767 self
768 }
769
770 pub fn layer<L>(mut self, layer: L) -> Self
775 where
776 L: Layer<BoxHttpService> + Clone + Send + Sync + 'static,
777 L::Service: Service<Request<Incoming>, Response = HttpResponse, Error = Infallible>
778 + Clone
779 + Send
780 + 'static,
781 <L::Service as Service<Request<Incoming>>>::Future: Send + 'static,
782 {
783 self.layers.push(to_service_layer(layer));
784 self
785 }
786
787 pub fn timeout_layer(mut self, timeout: Duration) -> Self {
790 self.layers.push(Arc::new(move |service: BoxHttpService| {
791 BoxCloneService::new(TimeoutService {
792 inner: service,
793 timeout,
794 })
795 }));
796 self
797 }
798
799 pub fn request_id_layer(mut self) -> Self {
803 self.layers.push(Arc::new(move |service: BoxHttpService| {
804 BoxCloneService::new(RequestIdService { inner: service })
805 }));
806 self
807 }
808
809 pub fn bus_injector<F>(mut self, injector: F) -> Self
814 where
815 F: Fn(&http::request::Parts, &mut Bus) + Send + Sync + 'static,
816 {
817 self.bus_injectors.push(Arc::new(injector));
818 self
819 }
820
821 #[cfg(feature = "http3")]
823 pub fn enable_http3(mut self, config: crate::http3::Http3Config) -> Self {
824 self.http3_config = Some(config);
825 self
826 }
827
828 #[cfg(feature = "http3")]
830 pub fn alt_svc_h3(mut self, port: u16) -> Self {
831 self.alt_svc_h3_port = Some(port);
832 self
833 }
834
835 pub fn route_descriptors(&self) -> Vec<HttpRouteDescriptor> {
837 let mut descriptors = self
838 .routes
839 .iter()
840 .map(|entry| HttpRouteDescriptor::new(entry.method.clone(), entry.pattern.raw.clone()))
841 .collect::<Vec<_>>();
842
843 if let Some(path) = &self.health.health_path {
844 descriptors.push(HttpRouteDescriptor::new(Method::GET, path.clone()));
845 }
846 if let Some(path) = &self.health.readiness_path {
847 descriptors.push(HttpRouteDescriptor::new(Method::GET, path.clone()));
848 }
849 if let Some(path) = &self.health.liveness_path {
850 descriptors.push(HttpRouteDescriptor::new(Method::GET, path.clone()));
851 }
852
853 descriptors
854 }
855
856 pub fn serve_dir(
860 mut self,
861 route_prefix: impl Into<String>,
862 directory: impl Into<String>,
863 ) -> Self {
864 self.static_assets.mounts.push(StaticMount {
865 route_prefix: normalize_route_path(route_prefix.into()),
866 directory: directory.into(),
867 });
868 if self.static_assets.cache_control.is_none() {
869 self.static_assets.cache_control = Some("public, max-age=3600".to_string());
870 }
871 self
872 }
873
874 pub fn spa_fallback(mut self, file_path: impl Into<String>) -> Self {
878 self.static_assets.spa_fallback = Some(file_path.into());
879 self
880 }
881
882 pub fn static_cache_control(mut self, cache_control: impl Into<String>) -> Self {
884 self.static_assets.cache_control = Some(cache_control.into());
885 self
886 }
887
888 pub fn compression_layer(mut self) -> Self {
890 self.static_assets.enable_compression = true;
891 self
892 }
893
894 pub fn ws<H, Fut>(mut self, path: impl Into<String>, handler: H) -> Self
901 where
902 H: Fn(WebSocketConnection, Arc<R>, Bus) -> Fut + Send + Sync + 'static,
903 Fut: Future<Output = ()> + Send + 'static,
904 {
905 let path_str: String = path.into();
906 let ws_handler: WsSessionHandler<R> = Arc::new(move |connection, resources, bus| {
907 Box::pin(handler(connection, resources, bus))
908 });
909 let bus_injectors = Arc::new(self.bus_injectors.clone());
910 let path_for_pattern = path_str.clone();
911 let path_for_handler = path_str;
912
913 let route_handler: RouteHandler<R> =
914 Arc::new(move |parts: http::request::Parts, res: &R| {
915 let ws_handler = ws_handler.clone();
916 let bus_injectors = bus_injectors.clone();
917 let resources = Arc::new(res.clone());
918 let path = path_for_handler.clone();
919
920 Box::pin(async move {
921 let request_id = uuid::Uuid::new_v4().to_string();
922 let span = tracing::info_span!(
923 "WebSocketUpgrade",
924 ranvier.ws.path = %path,
925 ranvier.ws.request_id = %request_id
926 );
927
928 async move {
929 let mut bus = Bus::new();
930 for injector in bus_injectors.iter() {
931 injector(&parts, &mut bus);
932 }
933
934 let mut req = Request::from_parts(parts, ());
936 let session = websocket_session_from_request(&req);
937 bus.insert(session.clone());
938
939 let (response, on_upgrade) = match websocket_upgrade_response(&mut req) {
940 Ok(result) => result,
941 Err(error_response) => return error_response,
942 };
943
944 tokio::spawn(async move {
945 match on_upgrade.await {
946 Ok(upgraded) => {
947 let stream = WebSocketStream::from_raw_socket(
948 TokioIo::new(upgraded),
949 tokio_tungstenite::tungstenite::protocol::Role::Server,
950 None,
951 )
952 .await;
953 let connection = WebSocketConnection::new(stream, session);
954 ws_handler(connection, resources, bus).await;
955 }
956 Err(error) => {
957 tracing::warn!(
958 ranvier.ws.path = %path,
959 ranvier.ws.error = %error,
960 "websocket upgrade failed"
961 );
962 }
963 }
964 });
965
966 response
967 }
968 .instrument(span)
969 .await
970 }) as Pin<Box<dyn Future<Output = HttpResponse> + Send>>
971 });
972
973 self.routes.push(RouteEntry {
974 method: Method::GET,
975 pattern: RoutePattern::parse(&path_for_pattern),
976 handler: route_handler,
977 layers: Arc::new(Vec::new()),
978 apply_global_layers: true,
979 });
980
981 self
982 }
983
984 pub fn health_endpoint(mut self, path: impl Into<String>) -> Self {
989 self.health.health_path = Some(normalize_route_path(path.into()));
990 self
991 }
992
993 pub fn health_check<F, Fut, Err>(mut self, name: impl Into<String>, check: F) -> Self
997 where
998 F: Fn(Arc<R>) -> Fut + Send + Sync + 'static,
999 Fut: Future<Output = Result<(), Err>> + Send + 'static,
1000 Err: ToString + Send + 'static,
1001 {
1002 if self.health.health_path.is_none() {
1003 self.health.health_path = Some("/health".to_string());
1004 }
1005
1006 let check_fn: HealthCheckFn<R> = Arc::new(move |resources: Arc<R>| {
1007 let fut = check(resources);
1008 Box::pin(async move { fut.await.map_err(|error| error.to_string()) })
1009 });
1010
1011 self.health.checks.push(NamedHealthCheck {
1012 name: name.into(),
1013 check: check_fn,
1014 });
1015 self
1016 }
1017
1018 pub fn readiness_liveness(
1020 mut self,
1021 readiness_path: impl Into<String>,
1022 liveness_path: impl Into<String>,
1023 ) -> Self {
1024 self.health.readiness_path = Some(normalize_route_path(readiness_path.into()));
1025 self.health.liveness_path = Some(normalize_route_path(liveness_path.into()));
1026 self
1027 }
1028
1029 pub fn readiness_liveness_default(self) -> Self {
1031 self.readiness_liveness("/ready", "/live")
1032 }
1033
1034 pub fn route<Out, E>(self, path: impl Into<String>, circuit: Axon<(), Out, E, R>) -> Self
1036 where
1037 Out: IntoResponse + Send + Sync + 'static,
1038 E: Send + 'static + std::fmt::Debug,
1039 {
1040 self.route_method(Method::GET, path, circuit)
1041 }
1042 pub fn route_method<Out, E>(
1051 self,
1052 method: Method,
1053 path: impl Into<String>,
1054 circuit: Axon<(), Out, E, R>,
1055 ) -> Self
1056 where
1057 Out: IntoResponse + Send + Sync + 'static,
1058 E: Send + 'static + std::fmt::Debug,
1059 {
1060 self.route_method_with_error(method, path, circuit, |error| {
1061 (
1062 StatusCode::INTERNAL_SERVER_ERROR,
1063 format!("Error: {:?}", error),
1064 ).into_response()
1065 })
1066 }
1067
1068 pub fn route_method_with_error<Out, E, H>(
1069 self,
1070 method: Method,
1071 path: impl Into<String>,
1072 circuit: Axon<(), Out, E, R>,
1073 error_handler: H,
1074 ) -> Self
1075 where
1076 Out: IntoResponse + Send + Sync + 'static,
1077 E: Send + 'static + std::fmt::Debug,
1078 H: Fn(&E) -> HttpResponse + Send + Sync + 'static,
1079 {
1080 self.route_method_with_error_and_layers(
1081 method,
1082 path,
1083 circuit,
1084 error_handler,
1085 Arc::new(Vec::new()),
1086 true,
1087 )
1088 }
1089
1090 pub fn route_method_with_layer<Out, E, L>(
1091 self,
1092 method: Method,
1093 path: impl Into<String>,
1094 circuit: Axon<(), Out, E, R>,
1095 layer: L,
1096 ) -> Self
1097 where
1098 Out: IntoResponse + Send + Sync + 'static,
1099 E: Send + 'static + std::fmt::Debug,
1100 L: Layer<BoxHttpService> + Clone + Send + Sync + 'static,
1101 L::Service: Service<Request<Incoming>, Response = HttpResponse, Error = Infallible>
1102 + Clone
1103 + Send
1104 + 'static,
1105 <L::Service as Service<Request<Incoming>>>::Future: Send + 'static,
1106 {
1107 self.route_method_with_error_and_layers(
1108 method,
1109 path,
1110 circuit,
1111 |error| {
1112 (
1113 StatusCode::INTERNAL_SERVER_ERROR,
1114 format!("Error: {:?}", error),
1115 ).into_response()
1116 },
1117 Arc::new(vec![to_service_layer(layer)]),
1118 true,
1119 )
1120 }
1121
1122 pub fn route_method_with_layer_override<Out, E, L>(
1123 self,
1124 method: Method,
1125 path: impl Into<String>,
1126 circuit: Axon<(), Out, E, R>,
1127 layer: L,
1128 ) -> Self
1129 where
1130 Out: IntoResponse + Send + Sync + 'static,
1131 E: Send + 'static + std::fmt::Debug,
1132 L: Layer<BoxHttpService> + Clone + Send + Sync + 'static,
1133 L::Service: Service<Request<Incoming>, Response = HttpResponse, Error = Infallible>
1134 + Clone
1135 + Send
1136 + 'static,
1137 <L::Service as Service<Request<Incoming>>>::Future: Send + 'static,
1138 {
1139 self.route_method_with_error_and_layers(
1140 method,
1141 path,
1142 circuit,
1143 |error| {
1144 (
1145 StatusCode::INTERNAL_SERVER_ERROR,
1146 format!("Error: {:?}", error),
1147 ).into_response()
1148 },
1149 Arc::new(vec![to_service_layer(layer)]),
1150 false,
1151 )
1152 }
1153
1154 fn route_method_with_error_and_layers<Out, E, H>(
1155 mut self,
1156 method: Method,
1157 path: impl Into<String>,
1158 circuit: Axon<(), Out, E, R>,
1159 error_handler: H,
1160 route_layers: Arc<Vec<ServiceLayer>>,
1161 apply_global_layers: bool,
1162 ) -> Self
1163 where
1164 Out: IntoResponse + Send + Sync + 'static,
1165 E: Send + 'static + std::fmt::Debug,
1166 H: Fn(&E) -> HttpResponse + Send + Sync + 'static,
1167 {
1168 let path_str: String = path.into();
1169 let circuit = Arc::new(circuit);
1170 let error_handler = Arc::new(error_handler);
1171 let route_bus_injectors = Arc::new(self.bus_injectors.clone());
1172 let path_for_pattern = path_str.clone();
1173 let path_for_handler = path_str;
1174 let method_for_pattern = method.clone();
1175 let method_for_handler = method;
1176
1177 let handler: RouteHandler<R> = Arc::new(move |parts: http::request::Parts, res: &R| {
1178 let circuit = circuit.clone();
1179 let error_handler = error_handler.clone();
1180 let route_bus_injectors = route_bus_injectors.clone();
1181 let res = res.clone();
1182 let path = path_for_handler.clone();
1183 let method = method_for_handler.clone();
1184
1185 Box::pin(async move {
1186 let request_id = uuid::Uuid::new_v4().to_string();
1187 let span = tracing::info_span!(
1188 "HTTPRequest",
1189 ranvier.http.method = %method,
1190 ranvier.http.path = %path,
1191 ranvier.http.request_id = %request_id
1192 );
1193
1194 async move {
1195 let mut bus = Bus::new();
1196 for injector in route_bus_injectors.iter() {
1197 injector(&parts, &mut bus);
1198 }
1199 let result = circuit.execute((), &res, &mut bus).await;
1200 outcome_to_response_with_error(result, |error| error_handler(error))
1201 }
1202 .instrument(span)
1203 .await
1204 }) as Pin<Box<dyn Future<Output = HttpResponse> + Send>>
1205 });
1206
1207 self.routes.push(RouteEntry {
1208 method: method_for_pattern,
1209 pattern: RoutePattern::parse(&path_for_pattern),
1210 handler,
1211 layers: route_layers,
1212 apply_global_layers,
1213 });
1214 self
1215 }
1216
1217 pub fn get<Out, E>(self, path: impl Into<String>, circuit: Axon<(), Out, E, R>) -> Self
1218 where
1219 Out: IntoResponse + Send + Sync + 'static,
1220 E: Send + 'static + std::fmt::Debug,
1221 {
1222 self.route_method(Method::GET, path, circuit)
1223 }
1224
1225 pub fn get_with_error<Out, E, H>(
1226 self,
1227 path: impl Into<String>,
1228 circuit: Axon<(), Out, E, R>,
1229 error_handler: H,
1230 ) -> Self
1231 where
1232 Out: IntoResponse + Send + Sync + 'static,
1233 E: Send + 'static + std::fmt::Debug,
1234 H: Fn(&E) -> HttpResponse + Send + Sync + 'static,
1235 {
1236 self.route_method_with_error(Method::GET, path, circuit, error_handler)
1237 }
1238
1239 pub fn get_with_layer<Out, E, L>(
1240 self,
1241 path: impl Into<String>,
1242 circuit: Axon<(), Out, E, R>,
1243 layer: L,
1244 ) -> Self
1245 where
1246 Out: IntoResponse + Send + Sync + 'static,
1247 E: Send + 'static + std::fmt::Debug,
1248 L: Layer<BoxHttpService> + Clone + Send + Sync + 'static,
1249 L::Service: Service<Request<Incoming>, Response = HttpResponse, Error = Infallible>
1250 + Clone
1251 + Send
1252 + 'static,
1253 <L::Service as Service<Request<Incoming>>>::Future: Send + 'static,
1254 {
1255 self.route_method_with_layer(Method::GET, path, circuit, layer)
1256 }
1257
1258 pub fn get_with_layer_override<Out, E, L>(
1259 self,
1260 path: impl Into<String>,
1261 circuit: Axon<(), Out, E, R>,
1262 layer: L,
1263 ) -> Self
1264 where
1265 Out: IntoResponse + Send + Sync + 'static,
1266 E: Send + 'static + std::fmt::Debug,
1267 L: Layer<BoxHttpService> + Clone + Send + Sync + 'static,
1268 L::Service: Service<Request<Incoming>, Response = HttpResponse, Error = Infallible>
1269 + Clone
1270 + Send
1271 + 'static,
1272 <L::Service as Service<Request<Incoming>>>::Future: Send + 'static,
1273 {
1274 self.route_method_with_layer_override(Method::GET, path, circuit, layer)
1275 }
1276
1277 pub fn post<Out, E>(self, path: impl Into<String>, circuit: Axon<(), Out, E, R>) -> Self
1278 where
1279 Out: IntoResponse + Send + Sync + 'static,
1280 E: Send + 'static + std::fmt::Debug,
1281 {
1282 self.route_method(Method::POST, path, circuit)
1283 }
1284
1285 pub fn put<Out, E>(self, path: impl Into<String>, circuit: Axon<(), Out, E, R>) -> Self
1286 where
1287 Out: IntoResponse + Send + Sync + 'static,
1288 E: Send + 'static + std::fmt::Debug,
1289 {
1290 self.route_method(Method::PUT, path, circuit)
1291 }
1292
1293 pub fn delete<Out, E>(self, path: impl Into<String>, circuit: Axon<(), Out, E, R>) -> Self
1294 where
1295 Out: IntoResponse + Send + Sync + 'static,
1296 E: Send + 'static + std::fmt::Debug,
1297 {
1298 self.route_method(Method::DELETE, path, circuit)
1299 }
1300
1301 pub fn patch<Out, E>(self, path: impl Into<String>, circuit: Axon<(), Out, E, R>) -> Self
1302 where
1303 Out: IntoResponse + Send + Sync + 'static,
1304 E: Send + 'static + std::fmt::Debug,
1305 {
1306 self.route_method(Method::PATCH, path, circuit)
1307 }
1308
1309 pub fn fallback<Out, E>(mut self, circuit: Axon<(), Out, E, R>) -> Self
1320 where
1321 Out: IntoResponse + Send + Sync + 'static,
1322 E: Send + 'static + std::fmt::Debug,
1323 {
1324 let circuit = Arc::new(circuit);
1325 let fallback_bus_injectors = Arc::new(self.bus_injectors.clone());
1326
1327 let handler: RouteHandler<R> = Arc::new(move |parts: http::request::Parts, res: &R| {
1328 let circuit = circuit.clone();
1329 let fallback_bus_injectors = fallback_bus_injectors.clone();
1330 let res = res.clone();
1331 Box::pin(async move {
1332 let request_id = uuid::Uuid::new_v4().to_string();
1333 let span = tracing::info_span!(
1334 "HTTPRequest",
1335 ranvier.http.method = "FALLBACK",
1336 ranvier.http.request_id = %request_id
1337 );
1338
1339 async move {
1340 let mut bus = Bus::new();
1341 for injector in fallback_bus_injectors.iter() {
1342 injector(&parts, &mut bus);
1343 }
1344 let result = circuit.execute((), &res, &mut bus).await;
1345
1346 match result {
1347 Outcome::Next(output) => {
1348 let mut response = output.into_response();
1349 *response.status_mut() = StatusCode::NOT_FOUND;
1350 response
1351 }
1352 _ => Response::builder()
1353 .status(StatusCode::NOT_FOUND)
1354 .body(Full::new(Bytes::from("Not Found")).map_err(|never| match never {}).boxed())
1355 .unwrap(),
1356 }
1357 }
1358 .instrument(span)
1359 .await
1360 }) as Pin<Box<dyn Future<Output = HttpResponse> + Send>>
1361 });
1362
1363 self.fallback = Some(handler);
1364 self
1365 }
1366
1367 pub async fn run(self, resources: R) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
1369 self.run_with_shutdown_signal(resources, shutdown_signal())
1370 .await
1371 }
1372
1373 async fn run_with_shutdown_signal<S>(
1374 self,
1375 resources: R,
1376 shutdown_signal: S,
1377 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>>
1378 where
1379 S: Future<Output = ()> + Send,
1380 {
1381 let addr_str = self.addr.as_deref().unwrap_or("127.0.0.1:3000");
1382 let addr: SocketAddr = addr_str.parse()?;
1383
1384 let routes = Arc::new(self.routes);
1385 let fallback = self.fallback;
1386 let layers = Arc::new(self.layers);
1387 let health = Arc::new(self.health);
1388 let static_assets = Arc::new(self.static_assets);
1389 let on_start = self.on_start;
1390 let on_shutdown = self.on_shutdown;
1391 let graceful_shutdown_timeout = self.graceful_shutdown_timeout;
1392 let resources = Arc::new(resources);
1393
1394 let listener = TcpListener::bind(addr).await?;
1395 tracing::info!("Ranvier HTTP Ingress listening on http://{}", addr);
1396 if let Some(callback) = on_start.as_ref() {
1397 callback();
1398 }
1399
1400 tokio::pin!(shutdown_signal);
1401 let mut connections = tokio::task::JoinSet::new();
1402
1403 loop {
1404 tokio::select! {
1405 _ = &mut shutdown_signal => {
1406 tracing::info!("Shutdown signal received. Draining in-flight connections.");
1407 break;
1408 }
1409 accept_result = listener.accept() => {
1410 let (stream, _) = accept_result?;
1411 let io = TokioIo::new(stream);
1412
1413 let routes = routes.clone();
1414 let fallback = fallback.clone();
1415 let resources = resources.clone();
1416 let layers = layers.clone();
1417 let health = health.clone();
1418 let static_assets = static_assets.clone();
1419 #[cfg(feature = "http3")]
1420 let alt_svc_h3_port = self.alt_svc_h3_port;
1421
1422 connections.spawn(async move {
1423 let service = build_http_service(
1424 routes,
1425 fallback,
1426 resources,
1427 layers,
1428 health,
1429 static_assets,
1430 #[cfg(feature = "http3")] alt_svc_h3_port,
1431 );
1432 let hyper_service = TowerToHyperService::new(service);
1433 if let Err(err) = http1::Builder::new()
1434 .serve_connection(io, hyper_service)
1435 .with_upgrades()
1436 .await
1437 {
1438 tracing::error!("Error serving connection: {:?}", err);
1439 }
1440 });
1441 }
1442 Some(join_result) = connections.join_next(), if !connections.is_empty() => {
1443 if let Err(err) = join_result {
1444 tracing::warn!("Connection task join error: {:?}", err);
1445 }
1446 }
1447 }
1448 }
1449
1450 let _timed_out = drain_connections(&mut connections, graceful_shutdown_timeout).await;
1451
1452 drop(resources);
1453 if let Some(callback) = on_shutdown.as_ref() {
1454 callback();
1455 }
1456
1457 Ok(())
1458 }
1459
1460 pub fn into_raw_service(self, resources: R) -> RawIngressService<R> {
1476 let routes = Arc::new(self.routes);
1477 let fallback = self.fallback;
1478 let layers = Arc::new(self.layers);
1479 let health = Arc::new(self.health);
1480 let static_assets = Arc::new(self.static_assets);
1481 let resources = Arc::new(resources);
1482
1483 RawIngressService {
1484 routes,
1485 fallback,
1486 layers,
1487 health,
1488 static_assets,
1489 resources,
1490 }
1491 }
1492}
1493
1494fn build_http_service<R>(
1495 routes: Arc<Vec<RouteEntry<R>>>,
1496 fallback: Option<RouteHandler<R>>,
1497 resources: Arc<R>,
1498 layers: Arc<Vec<ServiceLayer>>,
1499 health: Arc<HealthConfig<R>>,
1500 static_assets: Arc<StaticAssetsConfig>,
1501 #[cfg(feature = "http3")] alt_svc_port: Option<u16>,
1502) -> BoxHttpService
1503where
1504 R: ranvier_core::transition::ResourceRequirement + Clone + Send + Sync + 'static,
1505{
1506 let base_service = service_fn(move |req: Request<Incoming>| {
1507 let routes = routes.clone();
1508 let fallback = fallback.clone();
1509 let resources = resources.clone();
1510 let layers = layers.clone();
1511 let health = health.clone();
1512 let static_assets = static_assets.clone();
1513
1514 async move {
1515 let mut req = req;
1516 let method = req.method().clone();
1517 let path = req.uri().path().to_string();
1518
1519 if let Some(response) =
1520 maybe_handle_health_request(&method, &path, &health, resources.clone()).await
1521 {
1522 return Ok::<_, Infallible>(response.into_response());
1523 }
1524
1525 if let Some((entry, params)) = find_matching_route(routes.as_slice(), &method, &path) {
1526 req.extensions_mut().insert(params);
1527 let effective_layers = if entry.apply_global_layers {
1528 merge_layers(&layers, &entry.layers)
1529 } else {
1530 entry.layers.clone()
1531 };
1532
1533 if effective_layers.is_empty() {
1534 let (parts, _) = req.into_parts();
1535 let mut res = (entry.handler)(parts, &resources).await;
1536 #[cfg(feature = "http3")]
1537 if let Some(port) = alt_svc_port {
1538 if let Ok(val) = http::HeaderValue::from_str(&format!("h3=\":{}\"; ma=86400", port)) {
1539 res.headers_mut().insert(http::header::ALT_SVC, val);
1540 }
1541 }
1542 Ok::<_, Infallible>(res)
1543 } else {
1544 let route_service = build_route_service(
1545 entry.handler.clone(),
1546 resources.clone(),
1547 effective_layers,
1548 );
1549 let mut res = route_service.oneshot(req).await;
1550 #[cfg(feature = "http3")]
1551 if let Ok(ref mut r) = res {
1552 if let Some(port) = alt_svc_port {
1553 if let Ok(val) = http::HeaderValue::from_str(&format!("h3=\":{}\"; ma=86400", port)) {
1554 r.headers_mut().insert(http::header::ALT_SVC, val);
1555 }
1556 }
1557 }
1558 res
1559 }
1560 } else {
1561 let req =
1562 match maybe_handle_static_request(req, &method, &path, static_assets.as_ref())
1563 .await
1564 {
1565 Ok(req) => req,
1566 Err(response) => return Ok(response),
1567 };
1568
1569 let mut fallback_res = if let Some(ref fb) = fallback {
1570 if layers.is_empty() {
1571 let (parts, _) = req.into_parts();
1572 Ok(fb(parts, &resources).await)
1573 } else {
1574 let fallback_service =
1575 build_route_service(fb.clone(), resources.clone(), layers.clone());
1576 fallback_service.oneshot(req).await
1577 }
1578 } else {
1579 Ok(Response::builder()
1580 .status(StatusCode::NOT_FOUND)
1581 .body(Full::new(Bytes::from("Not Found")).map_err(|never| match never {}).boxed())
1582 .unwrap())
1583 };
1584
1585 #[cfg(feature = "http3")]
1586 if let Ok(r) = fallback_res.as_mut() {
1587 if let Some(port) = alt_svc_port {
1588 if let Ok(val) = http::HeaderValue::from_str(&format!("h3=\":{}\"; ma=86400", port)) {
1589 r.headers_mut().insert(http::header::ALT_SVC, val);
1590 }
1591 }
1592 }
1593
1594 fallback_res
1595 }
1596 }
1597 });
1598
1599 BoxCloneService::new(base_service)
1600}
1601
1602fn build_route_service<R>(
1603 handler: RouteHandler<R>,
1604 resources: Arc<R>,
1605 layers: Arc<Vec<ServiceLayer>>,
1606) -> BoxHttpService
1607where
1608 R: ranvier_core::transition::ResourceRequirement + Clone + Send + Sync + 'static,
1609{
1610 let base_service = service_fn(move |req: Request<Incoming>| {
1611 let handler = handler.clone();
1612 let resources = resources.clone();
1613 let (parts, _) = req.into_parts();
1614 async move { Ok::<_, Infallible>(handler(parts, &resources).await) }
1615 });
1616
1617 let mut service = BoxCloneService::new(base_service);
1618 for layer in layers.iter() {
1619 service = layer(service);
1620 }
1621 service
1622}
1623
1624fn merge_layers(
1625 global_layers: &Arc<Vec<ServiceLayer>>,
1626 route_layers: &Arc<Vec<ServiceLayer>>,
1627) -> Arc<Vec<ServiceLayer>> {
1628 if global_layers.is_empty() {
1629 return route_layers.clone();
1630 }
1631 if route_layers.is_empty() {
1632 return global_layers.clone();
1633 }
1634
1635 let mut combined = Vec::with_capacity(global_layers.len() + route_layers.len());
1636 combined.extend(global_layers.iter().cloned());
1637 combined.extend(route_layers.iter().cloned());
1638 Arc::new(combined)
1639}
1640
1641async fn maybe_handle_health_request<R>(
1642 method: &Method,
1643 path: &str,
1644 health: &HealthConfig<R>,
1645 resources: Arc<R>,
1646) -> Option<HttpResponse>
1647where
1648 R: ranvier_core::transition::ResourceRequirement + Clone + Send + Sync + 'static,
1649{
1650 if method != Method::GET {
1651 return None;
1652 }
1653
1654 if let Some(liveness_path) = health.liveness_path.as_ref() {
1655 if path == liveness_path {
1656 return Some(health_json_response("liveness", true, Vec::new()));
1657 }
1658 }
1659
1660 if let Some(readiness_path) = health.readiness_path.as_ref() {
1661 if path == readiness_path {
1662 let (healthy, checks) = run_named_health_checks(&health.checks, resources).await;
1663 return Some(health_json_response("readiness", healthy, checks));
1664 }
1665 }
1666
1667 if let Some(health_path) = health.health_path.as_ref() {
1668 if path == health_path {
1669 let (healthy, checks) = run_named_health_checks(&health.checks, resources).await;
1670 return Some(health_json_response("health", healthy, checks));
1671 }
1672 }
1673
1674 None
1675}
1676
1677async fn maybe_handle_static_request(
1678 req: Request<Incoming>,
1679 method: &Method,
1680 path: &str,
1681 static_assets: &StaticAssetsConfig,
1682) -> Result<Request<Incoming>, HttpResponse> {
1683 if method != Method::GET && method != Method::HEAD {
1684 return Ok(req);
1685 }
1686
1687 if let Some(mount) = static_assets
1688 .mounts
1689 .iter()
1690 .find(|mount| strip_mount_prefix(path, &mount.route_prefix).is_some())
1691 {
1692 let accept_encoding = req.headers().get(http::header::ACCEPT_ENCODING).cloned();
1693 let Some(stripped_path) = strip_mount_prefix(path, &mount.route_prefix) else {
1694 return Ok(req);
1695 };
1696 let rewritten = rewrite_request_path(req, &stripped_path);
1697 let service = ServeDir::new(&mount.directory);
1698 let response = match service.oneshot(rewritten).await {
1699 Ok(response) => response,
1700 Err(_) => {
1701 return Err(Response::builder()
1702 .status(StatusCode::INTERNAL_SERVER_ERROR)
1703 .body(Full::new(Bytes::from("Failed to serve static asset")).map_err(|never| match never {}).boxed())
1704 .unwrap_or_else(|_| Response::new(Full::new(Bytes::new()).map_err(|never| match never {}).boxed())));
1705 }
1706 };
1707 let response =
1708 collect_static_response(response, static_assets.cache_control.as_deref()).await;
1709 let response = maybe_compress_static_response(
1710 response,
1711 accept_encoding,
1712 static_assets.enable_compression,
1713 )
1714 .await;
1715 let (parts, body) = response.into_parts();
1716 return Err(Response::from_parts(parts, body.map_err(|never| match never {}).boxed()));
1717 }
1718
1719 if let Some(spa_file) = static_assets.spa_fallback.as_ref() {
1720 if looks_like_spa_request(path) {
1721 let accept_encoding = req.headers().get(http::header::ACCEPT_ENCODING).cloned();
1722 let service = ServeFile::new(spa_file);
1723 let response = match service.oneshot(req).await {
1724 Ok(response) => response,
1725 Err(_) => {
1726 return Err(Response::builder()
1727 .status(StatusCode::INTERNAL_SERVER_ERROR)
1728 .body(Full::new(Bytes::from("Failed to serve SPA fallback")).map_err(|never| match never {}).boxed())
1729 .unwrap_or_else(|_| Response::new(Full::new(Bytes::new()).map_err(|never| match never {}).boxed())));
1730 }
1731 };
1732 let response =
1733 collect_static_response(response, static_assets.cache_control.as_deref()).await;
1734 let response = maybe_compress_static_response(
1735 response,
1736 accept_encoding,
1737 static_assets.enable_compression,
1738 )
1739 .await;
1740 let (parts, body) = response.into_parts();
1741 return Err(Response::from_parts(parts, body.map_err(|never| match never {}).boxed()));
1742 }
1743 }
1744
1745 Ok(req)
1746}
1747
1748fn strip_mount_prefix(path: &str, prefix: &str) -> Option<String> {
1749 let normalized_prefix = if prefix == "/" {
1750 "/"
1751 } else {
1752 prefix.trim_end_matches('/')
1753 };
1754
1755 if normalized_prefix == "/" {
1756 return Some(path.to_string());
1757 }
1758
1759 if path == normalized_prefix {
1760 return Some("/".to_string());
1761 }
1762
1763 let with_slash = format!("{normalized_prefix}/");
1764 path.strip_prefix(&with_slash)
1765 .map(|stripped| format!("/{}", stripped))
1766}
1767
1768fn rewrite_request_path(mut req: Request<Incoming>, new_path: &str) -> Request<Incoming> {
1769 let query = req.uri().query().map(str::to_string);
1770 let path_and_query = match query {
1771 Some(query) => format!("{new_path}?{query}"),
1772 None => new_path.to_string(),
1773 };
1774
1775 let mut parts = req.uri().clone().into_parts();
1776 if let Ok(parsed_path_and_query) = path_and_query.parse() {
1777 parts.path_and_query = Some(parsed_path_and_query);
1778 if let Ok(uri) = Uri::from_parts(parts) {
1779 *req.uri_mut() = uri;
1780 }
1781 }
1782
1783 req
1784}
1785
1786async fn collect_static_response<B>(
1787 response: Response<B>,
1788 cache_control: Option<&str>,
1789) -> Response<Full<Bytes>>
1790where
1791 B: Body<Data = Bytes> + Send + 'static,
1792 B::Error: std::fmt::Display,
1793{
1794 let status = response.status();
1795 let headers = response.headers().clone();
1796 let body = response.into_body();
1797 let collected = body.collect().await;
1798
1799 let bytes = match collected {
1800 Ok(value) => value.to_bytes(),
1801 Err(error) => Bytes::from(error.to_string()),
1802 };
1803
1804 let mut builder = Response::builder().status(status);
1805 for (name, value) in headers.iter() {
1806 builder = builder.header(name, value);
1807 }
1808
1809 let mut response = builder
1810 .body(Full::new(bytes))
1811 .unwrap_or_else(|_| Response::new(Full::new(Bytes::new())));
1812
1813 if status == StatusCode::OK {
1814 if let Some(value) = cache_control {
1815 if !response.headers().contains_key(http::header::CACHE_CONTROL) {
1816 if let Ok(header_value) = http::HeaderValue::from_str(value) {
1817 response
1818 .headers_mut()
1819 .insert(http::header::CACHE_CONTROL, header_value);
1820 }
1821 }
1822 }
1823 }
1824
1825 response
1826}
1827
1828fn looks_like_spa_request(path: &str) -> bool {
1829 let tail = path.rsplit('/').next().unwrap_or_default();
1830 !tail.contains('.')
1831}
1832
1833async fn maybe_compress_static_response(
1834 response: Response<Full<Bytes>>,
1835 accept_encoding: Option<http::HeaderValue>,
1836 enable_compression: bool,
1837) -> Response<Full<Bytes>> {
1838 if !enable_compression {
1839 return response;
1840 }
1841
1842 let Some(accept_encoding) = accept_encoding else {
1843 return response;
1844 };
1845
1846 let mut request = Request::builder()
1847 .uri("/")
1848 .body(Full::new(Bytes::new()))
1849 .unwrap_or_else(|_| Request::new(Full::new(Bytes::new())));
1850 request
1851 .headers_mut()
1852 .insert(http::header::ACCEPT_ENCODING, accept_encoding);
1853
1854 let service = CompressionLayer::new().layer(service_fn({
1855 let response = response.clone();
1856 move |_req: Request<Full<Bytes>>| {
1857 let response = response.clone();
1858 async move { Ok::<_, Infallible>(response) }
1859 }
1860 }));
1861
1862 match service.oneshot(request).await {
1863 Ok(compressed) => collect_static_response(compressed, None).await,
1864 Err(_) => response,
1865 }
1866}
1867
1868async fn run_named_health_checks<R>(
1869 checks: &[NamedHealthCheck<R>],
1870 resources: Arc<R>,
1871) -> (bool, Vec<HealthCheckReport>)
1872where
1873 R: ranvier_core::transition::ResourceRequirement + Clone + Send + Sync + 'static,
1874{
1875 let mut reports = Vec::with_capacity(checks.len());
1876 let mut healthy = true;
1877
1878 for check in checks {
1879 match (check.check)(resources.clone()).await {
1880 Ok(()) => reports.push(HealthCheckReport {
1881 name: check.name.clone(),
1882 status: "ok",
1883 error: None,
1884 }),
1885 Err(error) => {
1886 healthy = false;
1887 reports.push(HealthCheckReport {
1888 name: check.name.clone(),
1889 status: "error",
1890 error: Some(error),
1891 });
1892 }
1893 }
1894 }
1895
1896 (healthy, reports)
1897}
1898
1899fn health_json_response(
1900 probe: &'static str,
1901 healthy: bool,
1902 checks: Vec<HealthCheckReport>,
1903) -> HttpResponse {
1904 let status_code = if healthy {
1905 StatusCode::OK
1906 } else {
1907 StatusCode::SERVICE_UNAVAILABLE
1908 };
1909 let status = if healthy { "ok" } else { "degraded" };
1910 let payload = HealthReport {
1911 status,
1912 probe,
1913 checks,
1914 };
1915
1916 let body = serde_json::to_vec(&payload)
1917 .unwrap_or_else(|_| br#"{"status":"error","probe":"health"}"#.to_vec());
1918
1919 Response::builder()
1920 .status(status_code)
1921 .header(http::header::CONTENT_TYPE, "application/json")
1922 .body(Full::new(Bytes::from(body)).map_err(|never| match never {}).boxed())
1923 .unwrap()
1924}
1925
1926async fn shutdown_signal() {
1927 #[cfg(unix)]
1928 {
1929 use tokio::signal::unix::{SignalKind, signal};
1930
1931 match signal(SignalKind::terminate()) {
1932 Ok(mut terminate) => {
1933 tokio::select! {
1934 _ = tokio::signal::ctrl_c() => {}
1935 _ = terminate.recv() => {}
1936 }
1937 }
1938 Err(err) => {
1939 tracing::warn!("Failed to install SIGTERM handler: {:?}", err);
1940 if let Err(ctrl_c_err) = tokio::signal::ctrl_c().await {
1941 tracing::warn!("Failed to listen for Ctrl+C: {:?}", ctrl_c_err);
1942 }
1943 }
1944 }
1945 }
1946
1947 #[cfg(not(unix))]
1948 {
1949 if let Err(err) = tokio::signal::ctrl_c().await {
1950 tracing::warn!("Failed to listen for Ctrl+C: {:?}", err);
1951 }
1952 }
1953}
1954
1955async fn drain_connections(
1956 connections: &mut tokio::task::JoinSet<()>,
1957 graceful_shutdown_timeout: Duration,
1958) -> bool {
1959 if connections.is_empty() {
1960 return false;
1961 }
1962
1963 let drain_result = tokio::time::timeout(graceful_shutdown_timeout, async {
1964 while let Some(join_result) = connections.join_next().await {
1965 if let Err(err) = join_result {
1966 tracing::warn!("Connection task join error during shutdown: {:?}", err);
1967 }
1968 }
1969 })
1970 .await;
1971
1972 if drain_result.is_err() {
1973 tracing::warn!(
1974 "Graceful shutdown timeout reached ({:?}). Aborting remaining connections.",
1975 graceful_shutdown_timeout
1976 );
1977 connections.abort_all();
1978 while let Some(join_result) = connections.join_next().await {
1979 if let Err(err) = join_result {
1980 tracing::warn!("Connection task abort join error: {:?}", err);
1981 }
1982 }
1983 true
1984 } else {
1985 false
1986 }
1987}
1988
1989impl<R> Default for HttpIngress<R>
1990where
1991 R: ranvier_core::transition::ResourceRequirement + Clone + Send + Sync + 'static,
1992{
1993 fn default() -> Self {
1994 Self::new()
1995 }
1996}
1997
1998#[derive(Clone)]
2000pub struct RawIngressService<R> {
2001 routes: Arc<Vec<RouteEntry<R>>>,
2002 fallback: Option<RouteHandler<R>>,
2003 layers: Arc<Vec<ServiceLayer>>,
2004 health: Arc<HealthConfig<R>>,
2005 static_assets: Arc<StaticAssetsConfig>,
2006 resources: Arc<R>,
2007}
2008
2009impl<R> Service<Request<Incoming>> for RawIngressService<R>
2010where
2011 R: ranvier_core::transition::ResourceRequirement + Clone + Send + Sync + 'static,
2012{
2013 type Response = HttpResponse;
2014 type Error = Infallible;
2015 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
2016
2017 fn poll_ready(
2018 &mut self,
2019 _cx: &mut std::task::Context<'_>,
2020 ) -> std::task::Poll<Result<(), Self::Error>> {
2021 std::task::Poll::Ready(Ok(()))
2022 }
2023
2024 fn call(&mut self, req: Request<Incoming>) -> Self::Future {
2025 let routes = self.routes.clone();
2026 let fallback = self.fallback.clone();
2027 let layers = self.layers.clone();
2028 let health = self.health.clone();
2029 let static_assets = self.static_assets.clone();
2030 let resources = self.resources.clone();
2031
2032 Box::pin(async move {
2033 let service =
2034 build_http_service(routes, fallback, resources, layers, health, static_assets, #[cfg(feature = "http3")] None);
2035 service.oneshot(req).await
2036 })
2037 }
2038}
2039
2040#[cfg(test)]
2041mod tests {
2042 use super::*;
2043 use async_trait::async_trait;
2044 use futures_util::{SinkExt, StreamExt};
2045 use ranvier_observe::{HttpMetrics, HttpMetricsLayer, IncomingTraceContext, TraceContextLayer};
2046 use serde::Deserialize;
2047 use std::fs;
2048 use std::sync::atomic::{AtomicBool, Ordering};
2049 use tempfile::tempdir;
2050 use tokio::io::{AsyncReadExt, AsyncWriteExt};
2051 use tokio_tungstenite::tungstenite::Message as WsClientMessage;
2052 use tokio_tungstenite::tungstenite::client::IntoClientRequest;
2053
2054 async fn connect_with_retry(addr: std::net::SocketAddr) -> tokio::net::TcpStream {
2055 let deadline = tokio::time::Instant::now() + Duration::from_secs(2);
2056
2057 loop {
2058 match tokio::net::TcpStream::connect(addr).await {
2059 Ok(stream) => return stream,
2060 Err(error) => {
2061 if tokio::time::Instant::now() >= deadline {
2062 panic!("connect server: {error}");
2063 }
2064 tokio::time::sleep(Duration::from_millis(25)).await;
2065 }
2066 }
2067 }
2068 }
2069
2070 #[test]
2071 fn route_pattern_matches_static_path() {
2072 let pattern = RoutePattern::parse("/orders/list");
2073 let params = pattern.match_path("/orders/list").expect("should match");
2074 assert!(params.into_inner().is_empty());
2075 }
2076
2077 #[test]
2078 fn route_pattern_matches_param_segments() {
2079 let pattern = RoutePattern::parse("/orders/:id/items/:item_id");
2080 let params = pattern
2081 .match_path("/orders/42/items/sku-123")
2082 .expect("should match");
2083 assert_eq!(params.get("id"), Some("42"));
2084 assert_eq!(params.get("item_id"), Some("sku-123"));
2085 }
2086
2087 #[test]
2088 fn route_pattern_matches_wildcard_segment() {
2089 let pattern = RoutePattern::parse("/assets/*path");
2090 let params = pattern
2091 .match_path("/assets/css/theme/light.css")
2092 .expect("should match");
2093 assert_eq!(params.get("path"), Some("css/theme/light.css"));
2094 }
2095
2096 #[test]
2097 fn route_pattern_rejects_non_matching_path() {
2098 let pattern = RoutePattern::parse("/orders/:id");
2099 assert!(pattern.match_path("/users/42").is_none());
2100 }
2101
2102 #[test]
2103 fn graceful_shutdown_timeout_defaults_to_30_seconds() {
2104 let ingress = HttpIngress::<()>::new();
2105 assert_eq!(ingress.graceful_shutdown_timeout, Duration::from_secs(30));
2106 assert!(ingress.layers.is_empty());
2107 assert!(ingress.bus_injectors.is_empty());
2108 assert!(ingress.static_assets.mounts.is_empty());
2109 assert!(ingress.on_start.is_none());
2110 assert!(ingress.on_shutdown.is_none());
2111 }
2112
2113 #[test]
2114 fn layer_registration_stacks_globally() {
2115 let ingress = HttpIngress::<()>::new()
2116 .layer(tower::layer::util::Identity::new())
2117 .layer(tower::layer::util::Identity::new());
2118 assert_eq!(ingress.layers.len(), 2);
2119 }
2120
2121 #[test]
2122 fn layer_accepts_tower_http_cors_layer() {
2123 let ingress = HttpIngress::<()>::new().layer(tower_http::cors::CorsLayer::permissive());
2124 assert_eq!(ingress.layers.len(), 1);
2125 }
2126
2127 #[test]
2128 fn route_without_layer_keeps_empty_route_middleware_stack() {
2129 let ingress =
2130 HttpIngress::<()>::new().get("/ping", Axon::<(), (), Infallible, ()>::new("Ping"));
2131 assert_eq!(ingress.routes.len(), 1);
2132 assert!(ingress.routes[0].layers.is_empty());
2133 assert!(ingress.routes[0].apply_global_layers);
2134 }
2135
2136 #[test]
2137 fn route_with_layer_registers_route_middleware_stack() {
2138 let ingress = HttpIngress::<()>::new().get_with_layer(
2139 "/ping",
2140 Axon::<(), (), Infallible, ()>::new("Ping"),
2141 tower::layer::util::Identity::new(),
2142 );
2143 assert_eq!(ingress.routes.len(), 1);
2144 assert_eq!(ingress.routes[0].layers.len(), 1);
2145 assert!(ingress.routes[0].apply_global_layers);
2146 }
2147
2148 #[test]
2149 fn route_with_layer_override_disables_global_layers() {
2150 let ingress = HttpIngress::<()>::new().get_with_layer_override(
2151 "/ping",
2152 Axon::<(), (), Infallible, ()>::new("Ping"),
2153 tower::layer::util::Identity::new(),
2154 );
2155 assert_eq!(ingress.routes.len(), 1);
2156 assert_eq!(ingress.routes[0].layers.len(), 1);
2157 assert!(!ingress.routes[0].apply_global_layers);
2158 }
2159
2160 #[test]
2161 fn timeout_layer_registers_builtin_middleware() {
2162 let ingress = HttpIngress::<()>::new().timeout_layer(Duration::from_secs(1));
2163 assert_eq!(ingress.layers.len(), 1);
2164 }
2165
2166 #[test]
2167 fn request_id_layer_registers_builtin_middleware() {
2168 let ingress = HttpIngress::<()>::new().request_id_layer();
2169 assert_eq!(ingress.layers.len(), 1);
2170 }
2171
2172 #[test]
2173 fn compression_layer_registers_builtin_middleware() {
2174 let ingress = HttpIngress::<()>::new().compression_layer();
2175 assert!(ingress.static_assets.enable_compression);
2176 }
2177
2178 #[test]
2179 fn bus_injector_registration_adds_hook() {
2180 let ingress = HttpIngress::<()>::new().bus_injector(|_req, bus| {
2181 bus.insert("ok".to_string());
2182 });
2183 assert_eq!(ingress.bus_injectors.len(), 1);
2184 }
2185
2186 #[test]
2187 fn ws_route_registers_get_route_pattern() {
2188 let ingress =
2189 HttpIngress::<()>::new().ws("/ws/events", |_socket, _resources, _bus| async {});
2190 assert_eq!(ingress.routes.len(), 1);
2191 assert_eq!(ingress.routes[0].method, Method::GET);
2192 assert_eq!(ingress.routes[0].pattern.raw, "/ws/events");
2193 }
2194
2195 #[derive(Debug, Deserialize)]
2196 struct WsWelcomeFrame {
2197 connection_id: String,
2198 path: String,
2199 tenant: String,
2200 }
2201
2202 #[tokio::test]
2203 async fn ws_route_upgrades_and_bridges_event_source_sink_with_connection_bus() {
2204 let probe = std::net::TcpListener::bind("127.0.0.1:0").expect("bind probe");
2205 let addr = probe.local_addr().expect("local addr");
2206 drop(probe);
2207
2208 let ingress = HttpIngress::<()>::new()
2209 .bind(addr.to_string())
2210 .bus_injector(|req, bus| {
2211 if let Some(value) = req
2212 .headers
2213 .get("x-tenant-id")
2214 .and_then(|v| v.to_str().ok())
2215 {
2216 bus.insert(value.to_string());
2217 }
2218 })
2219 .ws("/ws/echo", |mut socket, _resources, bus| async move {
2220 let tenant = bus
2221 .read::<String>()
2222 .cloned()
2223 .unwrap_or_else(|| "unknown".to_string());
2224 if let Some(session) = bus.read::<WebSocketSessionContext>() {
2225 let welcome = serde_json::json!({
2226 "connection_id": session.connection_id().to_string(),
2227 "path": session.path(),
2228 "tenant": tenant,
2229 });
2230 let _ = socket.send_json(&welcome).await;
2231 }
2232
2233 while let Some(event) = socket.next_event().await {
2234 match event {
2235 WebSocketEvent::Text(text) => {
2236 let _ = socket.send_event(format!("echo:{text}")).await;
2237 }
2238 WebSocketEvent::Binary(bytes) => {
2239 let _ = socket.send_event(bytes).await;
2240 }
2241 WebSocketEvent::Close => break,
2242 WebSocketEvent::Ping(_) | WebSocketEvent::Pong(_) => {}
2243 }
2244 }
2245 });
2246
2247 let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel::<()>();
2248 let server = tokio::spawn(async move {
2249 ingress
2250 .run_with_shutdown_signal((), async move {
2251 let _ = shutdown_rx.await;
2252 })
2253 .await
2254 });
2255
2256 let ws_uri = format!("ws://{addr}/ws/echo?room=alpha");
2257 let mut ws_request = ws_uri
2258 .as_str()
2259 .into_client_request()
2260 .expect("ws client request");
2261 ws_request
2262 .headers_mut()
2263 .insert("x-tenant-id", http::HeaderValue::from_static("acme"));
2264 let (mut client, _response) = tokio_tungstenite::connect_async(ws_request)
2265 .await
2266 .expect("websocket connect");
2267
2268 let welcome = client
2269 .next()
2270 .await
2271 .expect("welcome frame")
2272 .expect("welcome frame ok");
2273 let welcome_text = match welcome {
2274 WsClientMessage::Text(text) => text.to_string(),
2275 other => panic!("expected text welcome frame, got {other:?}"),
2276 };
2277 let welcome_payload: WsWelcomeFrame =
2278 serde_json::from_str(&welcome_text).expect("welcome json");
2279 assert_eq!(welcome_payload.path, "/ws/echo");
2280 assert_eq!(welcome_payload.tenant, "acme");
2281 assert!(!welcome_payload.connection_id.is_empty());
2282
2283 client
2284 .send(WsClientMessage::Text("hello".into()))
2285 .await
2286 .expect("send text");
2287 let echo_text = client
2288 .next()
2289 .await
2290 .expect("echo text frame")
2291 .expect("echo text frame ok");
2292 assert_eq!(echo_text, WsClientMessage::Text("echo:hello".into()));
2293
2294 client
2295 .send(WsClientMessage::Binary(vec![1, 2, 3, 4].into()))
2296 .await
2297 .expect("send binary");
2298 let echo_binary = client
2299 .next()
2300 .await
2301 .expect("echo binary frame")
2302 .expect("echo binary frame ok");
2303 assert_eq!(
2304 echo_binary,
2305 WsClientMessage::Binary(vec![1, 2, 3, 4].into())
2306 );
2307
2308 client.close(None).await.expect("close websocket");
2309
2310 let _ = shutdown_tx.send(());
2311 server
2312 .await
2313 .expect("server join")
2314 .expect("server shutdown should succeed");
2315 }
2316
2317 #[derive(Clone)]
2318 struct EchoTrace;
2319
2320 #[async_trait]
2321 impl Transition<(), String> for EchoTrace {
2322 type Error = Infallible;
2323 type Resources = ();
2324
2325 async fn run(
2326 &self,
2327 _state: (),
2328 _resources: &Self::Resources,
2329 bus: &mut Bus,
2330 ) -> Outcome<String, Self::Error> {
2331 let trace_id = bus
2332 .read::<String>()
2333 .cloned()
2334 .unwrap_or_else(|| "missing-trace".to_string());
2335 Outcome::next(trace_id)
2336 }
2337 }
2338
2339 #[tokio::test]
2340 async fn observe_trace_context_and_metrics_layers_work_with_ingress() {
2341 let metrics = HttpMetrics::default();
2342 let ingress = HttpIngress::<()>::new()
2343 .layer(TraceContextLayer::new())
2344 .layer(HttpMetricsLayer::new(metrics.clone()))
2345 .bus_injector(|req, bus| {
2346 if let Some(trace) = req.extensions.get::<IncomingTraceContext>() {
2347 bus.insert(trace.trace_id().to_string());
2348 }
2349 })
2350 .get(
2351 "/trace",
2352 Axon::<(), (), Infallible, ()>::new("EchoTrace").then(EchoTrace),
2353 );
2354
2355 let app = crate::test_harness::TestApp::new(ingress, ());
2356 let response = app
2357 .send(crate::test_harness::TestRequest::get("/trace").header(
2358 "traceparent",
2359 "00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01",
2360 ))
2361 .await
2362 .expect("request should succeed");
2363
2364 assert_eq!(response.status(), StatusCode::OK);
2365 assert_eq!(
2366 response.text().expect("utf8 response"),
2367 "4bf92f3577b34da6a3ce929d0e0e4736"
2368 );
2369
2370 let snapshot = metrics.snapshot();
2371 assert_eq!(snapshot.requests_total, 1);
2372 assert_eq!(snapshot.requests_error, 0);
2373 }
2374
2375 #[test]
2376 fn route_descriptors_export_http_and_health_paths() {
2377 let ingress = HttpIngress::<()>::new()
2378 .get(
2379 "/orders/:id",
2380 Axon::<(), (), Infallible, ()>::new("OrderById"),
2381 )
2382 .health_endpoint("/healthz")
2383 .readiness_liveness("/readyz", "/livez");
2384
2385 let descriptors = ingress.route_descriptors();
2386
2387 assert!(
2388 descriptors
2389 .iter()
2390 .any(|descriptor| descriptor.method() == Method::GET
2391 && descriptor.path_pattern() == "/orders/:id")
2392 );
2393 assert!(
2394 descriptors
2395 .iter()
2396 .any(|descriptor| descriptor.method() == Method::GET
2397 && descriptor.path_pattern() == "/healthz")
2398 );
2399 assert!(
2400 descriptors
2401 .iter()
2402 .any(|descriptor| descriptor.method() == Method::GET
2403 && descriptor.path_pattern() == "/readyz")
2404 );
2405 assert!(
2406 descriptors
2407 .iter()
2408 .any(|descriptor| descriptor.method() == Method::GET
2409 && descriptor.path_pattern() == "/livez")
2410 );
2411 }
2412
2413 #[tokio::test]
2414 async fn lifecycle_hooks_fire_on_start_and_shutdown() {
2415 let started = Arc::new(AtomicBool::new(false));
2416 let shutdown = Arc::new(AtomicBool::new(false));
2417
2418 let started_flag = started.clone();
2419 let shutdown_flag = shutdown.clone();
2420
2421 let ingress = HttpIngress::<()>::new()
2422 .bind("127.0.0.1:0")
2423 .on_start(move || {
2424 started_flag.store(true, Ordering::SeqCst);
2425 })
2426 .on_shutdown(move || {
2427 shutdown_flag.store(true, Ordering::SeqCst);
2428 })
2429 .graceful_shutdown(Duration::from_millis(50));
2430
2431 ingress
2432 .run_with_shutdown_signal((), async {
2433 tokio::time::sleep(Duration::from_millis(20)).await;
2434 })
2435 .await
2436 .expect("server should exit gracefully");
2437
2438 assert!(started.load(Ordering::SeqCst));
2439 assert!(shutdown.load(Ordering::SeqCst));
2440 }
2441
2442 #[tokio::test]
2443 async fn graceful_shutdown_drains_in_flight_requests_before_exit() {
2444 #[derive(Clone)]
2445 struct SlowDrainRoute;
2446
2447 #[async_trait]
2448 impl Transition<(), &'static str> for SlowDrainRoute {
2449 type Error = Infallible;
2450 type Resources = ();
2451
2452 async fn run(
2453 &self,
2454 _state: (),
2455 _resources: &Self::Resources,
2456 _bus: &mut Bus,
2457 ) -> Outcome<&'static str, Self::Error> {
2458 tokio::time::sleep(Duration::from_millis(120)).await;
2459 Outcome::next("drained-ok")
2460 }
2461 }
2462
2463 let probe = std::net::TcpListener::bind("127.0.0.1:0").expect("bind probe");
2464 let addr = probe.local_addr().expect("local addr");
2465 drop(probe);
2466
2467 let ingress = HttpIngress::<()>::new()
2468 .bind(addr.to_string())
2469 .graceful_shutdown(Duration::from_millis(500))
2470 .get(
2471 "/drain",
2472 Axon::<(), (), Infallible, ()>::new("SlowDrain").then(SlowDrainRoute),
2473 );
2474
2475 let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel::<()>();
2476 let server = tokio::spawn(async move {
2477 ingress
2478 .run_with_shutdown_signal((), async move {
2479 let _ = shutdown_rx.await;
2480 })
2481 .await
2482 });
2483
2484 let mut stream = connect_with_retry(addr).await;
2485 stream
2486 .write_all(b"GET /drain HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n")
2487 .await
2488 .expect("write request");
2489
2490 tokio::time::sleep(Duration::from_millis(20)).await;
2491 let _ = shutdown_tx.send(());
2492
2493 let mut buf = Vec::new();
2494 stream.read_to_end(&mut buf).await.expect("read response");
2495 let response = String::from_utf8_lossy(&buf);
2496 assert!(response.starts_with("HTTP/1.1 200"), "{response}");
2497 assert!(response.contains("drained-ok"), "{response}");
2498
2499 server
2500 .await
2501 .expect("server join")
2502 .expect("server shutdown should succeed");
2503 }
2504
2505 #[tokio::test]
2506 async fn serve_dir_serves_static_file_with_cache_and_metadata_headers() {
2507 let temp = tempdir().expect("tempdir");
2508 let root = temp.path().join("public");
2509 fs::create_dir_all(&root).expect("create dir");
2510 let file = root.join("hello.txt");
2511 fs::write(&file, "hello static").expect("write file");
2512
2513 let ingress =
2514 Ranvier::http::<()>().serve_dir("/static", root.to_string_lossy().to_string());
2515 let app = crate::test_harness::TestApp::new(ingress, ());
2516 let response = app
2517 .send(crate::test_harness::TestRequest::get("/static/hello.txt"))
2518 .await
2519 .expect("request should succeed");
2520
2521 assert_eq!(response.status(), StatusCode::OK);
2522 assert_eq!(response.text().expect("utf8"), "hello static");
2523 assert!(response.header("cache-control").is_some());
2524 let has_metadata_header =
2525 response.header("etag").is_some() || response.header("last-modified").is_some();
2526 assert!(has_metadata_header);
2527 }
2528
2529 #[tokio::test]
2530 async fn spa_fallback_returns_index_for_unmatched_path() {
2531 let temp = tempdir().expect("tempdir");
2532 let index = temp.path().join("index.html");
2533 fs::write(&index, "<html><body>spa</body></html>").expect("write index");
2534
2535 let ingress = Ranvier::http::<()>().spa_fallback(index.to_string_lossy().to_string());
2536 let app = crate::test_harness::TestApp::new(ingress, ());
2537 let response = app
2538 .send(crate::test_harness::TestRequest::get("/dashboard/settings"))
2539 .await
2540 .expect("request should succeed");
2541
2542 assert_eq!(response.status(), StatusCode::OK);
2543 assert!(response.text().expect("utf8").contains("spa"));
2544 }
2545
2546 #[tokio::test]
2547 async fn static_compression_layer_sets_content_encoding_for_gzip_client() {
2548 let temp = tempdir().expect("tempdir");
2549 let root = temp.path().join("public");
2550 fs::create_dir_all(&root).expect("create dir");
2551 let file = root.join("compressed.txt");
2552 fs::write(&file, "compress me ".repeat(400)).expect("write file");
2553
2554 let ingress = Ranvier::http::<()>()
2555 .serve_dir("/static", root.to_string_lossy().to_string())
2556 .compression_layer();
2557 let app = crate::test_harness::TestApp::new(ingress, ());
2558 let response = app
2559 .send(
2560 crate::test_harness::TestRequest::get("/static/compressed.txt")
2561 .header("accept-encoding", "gzip"),
2562 )
2563 .await
2564 .expect("request should succeed");
2565
2566 assert_eq!(response.status(), StatusCode::OK);
2567 assert_eq!(
2568 response
2569 .header("content-encoding")
2570 .and_then(|value| value.to_str().ok()),
2571 Some("gzip")
2572 );
2573 }
2574
2575 #[tokio::test]
2576 async fn drain_connections_completes_before_timeout() {
2577 let mut connections = tokio::task::JoinSet::new();
2578 connections.spawn(async {
2579 tokio::time::sleep(Duration::from_millis(20)).await;
2580 });
2581
2582 let timed_out = drain_connections(&mut connections, Duration::from_millis(200)).await;
2583 assert!(!timed_out);
2584 assert!(connections.is_empty());
2585 }
2586
2587 #[tokio::test]
2588 async fn drain_connections_times_out_and_aborts() {
2589 let mut connections = tokio::task::JoinSet::new();
2590 connections.spawn(async {
2591 tokio::time::sleep(Duration::from_secs(10)).await;
2592 });
2593
2594 let timed_out = drain_connections(&mut connections, Duration::from_millis(10)).await;
2595 assert!(timed_out);
2596 assert!(connections.is_empty());
2597 }
2598
2599 #[tokio::test]
2600 async fn timeout_layer_returns_408_for_slow_route() {
2601 #[derive(Clone)]
2602 struct SlowRoute;
2603
2604 #[async_trait]
2605 impl Transition<(), &'static str> for SlowRoute {
2606 type Error = Infallible;
2607 type Resources = ();
2608
2609 async fn run(
2610 &self,
2611 _state: (),
2612 _resources: &Self::Resources,
2613 _bus: &mut Bus,
2614 ) -> Outcome<&'static str, Self::Error> {
2615 tokio::time::sleep(Duration::from_millis(80)).await;
2616 Outcome::next("slow-ok")
2617 }
2618 }
2619
2620 let probe = std::net::TcpListener::bind("127.0.0.1:0").expect("bind probe");
2621 let addr = probe.local_addr().expect("local addr");
2622 drop(probe);
2623
2624 let ingress = HttpIngress::<()>::new()
2625 .bind(addr.to_string())
2626 .timeout_layer(Duration::from_millis(10))
2627 .get(
2628 "/slow",
2629 Axon::<(), (), Infallible, ()>::new("Slow").then(SlowRoute),
2630 );
2631
2632 let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel::<()>();
2633 let server = tokio::spawn(async move {
2634 ingress
2635 .run_with_shutdown_signal((), async move {
2636 let _ = shutdown_rx.await;
2637 })
2638 .await
2639 });
2640
2641 let mut stream = connect_with_retry(addr).await;
2642 stream
2643 .write_all(b"GET /slow HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n")
2644 .await
2645 .expect("write request");
2646
2647 let mut buf = Vec::new();
2648 stream.read_to_end(&mut buf).await.expect("read response");
2649 let response = String::from_utf8_lossy(&buf);
2650 assert!(response.starts_with("HTTP/1.1 408"), "{response}");
2651
2652 let _ = shutdown_tx.send(());
2653 server
2654 .await
2655 .expect("server join")
2656 .expect("server shutdown should succeed");
2657 }
2658
2659 #[tokio::test]
2660 async fn route_layer_override_bypasses_global_timeout() {
2661 #[derive(Clone)]
2662 struct SlowRoute;
2663
2664 #[async_trait]
2665 impl Transition<(), &'static str> for SlowRoute {
2666 type Error = Infallible;
2667 type Resources = ();
2668
2669 async fn run(
2670 &self,
2671 _state: (),
2672 _resources: &Self::Resources,
2673 _bus: &mut Bus,
2674 ) -> Outcome<&'static str, Self::Error> {
2675 tokio::time::sleep(Duration::from_millis(60)).await;
2676 Outcome::next("override-ok")
2677 }
2678 }
2679
2680 let probe = std::net::TcpListener::bind("127.0.0.1:0").expect("bind probe");
2681 let addr = probe.local_addr().expect("local addr");
2682 drop(probe);
2683
2684 let ingress = HttpIngress::<()>::new()
2685 .bind(addr.to_string())
2686 .timeout_layer(Duration::from_millis(10))
2687 .get_with_layer_override(
2688 "/slow",
2689 Axon::<(), (), Infallible, ()>::new("SlowOverride").then(SlowRoute),
2690 tower::layer::util::Identity::new(),
2691 );
2692
2693 let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel::<()>();
2694 let server = tokio::spawn(async move {
2695 ingress
2696 .run_with_shutdown_signal((), async move {
2697 let _ = shutdown_rx.await;
2698 })
2699 .await
2700 });
2701
2702 let mut stream = connect_with_retry(addr).await;
2703 stream
2704 .write_all(b"GET /slow HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n")
2705 .await
2706 .expect("write request");
2707
2708 let mut buf = Vec::new();
2709 stream.read_to_end(&mut buf).await.expect("read response");
2710 let response = String::from_utf8_lossy(&buf);
2711 assert!(response.starts_with("HTTP/1.1 200"), "{response}");
2712 assert!(response.contains("override-ok"), "{response}");
2713
2714 let _ = shutdown_tx.send(());
2715 server
2716 .await
2717 .expect("server join")
2718 .expect("server shutdown should succeed");
2719 }
2720}