1use base64::Engine;
17use bytes::Bytes;
18use futures_util::{SinkExt, StreamExt};
19use http::{Method, Request, Response, StatusCode, Uri};
20use http_body::Body;
21use http_body_util::{BodyExt, Full};
22use hyper::body::Incoming;
23use hyper::server::conn::http1;
24use hyper::upgrade::Upgraded;
25use hyper_util::rt::TokioIo;
26use hyper_util::service::TowerToHyperService;
27use ranvier_core::event::{EventSink, EventSource};
28use ranvier_core::prelude::*;
29use ranvier_runtime::Axon;
30use serde::Serialize;
31use serde::de::DeserializeOwned;
32use sha1::{Digest, Sha1};
33use std::collections::HashMap;
34use std::convert::Infallible;
35use std::future::Future;
36use std::net::SocketAddr;
37use std::pin::Pin;
38use std::sync::Arc;
39use std::time::Duration;
40use tokio::net::TcpListener;
41use tokio::sync::Mutex;
42use tokio_tungstenite::WebSocketStream;
43use tokio_tungstenite::tungstenite::{Error as WsWireError, Message as WsWireMessage};
44use tower::util::BoxCloneService;
45use tower::{Layer, Service, ServiceExt, service_fn};
46use tower_http::compression::CompressionLayer;
47use tower_http::services::{ServeDir, ServeFile};
48use tracing::Instrument;
49
50use crate::response::{IntoResponse, outcome_to_response_with_error};
51
52pub struct Ranvier;
57
58impl Ranvier {
59 pub fn http<R>() -> HttpIngress<R>
61 where
62 R: ranvier_core::transition::ResourceRequirement + Clone,
63 {
64 HttpIngress::new()
65 }
66}
67
68type RouteHandler<R> = Arc<
70 dyn Fn(Request<Incoming>, &R) -> Pin<Box<dyn Future<Output = Response<Full<Bytes>>> + Send>>
71 + Send
72 + Sync,
73>;
74
75type BoxHttpService = BoxCloneService<Request<Incoming>, Response<Full<Bytes>>, Infallible>;
76type ServiceLayer = Arc<dyn Fn(BoxHttpService) -> BoxHttpService + Send + Sync>;
77type LifecycleHook = Arc<dyn Fn() + Send + Sync>;
78type BusInjector = Arc<dyn Fn(&Request<Incoming>, &mut Bus) + Send + Sync>;
79type WsSessionFuture = Pin<Box<dyn Future<Output = ()> + Send>>;
80type WsSessionHandler<R> =
81 Arc<dyn Fn(WebSocketConnection, Arc<R>, Bus) -> WsSessionFuture + Send + Sync>;
82type HealthCheckFuture = Pin<Box<dyn Future<Output = Result<(), String>> + Send>>;
83type HealthCheckFn<R> = Arc<dyn Fn(Arc<R>) -> HealthCheckFuture + Send + Sync>;
84const REQUEST_ID_HEADER: &str = "x-request-id";
85const WS_UPGRADE_TOKEN: &str = "websocket";
86const WS_GUID: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
87
88#[derive(Clone)]
89struct NamedHealthCheck<R> {
90 name: String,
91 check: HealthCheckFn<R>,
92}
93
94#[derive(Clone)]
95struct HealthConfig<R> {
96 health_path: Option<String>,
97 readiness_path: Option<String>,
98 liveness_path: Option<String>,
99 checks: Vec<NamedHealthCheck<R>>,
100}
101
102impl<R> Default for HealthConfig<R> {
103 fn default() -> Self {
104 Self {
105 health_path: None,
106 readiness_path: None,
107 liveness_path: None,
108 checks: Vec::new(),
109 }
110 }
111}
112
113#[derive(Clone, Default)]
114struct StaticAssetsConfig {
115 mounts: Vec<StaticMount>,
116 spa_fallback: Option<String>,
117 cache_control: Option<String>,
118 enable_compression: bool,
119}
120
121#[derive(Clone)]
122struct StaticMount {
123 route_prefix: String,
124 directory: String,
125}
126
127#[derive(Serialize)]
128struct HealthReport {
129 status: &'static str,
130 probe: &'static str,
131 checks: Vec<HealthCheckReport>,
132}
133
134#[derive(Serialize)]
135struct HealthCheckReport {
136 name: String,
137 status: &'static str,
138 #[serde(skip_serializing_if = "Option::is_none")]
139 error: Option<String>,
140}
141
142#[derive(Clone)]
143struct TimeoutService {
144 inner: BoxHttpService,
145 timeout: Duration,
146}
147
148impl Service<Request<Incoming>> for TimeoutService {
149 type Response = Response<Full<Bytes>>;
150 type Error = Infallible;
151 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
152
153 fn poll_ready(
154 &mut self,
155 cx: &mut std::task::Context<'_>,
156 ) -> std::task::Poll<Result<(), Self::Error>> {
157 self.inner.poll_ready(cx)
158 }
159
160 fn call(&mut self, req: Request<Incoming>) -> Self::Future {
161 let timeout = self.timeout;
162 let fut = self.inner.call(req);
163 Box::pin(async move {
164 match tokio::time::timeout(timeout, fut).await {
165 Ok(response) => response,
166 Err(_) => Ok(Response::builder()
167 .status(StatusCode::REQUEST_TIMEOUT)
168 .body(Full::new(Bytes::from("Request Timeout")))
169 .unwrap()),
170 }
171 })
172 }
173}
174
175#[derive(Clone)]
176struct RequestIdService {
177 inner: BoxHttpService,
178}
179
180impl Service<Request<Incoming>> for RequestIdService {
181 type Response = Response<Full<Bytes>>;
182 type Error = Infallible;
183 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
184
185 fn poll_ready(
186 &mut self,
187 cx: &mut std::task::Context<'_>,
188 ) -> std::task::Poll<Result<(), Self::Error>> {
189 self.inner.poll_ready(cx)
190 }
191
192 fn call(&mut self, req: Request<Incoming>) -> Self::Future {
193 let mut req = req;
194 let request_id = req
195 .headers()
196 .get(REQUEST_ID_HEADER)
197 .cloned()
198 .unwrap_or_else(|| {
199 http::HeaderValue::from_str(&uuid::Uuid::new_v4().to_string())
200 .unwrap_or_else(|_| http::HeaderValue::from_static("request-id-unavailable"))
201 });
202
203 req.headers_mut()
204 .insert(REQUEST_ID_HEADER, request_id.clone());
205
206 let fut = self.inner.call(req);
207 Box::pin(async move {
208 let mut response = fut.await?;
209 response.headers_mut().insert(REQUEST_ID_HEADER, request_id);
210 Ok(response)
211 })
212 }
213}
214
215fn to_service_layer<L>(layer: L) -> ServiceLayer
216where
217 L: Layer<BoxHttpService> + Clone + Send + Sync + 'static,
218 L::Service: Service<Request<Incoming>, Response = Response<Full<Bytes>>, Error = Infallible>
219 + Clone
220 + Send
221 + 'static,
222 <L::Service as Service<Request<Incoming>>>::Future: Send + 'static,
223{
224 Arc::new(move |service: BoxHttpService| BoxCloneService::new(layer.clone().layer(service)))
225}
226
227#[derive(Clone, Debug, Default, PartialEq, Eq)]
228pub struct PathParams {
229 values: HashMap<String, String>,
230}
231
232#[derive(Clone, Debug, PartialEq, Eq)]
234pub struct HttpRouteDescriptor {
235 method: Method,
236 path_pattern: String,
237}
238
239impl HttpRouteDescriptor {
240 pub fn new(method: Method, path_pattern: impl Into<String>) -> Self {
241 Self {
242 method,
243 path_pattern: path_pattern.into(),
244 }
245 }
246
247 pub fn method(&self) -> &Method {
248 &self.method
249 }
250
251 pub fn path_pattern(&self) -> &str {
252 &self.path_pattern
253 }
254}
255
256#[derive(Clone, Debug, PartialEq, Eq, Serialize)]
258pub struct WebSocketSessionContext {
259 connection_id: uuid::Uuid,
260 path: String,
261 query: Option<String>,
262}
263
264impl WebSocketSessionContext {
265 pub fn connection_id(&self) -> uuid::Uuid {
266 self.connection_id
267 }
268
269 pub fn path(&self) -> &str {
270 &self.path
271 }
272
273 pub fn query(&self) -> Option<&str> {
274 self.query.as_deref()
275 }
276}
277
278#[derive(Clone, Debug, PartialEq, Eq)]
280pub enum WebSocketEvent {
281 Text(String),
282 Binary(Vec<u8>),
283 Ping(Vec<u8>),
284 Pong(Vec<u8>),
285 Close,
286}
287
288impl WebSocketEvent {
289 pub fn text(value: impl Into<String>) -> Self {
290 Self::Text(value.into())
291 }
292
293 pub fn binary(value: impl Into<Vec<u8>>) -> Self {
294 Self::Binary(value.into())
295 }
296
297 pub fn json<T>(value: &T) -> Result<Self, serde_json::Error>
298 where
299 T: Serialize,
300 {
301 let text = serde_json::to_string(value)?;
302 Ok(Self::Text(text))
303 }
304}
305
306#[derive(Debug, thiserror::Error)]
307pub enum WebSocketError {
308 #[error("websocket wire error: {0}")]
309 Wire(#[from] WsWireError),
310 #[error("json serialization failed: {0}")]
311 JsonSerialize(#[source] serde_json::Error),
312 #[error("json deserialization failed: {0}")]
313 JsonDeserialize(#[source] serde_json::Error),
314 #[error("expected text or binary frame for json payload")]
315 NonDataFrame,
316}
317
318type WsServerStream = WebSocketStream<TokioIo<Upgraded>>;
319type WsServerSink = futures_util::stream::SplitSink<WsServerStream, WsWireMessage>;
320type WsServerSource = futures_util::stream::SplitStream<WsServerStream>;
321
322pub struct WebSocketConnection {
324 sink: Mutex<WsServerSink>,
325 source: Mutex<WsServerSource>,
326 session: WebSocketSessionContext,
327}
328
329impl WebSocketConnection {
330 fn new(stream: WsServerStream, session: WebSocketSessionContext) -> Self {
331 let (sink, source) = stream.split();
332 Self {
333 sink: Mutex::new(sink),
334 source: Mutex::new(source),
335 session,
336 }
337 }
338
339 pub fn session(&self) -> &WebSocketSessionContext {
340 &self.session
341 }
342
343 pub async fn send(&self, event: WebSocketEvent) -> Result<(), WebSocketError> {
344 let mut sink = self.sink.lock().await;
345 sink.send(event.into_wire_message()).await?;
346 Ok(())
347 }
348
349 pub async fn send_json<T>(&self, value: &T) -> Result<(), WebSocketError>
350 where
351 T: Serialize,
352 {
353 let event = WebSocketEvent::json(value).map_err(WebSocketError::JsonSerialize)?;
354 self.send(event).await
355 }
356
357 pub async fn next_json<T>(&mut self) -> Result<Option<T>, WebSocketError>
358 where
359 T: DeserializeOwned,
360 {
361 let Some(event) = self.recv_event().await? else {
362 return Ok(None);
363 };
364 match event {
365 WebSocketEvent::Text(text) => serde_json::from_str(&text)
366 .map(Some)
367 .map_err(WebSocketError::JsonDeserialize),
368 WebSocketEvent::Binary(bytes) => serde_json::from_slice(&bytes)
369 .map(Some)
370 .map_err(WebSocketError::JsonDeserialize),
371 _ => Err(WebSocketError::NonDataFrame),
372 }
373 }
374
375 async fn recv_event(&mut self) -> Result<Option<WebSocketEvent>, WsWireError> {
376 let mut source = self.source.lock().await;
377 while let Some(item) = source.next().await {
378 let message = item?;
379 if let Some(event) = WebSocketEvent::from_wire_message(message) {
380 return Ok(Some(event));
381 }
382 }
383 Ok(None)
384 }
385}
386
387impl WebSocketEvent {
388 fn from_wire_message(message: WsWireMessage) -> Option<Self> {
389 match message {
390 WsWireMessage::Text(value) => Some(Self::Text(value.to_string())),
391 WsWireMessage::Binary(value) => Some(Self::Binary(value.to_vec())),
392 WsWireMessage::Ping(value) => Some(Self::Ping(value.to_vec())),
393 WsWireMessage::Pong(value) => Some(Self::Pong(value.to_vec())),
394 WsWireMessage::Close(_) => Some(Self::Close),
395 WsWireMessage::Frame(_) => None,
396 }
397 }
398
399 fn into_wire_message(self) -> WsWireMessage {
400 match self {
401 Self::Text(value) => WsWireMessage::Text(value.into()),
402 Self::Binary(value) => WsWireMessage::Binary(value.into()),
403 Self::Ping(value) => WsWireMessage::Ping(value.into()),
404 Self::Pong(value) => WsWireMessage::Pong(value.into()),
405 Self::Close => WsWireMessage::Close(None),
406 }
407 }
408}
409
410#[async_trait::async_trait]
411impl EventSource<WebSocketEvent> for WebSocketConnection {
412 async fn next_event(&mut self) -> Option<WebSocketEvent> {
413 match self.recv_event().await {
414 Ok(event) => event,
415 Err(error) => {
416 tracing::warn!(ranvier.ws.error = %error, "websocket source read failed");
417 None
418 }
419 }
420 }
421}
422
423#[async_trait::async_trait]
424impl EventSink<WebSocketEvent> for WebSocketConnection {
425 type Error = WebSocketError;
426
427 async fn send_event(&self, event: WebSocketEvent) -> Result<(), Self::Error> {
428 self.send(event).await
429 }
430}
431
432#[async_trait::async_trait]
433impl EventSink<String> for WebSocketConnection {
434 type Error = WebSocketError;
435
436 async fn send_event(&self, event: String) -> Result<(), Self::Error> {
437 self.send(WebSocketEvent::Text(event)).await
438 }
439}
440
441#[async_trait::async_trait]
442impl EventSink<Vec<u8>> for WebSocketConnection {
443 type Error = WebSocketError;
444
445 async fn send_event(&self, event: Vec<u8>) -> Result<(), Self::Error> {
446 self.send(WebSocketEvent::Binary(event)).await
447 }
448}
449
450impl PathParams {
451 pub fn new(values: HashMap<String, String>) -> Self {
452 Self { values }
453 }
454
455 pub fn get(&self, key: &str) -> Option<&str> {
456 self.values.get(key).map(String::as_str)
457 }
458
459 pub fn as_map(&self) -> &HashMap<String, String> {
460 &self.values
461 }
462
463 pub fn into_inner(self) -> HashMap<String, String> {
464 self.values
465 }
466}
467
468#[derive(Clone, Debug, PartialEq, Eq)]
469enum RouteSegment {
470 Static(String),
471 Param(String),
472 Wildcard(String),
473}
474
475#[derive(Clone, Debug, PartialEq, Eq)]
476struct RoutePattern {
477 raw: String,
478 segments: Vec<RouteSegment>,
479}
480
481impl RoutePattern {
482 fn parse(path: &str) -> Self {
483 let segments = path_segments(path)
484 .into_iter()
485 .map(|segment| {
486 if let Some(name) = segment.strip_prefix(':') {
487 if !name.is_empty() {
488 return RouteSegment::Param(name.to_string());
489 }
490 }
491 if let Some(name) = segment.strip_prefix('*') {
492 if !name.is_empty() {
493 return RouteSegment::Wildcard(name.to_string());
494 }
495 }
496 RouteSegment::Static(segment.to_string())
497 })
498 .collect();
499
500 Self {
501 raw: path.to_string(),
502 segments,
503 }
504 }
505
506 fn match_path(&self, path: &str) -> Option<PathParams> {
507 let mut params = HashMap::new();
508 let path_segments = path_segments(path);
509 let mut pattern_index = 0usize;
510 let mut path_index = 0usize;
511
512 while pattern_index < self.segments.len() {
513 match &self.segments[pattern_index] {
514 RouteSegment::Static(expected) => {
515 let actual = path_segments.get(path_index)?;
516 if actual != expected {
517 return None;
518 }
519 pattern_index += 1;
520 path_index += 1;
521 }
522 RouteSegment::Param(name) => {
523 let actual = path_segments.get(path_index)?;
524 params.insert(name.clone(), (*actual).to_string());
525 pattern_index += 1;
526 path_index += 1;
527 }
528 RouteSegment::Wildcard(name) => {
529 let remaining = path_segments[path_index..].join("/");
530 params.insert(name.clone(), remaining);
531 pattern_index += 1;
532 path_index = path_segments.len();
533 break;
534 }
535 }
536 }
537
538 if pattern_index == self.segments.len() && path_index == path_segments.len() {
539 Some(PathParams::new(params))
540 } else {
541 None
542 }
543 }
544}
545
546#[derive(Clone)]
547struct RouteEntry<R> {
548 method: Method,
549 pattern: RoutePattern,
550 handler: RouteHandler<R>,
551 layers: Arc<Vec<ServiceLayer>>,
552 apply_global_layers: bool,
553}
554
555fn path_segments(path: &str) -> Vec<&str> {
556 if path == "/" {
557 return Vec::new();
558 }
559
560 path.trim_matches('/')
561 .split('/')
562 .filter(|segment| !segment.is_empty())
563 .collect()
564}
565
566fn normalize_route_path(path: String) -> String {
567 if path.is_empty() {
568 return "/".to_string();
569 }
570 if path.starts_with('/') {
571 path
572 } else {
573 format!("/{path}")
574 }
575}
576
577fn find_matching_route<'a, R>(
578 routes: &'a [RouteEntry<R>],
579 method: &Method,
580 path: &str,
581) -> Option<(&'a RouteEntry<R>, PathParams)> {
582 for entry in routes {
583 if &entry.method != method {
584 continue;
585 }
586 if let Some(params) = entry.pattern.match_path(path) {
587 return Some((entry, params));
588 }
589 }
590 None
591}
592
593fn header_contains_token(
594 headers: &http::HeaderMap,
595 name: http::header::HeaderName,
596 token: &str,
597) -> bool {
598 headers
599 .get(name)
600 .and_then(|value| value.to_str().ok())
601 .map(|value| {
602 value
603 .split(',')
604 .any(|part| part.trim().eq_ignore_ascii_case(token))
605 })
606 .unwrap_or(false)
607}
608
609fn websocket_session_from_request(req: &Request<Incoming>) -> WebSocketSessionContext {
610 WebSocketSessionContext {
611 connection_id: uuid::Uuid::new_v4(),
612 path: req.uri().path().to_string(),
613 query: req.uri().query().map(str::to_string),
614 }
615}
616
617fn websocket_accept_key(client_key: &str) -> String {
618 let mut hasher = Sha1::new();
619 hasher.update(client_key.as_bytes());
620 hasher.update(WS_GUID.as_bytes());
621 let digest = hasher.finalize();
622 base64::engine::general_purpose::STANDARD.encode(digest)
623}
624
625fn websocket_bad_request(message: &'static str) -> Response<Full<Bytes>> {
626 Response::builder()
627 .status(StatusCode::BAD_REQUEST)
628 .body(Full::new(Bytes::from(message)))
629 .unwrap_or_else(|_| Response::new(Full::new(Bytes::new())))
630}
631
632fn websocket_upgrade_response(
633 req: &mut Request<Incoming>,
634) -> Result<(Response<Full<Bytes>>, hyper::upgrade::OnUpgrade), Response<Full<Bytes>>> {
635 if req.method() != Method::GET {
636 return Err(websocket_bad_request(
637 "WebSocket upgrade requires GET method",
638 ));
639 }
640
641 if !header_contains_token(req.headers(), http::header::CONNECTION, "upgrade") {
642 return Err(websocket_bad_request(
643 "Missing Connection: upgrade header for WebSocket",
644 ));
645 }
646
647 if !header_contains_token(req.headers(), http::header::UPGRADE, WS_UPGRADE_TOKEN) {
648 return Err(websocket_bad_request("Missing Upgrade: websocket header"));
649 }
650
651 if let Some(version) = req.headers().get("sec-websocket-version") {
652 if version != "13" {
653 return Err(websocket_bad_request(
654 "Unsupported Sec-WebSocket-Version (expected 13)",
655 ));
656 }
657 }
658
659 let Some(client_key) = req
660 .headers()
661 .get("sec-websocket-key")
662 .and_then(|value| value.to_str().ok())
663 else {
664 return Err(websocket_bad_request(
665 "Missing Sec-WebSocket-Key header for WebSocket",
666 ));
667 };
668
669 let accept_key = websocket_accept_key(client_key);
670 let on_upgrade = hyper::upgrade::on(req);
671 let response = Response::builder()
672 .status(StatusCode::SWITCHING_PROTOCOLS)
673 .header(http::header::UPGRADE, WS_UPGRADE_TOKEN)
674 .header(http::header::CONNECTION, "Upgrade")
675 .header("sec-websocket-accept", accept_key)
676 .body(Full::new(Bytes::new()))
677 .unwrap_or_else(|_| Response::new(Full::new(Bytes::new())));
678
679 Ok((response, on_upgrade))
680}
681
682pub struct HttpIngress<R = ()> {
688 addr: Option<String>,
690 routes: Vec<RouteEntry<R>>,
692 fallback: Option<RouteHandler<R>>,
694 layers: Vec<ServiceLayer>,
696 on_start: Option<LifecycleHook>,
698 on_shutdown: Option<LifecycleHook>,
700 graceful_shutdown_timeout: Duration,
702 bus_injectors: Vec<BusInjector>,
704 static_assets: StaticAssetsConfig,
706 health: HealthConfig<R>,
708 _phantom: std::marker::PhantomData<R>,
709}
710
711impl<R> HttpIngress<R>
712where
713 R: ranvier_core::transition::ResourceRequirement + Clone + Send + Sync + 'static,
714{
715 pub fn new() -> Self {
717 Self {
718 addr: None,
719 routes: Vec::new(),
720 fallback: None,
721 layers: Vec::new(),
722 on_start: None,
723 on_shutdown: None,
724 graceful_shutdown_timeout: Duration::from_secs(30),
725 bus_injectors: Vec::new(),
726 static_assets: StaticAssetsConfig::default(),
727 health: HealthConfig::default(),
728 _phantom: std::marker::PhantomData,
729 }
730 }
731
732 pub fn bind(mut self, addr: impl Into<String>) -> Self {
734 self.addr = Some(addr.into());
735 self
736 }
737
738 pub fn on_start<F>(mut self, callback: F) -> Self
740 where
741 F: Fn() + Send + Sync + 'static,
742 {
743 self.on_start = Some(Arc::new(callback));
744 self
745 }
746
747 pub fn on_shutdown<F>(mut self, callback: F) -> Self
749 where
750 F: Fn() + Send + Sync + 'static,
751 {
752 self.on_shutdown = Some(Arc::new(callback));
753 self
754 }
755
756 pub fn graceful_shutdown(mut self, timeout: Duration) -> Self {
758 self.graceful_shutdown_timeout = timeout;
759 self
760 }
761
762 pub fn layer<L>(mut self, layer: L) -> Self
767 where
768 L: Layer<BoxHttpService> + Clone + Send + Sync + 'static,
769 L::Service: Service<Request<Incoming>, Response = Response<Full<Bytes>>, Error = Infallible>
770 + Clone
771 + Send
772 + 'static,
773 <L::Service as Service<Request<Incoming>>>::Future: Send + 'static,
774 {
775 self.layers.push(to_service_layer(layer));
776 self
777 }
778
779 pub fn timeout_layer(mut self, timeout: Duration) -> Self {
782 self.layers.push(Arc::new(move |service: BoxHttpService| {
783 BoxCloneService::new(TimeoutService {
784 inner: service,
785 timeout,
786 })
787 }));
788 self
789 }
790
791 pub fn request_id_layer(mut self) -> Self {
795 self.layers.push(Arc::new(move |service: BoxHttpService| {
796 BoxCloneService::new(RequestIdService { inner: service })
797 }));
798 self
799 }
800
801 pub fn bus_injector<F>(mut self, injector: F) -> Self
806 where
807 F: Fn(&Request<Incoming>, &mut Bus) + Send + Sync + 'static,
808 {
809 self.bus_injectors.push(Arc::new(injector));
810 self
811 }
812
813 pub fn route_descriptors(&self) -> Vec<HttpRouteDescriptor> {
815 let mut descriptors = self
816 .routes
817 .iter()
818 .map(|entry| HttpRouteDescriptor::new(entry.method.clone(), entry.pattern.raw.clone()))
819 .collect::<Vec<_>>();
820
821 if let Some(path) = &self.health.health_path {
822 descriptors.push(HttpRouteDescriptor::new(Method::GET, path.clone()));
823 }
824 if let Some(path) = &self.health.readiness_path {
825 descriptors.push(HttpRouteDescriptor::new(Method::GET, path.clone()));
826 }
827 if let Some(path) = &self.health.liveness_path {
828 descriptors.push(HttpRouteDescriptor::new(Method::GET, path.clone()));
829 }
830
831 descriptors
832 }
833
834 pub fn serve_dir(
838 mut self,
839 route_prefix: impl Into<String>,
840 directory: impl Into<String>,
841 ) -> Self {
842 self.static_assets.mounts.push(StaticMount {
843 route_prefix: normalize_route_path(route_prefix.into()),
844 directory: directory.into(),
845 });
846 if self.static_assets.cache_control.is_none() {
847 self.static_assets.cache_control = Some("public, max-age=3600".to_string());
848 }
849 self
850 }
851
852 pub fn spa_fallback(mut self, file_path: impl Into<String>) -> Self {
856 self.static_assets.spa_fallback = Some(file_path.into());
857 self
858 }
859
860 pub fn static_cache_control(mut self, cache_control: impl Into<String>) -> Self {
862 self.static_assets.cache_control = Some(cache_control.into());
863 self
864 }
865
866 pub fn compression_layer(mut self) -> Self {
868 self.static_assets.enable_compression = true;
869 self
870 }
871
872 pub fn ws<H, Fut>(mut self, path: impl Into<String>, handler: H) -> Self
879 where
880 H: Fn(WebSocketConnection, Arc<R>, Bus) -> Fut + Send + Sync + 'static,
881 Fut: Future<Output = ()> + Send + 'static,
882 {
883 let path_str: String = path.into();
884 let ws_handler: WsSessionHandler<R> = Arc::new(move |connection, resources, bus| {
885 Box::pin(handler(connection, resources, bus))
886 });
887 let bus_injectors = Arc::new(self.bus_injectors.clone());
888 let path_for_pattern = path_str.clone();
889 let path_for_handler = path_str;
890
891 let route_handler: RouteHandler<R> =
892 Arc::new(move |mut req: Request<Incoming>, res: &R| {
893 let ws_handler = ws_handler.clone();
894 let bus_injectors = bus_injectors.clone();
895 let resources = Arc::new(res.clone());
896 let path = path_for_handler.clone();
897
898 Box::pin(async move {
899 let request_id = uuid::Uuid::new_v4().to_string();
900 let span = tracing::info_span!(
901 "WebSocketUpgrade",
902 ranvier.ws.path = %path,
903 ranvier.ws.request_id = %request_id
904 );
905
906 async move {
907 let mut bus = Bus::new();
908 for injector in bus_injectors.iter() {
909 injector(&req, &mut bus);
910 }
911
912 let session = websocket_session_from_request(&req);
913 bus.insert(session.clone());
914
915 let (response, on_upgrade) = match websocket_upgrade_response(&mut req) {
916 Ok(result) => result,
917 Err(error_response) => return error_response,
918 };
919
920 tokio::spawn(async move {
921 match on_upgrade.await {
922 Ok(upgraded) => {
923 let stream = WebSocketStream::from_raw_socket(
924 TokioIo::new(upgraded),
925 tokio_tungstenite::tungstenite::protocol::Role::Server,
926 None,
927 )
928 .await;
929 let connection = WebSocketConnection::new(stream, session);
930 ws_handler(connection, resources, bus).await;
931 }
932 Err(error) => {
933 tracing::warn!(
934 ranvier.ws.path = %path,
935 ranvier.ws.error = %error,
936 "websocket upgrade failed"
937 );
938 }
939 }
940 });
941
942 response
943 }
944 .instrument(span)
945 .await
946 }) as Pin<Box<dyn Future<Output = Response<Full<Bytes>>> + Send>>
947 });
948
949 self.routes.push(RouteEntry {
950 method: Method::GET,
951 pattern: RoutePattern::parse(&path_for_pattern),
952 handler: route_handler,
953 layers: Arc::new(Vec::new()),
954 apply_global_layers: true,
955 });
956
957 self
958 }
959
960 pub fn health_endpoint(mut self, path: impl Into<String>) -> Self {
965 self.health.health_path = Some(normalize_route_path(path.into()));
966 self
967 }
968
969 pub fn health_check<F, Fut, Err>(mut self, name: impl Into<String>, check: F) -> Self
973 where
974 F: Fn(Arc<R>) -> Fut + Send + Sync + 'static,
975 Fut: Future<Output = Result<(), Err>> + Send + 'static,
976 Err: ToString + Send + 'static,
977 {
978 if self.health.health_path.is_none() {
979 self.health.health_path = Some("/health".to_string());
980 }
981
982 let check_fn: HealthCheckFn<R> = Arc::new(move |resources: Arc<R>| {
983 let fut = check(resources);
984 Box::pin(async move { fut.await.map_err(|error| error.to_string()) })
985 });
986
987 self.health.checks.push(NamedHealthCheck {
988 name: name.into(),
989 check: check_fn,
990 });
991 self
992 }
993
994 pub fn readiness_liveness(
996 mut self,
997 readiness_path: impl Into<String>,
998 liveness_path: impl Into<String>,
999 ) -> Self {
1000 self.health.readiness_path = Some(normalize_route_path(readiness_path.into()));
1001 self.health.liveness_path = Some(normalize_route_path(liveness_path.into()));
1002 self
1003 }
1004
1005 pub fn readiness_liveness_default(self) -> Self {
1007 self.readiness_liveness("/ready", "/live")
1008 }
1009
1010 pub fn route<Out, E>(self, path: impl Into<String>, circuit: Axon<(), Out, E, R>) -> Self
1012 where
1013 Out: IntoResponse + Send + Sync + 'static,
1014 E: Send + 'static + std::fmt::Debug,
1015 {
1016 self.route_method(Method::GET, path, circuit)
1017 }
1018 pub fn route_method<Out, E>(
1027 self,
1028 method: Method,
1029 path: impl Into<String>,
1030 circuit: Axon<(), Out, E, R>,
1031 ) -> Self
1032 where
1033 Out: IntoResponse + Send + Sync + 'static,
1034 E: Send + 'static + std::fmt::Debug,
1035 {
1036 self.route_method_with_error(method, path, circuit, |error| {
1037 (
1038 StatusCode::INTERNAL_SERVER_ERROR,
1039 format!("Error: {:?}", error),
1040 )
1041 .into_response()
1042 })
1043 }
1044
1045 pub fn route_method_with_error<Out, E, H>(
1046 self,
1047 method: Method,
1048 path: impl Into<String>,
1049 circuit: Axon<(), Out, E, R>,
1050 error_handler: H,
1051 ) -> Self
1052 where
1053 Out: IntoResponse + Send + Sync + 'static,
1054 E: Send + 'static + std::fmt::Debug,
1055 H: Fn(&E) -> Response<Full<Bytes>> + Send + Sync + 'static,
1056 {
1057 self.route_method_with_error_and_layers(
1058 method,
1059 path,
1060 circuit,
1061 error_handler,
1062 Arc::new(Vec::new()),
1063 true,
1064 )
1065 }
1066
1067 pub fn route_method_with_layer<Out, E, L>(
1068 self,
1069 method: Method,
1070 path: impl Into<String>,
1071 circuit: Axon<(), Out, E, R>,
1072 layer: L,
1073 ) -> Self
1074 where
1075 Out: IntoResponse + Send + Sync + 'static,
1076 E: Send + 'static + std::fmt::Debug,
1077 L: Layer<BoxHttpService> + Clone + Send + Sync + 'static,
1078 L::Service: Service<Request<Incoming>, Response = Response<Full<Bytes>>, Error = Infallible>
1079 + Clone
1080 + Send
1081 + 'static,
1082 <L::Service as Service<Request<Incoming>>>::Future: Send + 'static,
1083 {
1084 self.route_method_with_error_and_layers(
1085 method,
1086 path,
1087 circuit,
1088 |error| {
1089 (
1090 StatusCode::INTERNAL_SERVER_ERROR,
1091 format!("Error: {:?}", error),
1092 )
1093 .into_response()
1094 },
1095 Arc::new(vec![to_service_layer(layer)]),
1096 true,
1097 )
1098 }
1099
1100 pub fn route_method_with_layer_override<Out, E, L>(
1101 self,
1102 method: Method,
1103 path: impl Into<String>,
1104 circuit: Axon<(), Out, E, R>,
1105 layer: L,
1106 ) -> Self
1107 where
1108 Out: IntoResponse + Send + Sync + 'static,
1109 E: Send + 'static + std::fmt::Debug,
1110 L: Layer<BoxHttpService> + Clone + Send + Sync + 'static,
1111 L::Service: Service<Request<Incoming>, Response = Response<Full<Bytes>>, Error = Infallible>
1112 + Clone
1113 + Send
1114 + 'static,
1115 <L::Service as Service<Request<Incoming>>>::Future: Send + 'static,
1116 {
1117 self.route_method_with_error_and_layers(
1118 method,
1119 path,
1120 circuit,
1121 |error| {
1122 (
1123 StatusCode::INTERNAL_SERVER_ERROR,
1124 format!("Error: {:?}", error),
1125 )
1126 .into_response()
1127 },
1128 Arc::new(vec![to_service_layer(layer)]),
1129 false,
1130 )
1131 }
1132
1133 fn route_method_with_error_and_layers<Out, E, H>(
1134 mut self,
1135 method: Method,
1136 path: impl Into<String>,
1137 circuit: Axon<(), Out, E, R>,
1138 error_handler: H,
1139 route_layers: Arc<Vec<ServiceLayer>>,
1140 apply_global_layers: bool,
1141 ) -> Self
1142 where
1143 Out: IntoResponse + Send + Sync + 'static,
1144 E: Send + 'static + std::fmt::Debug,
1145 H: Fn(&E) -> Response<Full<Bytes>> + Send + Sync + 'static,
1146 {
1147 let path_str: String = path.into();
1148 let circuit = Arc::new(circuit);
1149 let error_handler = Arc::new(error_handler);
1150 let route_bus_injectors = Arc::new(self.bus_injectors.clone());
1151 let path_for_pattern = path_str.clone();
1152 let path_for_handler = path_str;
1153 let method_for_pattern = method.clone();
1154 let method_for_handler = method;
1155
1156 let handler: RouteHandler<R> = Arc::new(move |req: Request<Incoming>, res: &R| {
1157 let circuit = circuit.clone();
1158 let error_handler = error_handler.clone();
1159 let route_bus_injectors = route_bus_injectors.clone();
1160 let res = res.clone();
1161 let path = path_for_handler.clone();
1162 let method = method_for_handler.clone();
1163
1164 Box::pin(async move {
1165 let request_id = uuid::Uuid::new_v4().to_string();
1166 let span = tracing::info_span!(
1167 "HTTPRequest",
1168 ranvier.http.method = %method,
1169 ranvier.http.path = %path,
1170 ranvier.http.request_id = %request_id
1171 );
1172
1173 async move {
1174 let mut bus = Bus::new();
1175 for injector in route_bus_injectors.iter() {
1176 injector(&req, &mut bus);
1177 }
1178 let result = circuit.execute((), &res, &mut bus).await;
1179 outcome_to_response_with_error(result, |error| error_handler(error))
1180 }
1181 .instrument(span)
1182 .await
1183 }) as Pin<Box<dyn Future<Output = Response<Full<Bytes>>> + Send>>
1184 });
1185
1186 self.routes.push(RouteEntry {
1187 method: method_for_pattern,
1188 pattern: RoutePattern::parse(&path_for_pattern),
1189 handler,
1190 layers: route_layers,
1191 apply_global_layers,
1192 });
1193 self
1194 }
1195
1196 fn route_method_with_body<Out, E>(
1198 mut self,
1199 method: Method,
1200 path: impl Into<String>,
1201 circuit: Axon<(), Out, E, R>,
1202 ) -> Self
1203 where
1204 Out: IntoResponse + Send + Sync + 'static,
1205 E: Send + 'static + std::fmt::Debug,
1206 {
1207 use crate::extract::HttpRequestBody;
1208
1209 let path_str: String = path.into();
1210 let circuit = Arc::new(circuit);
1211 let route_bus_injectors = Arc::new(self.bus_injectors.clone());
1212 let path_for_pattern = path_str.clone();
1213 let path_for_handler = path_str;
1214 let method_for_pattern = method.clone();
1215 let method_for_handler = method;
1216
1217 let handler: RouteHandler<R> = Arc::new(move |req: Request<Incoming>, res: &R| {
1218 let circuit = circuit.clone();
1219 let route_bus_injectors = route_bus_injectors.clone();
1220 let res = res.clone();
1221 let path = path_for_handler.clone();
1222 let method = method_for_handler.clone();
1223
1224 Box::pin(async move {
1225 let request_id = uuid::Uuid::new_v4().to_string();
1226 let span = tracing::info_span!(
1227 "HTTPRequest",
1228 ranvier.http.method = %method,
1229 ranvier.http.path = %path,
1230 ranvier.http.request_id = %request_id
1231 );
1232
1233 async move {
1234 let (parts, body) = req.into_parts();
1236
1237 let body_bytes = match body.collect().await {
1239 Ok(collected) => collected.to_bytes(),
1240 Err(err) => {
1241 tracing::warn!("Failed to collect request body: {:?}", err);
1242 Bytes::new()
1243 }
1244 };
1245
1246 let mut bus = Bus::new();
1247 let stub_req = {
1251 let mut builder = Request::builder()
1252 .method(&parts.method)
1253 .uri(parts.uri.clone());
1254 for (k, v) in &parts.headers {
1255 builder = builder.header(k, v);
1256 }
1257 let mut stub = builder
1258 .body(Full::new(Bytes::new()))
1259 .unwrap_or_else(|_| Request::new(Full::new(Bytes::new())));
1260 *stub.extensions_mut() = parts.extensions.clone();
1261 stub
1262 };
1263 if let Some(params) = stub_req.extensions().get::<PathParams>() {
1271 bus.insert(params.clone());
1272 }
1273
1274 bus.insert(HttpRequestBody::new(body_bytes));
1276
1277 let result = circuit.execute((), &res, &mut bus).await;
1278 outcome_to_response_with_error(result, |error| {
1279 Response::builder()
1280 .status(StatusCode::INTERNAL_SERVER_ERROR)
1281 .body(Full::new(Bytes::from(format!("Error: {:?}", error))))
1282 .unwrap()
1283 })
1284 }
1285 .instrument(span)
1286 .await
1287 }) as Pin<Box<dyn Future<Output = Response<Full<Bytes>>> + Send>>
1288 });
1289
1290 self.routes.push(RouteEntry {
1291 method: method_for_pattern,
1292 pattern: RoutePattern::parse(&path_for_pattern),
1293 handler,
1294 layers: Arc::new(Vec::new()),
1295 apply_global_layers: true,
1296 });
1297 self
1298 }
1299
1300 pub fn get<Out, E>(self, path: impl Into<String>, circuit: Axon<(), Out, E, R>) -> Self
1301 where
1302 Out: IntoResponse + Send + Sync + 'static,
1303 E: Send + 'static + std::fmt::Debug,
1304 {
1305 self.route_method(Method::GET, path, circuit)
1306 }
1307
1308 pub fn get_with_error<Out, E, H>(
1309 self,
1310 path: impl Into<String>,
1311 circuit: Axon<(), Out, E, R>,
1312 error_handler: H,
1313 ) -> Self
1314 where
1315 Out: IntoResponse + Send + Sync + 'static,
1316 E: Send + 'static + std::fmt::Debug,
1317 H: Fn(&E) -> Response<Full<Bytes>> + Send + Sync + 'static,
1318 {
1319 self.route_method_with_error(Method::GET, path, circuit, error_handler)
1320 }
1321
1322 pub fn get_with_layer<Out, E, L>(
1323 self,
1324 path: impl Into<String>,
1325 circuit: Axon<(), Out, E, R>,
1326 layer: L,
1327 ) -> Self
1328 where
1329 Out: IntoResponse + Send + Sync + 'static,
1330 E: Send + 'static + std::fmt::Debug,
1331 L: Layer<BoxHttpService> + Clone + Send + Sync + 'static,
1332 L::Service: Service<Request<Incoming>, Response = Response<Full<Bytes>>, Error = Infallible>
1333 + Clone
1334 + Send
1335 + 'static,
1336 <L::Service as Service<Request<Incoming>>>::Future: Send + 'static,
1337 {
1338 self.route_method_with_layer(Method::GET, path, circuit, layer)
1339 }
1340
1341 pub fn get_with_layer_override<Out, E, L>(
1342 self,
1343 path: impl Into<String>,
1344 circuit: Axon<(), Out, E, R>,
1345 layer: L,
1346 ) -> Self
1347 where
1348 Out: IntoResponse + Send + Sync + 'static,
1349 E: Send + 'static + std::fmt::Debug,
1350 L: Layer<BoxHttpService> + Clone + Send + Sync + 'static,
1351 L::Service: Service<Request<Incoming>, Response = Response<Full<Bytes>>, Error = Infallible>
1352 + Clone
1353 + Send
1354 + 'static,
1355 <L::Service as Service<Request<Incoming>>>::Future: Send + 'static,
1356 {
1357 self.route_method_with_layer_override(Method::GET, path, circuit, layer)
1358 }
1359
1360 pub fn post<Out, E>(self, path: impl Into<String>, circuit: Axon<(), Out, E, R>) -> Self
1361 where
1362 Out: IntoResponse + Send + Sync + 'static,
1363 E: Send + 'static + std::fmt::Debug,
1364 {
1365 self.route_method(Method::POST, path, circuit)
1366 }
1367
1368 pub fn post_body<Out, E>(self, path: impl Into<String>, circuit: Axon<(), Out, E, R>) -> Self
1389 where
1390 Out: IntoResponse + Send + Sync + 'static,
1391 E: Send + 'static + std::fmt::Debug,
1392 {
1393 self.route_method_with_body(Method::POST, path, circuit)
1394 }
1395
1396 pub fn put_body<Out, E>(self, path: impl Into<String>, circuit: Axon<(), Out, E, R>) -> Self
1398 where
1399 Out: IntoResponse + Send + Sync + 'static,
1400 E: Send + 'static + std::fmt::Debug,
1401 {
1402 self.route_method_with_body(Method::PUT, path, circuit)
1403 }
1404
1405 pub fn patch_body<Out, E>(self, path: impl Into<String>, circuit: Axon<(), Out, E, R>) -> Self
1407 where
1408 Out: IntoResponse + Send + Sync + 'static,
1409 E: Send + 'static + std::fmt::Debug,
1410 {
1411 self.route_method_with_body(Method::PATCH, path, circuit)
1412 }
1413
1414 pub fn put<Out, E>(self, path: impl Into<String>, circuit: Axon<(), Out, E, R>) -> Self
1415 where
1416 Out: IntoResponse + Send + Sync + 'static,
1417 E: Send + 'static + std::fmt::Debug,
1418 {
1419 self.route_method(Method::PUT, path, circuit)
1420 }
1421
1422 pub fn delete<Out, E>(self, path: impl Into<String>, circuit: Axon<(), Out, E, R>) -> Self
1423 where
1424 Out: IntoResponse + Send + Sync + 'static,
1425 E: Send + 'static + std::fmt::Debug,
1426 {
1427 self.route_method(Method::DELETE, path, circuit)
1428 }
1429
1430 pub fn patch<Out, E>(self, path: impl Into<String>, circuit: Axon<(), Out, E, R>) -> Self
1431 where
1432 Out: IntoResponse + Send + Sync + 'static,
1433 E: Send + 'static + std::fmt::Debug,
1434 {
1435 self.route_method(Method::PATCH, path, circuit)
1436 }
1437
1438 pub fn fallback<Out, E>(mut self, circuit: Axon<(), Out, E, R>) -> Self
1449 where
1450 Out: IntoResponse + Send + Sync + 'static,
1451 E: Send + 'static + std::fmt::Debug,
1452 {
1453 let circuit = Arc::new(circuit);
1454 let fallback_bus_injectors = Arc::new(self.bus_injectors.clone());
1455
1456 let handler: RouteHandler<R> = Arc::new(move |req: Request<Incoming>, res: &R| {
1457 let circuit = circuit.clone();
1458 let fallback_bus_injectors = fallback_bus_injectors.clone();
1459 let res = res.clone();
1460 Box::pin(async move {
1461 let request_id = uuid::Uuid::new_v4().to_string();
1462 let span = tracing::info_span!(
1463 "HTTPRequest",
1464 ranvier.http.method = "FALLBACK",
1465 ranvier.http.request_id = %request_id
1466 );
1467
1468 async move {
1469 let mut bus = Bus::new();
1470 for injector in fallback_bus_injectors.iter() {
1471 injector(&req, &mut bus);
1472 }
1473 let result = circuit.execute((), &res, &mut bus).await;
1474
1475 match result {
1476 Outcome::Next(output) => {
1477 let mut response = output.into_response();
1478 *response.status_mut() = StatusCode::NOT_FOUND;
1479 response
1480 }
1481 _ => Response::builder()
1482 .status(StatusCode::NOT_FOUND)
1483 .body(Full::new(Bytes::from("Not Found")))
1484 .unwrap(),
1485 }
1486 }
1487 .instrument(span)
1488 .await
1489 }) as Pin<Box<dyn Future<Output = Response<Full<Bytes>>> + Send>>
1490 });
1491
1492 self.fallback = Some(handler);
1493 self
1494 }
1495
1496 pub async fn run(self, resources: R) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
1498 self.run_with_shutdown_signal(resources, shutdown_signal())
1499 .await
1500 }
1501
1502 async fn run_with_shutdown_signal<S>(
1503 self,
1504 resources: R,
1505 shutdown_signal: S,
1506 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>>
1507 where
1508 S: Future<Output = ()> + Send,
1509 {
1510 let addr_str = self.addr.as_deref().unwrap_or("127.0.0.1:3000");
1511 let addr: SocketAddr = addr_str.parse()?;
1512
1513 let routes = Arc::new(self.routes);
1514 let fallback = self.fallback;
1515 let layers = Arc::new(self.layers);
1516 let health = Arc::new(self.health);
1517 let static_assets = Arc::new(self.static_assets);
1518 let on_start = self.on_start;
1519 let on_shutdown = self.on_shutdown;
1520 let graceful_shutdown_timeout = self.graceful_shutdown_timeout;
1521 let resources = Arc::new(resources);
1522
1523 let listener = TcpListener::bind(addr).await?;
1524 tracing::info!("Ranvier HTTP Ingress listening on http://{}", addr);
1525 if let Some(callback) = on_start.as_ref() {
1526 callback();
1527 }
1528
1529 tokio::pin!(shutdown_signal);
1530 let mut connections = tokio::task::JoinSet::new();
1531
1532 loop {
1533 tokio::select! {
1534 _ = &mut shutdown_signal => {
1535 tracing::info!("Shutdown signal received. Draining in-flight connections.");
1536 break;
1537 }
1538 accept_result = listener.accept() => {
1539 let (stream, _) = accept_result?;
1540 let io = TokioIo::new(stream);
1541
1542 let routes = routes.clone();
1543 let fallback = fallback.clone();
1544 let resources = resources.clone();
1545 let layers = layers.clone();
1546 let health = health.clone();
1547 let static_assets = static_assets.clone();
1548
1549 connections.spawn(async move {
1550 let service = build_http_service(
1551 routes,
1552 fallback,
1553 resources,
1554 layers,
1555 health,
1556 static_assets,
1557 );
1558 let hyper_service = TowerToHyperService::new(service);
1559 if let Err(err) = http1::Builder::new()
1560 .serve_connection(io, hyper_service)
1561 .with_upgrades()
1562 .await
1563 {
1564 tracing::error!("Error serving connection: {:?}", err);
1565 }
1566 });
1567 }
1568 Some(join_result) = connections.join_next(), if !connections.is_empty() => {
1569 if let Err(err) = join_result {
1570 tracing::warn!("Connection task join error: {:?}", err);
1571 }
1572 }
1573 }
1574 }
1575
1576 let _timed_out = drain_connections(&mut connections, graceful_shutdown_timeout).await;
1577
1578 drop(resources);
1579 if let Some(callback) = on_shutdown.as_ref() {
1580 callback();
1581 }
1582
1583 Ok(())
1584 }
1585
1586 pub fn into_raw_service(self, resources: R) -> RawIngressService<R> {
1602 let routes = Arc::new(self.routes);
1603 let fallback = self.fallback;
1604 let layers = Arc::new(self.layers);
1605 let health = Arc::new(self.health);
1606 let static_assets = Arc::new(self.static_assets);
1607 let resources = Arc::new(resources);
1608
1609 RawIngressService {
1610 routes,
1611 fallback,
1612 layers,
1613 health,
1614 static_assets,
1615 resources,
1616 }
1617 }
1618}
1619
1620fn build_http_service<R>(
1621 routes: Arc<Vec<RouteEntry<R>>>,
1622 fallback: Option<RouteHandler<R>>,
1623 resources: Arc<R>,
1624 layers: Arc<Vec<ServiceLayer>>,
1625 health: Arc<HealthConfig<R>>,
1626 static_assets: Arc<StaticAssetsConfig>,
1627) -> BoxHttpService
1628where
1629 R: ranvier_core::transition::ResourceRequirement + Clone + Send + Sync + 'static,
1630{
1631 let base_service = service_fn(move |req: Request<Incoming>| {
1632 let routes = routes.clone();
1633 let fallback = fallback.clone();
1634 let resources = resources.clone();
1635 let layers = layers.clone();
1636 let health = health.clone();
1637 let static_assets = static_assets.clone();
1638
1639 async move {
1640 let mut req = req;
1641 let method = req.method().clone();
1642 let path = req.uri().path().to_string();
1643
1644 if let Some(response) =
1645 maybe_handle_health_request(&method, &path, &health, resources.clone()).await
1646 {
1647 return Ok::<_, Infallible>(response);
1648 }
1649
1650 if let Some((entry, params)) = find_matching_route(routes.as_slice(), &method, &path) {
1651 req.extensions_mut().insert(params);
1652 let effective_layers = if entry.apply_global_layers {
1653 merge_layers(&layers, &entry.layers)
1654 } else {
1655 entry.layers.clone()
1656 };
1657
1658 if effective_layers.is_empty() {
1659 Ok::<_, Infallible>((entry.handler)(req, &resources).await)
1660 } else {
1661 let route_service = build_route_service(
1662 entry.handler.clone(),
1663 resources.clone(),
1664 effective_layers,
1665 );
1666 route_service.oneshot(req).await
1667 }
1668 } else {
1669 let req =
1670 match maybe_handle_static_request(req, &method, &path, static_assets.as_ref())
1671 .await
1672 {
1673 Ok(req) => req,
1674 Err(response) => return Ok(response),
1675 };
1676
1677 if let Some(ref fb) = fallback {
1678 if layers.is_empty() {
1679 Ok(fb(req, &resources).await)
1680 } else {
1681 let fallback_service =
1682 build_route_service(fb.clone(), resources.clone(), layers.clone());
1683 fallback_service.oneshot(req).await
1684 }
1685 } else {
1686 Ok(Response::builder()
1687 .status(StatusCode::NOT_FOUND)
1688 .body(Full::new(Bytes::from("Not Found")))
1689 .unwrap())
1690 }
1691 }
1692 }
1693 });
1694
1695 BoxCloneService::new(base_service)
1696}
1697
1698fn build_route_service<R>(
1699 handler: RouteHandler<R>,
1700 resources: Arc<R>,
1701 layers: Arc<Vec<ServiceLayer>>,
1702) -> BoxHttpService
1703where
1704 R: ranvier_core::transition::ResourceRequirement + Clone + Send + Sync + 'static,
1705{
1706 let base_service = service_fn(move |req: Request<Incoming>| {
1707 let handler = handler.clone();
1708 let resources = resources.clone();
1709 async move { Ok::<_, Infallible>(handler(req, &resources).await) }
1710 });
1711
1712 let mut service = BoxCloneService::new(base_service);
1713 for layer in layers.iter() {
1714 service = layer(service);
1715 }
1716 service
1717}
1718
1719fn merge_layers(
1720 global_layers: &Arc<Vec<ServiceLayer>>,
1721 route_layers: &Arc<Vec<ServiceLayer>>,
1722) -> Arc<Vec<ServiceLayer>> {
1723 if global_layers.is_empty() {
1724 return route_layers.clone();
1725 }
1726 if route_layers.is_empty() {
1727 return global_layers.clone();
1728 }
1729
1730 let mut combined = Vec::with_capacity(global_layers.len() + route_layers.len());
1731 combined.extend(global_layers.iter().cloned());
1732 combined.extend(route_layers.iter().cloned());
1733 Arc::new(combined)
1734}
1735
1736async fn maybe_handle_health_request<R>(
1737 method: &Method,
1738 path: &str,
1739 health: &HealthConfig<R>,
1740 resources: Arc<R>,
1741) -> Option<Response<Full<Bytes>>>
1742where
1743 R: ranvier_core::transition::ResourceRequirement + Clone + Send + Sync + 'static,
1744{
1745 if method != Method::GET {
1746 return None;
1747 }
1748
1749 if let Some(liveness_path) = health.liveness_path.as_ref() {
1750 if path == liveness_path {
1751 return Some(health_json_response("liveness", true, Vec::new()));
1752 }
1753 }
1754
1755 if let Some(readiness_path) = health.readiness_path.as_ref() {
1756 if path == readiness_path {
1757 let (healthy, checks) = run_named_health_checks(&health.checks, resources).await;
1758 return Some(health_json_response("readiness", healthy, checks));
1759 }
1760 }
1761
1762 if let Some(health_path) = health.health_path.as_ref() {
1763 if path == health_path {
1764 let (healthy, checks) = run_named_health_checks(&health.checks, resources).await;
1765 return Some(health_json_response("health", healthy, checks));
1766 }
1767 }
1768
1769 None
1770}
1771
1772async fn maybe_handle_static_request(
1773 req: Request<Incoming>,
1774 method: &Method,
1775 path: &str,
1776 static_assets: &StaticAssetsConfig,
1777) -> Result<Request<Incoming>, Response<Full<Bytes>>> {
1778 if method != Method::GET && method != Method::HEAD {
1779 return Ok(req);
1780 }
1781
1782 if let Some(mount) = static_assets
1783 .mounts
1784 .iter()
1785 .find(|mount| strip_mount_prefix(path, &mount.route_prefix).is_some())
1786 {
1787 let accept_encoding = req.headers().get(http::header::ACCEPT_ENCODING).cloned();
1788 let Some(stripped_path) = strip_mount_prefix(path, &mount.route_prefix) else {
1789 return Ok(req);
1790 };
1791 let rewritten = rewrite_request_path(req, &stripped_path);
1792 let service = ServeDir::new(&mount.directory);
1793 let response = match service.oneshot(rewritten).await {
1794 Ok(response) => response,
1795 Err(_) => {
1796 return Err(Response::builder()
1797 .status(StatusCode::INTERNAL_SERVER_ERROR)
1798 .body(Full::new(Bytes::from("Failed to serve static asset")))
1799 .unwrap_or_else(|_| Response::new(Full::new(Bytes::new()))));
1800 }
1801 };
1802 let response =
1803 collect_static_response(response, static_assets.cache_control.as_deref()).await;
1804 return Err(maybe_compress_static_response(
1805 response,
1806 accept_encoding,
1807 static_assets.enable_compression,
1808 )
1809 .await);
1810 }
1811
1812 if let Some(spa_file) = static_assets.spa_fallback.as_ref() {
1813 if looks_like_spa_request(path) {
1814 let accept_encoding = req.headers().get(http::header::ACCEPT_ENCODING).cloned();
1815 let service = ServeFile::new(spa_file);
1816 let response = match service.oneshot(req).await {
1817 Ok(response) => response,
1818 Err(_) => {
1819 return Err(Response::builder()
1820 .status(StatusCode::INTERNAL_SERVER_ERROR)
1821 .body(Full::new(Bytes::from("Failed to serve SPA fallback")))
1822 .unwrap_or_else(|_| Response::new(Full::new(Bytes::new()))));
1823 }
1824 };
1825 let response =
1826 collect_static_response(response, static_assets.cache_control.as_deref()).await;
1827 return Err(maybe_compress_static_response(
1828 response,
1829 accept_encoding,
1830 static_assets.enable_compression,
1831 )
1832 .await);
1833 }
1834 }
1835
1836 Ok(req)
1837}
1838
1839fn strip_mount_prefix(path: &str, prefix: &str) -> Option<String> {
1840 let normalized_prefix = if prefix == "/" {
1841 "/"
1842 } else {
1843 prefix.trim_end_matches('/')
1844 };
1845
1846 if normalized_prefix == "/" {
1847 return Some(path.to_string());
1848 }
1849
1850 if path == normalized_prefix {
1851 return Some("/".to_string());
1852 }
1853
1854 let with_slash = format!("{normalized_prefix}/");
1855 path.strip_prefix(&with_slash)
1856 .map(|stripped| format!("/{}", stripped))
1857}
1858
1859fn rewrite_request_path(mut req: Request<Incoming>, new_path: &str) -> Request<Incoming> {
1860 let query = req.uri().query().map(str::to_string);
1861 let path_and_query = match query {
1862 Some(query) => format!("{new_path}?{query}"),
1863 None => new_path.to_string(),
1864 };
1865
1866 let mut parts = req.uri().clone().into_parts();
1867 if let Ok(parsed_path_and_query) = path_and_query.parse() {
1868 parts.path_and_query = Some(parsed_path_and_query);
1869 if let Ok(uri) = Uri::from_parts(parts) {
1870 *req.uri_mut() = uri;
1871 }
1872 }
1873
1874 req
1875}
1876
1877async fn collect_static_response<B>(
1878 response: Response<B>,
1879 cache_control: Option<&str>,
1880) -> Response<Full<Bytes>>
1881where
1882 B: Body<Data = Bytes> + Send + 'static,
1883 B::Error: std::fmt::Display,
1884{
1885 let status = response.status();
1886 let headers = response.headers().clone();
1887 let body = response.into_body();
1888 let collected = body.collect().await;
1889
1890 let bytes = match collected {
1891 Ok(value) => value.to_bytes(),
1892 Err(error) => Bytes::from(error.to_string()),
1893 };
1894
1895 let mut builder = Response::builder().status(status);
1896 for (name, value) in headers.iter() {
1897 builder = builder.header(name, value);
1898 }
1899
1900 let mut response = builder
1901 .body(Full::new(bytes))
1902 .unwrap_or_else(|_| Response::new(Full::new(Bytes::new())));
1903
1904 if status == StatusCode::OK {
1905 if let Some(value) = cache_control {
1906 if !response.headers().contains_key(http::header::CACHE_CONTROL) {
1907 if let Ok(header_value) = http::HeaderValue::from_str(value) {
1908 response
1909 .headers_mut()
1910 .insert(http::header::CACHE_CONTROL, header_value);
1911 }
1912 }
1913 }
1914 }
1915
1916 response
1917}
1918
1919fn looks_like_spa_request(path: &str) -> bool {
1920 let tail = path.rsplit('/').next().unwrap_or_default();
1921 !tail.contains('.')
1922}
1923
1924async fn maybe_compress_static_response(
1925 response: Response<Full<Bytes>>,
1926 accept_encoding: Option<http::HeaderValue>,
1927 enable_compression: bool,
1928) -> Response<Full<Bytes>> {
1929 if !enable_compression {
1930 return response;
1931 }
1932
1933 let Some(accept_encoding) = accept_encoding else {
1934 return response;
1935 };
1936
1937 let mut request = Request::builder()
1938 .uri("/")
1939 .body(Full::new(Bytes::new()))
1940 .unwrap_or_else(|_| Request::new(Full::new(Bytes::new())));
1941 request
1942 .headers_mut()
1943 .insert(http::header::ACCEPT_ENCODING, accept_encoding);
1944
1945 let service = CompressionLayer::new().layer(service_fn({
1946 let response = response.clone();
1947 move |_req: Request<Full<Bytes>>| {
1948 let response = response.clone();
1949 async move { Ok::<_, Infallible>(response) }
1950 }
1951 }));
1952
1953 match service.oneshot(request).await {
1954 Ok(compressed) => collect_static_response(compressed, None).await,
1955 Err(_) => response,
1956 }
1957}
1958
1959async fn run_named_health_checks<R>(
1960 checks: &[NamedHealthCheck<R>],
1961 resources: Arc<R>,
1962) -> (bool, Vec<HealthCheckReport>)
1963where
1964 R: ranvier_core::transition::ResourceRequirement + Clone + Send + Sync + 'static,
1965{
1966 let mut reports = Vec::with_capacity(checks.len());
1967 let mut healthy = true;
1968
1969 for check in checks {
1970 match (check.check)(resources.clone()).await {
1971 Ok(()) => reports.push(HealthCheckReport {
1972 name: check.name.clone(),
1973 status: "ok",
1974 error: None,
1975 }),
1976 Err(error) => {
1977 healthy = false;
1978 reports.push(HealthCheckReport {
1979 name: check.name.clone(),
1980 status: "error",
1981 error: Some(error),
1982 });
1983 }
1984 }
1985 }
1986
1987 (healthy, reports)
1988}
1989
1990fn health_json_response(
1991 probe: &'static str,
1992 healthy: bool,
1993 checks: Vec<HealthCheckReport>,
1994) -> Response<Full<Bytes>> {
1995 let status_code = if healthy {
1996 StatusCode::OK
1997 } else {
1998 StatusCode::SERVICE_UNAVAILABLE
1999 };
2000 let status = if healthy { "ok" } else { "degraded" };
2001 let payload = HealthReport {
2002 status,
2003 probe,
2004 checks,
2005 };
2006
2007 let body = serde_json::to_vec(&payload)
2008 .unwrap_or_else(|_| br#"{"status":"error","probe":"health"}"#.to_vec());
2009
2010 Response::builder()
2011 .status(status_code)
2012 .header(http::header::CONTENT_TYPE, "application/json")
2013 .body(Full::new(Bytes::from(body)))
2014 .unwrap()
2015}
2016
2017async fn shutdown_signal() {
2018 #[cfg(unix)]
2019 {
2020 use tokio::signal::unix::{SignalKind, signal};
2021
2022 match signal(SignalKind::terminate()) {
2023 Ok(mut terminate) => {
2024 tokio::select! {
2025 _ = tokio::signal::ctrl_c() => {}
2026 _ = terminate.recv() => {}
2027 }
2028 }
2029 Err(err) => {
2030 tracing::warn!("Failed to install SIGTERM handler: {:?}", err);
2031 if let Err(ctrl_c_err) = tokio::signal::ctrl_c().await {
2032 tracing::warn!("Failed to listen for Ctrl+C: {:?}", ctrl_c_err);
2033 }
2034 }
2035 }
2036 }
2037
2038 #[cfg(not(unix))]
2039 {
2040 if let Err(err) = tokio::signal::ctrl_c().await {
2041 tracing::warn!("Failed to listen for Ctrl+C: {:?}", err);
2042 }
2043 }
2044}
2045
2046async fn drain_connections(
2047 connections: &mut tokio::task::JoinSet<()>,
2048 graceful_shutdown_timeout: Duration,
2049) -> bool {
2050 if connections.is_empty() {
2051 return false;
2052 }
2053
2054 let drain_result = tokio::time::timeout(graceful_shutdown_timeout, async {
2055 while let Some(join_result) = connections.join_next().await {
2056 if let Err(err) = join_result {
2057 tracing::warn!("Connection task join error during shutdown: {:?}", err);
2058 }
2059 }
2060 })
2061 .await;
2062
2063 if drain_result.is_err() {
2064 tracing::warn!(
2065 "Graceful shutdown timeout reached ({:?}). Aborting remaining connections.",
2066 graceful_shutdown_timeout
2067 );
2068 connections.abort_all();
2069 while let Some(join_result) = connections.join_next().await {
2070 if let Err(err) = join_result {
2071 tracing::warn!("Connection task abort join error: {:?}", err);
2072 }
2073 }
2074 true
2075 } else {
2076 false
2077 }
2078}
2079
2080impl<R> Default for HttpIngress<R>
2081where
2082 R: ranvier_core::transition::ResourceRequirement + Clone + Send + Sync + 'static,
2083{
2084 fn default() -> Self {
2085 Self::new()
2086 }
2087}
2088
2089#[deprecated(since = "0.9.0", note = "Internal service type")]
2091#[derive(Clone)]
2092pub struct RawIngressService<R> {
2093 routes: Arc<Vec<RouteEntry<R>>>,
2094 fallback: Option<RouteHandler<R>>,
2095 layers: Arc<Vec<ServiceLayer>>,
2096 health: Arc<HealthConfig<R>>,
2097 static_assets: Arc<StaticAssetsConfig>,
2098 resources: Arc<R>,
2099}
2100
2101impl<R> Service<Request<Incoming>> for RawIngressService<R>
2102where
2103 R: ranvier_core::transition::ResourceRequirement + Clone + Send + Sync + 'static,
2104{
2105 type Response = Response<Full<Bytes>>;
2106 type Error = Infallible;
2107 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
2108
2109 fn poll_ready(
2110 &mut self,
2111 _cx: &mut std::task::Context<'_>,
2112 ) -> std::task::Poll<Result<(), Self::Error>> {
2113 std::task::Poll::Ready(Ok(()))
2114 }
2115
2116 fn call(&mut self, req: Request<Incoming>) -> Self::Future {
2117 let routes = self.routes.clone();
2118 let fallback = self.fallback.clone();
2119 let layers = self.layers.clone();
2120 let health = self.health.clone();
2121 let static_assets = self.static_assets.clone();
2122 let resources = self.resources.clone();
2123
2124 Box::pin(async move {
2125 let service =
2126 build_http_service(routes, fallback, resources, layers, health, static_assets);
2127 service.oneshot(req).await
2128 })
2129 }
2130}
2131
2132#[cfg(test)]
2133mod tests {
2134 use super::*;
2135 use async_trait::async_trait;
2136 use futures_util::{SinkExt, StreamExt};
2137 use ranvier_observe::{HttpMetrics, HttpMetricsLayer, IncomingTraceContext, TraceContextLayer};
2138 use serde::Deserialize;
2139 use std::fs;
2140 use std::sync::atomic::{AtomicBool, Ordering};
2141 use tempfile::tempdir;
2142 use tokio::io::{AsyncReadExt, AsyncWriteExt};
2143 use tokio_tungstenite::tungstenite::Message as WsClientMessage;
2144 use tokio_tungstenite::tungstenite::client::IntoClientRequest;
2145
2146 async fn connect_with_retry(addr: std::net::SocketAddr) -> tokio::net::TcpStream {
2147 let deadline = tokio::time::Instant::now() + Duration::from_secs(2);
2148
2149 loop {
2150 match tokio::net::TcpStream::connect(addr).await {
2151 Ok(stream) => return stream,
2152 Err(error) => {
2153 if tokio::time::Instant::now() >= deadline {
2154 panic!("connect server: {error}");
2155 }
2156 tokio::time::sleep(Duration::from_millis(25)).await;
2157 }
2158 }
2159 }
2160 }
2161
2162 #[test]
2163 fn route_pattern_matches_static_path() {
2164 let pattern = RoutePattern::parse("/orders/list");
2165 let params = pattern.match_path("/orders/list").expect("should match");
2166 assert!(params.into_inner().is_empty());
2167 }
2168
2169 #[test]
2170 fn route_pattern_matches_param_segments() {
2171 let pattern = RoutePattern::parse("/orders/:id/items/:item_id");
2172 let params = pattern
2173 .match_path("/orders/42/items/sku-123")
2174 .expect("should match");
2175 assert_eq!(params.get("id"), Some("42"));
2176 assert_eq!(params.get("item_id"), Some("sku-123"));
2177 }
2178
2179 #[test]
2180 fn route_pattern_matches_wildcard_segment() {
2181 let pattern = RoutePattern::parse("/assets/*path");
2182 let params = pattern
2183 .match_path("/assets/css/theme/light.css")
2184 .expect("should match");
2185 assert_eq!(params.get("path"), Some("css/theme/light.css"));
2186 }
2187
2188 #[test]
2189 fn route_pattern_rejects_non_matching_path() {
2190 let pattern = RoutePattern::parse("/orders/:id");
2191 assert!(pattern.match_path("/users/42").is_none());
2192 }
2193
2194 #[test]
2195 fn graceful_shutdown_timeout_defaults_to_30_seconds() {
2196 let ingress = HttpIngress::<()>::new();
2197 assert_eq!(ingress.graceful_shutdown_timeout, Duration::from_secs(30));
2198 assert!(ingress.layers.is_empty());
2199 assert!(ingress.bus_injectors.is_empty());
2200 assert!(ingress.static_assets.mounts.is_empty());
2201 assert!(ingress.on_start.is_none());
2202 assert!(ingress.on_shutdown.is_none());
2203 }
2204
2205 #[test]
2206 fn layer_registration_stacks_globally() {
2207 let ingress = HttpIngress::<()>::new()
2208 .layer(tower::layer::util::Identity::new())
2209 .layer(tower::layer::util::Identity::new());
2210 assert_eq!(ingress.layers.len(), 2);
2211 }
2212
2213 #[test]
2214 fn layer_accepts_tower_http_cors_layer() {
2215 let ingress = HttpIngress::<()>::new().layer(tower_http::cors::CorsLayer::permissive());
2216 assert_eq!(ingress.layers.len(), 1);
2217 }
2218
2219 #[test]
2220 fn route_without_layer_keeps_empty_route_middleware_stack() {
2221 let ingress =
2222 HttpIngress::<()>::new().get("/ping", Axon::<(), (), Infallible, ()>::new("Ping"));
2223 assert_eq!(ingress.routes.len(), 1);
2224 assert!(ingress.routes[0].layers.is_empty());
2225 assert!(ingress.routes[0].apply_global_layers);
2226 }
2227
2228 #[test]
2229 fn route_with_layer_registers_route_middleware_stack() {
2230 let ingress = HttpIngress::<()>::new().get_with_layer(
2231 "/ping",
2232 Axon::<(), (), Infallible, ()>::new("Ping"),
2233 tower::layer::util::Identity::new(),
2234 );
2235 assert_eq!(ingress.routes.len(), 1);
2236 assert_eq!(ingress.routes[0].layers.len(), 1);
2237 assert!(ingress.routes[0].apply_global_layers);
2238 }
2239
2240 #[test]
2241 fn route_with_layer_override_disables_global_layers() {
2242 let ingress = HttpIngress::<()>::new().get_with_layer_override(
2243 "/ping",
2244 Axon::<(), (), Infallible, ()>::new("Ping"),
2245 tower::layer::util::Identity::new(),
2246 );
2247 assert_eq!(ingress.routes.len(), 1);
2248 assert_eq!(ingress.routes[0].layers.len(), 1);
2249 assert!(!ingress.routes[0].apply_global_layers);
2250 }
2251
2252 #[test]
2253 fn timeout_layer_registers_builtin_middleware() {
2254 let ingress = HttpIngress::<()>::new().timeout_layer(Duration::from_secs(1));
2255 assert_eq!(ingress.layers.len(), 1);
2256 }
2257
2258 #[test]
2259 fn request_id_layer_registers_builtin_middleware() {
2260 let ingress = HttpIngress::<()>::new().request_id_layer();
2261 assert_eq!(ingress.layers.len(), 1);
2262 }
2263
2264 #[test]
2265 fn compression_layer_registers_builtin_middleware() {
2266 let ingress = HttpIngress::<()>::new().compression_layer();
2267 assert!(ingress.static_assets.enable_compression);
2268 }
2269
2270 #[test]
2271 fn bus_injector_registration_adds_hook() {
2272 let ingress = HttpIngress::<()>::new().bus_injector(|_req, bus| {
2273 bus.insert("ok".to_string());
2274 });
2275 assert_eq!(ingress.bus_injectors.len(), 1);
2276 }
2277
2278 #[test]
2279 fn ws_route_registers_get_route_pattern() {
2280 let ingress =
2281 HttpIngress::<()>::new().ws("/ws/events", |_socket, _resources, _bus| async {});
2282 assert_eq!(ingress.routes.len(), 1);
2283 assert_eq!(ingress.routes[0].method, Method::GET);
2284 assert_eq!(ingress.routes[0].pattern.raw, "/ws/events");
2285 }
2286
2287 #[derive(Debug, Deserialize)]
2288 struct WsWelcomeFrame {
2289 connection_id: String,
2290 path: String,
2291 tenant: String,
2292 }
2293
2294 #[tokio::test]
2295 async fn ws_route_upgrades_and_bridges_event_source_sink_with_connection_bus() {
2296 let probe = std::net::TcpListener::bind("127.0.0.1:0").expect("bind probe");
2297 let addr = probe.local_addr().expect("local addr");
2298 drop(probe);
2299
2300 let ingress = HttpIngress::<()>::new()
2301 .bind(addr.to_string())
2302 .bus_injector(|req, bus| {
2303 if let Some(value) = req
2304 .headers()
2305 .get("x-tenant-id")
2306 .and_then(|v| v.to_str().ok())
2307 {
2308 bus.insert(value.to_string());
2309 }
2310 })
2311 .ws("/ws/echo", |mut socket, _resources, bus| async move {
2312 let tenant = bus
2313 .read::<String>()
2314 .cloned()
2315 .unwrap_or_else(|| "unknown".to_string());
2316 if let Some(session) = bus.read::<WebSocketSessionContext>() {
2317 let welcome = serde_json::json!({
2318 "connection_id": session.connection_id().to_string(),
2319 "path": session.path(),
2320 "tenant": tenant,
2321 });
2322 let _ = socket.send_json(&welcome).await;
2323 }
2324
2325 while let Some(event) = socket.next_event().await {
2326 match event {
2327 WebSocketEvent::Text(text) => {
2328 let _ = socket.send_event(format!("echo:{text}")).await;
2329 }
2330 WebSocketEvent::Binary(bytes) => {
2331 let _ = socket.send_event(bytes).await;
2332 }
2333 WebSocketEvent::Close => break,
2334 WebSocketEvent::Ping(_) | WebSocketEvent::Pong(_) => {}
2335 }
2336 }
2337 });
2338
2339 let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel::<()>();
2340 let server = tokio::spawn(async move {
2341 ingress
2342 .run_with_shutdown_signal((), async move {
2343 let _ = shutdown_rx.await;
2344 })
2345 .await
2346 });
2347
2348 let ws_uri = format!("ws://{addr}/ws/echo?room=alpha");
2349 let mut ws_request = ws_uri
2350 .as_str()
2351 .into_client_request()
2352 .expect("ws client request");
2353 ws_request
2354 .headers_mut()
2355 .insert("x-tenant-id", http::HeaderValue::from_static("acme"));
2356 let (mut client, _response) = tokio_tungstenite::connect_async(ws_request)
2357 .await
2358 .expect("websocket connect");
2359
2360 let welcome = client
2361 .next()
2362 .await
2363 .expect("welcome frame")
2364 .expect("welcome frame ok");
2365 let welcome_text = match welcome {
2366 WsClientMessage::Text(text) => text.to_string(),
2367 other => panic!("expected text welcome frame, got {other:?}"),
2368 };
2369 let welcome_payload: WsWelcomeFrame =
2370 serde_json::from_str(&welcome_text).expect("welcome json");
2371 assert_eq!(welcome_payload.path, "/ws/echo");
2372 assert_eq!(welcome_payload.tenant, "acme");
2373 assert!(!welcome_payload.connection_id.is_empty());
2374
2375 client
2376 .send(WsClientMessage::Text("hello".into()))
2377 .await
2378 .expect("send text");
2379 let echo_text = client
2380 .next()
2381 .await
2382 .expect("echo text frame")
2383 .expect("echo text frame ok");
2384 assert_eq!(echo_text, WsClientMessage::Text("echo:hello".into()));
2385
2386 client
2387 .send(WsClientMessage::Binary(vec![1, 2, 3, 4].into()))
2388 .await
2389 .expect("send binary");
2390 let echo_binary = client
2391 .next()
2392 .await
2393 .expect("echo binary frame")
2394 .expect("echo binary frame ok");
2395 assert_eq!(
2396 echo_binary,
2397 WsClientMessage::Binary(vec![1, 2, 3, 4].into())
2398 );
2399
2400 client.close(None).await.expect("close websocket");
2401
2402 let _ = shutdown_tx.send(());
2403 server
2404 .await
2405 .expect("server join")
2406 .expect("server shutdown should succeed");
2407 }
2408
2409 #[derive(Clone)]
2410 struct EchoTrace;
2411
2412 #[async_trait]
2413 impl Transition<(), String> for EchoTrace {
2414 type Error = Infallible;
2415 type Resources = ();
2416
2417 async fn run(
2418 &self,
2419 _state: (),
2420 _resources: &Self::Resources,
2421 bus: &mut Bus,
2422 ) -> Outcome<String, Self::Error> {
2423 let trace_id = bus
2424 .read::<String>()
2425 .cloned()
2426 .unwrap_or_else(|| "missing-trace".to_string());
2427 Outcome::next(trace_id)
2428 }
2429 }
2430
2431 #[tokio::test]
2432 async fn observe_trace_context_and_metrics_layers_work_with_ingress() {
2433 let metrics = HttpMetrics::default();
2434 let ingress = HttpIngress::<()>::new()
2435 .layer(TraceContextLayer::new())
2436 .layer(HttpMetricsLayer::new(metrics.clone()))
2437 .bus_injector(|req, bus| {
2438 if let Some(trace) = req.extensions().get::<IncomingTraceContext>() {
2439 bus.insert(trace.trace_id().to_string());
2440 }
2441 })
2442 .get(
2443 "/trace",
2444 Axon::<(), (), Infallible, ()>::new("EchoTrace").then(EchoTrace),
2445 );
2446
2447 let app = crate::test_harness::TestApp::new(ingress, ());
2448 let response = app
2449 .send(crate::test_harness::TestRequest::get("/trace").header(
2450 "traceparent",
2451 "00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01",
2452 ))
2453 .await
2454 .expect("request should succeed");
2455
2456 assert_eq!(response.status(), StatusCode::OK);
2457 assert_eq!(
2458 response.text().expect("utf8 response"),
2459 "4bf92f3577b34da6a3ce929d0e0e4736"
2460 );
2461
2462 let snapshot = metrics.snapshot();
2463 assert_eq!(snapshot.requests_total, 1);
2464 assert_eq!(snapshot.requests_error, 0);
2465 }
2466
2467 #[test]
2468 fn route_descriptors_export_http_and_health_paths() {
2469 let ingress = HttpIngress::<()>::new()
2470 .get(
2471 "/orders/:id",
2472 Axon::<(), (), Infallible, ()>::new("OrderById"),
2473 )
2474 .health_endpoint("/healthz")
2475 .readiness_liveness("/readyz", "/livez");
2476
2477 let descriptors = ingress.route_descriptors();
2478
2479 assert!(
2480 descriptors
2481 .iter()
2482 .any(|descriptor| descriptor.method() == Method::GET
2483 && descriptor.path_pattern() == "/orders/:id")
2484 );
2485 assert!(
2486 descriptors
2487 .iter()
2488 .any(|descriptor| descriptor.method() == Method::GET
2489 && descriptor.path_pattern() == "/healthz")
2490 );
2491 assert!(
2492 descriptors
2493 .iter()
2494 .any(|descriptor| descriptor.method() == Method::GET
2495 && descriptor.path_pattern() == "/readyz")
2496 );
2497 assert!(
2498 descriptors
2499 .iter()
2500 .any(|descriptor| descriptor.method() == Method::GET
2501 && descriptor.path_pattern() == "/livez")
2502 );
2503 }
2504
2505 #[tokio::test]
2506 async fn lifecycle_hooks_fire_on_start_and_shutdown() {
2507 let started = Arc::new(AtomicBool::new(false));
2508 let shutdown = Arc::new(AtomicBool::new(false));
2509
2510 let started_flag = started.clone();
2511 let shutdown_flag = shutdown.clone();
2512
2513 let ingress = HttpIngress::<()>::new()
2514 .bind("127.0.0.1:0")
2515 .on_start(move || {
2516 started_flag.store(true, Ordering::SeqCst);
2517 })
2518 .on_shutdown(move || {
2519 shutdown_flag.store(true, Ordering::SeqCst);
2520 })
2521 .graceful_shutdown(Duration::from_millis(50));
2522
2523 ingress
2524 .run_with_shutdown_signal((), async {
2525 tokio::time::sleep(Duration::from_millis(20)).await;
2526 })
2527 .await
2528 .expect("server should exit gracefully");
2529
2530 assert!(started.load(Ordering::SeqCst));
2531 assert!(shutdown.load(Ordering::SeqCst));
2532 }
2533
2534 #[tokio::test]
2535 async fn graceful_shutdown_drains_in_flight_requests_before_exit() {
2536 #[derive(Clone)]
2537 struct SlowDrainRoute;
2538
2539 #[async_trait]
2540 impl Transition<(), &'static str> for SlowDrainRoute {
2541 type Error = Infallible;
2542 type Resources = ();
2543
2544 async fn run(
2545 &self,
2546 _state: (),
2547 _resources: &Self::Resources,
2548 _bus: &mut Bus,
2549 ) -> Outcome<&'static str, Self::Error> {
2550 tokio::time::sleep(Duration::from_millis(120)).await;
2551 Outcome::next("drained-ok")
2552 }
2553 }
2554
2555 let probe = std::net::TcpListener::bind("127.0.0.1:0").expect("bind probe");
2556 let addr = probe.local_addr().expect("local addr");
2557 drop(probe);
2558
2559 let ingress = HttpIngress::<()>::new()
2560 .bind(addr.to_string())
2561 .graceful_shutdown(Duration::from_millis(500))
2562 .get(
2563 "/drain",
2564 Axon::<(), (), Infallible, ()>::new("SlowDrain").then(SlowDrainRoute),
2565 );
2566
2567 let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel::<()>();
2568 let server = tokio::spawn(async move {
2569 ingress
2570 .run_with_shutdown_signal((), async move {
2571 let _ = shutdown_rx.await;
2572 })
2573 .await
2574 });
2575
2576 let mut stream = connect_with_retry(addr).await;
2577 stream
2578 .write_all(b"GET /drain HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n")
2579 .await
2580 .expect("write request");
2581
2582 tokio::time::sleep(Duration::from_millis(20)).await;
2583 let _ = shutdown_tx.send(());
2584
2585 let mut buf = Vec::new();
2586 stream.read_to_end(&mut buf).await.expect("read response");
2587 let response = String::from_utf8_lossy(&buf);
2588 assert!(response.starts_with("HTTP/1.1 200"), "{response}");
2589 assert!(response.contains("drained-ok"), "{response}");
2590
2591 server
2592 .await
2593 .expect("server join")
2594 .expect("server shutdown should succeed");
2595 }
2596
2597 #[tokio::test]
2598 async fn serve_dir_serves_static_file_with_cache_and_metadata_headers() {
2599 let temp = tempdir().expect("tempdir");
2600 let root = temp.path().join("public");
2601 fs::create_dir_all(&root).expect("create dir");
2602 let file = root.join("hello.txt");
2603 fs::write(&file, "hello static").expect("write file");
2604
2605 let ingress =
2606 Ranvier::http::<()>().serve_dir("/static", root.to_string_lossy().to_string());
2607 let app = crate::test_harness::TestApp::new(ingress, ());
2608 let response = app
2609 .send(crate::test_harness::TestRequest::get("/static/hello.txt"))
2610 .await
2611 .expect("request should succeed");
2612
2613 assert_eq!(response.status(), StatusCode::OK);
2614 assert_eq!(response.text().expect("utf8"), "hello static");
2615 assert!(response.header("cache-control").is_some());
2616 let has_metadata_header =
2617 response.header("etag").is_some() || response.header("last-modified").is_some();
2618 assert!(has_metadata_header);
2619 }
2620
2621 #[tokio::test]
2622 async fn spa_fallback_returns_index_for_unmatched_path() {
2623 let temp = tempdir().expect("tempdir");
2624 let index = temp.path().join("index.html");
2625 fs::write(&index, "<html><body>spa</body></html>").expect("write index");
2626
2627 let ingress = Ranvier::http::<()>().spa_fallback(index.to_string_lossy().to_string());
2628 let app = crate::test_harness::TestApp::new(ingress, ());
2629 let response = app
2630 .send(crate::test_harness::TestRequest::get("/dashboard/settings"))
2631 .await
2632 .expect("request should succeed");
2633
2634 assert_eq!(response.status(), StatusCode::OK);
2635 assert!(response.text().expect("utf8").contains("spa"));
2636 }
2637
2638 #[tokio::test]
2639 async fn static_compression_layer_sets_content_encoding_for_gzip_client() {
2640 let temp = tempdir().expect("tempdir");
2641 let root = temp.path().join("public");
2642 fs::create_dir_all(&root).expect("create dir");
2643 let file = root.join("compressed.txt");
2644 fs::write(&file, "compress me ".repeat(400)).expect("write file");
2645
2646 let ingress = Ranvier::http::<()>()
2647 .serve_dir("/static", root.to_string_lossy().to_string())
2648 .compression_layer();
2649 let app = crate::test_harness::TestApp::new(ingress, ());
2650 let response = app
2651 .send(
2652 crate::test_harness::TestRequest::get("/static/compressed.txt")
2653 .header("accept-encoding", "gzip"),
2654 )
2655 .await
2656 .expect("request should succeed");
2657
2658 assert_eq!(response.status(), StatusCode::OK);
2659 assert_eq!(
2660 response
2661 .header("content-encoding")
2662 .and_then(|value| value.to_str().ok()),
2663 Some("gzip")
2664 );
2665 }
2666
2667 #[tokio::test]
2668 async fn drain_connections_completes_before_timeout() {
2669 let mut connections = tokio::task::JoinSet::new();
2670 connections.spawn(async {
2671 tokio::time::sleep(Duration::from_millis(20)).await;
2672 });
2673
2674 let timed_out = drain_connections(&mut connections, Duration::from_millis(200)).await;
2675 assert!(!timed_out);
2676 assert!(connections.is_empty());
2677 }
2678
2679 #[tokio::test]
2680 async fn drain_connections_times_out_and_aborts() {
2681 let mut connections = tokio::task::JoinSet::new();
2682 connections.spawn(async {
2683 tokio::time::sleep(Duration::from_secs(10)).await;
2684 });
2685
2686 let timed_out = drain_connections(&mut connections, Duration::from_millis(10)).await;
2687 assert!(timed_out);
2688 assert!(connections.is_empty());
2689 }
2690
2691 #[tokio::test]
2692 async fn timeout_layer_returns_408_for_slow_route() {
2693 #[derive(Clone)]
2694 struct SlowRoute;
2695
2696 #[async_trait]
2697 impl Transition<(), &'static str> for SlowRoute {
2698 type Error = Infallible;
2699 type Resources = ();
2700
2701 async fn run(
2702 &self,
2703 _state: (),
2704 _resources: &Self::Resources,
2705 _bus: &mut Bus,
2706 ) -> Outcome<&'static str, Self::Error> {
2707 tokio::time::sleep(Duration::from_millis(80)).await;
2708 Outcome::next("slow-ok")
2709 }
2710 }
2711
2712 let probe = std::net::TcpListener::bind("127.0.0.1:0").expect("bind probe");
2713 let addr = probe.local_addr().expect("local addr");
2714 drop(probe);
2715
2716 let ingress = HttpIngress::<()>::new()
2717 .bind(addr.to_string())
2718 .timeout_layer(Duration::from_millis(10))
2719 .get(
2720 "/slow",
2721 Axon::<(), (), Infallible, ()>::new("Slow").then(SlowRoute),
2722 );
2723
2724 let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel::<()>();
2725 let server = tokio::spawn(async move {
2726 ingress
2727 .run_with_shutdown_signal((), async move {
2728 let _ = shutdown_rx.await;
2729 })
2730 .await
2731 });
2732
2733 let mut stream = connect_with_retry(addr).await;
2734 stream
2735 .write_all(b"GET /slow HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n")
2736 .await
2737 .expect("write request");
2738
2739 let mut buf = Vec::new();
2740 stream.read_to_end(&mut buf).await.expect("read response");
2741 let response = String::from_utf8_lossy(&buf);
2742 assert!(response.starts_with("HTTP/1.1 408"), "{response}");
2743
2744 let _ = shutdown_tx.send(());
2745 server
2746 .await
2747 .expect("server join")
2748 .expect("server shutdown should succeed");
2749 }
2750
2751 #[tokio::test]
2752 async fn route_layer_override_bypasses_global_timeout() {
2753 #[derive(Clone)]
2754 struct SlowRoute;
2755
2756 #[async_trait]
2757 impl Transition<(), &'static str> for SlowRoute {
2758 type Error = Infallible;
2759 type Resources = ();
2760
2761 async fn run(
2762 &self,
2763 _state: (),
2764 _resources: &Self::Resources,
2765 _bus: &mut Bus,
2766 ) -> Outcome<&'static str, Self::Error> {
2767 tokio::time::sleep(Duration::from_millis(60)).await;
2768 Outcome::next("override-ok")
2769 }
2770 }
2771
2772 let probe = std::net::TcpListener::bind("127.0.0.1:0").expect("bind probe");
2773 let addr = probe.local_addr().expect("local addr");
2774 drop(probe);
2775
2776 let ingress = HttpIngress::<()>::new()
2777 .bind(addr.to_string())
2778 .timeout_layer(Duration::from_millis(10))
2779 .get_with_layer_override(
2780 "/slow",
2781 Axon::<(), (), Infallible, ()>::new("SlowOverride").then(SlowRoute),
2782 tower::layer::util::Identity::new(),
2783 );
2784
2785 let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel::<()>();
2786 let server = tokio::spawn(async move {
2787 ingress
2788 .run_with_shutdown_signal((), async move {
2789 let _ = shutdown_rx.await;
2790 })
2791 .await
2792 });
2793
2794 let mut stream = connect_with_retry(addr).await;
2795 stream
2796 .write_all(b"GET /slow HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n")
2797 .await
2798 .expect("write request");
2799
2800 let mut buf = Vec::new();
2801 stream.read_to_end(&mut buf).await.expect("read response");
2802 let response = String::from_utf8_lossy(&buf);
2803 assert!(response.starts_with("HTTP/1.1 200"), "{response}");
2804 assert!(response.contains("override-ok"), "{response}");
2805
2806 let _ = shutdown_tx.send(());
2807 server
2808 .await
2809 .expect("server join")
2810 .expect("server shutdown should succeed");
2811 }
2812}