1use base64::Engine;
17use bytes::Bytes;
18use futures_util::{SinkExt, StreamExt};
19use http::{Method, Request, Response, StatusCode};
20use http_body_util::{BodyExt, Full};
21use hyper::body::Incoming;
22use hyper::server::conn::http1;
23use hyper::upgrade::Upgraded;
24use hyper_util::rt::TokioIo;
25use ranvier_core::event::{EventSink, EventSource};
26use ranvier_core::prelude::*;
27use ranvier_runtime::Axon;
28use serde::Serialize;
29use serde::de::DeserializeOwned;
30use sha1::{Digest, Sha1};
31use std::collections::HashMap;
32use std::convert::Infallible;
33use std::future::Future;
34use std::net::SocketAddr;
35use std::pin::Pin;
36use std::sync::Arc;
37use std::time::Duration;
38use tokio::net::TcpListener;
39use tokio::sync::Mutex;
40use tokio_tungstenite::WebSocketStream;
41use tokio_tungstenite::tungstenite::{Error as WsWireError, Message as WsWireMessage};
42use tracing::Instrument;
43
44use crate::response::{HttpResponse, IntoResponse, outcome_to_response_with_error};
45
46pub struct Ranvier;
51
52impl Ranvier {
53 pub fn http<R>() -> HttpIngress<R>
55 where
56 R: ranvier_core::transition::ResourceRequirement + Clone,
57 {
58 HttpIngress::new()
59 }
60}
61
62type RouteHandler<R> = Arc<
64 dyn Fn(http::request::Parts, &R) -> Pin<Box<dyn Future<Output = HttpResponse> + Send>>
65 + Send
66 + Sync,
67>;
68
69#[derive(Clone)]
71struct BoxService(
72 Arc<
73 dyn Fn(Request<Incoming>) -> Pin<Box<dyn Future<Output = Result<HttpResponse, Infallible>> + Send>>
74 + Send
75 + Sync,
76 >,
77);
78
79impl BoxService {
80 fn new<F, Fut>(f: F) -> Self
81 where
82 F: Fn(Request<Incoming>) -> Fut + Send + Sync + 'static,
83 Fut: Future<Output = Result<HttpResponse, Infallible>> + Send + 'static,
84 {
85 Self(Arc::new(move |req| Box::pin(f(req))))
86 }
87
88 fn call(&self, req: Request<Incoming>) -> Pin<Box<dyn Future<Output = Result<HttpResponse, Infallible>> + Send>> {
89 (self.0)(req)
90 }
91}
92
93impl hyper::service::Service<Request<Incoming>> for BoxService {
94 type Response = HttpResponse;
95 type Error = Infallible;
96 type Future = Pin<Box<dyn Future<Output = Result<HttpResponse, Infallible>> + Send>>;
97
98 fn call(&self, req: Request<Incoming>) -> Self::Future {
99 (self.0)(req)
100 }
101}
102
103type BoxHttpService = BoxService;
104type ServiceLayer = Arc<dyn Fn(BoxHttpService) -> BoxHttpService + Send + Sync>;
105type LifecycleHook = Arc<dyn Fn() + Send + Sync>;
106type BusInjector = Arc<dyn Fn(&http::request::Parts, &mut Bus) + Send + Sync + 'static>;
107type WsSessionFuture = Pin<Box<dyn Future<Output = ()> + Send>>;
108type WsSessionHandler<R> =
109 Arc<dyn Fn(WebSocketConnection, Arc<R>, Bus) -> WsSessionFuture + Send + Sync>;
110type HealthCheckFuture = Pin<Box<dyn Future<Output = Result<(), String>> + Send>>;
111type HealthCheckFn<R> = Arc<dyn Fn(Arc<R>) -> HealthCheckFuture + Send + Sync>;
112const REQUEST_ID_HEADER: &str = "x-request-id";
113const WS_UPGRADE_TOKEN: &str = "websocket";
114const WS_GUID: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
115
116#[derive(Clone)]
117struct NamedHealthCheck<R> {
118 name: String,
119 check: HealthCheckFn<R>,
120}
121
122#[derive(Clone)]
123struct HealthConfig<R> {
124 health_path: Option<String>,
125 readiness_path: Option<String>,
126 liveness_path: Option<String>,
127 checks: Vec<NamedHealthCheck<R>>,
128}
129
130impl<R> Default for HealthConfig<R> {
131 fn default() -> Self {
132 Self {
133 health_path: None,
134 readiness_path: None,
135 liveness_path: None,
136 checks: Vec::new(),
137 }
138 }
139}
140
141#[derive(Clone, Default)]
142struct StaticAssetsConfig {
143 mounts: Vec<StaticMount>,
144 spa_fallback: Option<String>,
145 cache_control: Option<String>,
146 enable_compression: bool,
147}
148
149#[derive(Clone)]
150struct StaticMount {
151 route_prefix: String,
152 directory: String,
153}
154
155#[derive(Serialize)]
156struct HealthReport {
157 status: &'static str,
158 probe: &'static str,
159 checks: Vec<HealthCheckReport>,
160}
161
162#[derive(Serialize)]
163struct HealthCheckReport {
164 name: String,
165 status: &'static str,
166 #[serde(skip_serializing_if = "Option::is_none")]
167 error: Option<String>,
168}
169
170fn timeout_middleware(timeout: Duration) -> ServiceLayer {
171 Arc::new(move |inner: BoxHttpService| {
172 BoxService::new(move |req: Request<Incoming>| {
173 let inner = inner.clone();
174 async move {
175 match tokio::time::timeout(timeout, inner.call(req)).await {
176 Ok(response) => response,
177 Err(_) => Ok(Response::builder()
178 .status(StatusCode::REQUEST_TIMEOUT)
179 .body(
180 Full::new(Bytes::from("Request Timeout"))
181 .map_err(|never| match never {})
182 .boxed(),
183 )
184 .unwrap()),
185 }
186 }
187 })
188 })
189}
190
191fn request_id_middleware() -> ServiceLayer {
192 Arc::new(move |inner: BoxHttpService| {
193 BoxService::new(move |req: Request<Incoming>| {
194 let inner = inner.clone();
195 async move {
196 let mut req = req;
197 let request_id = req
198 .headers()
199 .get(REQUEST_ID_HEADER)
200 .cloned()
201 .unwrap_or_else(|| {
202 http::HeaderValue::from_str(&uuid::Uuid::new_v4().to_string())
203 .unwrap_or_else(|_| {
204 http::HeaderValue::from_static("request-id-unavailable")
205 })
206 });
207 req.headers_mut()
208 .insert(REQUEST_ID_HEADER, request_id.clone());
209 let mut response = inner.call(req).await?;
210 response
211 .headers_mut()
212 .insert(REQUEST_ID_HEADER, request_id);
213 Ok(response)
214 }
215 })
216 })
217}
218
219#[derive(Clone, Debug, Default, PartialEq, Eq)]
220pub struct PathParams {
221 values: HashMap<String, String>,
222}
223
224#[derive(Clone, Debug, PartialEq, Eq)]
226pub struct HttpRouteDescriptor {
227 method: Method,
228 path_pattern: String,
229}
230
231impl HttpRouteDescriptor {
232 pub fn new(method: Method, path_pattern: impl Into<String>) -> Self {
233 Self {
234 method,
235 path_pattern: path_pattern.into(),
236 }
237 }
238
239 pub fn method(&self) -> &Method {
240 &self.method
241 }
242
243 pub fn path_pattern(&self) -> &str {
244 &self.path_pattern
245 }
246}
247
248#[derive(Clone, Debug, PartialEq, Eq, Serialize)]
250pub struct WebSocketSessionContext {
251 connection_id: uuid::Uuid,
252 path: String,
253 query: Option<String>,
254}
255
256impl WebSocketSessionContext {
257 pub fn connection_id(&self) -> uuid::Uuid {
258 self.connection_id
259 }
260
261 pub fn path(&self) -> &str {
262 &self.path
263 }
264
265 pub fn query(&self) -> Option<&str> {
266 self.query.as_deref()
267 }
268}
269
270#[derive(Clone, Debug, PartialEq, Eq)]
272pub enum WebSocketEvent {
273 Text(String),
274 Binary(Vec<u8>),
275 Ping(Vec<u8>),
276 Pong(Vec<u8>),
277 Close,
278}
279
280impl WebSocketEvent {
281 pub fn text(value: impl Into<String>) -> Self {
282 Self::Text(value.into())
283 }
284
285 pub fn binary(value: impl Into<Vec<u8>>) -> Self {
286 Self::Binary(value.into())
287 }
288
289 pub fn json<T>(value: &T) -> Result<Self, serde_json::Error>
290 where
291 T: Serialize,
292 {
293 let text = serde_json::to_string(value)?;
294 Ok(Self::Text(text))
295 }
296}
297
298#[derive(Debug, thiserror::Error)]
299pub enum WebSocketError {
300 #[error("websocket wire error: {0}")]
301 Wire(#[from] WsWireError),
302 #[error("json serialization failed: {0}")]
303 JsonSerialize(#[source] serde_json::Error),
304 #[error("json deserialization failed: {0}")]
305 JsonDeserialize(#[source] serde_json::Error),
306 #[error("expected text or binary frame for json payload")]
307 NonDataFrame,
308}
309
310type WsServerStream = WebSocketStream<TokioIo<Upgraded>>;
311type WsServerSink = futures_util::stream::SplitSink<WsServerStream, WsWireMessage>;
312type WsServerSource = futures_util::stream::SplitStream<WsServerStream>;
313
314pub struct WebSocketConnection {
316 sink: Mutex<WsServerSink>,
317 source: Mutex<WsServerSource>,
318 session: WebSocketSessionContext,
319}
320
321impl WebSocketConnection {
322 fn new(stream: WsServerStream, session: WebSocketSessionContext) -> Self {
323 let (sink, source) = stream.split();
324 Self {
325 sink: Mutex::new(sink),
326 source: Mutex::new(source),
327 session,
328 }
329 }
330
331 pub fn session(&self) -> &WebSocketSessionContext {
332 &self.session
333 }
334
335 pub async fn send(&self, event: WebSocketEvent) -> Result<(), WebSocketError> {
336 let mut sink = self.sink.lock().await;
337 sink.send(event.into_wire_message()).await?;
338 Ok(())
339 }
340
341 pub async fn send_json<T>(&self, value: &T) -> Result<(), WebSocketError>
342 where
343 T: Serialize,
344 {
345 let event = WebSocketEvent::json(value).map_err(WebSocketError::JsonSerialize)?;
346 self.send(event).await
347 }
348
349 pub async fn next_json<T>(&mut self) -> Result<Option<T>, WebSocketError>
350 where
351 T: DeserializeOwned,
352 {
353 let Some(event) = self.recv_event().await? else {
354 return Ok(None);
355 };
356 match event {
357 WebSocketEvent::Text(text) => serde_json::from_str(&text)
358 .map(Some)
359 .map_err(WebSocketError::JsonDeserialize),
360 WebSocketEvent::Binary(bytes) => serde_json::from_slice(&bytes)
361 .map(Some)
362 .map_err(WebSocketError::JsonDeserialize),
363 _ => Err(WebSocketError::NonDataFrame),
364 }
365 }
366
367 async fn recv_event(&mut self) -> Result<Option<WebSocketEvent>, WsWireError> {
368 let mut source = self.source.lock().await;
369 while let Some(item) = source.next().await {
370 let message = item?;
371 if let Some(event) = WebSocketEvent::from_wire_message(message) {
372 return Ok(Some(event));
373 }
374 }
375 Ok(None)
376 }
377}
378
379impl WebSocketEvent {
380 fn from_wire_message(message: WsWireMessage) -> Option<Self> {
381 match message {
382 WsWireMessage::Text(value) => Some(Self::Text(value.to_string())),
383 WsWireMessage::Binary(value) => Some(Self::Binary(value.to_vec())),
384 WsWireMessage::Ping(value) => Some(Self::Ping(value.to_vec())),
385 WsWireMessage::Pong(value) => Some(Self::Pong(value.to_vec())),
386 WsWireMessage::Close(_) => Some(Self::Close),
387 WsWireMessage::Frame(_) => None,
388 }
389 }
390
391 fn into_wire_message(self) -> WsWireMessage {
392 match self {
393 Self::Text(value) => WsWireMessage::Text(value),
394 Self::Binary(value) => WsWireMessage::Binary(value),
395 Self::Ping(value) => WsWireMessage::Ping(value),
396 Self::Pong(value) => WsWireMessage::Pong(value),
397 Self::Close => WsWireMessage::Close(None),
398 }
399 }
400}
401
402#[async_trait::async_trait]
403impl EventSource<WebSocketEvent> for WebSocketConnection {
404 async fn next_event(&mut self) -> Option<WebSocketEvent> {
405 match self.recv_event().await {
406 Ok(event) => event,
407 Err(error) => {
408 tracing::warn!(ranvier.ws.error = %error, "websocket source read failed");
409 None
410 }
411 }
412 }
413}
414
415#[async_trait::async_trait]
416impl EventSink<WebSocketEvent> for WebSocketConnection {
417 type Error = WebSocketError;
418
419 async fn send_event(&self, event: WebSocketEvent) -> Result<(), Self::Error> {
420 self.send(event).await
421 }
422}
423
424#[async_trait::async_trait]
425impl EventSink<String> for WebSocketConnection {
426 type Error = WebSocketError;
427
428 async fn send_event(&self, event: String) -> Result<(), Self::Error> {
429 self.send(WebSocketEvent::Text(event)).await
430 }
431}
432
433#[async_trait::async_trait]
434impl EventSink<Vec<u8>> for WebSocketConnection {
435 type Error = WebSocketError;
436
437 async fn send_event(&self, event: Vec<u8>) -> Result<(), Self::Error> {
438 self.send(WebSocketEvent::Binary(event)).await
439 }
440}
441
442impl PathParams {
443 pub fn new(values: HashMap<String, String>) -> Self {
444 Self { values }
445 }
446
447 pub fn get(&self, key: &str) -> Option<&str> {
448 self.values.get(key).map(String::as_str)
449 }
450
451 pub fn as_map(&self) -> &HashMap<String, String> {
452 &self.values
453 }
454
455 pub fn into_inner(self) -> HashMap<String, String> {
456 self.values
457 }
458}
459
460#[derive(Clone, Debug, PartialEq, Eq)]
461enum RouteSegment {
462 Static(String),
463 Param(String),
464 Wildcard(String),
465}
466
467#[derive(Clone, Debug, PartialEq, Eq)]
468struct RoutePattern {
469 raw: String,
470 segments: Vec<RouteSegment>,
471}
472
473impl RoutePattern {
474 fn parse(path: &str) -> Self {
475 let segments = path_segments(path)
476 .into_iter()
477 .map(|segment| {
478 if let Some(name) = segment.strip_prefix(':') {
479 if !name.is_empty() {
480 return RouteSegment::Param(name.to_string());
481 }
482 }
483 if let Some(name) = segment.strip_prefix('*') {
484 if !name.is_empty() {
485 return RouteSegment::Wildcard(name.to_string());
486 }
487 }
488 RouteSegment::Static(segment.to_string())
489 })
490 .collect();
491
492 Self {
493 raw: path.to_string(),
494 segments,
495 }
496 }
497
498 fn match_path(&self, path: &str) -> Option<PathParams> {
499 let mut params = HashMap::new();
500 let path_segments = path_segments(path);
501 let mut pattern_index = 0usize;
502 let mut path_index = 0usize;
503
504 while pattern_index < self.segments.len() {
505 match &self.segments[pattern_index] {
506 RouteSegment::Static(expected) => {
507 let actual = path_segments.get(path_index)?;
508 if actual != expected {
509 return None;
510 }
511 pattern_index += 1;
512 path_index += 1;
513 }
514 RouteSegment::Param(name) => {
515 let actual = path_segments.get(path_index)?;
516 params.insert(name.clone(), (*actual).to_string());
517 pattern_index += 1;
518 path_index += 1;
519 }
520 RouteSegment::Wildcard(name) => {
521 let remaining = path_segments[path_index..].join("/");
522 params.insert(name.clone(), remaining);
523 pattern_index += 1;
524 path_index = path_segments.len();
525 break;
526 }
527 }
528 }
529
530 if pattern_index == self.segments.len() && path_index == path_segments.len() {
531 Some(PathParams::new(params))
532 } else {
533 None
534 }
535 }
536}
537
538#[derive(Clone)]
539struct RouteEntry<R> {
540 method: Method,
541 pattern: RoutePattern,
542 handler: RouteHandler<R>,
543 layers: Arc<Vec<ServiceLayer>>,
544 apply_global_layers: bool,
545}
546
547fn path_segments(path: &str) -> Vec<&str> {
548 if path == "/" {
549 return Vec::new();
550 }
551
552 path.trim_matches('/')
553 .split('/')
554 .filter(|segment| !segment.is_empty())
555 .collect()
556}
557
558fn normalize_route_path(path: String) -> String {
559 if path.is_empty() {
560 return "/".to_string();
561 }
562 if path.starts_with('/') {
563 path
564 } else {
565 format!("/{path}")
566 }
567}
568
569fn find_matching_route<'a, R>(
570 routes: &'a [RouteEntry<R>],
571 method: &Method,
572 path: &str,
573) -> Option<(&'a RouteEntry<R>, PathParams)> {
574 for entry in routes {
575 if entry.method != *method {
576 continue;
577 }
578 if let Some(params) = entry.pattern.match_path(path) {
579 return Some((entry, params));
580 }
581 }
582 None
583}
584
585fn header_contains_token(
586 headers: &http::HeaderMap,
587 name: http::header::HeaderName,
588 token: &str,
589) -> bool {
590 headers
591 .get(name)
592 .and_then(|value| value.to_str().ok())
593 .map(|value| {
594 value
595 .split(',')
596 .any(|part| part.trim().eq_ignore_ascii_case(token))
597 })
598 .unwrap_or(false)
599}
600
601fn websocket_session_from_request<B>(req: &Request<B>) -> WebSocketSessionContext {
602 WebSocketSessionContext {
603 connection_id: uuid::Uuid::new_v4(),
604 path: req.uri().path().to_string(),
605 query: req.uri().query().map(str::to_string),
606 }
607}
608
609fn websocket_accept_key(client_key: &str) -> String {
610 let mut hasher = Sha1::new();
611 hasher.update(client_key.as_bytes());
612 hasher.update(WS_GUID.as_bytes());
613 let digest = hasher.finalize();
614 base64::engine::general_purpose::STANDARD.encode(digest)
615}
616
617fn websocket_bad_request(message: &'static str) -> HttpResponse {
618 Response::builder()
619 .status(StatusCode::BAD_REQUEST)
620 .body(
621 Full::new(Bytes::from(message))
622 .map_err(|never| match never {})
623 .boxed(),
624 )
625 .unwrap_or_else(|_| {
626 Response::new(
627 Full::new(Bytes::new())
628 .map_err(|never| match never {})
629 .boxed(),
630 )
631 })
632}
633
634fn websocket_upgrade_response<B>(
635 req: &mut Request<B>,
636) -> Result<(HttpResponse, hyper::upgrade::OnUpgrade), HttpResponse> {
637 if req.method() != Method::GET {
638 return Err(websocket_bad_request(
639 "WebSocket upgrade requires GET method",
640 ));
641 }
642
643 if !header_contains_token(req.headers(), http::header::CONNECTION, "upgrade") {
644 return Err(websocket_bad_request(
645 "Missing Connection: upgrade header for WebSocket",
646 ));
647 }
648
649 if !header_contains_token(req.headers(), http::header::UPGRADE, WS_UPGRADE_TOKEN) {
650 return Err(websocket_bad_request("Missing Upgrade: websocket header"));
651 }
652
653 if let Some(version) = req.headers().get("sec-websocket-version") {
654 if version != "13" {
655 return Err(websocket_bad_request(
656 "Unsupported Sec-WebSocket-Version (expected 13)",
657 ));
658 }
659 }
660
661 let Some(client_key) = req
662 .headers()
663 .get("sec-websocket-key")
664 .and_then(|value| value.to_str().ok())
665 else {
666 return Err(websocket_bad_request(
667 "Missing Sec-WebSocket-Key header for WebSocket",
668 ));
669 };
670
671 let accept_key = websocket_accept_key(client_key);
672 let on_upgrade = hyper::upgrade::on(req);
673 let response = Response::builder()
674 .status(StatusCode::SWITCHING_PROTOCOLS)
675 .header(http::header::UPGRADE, WS_UPGRADE_TOKEN)
676 .header(http::header::CONNECTION, "Upgrade")
677 .header("sec-websocket-accept", accept_key)
678 .body(
679 Full::new(Bytes::new())
680 .map_err(|never| match never {})
681 .boxed(),
682 )
683 .unwrap_or_else(|_| {
684 Response::new(
685 Full::new(Bytes::new())
686 .map_err(|never| match never {})
687 .boxed(),
688 )
689 });
690
691 Ok((response, on_upgrade))
692}
693
694pub struct HttpIngress<R = ()> {
700 addr: Option<String>,
702 routes: Vec<RouteEntry<R>>,
704 fallback: Option<RouteHandler<R>>,
706 layers: Vec<ServiceLayer>,
708 on_start: Option<LifecycleHook>,
710 on_shutdown: Option<LifecycleHook>,
712 graceful_shutdown_timeout: Duration,
714 bus_injectors: Vec<BusInjector>,
716 static_assets: StaticAssetsConfig,
718 health: HealthConfig<R>,
720 #[cfg(feature = "http3")]
721 http3_config: Option<crate::http3::Http3Config>,
722 #[cfg(feature = "http3")]
723 alt_svc_h3_port: Option<u16>,
724 active_intervention: bool,
726 policy_registry: Option<ranvier_core::policy::PolicyRegistry>,
728 _phantom: std::marker::PhantomData<R>,
729}
730
731impl<R> HttpIngress<R>
732where
733 R: ranvier_core::transition::ResourceRequirement + Clone + Send + Sync + 'static,
734{
735 pub fn new() -> Self {
737 Self {
738 addr: None,
739 routes: Vec::new(),
740 fallback: None,
741 layers: Vec::new(),
742 on_start: None,
743 on_shutdown: None,
744 graceful_shutdown_timeout: Duration::from_secs(30),
745 bus_injectors: Vec::new(),
746 static_assets: StaticAssetsConfig::default(),
747 health: HealthConfig::default(),
748 #[cfg(feature = "http3")]
749 http3_config: None,
750 #[cfg(feature = "http3")]
751 alt_svc_h3_port: None,
752 active_intervention: false,
753 policy_registry: None,
754 _phantom: std::marker::PhantomData,
755 }
756 }
757
758 pub fn bind(mut self, addr: impl Into<String>) -> Self {
760 self.addr = Some(addr.into());
761 self
762 }
763
764 pub fn active_intervention(mut self) -> Self {
768 self.active_intervention = true;
769 self
770 }
771
772 pub fn policy_registry(mut self, registry: ranvier_core::policy::PolicyRegistry) -> Self {
774 self.policy_registry = Some(registry);
775 self
776 }
777
778 pub fn on_start<F>(mut self, callback: F) -> Self
780 where
781 F: Fn() + Send + Sync + 'static,
782 {
783 self.on_start = Some(Arc::new(callback));
784 self
785 }
786
787 pub fn on_shutdown<F>(mut self, callback: F) -> Self
789 where
790 F: Fn() + Send + Sync + 'static,
791 {
792 self.on_shutdown = Some(Arc::new(callback));
793 self
794 }
795
796 pub fn graceful_shutdown(mut self, timeout: Duration) -> Self {
798 self.graceful_shutdown_timeout = timeout;
799 self
800 }
801
802 pub fn timeout_layer(mut self, timeout: Duration) -> Self {
805 self.layers.push(timeout_middleware(timeout));
806 self
807 }
808
809 pub fn request_id_layer(mut self) -> Self {
813 self.layers.push(request_id_middleware());
814 self
815 }
816
817 pub fn bus_injector<F>(mut self, injector: F) -> Self
822 where
823 F: Fn(&http::request::Parts, &mut Bus) + Send + Sync + 'static,
824 {
825 self.bus_injectors.push(Arc::new(injector));
826 self
827 }
828
829 #[cfg(feature = "http3")]
831 pub fn enable_http3(mut self, config: crate::http3::Http3Config) -> Self {
832 self.http3_config = Some(config);
833 self
834 }
835
836 #[cfg(feature = "http3")]
838 pub fn alt_svc_h3(mut self, port: u16) -> Self {
839 self.alt_svc_h3_port = Some(port);
840 self
841 }
842
843 pub fn route_descriptors(&self) -> Vec<HttpRouteDescriptor> {
845 let mut descriptors = self
846 .routes
847 .iter()
848 .map(|entry| HttpRouteDescriptor::new(entry.method.clone(), entry.pattern.raw.clone()))
849 .collect::<Vec<_>>();
850
851 if let Some(path) = &self.health.health_path {
852 descriptors.push(HttpRouteDescriptor::new(Method::GET, path.clone()));
853 }
854 if let Some(path) = &self.health.readiness_path {
855 descriptors.push(HttpRouteDescriptor::new(Method::GET, path.clone()));
856 }
857 if let Some(path) = &self.health.liveness_path {
858 descriptors.push(HttpRouteDescriptor::new(Method::GET, path.clone()));
859 }
860
861 descriptors
862 }
863
864 pub fn serve_dir(
868 mut self,
869 route_prefix: impl Into<String>,
870 directory: impl Into<String>,
871 ) -> Self {
872 self.static_assets.mounts.push(StaticMount {
873 route_prefix: normalize_route_path(route_prefix.into()),
874 directory: directory.into(),
875 });
876 if self.static_assets.cache_control.is_none() {
877 self.static_assets.cache_control = Some("public, max-age=3600".to_string());
878 }
879 self
880 }
881
882 pub fn spa_fallback(mut self, file_path: impl Into<String>) -> Self {
886 self.static_assets.spa_fallback = Some(file_path.into());
887 self
888 }
889
890 pub fn static_cache_control(mut self, cache_control: impl Into<String>) -> Self {
892 self.static_assets.cache_control = Some(cache_control.into());
893 self
894 }
895
896 pub fn compression_layer(mut self) -> Self {
898 self.static_assets.enable_compression = true;
899 self
900 }
901
902 pub fn ws<H, Fut>(mut self, path: impl Into<String>, handler: H) -> Self
909 where
910 H: Fn(WebSocketConnection, Arc<R>, Bus) -> Fut + Send + Sync + 'static,
911 Fut: Future<Output = ()> + Send + 'static,
912 {
913 let path_str: String = path.into();
914 let ws_handler: WsSessionHandler<R> = Arc::new(move |connection, resources, bus| {
915 Box::pin(handler(connection, resources, bus))
916 });
917 let bus_injectors = Arc::new(self.bus_injectors.clone());
918 let path_for_pattern = path_str.clone();
919 let path_for_handler = path_str;
920
921 let route_handler: RouteHandler<R> =
922 Arc::new(move |parts: http::request::Parts, res: &R| {
923 let ws_handler = ws_handler.clone();
924 let bus_injectors = bus_injectors.clone();
925 let resources = Arc::new(res.clone());
926 let path = path_for_handler.clone();
927
928 Box::pin(async move {
929 let request_id = uuid::Uuid::new_v4().to_string();
930 let span = tracing::info_span!(
931 "WebSocketUpgrade",
932 ranvier.ws.path = %path,
933 ranvier.ws.request_id = %request_id
934 );
935
936 async move {
937 let mut bus = Bus::new();
938 for injector in bus_injectors.iter() {
939 injector(&parts, &mut bus);
940 }
941
942 let mut req = Request::from_parts(parts, ());
944 let session = websocket_session_from_request(&req);
945 bus.insert(session.clone());
946
947 let (response, on_upgrade) = match websocket_upgrade_response(&mut req) {
948 Ok(result) => result,
949 Err(error_response) => return error_response,
950 };
951
952 tokio::spawn(async move {
953 match on_upgrade.await {
954 Ok(upgraded) => {
955 let stream = WebSocketStream::from_raw_socket(
956 TokioIo::new(upgraded),
957 tokio_tungstenite::tungstenite::protocol::Role::Server,
958 None,
959 )
960 .await;
961 let connection = WebSocketConnection::new(stream, session);
962 ws_handler(connection, resources, bus).await;
963 }
964 Err(error) => {
965 tracing::warn!(
966 ranvier.ws.path = %path,
967 ranvier.ws.error = %error,
968 "websocket upgrade failed"
969 );
970 }
971 }
972 });
973
974 response
975 }
976 .instrument(span)
977 .await
978 }) as Pin<Box<dyn Future<Output = HttpResponse> + Send>>
979 });
980
981 self.routes.push(RouteEntry {
982 method: Method::GET,
983 pattern: RoutePattern::parse(&path_for_pattern),
984 handler: route_handler,
985 layers: Arc::new(Vec::new()),
986 apply_global_layers: true,
987 });
988
989 self
990 }
991
992 pub fn health_endpoint(mut self, path: impl Into<String>) -> Self {
997 self.health.health_path = Some(normalize_route_path(path.into()));
998 self
999 }
1000
1001 pub fn health_check<F, Fut, Err>(mut self, name: impl Into<String>, check: F) -> Self
1005 where
1006 F: Fn(Arc<R>) -> Fut + Send + Sync + 'static,
1007 Fut: Future<Output = Result<(), Err>> + Send + 'static,
1008 Err: ToString + Send + 'static,
1009 {
1010 if self.health.health_path.is_none() {
1011 self.health.health_path = Some("/health".to_string());
1012 }
1013
1014 let check_fn: HealthCheckFn<R> = Arc::new(move |resources: Arc<R>| {
1015 let fut = check(resources);
1016 Box::pin(async move { fut.await.map_err(|error| error.to_string()) })
1017 });
1018
1019 self.health.checks.push(NamedHealthCheck {
1020 name: name.into(),
1021 check: check_fn,
1022 });
1023 self
1024 }
1025
1026 pub fn readiness_liveness(
1028 mut self,
1029 readiness_path: impl Into<String>,
1030 liveness_path: impl Into<String>,
1031 ) -> Self {
1032 self.health.readiness_path = Some(normalize_route_path(readiness_path.into()));
1033 self.health.liveness_path = Some(normalize_route_path(liveness_path.into()));
1034 self
1035 }
1036
1037 pub fn readiness_liveness_default(self) -> Self {
1039 self.readiness_liveness("/ready", "/live")
1040 }
1041
1042 pub fn route<Out, E>(self, path: impl Into<String>, circuit: Axon<(), Out, E, R>) -> Self
1044 where
1045 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1046 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1047 {
1048 self.route_method(Method::GET, path, circuit)
1049 }
1050 pub fn route_method<Out, E>(
1059 self,
1060 method: Method,
1061 path: impl Into<String>,
1062 circuit: Axon<(), Out, E, R>,
1063 ) -> Self
1064 where
1065 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1066 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1067 {
1068 self.route_method_with_error(method, path, circuit, |error| {
1069 (
1070 StatusCode::INTERNAL_SERVER_ERROR,
1071 format!("Error: {:?}", error),
1072 )
1073 .into_response()
1074 })
1075 }
1076
1077 pub fn route_method_with_error<Out, E, H>(
1078 self,
1079 method: Method,
1080 path: impl Into<String>,
1081 circuit: Axon<(), Out, E, R>,
1082 error_handler: H,
1083 ) -> Self
1084 where
1085 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1086 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1087 H: Fn(&E) -> HttpResponse + Send + Sync + 'static,
1088 {
1089 self.route_method_with_error_and_layers(
1090 method,
1091 path,
1092 circuit,
1093 error_handler,
1094 Arc::new(Vec::new()),
1095 true,
1096 )
1097 }
1098
1099
1100
1101 fn route_method_with_error_and_layers<Out, E, H>(
1102 mut self,
1103 method: Method,
1104 path: impl Into<String>,
1105 circuit: Axon<(), Out, E, R>,
1106 error_handler: H,
1107 route_layers: Arc<Vec<ServiceLayer>>,
1108 apply_global_layers: bool,
1109 ) -> Self
1110 where
1111 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1112 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1113 H: Fn(&E) -> HttpResponse + Send + Sync + 'static,
1114 {
1115 let path_str: String = path.into();
1116 let circuit = Arc::new(circuit);
1117 let error_handler = Arc::new(error_handler);
1118 let route_bus_injectors = Arc::new(self.bus_injectors.clone());
1119 let path_for_pattern = path_str.clone();
1120 let path_for_handler = path_str;
1121 let method_for_pattern = method.clone();
1122 let method_for_handler = method;
1123
1124 let handler: RouteHandler<R> = Arc::new(move |parts: http::request::Parts, res: &R| {
1125 let circuit = circuit.clone();
1126 let error_handler = error_handler.clone();
1127 let route_bus_injectors = route_bus_injectors.clone();
1128 let res = res.clone();
1129 let path = path_for_handler.clone();
1130 let method = method_for_handler.clone();
1131
1132 Box::pin(async move {
1133 let request_id = uuid::Uuid::new_v4().to_string();
1134 let span = tracing::info_span!(
1135 "HTTPRequest",
1136 ranvier.http.method = %method,
1137 ranvier.http.path = %path,
1138 ranvier.http.request_id = %request_id
1139 );
1140
1141 async move {
1142 let mut bus = Bus::new();
1143 for injector in route_bus_injectors.iter() {
1144 injector(&parts, &mut bus);
1145 }
1146 let result = circuit.execute((), &res, &mut bus).await;
1147 outcome_to_response_with_error(result, |error| error_handler(error))
1148 }
1149 .instrument(span)
1150 .await
1151 }) as Pin<Box<dyn Future<Output = HttpResponse> + Send>>
1152 });
1153
1154 self.routes.push(RouteEntry {
1155 method: method_for_pattern,
1156 pattern: RoutePattern::parse(&path_for_pattern),
1157 handler,
1158 layers: route_layers,
1159 apply_global_layers,
1160 });
1161 self
1162 }
1163
1164 pub fn get<Out, E>(self, path: impl Into<String>, circuit: Axon<(), Out, E, R>) -> Self
1165 where
1166 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1167 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1168 {
1169 self.route_method(Method::GET, path, circuit)
1170 }
1171
1172 pub fn get_with_error<Out, E, H>(
1173 self,
1174 path: impl Into<String>,
1175 circuit: Axon<(), Out, E, R>,
1176 error_handler: H,
1177 ) -> Self
1178 where
1179 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1180 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1181 H: Fn(&E) -> HttpResponse + Send + Sync + 'static,
1182 {
1183 self.route_method_with_error(Method::GET, path, circuit, error_handler)
1184 }
1185
1186 pub fn post<Out, E>(self, path: impl Into<String>, circuit: Axon<(), Out, E, R>) -> Self
1187 where
1188 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1189 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1190 {
1191 self.route_method(Method::POST, path, circuit)
1192 }
1193
1194 pub fn put<Out, E>(self, path: impl Into<String>, circuit: Axon<(), Out, E, R>) -> Self
1195 where
1196 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1197 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1198 {
1199 self.route_method(Method::PUT, path, circuit)
1200 }
1201
1202 pub fn delete<Out, E>(self, path: impl Into<String>, circuit: Axon<(), Out, E, R>) -> Self
1203 where
1204 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1205 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1206 {
1207 self.route_method(Method::DELETE, path, circuit)
1208 }
1209
1210 pub fn patch<Out, E>(self, path: impl Into<String>, circuit: Axon<(), Out, E, R>) -> Self
1211 where
1212 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1213 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1214 {
1215 self.route_method(Method::PATCH, path, circuit)
1216 }
1217
1218 pub fn post_with_error<Out, E, H>(
1219 self,
1220 path: impl Into<String>,
1221 circuit: Axon<(), Out, E, R>,
1222 error_handler: H,
1223 ) -> Self
1224 where
1225 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1226 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1227 H: Fn(&E) -> HttpResponse + Send + Sync + 'static,
1228 {
1229 self.route_method_with_error(Method::POST, path, circuit, error_handler)
1230 }
1231
1232 pub fn put_with_error<Out, E, H>(
1233 self,
1234 path: impl Into<String>,
1235 circuit: Axon<(), Out, E, R>,
1236 error_handler: H,
1237 ) -> Self
1238 where
1239 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1240 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1241 H: Fn(&E) -> HttpResponse + Send + Sync + 'static,
1242 {
1243 self.route_method_with_error(Method::PUT, path, circuit, error_handler)
1244 }
1245
1246 pub fn delete_with_error<Out, E, H>(
1247 self,
1248 path: impl Into<String>,
1249 circuit: Axon<(), Out, E, R>,
1250 error_handler: H,
1251 ) -> Self
1252 where
1253 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1254 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1255 H: Fn(&E) -> HttpResponse + Send + Sync + 'static,
1256 {
1257 self.route_method_with_error(Method::DELETE, path, circuit, error_handler)
1258 }
1259
1260 pub fn patch_with_error<Out, E, H>(
1261 self,
1262 path: impl Into<String>,
1263 circuit: Axon<(), Out, E, R>,
1264 error_handler: H,
1265 ) -> Self
1266 where
1267 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1268 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1269 H: Fn(&E) -> HttpResponse + Send + Sync + 'static,
1270 {
1271 self.route_method_with_error(Method::PATCH, path, circuit, error_handler)
1272 }
1273
1274 pub fn fallback<Out, E>(mut self, circuit: Axon<(), Out, E, R>) -> Self
1285 where
1286 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1287 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1288 {
1289 let circuit = Arc::new(circuit);
1290 let fallback_bus_injectors = Arc::new(self.bus_injectors.clone());
1291
1292 let handler: RouteHandler<R> = Arc::new(move |parts: http::request::Parts, res: &R| {
1293 let circuit = circuit.clone();
1294 let fallback_bus_injectors = fallback_bus_injectors.clone();
1295 let res = res.clone();
1296 Box::pin(async move {
1297 let request_id = uuid::Uuid::new_v4().to_string();
1298 let span = tracing::info_span!(
1299 "HTTPRequest",
1300 ranvier.http.method = "FALLBACK",
1301 ranvier.http.request_id = %request_id
1302 );
1303
1304 async move {
1305 let mut bus = Bus::new();
1306 for injector in fallback_bus_injectors.iter() {
1307 injector(&parts, &mut bus);
1308 }
1309 let result: ranvier_core::Outcome<Out, E> =
1310 circuit.execute((), &res, &mut bus).await;
1311
1312 match result {
1313 Outcome::Next(output) => {
1314 let mut response = output.into_response();
1315 *response.status_mut() = StatusCode::NOT_FOUND;
1316 response
1317 }
1318 _ => Response::builder()
1319 .status(StatusCode::NOT_FOUND)
1320 .body(
1321 Full::new(Bytes::from("Not Found"))
1322 .map_err(|never| match never {})
1323 .boxed(),
1324 )
1325 .unwrap(),
1326 }
1327 }
1328 .instrument(span)
1329 .await
1330 }) as Pin<Box<dyn Future<Output = HttpResponse> + Send>>
1331 });
1332
1333 self.fallback = Some(handler);
1334 self
1335 }
1336
1337 pub async fn run(self, resources: R) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
1339 self.run_with_shutdown_signal(resources, shutdown_signal())
1340 .await
1341 }
1342
1343 async fn run_with_shutdown_signal<S>(
1344 self,
1345 resources: R,
1346 shutdown_signal: S,
1347 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>>
1348 where
1349 S: Future<Output = ()> + Send,
1350 {
1351 let addr_str = self.addr.as_deref().unwrap_or("127.0.0.1:3000");
1352 let addr: SocketAddr = addr_str.parse()?;
1353
1354 let mut raw_routes = self.routes;
1355 if self.active_intervention {
1356 let handler: RouteHandler<R> = Arc::new(|_parts, _res| {
1357 Box::pin(async move {
1358 Response::builder()
1359 .status(StatusCode::OK)
1360 .body(
1361 Full::new(Bytes::from("Intervention accepted"))
1362 .map_err(|never| match never {} as Infallible)
1363 .boxed(),
1364 )
1365 .unwrap()
1366 }) as Pin<Box<dyn Future<Output = HttpResponse> + Send>>
1367 });
1368
1369 raw_routes.push(RouteEntry {
1370 method: Method::POST,
1371 pattern: RoutePattern::parse("/_system/intervene/force_resume"),
1372 handler,
1373 layers: Arc::new(Vec::new()),
1374 apply_global_layers: true,
1375 });
1376 }
1377
1378 if let Some(registry) = self.policy_registry.clone() {
1379 let handler: RouteHandler<R> = Arc::new(move |_parts, _res| {
1380 let _registry = registry.clone();
1381 Box::pin(async move {
1382 Response::builder()
1386 .status(StatusCode::OK)
1387 .body(
1388 Full::new(Bytes::from("Policy registry active"))
1389 .map_err(|never| match never {} as Infallible)
1390 .boxed(),
1391 )
1392 .unwrap()
1393 }) as Pin<Box<dyn Future<Output = HttpResponse> + Send>>
1394 });
1395
1396 raw_routes.push(RouteEntry {
1397 method: Method::POST,
1398 pattern: RoutePattern::parse("/_system/policy/reload"),
1399 handler,
1400 layers: Arc::new(Vec::new()),
1401 apply_global_layers: true,
1402 });
1403 }
1404 let routes = Arc::new(raw_routes);
1405 let fallback = self.fallback;
1406 let layers = Arc::new(self.layers);
1407 let health = Arc::new(self.health);
1408 let static_assets = Arc::new(self.static_assets);
1409 let on_start = self.on_start;
1410 let on_shutdown = self.on_shutdown;
1411 let graceful_shutdown_timeout = self.graceful_shutdown_timeout;
1412 let resources = Arc::new(resources);
1413
1414 let listener = TcpListener::bind(addr).await?;
1415 tracing::info!("Ranvier HTTP Ingress listening on http://{}", addr);
1416 if let Some(callback) = on_start.as_ref() {
1417 callback();
1418 }
1419
1420 tokio::pin!(shutdown_signal);
1421 let mut connections = tokio::task::JoinSet::new();
1422
1423 loop {
1424 tokio::select! {
1425 _ = &mut shutdown_signal => {
1426 tracing::info!("Shutdown signal received. Draining in-flight connections.");
1427 break;
1428 }
1429 accept_result = listener.accept() => {
1430 let (stream, _) = accept_result?;
1431 let io = TokioIo::new(stream);
1432
1433 let routes = routes.clone();
1434 let fallback = fallback.clone();
1435 let resources = resources.clone();
1436 let layers = layers.clone();
1437 let health = health.clone();
1438 let static_assets = static_assets.clone();
1439 #[cfg(feature = "http3")]
1440 let alt_svc_h3_port = self.alt_svc_h3_port;
1441
1442 connections.spawn(async move {
1443 let service = build_http_service(
1444 routes,
1445 fallback,
1446 resources,
1447 layers,
1448 health,
1449 static_assets,
1450 #[cfg(feature = "http3")] alt_svc_h3_port,
1451 );
1452 if let Err(err) = http1::Builder::new()
1453 .serve_connection(io, service)
1454 .with_upgrades()
1455 .await
1456 {
1457 tracing::error!("Error serving connection: {:?}", err);
1458 }
1459 });
1460 }
1461 Some(join_result) = connections.join_next(), if !connections.is_empty() => {
1462 if let Err(err) = join_result {
1463 tracing::warn!("Connection task join error: {:?}", err);
1464 }
1465 }
1466 }
1467 }
1468
1469 let _timed_out = drain_connections(&mut connections, graceful_shutdown_timeout).await;
1470
1471 drop(resources);
1472 if let Some(callback) = on_shutdown.as_ref() {
1473 callback();
1474 }
1475
1476 Ok(())
1477 }
1478
1479 pub fn into_raw_service(self, resources: R) -> RawIngressService<R> {
1495 let routes = Arc::new(self.routes);
1496 let fallback = self.fallback;
1497 let layers = Arc::new(self.layers);
1498 let health = Arc::new(self.health);
1499 let static_assets = Arc::new(self.static_assets);
1500 let resources = Arc::new(resources);
1501
1502 RawIngressService {
1503 routes,
1504 fallback,
1505 layers,
1506 health,
1507 static_assets,
1508 resources,
1509 }
1510 }
1511}
1512
1513fn build_http_service<R>(
1514 routes: Arc<Vec<RouteEntry<R>>>,
1515 fallback: Option<RouteHandler<R>>,
1516 resources: Arc<R>,
1517 layers: Arc<Vec<ServiceLayer>>,
1518 health: Arc<HealthConfig<R>>,
1519 static_assets: Arc<StaticAssetsConfig>,
1520 #[cfg(feature = "http3")] alt_svc_port: Option<u16>,
1521) -> BoxHttpService
1522where
1523 R: ranvier_core::transition::ResourceRequirement + Clone + Send + Sync + 'static,
1524{
1525 BoxService::new(move |req: Request<Incoming>| {
1526 let routes = routes.clone();
1527 let fallback = fallback.clone();
1528 let resources = resources.clone();
1529 let layers = layers.clone();
1530 let health = health.clone();
1531 let static_assets = static_assets.clone();
1532
1533 async move {
1534 let mut req = req;
1535 let method = req.method().clone();
1536 let path = req.uri().path().to_string();
1537
1538 if let Some(response) =
1539 maybe_handle_health_request(&method, &path, &health, resources.clone()).await
1540 {
1541 return Ok::<_, Infallible>(response.into_response());
1542 }
1543
1544 if let Some((entry, params)) = find_matching_route(routes.as_slice(), &method, &path) {
1545 req.extensions_mut().insert(params);
1546 let effective_layers = if entry.apply_global_layers {
1547 merge_layers(&layers, &entry.layers)
1548 } else {
1549 entry.layers.clone()
1550 };
1551
1552 if effective_layers.is_empty() {
1553 let (parts, _) = req.into_parts();
1554 #[allow(unused_mut)]
1555 let mut res = (entry.handler)(parts, &resources).await;
1556 #[cfg(feature = "http3")]
1557 if let Some(port) = alt_svc_port {
1558 if let Ok(val) =
1559 http::HeaderValue::from_str(&format!("h3=\":{}\"; ma=86400", port))
1560 {
1561 res.headers_mut().insert(http::header::ALT_SVC, val);
1562 }
1563 }
1564 Ok::<_, Infallible>(res)
1565 } else {
1566 let route_service = build_route_service(
1567 entry.handler.clone(),
1568 resources.clone(),
1569 effective_layers,
1570 );
1571 #[allow(unused_mut)]
1572 let mut res = route_service.call(req).await;
1573 #[cfg(feature = "http3")]
1574 #[allow(irrefutable_let_patterns)]
1575 if let Ok(ref mut r) = res {
1576 if let Some(port) = alt_svc_port {
1577 if let Ok(val) =
1578 http::HeaderValue::from_str(&format!("h3=\":{}\"; ma=86400", port))
1579 {
1580 r.headers_mut().insert(http::header::ALT_SVC, val);
1581 }
1582 }
1583 }
1584 res
1585 }
1586 } else {
1587 let req =
1588 match maybe_handle_static_request(req, &method, &path, static_assets.as_ref())
1589 .await
1590 {
1591 Ok(req) => req,
1592 Err(response) => return Ok(response),
1593 };
1594
1595 #[allow(unused_mut)]
1596 let mut fallback_res = if let Some(ref fb) = fallback {
1597 if layers.is_empty() {
1598 let (parts, _) = req.into_parts();
1599 Ok(fb(parts, &resources).await)
1600 } else {
1601 let fallback_service =
1602 build_route_service(fb.clone(), resources.clone(), layers.clone());
1603 fallback_service.call(req).await
1604 }
1605 } else {
1606 Ok(Response::builder()
1607 .status(StatusCode::NOT_FOUND)
1608 .body(
1609 Full::new(Bytes::from("Not Found"))
1610 .map_err(|never| match never {})
1611 .boxed(),
1612 )
1613 .unwrap())
1614 };
1615
1616 #[cfg(feature = "http3")]
1617 if let Ok(r) = fallback_res.as_mut() {
1618 if let Some(port) = alt_svc_port {
1619 if let Ok(val) =
1620 http::HeaderValue::from_str(&format!("h3=\":{}\"; ma=86400", port))
1621 {
1622 r.headers_mut().insert(http::header::ALT_SVC, val);
1623 }
1624 }
1625 }
1626
1627 fallback_res
1628 }
1629 }
1630 })
1631}
1632
1633fn build_route_service<R>(
1634 handler: RouteHandler<R>,
1635 resources: Arc<R>,
1636 layers: Arc<Vec<ServiceLayer>>,
1637) -> BoxHttpService
1638where
1639 R: ranvier_core::transition::ResourceRequirement + Clone + Send + Sync + 'static,
1640{
1641 let mut service = BoxService::new(move |req: Request<Incoming>| {
1642 let handler = handler.clone();
1643 let resources = resources.clone();
1644 async move {
1645 let (parts, _) = req.into_parts();
1646 Ok::<_, Infallible>(handler(parts, &resources).await)
1647 }
1648 });
1649
1650 for layer in layers.iter() {
1651 service = layer(service);
1652 }
1653 service
1654}
1655
1656fn merge_layers(
1657 global_layers: &Arc<Vec<ServiceLayer>>,
1658 route_layers: &Arc<Vec<ServiceLayer>>,
1659) -> Arc<Vec<ServiceLayer>> {
1660 if global_layers.is_empty() {
1661 return route_layers.clone();
1662 }
1663 if route_layers.is_empty() {
1664 return global_layers.clone();
1665 }
1666
1667 let mut combined = Vec::with_capacity(global_layers.len() + route_layers.len());
1668 combined.extend(global_layers.iter().cloned());
1669 combined.extend(route_layers.iter().cloned());
1670 Arc::new(combined)
1671}
1672
1673async fn maybe_handle_health_request<R>(
1674 method: &Method,
1675 path: &str,
1676 health: &HealthConfig<R>,
1677 resources: Arc<R>,
1678) -> Option<HttpResponse>
1679where
1680 R: ranvier_core::transition::ResourceRequirement + Clone + Send + Sync + 'static,
1681{
1682 if method != Method::GET {
1683 return None;
1684 }
1685
1686 if let Some(liveness_path) = health.liveness_path.as_ref() {
1687 if path == liveness_path {
1688 return Some(health_json_response("liveness", true, Vec::new()));
1689 }
1690 }
1691
1692 if let Some(readiness_path) = health.readiness_path.as_ref() {
1693 if path == readiness_path {
1694 let (healthy, checks) = run_named_health_checks(&health.checks, resources).await;
1695 return Some(health_json_response("readiness", healthy, checks));
1696 }
1697 }
1698
1699 if let Some(health_path) = health.health_path.as_ref() {
1700 if path == health_path {
1701 let (healthy, checks) = run_named_health_checks(&health.checks, resources).await;
1702 return Some(health_json_response("health", healthy, checks));
1703 }
1704 }
1705
1706 None
1707}
1708
1709async fn serve_single_file(file_path: &str) -> Result<Response<Full<Bytes>>, std::io::Error> {
1711 let path = std::path::Path::new(file_path);
1712 let content = tokio::fs::read(path).await?;
1713 let mime = guess_mime(file_path);
1714 let mut response = Response::new(Full::new(Bytes::from(content)));
1715 if let Ok(value) = http::HeaderValue::from_str(mime) {
1716 response
1717 .headers_mut()
1718 .insert(http::header::CONTENT_TYPE, value);
1719 }
1720 if let Ok(metadata) = tokio::fs::metadata(path).await {
1721 if let Ok(modified) = metadata.modified() {
1722 if let Ok(duration) = modified.duration_since(std::time::UNIX_EPOCH) {
1723 let etag = format!("\"{}\"", duration.as_secs());
1724 if let Ok(value) = http::HeaderValue::from_str(&etag) {
1725 response.headers_mut().insert(http::header::ETAG, value);
1726 }
1727 }
1728 }
1729 }
1730 Ok(response)
1731}
1732
1733async fn serve_static_file(
1735 directory: &str,
1736 file_subpath: &str,
1737) -> Result<Response<Full<Bytes>>, std::io::Error> {
1738 let subpath = file_subpath.trim_start_matches('/');
1739 if subpath.is_empty() || subpath == "/" {
1740 return Err(std::io::Error::new(
1741 std::io::ErrorKind::NotFound,
1742 "empty path",
1743 ));
1744 }
1745 let full_path = std::path::Path::new(directory).join(subpath);
1746 let canonical = tokio::fs::canonicalize(&full_path).await?;
1748 let dir_canonical = tokio::fs::canonicalize(directory).await?;
1749 if !canonical.starts_with(&dir_canonical) {
1750 return Err(std::io::Error::new(
1751 std::io::ErrorKind::PermissionDenied,
1752 "path traversal detected",
1753 ));
1754 }
1755 let content = tokio::fs::read(&canonical).await?;
1756 let mime = guess_mime(canonical.to_str().unwrap_or(""));
1757 let mut response = Response::new(Full::new(Bytes::from(content)));
1758 if let Ok(value) = http::HeaderValue::from_str(mime) {
1759 response
1760 .headers_mut()
1761 .insert(http::header::CONTENT_TYPE, value);
1762 }
1763 if let Ok(metadata) = tokio::fs::metadata(&canonical).await {
1764 if let Ok(modified) = metadata.modified() {
1765 if let Ok(duration) = modified.duration_since(std::time::UNIX_EPOCH) {
1766 let etag = format!("\"{}\"", duration.as_secs());
1767 if let Ok(value) = http::HeaderValue::from_str(&etag) {
1768 response.headers_mut().insert(http::header::ETAG, value);
1769 }
1770 }
1771 }
1772 }
1773 Ok(response)
1774}
1775
1776fn guess_mime(path: &str) -> &'static str {
1777 match path.rsplit('.').next().unwrap_or("") {
1778 "html" | "htm" => "text/html; charset=utf-8",
1779 "css" => "text/css; charset=utf-8",
1780 "js" | "mjs" => "application/javascript; charset=utf-8",
1781 "json" => "application/json; charset=utf-8",
1782 "png" => "image/png",
1783 "jpg" | "jpeg" => "image/jpeg",
1784 "gif" => "image/gif",
1785 "svg" => "image/svg+xml",
1786 "ico" => "image/x-icon",
1787 "woff" => "font/woff",
1788 "woff2" => "font/woff2",
1789 "ttf" => "font/ttf",
1790 "txt" => "text/plain; charset=utf-8",
1791 "xml" => "application/xml; charset=utf-8",
1792 "wasm" => "application/wasm",
1793 "pdf" => "application/pdf",
1794 _ => "application/octet-stream",
1795 }
1796}
1797
1798fn apply_cache_control(
1799 mut response: Response<Full<Bytes>>,
1800 cache_control: Option<&str>,
1801) -> Response<Full<Bytes>> {
1802 if response.status() == StatusCode::OK {
1803 if let Some(value) = cache_control {
1804 if !response.headers().contains_key(http::header::CACHE_CONTROL) {
1805 if let Ok(header_value) = http::HeaderValue::from_str(value) {
1806 response
1807 .headers_mut()
1808 .insert(http::header::CACHE_CONTROL, header_value);
1809 }
1810 }
1811 }
1812 }
1813 response
1814}
1815
1816async fn maybe_handle_static_request(
1817 req: Request<Incoming>,
1818 method: &Method,
1819 path: &str,
1820 static_assets: &StaticAssetsConfig,
1821) -> Result<Request<Incoming>, HttpResponse> {
1822 if method != Method::GET && method != Method::HEAD {
1823 return Ok(req);
1824 }
1825
1826 if let Some(mount) = static_assets
1827 .mounts
1828 .iter()
1829 .find(|mount| strip_mount_prefix(path, &mount.route_prefix).is_some())
1830 {
1831 let accept_encoding = req.headers().get(http::header::ACCEPT_ENCODING).cloned();
1832 let Some(stripped_path) = strip_mount_prefix(path, &mount.route_prefix) else {
1833 return Ok(req);
1834 };
1835 let response = match serve_static_file(&mount.directory, &stripped_path).await {
1836 Ok(response) => response,
1837 Err(_) => {
1838 return Err(Response::builder()
1839 .status(StatusCode::INTERNAL_SERVER_ERROR)
1840 .body(
1841 Full::new(Bytes::from("Failed to serve static asset"))
1842 .map_err(|never| match never {})
1843 .boxed(),
1844 )
1845 .unwrap_or_else(|_| {
1846 Response::new(
1847 Full::new(Bytes::new())
1848 .map_err(|never| match never {})
1849 .boxed(),
1850 )
1851 }));
1852 }
1853 };
1854 let mut response = apply_cache_control(response, static_assets.cache_control.as_deref());
1855 response = maybe_compress_static_response(
1856 response,
1857 accept_encoding,
1858 static_assets.enable_compression,
1859 );
1860 let (parts, body) = response.into_parts();
1861 return Err(Response::from_parts(
1862 parts,
1863 body.map_err(|never| match never {}).boxed(),
1864 ));
1865 }
1866
1867 if let Some(spa_file) = static_assets.spa_fallback.as_ref() {
1868 if looks_like_spa_request(path) {
1869 let accept_encoding = req.headers().get(http::header::ACCEPT_ENCODING).cloned();
1870 let response = match serve_single_file(spa_file).await {
1871 Ok(response) => response,
1872 Err(_) => {
1873 return Err(Response::builder()
1874 .status(StatusCode::INTERNAL_SERVER_ERROR)
1875 .body(
1876 Full::new(Bytes::from("Failed to serve SPA fallback"))
1877 .map_err(|never| match never {})
1878 .boxed(),
1879 )
1880 .unwrap_or_else(|_| {
1881 Response::new(
1882 Full::new(Bytes::new())
1883 .map_err(|never| match never {})
1884 .boxed(),
1885 )
1886 }));
1887 }
1888 };
1889 let mut response =
1890 apply_cache_control(response, static_assets.cache_control.as_deref());
1891 response = maybe_compress_static_response(
1892 response,
1893 accept_encoding,
1894 static_assets.enable_compression,
1895 );
1896 let (parts, body) = response.into_parts();
1897 return Err(Response::from_parts(
1898 parts,
1899 body.map_err(|never| match never {}).boxed(),
1900 ));
1901 }
1902 }
1903
1904 Ok(req)
1905}
1906
1907fn strip_mount_prefix(path: &str, prefix: &str) -> Option<String> {
1908 let normalized_prefix = if prefix == "/" {
1909 "/"
1910 } else {
1911 prefix.trim_end_matches('/')
1912 };
1913
1914 if normalized_prefix == "/" {
1915 return Some(path.to_string());
1916 }
1917
1918 if path == normalized_prefix {
1919 return Some("/".to_string());
1920 }
1921
1922 let with_slash = format!("{normalized_prefix}/");
1923 path.strip_prefix(&with_slash)
1924 .map(|stripped| format!("/{}", stripped))
1925}
1926
1927fn looks_like_spa_request(path: &str) -> bool {
1928 let tail = path.rsplit('/').next().unwrap_or_default();
1929 !tail.contains('.')
1930}
1931
1932fn maybe_compress_static_response(
1933 response: Response<Full<Bytes>>,
1934 accept_encoding: Option<http::HeaderValue>,
1935 enable_compression: bool,
1936) -> Response<Full<Bytes>> {
1937 if !enable_compression {
1938 return response;
1939 }
1940
1941 let Some(accept_encoding) = accept_encoding else {
1942 return response;
1943 };
1944
1945 let accept_str = accept_encoding.to_str().unwrap_or("");
1946 if !accept_str.contains("gzip") {
1947 return response;
1948 }
1949
1950 let status = response.status();
1951 let headers = response.headers().clone();
1952 let body = response.into_body();
1953
1954 let data = futures_util::FutureExt::now_or_never(BodyExt::collect(body))
1956 .and_then(|r| r.ok())
1957 .map(|collected| collected.to_bytes())
1958 .unwrap_or_default();
1959
1960 let compressed = {
1962 use flate2::write::GzEncoder;
1963 use flate2::Compression;
1964 use std::io::Write;
1965 let mut encoder = GzEncoder::new(Vec::new(), Compression::default());
1966 let _ = encoder.write_all(&data);
1967 encoder.finish().unwrap_or_default()
1968 };
1969
1970 let mut builder = Response::builder().status(status);
1971 for (name, value) in headers.iter() {
1972 if name != http::header::CONTENT_LENGTH && name != http::header::CONTENT_ENCODING {
1973 builder = builder.header(name, value);
1974 }
1975 }
1976 builder
1977 .header(http::header::CONTENT_ENCODING, "gzip")
1978 .body(Full::new(Bytes::from(compressed)))
1979 .unwrap_or_else(|_| Response::new(Full::new(Bytes::new())))
1980}
1981
1982async fn run_named_health_checks<R>(
1983 checks: &[NamedHealthCheck<R>],
1984 resources: Arc<R>,
1985) -> (bool, Vec<HealthCheckReport>)
1986where
1987 R: ranvier_core::transition::ResourceRequirement + Clone + Send + Sync + 'static,
1988{
1989 let mut reports = Vec::with_capacity(checks.len());
1990 let mut healthy = true;
1991
1992 for check in checks {
1993 match (check.check)(resources.clone()).await {
1994 Ok(()) => reports.push(HealthCheckReport {
1995 name: check.name.clone(),
1996 status: "ok",
1997 error: None,
1998 }),
1999 Err(error) => {
2000 healthy = false;
2001 reports.push(HealthCheckReport {
2002 name: check.name.clone(),
2003 status: "error",
2004 error: Some(error),
2005 });
2006 }
2007 }
2008 }
2009
2010 (healthy, reports)
2011}
2012
2013fn health_json_response(
2014 probe: &'static str,
2015 healthy: bool,
2016 checks: Vec<HealthCheckReport>,
2017) -> HttpResponse {
2018 let status_code = if healthy {
2019 StatusCode::OK
2020 } else {
2021 StatusCode::SERVICE_UNAVAILABLE
2022 };
2023 let status = if healthy { "ok" } else { "degraded" };
2024 let payload = HealthReport {
2025 status,
2026 probe,
2027 checks,
2028 };
2029
2030 let body = serde_json::to_vec(&payload)
2031 .unwrap_or_else(|_| br#"{"status":"error","probe":"health"}"#.to_vec());
2032
2033 Response::builder()
2034 .status(status_code)
2035 .header(http::header::CONTENT_TYPE, "application/json")
2036 .body(
2037 Full::new(Bytes::from(body))
2038 .map_err(|never| match never {})
2039 .boxed(),
2040 )
2041 .unwrap()
2042}
2043
2044async fn shutdown_signal() {
2045 #[cfg(unix)]
2046 {
2047 use tokio::signal::unix::{SignalKind, signal};
2048
2049 match signal(SignalKind::terminate()) {
2050 Ok(mut terminate) => {
2051 tokio::select! {
2052 _ = tokio::signal::ctrl_c() => {}
2053 _ = terminate.recv() => {}
2054 }
2055 }
2056 Err(err) => {
2057 tracing::warn!("Failed to install SIGTERM handler: {:?}", err);
2058 if let Err(ctrl_c_err) = tokio::signal::ctrl_c().await {
2059 tracing::warn!("Failed to listen for Ctrl+C: {:?}", ctrl_c_err);
2060 }
2061 }
2062 }
2063 }
2064
2065 #[cfg(not(unix))]
2066 {
2067 if let Err(err) = tokio::signal::ctrl_c().await {
2068 tracing::warn!("Failed to listen for Ctrl+C: {:?}", err);
2069 }
2070 }
2071}
2072
2073async fn drain_connections(
2074 connections: &mut tokio::task::JoinSet<()>,
2075 graceful_shutdown_timeout: Duration,
2076) -> bool {
2077 if connections.is_empty() {
2078 return false;
2079 }
2080
2081 let drain_result = tokio::time::timeout(graceful_shutdown_timeout, async {
2082 while let Some(join_result) = connections.join_next().await {
2083 if let Err(err) = join_result {
2084 tracing::warn!("Connection task join error during shutdown: {:?}", err);
2085 }
2086 }
2087 })
2088 .await;
2089
2090 if drain_result.is_err() {
2091 tracing::warn!(
2092 "Graceful shutdown timeout reached ({:?}). Aborting remaining connections.",
2093 graceful_shutdown_timeout
2094 );
2095 connections.abort_all();
2096 while let Some(join_result) = connections.join_next().await {
2097 if let Err(err) = join_result {
2098 tracing::warn!("Connection task abort join error: {:?}", err);
2099 }
2100 }
2101 true
2102 } else {
2103 false
2104 }
2105}
2106
2107impl<R> Default for HttpIngress<R>
2108where
2109 R: ranvier_core::transition::ResourceRequirement + Clone + Send + Sync + 'static,
2110{
2111 fn default() -> Self {
2112 Self::new()
2113 }
2114}
2115
2116#[derive(Clone)]
2118pub struct RawIngressService<R> {
2119 routes: Arc<Vec<RouteEntry<R>>>,
2120 fallback: Option<RouteHandler<R>>,
2121 layers: Arc<Vec<ServiceLayer>>,
2122 health: Arc<HealthConfig<R>>,
2123 static_assets: Arc<StaticAssetsConfig>,
2124 resources: Arc<R>,
2125}
2126
2127impl<R> hyper::service::Service<Request<Incoming>> for RawIngressService<R>
2128where
2129 R: ranvier_core::transition::ResourceRequirement + Clone + Send + Sync + 'static,
2130{
2131 type Response = HttpResponse;
2132 type Error = Infallible;
2133 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
2134
2135 fn call(&self, req: Request<Incoming>) -> Self::Future {
2136 let routes = self.routes.clone();
2137 let fallback = self.fallback.clone();
2138 let layers = self.layers.clone();
2139 let health = self.health.clone();
2140 let static_assets = self.static_assets.clone();
2141 let resources = self.resources.clone();
2142
2143 Box::pin(async move {
2144 let service = build_http_service(
2145 routes,
2146 fallback,
2147 resources,
2148 layers,
2149 health,
2150 static_assets,
2151 #[cfg(feature = "http3")]
2152 None,
2153 );
2154 service.call(req).await
2155 })
2156 }
2157}
2158
2159#[cfg(test)]
2160mod tests {
2161 use super::*;
2162 use async_trait::async_trait;
2163 use futures_util::{SinkExt, StreamExt};
2164 use serde::Deserialize;
2165 use std::fs;
2166 use std::sync::atomic::{AtomicBool, Ordering};
2167 use tempfile::tempdir;
2168 use tokio::io::{AsyncReadExt, AsyncWriteExt};
2169 use tokio_tungstenite::tungstenite::Message as WsClientMessage;
2170 use tokio_tungstenite::tungstenite::client::IntoClientRequest;
2171
2172 async fn connect_with_retry(addr: std::net::SocketAddr) -> tokio::net::TcpStream {
2173 let deadline = tokio::time::Instant::now() + Duration::from_secs(2);
2174
2175 loop {
2176 match tokio::net::TcpStream::connect(addr).await {
2177 Ok(stream) => return stream,
2178 Err(error) => {
2179 if tokio::time::Instant::now() >= deadline {
2180 panic!("connect server: {error}");
2181 }
2182 tokio::time::sleep(Duration::from_millis(25)).await;
2183 }
2184 }
2185 }
2186 }
2187
2188 #[test]
2189 fn route_pattern_matches_static_path() {
2190 let pattern = RoutePattern::parse("/orders/list");
2191 let params = pattern.match_path("/orders/list").expect("should match");
2192 assert!(params.into_inner().is_empty());
2193 }
2194
2195 #[test]
2196 fn route_pattern_matches_param_segments() {
2197 let pattern = RoutePattern::parse("/orders/:id/items/:item_id");
2198 let params = pattern
2199 .match_path("/orders/42/items/sku-123")
2200 .expect("should match");
2201 assert_eq!(params.get("id"), Some("42"));
2202 assert_eq!(params.get("item_id"), Some("sku-123"));
2203 }
2204
2205 #[test]
2206 fn route_pattern_matches_wildcard_segment() {
2207 let pattern = RoutePattern::parse("/assets/*path");
2208 let params = pattern
2209 .match_path("/assets/css/theme/light.css")
2210 .expect("should match");
2211 assert_eq!(params.get("path"), Some("css/theme/light.css"));
2212 }
2213
2214 #[test]
2215 fn route_pattern_rejects_non_matching_path() {
2216 let pattern = RoutePattern::parse("/orders/:id");
2217 assert!(pattern.match_path("/users/42").is_none());
2218 }
2219
2220 #[test]
2221 fn graceful_shutdown_timeout_defaults_to_30_seconds() {
2222 let ingress = HttpIngress::<()>::new();
2223 assert_eq!(ingress.graceful_shutdown_timeout, Duration::from_secs(30));
2224 assert!(ingress.layers.is_empty());
2225 assert!(ingress.bus_injectors.is_empty());
2226 assert!(ingress.static_assets.mounts.is_empty());
2227 assert!(ingress.on_start.is_none());
2228 assert!(ingress.on_shutdown.is_none());
2229 }
2230
2231 #[test]
2232 fn route_without_layer_keeps_empty_route_middleware_stack() {
2233 let ingress =
2234 HttpIngress::<()>::new().get("/ping", Axon::<(), (), String, ()>::new("Ping"));
2235 assert_eq!(ingress.routes.len(), 1);
2236 assert!(ingress.routes[0].layers.is_empty());
2237 assert!(ingress.routes[0].apply_global_layers);
2238 }
2239
2240 #[test]
2241 fn timeout_layer_registers_builtin_middleware() {
2242 let ingress = HttpIngress::<()>::new().timeout_layer(Duration::from_secs(1));
2243 assert_eq!(ingress.layers.len(), 1);
2244 }
2245
2246 #[test]
2247 fn request_id_layer_registers_builtin_middleware() {
2248 let ingress = HttpIngress::<()>::new().request_id_layer();
2249 assert_eq!(ingress.layers.len(), 1);
2250 }
2251
2252 #[test]
2253 fn compression_layer_registers_builtin_middleware() {
2254 let ingress = HttpIngress::<()>::new().compression_layer();
2255 assert!(ingress.static_assets.enable_compression);
2256 }
2257
2258 #[test]
2259 fn bus_injector_registration_adds_hook() {
2260 let ingress = HttpIngress::<()>::new().bus_injector(|_req, bus| {
2261 bus.insert("ok".to_string());
2262 });
2263 assert_eq!(ingress.bus_injectors.len(), 1);
2264 }
2265
2266 #[test]
2267 fn ws_route_registers_get_route_pattern() {
2268 let ingress =
2269 HttpIngress::<()>::new().ws("/ws/events", |_socket, _resources, _bus| async {});
2270 assert_eq!(ingress.routes.len(), 1);
2271 assert_eq!(ingress.routes[0].method, Method::GET);
2272 assert_eq!(ingress.routes[0].pattern.raw, "/ws/events");
2273 }
2274
2275 #[derive(Debug, Deserialize)]
2276 struct WsWelcomeFrame {
2277 connection_id: String,
2278 path: String,
2279 tenant: String,
2280 }
2281
2282 #[tokio::test]
2283 async fn ws_route_upgrades_and_bridges_event_source_sink_with_connection_bus() {
2284 let probe = std::net::TcpListener::bind("127.0.0.1:0").expect("bind probe");
2285 let addr = probe.local_addr().expect("local addr");
2286 drop(probe);
2287
2288 let ingress = HttpIngress::<()>::new()
2289 .bind(addr.to_string())
2290 .bus_injector(|req, bus| {
2291 if let Some(value) = req.headers.get("x-tenant-id").and_then(|v| v.to_str().ok()) {
2292 bus.insert(value.to_string());
2293 }
2294 })
2295 .ws("/ws/echo", |mut socket, _resources, bus| async move {
2296 let tenant = bus
2297 .read::<String>()
2298 .cloned()
2299 .unwrap_or_else(|| "unknown".to_string());
2300 if let Some(session) = bus.read::<WebSocketSessionContext>() {
2301 let welcome = serde_json::json!({
2302 "connection_id": session.connection_id().to_string(),
2303 "path": session.path(),
2304 "tenant": tenant,
2305 });
2306 let _ = socket.send_json(&welcome).await;
2307 }
2308
2309 while let Some(event) = socket.next_event().await {
2310 match event {
2311 WebSocketEvent::Text(text) => {
2312 let _ = socket.send_event(format!("echo:{text}")).await;
2313 }
2314 WebSocketEvent::Binary(bytes) => {
2315 let _ = socket.send_event(bytes).await;
2316 }
2317 WebSocketEvent::Close => break,
2318 WebSocketEvent::Ping(_) | WebSocketEvent::Pong(_) => {}
2319 }
2320 }
2321 });
2322
2323 let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel::<()>();
2324 let server = tokio::spawn(async move {
2325 ingress
2326 .run_with_shutdown_signal((), async move {
2327 let _ = shutdown_rx.await;
2328 })
2329 .await
2330 });
2331
2332 let ws_uri = format!("ws://{addr}/ws/echo?room=alpha");
2333 let mut ws_request = ws_uri
2334 .as_str()
2335 .into_client_request()
2336 .expect("ws client request");
2337 ws_request
2338 .headers_mut()
2339 .insert("x-tenant-id", http::HeaderValue::from_static("acme"));
2340 let (mut client, _response) = tokio_tungstenite::connect_async(ws_request)
2341 .await
2342 .expect("websocket connect");
2343
2344 let welcome = client
2345 .next()
2346 .await
2347 .expect("welcome frame")
2348 .expect("welcome frame ok");
2349 let welcome_text = match welcome {
2350 WsClientMessage::Text(text) => text.to_string(),
2351 other => panic!("expected text welcome frame, got {other:?}"),
2352 };
2353 let welcome_payload: WsWelcomeFrame =
2354 serde_json::from_str(&welcome_text).expect("welcome json");
2355 assert_eq!(welcome_payload.path, "/ws/echo");
2356 assert_eq!(welcome_payload.tenant, "acme");
2357 assert!(!welcome_payload.connection_id.is_empty());
2358
2359 client
2360 .send(WsClientMessage::Text("hello".into()))
2361 .await
2362 .expect("send text");
2363 let echo_text = client
2364 .next()
2365 .await
2366 .expect("echo text frame")
2367 .expect("echo text frame ok");
2368 assert_eq!(echo_text, WsClientMessage::Text("echo:hello".into()));
2369
2370 client
2371 .send(WsClientMessage::Binary(vec![1, 2, 3, 4].into()))
2372 .await
2373 .expect("send binary");
2374 let echo_binary = client
2375 .next()
2376 .await
2377 .expect("echo binary frame")
2378 .expect("echo binary frame ok");
2379 assert_eq!(
2380 echo_binary,
2381 WsClientMessage::Binary(vec![1, 2, 3, 4].into())
2382 );
2383
2384 client.close(None).await.expect("close websocket");
2385
2386 let _ = shutdown_tx.send(());
2387 server
2388 .await
2389 .expect("server join")
2390 .expect("server shutdown should succeed");
2391 }
2392
2393 #[test]
2394 fn route_descriptors_export_http_and_health_paths() {
2395 let ingress = HttpIngress::<()>::new()
2396 .get(
2397 "/orders/:id",
2398 Axon::<(), (), String, ()>::new("OrderById"),
2399 )
2400 .health_endpoint("/healthz")
2401 .readiness_liveness("/readyz", "/livez");
2402
2403 let descriptors = ingress.route_descriptors();
2404
2405 assert!(
2406 descriptors
2407 .iter()
2408 .any(|descriptor| descriptor.method() == Method::GET
2409 && descriptor.path_pattern() == "/orders/:id")
2410 );
2411 assert!(
2412 descriptors
2413 .iter()
2414 .any(|descriptor| descriptor.method() == Method::GET
2415 && descriptor.path_pattern() == "/healthz")
2416 );
2417 assert!(
2418 descriptors
2419 .iter()
2420 .any(|descriptor| descriptor.method() == Method::GET
2421 && descriptor.path_pattern() == "/readyz")
2422 );
2423 assert!(
2424 descriptors
2425 .iter()
2426 .any(|descriptor| descriptor.method() == Method::GET
2427 && descriptor.path_pattern() == "/livez")
2428 );
2429 }
2430
2431 #[tokio::test]
2432 async fn lifecycle_hooks_fire_on_start_and_shutdown() {
2433 let started = Arc::new(AtomicBool::new(false));
2434 let shutdown = Arc::new(AtomicBool::new(false));
2435
2436 let started_flag = started.clone();
2437 let shutdown_flag = shutdown.clone();
2438
2439 let ingress = HttpIngress::<()>::new()
2440 .bind("127.0.0.1:0")
2441 .on_start(move || {
2442 started_flag.store(true, Ordering::SeqCst);
2443 })
2444 .on_shutdown(move || {
2445 shutdown_flag.store(true, Ordering::SeqCst);
2446 })
2447 .graceful_shutdown(Duration::from_millis(50));
2448
2449 ingress
2450 .run_with_shutdown_signal((), async {
2451 tokio::time::sleep(Duration::from_millis(20)).await;
2452 })
2453 .await
2454 .expect("server should exit gracefully");
2455
2456 assert!(started.load(Ordering::SeqCst));
2457 assert!(shutdown.load(Ordering::SeqCst));
2458 }
2459
2460 #[tokio::test]
2461 async fn graceful_shutdown_drains_in_flight_requests_before_exit() {
2462 #[derive(Clone)]
2463 struct SlowDrainRoute;
2464
2465 #[async_trait]
2466 impl Transition<(), String> for SlowDrainRoute {
2467 type Error = String;
2468 type Resources = ();
2469
2470 async fn run(
2471 &self,
2472 _state: (),
2473 _resources: &Self::Resources,
2474 _bus: &mut Bus,
2475 ) -> Outcome<String, Self::Error> {
2476 tokio::time::sleep(Duration::from_millis(120)).await;
2477 Outcome::next("drained-ok".to_string())
2478 }
2479 }
2480
2481 let probe = std::net::TcpListener::bind("127.0.0.1:0").expect("bind probe");
2482 let addr = probe.local_addr().expect("local addr");
2483 drop(probe);
2484
2485 let ingress = HttpIngress::<()>::new()
2486 .bind(addr.to_string())
2487 .graceful_shutdown(Duration::from_millis(500))
2488 .get(
2489 "/drain",
2490 Axon::<(), (), String, ()>::new("SlowDrain").then(SlowDrainRoute),
2491 );
2492
2493 let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel::<()>();
2494 let server = tokio::spawn(async move {
2495 ingress
2496 .run_with_shutdown_signal((), async move {
2497 let _ = shutdown_rx.await;
2498 })
2499 .await
2500 });
2501
2502 let mut stream = connect_with_retry(addr).await;
2503 stream
2504 .write_all(b"GET /drain HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n")
2505 .await
2506 .expect("write request");
2507
2508 tokio::time::sleep(Duration::from_millis(20)).await;
2509 let _ = shutdown_tx.send(());
2510
2511 let mut buf = Vec::new();
2512 stream.read_to_end(&mut buf).await.expect("read response");
2513 let response = String::from_utf8_lossy(&buf);
2514 assert!(response.starts_with("HTTP/1.1 200"), "{response}");
2515 assert!(response.contains("drained-ok"), "{response}");
2516
2517 server
2518 .await
2519 .expect("server join")
2520 .expect("server shutdown should succeed");
2521 }
2522
2523 #[tokio::test]
2524 async fn serve_dir_serves_static_file_with_cache_and_metadata_headers() {
2525 let temp = tempdir().expect("tempdir");
2526 let root = temp.path().join("public");
2527 fs::create_dir_all(&root).expect("create dir");
2528 let file = root.join("hello.txt");
2529 fs::write(&file, "hello static").expect("write file");
2530
2531 let ingress =
2532 Ranvier::http::<()>().serve_dir("/static", root.to_string_lossy().to_string());
2533 let app = crate::test_harness::TestApp::new(ingress, ());
2534 let response = app
2535 .send(crate::test_harness::TestRequest::get("/static/hello.txt"))
2536 .await
2537 .expect("request should succeed");
2538
2539 assert_eq!(response.status(), StatusCode::OK);
2540 assert_eq!(response.text().expect("utf8"), "hello static");
2541 assert!(response.header("cache-control").is_some());
2542 let has_metadata_header =
2543 response.header("etag").is_some() || response.header("last-modified").is_some();
2544 assert!(has_metadata_header);
2545 }
2546
2547 #[tokio::test]
2548 async fn spa_fallback_returns_index_for_unmatched_path() {
2549 let temp = tempdir().expect("tempdir");
2550 let index = temp.path().join("index.html");
2551 fs::write(&index, "<html><body>spa</body></html>").expect("write index");
2552
2553 let ingress = Ranvier::http::<()>().spa_fallback(index.to_string_lossy().to_string());
2554 let app = crate::test_harness::TestApp::new(ingress, ());
2555 let response = app
2556 .send(crate::test_harness::TestRequest::get("/dashboard/settings"))
2557 .await
2558 .expect("request should succeed");
2559
2560 assert_eq!(response.status(), StatusCode::OK);
2561 assert!(response.text().expect("utf8").contains("spa"));
2562 }
2563
2564 #[tokio::test]
2565 async fn static_compression_layer_sets_content_encoding_for_gzip_client() {
2566 let temp = tempdir().expect("tempdir");
2567 let root = temp.path().join("public");
2568 fs::create_dir_all(&root).expect("create dir");
2569 let file = root.join("compressed.txt");
2570 fs::write(&file, "compress me ".repeat(400)).expect("write file");
2571
2572 let ingress = Ranvier::http::<()>()
2573 .serve_dir("/static", root.to_string_lossy().to_string())
2574 .compression_layer();
2575 let app = crate::test_harness::TestApp::new(ingress, ());
2576 let response = app
2577 .send(
2578 crate::test_harness::TestRequest::get("/static/compressed.txt")
2579 .header("accept-encoding", "gzip"),
2580 )
2581 .await
2582 .expect("request should succeed");
2583
2584 assert_eq!(response.status(), StatusCode::OK);
2585 assert_eq!(
2586 response
2587 .header("content-encoding")
2588 .and_then(|value| value.to_str().ok()),
2589 Some("gzip")
2590 );
2591 }
2592
2593 #[tokio::test]
2594 async fn drain_connections_completes_before_timeout() {
2595 let mut connections = tokio::task::JoinSet::new();
2596 connections.spawn(async {
2597 tokio::time::sleep(Duration::from_millis(20)).await;
2598 });
2599
2600 let timed_out = drain_connections(&mut connections, Duration::from_millis(200)).await;
2601 assert!(!timed_out);
2602 assert!(connections.is_empty());
2603 }
2604
2605 #[tokio::test]
2606 async fn drain_connections_times_out_and_aborts() {
2607 let mut connections = tokio::task::JoinSet::new();
2608 connections.spawn(async {
2609 tokio::time::sleep(Duration::from_secs(10)).await;
2610 });
2611
2612 let timed_out = drain_connections(&mut connections, Duration::from_millis(10)).await;
2613 assert!(timed_out);
2614 assert!(connections.is_empty());
2615 }
2616
2617 #[tokio::test]
2618 async fn timeout_layer_returns_408_for_slow_route() {
2619 #[derive(Clone)]
2620 struct SlowRoute;
2621
2622 #[async_trait]
2623 impl Transition<(), String> for SlowRoute {
2624 type Error = String;
2625 type Resources = ();
2626
2627 async fn run(
2628 &self,
2629 _state: (),
2630 _resources: &Self::Resources,
2631 _bus: &mut Bus,
2632 ) -> Outcome<String, Self::Error> {
2633 tokio::time::sleep(Duration::from_millis(80)).await;
2634 Outcome::next("slow-ok".to_string())
2635 }
2636 }
2637
2638 let probe = std::net::TcpListener::bind("127.0.0.1:0").expect("bind probe");
2639 let addr = probe.local_addr().expect("local addr");
2640 drop(probe);
2641
2642 let ingress = HttpIngress::<()>::new()
2643 .bind(addr.to_string())
2644 .timeout_layer(Duration::from_millis(10))
2645 .get(
2646 "/slow",
2647 Axon::<(), (), String, ()>::new("Slow").then(SlowRoute),
2648 );
2649
2650 let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel::<()>();
2651 let server = tokio::spawn(async move {
2652 ingress
2653 .run_with_shutdown_signal((), async move {
2654 let _ = shutdown_rx.await;
2655 })
2656 .await
2657 });
2658
2659 let mut stream = connect_with_retry(addr).await;
2660 stream
2661 .write_all(b"GET /slow HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n")
2662 .await
2663 .expect("write request");
2664
2665 let mut buf = Vec::new();
2666 stream.read_to_end(&mut buf).await.expect("read response");
2667 let response = String::from_utf8_lossy(&buf);
2668 assert!(response.starts_with("HTTP/1.1 408"), "{response}");
2669
2670 let _ = shutdown_tx.send(());
2671 server
2672 .await
2673 .expect("server join")
2674 .expect("server shutdown should succeed");
2675 }
2676
2677}