1use base64::Engine;
17use bytes::Bytes;
18use futures_util::{SinkExt, StreamExt};
19use http::{Method, Request, Response, StatusCode, Uri};
20use http_body::Body;
21use http_body_util::{BodyExt, Full};
22use hyper::body::Incoming;
23use hyper::server::conn::http1;
24use hyper::upgrade::Upgraded;
25use hyper_util::rt::TokioIo;
26use hyper_util::service::TowerToHyperService;
27use ranvier_core::event::{EventSink, EventSource};
28use ranvier_core::prelude::*;
29use ranvier_runtime::Axon;
30use serde::Serialize;
31use serde::de::DeserializeOwned;
32use sha1::{Digest, Sha1};
33use std::collections::HashMap;
34use std::convert::Infallible;
35use std::future::Future;
36use std::net::SocketAddr;
37use std::pin::Pin;
38use std::sync::Arc;
39use std::time::Duration;
40use tokio::net::TcpListener;
41use tokio::sync::Mutex;
42use tokio_tungstenite::WebSocketStream;
43use tokio_tungstenite::tungstenite::{Error as WsWireError, Message as WsWireMessage};
44use tower::util::BoxCloneService;
45use tower::{Layer, Service, ServiceExt, service_fn};
46use tower_http::compression::CompressionLayer;
47use tower_http::services::{ServeDir, ServeFile};
48use tracing::Instrument;
49
50use crate::response::{IntoResponse, outcome_to_response_with_error};
51
52pub struct Ranvier;
57
58impl Ranvier {
59 pub fn http<R>() -> HttpIngress<R>
61 where
62 R: ranvier_core::transition::ResourceRequirement + Clone,
63 {
64 HttpIngress::new()
65 }
66}
67
68type RouteHandler<R> = Arc<
70 dyn Fn(Request<Incoming>, &R) -> Pin<Box<dyn Future<Output = Response<Full<Bytes>>> + Send>>
71 + Send
72 + Sync,
73>;
74
75type BoxHttpService = BoxCloneService<Request<Incoming>, Response<Full<Bytes>>, Infallible>;
76type ServiceLayer = Arc<dyn Fn(BoxHttpService) -> BoxHttpService + Send + Sync>;
77type LifecycleHook = Arc<dyn Fn() + Send + Sync>;
78type BusInjector = Arc<dyn Fn(&Request<Incoming>, &mut Bus) + Send + Sync>;
79type WsSessionFuture = Pin<Box<dyn Future<Output = ()> + Send>>;
80type WsSessionHandler<R> =
81 Arc<dyn Fn(WebSocketConnection, Arc<R>, Bus) -> WsSessionFuture + Send + Sync>;
82type HealthCheckFuture = Pin<Box<dyn Future<Output = Result<(), String>> + Send>>;
83type HealthCheckFn<R> = Arc<dyn Fn(Arc<R>) -> HealthCheckFuture + Send + Sync>;
84const REQUEST_ID_HEADER: &str = "x-request-id";
85const WS_UPGRADE_TOKEN: &str = "websocket";
86const WS_GUID: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
87
88#[derive(Clone)]
89struct NamedHealthCheck<R> {
90 name: String,
91 check: HealthCheckFn<R>,
92}
93
94#[derive(Clone)]
95struct HealthConfig<R> {
96 health_path: Option<String>,
97 readiness_path: Option<String>,
98 liveness_path: Option<String>,
99 checks: Vec<NamedHealthCheck<R>>,
100}
101
102impl<R> Default for HealthConfig<R> {
103 fn default() -> Self {
104 Self {
105 health_path: None,
106 readiness_path: None,
107 liveness_path: None,
108 checks: Vec::new(),
109 }
110 }
111}
112
113#[derive(Clone, Default)]
114struct StaticAssetsConfig {
115 mounts: Vec<StaticMount>,
116 spa_fallback: Option<String>,
117 cache_control: Option<String>,
118 enable_compression: bool,
119}
120
121#[derive(Clone)]
122struct StaticMount {
123 route_prefix: String,
124 directory: String,
125}
126
127#[derive(Serialize)]
128struct HealthReport {
129 status: &'static str,
130 probe: &'static str,
131 checks: Vec<HealthCheckReport>,
132}
133
134#[derive(Serialize)]
135struct HealthCheckReport {
136 name: String,
137 status: &'static str,
138 #[serde(skip_serializing_if = "Option::is_none")]
139 error: Option<String>,
140}
141
142#[derive(Clone)]
143struct TimeoutService {
144 inner: BoxHttpService,
145 timeout: Duration,
146}
147
148impl Service<Request<Incoming>> for TimeoutService {
149 type Response = Response<Full<Bytes>>;
150 type Error = Infallible;
151 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
152
153 fn poll_ready(
154 &mut self,
155 cx: &mut std::task::Context<'_>,
156 ) -> std::task::Poll<Result<(), Self::Error>> {
157 self.inner.poll_ready(cx)
158 }
159
160 fn call(&mut self, req: Request<Incoming>) -> Self::Future {
161 let timeout = self.timeout;
162 let fut = self.inner.call(req);
163 Box::pin(async move {
164 match tokio::time::timeout(timeout, fut).await {
165 Ok(response) => response,
166 Err(_) => Ok(Response::builder()
167 .status(StatusCode::REQUEST_TIMEOUT)
168 .body(Full::new(Bytes::from("Request Timeout")))
169 .unwrap()),
170 }
171 })
172 }
173}
174
175#[derive(Clone)]
176struct RequestIdService {
177 inner: BoxHttpService,
178}
179
180impl Service<Request<Incoming>> for RequestIdService {
181 type Response = Response<Full<Bytes>>;
182 type Error = Infallible;
183 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
184
185 fn poll_ready(
186 &mut self,
187 cx: &mut std::task::Context<'_>,
188 ) -> std::task::Poll<Result<(), Self::Error>> {
189 self.inner.poll_ready(cx)
190 }
191
192 fn call(&mut self, req: Request<Incoming>) -> Self::Future {
193 let mut req = req;
194 let request_id = req
195 .headers()
196 .get(REQUEST_ID_HEADER)
197 .cloned()
198 .unwrap_or_else(|| {
199 http::HeaderValue::from_str(&uuid::Uuid::new_v4().to_string())
200 .unwrap_or_else(|_| http::HeaderValue::from_static("request-id-unavailable"))
201 });
202
203 req.headers_mut()
204 .insert(REQUEST_ID_HEADER, request_id.clone());
205
206 let fut = self.inner.call(req);
207 Box::pin(async move {
208 let mut response = fut.await?;
209 response.headers_mut().insert(REQUEST_ID_HEADER, request_id);
210 Ok(response)
211 })
212 }
213}
214
215fn to_service_layer<L>(layer: L) -> ServiceLayer
216where
217 L: Layer<BoxHttpService> + Clone + Send + Sync + 'static,
218 L::Service: Service<Request<Incoming>, Response = Response<Full<Bytes>>, Error = Infallible>
219 + Clone
220 + Send
221 + 'static,
222 <L::Service as Service<Request<Incoming>>>::Future: Send + 'static,
223{
224 Arc::new(move |service: BoxHttpService| BoxCloneService::new(layer.clone().layer(service)))
225}
226
227#[derive(Clone, Debug, Default, PartialEq, Eq)]
228pub struct PathParams {
229 values: HashMap<String, String>,
230}
231
232#[derive(Clone, Debug, PartialEq, Eq)]
234pub struct HttpRouteDescriptor {
235 method: Method,
236 path_pattern: String,
237}
238
239impl HttpRouteDescriptor {
240 pub fn new(method: Method, path_pattern: impl Into<String>) -> Self {
241 Self {
242 method,
243 path_pattern: path_pattern.into(),
244 }
245 }
246
247 pub fn method(&self) -> &Method {
248 &self.method
249 }
250
251 pub fn path_pattern(&self) -> &str {
252 &self.path_pattern
253 }
254}
255
256#[derive(Clone, Debug, PartialEq, Eq, Serialize)]
258pub struct WebSocketSessionContext {
259 connection_id: uuid::Uuid,
260 path: String,
261 query: Option<String>,
262}
263
264impl WebSocketSessionContext {
265 pub fn connection_id(&self) -> uuid::Uuid {
266 self.connection_id
267 }
268
269 pub fn path(&self) -> &str {
270 &self.path
271 }
272
273 pub fn query(&self) -> Option<&str> {
274 self.query.as_deref()
275 }
276}
277
278#[derive(Clone, Debug, PartialEq, Eq)]
280pub enum WebSocketEvent {
281 Text(String),
282 Binary(Vec<u8>),
283 Ping(Vec<u8>),
284 Pong(Vec<u8>),
285 Close,
286}
287
288impl WebSocketEvent {
289 pub fn text(value: impl Into<String>) -> Self {
290 Self::Text(value.into())
291 }
292
293 pub fn binary(value: impl Into<Vec<u8>>) -> Self {
294 Self::Binary(value.into())
295 }
296
297 pub fn json<T>(value: &T) -> Result<Self, serde_json::Error>
298 where
299 T: Serialize,
300 {
301 let text = serde_json::to_string(value)?;
302 Ok(Self::Text(text))
303 }
304}
305
306#[derive(Debug, thiserror::Error)]
307pub enum WebSocketError {
308 #[error("websocket wire error: {0}")]
309 Wire(#[from] WsWireError),
310 #[error("json serialization failed: {0}")]
311 JsonSerialize(#[source] serde_json::Error),
312 #[error("json deserialization failed: {0}")]
313 JsonDeserialize(#[source] serde_json::Error),
314 #[error("expected text or binary frame for json payload")]
315 NonDataFrame,
316}
317
318type WsServerStream = WebSocketStream<TokioIo<Upgraded>>;
319type WsServerSink = futures_util::stream::SplitSink<WsServerStream, WsWireMessage>;
320type WsServerSource = futures_util::stream::SplitStream<WsServerStream>;
321
322pub struct WebSocketConnection {
324 sink: Mutex<WsServerSink>,
325 source: Mutex<WsServerSource>,
326 session: WebSocketSessionContext,
327}
328
329impl WebSocketConnection {
330 fn new(stream: WsServerStream, session: WebSocketSessionContext) -> Self {
331 let (sink, source) = stream.split();
332 Self {
333 sink: Mutex::new(sink),
334 source: Mutex::new(source),
335 session,
336 }
337 }
338
339 pub fn session(&self) -> &WebSocketSessionContext {
340 &self.session
341 }
342
343 pub async fn send(&self, event: WebSocketEvent) -> Result<(), WebSocketError> {
344 let mut sink = self.sink.lock().await;
345 sink.send(event.into_wire_message()).await?;
346 Ok(())
347 }
348
349 pub async fn send_json<T>(&self, value: &T) -> Result<(), WebSocketError>
350 where
351 T: Serialize,
352 {
353 let event = WebSocketEvent::json(value).map_err(WebSocketError::JsonSerialize)?;
354 self.send(event).await
355 }
356
357 pub async fn next_json<T>(&mut self) -> Result<Option<T>, WebSocketError>
358 where
359 T: DeserializeOwned,
360 {
361 let Some(event) = self.recv_event().await? else {
362 return Ok(None);
363 };
364 match event {
365 WebSocketEvent::Text(text) => serde_json::from_str(&text)
366 .map(Some)
367 .map_err(WebSocketError::JsonDeserialize),
368 WebSocketEvent::Binary(bytes) => serde_json::from_slice(&bytes)
369 .map(Some)
370 .map_err(WebSocketError::JsonDeserialize),
371 _ => Err(WebSocketError::NonDataFrame),
372 }
373 }
374
375 async fn recv_event(&mut self) -> Result<Option<WebSocketEvent>, WsWireError> {
376 let mut source = self.source.lock().await;
377 while let Some(item) = source.next().await {
378 let message = item?;
379 if let Some(event) = WebSocketEvent::from_wire_message(message) {
380 return Ok(Some(event));
381 }
382 }
383 Ok(None)
384 }
385}
386
387impl WebSocketEvent {
388 fn from_wire_message(message: WsWireMessage) -> Option<Self> {
389 match message {
390 WsWireMessage::Text(value) => Some(Self::Text(value.to_string())),
391 WsWireMessage::Binary(value) => Some(Self::Binary(value.to_vec())),
392 WsWireMessage::Ping(value) => Some(Self::Ping(value.to_vec())),
393 WsWireMessage::Pong(value) => Some(Self::Pong(value.to_vec())),
394 WsWireMessage::Close(_) => Some(Self::Close),
395 WsWireMessage::Frame(_) => None,
396 }
397 }
398
399 fn into_wire_message(self) -> WsWireMessage {
400 match self {
401 Self::Text(value) => WsWireMessage::Text(value.into()),
402 Self::Binary(value) => WsWireMessage::Binary(value.into()),
403 Self::Ping(value) => WsWireMessage::Ping(value.into()),
404 Self::Pong(value) => WsWireMessage::Pong(value.into()),
405 Self::Close => WsWireMessage::Close(None),
406 }
407 }
408}
409
410#[async_trait::async_trait]
411impl EventSource<WebSocketEvent> for WebSocketConnection {
412 async fn next_event(&mut self) -> Option<WebSocketEvent> {
413 match self.recv_event().await {
414 Ok(event) => event,
415 Err(error) => {
416 tracing::warn!(ranvier.ws.error = %error, "websocket source read failed");
417 None
418 }
419 }
420 }
421}
422
423#[async_trait::async_trait]
424impl EventSink<WebSocketEvent> for WebSocketConnection {
425 type Error = WebSocketError;
426
427 async fn send_event(&self, event: WebSocketEvent) -> Result<(), Self::Error> {
428 self.send(event).await
429 }
430}
431
432#[async_trait::async_trait]
433impl EventSink<String> for WebSocketConnection {
434 type Error = WebSocketError;
435
436 async fn send_event(&self, event: String) -> Result<(), Self::Error> {
437 self.send(WebSocketEvent::Text(event)).await
438 }
439}
440
441#[async_trait::async_trait]
442impl EventSink<Vec<u8>> for WebSocketConnection {
443 type Error = WebSocketError;
444
445 async fn send_event(&self, event: Vec<u8>) -> Result<(), Self::Error> {
446 self.send(WebSocketEvent::Binary(event)).await
447 }
448}
449
450impl PathParams {
451 pub fn new(values: HashMap<String, String>) -> Self {
452 Self { values }
453 }
454
455 pub fn get(&self, key: &str) -> Option<&str> {
456 self.values.get(key).map(String::as_str)
457 }
458
459 pub fn as_map(&self) -> &HashMap<String, String> {
460 &self.values
461 }
462
463 pub fn into_inner(self) -> HashMap<String, String> {
464 self.values
465 }
466}
467
468#[derive(Clone, Debug, PartialEq, Eq)]
469enum RouteSegment {
470 Static(String),
471 Param(String),
472 Wildcard(String),
473}
474
475#[derive(Clone, Debug, PartialEq, Eq)]
476struct RoutePattern {
477 raw: String,
478 segments: Vec<RouteSegment>,
479}
480
481impl RoutePattern {
482 fn parse(path: &str) -> Self {
483 let segments = path_segments(path)
484 .into_iter()
485 .map(|segment| {
486 if let Some(name) = segment.strip_prefix(':') {
487 if !name.is_empty() {
488 return RouteSegment::Param(name.to_string());
489 }
490 }
491 if let Some(name) = segment.strip_prefix('*') {
492 if !name.is_empty() {
493 return RouteSegment::Wildcard(name.to_string());
494 }
495 }
496 RouteSegment::Static(segment.to_string())
497 })
498 .collect();
499
500 Self {
501 raw: path.to_string(),
502 segments,
503 }
504 }
505
506 fn match_path(&self, path: &str) -> Option<PathParams> {
507 let mut params = HashMap::new();
508 let path_segments = path_segments(path);
509 let mut pattern_index = 0usize;
510 let mut path_index = 0usize;
511
512 while pattern_index < self.segments.len() {
513 match &self.segments[pattern_index] {
514 RouteSegment::Static(expected) => {
515 let actual = path_segments.get(path_index)?;
516 if actual != expected {
517 return None;
518 }
519 pattern_index += 1;
520 path_index += 1;
521 }
522 RouteSegment::Param(name) => {
523 let actual = path_segments.get(path_index)?;
524 params.insert(name.clone(), (*actual).to_string());
525 pattern_index += 1;
526 path_index += 1;
527 }
528 RouteSegment::Wildcard(name) => {
529 let remaining = path_segments[path_index..].join("/");
530 params.insert(name.clone(), remaining);
531 pattern_index += 1;
532 path_index = path_segments.len();
533 break;
534 }
535 }
536 }
537
538 if pattern_index == self.segments.len() && path_index == path_segments.len() {
539 Some(PathParams::new(params))
540 } else {
541 None
542 }
543 }
544}
545
546#[derive(Clone)]
547struct RouteEntry<R> {
548 method: Method,
549 pattern: RoutePattern,
550 handler: RouteHandler<R>,
551 layers: Arc<Vec<ServiceLayer>>,
552 apply_global_layers: bool,
553}
554
555fn path_segments(path: &str) -> Vec<&str> {
556 if path == "/" {
557 return Vec::new();
558 }
559
560 path.trim_matches('/')
561 .split('/')
562 .filter(|segment| !segment.is_empty())
563 .collect()
564}
565
566fn normalize_route_path(path: String) -> String {
567 if path.is_empty() {
568 return "/".to_string();
569 }
570 if path.starts_with('/') {
571 path
572 } else {
573 format!("/{path}")
574 }
575}
576
577fn find_matching_route<'a, R>(
578 routes: &'a [RouteEntry<R>],
579 method: &Method,
580 path: &str,
581) -> Option<(&'a RouteEntry<R>, PathParams)> {
582 for entry in routes {
583 if &entry.method != method {
584 continue;
585 }
586 if let Some(params) = entry.pattern.match_path(path) {
587 return Some((entry, params));
588 }
589 }
590 None
591}
592
593fn header_contains_token(
594 headers: &http::HeaderMap,
595 name: http::header::HeaderName,
596 token: &str,
597) -> bool {
598 headers
599 .get(name)
600 .and_then(|value| value.to_str().ok())
601 .map(|value| {
602 value
603 .split(',')
604 .any(|part| part.trim().eq_ignore_ascii_case(token))
605 })
606 .unwrap_or(false)
607}
608
609fn websocket_session_from_request(req: &Request<Incoming>) -> WebSocketSessionContext {
610 WebSocketSessionContext {
611 connection_id: uuid::Uuid::new_v4(),
612 path: req.uri().path().to_string(),
613 query: req.uri().query().map(str::to_string),
614 }
615}
616
617fn websocket_accept_key(client_key: &str) -> String {
618 let mut hasher = Sha1::new();
619 hasher.update(client_key.as_bytes());
620 hasher.update(WS_GUID.as_bytes());
621 let digest = hasher.finalize();
622 base64::engine::general_purpose::STANDARD.encode(digest)
623}
624
625fn websocket_bad_request(message: &'static str) -> Response<Full<Bytes>> {
626 Response::builder()
627 .status(StatusCode::BAD_REQUEST)
628 .body(Full::new(Bytes::from(message)))
629 .unwrap_or_else(|_| Response::new(Full::new(Bytes::new())))
630}
631
632fn websocket_upgrade_response(
633 req: &mut Request<Incoming>,
634) -> Result<(Response<Full<Bytes>>, hyper::upgrade::OnUpgrade), Response<Full<Bytes>>> {
635 if req.method() != Method::GET {
636 return Err(websocket_bad_request(
637 "WebSocket upgrade requires GET method",
638 ));
639 }
640
641 if !header_contains_token(req.headers(), http::header::CONNECTION, "upgrade") {
642 return Err(websocket_bad_request(
643 "Missing Connection: upgrade header for WebSocket",
644 ));
645 }
646
647 if !header_contains_token(req.headers(), http::header::UPGRADE, WS_UPGRADE_TOKEN) {
648 return Err(websocket_bad_request("Missing Upgrade: websocket header"));
649 }
650
651 if let Some(version) = req.headers().get("sec-websocket-version") {
652 if version != "13" {
653 return Err(websocket_bad_request(
654 "Unsupported Sec-WebSocket-Version (expected 13)",
655 ));
656 }
657 }
658
659 let Some(client_key) = req
660 .headers()
661 .get("sec-websocket-key")
662 .and_then(|value| value.to_str().ok())
663 else {
664 return Err(websocket_bad_request(
665 "Missing Sec-WebSocket-Key header for WebSocket",
666 ));
667 };
668
669 let accept_key = websocket_accept_key(client_key);
670 let on_upgrade = hyper::upgrade::on(req);
671 let response = Response::builder()
672 .status(StatusCode::SWITCHING_PROTOCOLS)
673 .header(http::header::UPGRADE, WS_UPGRADE_TOKEN)
674 .header(http::header::CONNECTION, "Upgrade")
675 .header("sec-websocket-accept", accept_key)
676 .body(Full::new(Bytes::new()))
677 .unwrap_or_else(|_| Response::new(Full::new(Bytes::new())));
678
679 Ok((response, on_upgrade))
680}
681
682enum RouteGroupSpec<R> {
696 Route {
698 method: Method,
699 sub_path: String,
700 handler: RouteHandler<R>,
701 },
702 Nested(RouteGroup<R>),
704}
705
706pub struct RouteGroup<R = ()> {
711 prefix: String,
714 layers: Arc<Vec<ServiceLayer>>,
718 specs: Vec<RouteGroupSpec<R>>,
720}
721
722impl<R> RouteGroup<R>
723where
724 R: ranvier_core::transition::ResourceRequirement + Clone + Send + Sync + 'static,
725{
726 pub fn new(prefix: impl Into<String>) -> Self {
731 Self {
732 prefix: normalize_route_path(prefix.into()).trim_end_matches('/').to_string(),
733 layers: Arc::new(Vec::new()),
734 specs: Vec::new(),
735 }
736 }
737
738 pub fn get<Out, E>(self, sub_path: impl Into<String>, circuit: Axon<(), Out, E, R>) -> Self
740 where
741 Out: IntoResponse + Send + Sync + 'static,
742 E: Send + 'static + std::fmt::Debug,
743 {
744 self.method(Method::GET, sub_path, circuit)
745 }
746
747 pub fn post<Out, E>(self, sub_path: impl Into<String>, circuit: Axon<(), Out, E, R>) -> Self
749 where
750 Out: IntoResponse + Send + Sync + 'static,
751 E: Send + 'static + std::fmt::Debug,
752 {
753 self.method(Method::POST, sub_path, circuit)
754 }
755
756 pub fn put<Out, E>(self, sub_path: impl Into<String>, circuit: Axon<(), Out, E, R>) -> Self
758 where
759 Out: IntoResponse + Send + Sync + 'static,
760 E: Send + 'static + std::fmt::Debug,
761 {
762 self.method(Method::PUT, sub_path, circuit)
763 }
764
765 pub fn patch<Out, E>(self, sub_path: impl Into<String>, circuit: Axon<(), Out, E, R>) -> Self
767 where
768 Out: IntoResponse + Send + Sync + 'static,
769 E: Send + 'static + std::fmt::Debug,
770 {
771 self.method(Method::PATCH, sub_path, circuit)
772 }
773
774 pub fn delete<Out, E>(self, sub_path: impl Into<String>, circuit: Axon<(), Out, E, R>) -> Self
776 where
777 Out: IntoResponse + Send + Sync + 'static,
778 E: Send + 'static + std::fmt::Debug,
779 {
780 self.method(Method::DELETE, sub_path, circuit)
781 }
782
783 pub fn method<Out, E>(
785 mut self,
786 method: Method,
787 sub_path: impl Into<String>,
788 circuit: Axon<(), Out, E, R>,
789 ) -> Self
790 where
791 Out: IntoResponse + Send + Sync + 'static,
792 E: Send + 'static + std::fmt::Debug,
793 {
794 let sub_path_str: String = sub_path.into();
795 let circuit = Arc::new(circuit);
796 let error_handler = Arc::new(|error: &E| {
797 (
798 StatusCode::INTERNAL_SERVER_ERROR,
799 format!("Error: {:?}", error),
800 )
801 .into_response()
802 });
803 let method_for_handler = method.clone();
804 let sub_path_for_handler = sub_path_str.clone();
805
806 let handler: RouteHandler<R> = Arc::new(move |req: Request<Incoming>, res: &R| {
807 let circuit = circuit.clone();
808 let error_handler = error_handler.clone();
809 let res = res.clone();
810 let path = sub_path_for_handler.clone();
811 let method = method_for_handler.clone();
812
813 Box::pin(async move {
814 let request_id = uuid::Uuid::new_v4().to_string();
815 let span = tracing::info_span!(
816 "HTTPRequest",
817 ranvier.http.method = %method,
818 ranvier.http.path = %path,
819 ranvier.http.request_id = %request_id
820 );
821
822 async move {
823 let mut bus = Bus::new();
824 let result = circuit.execute((), &res, &mut bus).await;
825 outcome_to_response_with_error(result, |error| error_handler(error))
826 }
827 .instrument(span)
828 .await
829 }) as Pin<Box<dyn Future<Output = Response<Full<Bytes>>> + Send>>
830 });
831
832 self.specs.push(RouteGroupSpec::Route {
833 method,
834 sub_path: sub_path_str,
835 handler,
836 });
837 self
838 }
839
840 pub fn group(mut self, child: RouteGroup<R>) -> Self {
844 self.specs.push(RouteGroupSpec::Nested(child));
845 self
846 }
847
848 pub(crate) fn into_entries(self, parent_prefix: &str) -> Vec<RouteEntry<R>> {
854 let full_prefix = format!(
855 "{}{}",
856 parent_prefix.trim_end_matches('/'),
857 self.prefix
858 );
859 let layers = self.layers.clone();
860 let mut entries = Vec::new();
861
862 for spec in self.specs {
863 match spec {
864 RouteGroupSpec::Route {
865 method,
866 sub_path,
867 handler,
868 } => {
869 let full_path = if sub_path.is_empty() || sub_path == "/" {
870 if full_prefix.is_empty() {
871 "/".to_string()
872 } else {
873 full_prefix.clone()
874 }
875 } else {
876 let sub = if sub_path.starts_with('/') {
877 sub_path.clone()
878 } else {
879 format!("/{sub_path}")
880 };
881 format!("{full_prefix}{sub}")
882 };
883 entries.push(RouteEntry {
884 method,
885 pattern: RoutePattern::parse(&full_path),
886 handler,
887 layers: layers.clone(),
888 apply_global_layers: true,
889 });
890 }
891 RouteGroupSpec::Nested(child) => {
892 entries.extend(child.into_entries(&full_prefix));
893 }
894 }
895 }
896
897 entries
898 }
899}
900
901pub struct HttpIngress<R = ()> {
910 addr: Option<String>,
912 routes: Vec<RouteEntry<R>>,
914 fallback: Option<RouteHandler<R>>,
916 layers: Vec<ServiceLayer>,
918 on_start: Option<LifecycleHook>,
920 on_shutdown: Option<LifecycleHook>,
922 graceful_shutdown_timeout: Duration,
924 bus_injectors: Vec<BusInjector>,
926 static_assets: StaticAssetsConfig,
928 health: HealthConfig<R>,
930 _phantom: std::marker::PhantomData<R>,
931}
932
933impl<R> HttpIngress<R>
934where
935 R: ranvier_core::transition::ResourceRequirement + Clone + Send + Sync + 'static,
936{
937 pub fn new() -> Self {
939 Self {
940 addr: None,
941 routes: Vec::new(),
942 fallback: None,
943 layers: Vec::new(),
944 on_start: None,
945 on_shutdown: None,
946 graceful_shutdown_timeout: Duration::from_secs(30),
947 bus_injectors: Vec::new(),
948 static_assets: StaticAssetsConfig::default(),
949 health: HealthConfig::default(),
950 _phantom: std::marker::PhantomData,
951 }
952 }
953
954 pub fn bind(mut self, addr: impl Into<String>) -> Self {
956 self.addr = Some(addr.into());
957 self
958 }
959
960 pub fn on_start<F>(mut self, callback: F) -> Self
962 where
963 F: Fn() + Send + Sync + 'static,
964 {
965 self.on_start = Some(Arc::new(callback));
966 self
967 }
968
969 pub fn on_shutdown<F>(mut self, callback: F) -> Self
971 where
972 F: Fn() + Send + Sync + 'static,
973 {
974 self.on_shutdown = Some(Arc::new(callback));
975 self
976 }
977
978 pub fn graceful_shutdown(mut self, timeout: Duration) -> Self {
980 self.graceful_shutdown_timeout = timeout;
981 self
982 }
983
984 pub fn layer<L>(mut self, layer: L) -> Self
989 where
990 L: Layer<BoxHttpService> + Clone + Send + Sync + 'static,
991 L::Service: Service<Request<Incoming>, Response = Response<Full<Bytes>>, Error = Infallible>
992 + Clone
993 + Send
994 + 'static,
995 <L::Service as Service<Request<Incoming>>>::Future: Send + 'static,
996 {
997 self.layers.push(to_service_layer(layer));
998 self
999 }
1000
1001 pub fn timeout_layer(mut self, timeout: Duration) -> Self {
1004 self.layers.push(Arc::new(move |service: BoxHttpService| {
1005 BoxCloneService::new(TimeoutService {
1006 inner: service,
1007 timeout,
1008 })
1009 }));
1010 self
1011 }
1012
1013 pub fn request_id_layer(mut self) -> Self {
1017 self.layers.push(Arc::new(move |service: BoxHttpService| {
1018 BoxCloneService::new(RequestIdService { inner: service })
1019 }));
1020 self
1021 }
1022
1023 pub fn bus_injector<F>(mut self, injector: F) -> Self
1028 where
1029 F: Fn(&Request<Incoming>, &mut Bus) + Send + Sync + 'static,
1030 {
1031 self.bus_injectors.push(Arc::new(injector));
1032 self
1033 }
1034
1035 pub fn route_descriptors(&self) -> Vec<HttpRouteDescriptor> {
1037 let mut descriptors = self
1038 .routes
1039 .iter()
1040 .map(|entry| HttpRouteDescriptor::new(entry.method.clone(), entry.pattern.raw.clone()))
1041 .collect::<Vec<_>>();
1042
1043 if let Some(path) = &self.health.health_path {
1044 descriptors.push(HttpRouteDescriptor::new(Method::GET, path.clone()));
1045 }
1046 if let Some(path) = &self.health.readiness_path {
1047 descriptors.push(HttpRouteDescriptor::new(Method::GET, path.clone()));
1048 }
1049 if let Some(path) = &self.health.liveness_path {
1050 descriptors.push(HttpRouteDescriptor::new(Method::GET, path.clone()));
1051 }
1052
1053 descriptors
1054 }
1055
1056 pub fn serve_dir(
1060 mut self,
1061 route_prefix: impl Into<String>,
1062 directory: impl Into<String>,
1063 ) -> Self {
1064 self.static_assets.mounts.push(StaticMount {
1065 route_prefix: normalize_route_path(route_prefix.into()),
1066 directory: directory.into(),
1067 });
1068 if self.static_assets.cache_control.is_none() {
1069 self.static_assets.cache_control = Some("public, max-age=3600".to_string());
1070 }
1071 self
1072 }
1073
1074 pub fn spa_fallback(mut self, file_path: impl Into<String>) -> Self {
1078 self.static_assets.spa_fallback = Some(file_path.into());
1079 self
1080 }
1081
1082 pub fn static_cache_control(mut self, cache_control: impl Into<String>) -> Self {
1084 self.static_assets.cache_control = Some(cache_control.into());
1085 self
1086 }
1087
1088 pub fn compression_layer(mut self) -> Self {
1090 self.static_assets.enable_compression = true;
1091 self
1092 }
1093
1094 pub fn ws<H, Fut>(mut self, path: impl Into<String>, handler: H) -> Self
1101 where
1102 H: Fn(WebSocketConnection, Arc<R>, Bus) -> Fut + Send + Sync + 'static,
1103 Fut: Future<Output = ()> + Send + 'static,
1104 {
1105 let path_str: String = path.into();
1106 let ws_handler: WsSessionHandler<R> = Arc::new(move |connection, resources, bus| {
1107 Box::pin(handler(connection, resources, bus))
1108 });
1109 let bus_injectors = Arc::new(self.bus_injectors.clone());
1110 let path_for_pattern = path_str.clone();
1111 let path_for_handler = path_str;
1112
1113 let route_handler: RouteHandler<R> =
1114 Arc::new(move |mut req: Request<Incoming>, res: &R| {
1115 let ws_handler = ws_handler.clone();
1116 let bus_injectors = bus_injectors.clone();
1117 let resources = Arc::new(res.clone());
1118 let path = path_for_handler.clone();
1119
1120 Box::pin(async move {
1121 let request_id = uuid::Uuid::new_v4().to_string();
1122 let span = tracing::info_span!(
1123 "WebSocketUpgrade",
1124 ranvier.ws.path = %path,
1125 ranvier.ws.request_id = %request_id
1126 );
1127
1128 async move {
1129 let mut bus = Bus::new();
1130 for injector in bus_injectors.iter() {
1131 injector(&req, &mut bus);
1132 }
1133
1134 let session = websocket_session_from_request(&req);
1135 bus.insert(session.clone());
1136
1137 let (response, on_upgrade) = match websocket_upgrade_response(&mut req) {
1138 Ok(result) => result,
1139 Err(error_response) => return error_response,
1140 };
1141
1142 tokio::spawn(async move {
1143 match on_upgrade.await {
1144 Ok(upgraded) => {
1145 let stream = WebSocketStream::from_raw_socket(
1146 TokioIo::new(upgraded),
1147 tokio_tungstenite::tungstenite::protocol::Role::Server,
1148 None,
1149 )
1150 .await;
1151 let connection = WebSocketConnection::new(stream, session);
1152 ws_handler(connection, resources, bus).await;
1153 }
1154 Err(error) => {
1155 tracing::warn!(
1156 ranvier.ws.path = %path,
1157 ranvier.ws.error = %error,
1158 "websocket upgrade failed"
1159 );
1160 }
1161 }
1162 });
1163
1164 response
1165 }
1166 .instrument(span)
1167 .await
1168 }) as Pin<Box<dyn Future<Output = Response<Full<Bytes>>> + Send>>
1169 });
1170
1171 self.routes.push(RouteEntry {
1172 method: Method::GET,
1173 pattern: RoutePattern::parse(&path_for_pattern),
1174 handler: route_handler,
1175 layers: Arc::new(Vec::new()),
1176 apply_global_layers: true,
1177 });
1178
1179 self
1180 }
1181
1182 pub fn health_endpoint(mut self, path: impl Into<String>) -> Self {
1187 self.health.health_path = Some(normalize_route_path(path.into()));
1188 self
1189 }
1190
1191 pub fn health_check<F, Fut, Err>(mut self, name: impl Into<String>, check: F) -> Self
1195 where
1196 F: Fn(Arc<R>) -> Fut + Send + Sync + 'static,
1197 Fut: Future<Output = Result<(), Err>> + Send + 'static,
1198 Err: ToString + Send + 'static,
1199 {
1200 if self.health.health_path.is_none() {
1201 self.health.health_path = Some("/health".to_string());
1202 }
1203
1204 let check_fn: HealthCheckFn<R> = Arc::new(move |resources: Arc<R>| {
1205 let fut = check(resources);
1206 Box::pin(async move { fut.await.map_err(|error| error.to_string()) })
1207 });
1208
1209 self.health.checks.push(NamedHealthCheck {
1210 name: name.into(),
1211 check: check_fn,
1212 });
1213 self
1214 }
1215
1216 pub fn readiness_liveness(
1218 mut self,
1219 readiness_path: impl Into<String>,
1220 liveness_path: impl Into<String>,
1221 ) -> Self {
1222 self.health.readiness_path = Some(normalize_route_path(readiness_path.into()));
1223 self.health.liveness_path = Some(normalize_route_path(liveness_path.into()));
1224 self
1225 }
1226
1227 pub fn readiness_liveness_default(self) -> Self {
1229 self.readiness_liveness("/ready", "/live")
1230 }
1231
1232 pub fn route<Out, E>(self, path: impl Into<String>, circuit: Axon<(), Out, E, R>) -> Self
1234 where
1235 Out: IntoResponse + Send + Sync + 'static,
1236 E: Send + 'static + std::fmt::Debug,
1237 {
1238 self.route_method(Method::GET, path, circuit)
1239 }
1240
1241 pub fn route_group(mut self, group: RouteGroup<R>) -> Self {
1277 for entry in group.into_entries("") {
1278 self.routes.push(entry);
1279 }
1280 self
1281 }
1282
1283
1284 pub fn route_method<Out, E>(
1292 self,
1293 method: Method,
1294 path: impl Into<String>,
1295 circuit: Axon<(), Out, E, R>,
1296 ) -> Self
1297 where
1298 Out: IntoResponse + Send + Sync + 'static,
1299 E: Send + 'static + std::fmt::Debug,
1300 {
1301 self.route_method_with_error(method, path, circuit, |error| {
1302 (
1303 StatusCode::INTERNAL_SERVER_ERROR,
1304 format!("Error: {:?}", error),
1305 )
1306 .into_response()
1307 })
1308 }
1309
1310 pub fn route_method_with_error<Out, E, H>(
1311 self,
1312 method: Method,
1313 path: impl Into<String>,
1314 circuit: Axon<(), Out, E, R>,
1315 error_handler: H,
1316 ) -> Self
1317 where
1318 Out: IntoResponse + Send + Sync + 'static,
1319 E: Send + 'static + std::fmt::Debug,
1320 H: Fn(&E) -> Response<Full<Bytes>> + Send + Sync + 'static,
1321 {
1322 self.route_method_with_error_and_layers(
1323 method,
1324 path,
1325 circuit,
1326 error_handler,
1327 Arc::new(Vec::new()),
1328 true,
1329 )
1330 }
1331
1332 pub fn route_method_with_layer<Out, E, L>(
1333 self,
1334 method: Method,
1335 path: impl Into<String>,
1336 circuit: Axon<(), Out, E, R>,
1337 layer: L,
1338 ) -> Self
1339 where
1340 Out: IntoResponse + Send + Sync + 'static,
1341 E: Send + 'static + std::fmt::Debug,
1342 L: Layer<BoxHttpService> + Clone + Send + Sync + 'static,
1343 L::Service: Service<Request<Incoming>, Response = Response<Full<Bytes>>, Error = Infallible>
1344 + Clone
1345 + Send
1346 + 'static,
1347 <L::Service as Service<Request<Incoming>>>::Future: Send + 'static,
1348 {
1349 self.route_method_with_error_and_layers(
1350 method,
1351 path,
1352 circuit,
1353 |error| {
1354 (
1355 StatusCode::INTERNAL_SERVER_ERROR,
1356 format!("Error: {:?}", error),
1357 )
1358 .into_response()
1359 },
1360 Arc::new(vec![to_service_layer(layer)]),
1361 true,
1362 )
1363 }
1364
1365 pub fn route_method_with_layer_override<Out, E, L>(
1366 self,
1367 method: Method,
1368 path: impl Into<String>,
1369 circuit: Axon<(), Out, E, R>,
1370 layer: L,
1371 ) -> Self
1372 where
1373 Out: IntoResponse + Send + Sync + 'static,
1374 E: Send + 'static + std::fmt::Debug,
1375 L: Layer<BoxHttpService> + Clone + Send + Sync + 'static,
1376 L::Service: Service<Request<Incoming>, Response = Response<Full<Bytes>>, Error = Infallible>
1377 + Clone
1378 + Send
1379 + 'static,
1380 <L::Service as Service<Request<Incoming>>>::Future: Send + 'static,
1381 {
1382 self.route_method_with_error_and_layers(
1383 method,
1384 path,
1385 circuit,
1386 |error| {
1387 (
1388 StatusCode::INTERNAL_SERVER_ERROR,
1389 format!("Error: {:?}", error),
1390 )
1391 .into_response()
1392 },
1393 Arc::new(vec![to_service_layer(layer)]),
1394 false,
1395 )
1396 }
1397
1398 fn route_method_with_error_and_layers<Out, E, H>(
1399 mut self,
1400 method: Method,
1401 path: impl Into<String>,
1402 circuit: Axon<(), Out, E, R>,
1403 error_handler: H,
1404 route_layers: Arc<Vec<ServiceLayer>>,
1405 apply_global_layers: bool,
1406 ) -> Self
1407 where
1408 Out: IntoResponse + Send + Sync + 'static,
1409 E: Send + 'static + std::fmt::Debug,
1410 H: Fn(&E) -> Response<Full<Bytes>> + Send + Sync + 'static,
1411 {
1412 let path_str: String = path.into();
1413 let circuit = Arc::new(circuit);
1414 let error_handler = Arc::new(error_handler);
1415 let route_bus_injectors = Arc::new(self.bus_injectors.clone());
1416 let path_for_pattern = path_str.clone();
1417 let path_for_handler = path_str;
1418 let method_for_pattern = method.clone();
1419 let method_for_handler = method;
1420
1421 let handler: RouteHandler<R> = Arc::new(move |req: Request<Incoming>, res: &R| {
1422 let circuit = circuit.clone();
1423 let error_handler = error_handler.clone();
1424 let route_bus_injectors = route_bus_injectors.clone();
1425 let res = res.clone();
1426 let path = path_for_handler.clone();
1427 let method = method_for_handler.clone();
1428
1429 Box::pin(async move {
1430 let request_id = uuid::Uuid::new_v4().to_string();
1431 let span = tracing::info_span!(
1432 "HTTPRequest",
1433 ranvier.http.method = %method,
1434 ranvier.http.path = %path,
1435 ranvier.http.request_id = %request_id
1436 );
1437
1438 async move {
1439 let mut bus = Bus::new();
1440 for injector in route_bus_injectors.iter() {
1441 injector(&req, &mut bus);
1442 }
1443 let result = circuit.execute((), &res, &mut bus).await;
1444 outcome_to_response_with_error(result, |error| error_handler(error))
1445 }
1446 .instrument(span)
1447 .await
1448 }) as Pin<Box<dyn Future<Output = Response<Full<Bytes>>> + Send>>
1449 });
1450
1451 self.routes.push(RouteEntry {
1452 method: method_for_pattern,
1453 pattern: RoutePattern::parse(&path_for_pattern),
1454 handler,
1455 layers: route_layers,
1456 apply_global_layers,
1457 });
1458 self
1459 }
1460
1461 fn route_method_with_body<Out, E>(
1463 mut self,
1464 method: Method,
1465 path: impl Into<String>,
1466 circuit: Axon<(), Out, E, R>,
1467 ) -> Self
1468 where
1469 Out: IntoResponse + Send + Sync + 'static,
1470 E: Send + 'static + std::fmt::Debug,
1471 {
1472 use crate::extract::HttpRequestBody;
1473
1474 let path_str: String = path.into();
1475 let circuit = Arc::new(circuit);
1476 let route_bus_injectors = Arc::new(self.bus_injectors.clone());
1477 let path_for_pattern = path_str.clone();
1478 let path_for_handler = path_str;
1479 let method_for_pattern = method.clone();
1480 let method_for_handler = method;
1481
1482 let handler: RouteHandler<R> = Arc::new(move |req: Request<Incoming>, res: &R| {
1483 let circuit = circuit.clone();
1484 let route_bus_injectors = route_bus_injectors.clone();
1485 let res = res.clone();
1486 let path = path_for_handler.clone();
1487 let method = method_for_handler.clone();
1488
1489 Box::pin(async move {
1490 let request_id = uuid::Uuid::new_v4().to_string();
1491 let span = tracing::info_span!(
1492 "HTTPRequest",
1493 ranvier.http.method = %method,
1494 ranvier.http.path = %path,
1495 ranvier.http.request_id = %request_id
1496 );
1497
1498 async move {
1499 let (parts, body) = req.into_parts();
1501
1502 let body_bytes = match body.collect().await {
1504 Ok(collected) => collected.to_bytes(),
1505 Err(err) => {
1506 tracing::warn!("Failed to collect request body: {:?}", err);
1507 Bytes::new()
1508 }
1509 };
1510
1511 let mut bus = Bus::new();
1512 let stub_req = {
1516 let mut builder = Request::builder()
1517 .method(&parts.method)
1518 .uri(parts.uri.clone());
1519 for (k, v) in &parts.headers {
1520 builder = builder.header(k, v);
1521 }
1522 let mut stub = builder
1523 .body(Full::new(Bytes::new()))
1524 .unwrap_or_else(|_| Request::new(Full::new(Bytes::new())));
1525 *stub.extensions_mut() = parts.extensions.clone();
1526 stub
1527 };
1528 if let Some(params) = stub_req.extensions().get::<PathParams>() {
1536 bus.insert(params.clone());
1537 }
1538
1539 bus.insert(HttpRequestBody::new(body_bytes));
1541
1542 let result = circuit.execute((), &res, &mut bus).await;
1543 outcome_to_response_with_error(result, |error| {
1544 Response::builder()
1545 .status(StatusCode::INTERNAL_SERVER_ERROR)
1546 .body(Full::new(Bytes::from(format!("Error: {:?}", error))))
1547 .unwrap()
1548 })
1549 }
1550 .instrument(span)
1551 .await
1552 }) as Pin<Box<dyn Future<Output = Response<Full<Bytes>>> + Send>>
1553 });
1554
1555 self.routes.push(RouteEntry {
1556 method: method_for_pattern,
1557 pattern: RoutePattern::parse(&path_for_pattern),
1558 handler,
1559 layers: Arc::new(Vec::new()),
1560 apply_global_layers: true,
1561 });
1562 self
1563 }
1564
1565 pub fn get<Out, E>(self, path: impl Into<String>, circuit: Axon<(), Out, E, R>) -> Self
1566 where
1567 Out: IntoResponse + Send + Sync + 'static,
1568 E: Send + 'static + std::fmt::Debug,
1569 {
1570 self.route_method(Method::GET, path, circuit)
1571 }
1572
1573 pub fn get_with_error<Out, E, H>(
1574 self,
1575 path: impl Into<String>,
1576 circuit: Axon<(), Out, E, R>,
1577 error_handler: H,
1578 ) -> Self
1579 where
1580 Out: IntoResponse + Send + Sync + 'static,
1581 E: Send + 'static + std::fmt::Debug,
1582 H: Fn(&E) -> Response<Full<Bytes>> + Send + Sync + 'static,
1583 {
1584 self.route_method_with_error(Method::GET, path, circuit, error_handler)
1585 }
1586
1587 pub fn get_with_layer<Out, E, L>(
1588 self,
1589 path: impl Into<String>,
1590 circuit: Axon<(), Out, E, R>,
1591 layer: L,
1592 ) -> Self
1593 where
1594 Out: IntoResponse + Send + Sync + 'static,
1595 E: Send + 'static + std::fmt::Debug,
1596 L: Layer<BoxHttpService> + Clone + Send + Sync + 'static,
1597 L::Service: Service<Request<Incoming>, Response = Response<Full<Bytes>>, Error = Infallible>
1598 + Clone
1599 + Send
1600 + 'static,
1601 <L::Service as Service<Request<Incoming>>>::Future: Send + 'static,
1602 {
1603 self.route_method_with_layer(Method::GET, path, circuit, layer)
1604 }
1605
1606 pub fn get_with_layer_override<Out, E, L>(
1607 self,
1608 path: impl Into<String>,
1609 circuit: Axon<(), Out, E, R>,
1610 layer: L,
1611 ) -> Self
1612 where
1613 Out: IntoResponse + Send + Sync + 'static,
1614 E: Send + 'static + std::fmt::Debug,
1615 L: Layer<BoxHttpService> + Clone + Send + Sync + 'static,
1616 L::Service: Service<Request<Incoming>, Response = Response<Full<Bytes>>, Error = Infallible>
1617 + Clone
1618 + Send
1619 + 'static,
1620 <L::Service as Service<Request<Incoming>>>::Future: Send + 'static,
1621 {
1622 self.route_method_with_layer_override(Method::GET, path, circuit, layer)
1623 }
1624
1625 pub fn post<Out, E>(self, path: impl Into<String>, circuit: Axon<(), Out, E, R>) -> Self
1626 where
1627 Out: IntoResponse + Send + Sync + 'static,
1628 E: Send + 'static + std::fmt::Debug,
1629 {
1630 self.route_method(Method::POST, path, circuit)
1631 }
1632
1633 pub fn post_body<Out, E>(self, path: impl Into<String>, circuit: Axon<(), Out, E, R>) -> Self
1654 where
1655 Out: IntoResponse + Send + Sync + 'static,
1656 E: Send + 'static + std::fmt::Debug,
1657 {
1658 self.route_method_with_body(Method::POST, path, circuit)
1659 }
1660
1661 pub fn put_body<Out, E>(self, path: impl Into<String>, circuit: Axon<(), Out, E, R>) -> Self
1663 where
1664 Out: IntoResponse + Send + Sync + 'static,
1665 E: Send + 'static + std::fmt::Debug,
1666 {
1667 self.route_method_with_body(Method::PUT, path, circuit)
1668 }
1669
1670 pub fn patch_body<Out, E>(self, path: impl Into<String>, circuit: Axon<(), Out, E, R>) -> Self
1672 where
1673 Out: IntoResponse + Send + Sync + 'static,
1674 E: Send + 'static + std::fmt::Debug,
1675 {
1676 self.route_method_with_body(Method::PATCH, path, circuit)
1677 }
1678
1679 pub fn put<Out, E>(self, path: impl Into<String>, circuit: Axon<(), Out, E, R>) -> Self
1680 where
1681 Out: IntoResponse + Send + Sync + 'static,
1682 E: Send + 'static + std::fmt::Debug,
1683 {
1684 self.route_method(Method::PUT, path, circuit)
1685 }
1686
1687 pub fn delete<Out, E>(self, path: impl Into<String>, circuit: Axon<(), Out, E, R>) -> Self
1688 where
1689 Out: IntoResponse + Send + Sync + 'static,
1690 E: Send + 'static + std::fmt::Debug,
1691 {
1692 self.route_method(Method::DELETE, path, circuit)
1693 }
1694
1695 pub fn patch<Out, E>(self, path: impl Into<String>, circuit: Axon<(), Out, E, R>) -> Self
1696 where
1697 Out: IntoResponse + Send + Sync + 'static,
1698 E: Send + 'static + std::fmt::Debug,
1699 {
1700 self.route_method(Method::PATCH, path, circuit)
1701 }
1702
1703 pub fn fallback<Out, E>(mut self, circuit: Axon<(), Out, E, R>) -> Self
1714 where
1715 Out: IntoResponse + Send + Sync + 'static,
1716 E: Send + 'static + std::fmt::Debug,
1717 {
1718 let circuit = Arc::new(circuit);
1719 let fallback_bus_injectors = Arc::new(self.bus_injectors.clone());
1720
1721 let handler: RouteHandler<R> = Arc::new(move |req: Request<Incoming>, res: &R| {
1722 let circuit = circuit.clone();
1723 let fallback_bus_injectors = fallback_bus_injectors.clone();
1724 let res = res.clone();
1725 Box::pin(async move {
1726 let request_id = uuid::Uuid::new_v4().to_string();
1727 let span = tracing::info_span!(
1728 "HTTPRequest",
1729 ranvier.http.method = "FALLBACK",
1730 ranvier.http.request_id = %request_id
1731 );
1732
1733 async move {
1734 let mut bus = Bus::new();
1735 for injector in fallback_bus_injectors.iter() {
1736 injector(&req, &mut bus);
1737 }
1738 let result = circuit.execute((), &res, &mut bus).await;
1739
1740 match result {
1741 Outcome::Next(output) => {
1742 let mut response = output.into_response();
1743 *response.status_mut() = StatusCode::NOT_FOUND;
1744 response
1745 }
1746 _ => Response::builder()
1747 .status(StatusCode::NOT_FOUND)
1748 .body(Full::new(Bytes::from("Not Found")))
1749 .unwrap(),
1750 }
1751 }
1752 .instrument(span)
1753 .await
1754 }) as Pin<Box<dyn Future<Output = Response<Full<Bytes>>> + Send>>
1755 });
1756
1757 self.fallback = Some(handler);
1758 self
1759 }
1760
1761 pub async fn run(self, resources: R) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
1763 self.run_with_shutdown_signal(resources, shutdown_signal())
1764 .await
1765 }
1766
1767 async fn run_with_shutdown_signal<S>(
1768 self,
1769 resources: R,
1770 shutdown_signal: S,
1771 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>>
1772 where
1773 S: Future<Output = ()> + Send,
1774 {
1775 let addr_str = self.addr.as_deref().unwrap_or("127.0.0.1:3000");
1776 let addr: SocketAddr = addr_str.parse()?;
1777
1778 let routes = Arc::new(self.routes);
1779 let fallback = self.fallback;
1780 let layers = Arc::new(self.layers);
1781 let health = Arc::new(self.health);
1782 let static_assets = Arc::new(self.static_assets);
1783 let on_start = self.on_start;
1784 let on_shutdown = self.on_shutdown;
1785 let graceful_shutdown_timeout = self.graceful_shutdown_timeout;
1786 let resources = Arc::new(resources);
1787
1788 let listener = TcpListener::bind(addr).await?;
1789 tracing::info!("Ranvier HTTP Ingress listening on http://{}", addr);
1790 if let Some(callback) = on_start.as_ref() {
1791 callback();
1792 }
1793
1794 tokio::pin!(shutdown_signal);
1795 let mut connections = tokio::task::JoinSet::new();
1796
1797 loop {
1798 tokio::select! {
1799 _ = &mut shutdown_signal => {
1800 tracing::info!("Shutdown signal received. Draining in-flight connections.");
1801 break;
1802 }
1803 accept_result = listener.accept() => {
1804 let (stream, _) = accept_result?;
1805 let io = TokioIo::new(stream);
1806
1807 let routes = routes.clone();
1808 let fallback = fallback.clone();
1809 let resources = resources.clone();
1810 let layers = layers.clone();
1811 let health = health.clone();
1812 let static_assets = static_assets.clone();
1813
1814 connections.spawn(async move {
1815 let service = build_http_service(
1816 routes,
1817 fallback,
1818 resources,
1819 layers,
1820 health,
1821 static_assets,
1822 );
1823 let hyper_service = TowerToHyperService::new(service);
1824 if let Err(err) = http1::Builder::new()
1825 .serve_connection(io, hyper_service)
1826 .with_upgrades()
1827 .await
1828 {
1829 tracing::error!("Error serving connection: {:?}", err);
1830 }
1831 });
1832 }
1833 Some(join_result) = connections.join_next(), if !connections.is_empty() => {
1834 if let Err(err) = join_result {
1835 tracing::warn!("Connection task join error: {:?}", err);
1836 }
1837 }
1838 }
1839 }
1840
1841 let _timed_out = drain_connections(&mut connections, graceful_shutdown_timeout).await;
1842
1843 drop(resources);
1844 if let Some(callback) = on_shutdown.as_ref() {
1845 callback();
1846 }
1847
1848 Ok(())
1849 }
1850
1851 pub fn into_raw_service(self, resources: R) -> RawIngressService<R> {
1867 let routes = Arc::new(self.routes);
1868 let fallback = self.fallback;
1869 let layers = Arc::new(self.layers);
1870 let health = Arc::new(self.health);
1871 let static_assets = Arc::new(self.static_assets);
1872 let resources = Arc::new(resources);
1873
1874 RawIngressService {
1875 routes,
1876 fallback,
1877 layers,
1878 health,
1879 static_assets,
1880 resources,
1881 }
1882 }
1883}
1884
1885fn build_http_service<R>(
1886 routes: Arc<Vec<RouteEntry<R>>>,
1887 fallback: Option<RouteHandler<R>>,
1888 resources: Arc<R>,
1889 layers: Arc<Vec<ServiceLayer>>,
1890 health: Arc<HealthConfig<R>>,
1891 static_assets: Arc<StaticAssetsConfig>,
1892) -> BoxHttpService
1893where
1894 R: ranvier_core::transition::ResourceRequirement + Clone + Send + Sync + 'static,
1895{
1896 let base_service = service_fn(move |req: Request<Incoming>| {
1897 let routes = routes.clone();
1898 let fallback = fallback.clone();
1899 let resources = resources.clone();
1900 let layers = layers.clone();
1901 let health = health.clone();
1902 let static_assets = static_assets.clone();
1903
1904 async move {
1905 let mut req = req;
1906 let method = req.method().clone();
1907 let path = req.uri().path().to_string();
1908
1909 if let Some(response) =
1910 maybe_handle_health_request(&method, &path, &health, resources.clone()).await
1911 {
1912 return Ok::<_, Infallible>(response);
1913 }
1914
1915 if let Some((entry, params)) = find_matching_route(routes.as_slice(), &method, &path) {
1916 req.extensions_mut().insert(params);
1917 let effective_layers = if entry.apply_global_layers {
1918 merge_layers(&layers, &entry.layers)
1919 } else {
1920 entry.layers.clone()
1921 };
1922
1923 if effective_layers.is_empty() {
1924 Ok::<_, Infallible>((entry.handler)(req, &resources).await)
1925 } else {
1926 let route_service = build_route_service(
1927 entry.handler.clone(),
1928 resources.clone(),
1929 effective_layers,
1930 );
1931 route_service.oneshot(req).await
1932 }
1933 } else {
1934 let req =
1935 match maybe_handle_static_request(req, &method, &path, static_assets.as_ref())
1936 .await
1937 {
1938 Ok(req) => req,
1939 Err(response) => return Ok(response),
1940 };
1941
1942 if let Some(ref fb) = fallback {
1943 if layers.is_empty() {
1944 Ok(fb(req, &resources).await)
1945 } else {
1946 let fallback_service =
1947 build_route_service(fb.clone(), resources.clone(), layers.clone());
1948 fallback_service.oneshot(req).await
1949 }
1950 } else {
1951 Ok(Response::builder()
1952 .status(StatusCode::NOT_FOUND)
1953 .body(Full::new(Bytes::from("Not Found")))
1954 .unwrap())
1955 }
1956 }
1957 }
1958 });
1959
1960 BoxCloneService::new(base_service)
1961}
1962
1963fn build_route_service<R>(
1964 handler: RouteHandler<R>,
1965 resources: Arc<R>,
1966 layers: Arc<Vec<ServiceLayer>>,
1967) -> BoxHttpService
1968where
1969 R: ranvier_core::transition::ResourceRequirement + Clone + Send + Sync + 'static,
1970{
1971 let base_service = service_fn(move |req: Request<Incoming>| {
1972 let handler = handler.clone();
1973 let resources = resources.clone();
1974 async move { Ok::<_, Infallible>(handler(req, &resources).await) }
1975 });
1976
1977 let mut service = BoxCloneService::new(base_service);
1978 for layer in layers.iter() {
1979 service = layer(service);
1980 }
1981 service
1982}
1983
1984fn merge_layers(
1985 global_layers: &Arc<Vec<ServiceLayer>>,
1986 route_layers: &Arc<Vec<ServiceLayer>>,
1987) -> Arc<Vec<ServiceLayer>> {
1988 if global_layers.is_empty() {
1989 return route_layers.clone();
1990 }
1991 if route_layers.is_empty() {
1992 return global_layers.clone();
1993 }
1994
1995 let mut combined = Vec::with_capacity(global_layers.len() + route_layers.len());
1996 combined.extend(global_layers.iter().cloned());
1997 combined.extend(route_layers.iter().cloned());
1998 Arc::new(combined)
1999}
2000
2001async fn maybe_handle_health_request<R>(
2002 method: &Method,
2003 path: &str,
2004 health: &HealthConfig<R>,
2005 resources: Arc<R>,
2006) -> Option<Response<Full<Bytes>>>
2007where
2008 R: ranvier_core::transition::ResourceRequirement + Clone + Send + Sync + 'static,
2009{
2010 if method != Method::GET {
2011 return None;
2012 }
2013
2014 if let Some(liveness_path) = health.liveness_path.as_ref() {
2015 if path == liveness_path {
2016 return Some(health_json_response("liveness", true, Vec::new()));
2017 }
2018 }
2019
2020 if let Some(readiness_path) = health.readiness_path.as_ref() {
2021 if path == readiness_path {
2022 let (healthy, checks) = run_named_health_checks(&health.checks, resources).await;
2023 return Some(health_json_response("readiness", healthy, checks));
2024 }
2025 }
2026
2027 if let Some(health_path) = health.health_path.as_ref() {
2028 if path == health_path {
2029 let (healthy, checks) = run_named_health_checks(&health.checks, resources).await;
2030 return Some(health_json_response("health", healthy, checks));
2031 }
2032 }
2033
2034 None
2035}
2036
2037async fn maybe_handle_static_request(
2038 req: Request<Incoming>,
2039 method: &Method,
2040 path: &str,
2041 static_assets: &StaticAssetsConfig,
2042) -> Result<Request<Incoming>, Response<Full<Bytes>>> {
2043 if method != Method::GET && method != Method::HEAD {
2044 return Ok(req);
2045 }
2046
2047 if let Some(mount) = static_assets
2048 .mounts
2049 .iter()
2050 .find(|mount| strip_mount_prefix(path, &mount.route_prefix).is_some())
2051 {
2052 let accept_encoding = req.headers().get(http::header::ACCEPT_ENCODING).cloned();
2053 let Some(stripped_path) = strip_mount_prefix(path, &mount.route_prefix) else {
2054 return Ok(req);
2055 };
2056 let rewritten = rewrite_request_path(req, &stripped_path);
2057 let service = ServeDir::new(&mount.directory);
2058 let response = match service.oneshot(rewritten).await {
2059 Ok(response) => response,
2060 Err(_) => {
2061 return Err(Response::builder()
2062 .status(StatusCode::INTERNAL_SERVER_ERROR)
2063 .body(Full::new(Bytes::from("Failed to serve static asset")))
2064 .unwrap_or_else(|_| Response::new(Full::new(Bytes::new()))));
2065 }
2066 };
2067 let response =
2068 collect_static_response(response, static_assets.cache_control.as_deref()).await;
2069 return Err(maybe_compress_static_response(
2070 response,
2071 accept_encoding,
2072 static_assets.enable_compression,
2073 )
2074 .await);
2075 }
2076
2077 if let Some(spa_file) = static_assets.spa_fallback.as_ref() {
2078 if looks_like_spa_request(path) {
2079 let accept_encoding = req.headers().get(http::header::ACCEPT_ENCODING).cloned();
2080 let service = ServeFile::new(spa_file);
2081 let response = match service.oneshot(req).await {
2082 Ok(response) => response,
2083 Err(_) => {
2084 return Err(Response::builder()
2085 .status(StatusCode::INTERNAL_SERVER_ERROR)
2086 .body(Full::new(Bytes::from("Failed to serve SPA fallback")))
2087 .unwrap_or_else(|_| Response::new(Full::new(Bytes::new()))));
2088 }
2089 };
2090 let response =
2091 collect_static_response(response, static_assets.cache_control.as_deref()).await;
2092 return Err(maybe_compress_static_response(
2093 response,
2094 accept_encoding,
2095 static_assets.enable_compression,
2096 )
2097 .await);
2098 }
2099 }
2100
2101 Ok(req)
2102}
2103
2104fn strip_mount_prefix(path: &str, prefix: &str) -> Option<String> {
2105 let normalized_prefix = if prefix == "/" {
2106 "/"
2107 } else {
2108 prefix.trim_end_matches('/')
2109 };
2110
2111 if normalized_prefix == "/" {
2112 return Some(path.to_string());
2113 }
2114
2115 if path == normalized_prefix {
2116 return Some("/".to_string());
2117 }
2118
2119 let with_slash = format!("{normalized_prefix}/");
2120 path.strip_prefix(&with_slash)
2121 .map(|stripped| format!("/{}", stripped))
2122}
2123
2124fn rewrite_request_path(mut req: Request<Incoming>, new_path: &str) -> Request<Incoming> {
2125 let query = req.uri().query().map(str::to_string);
2126 let path_and_query = match query {
2127 Some(query) => format!("{new_path}?{query}"),
2128 None => new_path.to_string(),
2129 };
2130
2131 let mut parts = req.uri().clone().into_parts();
2132 if let Ok(parsed_path_and_query) = path_and_query.parse() {
2133 parts.path_and_query = Some(parsed_path_and_query);
2134 if let Ok(uri) = Uri::from_parts(parts) {
2135 *req.uri_mut() = uri;
2136 }
2137 }
2138
2139 req
2140}
2141
2142async fn collect_static_response<B>(
2143 response: Response<B>,
2144 cache_control: Option<&str>,
2145) -> Response<Full<Bytes>>
2146where
2147 B: Body<Data = Bytes> + Send + 'static,
2148 B::Error: std::fmt::Display,
2149{
2150 let status = response.status();
2151 let headers = response.headers().clone();
2152 let body = response.into_body();
2153 let collected = body.collect().await;
2154
2155 let bytes = match collected {
2156 Ok(value) => value.to_bytes(),
2157 Err(error) => Bytes::from(error.to_string()),
2158 };
2159
2160 let mut builder = Response::builder().status(status);
2161 for (name, value) in headers.iter() {
2162 builder = builder.header(name, value);
2163 }
2164
2165 let mut response = builder
2166 .body(Full::new(bytes))
2167 .unwrap_or_else(|_| Response::new(Full::new(Bytes::new())));
2168
2169 if status == StatusCode::OK {
2170 if let Some(value) = cache_control {
2171 if !response.headers().contains_key(http::header::CACHE_CONTROL) {
2172 if let Ok(header_value) = http::HeaderValue::from_str(value) {
2173 response
2174 .headers_mut()
2175 .insert(http::header::CACHE_CONTROL, header_value);
2176 }
2177 }
2178 }
2179 }
2180
2181 response
2182}
2183
2184fn looks_like_spa_request(path: &str) -> bool {
2185 let tail = path.rsplit('/').next().unwrap_or_default();
2186 !tail.contains('.')
2187}
2188
2189async fn maybe_compress_static_response(
2190 response: Response<Full<Bytes>>,
2191 accept_encoding: Option<http::HeaderValue>,
2192 enable_compression: bool,
2193) -> Response<Full<Bytes>> {
2194 if !enable_compression {
2195 return response;
2196 }
2197
2198 let Some(accept_encoding) = accept_encoding else {
2199 return response;
2200 };
2201
2202 let mut request = Request::builder()
2203 .uri("/")
2204 .body(Full::new(Bytes::new()))
2205 .unwrap_or_else(|_| Request::new(Full::new(Bytes::new())));
2206 request
2207 .headers_mut()
2208 .insert(http::header::ACCEPT_ENCODING, accept_encoding);
2209
2210 let service = CompressionLayer::new().layer(service_fn({
2211 let response = response.clone();
2212 move |_req: Request<Full<Bytes>>| {
2213 let response = response.clone();
2214 async move { Ok::<_, Infallible>(response) }
2215 }
2216 }));
2217
2218 match service.oneshot(request).await {
2219 Ok(compressed) => collect_static_response(compressed, None).await,
2220 Err(_) => response,
2221 }
2222}
2223
2224async fn run_named_health_checks<R>(
2225 checks: &[NamedHealthCheck<R>],
2226 resources: Arc<R>,
2227) -> (bool, Vec<HealthCheckReport>)
2228where
2229 R: ranvier_core::transition::ResourceRequirement + Clone + Send + Sync + 'static,
2230{
2231 let mut reports = Vec::with_capacity(checks.len());
2232 let mut healthy = true;
2233
2234 for check in checks {
2235 match (check.check)(resources.clone()).await {
2236 Ok(()) => reports.push(HealthCheckReport {
2237 name: check.name.clone(),
2238 status: "ok",
2239 error: None,
2240 }),
2241 Err(error) => {
2242 healthy = false;
2243 reports.push(HealthCheckReport {
2244 name: check.name.clone(),
2245 status: "error",
2246 error: Some(error),
2247 });
2248 }
2249 }
2250 }
2251
2252 (healthy, reports)
2253}
2254
2255fn health_json_response(
2256 probe: &'static str,
2257 healthy: bool,
2258 checks: Vec<HealthCheckReport>,
2259) -> Response<Full<Bytes>> {
2260 let status_code = if healthy {
2261 StatusCode::OK
2262 } else {
2263 StatusCode::SERVICE_UNAVAILABLE
2264 };
2265 let status = if healthy { "ok" } else { "degraded" };
2266 let payload = HealthReport {
2267 status,
2268 probe,
2269 checks,
2270 };
2271
2272 let body = serde_json::to_vec(&payload)
2273 .unwrap_or_else(|_| br#"{"status":"error","probe":"health"}"#.to_vec());
2274
2275 Response::builder()
2276 .status(status_code)
2277 .header(http::header::CONTENT_TYPE, "application/json")
2278 .body(Full::new(Bytes::from(body)))
2279 .unwrap()
2280}
2281
2282async fn shutdown_signal() {
2283 #[cfg(unix)]
2284 {
2285 use tokio::signal::unix::{SignalKind, signal};
2286
2287 match signal(SignalKind::terminate()) {
2288 Ok(mut terminate) => {
2289 tokio::select! {
2290 _ = tokio::signal::ctrl_c() => {}
2291 _ = terminate.recv() => {}
2292 }
2293 }
2294 Err(err) => {
2295 tracing::warn!("Failed to install SIGTERM handler: {:?}", err);
2296 if let Err(ctrl_c_err) = tokio::signal::ctrl_c().await {
2297 tracing::warn!("Failed to listen for Ctrl+C: {:?}", ctrl_c_err);
2298 }
2299 }
2300 }
2301 }
2302
2303 #[cfg(not(unix))]
2304 {
2305 if let Err(err) = tokio::signal::ctrl_c().await {
2306 tracing::warn!("Failed to listen for Ctrl+C: {:?}", err);
2307 }
2308 }
2309}
2310
2311async fn drain_connections(
2312 connections: &mut tokio::task::JoinSet<()>,
2313 graceful_shutdown_timeout: Duration,
2314) -> bool {
2315 if connections.is_empty() {
2316 return false;
2317 }
2318
2319 let drain_result = tokio::time::timeout(graceful_shutdown_timeout, async {
2320 while let Some(join_result) = connections.join_next().await {
2321 if let Err(err) = join_result {
2322 tracing::warn!("Connection task join error during shutdown: {:?}", err);
2323 }
2324 }
2325 })
2326 .await;
2327
2328 if drain_result.is_err() {
2329 tracing::warn!(
2330 "Graceful shutdown timeout reached ({:?}). Aborting remaining connections.",
2331 graceful_shutdown_timeout
2332 );
2333 connections.abort_all();
2334 while let Some(join_result) = connections.join_next().await {
2335 if let Err(err) = join_result {
2336 tracing::warn!("Connection task abort join error: {:?}", err);
2337 }
2338 }
2339 true
2340 } else {
2341 false
2342 }
2343}
2344
2345impl<R> Default for HttpIngress<R>
2346where
2347 R: ranvier_core::transition::ResourceRequirement + Clone + Send + Sync + 'static,
2348{
2349 fn default() -> Self {
2350 Self::new()
2351 }
2352}
2353
2354#[deprecated(since = "0.9.0", note = "Internal service type")]
2356#[derive(Clone)]
2357pub struct RawIngressService<R> {
2358 routes: Arc<Vec<RouteEntry<R>>>,
2359 fallback: Option<RouteHandler<R>>,
2360 layers: Arc<Vec<ServiceLayer>>,
2361 health: Arc<HealthConfig<R>>,
2362 static_assets: Arc<StaticAssetsConfig>,
2363 resources: Arc<R>,
2364}
2365
2366impl<R> Service<Request<Incoming>> for RawIngressService<R>
2367where
2368 R: ranvier_core::transition::ResourceRequirement + Clone + Send + Sync + 'static,
2369{
2370 type Response = Response<Full<Bytes>>;
2371 type Error = Infallible;
2372 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
2373
2374 fn poll_ready(
2375 &mut self,
2376 _cx: &mut std::task::Context<'_>,
2377 ) -> std::task::Poll<Result<(), Self::Error>> {
2378 std::task::Poll::Ready(Ok(()))
2379 }
2380
2381 fn call(&mut self, req: Request<Incoming>) -> Self::Future {
2382 let routes = self.routes.clone();
2383 let fallback = self.fallback.clone();
2384 let layers = self.layers.clone();
2385 let health = self.health.clone();
2386 let static_assets = self.static_assets.clone();
2387 let resources = self.resources.clone();
2388
2389 Box::pin(async move {
2390 let service =
2391 build_http_service(routes, fallback, resources, layers, health, static_assets);
2392 service.oneshot(req).await
2393 })
2394 }
2395}
2396
2397#[cfg(test)]
2398mod tests {
2399 use super::*;
2400 use async_trait::async_trait;
2401 use futures_util::{SinkExt, StreamExt};
2402 use ranvier_observe::{HttpMetrics, HttpMetricsLayer, IncomingTraceContext, TraceContextLayer};
2403 use serde::Deserialize;
2404 use std::fs;
2405 use std::sync::atomic::{AtomicBool, Ordering};
2406 use tempfile::tempdir;
2407 use tokio::io::{AsyncReadExt, AsyncWriteExt};
2408 use tokio_tungstenite::tungstenite::Message as WsClientMessage;
2409 use tokio_tungstenite::tungstenite::client::IntoClientRequest;
2410
2411 async fn connect_with_retry(addr: std::net::SocketAddr) -> tokio::net::TcpStream {
2412 let deadline = tokio::time::Instant::now() + Duration::from_secs(2);
2413
2414 loop {
2415 match tokio::net::TcpStream::connect(addr).await {
2416 Ok(stream) => return stream,
2417 Err(error) => {
2418 if tokio::time::Instant::now() >= deadline {
2419 panic!("connect server: {error}");
2420 }
2421 tokio::time::sleep(Duration::from_millis(25)).await;
2422 }
2423 }
2424 }
2425 }
2426
2427 #[test]
2428 fn route_pattern_matches_static_path() {
2429 let pattern = RoutePattern::parse("/orders/list");
2430 let params = pattern.match_path("/orders/list").expect("should match");
2431 assert!(params.into_inner().is_empty());
2432 }
2433
2434 #[test]
2435 fn route_pattern_matches_param_segments() {
2436 let pattern = RoutePattern::parse("/orders/:id/items/:item_id");
2437 let params = pattern
2438 .match_path("/orders/42/items/sku-123")
2439 .expect("should match");
2440 assert_eq!(params.get("id"), Some("42"));
2441 assert_eq!(params.get("item_id"), Some("sku-123"));
2442 }
2443
2444 #[test]
2445 fn route_pattern_matches_wildcard_segment() {
2446 let pattern = RoutePattern::parse("/assets/*path");
2447 let params = pattern
2448 .match_path("/assets/css/theme/light.css")
2449 .expect("should match");
2450 assert_eq!(params.get("path"), Some("css/theme/light.css"));
2451 }
2452
2453 #[test]
2454 fn route_pattern_rejects_non_matching_path() {
2455 let pattern = RoutePattern::parse("/orders/:id");
2456 assert!(pattern.match_path("/users/42").is_none());
2457 }
2458
2459 #[test]
2460 fn graceful_shutdown_timeout_defaults_to_30_seconds() {
2461 let ingress = HttpIngress::<()>::new();
2462 assert_eq!(ingress.graceful_shutdown_timeout, Duration::from_secs(30));
2463 assert!(ingress.layers.is_empty());
2464 assert!(ingress.bus_injectors.is_empty());
2465 assert!(ingress.static_assets.mounts.is_empty());
2466 assert!(ingress.on_start.is_none());
2467 assert!(ingress.on_shutdown.is_none());
2468 }
2469
2470 #[test]
2471 fn layer_registration_stacks_globally() {
2472 let ingress = HttpIngress::<()>::new()
2473 .layer(tower::layer::util::Identity::new())
2474 .layer(tower::layer::util::Identity::new());
2475 assert_eq!(ingress.layers.len(), 2);
2476 }
2477
2478 #[test]
2479 fn layer_accepts_tower_http_cors_layer() {
2480 let ingress = HttpIngress::<()>::new().layer(tower_http::cors::CorsLayer::permissive());
2481 assert_eq!(ingress.layers.len(), 1);
2482 }
2483
2484 #[test]
2485 fn route_without_layer_keeps_empty_route_middleware_stack() {
2486 let ingress =
2487 HttpIngress::<()>::new().get("/ping", Axon::<(), (), Infallible, ()>::new("Ping"));
2488 assert_eq!(ingress.routes.len(), 1);
2489 assert!(ingress.routes[0].layers.is_empty());
2490 assert!(ingress.routes[0].apply_global_layers);
2491 }
2492
2493 #[test]
2494 fn route_with_layer_registers_route_middleware_stack() {
2495 let ingress = HttpIngress::<()>::new().get_with_layer(
2496 "/ping",
2497 Axon::<(), (), Infallible, ()>::new("Ping"),
2498 tower::layer::util::Identity::new(),
2499 );
2500 assert_eq!(ingress.routes.len(), 1);
2501 assert_eq!(ingress.routes[0].layers.len(), 1);
2502 assert!(ingress.routes[0].apply_global_layers);
2503 }
2504
2505 #[test]
2506 fn route_with_layer_override_disables_global_layers() {
2507 let ingress = HttpIngress::<()>::new().get_with_layer_override(
2508 "/ping",
2509 Axon::<(), (), Infallible, ()>::new("Ping"),
2510 tower::layer::util::Identity::new(),
2511 );
2512 assert_eq!(ingress.routes.len(), 1);
2513 assert_eq!(ingress.routes[0].layers.len(), 1);
2514 assert!(!ingress.routes[0].apply_global_layers);
2515 }
2516
2517 #[test]
2518 fn timeout_layer_registers_builtin_middleware() {
2519 let ingress = HttpIngress::<()>::new().timeout_layer(Duration::from_secs(1));
2520 assert_eq!(ingress.layers.len(), 1);
2521 }
2522
2523 #[test]
2524 fn request_id_layer_registers_builtin_middleware() {
2525 let ingress = HttpIngress::<()>::new().request_id_layer();
2526 assert_eq!(ingress.layers.len(), 1);
2527 }
2528
2529 #[test]
2530 fn compression_layer_registers_builtin_middleware() {
2531 let ingress = HttpIngress::<()>::new().compression_layer();
2532 assert!(ingress.static_assets.enable_compression);
2533 }
2534
2535 #[test]
2536 fn bus_injector_registration_adds_hook() {
2537 let ingress = HttpIngress::<()>::new().bus_injector(|_req, bus| {
2538 bus.insert("ok".to_string());
2539 });
2540 assert_eq!(ingress.bus_injectors.len(), 1);
2541 }
2542
2543 #[test]
2544 fn ws_route_registers_get_route_pattern() {
2545 let ingress =
2546 HttpIngress::<()>::new().ws("/ws/events", |_socket, _resources, _bus| async {});
2547 assert_eq!(ingress.routes.len(), 1);
2548 assert_eq!(ingress.routes[0].method, Method::GET);
2549 assert_eq!(ingress.routes[0].pattern.raw, "/ws/events");
2550 }
2551
2552 #[derive(Debug, Deserialize)]
2553 struct WsWelcomeFrame {
2554 connection_id: String,
2555 path: String,
2556 tenant: String,
2557 }
2558
2559 #[tokio::test]
2560 async fn ws_route_upgrades_and_bridges_event_source_sink_with_connection_bus() {
2561 let probe = std::net::TcpListener::bind("127.0.0.1:0").expect("bind probe");
2562 let addr = probe.local_addr().expect("local addr");
2563 drop(probe);
2564
2565 let ingress = HttpIngress::<()>::new()
2566 .bind(addr.to_string())
2567 .bus_injector(|req, bus| {
2568 if let Some(value) = req
2569 .headers()
2570 .get("x-tenant-id")
2571 .and_then(|v| v.to_str().ok())
2572 {
2573 bus.insert(value.to_string());
2574 }
2575 })
2576 .ws("/ws/echo", |mut socket, _resources, bus| async move {
2577 let tenant = bus
2578 .read::<String>()
2579 .cloned()
2580 .unwrap_or_else(|| "unknown".to_string());
2581 if let Some(session) = bus.read::<WebSocketSessionContext>() {
2582 let welcome = serde_json::json!({
2583 "connection_id": session.connection_id().to_string(),
2584 "path": session.path(),
2585 "tenant": tenant,
2586 });
2587 let _ = socket.send_json(&welcome).await;
2588 }
2589
2590 while let Some(event) = socket.next_event().await {
2591 match event {
2592 WebSocketEvent::Text(text) => {
2593 let _ = socket.send_event(format!("echo:{text}")).await;
2594 }
2595 WebSocketEvent::Binary(bytes) => {
2596 let _ = socket.send_event(bytes).await;
2597 }
2598 WebSocketEvent::Close => break,
2599 WebSocketEvent::Ping(_) | WebSocketEvent::Pong(_) => {}
2600 }
2601 }
2602 });
2603
2604 let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel::<()>();
2605 let server = tokio::spawn(async move {
2606 ingress
2607 .run_with_shutdown_signal((), async move {
2608 let _ = shutdown_rx.await;
2609 })
2610 .await
2611 });
2612
2613 let ws_uri = format!("ws://{addr}/ws/echo?room=alpha");
2614 let mut ws_request = ws_uri
2615 .as_str()
2616 .into_client_request()
2617 .expect("ws client request");
2618 ws_request
2619 .headers_mut()
2620 .insert("x-tenant-id", http::HeaderValue::from_static("acme"));
2621 let (mut client, _response) = tokio_tungstenite::connect_async(ws_request)
2622 .await
2623 .expect("websocket connect");
2624
2625 let welcome = client
2626 .next()
2627 .await
2628 .expect("welcome frame")
2629 .expect("welcome frame ok");
2630 let welcome_text = match welcome {
2631 WsClientMessage::Text(text) => text.to_string(),
2632 other => panic!("expected text welcome frame, got {other:?}"),
2633 };
2634 let welcome_payload: WsWelcomeFrame =
2635 serde_json::from_str(&welcome_text).expect("welcome json");
2636 assert_eq!(welcome_payload.path, "/ws/echo");
2637 assert_eq!(welcome_payload.tenant, "acme");
2638 assert!(!welcome_payload.connection_id.is_empty());
2639
2640 client
2641 .send(WsClientMessage::Text("hello".into()))
2642 .await
2643 .expect("send text");
2644 let echo_text = client
2645 .next()
2646 .await
2647 .expect("echo text frame")
2648 .expect("echo text frame ok");
2649 assert_eq!(echo_text, WsClientMessage::Text("echo:hello".into()));
2650
2651 client
2652 .send(WsClientMessage::Binary(vec![1, 2, 3, 4].into()))
2653 .await
2654 .expect("send binary");
2655 let echo_binary = client
2656 .next()
2657 .await
2658 .expect("echo binary frame")
2659 .expect("echo binary frame ok");
2660 assert_eq!(
2661 echo_binary,
2662 WsClientMessage::Binary(vec![1, 2, 3, 4].into())
2663 );
2664
2665 client.close(None).await.expect("close websocket");
2666
2667 let _ = shutdown_tx.send(());
2668 server
2669 .await
2670 .expect("server join")
2671 .expect("server shutdown should succeed");
2672 }
2673
2674 #[derive(Clone)]
2675 struct EchoTrace;
2676
2677 #[async_trait]
2678 impl Transition<(), String> for EchoTrace {
2679 type Error = Infallible;
2680 type Resources = ();
2681
2682 async fn run(
2683 &self,
2684 _state: (),
2685 _resources: &Self::Resources,
2686 bus: &mut Bus,
2687 ) -> Outcome<String, Self::Error> {
2688 let trace_id = bus
2689 .read::<String>()
2690 .cloned()
2691 .unwrap_or_else(|| "missing-trace".to_string());
2692 Outcome::next(trace_id)
2693 }
2694 }
2695
2696 #[tokio::test]
2697 async fn observe_trace_context_and_metrics_layers_work_with_ingress() {
2698 let metrics = HttpMetrics::default();
2699 let ingress = HttpIngress::<()>::new()
2700 .layer(TraceContextLayer::new())
2701 .layer(HttpMetricsLayer::new(metrics.clone()))
2702 .bus_injector(|req, bus| {
2703 if let Some(trace) = req.extensions().get::<IncomingTraceContext>() {
2704 bus.insert(trace.trace_id().to_string());
2705 }
2706 })
2707 .get(
2708 "/trace",
2709 Axon::<(), (), Infallible, ()>::new("EchoTrace").then(EchoTrace),
2710 );
2711
2712 let app = crate::test_harness::TestApp::new(ingress, ());
2713 let response = app
2714 .send(crate::test_harness::TestRequest::get("/trace").header(
2715 "traceparent",
2716 "00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01",
2717 ))
2718 .await
2719 .expect("request should succeed");
2720
2721 assert_eq!(response.status(), StatusCode::OK);
2722 assert_eq!(
2723 response.text().expect("utf8 response"),
2724 "4bf92f3577b34da6a3ce929d0e0e4736"
2725 );
2726
2727 let snapshot = metrics.snapshot();
2728 assert_eq!(snapshot.requests_total, 1);
2729 assert_eq!(snapshot.requests_error, 0);
2730 }
2731
2732 #[test]
2733 fn route_descriptors_export_http_and_health_paths() {
2734 let ingress = HttpIngress::<()>::new()
2735 .get(
2736 "/orders/:id",
2737 Axon::<(), (), Infallible, ()>::new("OrderById"),
2738 )
2739 .health_endpoint("/healthz")
2740 .readiness_liveness("/readyz", "/livez");
2741
2742 let descriptors = ingress.route_descriptors();
2743
2744 assert!(
2745 descriptors
2746 .iter()
2747 .any(|descriptor| descriptor.method() == Method::GET
2748 && descriptor.path_pattern() == "/orders/:id")
2749 );
2750 assert!(
2751 descriptors
2752 .iter()
2753 .any(|descriptor| descriptor.method() == Method::GET
2754 && descriptor.path_pattern() == "/healthz")
2755 );
2756 assert!(
2757 descriptors
2758 .iter()
2759 .any(|descriptor| descriptor.method() == Method::GET
2760 && descriptor.path_pattern() == "/readyz")
2761 );
2762 assert!(
2763 descriptors
2764 .iter()
2765 .any(|descriptor| descriptor.method() == Method::GET
2766 && descriptor.path_pattern() == "/livez")
2767 );
2768 }
2769
2770 #[tokio::test]
2771 async fn lifecycle_hooks_fire_on_start_and_shutdown() {
2772 let started = Arc::new(AtomicBool::new(false));
2773 let shutdown = Arc::new(AtomicBool::new(false));
2774
2775 let started_flag = started.clone();
2776 let shutdown_flag = shutdown.clone();
2777
2778 let ingress = HttpIngress::<()>::new()
2779 .bind("127.0.0.1:0")
2780 .on_start(move || {
2781 started_flag.store(true, Ordering::SeqCst);
2782 })
2783 .on_shutdown(move || {
2784 shutdown_flag.store(true, Ordering::SeqCst);
2785 })
2786 .graceful_shutdown(Duration::from_millis(50));
2787
2788 ingress
2789 .run_with_shutdown_signal((), async {
2790 tokio::time::sleep(Duration::from_millis(20)).await;
2791 })
2792 .await
2793 .expect("server should exit gracefully");
2794
2795 assert!(started.load(Ordering::SeqCst));
2796 assert!(shutdown.load(Ordering::SeqCst));
2797 }
2798
2799 #[tokio::test]
2800 async fn graceful_shutdown_drains_in_flight_requests_before_exit() {
2801 #[derive(Clone)]
2802 struct SlowDrainRoute;
2803
2804 #[async_trait]
2805 impl Transition<(), &'static str> for SlowDrainRoute {
2806 type Error = Infallible;
2807 type Resources = ();
2808
2809 async fn run(
2810 &self,
2811 _state: (),
2812 _resources: &Self::Resources,
2813 _bus: &mut Bus,
2814 ) -> Outcome<&'static str, Self::Error> {
2815 tokio::time::sleep(Duration::from_millis(120)).await;
2816 Outcome::next("drained-ok")
2817 }
2818 }
2819
2820 let probe = std::net::TcpListener::bind("127.0.0.1:0").expect("bind probe");
2821 let addr = probe.local_addr().expect("local addr");
2822 drop(probe);
2823
2824 let ingress = HttpIngress::<()>::new()
2825 .bind(addr.to_string())
2826 .graceful_shutdown(Duration::from_millis(500))
2827 .get(
2828 "/drain",
2829 Axon::<(), (), Infallible, ()>::new("SlowDrain").then(SlowDrainRoute),
2830 );
2831
2832 let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel::<()>();
2833 let server = tokio::spawn(async move {
2834 ingress
2835 .run_with_shutdown_signal((), async move {
2836 let _ = shutdown_rx.await;
2837 })
2838 .await
2839 });
2840
2841 let mut stream = connect_with_retry(addr).await;
2842 stream
2843 .write_all(b"GET /drain HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n")
2844 .await
2845 .expect("write request");
2846
2847 tokio::time::sleep(Duration::from_millis(20)).await;
2848 let _ = shutdown_tx.send(());
2849
2850 let mut buf = Vec::new();
2851 stream.read_to_end(&mut buf).await.expect("read response");
2852 let response = String::from_utf8_lossy(&buf);
2853 assert!(response.starts_with("HTTP/1.1 200"), "{response}");
2854 assert!(response.contains("drained-ok"), "{response}");
2855
2856 server
2857 .await
2858 .expect("server join")
2859 .expect("server shutdown should succeed");
2860 }
2861
2862 #[tokio::test]
2863 async fn serve_dir_serves_static_file_with_cache_and_metadata_headers() {
2864 let temp = tempdir().expect("tempdir");
2865 let root = temp.path().join("public");
2866 fs::create_dir_all(&root).expect("create dir");
2867 let file = root.join("hello.txt");
2868 fs::write(&file, "hello static").expect("write file");
2869
2870 let ingress =
2871 Ranvier::http::<()>().serve_dir("/static", root.to_string_lossy().to_string());
2872 let app = crate::test_harness::TestApp::new(ingress, ());
2873 let response = app
2874 .send(crate::test_harness::TestRequest::get("/static/hello.txt"))
2875 .await
2876 .expect("request should succeed");
2877
2878 assert_eq!(response.status(), StatusCode::OK);
2879 assert_eq!(response.text().expect("utf8"), "hello static");
2880 assert!(response.header("cache-control").is_some());
2881 let has_metadata_header =
2882 response.header("etag").is_some() || response.header("last-modified").is_some();
2883 assert!(has_metadata_header);
2884 }
2885
2886 #[tokio::test]
2887 async fn spa_fallback_returns_index_for_unmatched_path() {
2888 let temp = tempdir().expect("tempdir");
2889 let index = temp.path().join("index.html");
2890 fs::write(&index, "<html><body>spa</body></html>").expect("write index");
2891
2892 let ingress = Ranvier::http::<()>().spa_fallback(index.to_string_lossy().to_string());
2893 let app = crate::test_harness::TestApp::new(ingress, ());
2894 let response = app
2895 .send(crate::test_harness::TestRequest::get("/dashboard/settings"))
2896 .await
2897 .expect("request should succeed");
2898
2899 assert_eq!(response.status(), StatusCode::OK);
2900 assert!(response.text().expect("utf8").contains("spa"));
2901 }
2902
2903 #[tokio::test]
2904 async fn static_compression_layer_sets_content_encoding_for_gzip_client() {
2905 let temp = tempdir().expect("tempdir");
2906 let root = temp.path().join("public");
2907 fs::create_dir_all(&root).expect("create dir");
2908 let file = root.join("compressed.txt");
2909 fs::write(&file, "compress me ".repeat(400)).expect("write file");
2910
2911 let ingress = Ranvier::http::<()>()
2912 .serve_dir("/static", root.to_string_lossy().to_string())
2913 .compression_layer();
2914 let app = crate::test_harness::TestApp::new(ingress, ());
2915 let response = app
2916 .send(
2917 crate::test_harness::TestRequest::get("/static/compressed.txt")
2918 .header("accept-encoding", "gzip"),
2919 )
2920 .await
2921 .expect("request should succeed");
2922
2923 assert_eq!(response.status(), StatusCode::OK);
2924 assert_eq!(
2925 response
2926 .header("content-encoding")
2927 .and_then(|value| value.to_str().ok()),
2928 Some("gzip")
2929 );
2930 }
2931
2932 #[tokio::test]
2933 async fn drain_connections_completes_before_timeout() {
2934 let mut connections = tokio::task::JoinSet::new();
2935 connections.spawn(async {
2936 tokio::time::sleep(Duration::from_millis(20)).await;
2937 });
2938
2939 let timed_out = drain_connections(&mut connections, Duration::from_millis(200)).await;
2940 assert!(!timed_out);
2941 assert!(connections.is_empty());
2942 }
2943
2944 #[tokio::test]
2945 async fn drain_connections_times_out_and_aborts() {
2946 let mut connections = tokio::task::JoinSet::new();
2947 connections.spawn(async {
2948 tokio::time::sleep(Duration::from_secs(10)).await;
2949 });
2950
2951 let timed_out = drain_connections(&mut connections, Duration::from_millis(10)).await;
2952 assert!(timed_out);
2953 assert!(connections.is_empty());
2954 }
2955
2956 #[tokio::test]
2957 async fn timeout_layer_returns_408_for_slow_route() {
2958 #[derive(Clone)]
2959 struct SlowRoute;
2960
2961 #[async_trait]
2962 impl Transition<(), &'static str> for SlowRoute {
2963 type Error = Infallible;
2964 type Resources = ();
2965
2966 async fn run(
2967 &self,
2968 _state: (),
2969 _resources: &Self::Resources,
2970 _bus: &mut Bus,
2971 ) -> Outcome<&'static str, Self::Error> {
2972 tokio::time::sleep(Duration::from_millis(80)).await;
2973 Outcome::next("slow-ok")
2974 }
2975 }
2976
2977 let probe = std::net::TcpListener::bind("127.0.0.1:0").expect("bind probe");
2978 let addr = probe.local_addr().expect("local addr");
2979 drop(probe);
2980
2981 let ingress = HttpIngress::<()>::new()
2982 .bind(addr.to_string())
2983 .timeout_layer(Duration::from_millis(10))
2984 .get(
2985 "/slow",
2986 Axon::<(), (), Infallible, ()>::new("Slow").then(SlowRoute),
2987 );
2988
2989 let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel::<()>();
2990 let server = tokio::spawn(async move {
2991 ingress
2992 .run_with_shutdown_signal((), async move {
2993 let _ = shutdown_rx.await;
2994 })
2995 .await
2996 });
2997
2998 let mut stream = connect_with_retry(addr).await;
2999 stream
3000 .write_all(b"GET /slow HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n")
3001 .await
3002 .expect("write request");
3003
3004 let mut buf = Vec::new();
3005 stream.read_to_end(&mut buf).await.expect("read response");
3006 let response = String::from_utf8_lossy(&buf);
3007 assert!(response.starts_with("HTTP/1.1 408"), "{response}");
3008
3009 let _ = shutdown_tx.send(());
3010 server
3011 .await
3012 .expect("server join")
3013 .expect("server shutdown should succeed");
3014 }
3015
3016 #[tokio::test]
3017 async fn route_layer_override_bypasses_global_timeout() {
3018 #[derive(Clone)]
3019 struct SlowRoute;
3020
3021 #[async_trait]
3022 impl Transition<(), &'static str> for SlowRoute {
3023 type Error = Infallible;
3024 type Resources = ();
3025
3026 async fn run(
3027 &self,
3028 _state: (),
3029 _resources: &Self::Resources,
3030 _bus: &mut Bus,
3031 ) -> Outcome<&'static str, Self::Error> {
3032 tokio::time::sleep(Duration::from_millis(60)).await;
3033 Outcome::next("override-ok")
3034 }
3035 }
3036
3037 let probe = std::net::TcpListener::bind("127.0.0.1:0").expect("bind probe");
3038 let addr = probe.local_addr().expect("local addr");
3039 drop(probe);
3040
3041 let ingress = HttpIngress::<()>::new()
3042 .bind(addr.to_string())
3043 .timeout_layer(Duration::from_millis(10))
3044 .get_with_layer_override(
3045 "/slow",
3046 Axon::<(), (), Infallible, ()>::new("SlowOverride").then(SlowRoute),
3047 tower::layer::util::Identity::new(),
3048 );
3049
3050 let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel::<()>();
3051 let server = tokio::spawn(async move {
3052 ingress
3053 .run_with_shutdown_signal((), async move {
3054 let _ = shutdown_rx.await;
3055 })
3056 .await
3057 });
3058
3059 let mut stream = connect_with_retry(addr).await;
3060 stream
3061 .write_all(b"GET /slow HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n")
3062 .await
3063 .expect("write request");
3064
3065 let mut buf = Vec::new();
3066 stream.read_to_end(&mut buf).await.expect("read response");
3067 let response = String::from_utf8_lossy(&buf);
3068 assert!(response.starts_with("HTTP/1.1 200"), "{response}");
3069 assert!(response.contains("override-ok"), "{response}");
3070
3071 let _ = shutdown_tx.send(());
3072 server
3073 .await
3074 .expect("server join")
3075 .expect("server shutdown should succeed");
3076 }
3077}