1use base64::Engine;
17use bytes::Bytes;
18use futures_util::{SinkExt, StreamExt};
19use http::{Method, Request, Response, StatusCode};
20use http_body_util::{BodyExt, Full};
21use hyper::body::Incoming;
22use hyper::server::conn::http1;
23use hyper::upgrade::Upgraded;
24use hyper_util::rt::TokioIo;
25use ranvier_core::event::{EventSink, EventSource};
26use ranvier_core::prelude::*;
27use ranvier_runtime::Axon;
28use serde::Serialize;
29use serde::de::DeserializeOwned;
30use sha1::{Digest, Sha1};
31use std::collections::HashMap;
32use std::convert::Infallible;
33use std::future::Future;
34use std::net::SocketAddr;
35use std::pin::Pin;
36use std::sync::Arc;
37use std::time::Duration;
38use tokio::net::TcpListener;
39use tokio::sync::Mutex;
40use tokio_tungstenite::WebSocketStream;
41use tokio_tungstenite::tungstenite::{Error as WsWireError, Message as WsWireMessage};
42use tracing::Instrument;
43
44use crate::response::{HttpResponse, IntoResponse, outcome_to_response_with_error};
45
46pub struct Ranvier;
51
52impl Ranvier {
53 pub fn http<R>() -> HttpIngress<R>
55 where
56 R: ranvier_core::transition::ResourceRequirement + Clone,
57 {
58 HttpIngress::new()
59 }
60}
61
62type RouteHandler<R> = Arc<
64 dyn Fn(http::request::Parts, &R) -> Pin<Box<dyn Future<Output = HttpResponse> + Send>>
65 + Send
66 + Sync,
67>;
68
69#[derive(Clone)]
71struct BoxService(
72 Arc<
73 dyn Fn(Request<Incoming>) -> Pin<Box<dyn Future<Output = Result<HttpResponse, Infallible>> + Send>>
74 + Send
75 + Sync,
76 >,
77);
78
79impl BoxService {
80 fn new<F, Fut>(f: F) -> Self
81 where
82 F: Fn(Request<Incoming>) -> Fut + Send + Sync + 'static,
83 Fut: Future<Output = Result<HttpResponse, Infallible>> + Send + 'static,
84 {
85 Self(Arc::new(move |req| Box::pin(f(req))))
86 }
87
88 fn call(&self, req: Request<Incoming>) -> Pin<Box<dyn Future<Output = Result<HttpResponse, Infallible>> + Send>> {
89 (self.0)(req)
90 }
91}
92
93impl hyper::service::Service<Request<Incoming>> for BoxService {
94 type Response = HttpResponse;
95 type Error = Infallible;
96 type Future = Pin<Box<dyn Future<Output = Result<HttpResponse, Infallible>> + Send>>;
97
98 fn call(&self, req: Request<Incoming>) -> Self::Future {
99 (self.0)(req)
100 }
101}
102
103type BoxHttpService = BoxService;
104type ServiceLayer = Arc<dyn Fn(BoxHttpService) -> BoxHttpService + Send + Sync>;
105type LifecycleHook = Arc<dyn Fn() + Send + Sync>;
106type BusInjector = Arc<dyn Fn(&http::request::Parts, &mut Bus) + Send + Sync + 'static>;
107type WsSessionFuture = Pin<Box<dyn Future<Output = ()> + Send>>;
108type WsSessionHandler<R> =
109 Arc<dyn Fn(WebSocketConnection, Arc<R>, Bus) -> WsSessionFuture + Send + Sync>;
110type HealthCheckFuture = Pin<Box<dyn Future<Output = Result<(), String>> + Send>>;
111type HealthCheckFn<R> = Arc<dyn Fn(Arc<R>) -> HealthCheckFuture + Send + Sync>;
112const REQUEST_ID_HEADER: &str = "x-request-id";
113const WS_UPGRADE_TOKEN: &str = "websocket";
114const WS_GUID: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
115
116#[derive(Clone)]
117struct NamedHealthCheck<R> {
118 name: String,
119 check: HealthCheckFn<R>,
120}
121
122#[derive(Clone)]
123struct HealthConfig<R> {
124 health_path: Option<String>,
125 readiness_path: Option<String>,
126 liveness_path: Option<String>,
127 checks: Vec<NamedHealthCheck<R>>,
128}
129
130impl<R> Default for HealthConfig<R> {
131 fn default() -> Self {
132 Self {
133 health_path: None,
134 readiness_path: None,
135 liveness_path: None,
136 checks: Vec::new(),
137 }
138 }
139}
140
141#[derive(Clone, Default)]
142struct StaticAssetsConfig {
143 mounts: Vec<StaticMount>,
144 spa_fallback: Option<String>,
145 cache_control: Option<String>,
146 enable_compression: bool,
147}
148
149#[derive(Clone)]
150struct StaticMount {
151 route_prefix: String,
152 directory: String,
153}
154
155#[cfg(feature = "tls")]
157#[derive(Clone)]
158struct TlsAcceptorConfig {
159 cert_path: String,
160 key_path: String,
161}
162
163#[derive(Serialize)]
164struct HealthReport {
165 status: &'static str,
166 probe: &'static str,
167 checks: Vec<HealthCheckReport>,
168}
169
170#[derive(Serialize)]
171struct HealthCheckReport {
172 name: String,
173 status: &'static str,
174 #[serde(skip_serializing_if = "Option::is_none")]
175 error: Option<String>,
176}
177
178fn timeout_middleware(timeout: Duration) -> ServiceLayer {
179 Arc::new(move |inner: BoxHttpService| {
180 BoxService::new(move |req: Request<Incoming>| {
181 let inner = inner.clone();
182 async move {
183 match tokio::time::timeout(timeout, inner.call(req)).await {
184 Ok(response) => response,
185 Err(_) => Ok(Response::builder()
186 .status(StatusCode::REQUEST_TIMEOUT)
187 .body(
188 Full::new(Bytes::from("Request Timeout"))
189 .map_err(|never| match never {})
190 .boxed(),
191 )
192 .expect("valid HTTP response construction")),
193 }
194 }
195 })
196 })
197}
198
199fn request_id_middleware() -> ServiceLayer {
200 Arc::new(move |inner: BoxHttpService| {
201 BoxService::new(move |req: Request<Incoming>| {
202 let inner = inner.clone();
203 async move {
204 let mut req = req;
205 let request_id = req
206 .headers()
207 .get(REQUEST_ID_HEADER)
208 .cloned()
209 .unwrap_or_else(|| {
210 http::HeaderValue::from_str(&uuid::Uuid::new_v4().to_string())
211 .unwrap_or_else(|_| {
212 http::HeaderValue::from_static("request-id-unavailable")
213 })
214 });
215 req.headers_mut()
216 .insert(REQUEST_ID_HEADER, request_id.clone());
217 let mut response = inner.call(req).await?;
218 response
219 .headers_mut()
220 .insert(REQUEST_ID_HEADER, request_id);
221 Ok(response)
222 }
223 })
224 })
225}
226
227#[derive(Clone, Debug, Default, PartialEq, Eq)]
228pub struct PathParams {
229 values: HashMap<String, String>,
230}
231
232#[derive(Clone, Debug, PartialEq, Eq)]
234pub struct HttpRouteDescriptor {
235 method: Method,
236 path_pattern: String,
237}
238
239impl HttpRouteDescriptor {
240 pub fn new(method: Method, path_pattern: impl Into<String>) -> Self {
241 Self {
242 method,
243 path_pattern: path_pattern.into(),
244 }
245 }
246
247 pub fn method(&self) -> &Method {
248 &self.method
249 }
250
251 pub fn path_pattern(&self) -> &str {
252 &self.path_pattern
253 }
254}
255
256#[derive(Clone, Debug, PartialEq, Eq, Serialize)]
258pub struct WebSocketSessionContext {
259 connection_id: uuid::Uuid,
260 path: String,
261 query: Option<String>,
262}
263
264impl WebSocketSessionContext {
265 pub fn connection_id(&self) -> uuid::Uuid {
266 self.connection_id
267 }
268
269 pub fn path(&self) -> &str {
270 &self.path
271 }
272
273 pub fn query(&self) -> Option<&str> {
274 self.query.as_deref()
275 }
276}
277
278#[derive(Clone, Debug, PartialEq, Eq)]
280pub enum WebSocketEvent {
281 Text(String),
282 Binary(Vec<u8>),
283 Ping(Vec<u8>),
284 Pong(Vec<u8>),
285 Close,
286}
287
288impl WebSocketEvent {
289 pub fn text(value: impl Into<String>) -> Self {
290 Self::Text(value.into())
291 }
292
293 pub fn binary(value: impl Into<Vec<u8>>) -> Self {
294 Self::Binary(value.into())
295 }
296
297 pub fn json<T>(value: &T) -> Result<Self, serde_json::Error>
298 where
299 T: Serialize,
300 {
301 let text = serde_json::to_string(value)?;
302 Ok(Self::Text(text))
303 }
304}
305
306#[derive(Debug, thiserror::Error)]
307pub enum WebSocketError {
308 #[error("websocket wire error: {0}")]
309 Wire(#[from] WsWireError),
310 #[error("json serialization failed: {0}")]
311 JsonSerialize(#[source] serde_json::Error),
312 #[error("json deserialization failed: {0}")]
313 JsonDeserialize(#[source] serde_json::Error),
314 #[error("expected text or binary frame for json payload")]
315 NonDataFrame,
316}
317
318type WsServerStream = WebSocketStream<TokioIo<Upgraded>>;
319type WsServerSink = futures_util::stream::SplitSink<WsServerStream, WsWireMessage>;
320type WsServerSource = futures_util::stream::SplitStream<WsServerStream>;
321
322pub struct WebSocketConnection {
324 sink: Mutex<WsServerSink>,
325 source: Mutex<WsServerSource>,
326 session: WebSocketSessionContext,
327}
328
329impl WebSocketConnection {
330 fn new(stream: WsServerStream, session: WebSocketSessionContext) -> Self {
331 let (sink, source) = stream.split();
332 Self {
333 sink: Mutex::new(sink),
334 source: Mutex::new(source),
335 session,
336 }
337 }
338
339 pub fn session(&self) -> &WebSocketSessionContext {
340 &self.session
341 }
342
343 pub async fn send(&self, event: WebSocketEvent) -> Result<(), WebSocketError> {
344 let mut sink = self.sink.lock().await;
345 sink.send(event.into_wire_message()).await?;
346 Ok(())
347 }
348
349 pub async fn send_json<T>(&self, value: &T) -> Result<(), WebSocketError>
350 where
351 T: Serialize,
352 {
353 let event = WebSocketEvent::json(value).map_err(WebSocketError::JsonSerialize)?;
354 self.send(event).await
355 }
356
357 pub async fn next_json<T>(&mut self) -> Result<Option<T>, WebSocketError>
358 where
359 T: DeserializeOwned,
360 {
361 let Some(event) = self.recv_event().await? else {
362 return Ok(None);
363 };
364 match event {
365 WebSocketEvent::Text(text) => serde_json::from_str(&text)
366 .map(Some)
367 .map_err(WebSocketError::JsonDeserialize),
368 WebSocketEvent::Binary(bytes) => serde_json::from_slice(&bytes)
369 .map(Some)
370 .map_err(WebSocketError::JsonDeserialize),
371 _ => Err(WebSocketError::NonDataFrame),
372 }
373 }
374
375 async fn recv_event(&mut self) -> Result<Option<WebSocketEvent>, WsWireError> {
376 let mut source = self.source.lock().await;
377 while let Some(item) = source.next().await {
378 let message = item?;
379 if let Some(event) = WebSocketEvent::from_wire_message(message) {
380 return Ok(Some(event));
381 }
382 }
383 Ok(None)
384 }
385}
386
387impl WebSocketEvent {
388 fn from_wire_message(message: WsWireMessage) -> Option<Self> {
389 match message {
390 WsWireMessage::Text(value) => Some(Self::Text(value.to_string())),
391 WsWireMessage::Binary(value) => Some(Self::Binary(value.to_vec())),
392 WsWireMessage::Ping(value) => Some(Self::Ping(value.to_vec())),
393 WsWireMessage::Pong(value) => Some(Self::Pong(value.to_vec())),
394 WsWireMessage::Close(_) => Some(Self::Close),
395 WsWireMessage::Frame(_) => None,
396 }
397 }
398
399 fn into_wire_message(self) -> WsWireMessage {
400 match self {
401 Self::Text(value) => WsWireMessage::Text(value),
402 Self::Binary(value) => WsWireMessage::Binary(value),
403 Self::Ping(value) => WsWireMessage::Ping(value),
404 Self::Pong(value) => WsWireMessage::Pong(value),
405 Self::Close => WsWireMessage::Close(None),
406 }
407 }
408}
409
410#[async_trait::async_trait]
411impl EventSource<WebSocketEvent> for WebSocketConnection {
412 async fn next_event(&mut self) -> Option<WebSocketEvent> {
413 match self.recv_event().await {
414 Ok(event) => event,
415 Err(error) => {
416 tracing::warn!(ranvier.ws.error = %error, "websocket source read failed");
417 None
418 }
419 }
420 }
421}
422
423#[async_trait::async_trait]
424impl EventSink<WebSocketEvent> for WebSocketConnection {
425 type Error = WebSocketError;
426
427 async fn send_event(&self, event: WebSocketEvent) -> Result<(), Self::Error> {
428 self.send(event).await
429 }
430}
431
432#[async_trait::async_trait]
433impl EventSink<String> for WebSocketConnection {
434 type Error = WebSocketError;
435
436 async fn send_event(&self, event: String) -> Result<(), Self::Error> {
437 self.send(WebSocketEvent::Text(event)).await
438 }
439}
440
441#[async_trait::async_trait]
442impl EventSink<Vec<u8>> for WebSocketConnection {
443 type Error = WebSocketError;
444
445 async fn send_event(&self, event: Vec<u8>) -> Result<(), Self::Error> {
446 self.send(WebSocketEvent::Binary(event)).await
447 }
448}
449
450impl PathParams {
451 pub fn new(values: HashMap<String, String>) -> Self {
452 Self { values }
453 }
454
455 pub fn get(&self, key: &str) -> Option<&str> {
456 self.values.get(key).map(String::as_str)
457 }
458
459 pub fn as_map(&self) -> &HashMap<String, String> {
460 &self.values
461 }
462
463 pub fn into_inner(self) -> HashMap<String, String> {
464 self.values
465 }
466}
467
468#[derive(Clone, Debug, PartialEq, Eq)]
469enum RouteSegment {
470 Static(String),
471 Param(String),
472 Wildcard(String),
473}
474
475#[derive(Clone, Debug, PartialEq, Eq)]
476struct RoutePattern {
477 raw: String,
478 segments: Vec<RouteSegment>,
479}
480
481impl RoutePattern {
482 fn parse(path: &str) -> Self {
483 let segments = path_segments(path)
484 .into_iter()
485 .map(|segment| {
486 if let Some(name) = segment.strip_prefix(':') {
487 if !name.is_empty() {
488 return RouteSegment::Param(name.to_string());
489 }
490 }
491 if let Some(name) = segment.strip_prefix('*') {
492 if !name.is_empty() {
493 return RouteSegment::Wildcard(name.to_string());
494 }
495 }
496 RouteSegment::Static(segment.to_string())
497 })
498 .collect();
499
500 Self {
501 raw: path.to_string(),
502 segments,
503 }
504 }
505
506 fn match_path(&self, path: &str) -> Option<PathParams> {
507 let mut params = HashMap::new();
508 let path_segments = path_segments(path);
509 let mut pattern_index = 0usize;
510 let mut path_index = 0usize;
511
512 while pattern_index < self.segments.len() {
513 match &self.segments[pattern_index] {
514 RouteSegment::Static(expected) => {
515 let actual = path_segments.get(path_index)?;
516 if actual != expected {
517 return None;
518 }
519 pattern_index += 1;
520 path_index += 1;
521 }
522 RouteSegment::Param(name) => {
523 let actual = path_segments.get(path_index)?;
524 params.insert(name.clone(), (*actual).to_string());
525 pattern_index += 1;
526 path_index += 1;
527 }
528 RouteSegment::Wildcard(name) => {
529 let remaining = path_segments[path_index..].join("/");
530 params.insert(name.clone(), remaining);
531 pattern_index += 1;
532 path_index = path_segments.len();
533 break;
534 }
535 }
536 }
537
538 if pattern_index == self.segments.len() && path_index == path_segments.len() {
539 Some(PathParams::new(params))
540 } else {
541 None
542 }
543 }
544}
545
546#[derive(Clone)]
547struct RouteEntry<R> {
548 method: Method,
549 pattern: RoutePattern,
550 handler: RouteHandler<R>,
551 layers: Arc<Vec<ServiceLayer>>,
552 apply_global_layers: bool,
553}
554
555fn path_segments(path: &str) -> Vec<&str> {
556 if path == "/" {
557 return Vec::new();
558 }
559
560 path.trim_matches('/')
561 .split('/')
562 .filter(|segment| !segment.is_empty())
563 .collect()
564}
565
566fn normalize_route_path(path: String) -> String {
567 if path.is_empty() {
568 return "/".to_string();
569 }
570 if path.starts_with('/') {
571 path
572 } else {
573 format!("/{path}")
574 }
575}
576
577fn find_matching_route<'a, R>(
578 routes: &'a [RouteEntry<R>],
579 method: &Method,
580 path: &str,
581) -> Option<(&'a RouteEntry<R>, PathParams)> {
582 for entry in routes {
583 if entry.method != *method {
584 continue;
585 }
586 if let Some(params) = entry.pattern.match_path(path) {
587 return Some((entry, params));
588 }
589 }
590 None
591}
592
593fn header_contains_token(
594 headers: &http::HeaderMap,
595 name: http::header::HeaderName,
596 token: &str,
597) -> bool {
598 headers
599 .get(name)
600 .and_then(|value| value.to_str().ok())
601 .map(|value| {
602 value
603 .split(',')
604 .any(|part| part.trim().eq_ignore_ascii_case(token))
605 })
606 .unwrap_or(false)
607}
608
609fn websocket_session_from_request<B>(req: &Request<B>) -> WebSocketSessionContext {
610 WebSocketSessionContext {
611 connection_id: uuid::Uuid::new_v4(),
612 path: req.uri().path().to_string(),
613 query: req.uri().query().map(str::to_string),
614 }
615}
616
617fn websocket_accept_key(client_key: &str) -> String {
618 let mut hasher = Sha1::new();
619 hasher.update(client_key.as_bytes());
620 hasher.update(WS_GUID.as_bytes());
621 let digest = hasher.finalize();
622 base64::engine::general_purpose::STANDARD.encode(digest)
623}
624
625fn websocket_bad_request(message: &'static str) -> HttpResponse {
626 Response::builder()
627 .status(StatusCode::BAD_REQUEST)
628 .body(
629 Full::new(Bytes::from(message))
630 .map_err(|never| match never {})
631 .boxed(),
632 )
633 .unwrap_or_else(|_| {
634 Response::new(
635 Full::new(Bytes::new())
636 .map_err(|never| match never {})
637 .boxed(),
638 )
639 })
640}
641
642fn websocket_upgrade_response<B>(
643 req: &mut Request<B>,
644) -> Result<(HttpResponse, hyper::upgrade::OnUpgrade), HttpResponse> {
645 if req.method() != Method::GET {
646 return Err(websocket_bad_request(
647 "WebSocket upgrade requires GET method",
648 ));
649 }
650
651 if !header_contains_token(req.headers(), http::header::CONNECTION, "upgrade") {
652 return Err(websocket_bad_request(
653 "Missing Connection: upgrade header for WebSocket",
654 ));
655 }
656
657 if !header_contains_token(req.headers(), http::header::UPGRADE, WS_UPGRADE_TOKEN) {
658 return Err(websocket_bad_request("Missing Upgrade: websocket header"));
659 }
660
661 if let Some(version) = req.headers().get("sec-websocket-version") {
662 if version != "13" {
663 return Err(websocket_bad_request(
664 "Unsupported Sec-WebSocket-Version (expected 13)",
665 ));
666 }
667 }
668
669 let Some(client_key) = req
670 .headers()
671 .get("sec-websocket-key")
672 .and_then(|value| value.to_str().ok())
673 else {
674 return Err(websocket_bad_request(
675 "Missing Sec-WebSocket-Key header for WebSocket",
676 ));
677 };
678
679 let accept_key = websocket_accept_key(client_key);
680 let on_upgrade = hyper::upgrade::on(req);
681 let response = Response::builder()
682 .status(StatusCode::SWITCHING_PROTOCOLS)
683 .header(http::header::UPGRADE, WS_UPGRADE_TOKEN)
684 .header(http::header::CONNECTION, "Upgrade")
685 .header("sec-websocket-accept", accept_key)
686 .body(
687 Full::new(Bytes::new())
688 .map_err(|never| match never {})
689 .boxed(),
690 )
691 .unwrap_or_else(|_| {
692 Response::new(
693 Full::new(Bytes::new())
694 .map_err(|never| match never {})
695 .boxed(),
696 )
697 });
698
699 Ok((response, on_upgrade))
700}
701
702pub struct HttpIngress<R = ()> {
708 addr: Option<String>,
710 routes: Vec<RouteEntry<R>>,
712 fallback: Option<RouteHandler<R>>,
714 layers: Vec<ServiceLayer>,
716 on_start: Option<LifecycleHook>,
718 on_shutdown: Option<LifecycleHook>,
720 graceful_shutdown_timeout: Duration,
722 bus_injectors: Vec<BusInjector>,
724 static_assets: StaticAssetsConfig,
726 health: HealthConfig<R>,
728 #[cfg(feature = "http3")]
729 http3_config: Option<crate::http3::Http3Config>,
730 #[cfg(feature = "http3")]
731 alt_svc_h3_port: Option<u16>,
732 #[cfg(feature = "tls")]
734 tls_config: Option<TlsAcceptorConfig>,
735 active_intervention: bool,
737 policy_registry: Option<ranvier_core::policy::PolicyRegistry>,
739 _phantom: std::marker::PhantomData<R>,
740}
741
742impl<R> HttpIngress<R>
743where
744 R: ranvier_core::transition::ResourceRequirement + Clone + Send + Sync + 'static,
745{
746 pub fn new() -> Self {
748 Self {
749 addr: None,
750 routes: Vec::new(),
751 fallback: None,
752 layers: Vec::new(),
753 on_start: None,
754 on_shutdown: None,
755 graceful_shutdown_timeout: Duration::from_secs(30),
756 bus_injectors: Vec::new(),
757 static_assets: StaticAssetsConfig::default(),
758 health: HealthConfig::default(),
759 #[cfg(feature = "tls")]
760 tls_config: None,
761 #[cfg(feature = "http3")]
762 http3_config: None,
763 #[cfg(feature = "http3")]
764 alt_svc_h3_port: None,
765 active_intervention: false,
766 policy_registry: None,
767 _phantom: std::marker::PhantomData,
768 }
769 }
770
771 pub fn bind(mut self, addr: impl Into<String>) -> Self {
775 self.addr = Some(addr.into());
776 self
777 }
778
779 pub fn active_intervention(mut self) -> Self {
785 self.active_intervention = true;
786 self
787 }
788
789 pub fn policy_registry(mut self, registry: ranvier_core::policy::PolicyRegistry) -> Self {
791 self.policy_registry = Some(registry);
792 self
793 }
794
795 pub fn on_start<F>(mut self, callback: F) -> Self
799 where
800 F: Fn() + Send + Sync + 'static,
801 {
802 self.on_start = Some(Arc::new(callback));
803 self
804 }
805
806 pub fn on_shutdown<F>(mut self, callback: F) -> Self
808 where
809 F: Fn() + Send + Sync + 'static,
810 {
811 self.on_shutdown = Some(Arc::new(callback));
812 self
813 }
814
815 pub fn graceful_shutdown(mut self, timeout: Duration) -> Self {
817 self.graceful_shutdown_timeout = timeout;
818 self
819 }
820
821 pub fn config(mut self, config: &ranvier_core::config::RanvierConfig) -> Self {
827 self.addr = Some(config.bind_addr());
828 self.graceful_shutdown_timeout = config.shutdown_timeout();
829 config.init_telemetry();
830 self
831 }
832
833 #[cfg(feature = "tls")]
835 pub fn tls(mut self, cert_path: impl Into<String>, key_path: impl Into<String>) -> Self {
836 self.tls_config = Some(TlsAcceptorConfig {
837 cert_path: cert_path.into(),
838 key_path: key_path.into(),
839 });
840 self
841 }
842
843 pub fn timeout_layer(mut self, timeout: Duration) -> Self {
848 self.layers.push(timeout_middleware(timeout));
849 self
850 }
851
852 pub fn request_id_layer(mut self) -> Self {
856 self.layers.push(request_id_middleware());
857 self
858 }
859
860 pub fn bus_injector<F>(mut self, injector: F) -> Self
865 where
866 F: Fn(&http::request::Parts, &mut Bus) + Send + Sync + 'static,
867 {
868 self.bus_injectors.push(Arc::new(injector));
869 self
870 }
871
872 #[cfg(feature = "http3")]
874 pub fn enable_http3(mut self, config: crate::http3::Http3Config) -> Self {
875 self.http3_config = Some(config);
876 self
877 }
878
879 #[cfg(feature = "http3")]
881 pub fn alt_svc_h3(mut self, port: u16) -> Self {
882 self.alt_svc_h3_port = Some(port);
883 self
884 }
885
886 pub fn route_descriptors(&self) -> Vec<HttpRouteDescriptor> {
890 let mut descriptors = self
891 .routes
892 .iter()
893 .map(|entry| HttpRouteDescriptor::new(entry.method.clone(), entry.pattern.raw.clone()))
894 .collect::<Vec<_>>();
895
896 if let Some(path) = &self.health.health_path {
897 descriptors.push(HttpRouteDescriptor::new(Method::GET, path.clone()));
898 }
899 if let Some(path) = &self.health.readiness_path {
900 descriptors.push(HttpRouteDescriptor::new(Method::GET, path.clone()));
901 }
902 if let Some(path) = &self.health.liveness_path {
903 descriptors.push(HttpRouteDescriptor::new(Method::GET, path.clone()));
904 }
905
906 descriptors
907 }
908
909 pub fn serve_dir(
915 mut self,
916 route_prefix: impl Into<String>,
917 directory: impl Into<String>,
918 ) -> Self {
919 self.static_assets.mounts.push(StaticMount {
920 route_prefix: normalize_route_path(route_prefix.into()),
921 directory: directory.into(),
922 });
923 if self.static_assets.cache_control.is_none() {
924 self.static_assets.cache_control = Some("public, max-age=3600".to_string());
925 }
926 self
927 }
928
929 pub fn spa_fallback(mut self, file_path: impl Into<String>) -> Self {
933 self.static_assets.spa_fallback = Some(file_path.into());
934 self
935 }
936
937 pub fn static_cache_control(mut self, cache_control: impl Into<String>) -> Self {
939 self.static_assets.cache_control = Some(cache_control.into());
940 self
941 }
942
943 pub fn compression_layer(mut self) -> Self {
945 self.static_assets.enable_compression = true;
946 self
947 }
948
949 pub fn ws<H, Fut>(mut self, path: impl Into<String>, handler: H) -> Self
958 where
959 H: Fn(WebSocketConnection, Arc<R>, Bus) -> Fut + Send + Sync + 'static,
960 Fut: Future<Output = ()> + Send + 'static,
961 {
962 let path_str: String = path.into();
963 let ws_handler: WsSessionHandler<R> = Arc::new(move |connection, resources, bus| {
964 Box::pin(handler(connection, resources, bus))
965 });
966 let bus_injectors = Arc::new(self.bus_injectors.clone());
967 let path_for_pattern = path_str.clone();
968 let path_for_handler = path_str;
969
970 let route_handler: RouteHandler<R> =
971 Arc::new(move |parts: http::request::Parts, res: &R| {
972 let ws_handler = ws_handler.clone();
973 let bus_injectors = bus_injectors.clone();
974 let resources = Arc::new(res.clone());
975 let path = path_for_handler.clone();
976
977 Box::pin(async move {
978 let request_id = uuid::Uuid::new_v4().to_string();
979 let span = tracing::info_span!(
980 "WebSocketUpgrade",
981 ranvier.ws.path = %path,
982 ranvier.ws.request_id = %request_id
983 );
984
985 async move {
986 let mut bus = Bus::new();
987 for injector in bus_injectors.iter() {
988 injector(&parts, &mut bus);
989 }
990
991 let mut req = Request::from_parts(parts, ());
993 let session = websocket_session_from_request(&req);
994 bus.insert(session.clone());
995
996 let (response, on_upgrade) = match websocket_upgrade_response(&mut req) {
997 Ok(result) => result,
998 Err(error_response) => return error_response,
999 };
1000
1001 tokio::spawn(async move {
1002 match on_upgrade.await {
1003 Ok(upgraded) => {
1004 let stream = WebSocketStream::from_raw_socket(
1005 TokioIo::new(upgraded),
1006 tokio_tungstenite::tungstenite::protocol::Role::Server,
1007 None,
1008 )
1009 .await;
1010 let connection = WebSocketConnection::new(stream, session);
1011 ws_handler(connection, resources, bus).await;
1012 }
1013 Err(error) => {
1014 tracing::warn!(
1015 ranvier.ws.path = %path,
1016 ranvier.ws.error = %error,
1017 "websocket upgrade failed"
1018 );
1019 }
1020 }
1021 });
1022
1023 response
1024 }
1025 .instrument(span)
1026 .await
1027 }) as Pin<Box<dyn Future<Output = HttpResponse> + Send>>
1028 });
1029
1030 self.routes.push(RouteEntry {
1031 method: Method::GET,
1032 pattern: RoutePattern::parse(&path_for_pattern),
1033 handler: route_handler,
1034 layers: Arc::new(Vec::new()),
1035 apply_global_layers: true,
1036 });
1037
1038 self
1039 }
1040
1041 pub fn health_endpoint(mut self, path: impl Into<String>) -> Self {
1048 self.health.health_path = Some(normalize_route_path(path.into()));
1049 self
1050 }
1051
1052 pub fn health_check<F, Fut, Err>(mut self, name: impl Into<String>, check: F) -> Self
1056 where
1057 F: Fn(Arc<R>) -> Fut + Send + Sync + 'static,
1058 Fut: Future<Output = Result<(), Err>> + Send + 'static,
1059 Err: ToString + Send + 'static,
1060 {
1061 if self.health.health_path.is_none() {
1062 self.health.health_path = Some("/health".to_string());
1063 }
1064
1065 let check_fn: HealthCheckFn<R> = Arc::new(move |resources: Arc<R>| {
1066 let fut = check(resources);
1067 Box::pin(async move { fut.await.map_err(|error| error.to_string()) })
1068 });
1069
1070 self.health.checks.push(NamedHealthCheck {
1071 name: name.into(),
1072 check: check_fn,
1073 });
1074 self
1075 }
1076
1077 pub fn readiness_liveness(
1079 mut self,
1080 readiness_path: impl Into<String>,
1081 liveness_path: impl Into<String>,
1082 ) -> Self {
1083 self.health.readiness_path = Some(normalize_route_path(readiness_path.into()));
1084 self.health.liveness_path = Some(normalize_route_path(liveness_path.into()));
1085 self
1086 }
1087
1088 pub fn readiness_liveness_default(self) -> Self {
1090 self.readiness_liveness("/ready", "/live")
1091 }
1092
1093 pub fn route<Out, E>(self, path: impl Into<String>, circuit: Axon<(), Out, E, R>) -> Self
1097 where
1098 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1099 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1100 {
1101 self.route_method(Method::GET, path, circuit)
1102 }
1103 pub fn route_method<Out, E>(
1112 self,
1113 method: Method,
1114 path: impl Into<String>,
1115 circuit: Axon<(), Out, E, R>,
1116 ) -> Self
1117 where
1118 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1119 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1120 {
1121 self.route_method_with_error(method, path, circuit, |error| {
1122 (
1123 StatusCode::INTERNAL_SERVER_ERROR,
1124 format!("Error: {:?}", error),
1125 )
1126 .into_response()
1127 })
1128 }
1129
1130 pub fn route_method_with_error<Out, E, H>(
1131 self,
1132 method: Method,
1133 path: impl Into<String>,
1134 circuit: Axon<(), Out, E, R>,
1135 error_handler: H,
1136 ) -> Self
1137 where
1138 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1139 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1140 H: Fn(&E) -> HttpResponse + Send + Sync + 'static,
1141 {
1142 self.route_method_with_error_and_layers(
1143 method,
1144 path,
1145 circuit,
1146 error_handler,
1147 Arc::new(Vec::new()),
1148 true,
1149 )
1150 }
1151
1152
1153
1154 fn route_method_with_error_and_layers<Out, E, H>(
1155 mut self,
1156 method: Method,
1157 path: impl Into<String>,
1158 circuit: Axon<(), Out, E, R>,
1159 error_handler: H,
1160 route_layers: Arc<Vec<ServiceLayer>>,
1161 apply_global_layers: bool,
1162 ) -> Self
1163 where
1164 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1165 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1166 H: Fn(&E) -> HttpResponse + Send + Sync + 'static,
1167 {
1168 let path_str: String = path.into();
1169 let circuit = Arc::new(circuit);
1170 let error_handler = Arc::new(error_handler);
1171 let route_bus_injectors = Arc::new(self.bus_injectors.clone());
1172 let path_for_pattern = path_str.clone();
1173 let path_for_handler = path_str;
1174 let method_for_pattern = method.clone();
1175 let method_for_handler = method;
1176
1177 let handler: RouteHandler<R> = Arc::new(move |parts: http::request::Parts, res: &R| {
1178 let circuit = circuit.clone();
1179 let error_handler = error_handler.clone();
1180 let route_bus_injectors = route_bus_injectors.clone();
1181 let res = res.clone();
1182 let path = path_for_handler.clone();
1183 let method = method_for_handler.clone();
1184
1185 Box::pin(async move {
1186 let request_id = uuid::Uuid::new_v4().to_string();
1187 let span = tracing::info_span!(
1188 "HTTPRequest",
1189 ranvier.http.method = %method,
1190 ranvier.http.path = %path,
1191 ranvier.http.request_id = %request_id
1192 );
1193
1194 async move {
1195 let mut bus = Bus::new();
1196 for injector in route_bus_injectors.iter() {
1197 injector(&parts, &mut bus);
1198 }
1199 let result = circuit.execute((), &res, &mut bus).await;
1200 outcome_to_response_with_error(result, |error| error_handler(error))
1201 }
1202 .instrument(span)
1203 .await
1204 }) as Pin<Box<dyn Future<Output = HttpResponse> + Send>>
1205 });
1206
1207 self.routes.push(RouteEntry {
1208 method: method_for_pattern,
1209 pattern: RoutePattern::parse(&path_for_pattern),
1210 handler,
1211 layers: route_layers,
1212 apply_global_layers,
1213 });
1214 self
1215 }
1216
1217 pub fn get<Out, E>(self, path: impl Into<String>, circuit: Axon<(), Out, E, R>) -> Self
1218 where
1219 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1220 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1221 {
1222 self.route_method(Method::GET, path, circuit)
1223 }
1224
1225 pub fn get_with_error<Out, E, H>(
1226 self,
1227 path: impl Into<String>,
1228 circuit: Axon<(), Out, E, R>,
1229 error_handler: H,
1230 ) -> Self
1231 where
1232 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1233 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1234 H: Fn(&E) -> HttpResponse + Send + Sync + 'static,
1235 {
1236 self.route_method_with_error(Method::GET, path, circuit, error_handler)
1237 }
1238
1239 pub fn post<Out, E>(self, path: impl Into<String>, circuit: Axon<(), Out, E, R>) -> Self
1240 where
1241 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1242 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1243 {
1244 self.route_method(Method::POST, path, circuit)
1245 }
1246
1247 pub fn put<Out, E>(self, path: impl Into<String>, circuit: Axon<(), Out, E, R>) -> Self
1248 where
1249 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1250 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1251 {
1252 self.route_method(Method::PUT, path, circuit)
1253 }
1254
1255 pub fn delete<Out, E>(self, path: impl Into<String>, circuit: Axon<(), Out, E, R>) -> Self
1256 where
1257 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1258 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1259 {
1260 self.route_method(Method::DELETE, path, circuit)
1261 }
1262
1263 pub fn patch<Out, E>(self, path: impl Into<String>, circuit: Axon<(), Out, E, R>) -> Self
1264 where
1265 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1266 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1267 {
1268 self.route_method(Method::PATCH, path, circuit)
1269 }
1270
1271 pub fn post_with_error<Out, E, H>(
1272 self,
1273 path: impl Into<String>,
1274 circuit: Axon<(), Out, E, R>,
1275 error_handler: H,
1276 ) -> Self
1277 where
1278 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1279 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1280 H: Fn(&E) -> HttpResponse + Send + Sync + 'static,
1281 {
1282 self.route_method_with_error(Method::POST, path, circuit, error_handler)
1283 }
1284
1285 pub fn put_with_error<Out, E, H>(
1286 self,
1287 path: impl Into<String>,
1288 circuit: Axon<(), Out, E, R>,
1289 error_handler: H,
1290 ) -> Self
1291 where
1292 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1293 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1294 H: Fn(&E) -> HttpResponse + Send + Sync + 'static,
1295 {
1296 self.route_method_with_error(Method::PUT, path, circuit, error_handler)
1297 }
1298
1299 pub fn delete_with_error<Out, E, H>(
1300 self,
1301 path: impl Into<String>,
1302 circuit: Axon<(), Out, E, R>,
1303 error_handler: H,
1304 ) -> Self
1305 where
1306 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1307 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1308 H: Fn(&E) -> HttpResponse + Send + Sync + 'static,
1309 {
1310 self.route_method_with_error(Method::DELETE, path, circuit, error_handler)
1311 }
1312
1313 pub fn patch_with_error<Out, E, H>(
1314 self,
1315 path: impl Into<String>,
1316 circuit: Axon<(), Out, E, R>,
1317 error_handler: H,
1318 ) -> Self
1319 where
1320 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1321 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1322 H: Fn(&E) -> HttpResponse + Send + Sync + 'static,
1323 {
1324 self.route_method_with_error(Method::PATCH, path, circuit, error_handler)
1325 }
1326
1327 pub fn fallback<Out, E>(mut self, circuit: Axon<(), Out, E, R>) -> Self
1338 where
1339 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1340 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1341 {
1342 let circuit = Arc::new(circuit);
1343 let fallback_bus_injectors = Arc::new(self.bus_injectors.clone());
1344
1345 let handler: RouteHandler<R> = Arc::new(move |parts: http::request::Parts, res: &R| {
1346 let circuit = circuit.clone();
1347 let fallback_bus_injectors = fallback_bus_injectors.clone();
1348 let res = res.clone();
1349 Box::pin(async move {
1350 let request_id = uuid::Uuid::new_v4().to_string();
1351 let span = tracing::info_span!(
1352 "HTTPRequest",
1353 ranvier.http.method = "FALLBACK",
1354 ranvier.http.request_id = %request_id
1355 );
1356
1357 async move {
1358 let mut bus = Bus::new();
1359 for injector in fallback_bus_injectors.iter() {
1360 injector(&parts, &mut bus);
1361 }
1362 let result: ranvier_core::Outcome<Out, E> =
1363 circuit.execute((), &res, &mut bus).await;
1364
1365 match result {
1366 Outcome::Next(output) => {
1367 let mut response = output.into_response();
1368 *response.status_mut() = StatusCode::NOT_FOUND;
1369 response
1370 }
1371 _ => Response::builder()
1372 .status(StatusCode::NOT_FOUND)
1373 .body(
1374 Full::new(Bytes::from("Not Found"))
1375 .map_err(|never| match never {})
1376 .boxed(),
1377 )
1378 .expect("valid HTTP response construction"),
1379 }
1380 }
1381 .instrument(span)
1382 .await
1383 }) as Pin<Box<dyn Future<Output = HttpResponse> + Send>>
1384 });
1385
1386 self.fallback = Some(handler);
1387 self
1388 }
1389
1390 pub async fn run(self, resources: R) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
1394 self.run_with_shutdown_signal(resources, shutdown_signal())
1395 .await
1396 }
1397
1398 async fn run_with_shutdown_signal<S>(
1399 self,
1400 resources: R,
1401 shutdown_signal: S,
1402 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>>
1403 where
1404 S: Future<Output = ()> + Send,
1405 {
1406 let addr_str = self.addr.as_deref().unwrap_or("127.0.0.1:3000");
1407 let addr: SocketAddr = addr_str.parse()?;
1408
1409 let mut raw_routes = self.routes;
1410 if self.active_intervention {
1411 let handler: RouteHandler<R> = Arc::new(|_parts, _res| {
1412 Box::pin(async move {
1413 Response::builder()
1414 .status(StatusCode::OK)
1415 .body(
1416 Full::new(Bytes::from("Intervention accepted"))
1417 .map_err(|never| match never {} as Infallible)
1418 .boxed(),
1419 )
1420 .expect("valid HTTP response construction")
1421 }) as Pin<Box<dyn Future<Output = HttpResponse> + Send>>
1422 });
1423
1424 raw_routes.push(RouteEntry {
1425 method: Method::POST,
1426 pattern: RoutePattern::parse("/_system/intervene/force_resume"),
1427 handler,
1428 layers: Arc::new(Vec::new()),
1429 apply_global_layers: true,
1430 });
1431 }
1432
1433 if let Some(registry) = self.policy_registry.clone() {
1434 let handler: RouteHandler<R> = Arc::new(move |_parts, _res| {
1435 let _registry = registry.clone();
1436 Box::pin(async move {
1437 Response::builder()
1441 .status(StatusCode::OK)
1442 .body(
1443 Full::new(Bytes::from("Policy registry active"))
1444 .map_err(|never| match never {} as Infallible)
1445 .boxed(),
1446 )
1447 .expect("valid HTTP response construction")
1448 }) as Pin<Box<dyn Future<Output = HttpResponse> + Send>>
1449 });
1450
1451 raw_routes.push(RouteEntry {
1452 method: Method::POST,
1453 pattern: RoutePattern::parse("/_system/policy/reload"),
1454 handler,
1455 layers: Arc::new(Vec::new()),
1456 apply_global_layers: true,
1457 });
1458 }
1459 let routes = Arc::new(raw_routes);
1460 let fallback = self.fallback;
1461 let layers = Arc::new(self.layers);
1462 let health = Arc::new(self.health);
1463 let static_assets = Arc::new(self.static_assets);
1464 let on_start = self.on_start;
1465 let on_shutdown = self.on_shutdown;
1466 let graceful_shutdown_timeout = self.graceful_shutdown_timeout;
1467 let resources = Arc::new(resources);
1468
1469 let listener = TcpListener::bind(addr).await?;
1470
1471 #[cfg(feature = "tls")]
1473 let tls_acceptor = if let Some(ref tls_cfg) = self.tls_config {
1474 let acceptor = build_tls_acceptor(&tls_cfg.cert_path, &tls_cfg.key_path)?;
1475 tracing::info!("Ranvier HTTP Ingress listening on https://{}", addr);
1476 Some(acceptor)
1477 } else {
1478 tracing::info!("Ranvier HTTP Ingress listening on http://{}", addr);
1479 None
1480 };
1481 #[cfg(not(feature = "tls"))]
1482 tracing::info!("Ranvier HTTP Ingress listening on http://{}", addr);
1483
1484 if let Some(callback) = on_start.as_ref() {
1485 callback();
1486 }
1487
1488 tokio::pin!(shutdown_signal);
1489 let mut connections = tokio::task::JoinSet::new();
1490
1491 loop {
1492 tokio::select! {
1493 _ = &mut shutdown_signal => {
1494 tracing::info!("Shutdown signal received. Draining in-flight connections.");
1495 break;
1496 }
1497 accept_result = listener.accept() => {
1498 let (stream, _) = accept_result?;
1499
1500 let routes = routes.clone();
1501 let fallback = fallback.clone();
1502 let resources = resources.clone();
1503 let layers = layers.clone();
1504 let health = health.clone();
1505 let static_assets = static_assets.clone();
1506 #[cfg(feature = "http3")]
1507 let alt_svc_h3_port = self.alt_svc_h3_port;
1508
1509 #[cfg(feature = "tls")]
1510 let tls_acceptor = tls_acceptor.clone();
1511
1512 connections.spawn(async move {
1513 let service = build_http_service(
1514 routes,
1515 fallback,
1516 resources,
1517 layers,
1518 health,
1519 static_assets,
1520 #[cfg(feature = "http3")] alt_svc_h3_port,
1521 );
1522
1523 #[cfg(feature = "tls")]
1524 if let Some(acceptor) = tls_acceptor {
1525 match acceptor.accept(stream).await {
1526 Ok(tls_stream) => {
1527 let io = TokioIo::new(tls_stream);
1528 if let Err(err) = http1::Builder::new()
1529 .serve_connection(io, service)
1530 .with_upgrades()
1531 .await
1532 {
1533 tracing::error!("Error serving TLS connection: {:?}", err);
1534 }
1535 }
1536 Err(err) => {
1537 tracing::warn!("TLS handshake failed: {:?}", err);
1538 }
1539 }
1540 return;
1541 }
1542
1543 let io = TokioIo::new(stream);
1544 if let Err(err) = http1::Builder::new()
1545 .serve_connection(io, service)
1546 .with_upgrades()
1547 .await
1548 {
1549 tracing::error!("Error serving connection: {:?}", err);
1550 }
1551 });
1552 }
1553 Some(join_result) = connections.join_next(), if !connections.is_empty() => {
1554 if let Err(err) = join_result {
1555 tracing::warn!("Connection task join error: {:?}", err);
1556 }
1557 }
1558 }
1559 }
1560
1561 let _timed_out = drain_connections(&mut connections, graceful_shutdown_timeout).await;
1562
1563 drop(resources);
1564 if let Some(callback) = on_shutdown.as_ref() {
1565 callback();
1566 }
1567
1568 Ok(())
1569 }
1570
1571 pub fn into_raw_service(self, resources: R) -> RawIngressService<R> {
1587 let routes = Arc::new(self.routes);
1588 let fallback = self.fallback;
1589 let layers = Arc::new(self.layers);
1590 let health = Arc::new(self.health);
1591 let static_assets = Arc::new(self.static_assets);
1592 let resources = Arc::new(resources);
1593
1594 RawIngressService {
1595 routes,
1596 fallback,
1597 layers,
1598 health,
1599 static_assets,
1600 resources,
1601 }
1602 }
1603}
1604
1605fn build_http_service<R>(
1606 routes: Arc<Vec<RouteEntry<R>>>,
1607 fallback: Option<RouteHandler<R>>,
1608 resources: Arc<R>,
1609 layers: Arc<Vec<ServiceLayer>>,
1610 health: Arc<HealthConfig<R>>,
1611 static_assets: Arc<StaticAssetsConfig>,
1612 #[cfg(feature = "http3")] alt_svc_port: Option<u16>,
1613) -> BoxHttpService
1614where
1615 R: ranvier_core::transition::ResourceRequirement + Clone + Send + Sync + 'static,
1616{
1617 BoxService::new(move |req: Request<Incoming>| {
1618 let routes = routes.clone();
1619 let fallback = fallback.clone();
1620 let resources = resources.clone();
1621 let layers = layers.clone();
1622 let health = health.clone();
1623 let static_assets = static_assets.clone();
1624
1625 async move {
1626 let mut req = req;
1627 let method = req.method().clone();
1628 let path = req.uri().path().to_string();
1629
1630 if let Some(response) =
1631 maybe_handle_health_request(&method, &path, &health, resources.clone()).await
1632 {
1633 return Ok::<_, Infallible>(response.into_response());
1634 }
1635
1636 if let Some((entry, params)) = find_matching_route(routes.as_slice(), &method, &path) {
1637 req.extensions_mut().insert(params);
1638 let effective_layers = if entry.apply_global_layers {
1639 merge_layers(&layers, &entry.layers)
1640 } else {
1641 entry.layers.clone()
1642 };
1643
1644 if effective_layers.is_empty() {
1645 let (parts, _) = req.into_parts();
1646 #[allow(unused_mut)]
1647 let mut res = (entry.handler)(parts, &resources).await;
1648 #[cfg(feature = "http3")]
1649 if let Some(port) = alt_svc_port {
1650 if let Ok(val) =
1651 http::HeaderValue::from_str(&format!("h3=\":{}\"; ma=86400", port))
1652 {
1653 res.headers_mut().insert(http::header::ALT_SVC, val);
1654 }
1655 }
1656 Ok::<_, Infallible>(res)
1657 } else {
1658 let route_service = build_route_service(
1659 entry.handler.clone(),
1660 resources.clone(),
1661 effective_layers,
1662 );
1663 #[allow(unused_mut)]
1664 let mut res = route_service.call(req).await;
1665 #[cfg(feature = "http3")]
1666 #[allow(irrefutable_let_patterns)]
1667 if let Ok(ref mut r) = res {
1668 if let Some(port) = alt_svc_port {
1669 if let Ok(val) =
1670 http::HeaderValue::from_str(&format!("h3=\":{}\"; ma=86400", port))
1671 {
1672 r.headers_mut().insert(http::header::ALT_SVC, val);
1673 }
1674 }
1675 }
1676 res
1677 }
1678 } else {
1679 let req =
1680 match maybe_handle_static_request(req, &method, &path, static_assets.as_ref())
1681 .await
1682 {
1683 Ok(req) => req,
1684 Err(response) => return Ok(response),
1685 };
1686
1687 #[allow(unused_mut)]
1688 let mut fallback_res = if let Some(ref fb) = fallback {
1689 if layers.is_empty() {
1690 let (parts, _) = req.into_parts();
1691 Ok(fb(parts, &resources).await)
1692 } else {
1693 let fallback_service =
1694 build_route_service(fb.clone(), resources.clone(), layers.clone());
1695 fallback_service.call(req).await
1696 }
1697 } else {
1698 Ok(Response::builder()
1699 .status(StatusCode::NOT_FOUND)
1700 .body(
1701 Full::new(Bytes::from("Not Found"))
1702 .map_err(|never| match never {})
1703 .boxed(),
1704 )
1705 .expect("valid HTTP response construction"))
1706 };
1707
1708 #[cfg(feature = "http3")]
1709 if let Ok(r) = fallback_res.as_mut() {
1710 if let Some(port) = alt_svc_port {
1711 if let Ok(val) =
1712 http::HeaderValue::from_str(&format!("h3=\":{}\"; ma=86400", port))
1713 {
1714 r.headers_mut().insert(http::header::ALT_SVC, val);
1715 }
1716 }
1717 }
1718
1719 fallback_res
1720 }
1721 }
1722 })
1723}
1724
1725fn build_route_service<R>(
1726 handler: RouteHandler<R>,
1727 resources: Arc<R>,
1728 layers: Arc<Vec<ServiceLayer>>,
1729) -> BoxHttpService
1730where
1731 R: ranvier_core::transition::ResourceRequirement + Clone + Send + Sync + 'static,
1732{
1733 let mut service = BoxService::new(move |req: Request<Incoming>| {
1734 let handler = handler.clone();
1735 let resources = resources.clone();
1736 async move {
1737 let (parts, _) = req.into_parts();
1738 Ok::<_, Infallible>(handler(parts, &resources).await)
1739 }
1740 });
1741
1742 for layer in layers.iter() {
1743 service = layer(service);
1744 }
1745 service
1746}
1747
1748fn merge_layers(
1749 global_layers: &Arc<Vec<ServiceLayer>>,
1750 route_layers: &Arc<Vec<ServiceLayer>>,
1751) -> Arc<Vec<ServiceLayer>> {
1752 if global_layers.is_empty() {
1753 return route_layers.clone();
1754 }
1755 if route_layers.is_empty() {
1756 return global_layers.clone();
1757 }
1758
1759 let mut combined = Vec::with_capacity(global_layers.len() + route_layers.len());
1760 combined.extend(global_layers.iter().cloned());
1761 combined.extend(route_layers.iter().cloned());
1762 Arc::new(combined)
1763}
1764
1765async fn maybe_handle_health_request<R>(
1766 method: &Method,
1767 path: &str,
1768 health: &HealthConfig<R>,
1769 resources: Arc<R>,
1770) -> Option<HttpResponse>
1771where
1772 R: ranvier_core::transition::ResourceRequirement + Clone + Send + Sync + 'static,
1773{
1774 if method != Method::GET {
1775 return None;
1776 }
1777
1778 if let Some(liveness_path) = health.liveness_path.as_ref() {
1779 if path == liveness_path {
1780 return Some(health_json_response("liveness", true, Vec::new()));
1781 }
1782 }
1783
1784 if let Some(readiness_path) = health.readiness_path.as_ref() {
1785 if path == readiness_path {
1786 let (healthy, checks) = run_named_health_checks(&health.checks, resources).await;
1787 return Some(health_json_response("readiness", healthy, checks));
1788 }
1789 }
1790
1791 if let Some(health_path) = health.health_path.as_ref() {
1792 if path == health_path {
1793 let (healthy, checks) = run_named_health_checks(&health.checks, resources).await;
1794 return Some(health_json_response("health", healthy, checks));
1795 }
1796 }
1797
1798 None
1799}
1800
1801async fn serve_single_file(file_path: &str) -> Result<Response<Full<Bytes>>, std::io::Error> {
1803 let path = std::path::Path::new(file_path);
1804 let content = tokio::fs::read(path).await?;
1805 let mime = guess_mime(file_path);
1806 let mut response = Response::new(Full::new(Bytes::from(content)));
1807 if let Ok(value) = http::HeaderValue::from_str(mime) {
1808 response
1809 .headers_mut()
1810 .insert(http::header::CONTENT_TYPE, value);
1811 }
1812 if let Ok(metadata) = tokio::fs::metadata(path).await {
1813 if let Ok(modified) = metadata.modified() {
1814 if let Ok(duration) = modified.duration_since(std::time::UNIX_EPOCH) {
1815 let etag = format!("\"{}\"", duration.as_secs());
1816 if let Ok(value) = http::HeaderValue::from_str(&etag) {
1817 response.headers_mut().insert(http::header::ETAG, value);
1818 }
1819 }
1820 }
1821 }
1822 Ok(response)
1823}
1824
1825async fn serve_static_file(
1827 directory: &str,
1828 file_subpath: &str,
1829) -> Result<Response<Full<Bytes>>, std::io::Error> {
1830 let subpath = file_subpath.trim_start_matches('/');
1831 if subpath.is_empty() || subpath == "/" {
1832 return Err(std::io::Error::new(
1833 std::io::ErrorKind::NotFound,
1834 "empty path",
1835 ));
1836 }
1837 let full_path = std::path::Path::new(directory).join(subpath);
1838 let canonical = tokio::fs::canonicalize(&full_path).await?;
1840 let dir_canonical = tokio::fs::canonicalize(directory).await?;
1841 if !canonical.starts_with(&dir_canonical) {
1842 return Err(std::io::Error::new(
1843 std::io::ErrorKind::PermissionDenied,
1844 "path traversal detected",
1845 ));
1846 }
1847 let content = tokio::fs::read(&canonical).await?;
1848 let mime = guess_mime(canonical.to_str().unwrap_or(""));
1849 let mut response = Response::new(Full::new(Bytes::from(content)));
1850 if let Ok(value) = http::HeaderValue::from_str(mime) {
1851 response
1852 .headers_mut()
1853 .insert(http::header::CONTENT_TYPE, value);
1854 }
1855 if let Ok(metadata) = tokio::fs::metadata(&canonical).await {
1856 if let Ok(modified) = metadata.modified() {
1857 if let Ok(duration) = modified.duration_since(std::time::UNIX_EPOCH) {
1858 let etag = format!("\"{}\"", duration.as_secs());
1859 if let Ok(value) = http::HeaderValue::from_str(&etag) {
1860 response.headers_mut().insert(http::header::ETAG, value);
1861 }
1862 }
1863 }
1864 }
1865 Ok(response)
1866}
1867
1868fn guess_mime(path: &str) -> &'static str {
1869 match path.rsplit('.').next().unwrap_or("") {
1870 "html" | "htm" => "text/html; charset=utf-8",
1871 "css" => "text/css; charset=utf-8",
1872 "js" | "mjs" => "application/javascript; charset=utf-8",
1873 "json" => "application/json; charset=utf-8",
1874 "png" => "image/png",
1875 "jpg" | "jpeg" => "image/jpeg",
1876 "gif" => "image/gif",
1877 "svg" => "image/svg+xml",
1878 "ico" => "image/x-icon",
1879 "woff" => "font/woff",
1880 "woff2" => "font/woff2",
1881 "ttf" => "font/ttf",
1882 "txt" => "text/plain; charset=utf-8",
1883 "xml" => "application/xml; charset=utf-8",
1884 "wasm" => "application/wasm",
1885 "pdf" => "application/pdf",
1886 _ => "application/octet-stream",
1887 }
1888}
1889
1890fn apply_cache_control(
1891 mut response: Response<Full<Bytes>>,
1892 cache_control: Option<&str>,
1893) -> Response<Full<Bytes>> {
1894 if response.status() == StatusCode::OK {
1895 if let Some(value) = cache_control {
1896 if !response.headers().contains_key(http::header::CACHE_CONTROL) {
1897 if let Ok(header_value) = http::HeaderValue::from_str(value) {
1898 response
1899 .headers_mut()
1900 .insert(http::header::CACHE_CONTROL, header_value);
1901 }
1902 }
1903 }
1904 }
1905 response
1906}
1907
1908async fn maybe_handle_static_request(
1909 req: Request<Incoming>,
1910 method: &Method,
1911 path: &str,
1912 static_assets: &StaticAssetsConfig,
1913) -> Result<Request<Incoming>, HttpResponse> {
1914 if method != Method::GET && method != Method::HEAD {
1915 return Ok(req);
1916 }
1917
1918 if let Some(mount) = static_assets
1919 .mounts
1920 .iter()
1921 .find(|mount| strip_mount_prefix(path, &mount.route_prefix).is_some())
1922 {
1923 let accept_encoding = req.headers().get(http::header::ACCEPT_ENCODING).cloned();
1924 let Some(stripped_path) = strip_mount_prefix(path, &mount.route_prefix) else {
1925 return Ok(req);
1926 };
1927 let response = match serve_static_file(&mount.directory, &stripped_path).await {
1928 Ok(response) => response,
1929 Err(_) => {
1930 return Err(Response::builder()
1931 .status(StatusCode::INTERNAL_SERVER_ERROR)
1932 .body(
1933 Full::new(Bytes::from("Failed to serve static asset"))
1934 .map_err(|never| match never {})
1935 .boxed(),
1936 )
1937 .unwrap_or_else(|_| {
1938 Response::new(
1939 Full::new(Bytes::new())
1940 .map_err(|never| match never {})
1941 .boxed(),
1942 )
1943 }));
1944 }
1945 };
1946 let mut response = apply_cache_control(response, static_assets.cache_control.as_deref());
1947 response = maybe_compress_static_response(
1948 response,
1949 accept_encoding,
1950 static_assets.enable_compression,
1951 );
1952 let (parts, body) = response.into_parts();
1953 return Err(Response::from_parts(
1954 parts,
1955 body.map_err(|never| match never {}).boxed(),
1956 ));
1957 }
1958
1959 if let Some(spa_file) = static_assets.spa_fallback.as_ref() {
1960 if looks_like_spa_request(path) {
1961 let accept_encoding = req.headers().get(http::header::ACCEPT_ENCODING).cloned();
1962 let response = match serve_single_file(spa_file).await {
1963 Ok(response) => response,
1964 Err(_) => {
1965 return Err(Response::builder()
1966 .status(StatusCode::INTERNAL_SERVER_ERROR)
1967 .body(
1968 Full::new(Bytes::from("Failed to serve SPA fallback"))
1969 .map_err(|never| match never {})
1970 .boxed(),
1971 )
1972 .unwrap_or_else(|_| {
1973 Response::new(
1974 Full::new(Bytes::new())
1975 .map_err(|never| match never {})
1976 .boxed(),
1977 )
1978 }));
1979 }
1980 };
1981 let mut response =
1982 apply_cache_control(response, static_assets.cache_control.as_deref());
1983 response = maybe_compress_static_response(
1984 response,
1985 accept_encoding,
1986 static_assets.enable_compression,
1987 );
1988 let (parts, body) = response.into_parts();
1989 return Err(Response::from_parts(
1990 parts,
1991 body.map_err(|never| match never {}).boxed(),
1992 ));
1993 }
1994 }
1995
1996 Ok(req)
1997}
1998
1999fn strip_mount_prefix(path: &str, prefix: &str) -> Option<String> {
2000 let normalized_prefix = if prefix == "/" {
2001 "/"
2002 } else {
2003 prefix.trim_end_matches('/')
2004 };
2005
2006 if normalized_prefix == "/" {
2007 return Some(path.to_string());
2008 }
2009
2010 if path == normalized_prefix {
2011 return Some("/".to_string());
2012 }
2013
2014 let with_slash = format!("{normalized_prefix}/");
2015 path.strip_prefix(&with_slash)
2016 .map(|stripped| format!("/{}", stripped))
2017}
2018
2019fn looks_like_spa_request(path: &str) -> bool {
2020 let tail = path.rsplit('/').next().unwrap_or_default();
2021 !tail.contains('.')
2022}
2023
2024fn maybe_compress_static_response(
2025 response: Response<Full<Bytes>>,
2026 accept_encoding: Option<http::HeaderValue>,
2027 enable_compression: bool,
2028) -> Response<Full<Bytes>> {
2029 if !enable_compression {
2030 return response;
2031 }
2032
2033 let Some(accept_encoding) = accept_encoding else {
2034 return response;
2035 };
2036
2037 let accept_str = accept_encoding.to_str().unwrap_or("");
2038 if !accept_str.contains("gzip") {
2039 return response;
2040 }
2041
2042 let status = response.status();
2043 let headers = response.headers().clone();
2044 let body = response.into_body();
2045
2046 let data = futures_util::FutureExt::now_or_never(BodyExt::collect(body))
2048 .and_then(|r| r.ok())
2049 .map(|collected| collected.to_bytes())
2050 .unwrap_or_default();
2051
2052 let compressed = {
2054 use flate2::write::GzEncoder;
2055 use flate2::Compression;
2056 use std::io::Write;
2057 let mut encoder = GzEncoder::new(Vec::new(), Compression::default());
2058 let _ = encoder.write_all(&data);
2059 encoder.finish().unwrap_or_default()
2060 };
2061
2062 let mut builder = Response::builder().status(status);
2063 for (name, value) in headers.iter() {
2064 if name != http::header::CONTENT_LENGTH && name != http::header::CONTENT_ENCODING {
2065 builder = builder.header(name, value);
2066 }
2067 }
2068 builder
2069 .header(http::header::CONTENT_ENCODING, "gzip")
2070 .body(Full::new(Bytes::from(compressed)))
2071 .unwrap_or_else(|_| Response::new(Full::new(Bytes::new())))
2072}
2073
2074async fn run_named_health_checks<R>(
2075 checks: &[NamedHealthCheck<R>],
2076 resources: Arc<R>,
2077) -> (bool, Vec<HealthCheckReport>)
2078where
2079 R: ranvier_core::transition::ResourceRequirement + Clone + Send + Sync + 'static,
2080{
2081 let mut reports = Vec::with_capacity(checks.len());
2082 let mut healthy = true;
2083
2084 for check in checks {
2085 match (check.check)(resources.clone()).await {
2086 Ok(()) => reports.push(HealthCheckReport {
2087 name: check.name.clone(),
2088 status: "ok",
2089 error: None,
2090 }),
2091 Err(error) => {
2092 healthy = false;
2093 reports.push(HealthCheckReport {
2094 name: check.name.clone(),
2095 status: "error",
2096 error: Some(error),
2097 });
2098 }
2099 }
2100 }
2101
2102 (healthy, reports)
2103}
2104
2105fn health_json_response(
2106 probe: &'static str,
2107 healthy: bool,
2108 checks: Vec<HealthCheckReport>,
2109) -> HttpResponse {
2110 let status_code = if healthy {
2111 StatusCode::OK
2112 } else {
2113 StatusCode::SERVICE_UNAVAILABLE
2114 };
2115 let status = if healthy { "ok" } else { "degraded" };
2116 let payload = HealthReport {
2117 status,
2118 probe,
2119 checks,
2120 };
2121
2122 let body = serde_json::to_vec(&payload)
2123 .unwrap_or_else(|_| br#"{"status":"error","probe":"health"}"#.to_vec());
2124
2125 Response::builder()
2126 .status(status_code)
2127 .header(http::header::CONTENT_TYPE, "application/json")
2128 .body(
2129 Full::new(Bytes::from(body))
2130 .map_err(|never| match never {})
2131 .boxed(),
2132 )
2133 .expect("valid HTTP response construction")
2134}
2135
2136async fn shutdown_signal() {
2137 #[cfg(unix)]
2138 {
2139 use tokio::signal::unix::{SignalKind, signal};
2140
2141 match signal(SignalKind::terminate()) {
2142 Ok(mut terminate) => {
2143 tokio::select! {
2144 _ = tokio::signal::ctrl_c() => {}
2145 _ = terminate.recv() => {}
2146 }
2147 }
2148 Err(err) => {
2149 tracing::warn!("Failed to install SIGTERM handler: {:?}", err);
2150 if let Err(ctrl_c_err) = tokio::signal::ctrl_c().await {
2151 tracing::warn!("Failed to listen for Ctrl+C: {:?}", ctrl_c_err);
2152 }
2153 }
2154 }
2155 }
2156
2157 #[cfg(not(unix))]
2158 {
2159 if let Err(err) = tokio::signal::ctrl_c().await {
2160 tracing::warn!("Failed to listen for Ctrl+C: {:?}", err);
2161 }
2162 }
2163}
2164
2165async fn drain_connections(
2166 connections: &mut tokio::task::JoinSet<()>,
2167 graceful_shutdown_timeout: Duration,
2168) -> bool {
2169 if connections.is_empty() {
2170 return false;
2171 }
2172
2173 let drain_result = tokio::time::timeout(graceful_shutdown_timeout, async {
2174 while let Some(join_result) = connections.join_next().await {
2175 if let Err(err) = join_result {
2176 tracing::warn!("Connection task join error during shutdown: {:?}", err);
2177 }
2178 }
2179 })
2180 .await;
2181
2182 if drain_result.is_err() {
2183 tracing::warn!(
2184 "Graceful shutdown timeout reached ({:?}). Aborting remaining connections.",
2185 graceful_shutdown_timeout
2186 );
2187 connections.abort_all();
2188 while let Some(join_result) = connections.join_next().await {
2189 if let Err(err) = join_result {
2190 tracing::warn!("Connection task abort join error: {:?}", err);
2191 }
2192 }
2193 true
2194 } else {
2195 false
2196 }
2197}
2198
2199#[cfg(feature = "tls")]
2201fn build_tls_acceptor(
2202 cert_path: &str,
2203 key_path: &str,
2204) -> Result<tokio_rustls::TlsAcceptor, Box<dyn std::error::Error + Send + Sync>> {
2205 use rustls::ServerConfig;
2206 use rustls_pemfile::{certs, private_key};
2207 use std::io::BufReader;
2208 use tokio_rustls::TlsAcceptor;
2209
2210 let cert_file = std::fs::File::open(cert_path)
2211 .map_err(|e| format!("Failed to open certificate file '{}': {}", cert_path, e))?;
2212 let key_file = std::fs::File::open(key_path)
2213 .map_err(|e| format!("Failed to open key file '{}': {}", key_path, e))?;
2214
2215 let cert_chain: Vec<_> = certs(&mut BufReader::new(cert_file))
2216 .collect::<Result<Vec<_>, _>>()
2217 .map_err(|e| format!("Failed to parse certificate PEM: {}", e))?;
2218
2219 let key = private_key(&mut BufReader::new(key_file))
2220 .map_err(|e| format!("Failed to parse private key PEM: {}", e))?
2221 .ok_or("No private key found in key file")?;
2222
2223 let config = ServerConfig::builder()
2224 .with_no_client_auth()
2225 .with_single_cert(cert_chain, key)
2226 .map_err(|e| format!("TLS configuration error: {}", e))?;
2227
2228 Ok(TlsAcceptor::from(Arc::new(config)))
2229}
2230
2231impl<R> Default for HttpIngress<R>
2232where
2233 R: ranvier_core::transition::ResourceRequirement + Clone + Send + Sync + 'static,
2234{
2235 fn default() -> Self {
2236 Self::new()
2237 }
2238}
2239
2240#[derive(Clone)]
2242pub struct RawIngressService<R> {
2243 routes: Arc<Vec<RouteEntry<R>>>,
2244 fallback: Option<RouteHandler<R>>,
2245 layers: Arc<Vec<ServiceLayer>>,
2246 health: Arc<HealthConfig<R>>,
2247 static_assets: Arc<StaticAssetsConfig>,
2248 resources: Arc<R>,
2249}
2250
2251impl<R> hyper::service::Service<Request<Incoming>> for RawIngressService<R>
2252where
2253 R: ranvier_core::transition::ResourceRequirement + Clone + Send + Sync + 'static,
2254{
2255 type Response = HttpResponse;
2256 type Error = Infallible;
2257 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
2258
2259 fn call(&self, req: Request<Incoming>) -> Self::Future {
2260 let routes = self.routes.clone();
2261 let fallback = self.fallback.clone();
2262 let layers = self.layers.clone();
2263 let health = self.health.clone();
2264 let static_assets = self.static_assets.clone();
2265 let resources = self.resources.clone();
2266
2267 Box::pin(async move {
2268 let service = build_http_service(
2269 routes,
2270 fallback,
2271 resources,
2272 layers,
2273 health,
2274 static_assets,
2275 #[cfg(feature = "http3")]
2276 None,
2277 );
2278 service.call(req).await
2279 })
2280 }
2281}
2282
2283#[cfg(test)]
2284mod tests {
2285 use super::*;
2286 use async_trait::async_trait;
2287 use futures_util::{SinkExt, StreamExt};
2288 use serde::Deserialize;
2289 use std::fs;
2290 use std::sync::atomic::{AtomicBool, Ordering};
2291 use tempfile::tempdir;
2292 use tokio::io::{AsyncReadExt, AsyncWriteExt};
2293 use tokio_tungstenite::tungstenite::Message as WsClientMessage;
2294 use tokio_tungstenite::tungstenite::client::IntoClientRequest;
2295
2296 async fn connect_with_retry(addr: std::net::SocketAddr) -> tokio::net::TcpStream {
2297 let deadline = tokio::time::Instant::now() + Duration::from_secs(2);
2298
2299 loop {
2300 match tokio::net::TcpStream::connect(addr).await {
2301 Ok(stream) => return stream,
2302 Err(error) => {
2303 if tokio::time::Instant::now() >= deadline {
2304 panic!("connect server: {error}");
2305 }
2306 tokio::time::sleep(Duration::from_millis(25)).await;
2307 }
2308 }
2309 }
2310 }
2311
2312 #[test]
2313 fn route_pattern_matches_static_path() {
2314 let pattern = RoutePattern::parse("/orders/list");
2315 let params = pattern.match_path("/orders/list").expect("should match");
2316 assert!(params.into_inner().is_empty());
2317 }
2318
2319 #[test]
2320 fn route_pattern_matches_param_segments() {
2321 let pattern = RoutePattern::parse("/orders/:id/items/:item_id");
2322 let params = pattern
2323 .match_path("/orders/42/items/sku-123")
2324 .expect("should match");
2325 assert_eq!(params.get("id"), Some("42"));
2326 assert_eq!(params.get("item_id"), Some("sku-123"));
2327 }
2328
2329 #[test]
2330 fn route_pattern_matches_wildcard_segment() {
2331 let pattern = RoutePattern::parse("/assets/*path");
2332 let params = pattern
2333 .match_path("/assets/css/theme/light.css")
2334 .expect("should match");
2335 assert_eq!(params.get("path"), Some("css/theme/light.css"));
2336 }
2337
2338 #[test]
2339 fn route_pattern_rejects_non_matching_path() {
2340 let pattern = RoutePattern::parse("/orders/:id");
2341 assert!(pattern.match_path("/users/42").is_none());
2342 }
2343
2344 #[test]
2345 fn graceful_shutdown_timeout_defaults_to_30_seconds() {
2346 let ingress = HttpIngress::<()>::new();
2347 assert_eq!(ingress.graceful_shutdown_timeout, Duration::from_secs(30));
2348 assert!(ingress.layers.is_empty());
2349 assert!(ingress.bus_injectors.is_empty());
2350 assert!(ingress.static_assets.mounts.is_empty());
2351 assert!(ingress.on_start.is_none());
2352 assert!(ingress.on_shutdown.is_none());
2353 }
2354
2355 #[test]
2356 fn route_without_layer_keeps_empty_route_middleware_stack() {
2357 let ingress =
2358 HttpIngress::<()>::new().get("/ping", Axon::<(), (), String, ()>::new("Ping"));
2359 assert_eq!(ingress.routes.len(), 1);
2360 assert!(ingress.routes[0].layers.is_empty());
2361 assert!(ingress.routes[0].apply_global_layers);
2362 }
2363
2364 #[test]
2365 fn timeout_layer_registers_builtin_middleware() {
2366 let ingress = HttpIngress::<()>::new().timeout_layer(Duration::from_secs(1));
2367 assert_eq!(ingress.layers.len(), 1);
2368 }
2369
2370 #[test]
2371 fn request_id_layer_registers_builtin_middleware() {
2372 let ingress = HttpIngress::<()>::new().request_id_layer();
2373 assert_eq!(ingress.layers.len(), 1);
2374 }
2375
2376 #[test]
2377 fn compression_layer_registers_builtin_middleware() {
2378 let ingress = HttpIngress::<()>::new().compression_layer();
2379 assert!(ingress.static_assets.enable_compression);
2380 }
2381
2382 #[test]
2383 fn bus_injector_registration_adds_hook() {
2384 let ingress = HttpIngress::<()>::new().bus_injector(|_req, bus| {
2385 bus.insert("ok".to_string());
2386 });
2387 assert_eq!(ingress.bus_injectors.len(), 1);
2388 }
2389
2390 #[test]
2391 fn ws_route_registers_get_route_pattern() {
2392 let ingress =
2393 HttpIngress::<()>::new().ws("/ws/events", |_socket, _resources, _bus| async {});
2394 assert_eq!(ingress.routes.len(), 1);
2395 assert_eq!(ingress.routes[0].method, Method::GET);
2396 assert_eq!(ingress.routes[0].pattern.raw, "/ws/events");
2397 }
2398
2399 #[derive(Debug, Deserialize)]
2400 struct WsWelcomeFrame {
2401 connection_id: String,
2402 path: String,
2403 tenant: String,
2404 }
2405
2406 #[tokio::test]
2407 async fn ws_route_upgrades_and_bridges_event_source_sink_with_connection_bus() {
2408 let probe = std::net::TcpListener::bind("127.0.0.1:0").expect("bind probe");
2409 let addr = probe.local_addr().expect("local addr");
2410 drop(probe);
2411
2412 let ingress = HttpIngress::<()>::new()
2413 .bind(addr.to_string())
2414 .bus_injector(|req, bus| {
2415 if let Some(value) = req.headers.get("x-tenant-id").and_then(|v| v.to_str().ok()) {
2416 bus.insert(value.to_string());
2417 }
2418 })
2419 .ws("/ws/echo", |mut socket, _resources, bus| async move {
2420 let tenant = bus
2421 .read::<String>()
2422 .cloned()
2423 .unwrap_or_else(|| "unknown".to_string());
2424 if let Some(session) = bus.read::<WebSocketSessionContext>() {
2425 let welcome = serde_json::json!({
2426 "connection_id": session.connection_id().to_string(),
2427 "path": session.path(),
2428 "tenant": tenant,
2429 });
2430 let _ = socket.send_json(&welcome).await;
2431 }
2432
2433 while let Some(event) = socket.next_event().await {
2434 match event {
2435 WebSocketEvent::Text(text) => {
2436 let _ = socket.send_event(format!("echo:{text}")).await;
2437 }
2438 WebSocketEvent::Binary(bytes) => {
2439 let _ = socket.send_event(bytes).await;
2440 }
2441 WebSocketEvent::Close => break,
2442 WebSocketEvent::Ping(_) | WebSocketEvent::Pong(_) => {}
2443 }
2444 }
2445 });
2446
2447 let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel::<()>();
2448 let server = tokio::spawn(async move {
2449 ingress
2450 .run_with_shutdown_signal((), async move {
2451 let _ = shutdown_rx.await;
2452 })
2453 .await
2454 });
2455
2456 let ws_uri = format!("ws://{addr}/ws/echo?room=alpha");
2457 let mut ws_request = ws_uri
2458 .as_str()
2459 .into_client_request()
2460 .expect("ws client request");
2461 ws_request
2462 .headers_mut()
2463 .insert("x-tenant-id", http::HeaderValue::from_static("acme"));
2464 let (mut client, _response) = tokio_tungstenite::connect_async(ws_request)
2465 .await
2466 .expect("websocket connect");
2467
2468 let welcome = client
2469 .next()
2470 .await
2471 .expect("welcome frame")
2472 .expect("welcome frame ok");
2473 let welcome_text = match welcome {
2474 WsClientMessage::Text(text) => text.to_string(),
2475 other => panic!("expected text welcome frame, got {other:?}"),
2476 };
2477 let welcome_payload: WsWelcomeFrame =
2478 serde_json::from_str(&welcome_text).expect("welcome json");
2479 assert_eq!(welcome_payload.path, "/ws/echo");
2480 assert_eq!(welcome_payload.tenant, "acme");
2481 assert!(!welcome_payload.connection_id.is_empty());
2482
2483 client
2484 .send(WsClientMessage::Text("hello".into()))
2485 .await
2486 .expect("send text");
2487 let echo_text = client
2488 .next()
2489 .await
2490 .expect("echo text frame")
2491 .expect("echo text frame ok");
2492 assert_eq!(echo_text, WsClientMessage::Text("echo:hello".into()));
2493
2494 client
2495 .send(WsClientMessage::Binary(vec![1, 2, 3, 4].into()))
2496 .await
2497 .expect("send binary");
2498 let echo_binary = client
2499 .next()
2500 .await
2501 .expect("echo binary frame")
2502 .expect("echo binary frame ok");
2503 assert_eq!(
2504 echo_binary,
2505 WsClientMessage::Binary(vec![1, 2, 3, 4].into())
2506 );
2507
2508 client.close(None).await.expect("close websocket");
2509
2510 let _ = shutdown_tx.send(());
2511 server
2512 .await
2513 .expect("server join")
2514 .expect("server shutdown should succeed");
2515 }
2516
2517 #[test]
2518 fn route_descriptors_export_http_and_health_paths() {
2519 let ingress = HttpIngress::<()>::new()
2520 .get(
2521 "/orders/:id",
2522 Axon::<(), (), String, ()>::new("OrderById"),
2523 )
2524 .health_endpoint("/healthz")
2525 .readiness_liveness("/readyz", "/livez");
2526
2527 let descriptors = ingress.route_descriptors();
2528
2529 assert!(
2530 descriptors
2531 .iter()
2532 .any(|descriptor| descriptor.method() == Method::GET
2533 && descriptor.path_pattern() == "/orders/:id")
2534 );
2535 assert!(
2536 descriptors
2537 .iter()
2538 .any(|descriptor| descriptor.method() == Method::GET
2539 && descriptor.path_pattern() == "/healthz")
2540 );
2541 assert!(
2542 descriptors
2543 .iter()
2544 .any(|descriptor| descriptor.method() == Method::GET
2545 && descriptor.path_pattern() == "/readyz")
2546 );
2547 assert!(
2548 descriptors
2549 .iter()
2550 .any(|descriptor| descriptor.method() == Method::GET
2551 && descriptor.path_pattern() == "/livez")
2552 );
2553 }
2554
2555 #[tokio::test]
2556 async fn lifecycle_hooks_fire_on_start_and_shutdown() {
2557 let started = Arc::new(AtomicBool::new(false));
2558 let shutdown = Arc::new(AtomicBool::new(false));
2559
2560 let started_flag = started.clone();
2561 let shutdown_flag = shutdown.clone();
2562
2563 let ingress = HttpIngress::<()>::new()
2564 .bind("127.0.0.1:0")
2565 .on_start(move || {
2566 started_flag.store(true, Ordering::SeqCst);
2567 })
2568 .on_shutdown(move || {
2569 shutdown_flag.store(true, Ordering::SeqCst);
2570 })
2571 .graceful_shutdown(Duration::from_millis(50));
2572
2573 ingress
2574 .run_with_shutdown_signal((), async {
2575 tokio::time::sleep(Duration::from_millis(20)).await;
2576 })
2577 .await
2578 .expect("server should exit gracefully");
2579
2580 assert!(started.load(Ordering::SeqCst));
2581 assert!(shutdown.load(Ordering::SeqCst));
2582 }
2583
2584 #[tokio::test]
2585 async fn graceful_shutdown_drains_in_flight_requests_before_exit() {
2586 #[derive(Clone)]
2587 struct SlowDrainRoute;
2588
2589 #[async_trait]
2590 impl Transition<(), String> for SlowDrainRoute {
2591 type Error = String;
2592 type Resources = ();
2593
2594 async fn run(
2595 &self,
2596 _state: (),
2597 _resources: &Self::Resources,
2598 _bus: &mut Bus,
2599 ) -> Outcome<String, Self::Error> {
2600 tokio::time::sleep(Duration::from_millis(120)).await;
2601 Outcome::next("drained-ok".to_string())
2602 }
2603 }
2604
2605 let probe = std::net::TcpListener::bind("127.0.0.1:0").expect("bind probe");
2606 let addr = probe.local_addr().expect("local addr");
2607 drop(probe);
2608
2609 let ingress = HttpIngress::<()>::new()
2610 .bind(addr.to_string())
2611 .graceful_shutdown(Duration::from_millis(500))
2612 .get(
2613 "/drain",
2614 Axon::<(), (), String, ()>::new("SlowDrain").then(SlowDrainRoute),
2615 );
2616
2617 let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel::<()>();
2618 let server = tokio::spawn(async move {
2619 ingress
2620 .run_with_shutdown_signal((), async move {
2621 let _ = shutdown_rx.await;
2622 })
2623 .await
2624 });
2625
2626 let mut stream = connect_with_retry(addr).await;
2627 stream
2628 .write_all(b"GET /drain HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n")
2629 .await
2630 .expect("write request");
2631
2632 tokio::time::sleep(Duration::from_millis(20)).await;
2633 let _ = shutdown_tx.send(());
2634
2635 let mut buf = Vec::new();
2636 stream.read_to_end(&mut buf).await.expect("read response");
2637 let response = String::from_utf8_lossy(&buf);
2638 assert!(response.starts_with("HTTP/1.1 200"), "{response}");
2639 assert!(response.contains("drained-ok"), "{response}");
2640
2641 server
2642 .await
2643 .expect("server join")
2644 .expect("server shutdown should succeed");
2645 }
2646
2647 #[tokio::test]
2648 async fn serve_dir_serves_static_file_with_cache_and_metadata_headers() {
2649 let temp = tempdir().expect("tempdir");
2650 let root = temp.path().join("public");
2651 fs::create_dir_all(&root).expect("create dir");
2652 let file = root.join("hello.txt");
2653 fs::write(&file, "hello static").expect("write file");
2654
2655 let ingress =
2656 Ranvier::http::<()>().serve_dir("/static", root.to_string_lossy().to_string());
2657 let app = crate::test_harness::TestApp::new(ingress, ());
2658 let response = app
2659 .send(crate::test_harness::TestRequest::get("/static/hello.txt"))
2660 .await
2661 .expect("request should succeed");
2662
2663 assert_eq!(response.status(), StatusCode::OK);
2664 assert_eq!(response.text().expect("utf8"), "hello static");
2665 assert!(response.header("cache-control").is_some());
2666 let has_metadata_header =
2667 response.header("etag").is_some() || response.header("last-modified").is_some();
2668 assert!(has_metadata_header);
2669 }
2670
2671 #[tokio::test]
2672 async fn spa_fallback_returns_index_for_unmatched_path() {
2673 let temp = tempdir().expect("tempdir");
2674 let index = temp.path().join("index.html");
2675 fs::write(&index, "<html><body>spa</body></html>").expect("write index");
2676
2677 let ingress = Ranvier::http::<()>().spa_fallback(index.to_string_lossy().to_string());
2678 let app = crate::test_harness::TestApp::new(ingress, ());
2679 let response = app
2680 .send(crate::test_harness::TestRequest::get("/dashboard/settings"))
2681 .await
2682 .expect("request should succeed");
2683
2684 assert_eq!(response.status(), StatusCode::OK);
2685 assert!(response.text().expect("utf8").contains("spa"));
2686 }
2687
2688 #[tokio::test]
2689 async fn static_compression_layer_sets_content_encoding_for_gzip_client() {
2690 let temp = tempdir().expect("tempdir");
2691 let root = temp.path().join("public");
2692 fs::create_dir_all(&root).expect("create dir");
2693 let file = root.join("compressed.txt");
2694 fs::write(&file, "compress me ".repeat(400)).expect("write file");
2695
2696 let ingress = Ranvier::http::<()>()
2697 .serve_dir("/static", root.to_string_lossy().to_string())
2698 .compression_layer();
2699 let app = crate::test_harness::TestApp::new(ingress, ());
2700 let response = app
2701 .send(
2702 crate::test_harness::TestRequest::get("/static/compressed.txt")
2703 .header("accept-encoding", "gzip"),
2704 )
2705 .await
2706 .expect("request should succeed");
2707
2708 assert_eq!(response.status(), StatusCode::OK);
2709 assert_eq!(
2710 response
2711 .header("content-encoding")
2712 .and_then(|value| value.to_str().ok()),
2713 Some("gzip")
2714 );
2715 }
2716
2717 #[tokio::test]
2718 async fn drain_connections_completes_before_timeout() {
2719 let mut connections = tokio::task::JoinSet::new();
2720 connections.spawn(async {
2721 tokio::time::sleep(Duration::from_millis(20)).await;
2722 });
2723
2724 let timed_out = drain_connections(&mut connections, Duration::from_millis(200)).await;
2725 assert!(!timed_out);
2726 assert!(connections.is_empty());
2727 }
2728
2729 #[tokio::test]
2730 async fn drain_connections_times_out_and_aborts() {
2731 let mut connections = tokio::task::JoinSet::new();
2732 connections.spawn(async {
2733 tokio::time::sleep(Duration::from_secs(10)).await;
2734 });
2735
2736 let timed_out = drain_connections(&mut connections, Duration::from_millis(10)).await;
2737 assert!(timed_out);
2738 assert!(connections.is_empty());
2739 }
2740
2741 #[tokio::test]
2742 async fn timeout_layer_returns_408_for_slow_route() {
2743 #[derive(Clone)]
2744 struct SlowRoute;
2745
2746 #[async_trait]
2747 impl Transition<(), String> for SlowRoute {
2748 type Error = String;
2749 type Resources = ();
2750
2751 async fn run(
2752 &self,
2753 _state: (),
2754 _resources: &Self::Resources,
2755 _bus: &mut Bus,
2756 ) -> Outcome<String, Self::Error> {
2757 tokio::time::sleep(Duration::from_millis(80)).await;
2758 Outcome::next("slow-ok".to_string())
2759 }
2760 }
2761
2762 let probe = std::net::TcpListener::bind("127.0.0.1:0").expect("bind probe");
2763 let addr = probe.local_addr().expect("local addr");
2764 drop(probe);
2765
2766 let ingress = HttpIngress::<()>::new()
2767 .bind(addr.to_string())
2768 .timeout_layer(Duration::from_millis(10))
2769 .get(
2770 "/slow",
2771 Axon::<(), (), String, ()>::new("Slow").then(SlowRoute),
2772 );
2773
2774 let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel::<()>();
2775 let server = tokio::spawn(async move {
2776 ingress
2777 .run_with_shutdown_signal((), async move {
2778 let _ = shutdown_rx.await;
2779 })
2780 .await
2781 });
2782
2783 let mut stream = connect_with_retry(addr).await;
2784 stream
2785 .write_all(b"GET /slow HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n")
2786 .await
2787 .expect("write request");
2788
2789 let mut buf = Vec::new();
2790 stream.read_to_end(&mut buf).await.expect("read response");
2791 let response = String::from_utf8_lossy(&buf);
2792 assert!(response.starts_with("HTTP/1.1 408"), "{response}");
2793
2794 let _ = shutdown_tx.send(());
2795 server
2796 .await
2797 .expect("server join")
2798 .expect("server shutdown should succeed");
2799 }
2800
2801}