1use base64::Engine;
17use bytes::Bytes;
18use futures_util::{SinkExt, StreamExt};
19use http::{Method, Request, Response, StatusCode};
20use http_body_util::{BodyExt, Full};
21use hyper::body::Incoming;
22use hyper::server::conn::http1;
23use hyper::upgrade::Upgraded;
24use hyper_util::rt::TokioIo;
25use ranvier_core::event::{EventSink, EventSource};
26use ranvier_core::prelude::*;
27use ranvier_runtime::Axon;
28use serde::Serialize;
29use serde::de::DeserializeOwned;
30use sha1::{Digest, Sha1};
31use std::collections::HashMap;
32use std::convert::Infallible;
33use std::future::Future;
34use std::net::SocketAddr;
35use std::pin::Pin;
36use std::sync::Arc;
37use std::time::Duration;
38use tokio::net::TcpListener;
39use tokio::sync::Mutex;
40use tokio_tungstenite::WebSocketStream;
41use tokio_tungstenite::tungstenite::{Error as WsWireError, Message as WsWireMessage};
42use tracing::Instrument;
43
44use crate::guard_integration::{
45 GuardExec, GuardIntegration, PreflightConfig, RegisteredGuard, ResponseBodyTransformFn,
46 ResponseExtractorFn,
47};
48use crate::response::{HttpResponse, IntoResponse, json_error_response, outcome_to_response_with_error};
49
50pub struct Ranvier;
55
56impl Ranvier {
57 pub fn http<R>() -> HttpIngress<R>
59 where
60 R: ranvier_core::transition::ResourceRequirement + Clone,
61 {
62 HttpIngress::new()
63 }
64}
65
66type RouteHandler<R> = Arc<
68 dyn Fn(http::request::Parts, &R) -> Pin<Box<dyn Future<Output = HttpResponse> + Send>>
69 + Send
70 + Sync,
71>;
72
73#[derive(Clone)]
75struct BoxService(
76 Arc<
77 dyn Fn(Request<Incoming>) -> Pin<Box<dyn Future<Output = Result<HttpResponse, Infallible>> + Send>>
78 + Send
79 + Sync,
80 >,
81);
82
83impl BoxService {
84 fn new<F, Fut>(f: F) -> Self
85 where
86 F: Fn(Request<Incoming>) -> Fut + Send + Sync + 'static,
87 Fut: Future<Output = Result<HttpResponse, Infallible>> + Send + 'static,
88 {
89 Self(Arc::new(move |req| Box::pin(f(req))))
90 }
91
92 fn call(&self, req: Request<Incoming>) -> Pin<Box<dyn Future<Output = Result<HttpResponse, Infallible>> + Send>> {
93 (self.0)(req)
94 }
95}
96
97impl hyper::service::Service<Request<Incoming>> for BoxService {
98 type Response = HttpResponse;
99 type Error = Infallible;
100 type Future = Pin<Box<dyn Future<Output = Result<HttpResponse, Infallible>> + Send>>;
101
102 fn call(&self, req: Request<Incoming>) -> Self::Future {
103 (self.0)(req)
104 }
105}
106
107type BoxHttpService = BoxService;
108type ServiceLayer = Arc<dyn Fn(BoxHttpService) -> BoxHttpService + Send + Sync>;
109type LifecycleHook = Arc<dyn Fn() + Send + Sync>;
110type BusInjector = Arc<dyn Fn(&http::request::Parts, &mut Bus) + Send + Sync + 'static>;
111type WsSessionFuture = Pin<Box<dyn Future<Output = ()> + Send>>;
112type WsSessionHandler<R> =
113 Arc<dyn Fn(WebSocketConnection, Arc<R>, Bus) -> WsSessionFuture + Send + Sync>;
114type HealthCheckFuture = Pin<Box<dyn Future<Output = Result<(), String>> + Send>>;
115type HealthCheckFn<R> = Arc<dyn Fn(Arc<R>) -> HealthCheckFuture + Send + Sync>;
116const REQUEST_ID_HEADER: &str = "x-request-id";
117const WS_UPGRADE_TOKEN: &str = "websocket";
118const WS_GUID: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
119
120#[derive(Clone)]
121struct NamedHealthCheck<R> {
122 name: String,
123 check: HealthCheckFn<R>,
124}
125
126#[derive(Clone)]
127struct HealthConfig<R> {
128 health_path: Option<String>,
129 readiness_path: Option<String>,
130 liveness_path: Option<String>,
131 checks: Vec<NamedHealthCheck<R>>,
132}
133
134impl<R> Default for HealthConfig<R> {
135 fn default() -> Self {
136 Self {
137 health_path: None,
138 readiness_path: None,
139 liveness_path: None,
140 checks: Vec::new(),
141 }
142 }
143}
144
145#[derive(Clone, Default)]
146struct StaticAssetsConfig {
147 mounts: Vec<StaticMount>,
148 spa_fallback: Option<String>,
149 cache_control: Option<String>,
150 enable_compression: bool,
151 directory_index: Option<String>,
153 immutable_cache: bool,
156 serve_precompressed: bool,
158 enable_range_requests: bool,
160}
161
162#[derive(Clone)]
163struct StaticMount {
164 route_prefix: String,
165 directory: String,
166}
167
168#[cfg(feature = "tls")]
170#[derive(Clone)]
171struct TlsAcceptorConfig {
172 cert_path: String,
173 key_path: String,
174}
175
176#[derive(Serialize)]
177struct HealthReport {
178 status: &'static str,
179 probe: &'static str,
180 checks: Vec<HealthCheckReport>,
181}
182
183#[derive(Serialize)]
184struct HealthCheckReport {
185 name: String,
186 status: &'static str,
187 #[serde(skip_serializing_if = "Option::is_none")]
188 error: Option<String>,
189}
190
191fn timeout_middleware(timeout: Duration) -> ServiceLayer {
192 Arc::new(move |inner: BoxHttpService| {
193 BoxService::new(move |req: Request<Incoming>| {
194 let inner = inner.clone();
195 async move {
196 match tokio::time::timeout(timeout, inner.call(req)).await {
197 Ok(response) => response,
198 Err(_) => Ok(Response::builder()
199 .status(StatusCode::REQUEST_TIMEOUT)
200 .body(
201 Full::new(Bytes::from("Request Timeout"))
202 .map_err(|never| match never {})
203 .boxed(),
204 )
205 .expect("valid HTTP response construction")),
206 }
207 }
208 })
209 })
210}
211
212fn request_id_middleware() -> ServiceLayer {
213 Arc::new(move |inner: BoxHttpService| {
214 BoxService::new(move |req: Request<Incoming>| {
215 let inner = inner.clone();
216 async move {
217 let mut req = req;
218 let request_id = req
219 .headers()
220 .get(REQUEST_ID_HEADER)
221 .cloned()
222 .unwrap_or_else(|| {
223 http::HeaderValue::from_str(&uuid::Uuid::new_v4().to_string())
224 .unwrap_or_else(|_| {
225 http::HeaderValue::from_static("request-id-unavailable")
226 })
227 });
228 req.headers_mut()
229 .insert(REQUEST_ID_HEADER, request_id.clone());
230 let mut response = inner.call(req).await?;
231 response
232 .headers_mut()
233 .insert(REQUEST_ID_HEADER, request_id);
234 Ok(response)
235 }
236 })
237 })
238}
239
240#[derive(Clone, Debug, Default, PartialEq, Eq)]
241pub struct PathParams {
242 values: HashMap<String, String>,
243}
244
245#[derive(Clone, Debug)]
247pub struct HttpRouteDescriptor {
248 method: Method,
249 path_pattern: String,
250 pub body_schema: Option<serde_json::Value>,
252}
253
254impl HttpRouteDescriptor {
255 pub fn new(method: Method, path_pattern: impl Into<String>) -> Self {
256 Self {
257 method,
258 path_pattern: path_pattern.into(),
259 body_schema: None,
260 }
261 }
262
263 pub fn method(&self) -> &Method {
264 &self.method
265 }
266
267 pub fn path_pattern(&self) -> &str {
268 &self.path_pattern
269 }
270
271 pub fn body_schema(&self) -> Option<&serde_json::Value> {
276 self.body_schema.as_ref()
277 }
278}
279
280#[derive(Clone, Debug, PartialEq, Eq, Serialize)]
282pub struct WebSocketSessionContext {
283 connection_id: uuid::Uuid,
284 path: String,
285 query: Option<String>,
286}
287
288impl WebSocketSessionContext {
289 pub fn connection_id(&self) -> uuid::Uuid {
290 self.connection_id
291 }
292
293 pub fn path(&self) -> &str {
294 &self.path
295 }
296
297 pub fn query(&self) -> Option<&str> {
298 self.query.as_deref()
299 }
300}
301
302#[derive(Clone, Debug, PartialEq, Eq)]
304pub enum WebSocketEvent {
305 Text(String),
306 Binary(Vec<u8>),
307 Ping(Vec<u8>),
308 Pong(Vec<u8>),
309 Close,
310}
311
312impl WebSocketEvent {
313 pub fn text(value: impl Into<String>) -> Self {
314 Self::Text(value.into())
315 }
316
317 pub fn binary(value: impl Into<Vec<u8>>) -> Self {
318 Self::Binary(value.into())
319 }
320
321 pub fn json<T>(value: &T) -> Result<Self, serde_json::Error>
322 where
323 T: Serialize,
324 {
325 let text = serde_json::to_string(value)?;
326 Ok(Self::Text(text))
327 }
328}
329
330#[derive(Debug, thiserror::Error)]
331pub enum WebSocketError {
332 #[error("websocket wire error: {0}")]
333 Wire(#[from] WsWireError),
334 #[error("json serialization failed: {0}")]
335 JsonSerialize(#[source] serde_json::Error),
336 #[error("json deserialization failed: {0}")]
337 JsonDeserialize(#[source] serde_json::Error),
338 #[error("expected text or binary frame for json payload")]
339 NonDataFrame,
340}
341
342type WsServerStream = WebSocketStream<TokioIo<Upgraded>>;
343type WsServerSink = futures_util::stream::SplitSink<WsServerStream, WsWireMessage>;
344type WsServerSource = futures_util::stream::SplitStream<WsServerStream>;
345
346pub struct WebSocketConnection {
348 sink: Mutex<WsServerSink>,
349 source: Mutex<WsServerSource>,
350 session: WebSocketSessionContext,
351}
352
353impl WebSocketConnection {
354 fn new(stream: WsServerStream, session: WebSocketSessionContext) -> Self {
355 let (sink, source) = stream.split();
356 Self {
357 sink: Mutex::new(sink),
358 source: Mutex::new(source),
359 session,
360 }
361 }
362
363 pub fn session(&self) -> &WebSocketSessionContext {
364 &self.session
365 }
366
367 pub async fn send(&self, event: WebSocketEvent) -> Result<(), WebSocketError> {
368 let mut sink = self.sink.lock().await;
369 sink.send(event.into_wire_message()).await?;
370 Ok(())
371 }
372
373 pub async fn send_json<T>(&self, value: &T) -> Result<(), WebSocketError>
374 where
375 T: Serialize,
376 {
377 let event = WebSocketEvent::json(value).map_err(WebSocketError::JsonSerialize)?;
378 self.send(event).await
379 }
380
381 pub async fn next_json<T>(&mut self) -> Result<Option<T>, WebSocketError>
382 where
383 T: DeserializeOwned,
384 {
385 let Some(event) = self.recv_event().await? else {
386 return Ok(None);
387 };
388 match event {
389 WebSocketEvent::Text(text) => serde_json::from_str(&text)
390 .map(Some)
391 .map_err(WebSocketError::JsonDeserialize),
392 WebSocketEvent::Binary(bytes) => serde_json::from_slice(&bytes)
393 .map(Some)
394 .map_err(WebSocketError::JsonDeserialize),
395 _ => Err(WebSocketError::NonDataFrame),
396 }
397 }
398
399 async fn recv_event(&mut self) -> Result<Option<WebSocketEvent>, WsWireError> {
400 let mut source = self.source.lock().await;
401 while let Some(item) = source.next().await {
402 let message = item?;
403 if let Some(event) = WebSocketEvent::from_wire_message(message) {
404 return Ok(Some(event));
405 }
406 }
407 Ok(None)
408 }
409}
410
411impl WebSocketEvent {
412 fn from_wire_message(message: WsWireMessage) -> Option<Self> {
413 match message {
414 WsWireMessage::Text(value) => Some(Self::Text(value.to_string())),
415 WsWireMessage::Binary(value) => Some(Self::Binary(value.to_vec())),
416 WsWireMessage::Ping(value) => Some(Self::Ping(value.to_vec())),
417 WsWireMessage::Pong(value) => Some(Self::Pong(value.to_vec())),
418 WsWireMessage::Close(_) => Some(Self::Close),
419 WsWireMessage::Frame(_) => None,
420 }
421 }
422
423 fn into_wire_message(self) -> WsWireMessage {
424 match self {
425 Self::Text(value) => WsWireMessage::Text(value),
426 Self::Binary(value) => WsWireMessage::Binary(value),
427 Self::Ping(value) => WsWireMessage::Ping(value),
428 Self::Pong(value) => WsWireMessage::Pong(value),
429 Self::Close => WsWireMessage::Close(None),
430 }
431 }
432}
433
434#[async_trait::async_trait]
435impl EventSource<WebSocketEvent> for WebSocketConnection {
436 async fn next_event(&mut self) -> Option<WebSocketEvent> {
437 match self.recv_event().await {
438 Ok(event) => event,
439 Err(error) => {
440 tracing::warn!(ranvier.ws.error = %error, "websocket source read failed");
441 None
442 }
443 }
444 }
445}
446
447#[async_trait::async_trait]
448impl EventSink<WebSocketEvent> for WebSocketConnection {
449 type Error = WebSocketError;
450
451 async fn send_event(&self, event: WebSocketEvent) -> Result<(), Self::Error> {
452 self.send(event).await
453 }
454}
455
456#[async_trait::async_trait]
457impl EventSink<String> for WebSocketConnection {
458 type Error = WebSocketError;
459
460 async fn send_event(&self, event: String) -> Result<(), Self::Error> {
461 self.send(WebSocketEvent::Text(event)).await
462 }
463}
464
465#[async_trait::async_trait]
466impl EventSink<Vec<u8>> for WebSocketConnection {
467 type Error = WebSocketError;
468
469 async fn send_event(&self, event: Vec<u8>) -> Result<(), Self::Error> {
470 self.send(WebSocketEvent::Binary(event)).await
471 }
472}
473
474impl PathParams {
475 pub fn new(values: HashMap<String, String>) -> Self {
476 Self { values }
477 }
478
479 pub fn get(&self, key: &str) -> Option<&str> {
480 self.values.get(key).map(String::as_str)
481 }
482
483 pub fn as_map(&self) -> &HashMap<String, String> {
484 &self.values
485 }
486
487 pub fn into_inner(self) -> HashMap<String, String> {
488 self.values
489 }
490}
491
492#[derive(Clone, Debug, PartialEq, Eq)]
493enum RouteSegment {
494 Static(String),
495 Param(String),
496 Wildcard(String),
497}
498
499#[derive(Clone, Debug, PartialEq, Eq)]
500struct RoutePattern {
501 raw: String,
502 segments: Vec<RouteSegment>,
503}
504
505impl RoutePattern {
506 fn parse(path: &str) -> Self {
507 let segments = path_segments(path)
508 .into_iter()
509 .map(|segment| {
510 if let Some(name) = segment.strip_prefix(':') {
511 if !name.is_empty() {
512 return RouteSegment::Param(name.to_string());
513 }
514 }
515 if let Some(name) = segment.strip_prefix('*') {
516 if !name.is_empty() {
517 return RouteSegment::Wildcard(name.to_string());
518 }
519 }
520 RouteSegment::Static(segment.to_string())
521 })
522 .collect();
523
524 Self {
525 raw: path.to_string(),
526 segments,
527 }
528 }
529
530 fn match_path(&self, path: &str) -> Option<PathParams> {
531 let mut params = HashMap::new();
532 let path_segments = path_segments(path);
533 let mut pattern_index = 0usize;
534 let mut path_index = 0usize;
535
536 while pattern_index < self.segments.len() {
537 match &self.segments[pattern_index] {
538 RouteSegment::Static(expected) => {
539 let actual = path_segments.get(path_index)?;
540 if actual != expected {
541 return None;
542 }
543 pattern_index += 1;
544 path_index += 1;
545 }
546 RouteSegment::Param(name) => {
547 let actual = path_segments.get(path_index)?;
548 params.insert(name.clone(), (*actual).to_string());
549 pattern_index += 1;
550 path_index += 1;
551 }
552 RouteSegment::Wildcard(name) => {
553 let remaining = path_segments[path_index..].join("/");
554 params.insert(name.clone(), remaining);
555 pattern_index += 1;
556 path_index = path_segments.len();
557 break;
558 }
559 }
560 }
561
562 if pattern_index == self.segments.len() && path_index == path_segments.len() {
563 Some(PathParams::new(params))
564 } else {
565 None
566 }
567 }
568}
569
570#[derive(Clone)]
573struct BodyBytes(Bytes);
574
575#[derive(Clone)]
576struct RouteEntry<R> {
577 method: Method,
578 pattern: RoutePattern,
579 handler: RouteHandler<R>,
580 layers: Arc<Vec<ServiceLayer>>,
581 apply_global_layers: bool,
582 needs_body: bool,
585 body_schema: Option<serde_json::Value>,
587}
588
589fn path_segments(path: &str) -> Vec<&str> {
590 if path == "/" {
591 return Vec::new();
592 }
593
594 path.trim_matches('/')
595 .split('/')
596 .filter(|segment| !segment.is_empty())
597 .collect()
598}
599
600fn normalize_route_path(path: String) -> String {
601 if path.is_empty() {
602 return "/".to_string();
603 }
604 if path.starts_with('/') {
605 path
606 } else {
607 format!("/{path}")
608 }
609}
610
611fn find_matching_route<'a, R>(
612 routes: &'a [RouteEntry<R>],
613 method: &Method,
614 path: &str,
615) -> Option<(&'a RouteEntry<R>, PathParams)> {
616 for entry in routes {
617 if entry.method != *method {
618 continue;
619 }
620 if let Some(params) = entry.pattern.match_path(path) {
621 return Some((entry, params));
622 }
623 }
624 None
625}
626
627fn header_contains_token(
628 headers: &http::HeaderMap,
629 name: http::header::HeaderName,
630 token: &str,
631) -> bool {
632 headers
633 .get(name)
634 .and_then(|value| value.to_str().ok())
635 .map(|value| {
636 value
637 .split(',')
638 .any(|part| part.trim().eq_ignore_ascii_case(token))
639 })
640 .unwrap_or(false)
641}
642
643fn websocket_session_from_request<B>(req: &Request<B>) -> WebSocketSessionContext {
644 WebSocketSessionContext {
645 connection_id: uuid::Uuid::new_v4(),
646 path: req.uri().path().to_string(),
647 query: req.uri().query().map(str::to_string),
648 }
649}
650
651fn websocket_accept_key(client_key: &str) -> String {
652 let mut hasher = Sha1::new();
653 hasher.update(client_key.as_bytes());
654 hasher.update(WS_GUID.as_bytes());
655 let digest = hasher.finalize();
656 base64::engine::general_purpose::STANDARD.encode(digest)
657}
658
659fn websocket_bad_request(message: &'static str) -> HttpResponse {
660 Response::builder()
661 .status(StatusCode::BAD_REQUEST)
662 .body(
663 Full::new(Bytes::from(message))
664 .map_err(|never| match never {})
665 .boxed(),
666 )
667 .unwrap_or_else(|_| {
668 Response::new(
669 Full::new(Bytes::new())
670 .map_err(|never| match never {})
671 .boxed(),
672 )
673 })
674}
675
676fn websocket_upgrade_response<B>(
677 req: &mut Request<B>,
678) -> Result<(HttpResponse, hyper::upgrade::OnUpgrade), HttpResponse> {
679 if req.method() != Method::GET {
680 return Err(websocket_bad_request(
681 "WebSocket upgrade requires GET method",
682 ));
683 }
684
685 if !header_contains_token(req.headers(), http::header::CONNECTION, "upgrade") {
686 return Err(websocket_bad_request(
687 "Missing Connection: upgrade header for WebSocket",
688 ));
689 }
690
691 if !header_contains_token(req.headers(), http::header::UPGRADE, WS_UPGRADE_TOKEN) {
692 return Err(websocket_bad_request("Missing Upgrade: websocket header"));
693 }
694
695 if let Some(version) = req.headers().get("sec-websocket-version") {
696 if version != "13" {
697 return Err(websocket_bad_request(
698 "Unsupported Sec-WebSocket-Version (expected 13)",
699 ));
700 }
701 }
702
703 let Some(client_key) = req
704 .headers()
705 .get("sec-websocket-key")
706 .and_then(|value| value.to_str().ok())
707 else {
708 return Err(websocket_bad_request(
709 "Missing Sec-WebSocket-Key header for WebSocket",
710 ));
711 };
712
713 let accept_key = websocket_accept_key(client_key);
714 let on_upgrade = hyper::upgrade::on(req);
715 let response = Response::builder()
716 .status(StatusCode::SWITCHING_PROTOCOLS)
717 .header(http::header::UPGRADE, WS_UPGRADE_TOKEN)
718 .header(http::header::CONNECTION, "Upgrade")
719 .header("sec-websocket-accept", accept_key)
720 .body(
721 Full::new(Bytes::new())
722 .map_err(|never| match never {})
723 .boxed(),
724 )
725 .unwrap_or_else(|_| {
726 Response::new(
727 Full::new(Bytes::new())
728 .map_err(|never| match never {})
729 .boxed(),
730 )
731 });
732
733 Ok((response, on_upgrade))
734}
735
736pub struct HttpIngress<R = ()> {
742 addr: Option<String>,
744 routes: Vec<RouteEntry<R>>,
746 fallback: Option<RouteHandler<R>>,
748 layers: Vec<ServiceLayer>,
750 on_start: Option<LifecycleHook>,
752 on_shutdown: Option<LifecycleHook>,
754 graceful_shutdown_timeout: Duration,
756 bus_injectors: Vec<BusInjector>,
758 static_assets: StaticAssetsConfig,
760 health: HealthConfig<R>,
762 #[cfg(feature = "http3")]
763 http3_config: Option<crate::http3::Http3Config>,
764 #[cfg(feature = "http3")]
765 alt_svc_h3_port: Option<u16>,
766 #[cfg(feature = "tls")]
768 tls_config: Option<TlsAcceptorConfig>,
769 active_intervention: bool,
771 policy_registry: Option<ranvier_core::policy::PolicyRegistry>,
773 guard_execs: Vec<Arc<dyn GuardExec>>,
775 guard_response_extractors: Vec<ResponseExtractorFn>,
777 guard_body_transforms: Vec<ResponseBodyTransformFn>,
779 preflight_config: Option<PreflightConfig>,
781 _phantom: std::marker::PhantomData<R>,
782}
783
784impl<R> HttpIngress<R>
785where
786 R: ranvier_core::transition::ResourceRequirement + Clone + Send + Sync + 'static,
787{
788 pub fn new() -> Self {
790 Self {
791 addr: None,
792 routes: Vec::new(),
793 fallback: None,
794 layers: Vec::new(),
795 on_start: None,
796 on_shutdown: None,
797 graceful_shutdown_timeout: Duration::from_secs(30),
798 bus_injectors: Vec::new(),
799 static_assets: StaticAssetsConfig::default(),
800 health: HealthConfig::default(),
801 #[cfg(feature = "tls")]
802 tls_config: None,
803 #[cfg(feature = "http3")]
804 http3_config: None,
805 #[cfg(feature = "http3")]
806 alt_svc_h3_port: None,
807 active_intervention: false,
808 policy_registry: None,
809 guard_execs: Vec::new(),
810 guard_response_extractors: Vec::new(),
811 guard_body_transforms: Vec::new(),
812 preflight_config: None,
813 _phantom: std::marker::PhantomData,
814 }
815 }
816
817 pub fn bind(mut self, addr: impl Into<String>) -> Self {
821 self.addr = Some(addr.into());
822 self
823 }
824
825 pub fn active_intervention(mut self) -> Self {
831 self.active_intervention = true;
832 self
833 }
834
835 pub fn policy_registry(mut self, registry: ranvier_core::policy::PolicyRegistry) -> Self {
837 self.policy_registry = Some(registry);
838 self
839 }
840
841 pub fn on_start<F>(mut self, callback: F) -> Self
845 where
846 F: Fn() + Send + Sync + 'static,
847 {
848 self.on_start = Some(Arc::new(callback));
849 self
850 }
851
852 pub fn on_shutdown<F>(mut self, callback: F) -> Self
854 where
855 F: Fn() + Send + Sync + 'static,
856 {
857 self.on_shutdown = Some(Arc::new(callback));
858 self
859 }
860
861 pub fn graceful_shutdown(mut self, timeout: Duration) -> Self {
863 self.graceful_shutdown_timeout = timeout;
864 self
865 }
866
867 pub fn config(mut self, config: &ranvier_core::config::RanvierConfig) -> Self {
873 self.addr = Some(config.bind_addr());
874 self.graceful_shutdown_timeout = config.shutdown_timeout();
875 config.init_telemetry();
876 self
877 }
878
879 #[cfg(feature = "tls")]
881 pub fn tls(mut self, cert_path: impl Into<String>, key_path: impl Into<String>) -> Self {
882 self.tls_config = Some(TlsAcceptorConfig {
883 cert_path: cert_path.into(),
884 key_path: key_path.into(),
885 });
886 self
887 }
888
889 pub fn timeout_layer(mut self, timeout: Duration) -> Self {
894 self.layers.push(timeout_middleware(timeout));
895 self
896 }
897
898 pub fn request_id_layer(mut self) -> Self {
902 self.layers.push(request_id_middleware());
903 self
904 }
905
906 pub fn bus_injector<F>(mut self, injector: F) -> Self
911 where
912 F: Fn(&http::request::Parts, &mut Bus) + Send + Sync + 'static,
913 {
914 self.bus_injectors.push(Arc::new(injector));
915 self
916 }
917
918 #[cfg(feature = "htmx")]
926 pub fn htmx_support(mut self) -> Self {
927 self.bus_injectors
928 .push(Arc::new(crate::htmx::inject_htmx_headers));
929 self.guard_response_extractors
930 .push(Arc::new(crate::htmx::extract_htmx_response_headers));
931 self
932 }
933
934 pub fn guard(mut self, guard: impl GuardIntegration) -> Self {
954 let registration = guard.register();
955 for injector in registration.bus_injectors {
956 self.bus_injectors.push(injector);
957 }
958 self.guard_execs.push(registration.exec);
959 if let Some(extractor) = registration.response_extractor {
960 self.guard_response_extractors.push(extractor);
961 }
962 if let Some(transform) = registration.response_body_transform {
963 self.guard_body_transforms.push(transform);
964 }
965 if registration.handles_preflight {
966 if let Some(config) = registration.preflight_config {
967 self.preflight_config = Some(config);
968 }
969 }
970 self
971 }
972
973 #[cfg(feature = "http3")]
975 pub fn enable_http3(mut self, config: crate::http3::Http3Config) -> Self {
976 self.http3_config = Some(config);
977 self
978 }
979
980 #[cfg(feature = "http3")]
982 pub fn alt_svc_h3(mut self, port: u16) -> Self {
983 self.alt_svc_h3_port = Some(port);
984 self
985 }
986
987 pub fn route_descriptors(&self) -> Vec<HttpRouteDescriptor> {
991 let mut descriptors = self
992 .routes
993 .iter()
994 .map(|entry| {
995 let mut desc = HttpRouteDescriptor::new(entry.method.clone(), entry.pattern.raw.clone());
996 desc.body_schema = entry.body_schema.clone();
997 desc
998 })
999 .collect::<Vec<_>>();
1000
1001 if let Some(path) = &self.health.health_path {
1002 descriptors.push(HttpRouteDescriptor::new(Method::GET, path.clone()));
1003 }
1004 if let Some(path) = &self.health.readiness_path {
1005 descriptors.push(HttpRouteDescriptor::new(Method::GET, path.clone()));
1006 }
1007 if let Some(path) = &self.health.liveness_path {
1008 descriptors.push(HttpRouteDescriptor::new(Method::GET, path.clone()));
1009 }
1010
1011 descriptors
1012 }
1013
1014 pub fn serve_dir(
1020 mut self,
1021 route_prefix: impl Into<String>,
1022 directory: impl Into<String>,
1023 ) -> Self {
1024 self.static_assets.mounts.push(StaticMount {
1025 route_prefix: normalize_route_path(route_prefix.into()),
1026 directory: directory.into(),
1027 });
1028 if self.static_assets.cache_control.is_none() {
1029 self.static_assets.cache_control = Some("public, max-age=3600".to_string());
1030 }
1031 self
1032 }
1033
1034 pub fn spa_fallback(mut self, file_path: impl Into<String>) -> Self {
1038 self.static_assets.spa_fallback = Some(file_path.into());
1039 self
1040 }
1041
1042 pub fn static_cache_control(mut self, cache_control: impl Into<String>) -> Self {
1044 self.static_assets.cache_control = Some(cache_control.into());
1045 self
1046 }
1047
1048 pub fn directory_index(mut self, filename: impl Into<String>) -> Self {
1056 self.static_assets.directory_index = Some(filename.into());
1057 self
1058 }
1059
1060 pub fn immutable_cache(mut self) -> Self {
1065 self.static_assets.immutable_cache = true;
1066 self
1067 }
1068
1069 pub fn serve_precompressed(mut self) -> Self {
1075 self.static_assets.serve_precompressed = true;
1076 self
1077 }
1078
1079 pub fn enable_range_requests(mut self) -> Self {
1084 self.static_assets.enable_range_requests = true;
1085 self
1086 }
1087
1088 pub fn compression_layer(mut self) -> Self {
1090 self.static_assets.enable_compression = true;
1091 self
1092 }
1093
1094 pub fn ws<H, Fut>(mut self, path: impl Into<String>, handler: H) -> Self
1103 where
1104 H: Fn(WebSocketConnection, Arc<R>, Bus) -> Fut + Send + Sync + 'static,
1105 Fut: Future<Output = ()> + Send + 'static,
1106 {
1107 let path_str: String = path.into();
1108 let ws_handler: WsSessionHandler<R> = Arc::new(move |connection, resources, bus| {
1109 Box::pin(handler(connection, resources, bus))
1110 });
1111 let bus_injectors = Arc::new(self.bus_injectors.clone());
1112 let ws_guard_execs = Arc::new(self.guard_execs.clone());
1113 let path_for_pattern = path_str.clone();
1114 let path_for_handler = path_str;
1115
1116 let route_handler: RouteHandler<R> =
1117 Arc::new(move |parts: http::request::Parts, res: &R| {
1118 let ws_handler = ws_handler.clone();
1119 let bus_injectors = bus_injectors.clone();
1120 let ws_guard_execs = ws_guard_execs.clone();
1121 let resources = Arc::new(res.clone());
1122 let path = path_for_handler.clone();
1123
1124 Box::pin(async move {
1125 let request_id = uuid::Uuid::new_v4().to_string();
1126 let span = tracing::info_span!(
1127 "WebSocketUpgrade",
1128 ranvier.ws.path = %path,
1129 ranvier.ws.request_id = %request_id
1130 );
1131
1132 async move {
1133 let mut bus = Bus::new();
1134 for injector in bus_injectors.iter() {
1135 injector(&parts, &mut bus);
1136 }
1137 for guard_exec in ws_guard_execs.iter() {
1138 if let Err(rejection) = guard_exec.exec_guard(&mut bus).await {
1139 return json_error_response(rejection.status, &rejection.message);
1140 }
1141 }
1142
1143 let mut req = Request::from_parts(parts, ());
1145 let session = websocket_session_from_request(&req);
1146 bus.insert(session.clone());
1147
1148 let (response, on_upgrade) = match websocket_upgrade_response(&mut req) {
1149 Ok(result) => result,
1150 Err(error_response) => return error_response,
1151 };
1152
1153 tokio::spawn(async move {
1154 match on_upgrade.await {
1155 Ok(upgraded) => {
1156 let stream = WebSocketStream::from_raw_socket(
1157 TokioIo::new(upgraded),
1158 tokio_tungstenite::tungstenite::protocol::Role::Server,
1159 None,
1160 )
1161 .await;
1162 let connection = WebSocketConnection::new(stream, session);
1163 ws_handler(connection, resources, bus).await;
1164 }
1165 Err(error) => {
1166 tracing::warn!(
1167 ranvier.ws.path = %path,
1168 ranvier.ws.error = %error,
1169 "websocket upgrade failed"
1170 );
1171 }
1172 }
1173 });
1174
1175 response
1176 }
1177 .instrument(span)
1178 .await
1179 }) as Pin<Box<dyn Future<Output = HttpResponse> + Send>>
1180 });
1181
1182 self.routes.push(RouteEntry {
1183 method: Method::GET,
1184 pattern: RoutePattern::parse(&path_for_pattern),
1185 handler: route_handler,
1186 layers: Arc::new(Vec::new()),
1187 apply_global_layers: true,
1188 needs_body: false,
1189 body_schema: None,
1190 });
1191
1192 self
1193 }
1194
1195 pub fn health_endpoint(mut self, path: impl Into<String>) -> Self {
1202 self.health.health_path = Some(normalize_route_path(path.into()));
1203 self
1204 }
1205
1206 pub fn health_check<F, Fut, Err>(mut self, name: impl Into<String>, check: F) -> Self
1210 where
1211 F: Fn(Arc<R>) -> Fut + Send + Sync + 'static,
1212 Fut: Future<Output = Result<(), Err>> + Send + 'static,
1213 Err: ToString + Send + 'static,
1214 {
1215 if self.health.health_path.is_none() {
1216 self.health.health_path = Some("/health".to_string());
1217 }
1218
1219 let check_fn: HealthCheckFn<R> = Arc::new(move |resources: Arc<R>| {
1220 let fut = check(resources);
1221 Box::pin(async move { fut.await.map_err(|error| error.to_string()) })
1222 });
1223
1224 self.health.checks.push(NamedHealthCheck {
1225 name: name.into(),
1226 check: check_fn,
1227 });
1228 self
1229 }
1230
1231 pub fn readiness_liveness(
1233 mut self,
1234 readiness_path: impl Into<String>,
1235 liveness_path: impl Into<String>,
1236 ) -> Self {
1237 self.health.readiness_path = Some(normalize_route_path(readiness_path.into()));
1238 self.health.liveness_path = Some(normalize_route_path(liveness_path.into()));
1239 self
1240 }
1241
1242 pub fn readiness_liveness_default(self) -> Self {
1244 self.readiness_liveness("/ready", "/live")
1245 }
1246
1247 pub fn route<Out, E>(self, path: impl Into<String>, circuit: Axon<(), Out, E, R>) -> Self
1251 where
1252 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1253 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1254 {
1255 self.route_method(Method::GET, path, circuit)
1256 }
1257 pub fn route_method<Out, E>(
1266 self,
1267 method: Method,
1268 path: impl Into<String>,
1269 circuit: Axon<(), Out, E, R>,
1270 ) -> Self
1271 where
1272 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1273 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1274 {
1275 self.route_method_with_error(method, path, circuit, |error| {
1276 (
1277 StatusCode::INTERNAL_SERVER_ERROR,
1278 format!("Error: {:?}", error),
1279 )
1280 .into_response()
1281 })
1282 }
1283
1284 pub fn route_method_with_error<Out, E, H>(
1285 self,
1286 method: Method,
1287 path: impl Into<String>,
1288 circuit: Axon<(), Out, E, R>,
1289 error_handler: H,
1290 ) -> Self
1291 where
1292 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1293 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1294 H: Fn(&E) -> HttpResponse + Send + Sync + 'static,
1295 {
1296 self.route_method_with_error_and_layers(
1297 method,
1298 path,
1299 circuit,
1300 error_handler,
1301 Arc::new(Vec::new()),
1302 true,
1303 )
1304 }
1305
1306
1307
1308 fn route_method_with_error_and_layers<Out, E, H>(
1309 mut self,
1310 method: Method,
1311 path: impl Into<String>,
1312 circuit: Axon<(), Out, E, R>,
1313 error_handler: H,
1314 route_layers: Arc<Vec<ServiceLayer>>,
1315 apply_global_layers: bool,
1316 ) -> Self
1317 where
1318 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1319 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1320 H: Fn(&E) -> HttpResponse + Send + Sync + 'static,
1321 {
1322 let path_str: String = path.into();
1323 let circuit = Arc::new(circuit);
1324 let error_handler = Arc::new(error_handler);
1325 let route_bus_injectors = Arc::new(self.bus_injectors.clone());
1326 let route_guard_execs = Arc::new(self.guard_execs.clone());
1327 let route_response_extractors = Arc::new(self.guard_response_extractors.clone());
1328 let route_body_transforms = Arc::new(self.guard_body_transforms.clone());
1329 let path_for_pattern = path_str.clone();
1330 let path_for_handler = path_str;
1331 let method_for_pattern = method.clone();
1332 let method_for_handler = method;
1333
1334 let handler: RouteHandler<R> = Arc::new(move |parts: http::request::Parts, res: &R| {
1335 let circuit = circuit.clone();
1336 let error_handler = error_handler.clone();
1337 let route_bus_injectors = route_bus_injectors.clone();
1338 let route_guard_execs = route_guard_execs.clone();
1339 let route_response_extractors = route_response_extractors.clone();
1340 let route_body_transforms = route_body_transforms.clone();
1341 let res = res.clone();
1342 let path = path_for_handler.clone();
1343 let method = method_for_handler.clone();
1344
1345 Box::pin(async move {
1346 let request_id = uuid::Uuid::new_v4().to_string();
1347 let span = tracing::info_span!(
1348 "HTTPRequest",
1349 ranvier.http.method = %method,
1350 ranvier.http.path = %path,
1351 ranvier.http.request_id = %request_id
1352 );
1353
1354 async move {
1355 let mut bus = Bus::new();
1356 for injector in route_bus_injectors.iter() {
1357 injector(&parts, &mut bus);
1358 }
1359 for guard_exec in route_guard_execs.iter() {
1360 if let Err(rejection) = guard_exec.exec_guard(&mut bus).await {
1361 let mut response = json_error_response(rejection.status, &rejection.message);
1362 for extractor in route_response_extractors.iter() {
1363 extractor(&bus, response.headers_mut());
1364 }
1365 return response;
1366 }
1367 }
1368 if let Some(cached) = bus.read::<ranvier_guard::IdempotencyCachedResponse>() {
1370 let body = Bytes::from(cached.body.clone());
1371 let mut response = Response::builder()
1372 .status(StatusCode::OK)
1373 .header("content-type", "application/json")
1374 .body(Full::new(body).map_err(|n: Infallible| match n {}).boxed())
1375 .unwrap();
1376 for extractor in route_response_extractors.iter() {
1377 extractor(&bus, response.headers_mut());
1378 }
1379 return response;
1380 }
1381 let result = if let Some(td) = bus.read::<ranvier_guard::TimeoutDeadline>() {
1383 let remaining = td.remaining();
1384 if remaining.is_zero() {
1385 let mut response = json_error_response(
1386 StatusCode::REQUEST_TIMEOUT,
1387 "Request timeout: pipeline deadline exceeded",
1388 );
1389 for extractor in route_response_extractors.iter() {
1390 extractor(&bus, response.headers_mut());
1391 }
1392 return response;
1393 }
1394 match tokio::time::timeout(remaining, circuit.execute((), &res, &mut bus)).await {
1395 Ok(result) => result,
1396 Err(_) => {
1397 let mut response = json_error_response(
1398 StatusCode::REQUEST_TIMEOUT,
1399 "Request timeout: pipeline deadline exceeded",
1400 );
1401 for extractor in route_response_extractors.iter() {
1402 extractor(&bus, response.headers_mut());
1403 }
1404 return response;
1405 }
1406 }
1407 } else {
1408 circuit.execute((), &res, &mut bus).await
1409 };
1410 let mut response = outcome_to_response_with_error(result, |error| error_handler(error));
1411 for extractor in route_response_extractors.iter() {
1412 extractor(&bus, response.headers_mut());
1413 }
1414 if !route_body_transforms.is_empty() {
1415 response = apply_body_transforms(response, &bus, &route_body_transforms).await;
1416 }
1417 response
1418 }
1419 .instrument(span)
1420 .await
1421 }) as Pin<Box<dyn Future<Output = HttpResponse> + Send>>
1422 });
1423
1424 self.routes.push(RouteEntry {
1425 method: method_for_pattern,
1426 pattern: RoutePattern::parse(&path_for_pattern),
1427 handler,
1428 layers: route_layers,
1429 apply_global_layers,
1430 needs_body: false,
1431 body_schema: None,
1432 });
1433 self
1434 }
1435
1436 fn route_method_typed<T, Out, E>(
1442 mut self,
1443 method: Method,
1444 path: impl Into<String>,
1445 circuit: Axon<T, Out, E, R>,
1446 ) -> Self
1447 where
1448 T: serde::de::DeserializeOwned + Send + Sync + serde::Serialize + schemars::JsonSchema + 'static,
1449 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1450 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1451 {
1452 let body_schema = serde_json::to_value(schemars::schema_for!(T)).ok();
1453 let path_str: String = path.into();
1454 let circuit = Arc::new(circuit);
1455 let route_bus_injectors = Arc::new(self.bus_injectors.clone());
1456 let route_guard_execs = Arc::new(self.guard_execs.clone());
1457 let route_response_extractors = Arc::new(self.guard_response_extractors.clone());
1458 let route_body_transforms = Arc::new(self.guard_body_transforms.clone());
1459 let path_for_pattern = path_str.clone();
1460 let path_for_handler = path_str;
1461 let method_for_pattern = method.clone();
1462 let method_for_handler = method;
1463
1464 let handler: RouteHandler<R> = Arc::new(move |parts: http::request::Parts, res: &R| {
1465 let circuit = circuit.clone();
1466 let route_bus_injectors = route_bus_injectors.clone();
1467 let route_guard_execs = route_guard_execs.clone();
1468 let route_response_extractors = route_response_extractors.clone();
1469 let route_body_transforms = route_body_transforms.clone();
1470 let res = res.clone();
1471 let path = path_for_handler.clone();
1472 let method = method_for_handler.clone();
1473
1474 Box::pin(async move {
1475 let request_id = uuid::Uuid::new_v4().to_string();
1476 let span = tracing::info_span!(
1477 "HTTPRequest",
1478 ranvier.http.method = %method,
1479 ranvier.http.path = %path,
1480 ranvier.http.request_id = %request_id
1481 );
1482
1483 async move {
1484 let body_bytes = parts
1486 .extensions
1487 .get::<BodyBytes>()
1488 .map(|b| b.0.clone())
1489 .unwrap_or_default();
1490
1491 let input: T = match serde_json::from_slice(&body_bytes) {
1493 Ok(v) => v,
1494 Err(e) => {
1495 return json_error_response(
1496 StatusCode::BAD_REQUEST,
1497 &format!("Invalid request body: {}", e),
1498 );
1499 }
1500 };
1501
1502 let mut bus = Bus::new();
1503 for injector in route_bus_injectors.iter() {
1504 injector(&parts, &mut bus);
1505 }
1506 for guard_exec in route_guard_execs.iter() {
1507 if let Err(rejection) = guard_exec.exec_guard(&mut bus).await {
1508 let mut response = json_error_response(rejection.status, &rejection.message);
1509 for extractor in route_response_extractors.iter() {
1510 extractor(&bus, response.headers_mut());
1511 }
1512 return response;
1513 }
1514 }
1515 if let Some(cached) = bus.read::<ranvier_guard::IdempotencyCachedResponse>() {
1517 let body = Bytes::from(cached.body.clone());
1518 let mut response = Response::builder()
1519 .status(StatusCode::OK)
1520 .header("content-type", "application/json")
1521 .body(Full::new(body).map_err(|n: Infallible| match n {}).boxed())
1522 .unwrap();
1523 for extractor in route_response_extractors.iter() {
1524 extractor(&bus, response.headers_mut());
1525 }
1526 return response;
1527 }
1528 let result = if let Some(td) = bus.read::<ranvier_guard::TimeoutDeadline>() {
1530 let remaining = td.remaining();
1531 if remaining.is_zero() {
1532 let mut response = json_error_response(
1533 StatusCode::REQUEST_TIMEOUT,
1534 "Request timeout: pipeline deadline exceeded",
1535 );
1536 for extractor in route_response_extractors.iter() {
1537 extractor(&bus, response.headers_mut());
1538 }
1539 return response;
1540 }
1541 match tokio::time::timeout(remaining, circuit.execute(input, &res, &mut bus)).await {
1542 Ok(result) => result,
1543 Err(_) => {
1544 let mut response = json_error_response(
1545 StatusCode::REQUEST_TIMEOUT,
1546 "Request timeout: pipeline deadline exceeded",
1547 );
1548 for extractor in route_response_extractors.iter() {
1549 extractor(&bus, response.headers_mut());
1550 }
1551 return response;
1552 }
1553 }
1554 } else {
1555 circuit.execute(input, &res, &mut bus).await
1556 };
1557 let mut response = outcome_to_response_with_error(result, |error| {
1558 if cfg!(debug_assertions) {
1559 (
1560 StatusCode::INTERNAL_SERVER_ERROR,
1561 format!("Error: {:?}", error),
1562 )
1563 .into_response()
1564 } else {
1565 json_error_response(
1566 StatusCode::INTERNAL_SERVER_ERROR,
1567 "Internal server error",
1568 )
1569 }
1570 });
1571 for extractor in route_response_extractors.iter() {
1572 extractor(&bus, response.headers_mut());
1573 }
1574 if !route_body_transforms.is_empty() {
1575 response = apply_body_transforms(response, &bus, &route_body_transforms).await;
1576 }
1577 response
1578 }
1579 .instrument(span)
1580 .await
1581 }) as Pin<Box<dyn Future<Output = HttpResponse> + Send>>
1582 });
1583
1584 self.routes.push(RouteEntry {
1585 method: method_for_pattern,
1586 pattern: RoutePattern::parse(&path_for_pattern),
1587 handler,
1588 layers: Arc::new(Vec::new()),
1589 apply_global_layers: true,
1590 needs_body: true,
1591 body_schema,
1592 });
1593 self
1594 }
1595
1596 pub fn get<Out, E>(self, path: impl Into<String>, circuit: Axon<(), Out, E, R>) -> Self
1597 where
1598 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1599 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1600 {
1601 self.route_method(Method::GET, path, circuit)
1602 }
1603
1604 pub fn get_with_error<Out, E, H>(
1605 self,
1606 path: impl Into<String>,
1607 circuit: Axon<(), Out, E, R>,
1608 error_handler: H,
1609 ) -> Self
1610 where
1611 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1612 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1613 H: Fn(&E) -> HttpResponse + Send + Sync + 'static,
1614 {
1615 self.route_method_with_error(Method::GET, path, circuit, error_handler)
1616 }
1617
1618 pub fn post<Out, E>(self, path: impl Into<String>, circuit: Axon<(), Out, E, R>) -> Self
1619 where
1620 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1621 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1622 {
1623 self.route_method(Method::POST, path, circuit)
1624 }
1625
1626 pub fn post_typed<T, Out, E>(
1642 self,
1643 path: impl Into<String>,
1644 circuit: Axon<T, Out, E, R>,
1645 ) -> Self
1646 where
1647 T: serde::de::DeserializeOwned + Send + Sync + serde::Serialize + schemars::JsonSchema + 'static,
1648 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1649 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1650 {
1651 self.route_method_typed::<T, Out, E>(Method::POST, path, circuit)
1652 }
1653
1654 pub fn put_typed<T, Out, E>(
1658 self,
1659 path: impl Into<String>,
1660 circuit: Axon<T, Out, E, R>,
1661 ) -> Self
1662 where
1663 T: serde::de::DeserializeOwned + Send + Sync + serde::Serialize + schemars::JsonSchema + 'static,
1664 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1665 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1666 {
1667 self.route_method_typed::<T, Out, E>(Method::PUT, path, circuit)
1668 }
1669
1670 pub fn patch_typed<T, Out, E>(
1674 self,
1675 path: impl Into<String>,
1676 circuit: Axon<T, Out, E, R>,
1677 ) -> Self
1678 where
1679 T: serde::de::DeserializeOwned + Send + Sync + serde::Serialize + schemars::JsonSchema + 'static,
1680 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1681 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1682 {
1683 self.route_method_typed::<T, Out, E>(Method::PATCH, path, circuit)
1684 }
1685
1686 pub fn put<Out, E>(self, path: impl Into<String>, circuit: Axon<(), Out, E, R>) -> Self
1687 where
1688 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1689 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1690 {
1691 self.route_method(Method::PUT, path, circuit)
1692 }
1693
1694 pub fn delete<Out, E>(self, path: impl Into<String>, circuit: Axon<(), Out, E, R>) -> Self
1695 where
1696 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1697 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1698 {
1699 self.route_method(Method::DELETE, path, circuit)
1700 }
1701
1702 pub fn patch<Out, E>(self, path: impl Into<String>, circuit: Axon<(), Out, E, R>) -> Self
1703 where
1704 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1705 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1706 {
1707 self.route_method(Method::PATCH, path, circuit)
1708 }
1709
1710 pub fn post_with_error<Out, E, H>(
1711 self,
1712 path: impl Into<String>,
1713 circuit: Axon<(), Out, E, R>,
1714 error_handler: H,
1715 ) -> Self
1716 where
1717 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1718 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1719 H: Fn(&E) -> HttpResponse + Send + Sync + 'static,
1720 {
1721 self.route_method_with_error(Method::POST, path, circuit, error_handler)
1722 }
1723
1724 pub fn put_with_error<Out, E, H>(
1725 self,
1726 path: impl Into<String>,
1727 circuit: Axon<(), Out, E, R>,
1728 error_handler: H,
1729 ) -> Self
1730 where
1731 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1732 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1733 H: Fn(&E) -> HttpResponse + Send + Sync + 'static,
1734 {
1735 self.route_method_with_error(Method::PUT, path, circuit, error_handler)
1736 }
1737
1738 pub fn delete_with_error<Out, E, H>(
1739 self,
1740 path: impl Into<String>,
1741 circuit: Axon<(), Out, E, R>,
1742 error_handler: H,
1743 ) -> Self
1744 where
1745 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1746 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1747 H: Fn(&E) -> HttpResponse + Send + Sync + 'static,
1748 {
1749 self.route_method_with_error(Method::DELETE, path, circuit, error_handler)
1750 }
1751
1752 pub fn patch_with_error<Out, E, H>(
1753 self,
1754 path: impl Into<String>,
1755 circuit: Axon<(), Out, E, R>,
1756 error_handler: H,
1757 ) -> Self
1758 where
1759 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1760 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1761 H: Fn(&E) -> HttpResponse + Send + Sync + 'static,
1762 {
1763 self.route_method_with_error(Method::PATCH, path, circuit, error_handler)
1764 }
1765
1766 #[cfg(feature = "streaming")]
1788 pub fn post_sse<Item, E>(
1789 self,
1790 path: impl Into<String>,
1791 circuit: ranvier_runtime::StreamingAxon<(), Item, E, R>,
1792 ) -> Self
1793 where
1794 Item: serde::Serialize + Send + Sync + 'static,
1795 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1796 {
1797 self.route_sse_internal::<(), Item, E>(Method::POST, path, circuit, false)
1798 }
1799
1800 #[cfg(feature = "streaming")]
1817 pub fn post_sse_typed<T, Item, E>(
1818 self,
1819 path: impl Into<String>,
1820 circuit: ranvier_runtime::StreamingAxon<T, Item, E, R>,
1821 ) -> Self
1822 where
1823 T: serde::de::DeserializeOwned + Send + Sync + serde::Serialize + schemars::JsonSchema + 'static,
1824 Item: serde::Serialize + Send + Sync + 'static,
1825 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1826 {
1827 self.route_sse_internal::<T, Item, E>(Method::POST, path, circuit, true)
1828 }
1829
1830 #[cfg(feature = "streaming")]
1832 fn route_sse_internal<T, Item, E>(
1833 mut self,
1834 method: Method,
1835 path: impl Into<String>,
1836 circuit: ranvier_runtime::StreamingAxon<T, Item, E, R>,
1837 needs_body: bool,
1838 ) -> Self
1839 where
1840 T: serde::de::DeserializeOwned + Send + Sync + serde::Serialize + 'static,
1841 Item: serde::Serialize + Send + Sync + 'static,
1842 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1843 {
1844 let path_str: String = path.into();
1845 let circuit = Arc::new(circuit);
1846 let route_bus_injectors = Arc::new(self.bus_injectors.clone());
1847 let route_guard_execs = Arc::new(self.guard_execs.clone());
1848 let route_response_extractors = Arc::new(self.guard_response_extractors.clone());
1849 let path_for_pattern = path_str.clone();
1850 let path_for_handler = path_str;
1851 let method_for_pattern = method.clone();
1852 let method_for_handler = method;
1853
1854 let handler: RouteHandler<R> = Arc::new(move |parts: http::request::Parts, res: &R| {
1855 let circuit = circuit.clone();
1856 let route_bus_injectors = route_bus_injectors.clone();
1857 let route_guard_execs = route_guard_execs.clone();
1858 let route_response_extractors = route_response_extractors.clone();
1859 let res = res.clone();
1860 let path = path_for_handler.clone();
1861 let method = method_for_handler.clone();
1862
1863 Box::pin(async move {
1864 let request_id = uuid::Uuid::new_v4().to_string();
1865 let span = tracing::info_span!(
1866 "SSERequest",
1867 ranvier.http.method = %method,
1868 ranvier.http.path = %path,
1869 ranvier.http.request_id = %request_id
1870 );
1871
1872 async move {
1873 let input: T = if needs_body {
1875 let body_bytes = parts
1876 .extensions
1877 .get::<BodyBytes>()
1878 .map(|b| b.0.clone())
1879 .unwrap_or_default();
1880
1881 match serde_json::from_slice(&body_bytes) {
1882 Ok(v) => v,
1883 Err(e) => {
1884 return json_error_response(
1885 StatusCode::BAD_REQUEST,
1886 &format!("Invalid request body: {}", e),
1887 );
1888 }
1889 }
1890 } else {
1891 match serde_json::from_str("null") {
1894 Ok(v) => v,
1895 Err(_) => {
1896 return json_error_response(
1897 StatusCode::INTERNAL_SERVER_ERROR,
1898 "Internal: failed to construct default input",
1899 );
1900 }
1901 }
1902 };
1903
1904 let mut bus = Bus::new();
1905 for injector in route_bus_injectors.iter() {
1906 injector(&parts, &mut bus);
1907 }
1908 for guard_exec in route_guard_execs.iter() {
1909 if let Err(rejection) = guard_exec.exec_guard(&mut bus).await {
1910 let mut response = json_error_response(rejection.status, &rejection.message);
1911 for extractor in route_response_extractors.iter() {
1912 extractor(&bus, response.headers_mut());
1913 }
1914 return response;
1915 }
1916 }
1917
1918 let stream = match circuit.execute(input, &res, &mut bus).await {
1920 Ok(s) => s,
1921 Err(e) => {
1922 tracing::error!("Streaming pipeline error: {}", e);
1923 if cfg!(debug_assertions) {
1924 return json_error_response(
1925 StatusCode::INTERNAL_SERVER_ERROR,
1926 &format!("Streaming error: {}", e),
1927 );
1928 } else {
1929 return json_error_response(
1930 StatusCode::INTERNAL_SERVER_ERROR,
1931 "Internal server error",
1932 );
1933 }
1934 }
1935 };
1936
1937 let buffer_size = circuit.buffer_size;
1940 let (tx, mut rx) = tokio::sync::mpsc::channel::<Bytes>(buffer_size);
1941
1942 tokio::spawn(async move {
1944 let mut pinned = Box::pin(stream);
1945 while let Some(item) = futures_util::StreamExt::next(&mut pinned).await {
1946 let text = match serde_json::to_string(&item) {
1947 Ok(json) => format!("data: {}\n\n", json),
1948 Err(e) => {
1949 tracing::error!("SSE item serialization error: {}", e);
1950 let err_text = "event: error\ndata: {\"message\":\"serialization error\",\"code\":\"serialize_error\"}\n\n".to_string();
1951 let _ = tx.send(Bytes::from(err_text)).await;
1952 break;
1953 }
1954 };
1955 if tx.send(Bytes::from(text)).await.is_err() {
1956 tracing::info!("SSE client disconnected");
1957 break;
1958 }
1959 }
1960 let _ = tx.send(Bytes::from("data: [DONE]\n\n")).await;
1962 });
1963
1964 let frame_stream = async_stream::stream! {
1966 while let Some(bytes) = rx.recv().await {
1967 yield Ok::<http_body::Frame<Bytes>, std::convert::Infallible>(
1968 http_body::Frame::data(bytes)
1969 );
1970 }
1971 };
1972
1973 let body = http_body_util::StreamBody::new(frame_stream);
1974 Response::builder()
1975 .status(StatusCode::OK)
1976 .header(http::header::CONTENT_TYPE, "text/event-stream")
1977 .header(http::header::CACHE_CONTROL, "no-cache")
1978 .header(http::header::CONNECTION, "keep-alive")
1979 .body(http_body_util::BodyExt::boxed(body))
1980 .expect("Valid SSE response")
1981 }
1982 .instrument(span)
1983 .await
1984 }) as Pin<Box<dyn Future<Output = HttpResponse> + Send>>
1985 });
1986
1987 self.routes.push(RouteEntry {
1988 method: method_for_pattern,
1989 pattern: RoutePattern::parse(&path_for_pattern),
1990 handler,
1991 layers: Arc::new(Vec::new()),
1992 apply_global_layers: true,
1993 needs_body,
1994 body_schema: None,
1995 });
1996 self
1997 }
1998
1999 fn route_method_with_extra_guards<Out, E>(
2005 mut self,
2006 method: Method,
2007 path: impl Into<String>,
2008 circuit: Axon<(), Out, E, R>,
2009 extra_guards: Vec<RegisteredGuard>,
2010 ) -> Self
2011 where
2012 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
2013 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
2014 {
2015 let saved_injectors = self.bus_injectors.len();
2017 let saved_execs = self.guard_execs.len();
2018 let saved_extractors = self.guard_response_extractors.len();
2019 let saved_transforms = self.guard_body_transforms.len();
2020
2021 for registration in extra_guards {
2023 for injector in registration.bus_injectors {
2024 self.bus_injectors.push(injector);
2025 }
2026 self.guard_execs.push(registration.exec);
2027 if let Some(extractor) = registration.response_extractor {
2028 self.guard_response_extractors.push(extractor);
2029 }
2030 if let Some(transform) = registration.response_body_transform {
2031 self.guard_body_transforms.push(transform);
2032 }
2033 }
2034
2035 self = self.route_method(method, path, circuit);
2037
2038 self.bus_injectors.truncate(saved_injectors);
2040 self.guard_execs.truncate(saved_execs);
2041 self.guard_response_extractors.truncate(saved_extractors);
2042 self.guard_body_transforms.truncate(saved_transforms);
2043
2044 self
2045 }
2046
2047 pub fn get_with_guards<Out, E>(
2065 self,
2066 path: impl Into<String>,
2067 circuit: Axon<(), Out, E, R>,
2068 extra_guards: Vec<RegisteredGuard>,
2069 ) -> Self
2070 where
2071 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
2072 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
2073 {
2074 self.route_method_with_extra_guards(Method::GET, path, circuit, extra_guards)
2075 }
2076
2077 pub fn post_with_guards<Out, E>(
2098 self,
2099 path: impl Into<String>,
2100 circuit: Axon<(), Out, E, R>,
2101 extra_guards: Vec<RegisteredGuard>,
2102 ) -> Self
2103 where
2104 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
2105 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
2106 {
2107 self.route_method_with_extra_guards(Method::POST, path, circuit, extra_guards)
2108 }
2109
2110 pub fn put_with_guards<Out, E>(
2112 self,
2113 path: impl Into<String>,
2114 circuit: Axon<(), Out, E, R>,
2115 extra_guards: Vec<RegisteredGuard>,
2116 ) -> Self
2117 where
2118 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
2119 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
2120 {
2121 self.route_method_with_extra_guards(Method::PUT, path, circuit, extra_guards)
2122 }
2123
2124 pub fn delete_with_guards<Out, E>(
2126 self,
2127 path: impl Into<String>,
2128 circuit: Axon<(), Out, E, R>,
2129 extra_guards: Vec<RegisteredGuard>,
2130 ) -> Self
2131 where
2132 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
2133 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
2134 {
2135 self.route_method_with_extra_guards(Method::DELETE, path, circuit, extra_guards)
2136 }
2137
2138 pub fn patch_with_guards<Out, E>(
2140 self,
2141 path: impl Into<String>,
2142 circuit: Axon<(), Out, E, R>,
2143 extra_guards: Vec<RegisteredGuard>,
2144 ) -> Self
2145 where
2146 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
2147 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
2148 {
2149 self.route_method_with_extra_guards(Method::PATCH, path, circuit, extra_guards)
2150 }
2151
2152 pub fn fallback<Out, E>(mut self, circuit: Axon<(), Out, E, R>) -> Self
2163 where
2164 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
2165 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
2166 {
2167 let circuit = Arc::new(circuit);
2168 let fallback_bus_injectors = Arc::new(self.bus_injectors.clone());
2169 let fallback_guard_execs = Arc::new(self.guard_execs.clone());
2170 let fallback_response_extractors = Arc::new(self.guard_response_extractors.clone());
2171 let fallback_body_transforms = Arc::new(self.guard_body_transforms.clone());
2172
2173 let handler: RouteHandler<R> = Arc::new(move |parts: http::request::Parts, res: &R| {
2174 let circuit = circuit.clone();
2175 let fallback_bus_injectors = fallback_bus_injectors.clone();
2176 let fallback_guard_execs = fallback_guard_execs.clone();
2177 let fallback_response_extractors = fallback_response_extractors.clone();
2178 let fallback_body_transforms = fallback_body_transforms.clone();
2179 let res = res.clone();
2180 Box::pin(async move {
2181 let request_id = uuid::Uuid::new_v4().to_string();
2182 let span = tracing::info_span!(
2183 "HTTPRequest",
2184 ranvier.http.method = "FALLBACK",
2185 ranvier.http.request_id = %request_id
2186 );
2187
2188 async move {
2189 let mut bus = Bus::new();
2190 for injector in fallback_bus_injectors.iter() {
2191 injector(&parts, &mut bus);
2192 }
2193 for guard_exec in fallback_guard_execs.iter() {
2194 if let Err(rejection) = guard_exec.exec_guard(&mut bus).await {
2195 let mut response = json_error_response(rejection.status, &rejection.message);
2196 for extractor in fallback_response_extractors.iter() {
2197 extractor(&bus, response.headers_mut());
2198 }
2199 return response;
2200 }
2201 }
2202 let result: ranvier_core::Outcome<Out, E> =
2203 circuit.execute((), &res, &mut bus).await;
2204
2205 let mut response = match result {
2206 Outcome::Next(output) => {
2207 let mut response = output.into_response();
2208 *response.status_mut() = StatusCode::NOT_FOUND;
2209 response
2210 }
2211 _ => Response::builder()
2212 .status(StatusCode::NOT_FOUND)
2213 .body(
2214 Full::new(Bytes::from("Not Found"))
2215 .map_err(|never| match never {})
2216 .boxed(),
2217 )
2218 .expect("valid HTTP response construction"),
2219 };
2220 for extractor in fallback_response_extractors.iter() {
2221 extractor(&bus, response.headers_mut());
2222 }
2223 if !fallback_body_transforms.is_empty() {
2224 response = apply_body_transforms(response, &bus, &fallback_body_transforms).await;
2225 }
2226 response
2227 }
2228 .instrument(span)
2229 .await
2230 }) as Pin<Box<dyn Future<Output = HttpResponse> + Send>>
2231 });
2232
2233 self.fallback = Some(handler);
2234 self
2235 }
2236
2237 pub async fn run(self, resources: R) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
2241 self.run_with_shutdown_signal(resources, shutdown_signal())
2242 .await
2243 }
2244
2245 async fn run_with_shutdown_signal<S>(
2246 self,
2247 resources: R,
2248 shutdown_signal: S,
2249 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>>
2250 where
2251 S: Future<Output = ()> + Send,
2252 {
2253 let addr_str = self.addr.as_deref().unwrap_or("127.0.0.1:3000");
2254 let addr: SocketAddr = addr_str.parse()?;
2255
2256 let mut raw_routes = self.routes;
2257 if self.active_intervention {
2258 let handler: RouteHandler<R> = Arc::new(|_parts, _res| {
2259 Box::pin(async move {
2260 Response::builder()
2261 .status(StatusCode::OK)
2262 .body(
2263 Full::new(Bytes::from("Intervention accepted"))
2264 .map_err(|never| match never {} as Infallible)
2265 .boxed(),
2266 )
2267 .expect("valid HTTP response construction")
2268 }) as Pin<Box<dyn Future<Output = HttpResponse> + Send>>
2269 });
2270
2271 raw_routes.push(RouteEntry {
2272 method: Method::POST,
2273 pattern: RoutePattern::parse("/_system/intervene/force_resume"),
2274 handler,
2275 layers: Arc::new(Vec::new()),
2276 apply_global_layers: true,
2277 needs_body: false,
2278 body_schema: None,
2279 });
2280 }
2281
2282 if let Some(registry) = self.policy_registry.clone() {
2283 let handler: RouteHandler<R> = Arc::new(move |_parts, _res| {
2284 let _registry = registry.clone();
2285 Box::pin(async move {
2286 Response::builder()
2290 .status(StatusCode::OK)
2291 .body(
2292 Full::new(Bytes::from("Policy registry active"))
2293 .map_err(|never| match never {} as Infallible)
2294 .boxed(),
2295 )
2296 .expect("valid HTTP response construction")
2297 }) as Pin<Box<dyn Future<Output = HttpResponse> + Send>>
2298 });
2299
2300 raw_routes.push(RouteEntry {
2301 method: Method::POST,
2302 pattern: RoutePattern::parse("/_system/policy/reload"),
2303 handler,
2304 layers: Arc::new(Vec::new()),
2305 apply_global_layers: true,
2306 needs_body: false,
2307 body_schema: None,
2308 });
2309 }
2310 let routes = Arc::new(raw_routes);
2311 let fallback = self.fallback;
2312 let layers = Arc::new(self.layers);
2313 let health = Arc::new(self.health);
2314 let static_assets = Arc::new(self.static_assets);
2315 let preflight_config = Arc::new(self.preflight_config);
2316 let on_start = self.on_start;
2317 let on_shutdown = self.on_shutdown;
2318 let graceful_shutdown_timeout = self.graceful_shutdown_timeout;
2319 let resources = Arc::new(resources);
2320
2321 let listener = TcpListener::bind(addr).await?;
2322
2323 #[cfg(feature = "tls")]
2325 let tls_acceptor = if let Some(ref tls_cfg) = self.tls_config {
2326 let acceptor = build_tls_acceptor(&tls_cfg.cert_path, &tls_cfg.key_path)?;
2327 tracing::info!("Ranvier HTTP Ingress listening on https://{}", addr);
2328 Some(acceptor)
2329 } else {
2330 tracing::info!("Ranvier HTTP Ingress listening on http://{}", addr);
2331 None
2332 };
2333 #[cfg(not(feature = "tls"))]
2334 tracing::info!("Ranvier HTTP Ingress listening on http://{}", addr);
2335
2336 if let Some(callback) = on_start.as_ref() {
2337 callback();
2338 }
2339
2340 tokio::pin!(shutdown_signal);
2341 let mut connections = tokio::task::JoinSet::new();
2342
2343 loop {
2344 tokio::select! {
2345 _ = &mut shutdown_signal => {
2346 tracing::info!("Shutdown signal received. Draining in-flight connections.");
2347 break;
2348 }
2349 accept_result = listener.accept() => {
2350 let (stream, _) = accept_result?;
2351
2352 let routes = routes.clone();
2353 let fallback = fallback.clone();
2354 let resources = resources.clone();
2355 let layers = layers.clone();
2356 let health = health.clone();
2357 let static_assets = static_assets.clone();
2358 let preflight_config = preflight_config.clone();
2359 #[cfg(feature = "http3")]
2360 let alt_svc_h3_port = self.alt_svc_h3_port;
2361
2362 #[cfg(feature = "tls")]
2363 let tls_acceptor = tls_acceptor.clone();
2364
2365 connections.spawn(async move {
2366 let service = build_http_service(
2367 routes,
2368 fallback,
2369 resources,
2370 layers,
2371 health,
2372 static_assets,
2373 preflight_config,
2374 #[cfg(feature = "http3")] alt_svc_h3_port,
2375 );
2376
2377 #[cfg(feature = "tls")]
2378 if let Some(acceptor) = tls_acceptor {
2379 match acceptor.accept(stream).await {
2380 Ok(tls_stream) => {
2381 let io = TokioIo::new(tls_stream);
2382 if let Err(err) = http1::Builder::new()
2383 .serve_connection(io, service)
2384 .with_upgrades()
2385 .await
2386 {
2387 tracing::error!("Error serving TLS connection: {:?}", err);
2388 }
2389 }
2390 Err(err) => {
2391 tracing::warn!("TLS handshake failed: {:?}", err);
2392 }
2393 }
2394 return;
2395 }
2396
2397 let io = TokioIo::new(stream);
2398 if let Err(err) = http1::Builder::new()
2399 .serve_connection(io, service)
2400 .with_upgrades()
2401 .await
2402 {
2403 tracing::error!("Error serving connection: {:?}", err);
2404 }
2405 });
2406 }
2407 Some(join_result) = connections.join_next(), if !connections.is_empty() => {
2408 if let Err(err) = join_result {
2409 tracing::warn!("Connection task join error: {:?}", err);
2410 }
2411 }
2412 }
2413 }
2414
2415 let _timed_out = drain_connections(&mut connections, graceful_shutdown_timeout).await;
2416
2417 drop(resources);
2418 if let Some(callback) = on_shutdown.as_ref() {
2419 callback();
2420 }
2421
2422 Ok(())
2423 }
2424
2425 pub fn into_raw_service(self, resources: R) -> RawIngressService<R> {
2441 let routes = Arc::new(self.routes);
2442 let fallback = self.fallback;
2443 let layers = Arc::new(self.layers);
2444 let health = Arc::new(self.health);
2445 let static_assets = Arc::new(self.static_assets);
2446 let preflight_config = Arc::new(self.preflight_config);
2447 let resources = Arc::new(resources);
2448
2449 RawIngressService {
2450 routes,
2451 fallback,
2452 layers,
2453 health,
2454 static_assets,
2455 preflight_config,
2456 resources,
2457 }
2458 }
2459}
2460
2461async fn apply_body_transforms(
2466 response: HttpResponse,
2467 bus: &Bus,
2468 transforms: &[ResponseBodyTransformFn],
2469) -> HttpResponse {
2470 use http_body_util::BodyExt;
2471
2472 let (parts, body) = response.into_parts();
2473
2474 let collected = match body.collect().await {
2476 Ok(c) => c.to_bytes(),
2477 Err(_) => {
2478 return Response::builder()
2480 .status(StatusCode::INTERNAL_SERVER_ERROR)
2481 .body(
2482 Full::new(Bytes::from("body collection failed"))
2483 .map_err(|never| match never {})
2484 .boxed(),
2485 )
2486 .expect("valid response");
2487 }
2488 };
2489
2490 let mut transformed = collected;
2491 for transform in transforms {
2492 transformed = transform(bus, transformed);
2493 }
2494
2495 Response::from_parts(
2496 parts,
2497 Full::new(transformed)
2498 .map_err(|never| match never {})
2499 .boxed(),
2500 )
2501}
2502
2503fn build_http_service<R>(
2504 routes: Arc<Vec<RouteEntry<R>>>,
2505 fallback: Option<RouteHandler<R>>,
2506 resources: Arc<R>,
2507 layers: Arc<Vec<ServiceLayer>>,
2508 health: Arc<HealthConfig<R>>,
2509 static_assets: Arc<StaticAssetsConfig>,
2510 preflight_config: Arc<Option<PreflightConfig>>,
2511 #[cfg(feature = "http3")] alt_svc_port: Option<u16>,
2512) -> BoxHttpService
2513where
2514 R: ranvier_core::transition::ResourceRequirement + Clone + Send + Sync + 'static,
2515{
2516 BoxService::new(move |req: Request<Incoming>| {
2517 let routes = routes.clone();
2518 let fallback = fallback.clone();
2519 let resources = resources.clone();
2520 let layers = layers.clone();
2521 let health = health.clone();
2522 let static_assets = static_assets.clone();
2523 let preflight_config = preflight_config.clone();
2524
2525 async move {
2526 let mut req = req;
2527 let method = req.method().clone();
2528 let path = req.uri().path().to_string();
2529
2530 if let Some(response) =
2531 maybe_handle_health_request(&method, &path, &health, resources.clone()).await
2532 {
2533 return Ok::<_, Infallible>(response.into_response());
2534 }
2535
2536 if method == Method::OPTIONS {
2538 if let Some(ref config) = *preflight_config {
2539 let origin = req
2540 .headers()
2541 .get("origin")
2542 .and_then(|v| v.to_str().ok())
2543 .unwrap_or("");
2544 let is_wildcard = config.allowed_origins.iter().any(|o| o == "*");
2545 let is_allowed = is_wildcard
2546 || config.allowed_origins.iter().any(|o| o == origin);
2547
2548 if is_allowed || origin.is_empty() {
2549 let allow_origin = if is_wildcard {
2550 "*".to_string()
2551 } else {
2552 origin.to_string()
2553 };
2554 let mut response = Response::builder()
2555 .status(StatusCode::NO_CONTENT)
2556 .body(
2557 Full::new(Bytes::new())
2558 .map_err(|never| match never {})
2559 .boxed(),
2560 )
2561 .expect("valid preflight response");
2562 let headers = response.headers_mut();
2563 if let Ok(v) = allow_origin.parse() {
2564 headers.insert("access-control-allow-origin", v);
2565 }
2566 if let Ok(v) = config.allowed_methods.parse() {
2567 headers.insert("access-control-allow-methods", v);
2568 }
2569 if let Ok(v) = config.allowed_headers.parse() {
2570 headers.insert("access-control-allow-headers", v);
2571 }
2572 if let Ok(v) = config.max_age.parse() {
2573 headers.insert("access-control-max-age", v);
2574 }
2575 if config.allow_credentials {
2576 headers.insert(
2577 "access-control-allow-credentials",
2578 "true".parse().expect("valid header value"),
2579 );
2580 }
2581 return Ok(response);
2582 }
2583 }
2584 }
2585
2586 if let Some((entry, params)) = find_matching_route(routes.as_slice(), &method, &path) {
2587 req.extensions_mut().insert(params);
2588 let effective_layers = if entry.apply_global_layers {
2589 merge_layers(&layers, &entry.layers)
2590 } else {
2591 entry.layers.clone()
2592 };
2593
2594 if effective_layers.is_empty() {
2595 let (mut parts, body) = req.into_parts();
2596 if entry.needs_body {
2597 match BodyExt::collect(body).await {
2598 Ok(collected) => { parts.extensions.insert(BodyBytes(collected.to_bytes())); }
2599 Err(_) => {
2600 return Ok(json_error_response(
2601 StatusCode::BAD_REQUEST,
2602 "Failed to read request body",
2603 ));
2604 }
2605 }
2606 }
2607 #[allow(unused_mut)]
2608 let mut res = (entry.handler)(parts, &resources).await;
2609 #[cfg(feature = "http3")]
2610 if let Some(port) = alt_svc_port {
2611 if let Ok(val) =
2612 http::HeaderValue::from_str(&format!("h3=\":{}\"; ma=86400", port))
2613 {
2614 res.headers_mut().insert(http::header::ALT_SVC, val);
2615 }
2616 }
2617 Ok::<_, Infallible>(res)
2618 } else {
2619 let route_service = build_route_service(
2620 entry.handler.clone(),
2621 resources.clone(),
2622 effective_layers,
2623 entry.needs_body,
2624 );
2625 #[allow(unused_mut)]
2626 let mut res = route_service.call(req).await;
2627 #[cfg(feature = "http3")]
2628 #[allow(irrefutable_let_patterns)]
2629 if let Ok(ref mut r) = res {
2630 if let Some(port) = alt_svc_port {
2631 if let Ok(val) =
2632 http::HeaderValue::from_str(&format!("h3=\":{}\"; ma=86400", port))
2633 {
2634 r.headers_mut().insert(http::header::ALT_SVC, val);
2635 }
2636 }
2637 }
2638 res
2639 }
2640 } else {
2641 let req =
2642 match maybe_handle_static_request(req, &method, &path, static_assets.as_ref())
2643 .await
2644 {
2645 Ok(req) => req,
2646 Err(response) => return Ok(response),
2647 };
2648
2649 #[allow(unused_mut)]
2650 let mut fallback_res = if let Some(ref fb) = fallback {
2651 if layers.is_empty() {
2652 let (parts, _) = req.into_parts();
2653 Ok(fb(parts, &resources).await)
2654 } else {
2655 let fallback_service =
2656 build_route_service(fb.clone(), resources.clone(), layers.clone(), false);
2657 fallback_service.call(req).await
2658 }
2659 } else {
2660 Ok(Response::builder()
2661 .status(StatusCode::NOT_FOUND)
2662 .body(
2663 Full::new(Bytes::from("Not Found"))
2664 .map_err(|never| match never {})
2665 .boxed(),
2666 )
2667 .expect("valid HTTP response construction"))
2668 };
2669
2670 #[cfg(feature = "http3")]
2671 if let Ok(r) = fallback_res.as_mut() {
2672 if let Some(port) = alt_svc_port {
2673 if let Ok(val) =
2674 http::HeaderValue::from_str(&format!("h3=\":{}\"; ma=86400", port))
2675 {
2676 r.headers_mut().insert(http::header::ALT_SVC, val);
2677 }
2678 }
2679 }
2680
2681 fallback_res
2682 }
2683 }
2684 })
2685}
2686
2687fn build_route_service<R>(
2688 handler: RouteHandler<R>,
2689 resources: Arc<R>,
2690 layers: Arc<Vec<ServiceLayer>>,
2691 needs_body: bool,
2692) -> BoxHttpService
2693where
2694 R: ranvier_core::transition::ResourceRequirement + Clone + Send + Sync + 'static,
2695{
2696 let mut service = BoxService::new(move |req: Request<Incoming>| {
2697 let handler = handler.clone();
2698 let resources = resources.clone();
2699 async move {
2700 let (mut parts, body) = req.into_parts();
2701 if needs_body {
2702 match BodyExt::collect(body).await {
2703 Ok(collected) => { parts.extensions.insert(BodyBytes(collected.to_bytes())); }
2704 Err(_) => {
2705 return Ok(json_error_response(
2706 StatusCode::BAD_REQUEST,
2707 "Failed to read request body",
2708 ));
2709 }
2710 }
2711 }
2712 Ok::<_, Infallible>(handler(parts, &resources).await)
2713 }
2714 });
2715
2716 for layer in layers.iter() {
2717 service = layer(service);
2718 }
2719 service
2720}
2721
2722fn merge_layers(
2723 global_layers: &Arc<Vec<ServiceLayer>>,
2724 route_layers: &Arc<Vec<ServiceLayer>>,
2725) -> Arc<Vec<ServiceLayer>> {
2726 if global_layers.is_empty() {
2727 return route_layers.clone();
2728 }
2729 if route_layers.is_empty() {
2730 return global_layers.clone();
2731 }
2732
2733 let mut combined = Vec::with_capacity(global_layers.len() + route_layers.len());
2734 combined.extend(global_layers.iter().cloned());
2735 combined.extend(route_layers.iter().cloned());
2736 Arc::new(combined)
2737}
2738
2739async fn maybe_handle_health_request<R>(
2740 method: &Method,
2741 path: &str,
2742 health: &HealthConfig<R>,
2743 resources: Arc<R>,
2744) -> Option<HttpResponse>
2745where
2746 R: ranvier_core::transition::ResourceRequirement + Clone + Send + Sync + 'static,
2747{
2748 if method != Method::GET {
2749 return None;
2750 }
2751
2752 if let Some(liveness_path) = health.liveness_path.as_ref() {
2753 if path == liveness_path {
2754 return Some(health_json_response("liveness", true, Vec::new()));
2755 }
2756 }
2757
2758 if let Some(readiness_path) = health.readiness_path.as_ref() {
2759 if path == readiness_path {
2760 let (healthy, checks) = run_named_health_checks(&health.checks, resources).await;
2761 return Some(health_json_response("readiness", healthy, checks));
2762 }
2763 }
2764
2765 if let Some(health_path) = health.health_path.as_ref() {
2766 if path == health_path {
2767 let (healthy, checks) = run_named_health_checks(&health.checks, resources).await;
2768 return Some(health_json_response("health", healthy, checks));
2769 }
2770 }
2771
2772 None
2773}
2774
2775async fn serve_single_file(file_path: &str) -> Result<Response<Full<Bytes>>, std::io::Error> {
2777 let path = std::path::Path::new(file_path);
2778 let content = tokio::fs::read(path).await?;
2779 let mime = guess_mime(file_path);
2780 let mut response = Response::new(Full::new(Bytes::from(content)));
2781 if let Ok(value) = http::HeaderValue::from_str(mime) {
2782 response
2783 .headers_mut()
2784 .insert(http::header::CONTENT_TYPE, value);
2785 }
2786 if let Ok(metadata) = tokio::fs::metadata(path).await {
2787 if let Ok(modified) = metadata.modified() {
2788 if let Ok(duration) = modified.duration_since(std::time::UNIX_EPOCH) {
2789 let etag = format!("\"{}\"", duration.as_secs());
2790 if let Ok(value) = http::HeaderValue::from_str(&etag) {
2791 response.headers_mut().insert(http::header::ETAG, value);
2792 }
2793 }
2794 }
2795 }
2796 Ok(response)
2797}
2798
2799async fn serve_static_file(
2801 directory: &str,
2802 file_subpath: &str,
2803 config: &StaticAssetsConfig,
2804 if_none_match: Option<&http::HeaderValue>,
2805 accept_encoding: Option<&http::HeaderValue>,
2806 range_header: Option<&http::HeaderValue>,
2807) -> Result<Response<Full<Bytes>>, std::io::Error> {
2808 let subpath = file_subpath.trim_start_matches('/');
2809
2810 let resolved_subpath;
2812 if subpath.is_empty() || subpath.ends_with('/') {
2813 if let Some(ref index) = config.directory_index {
2814 resolved_subpath = if subpath.is_empty() {
2815 index.clone()
2816 } else {
2817 format!("{}{}", subpath, index)
2818 };
2819 } else {
2820 return Err(std::io::Error::new(
2821 std::io::ErrorKind::NotFound,
2822 "empty path",
2823 ));
2824 }
2825 } else {
2826 resolved_subpath = subpath.to_string();
2827 }
2828
2829 let full_path = std::path::Path::new(directory).join(&resolved_subpath);
2830 let canonical = tokio::fs::canonicalize(&full_path).await?;
2832 let dir_canonical = tokio::fs::canonicalize(directory).await?;
2833 if !canonical.starts_with(&dir_canonical) {
2834 return Err(std::io::Error::new(
2835 std::io::ErrorKind::PermissionDenied,
2836 "path traversal detected",
2837 ));
2838 }
2839
2840 let etag = if let Ok(metadata) = tokio::fs::metadata(&canonical).await {
2842 metadata
2843 .modified()
2844 .ok()
2845 .and_then(|m| m.duration_since(std::time::UNIX_EPOCH).ok())
2846 .map(|d| format!("\"{}\"", d.as_secs()))
2847 } else {
2848 None
2849 };
2850
2851 if let (Some(client_etag), Some(server_etag)) = (if_none_match, &etag) {
2853 if client_etag.as_bytes() == server_etag.as_bytes() {
2854 let mut response = Response::new(Full::new(Bytes::new()));
2855 *response.status_mut() = StatusCode::NOT_MODIFIED;
2856 if let Ok(value) = http::HeaderValue::from_str(server_etag) {
2857 response.headers_mut().insert(http::header::ETAG, value);
2858 }
2859 return Ok(response);
2860 }
2861 }
2862
2863 let (serve_path, content_encoding) = if config.serve_precompressed {
2865 let client_accepts = accept_encoding
2866 .and_then(|v| v.to_str().ok())
2867 .unwrap_or("");
2868 let canonical_str = canonical.to_str().unwrap_or("");
2869
2870 if client_accepts.contains("br") {
2871 let br_path = format!("{}.br", canonical_str);
2872 if tokio::fs::metadata(&br_path).await.is_ok() {
2873 (std::path::PathBuf::from(br_path), Some("br"))
2874 } else if client_accepts.contains("gzip") {
2875 let gz_path = format!("{}.gz", canonical_str);
2876 if tokio::fs::metadata(&gz_path).await.is_ok() {
2877 (std::path::PathBuf::from(gz_path), Some("gzip"))
2878 } else {
2879 (canonical.clone(), None)
2880 }
2881 } else {
2882 (canonical.clone(), None)
2883 }
2884 } else if client_accepts.contains("gzip") {
2885 let gz_path = format!("{}.gz", canonical_str);
2886 if tokio::fs::metadata(&gz_path).await.is_ok() {
2887 (std::path::PathBuf::from(gz_path), Some("gzip"))
2888 } else {
2889 (canonical.clone(), None)
2890 }
2891 } else {
2892 (canonical.clone(), None)
2893 }
2894 } else {
2895 (canonical.clone(), None)
2896 };
2897
2898 let content = tokio::fs::read(&serve_path).await?;
2899 let mime = guess_mime(canonical.to_str().unwrap_or(""));
2901
2902 if config.enable_range_requests {
2904 if let Some(range_val) = range_header {
2905 if let Some(response) = handle_range_request(
2906 range_val,
2907 &content,
2908 mime,
2909 etag.as_deref(),
2910 content_encoding,
2911 ) {
2912 return Ok(response);
2913 }
2914 }
2915 }
2916
2917 let mut response = Response::new(Full::new(Bytes::from(content)));
2918 if let Ok(value) = http::HeaderValue::from_str(mime) {
2919 response
2920 .headers_mut()
2921 .insert(http::header::CONTENT_TYPE, value);
2922 }
2923 if let Some(ref etag_val) = etag {
2924 if let Ok(value) = http::HeaderValue::from_str(etag_val) {
2925 response.headers_mut().insert(http::header::ETAG, value);
2926 }
2927 }
2928 if let Some(encoding) = content_encoding {
2929 if let Ok(value) = http::HeaderValue::from_str(encoding) {
2930 response
2931 .headers_mut()
2932 .insert(http::header::CONTENT_ENCODING, value);
2933 }
2934 }
2935 if config.enable_range_requests {
2936 response
2937 .headers_mut()
2938 .insert(http::header::ACCEPT_RANGES, http::HeaderValue::from_static("bytes"));
2939 }
2940
2941 if config.immutable_cache {
2943 if let Some(filename) = canonical.file_name().and_then(|n| n.to_str()) {
2944 if is_hashed_filename(filename) {
2945 if let Ok(value) = http::HeaderValue::from_str(
2946 "public, max-age=31536000, immutable",
2947 ) {
2948 response
2949 .headers_mut()
2950 .insert(http::header::CACHE_CONTROL, value);
2951 }
2952 }
2953 }
2954 }
2955
2956 Ok(response)
2957}
2958
2959fn handle_range_request(
2963 range_header: &http::HeaderValue,
2964 content: &[u8],
2965 mime: &str,
2966 etag: Option<&str>,
2967 content_encoding: Option<&str>,
2968) -> Option<Response<Full<Bytes>>> {
2969 let range_str = range_header.to_str().ok()?;
2970 let range_spec = range_str.strip_prefix("bytes=")?;
2971 let total = content.len();
2972 if total == 0 {
2973 return None;
2974 }
2975
2976 let (start, end) = if let Some(suffix) = range_spec.strip_prefix('-') {
2977 let n: usize = suffix.parse().ok()?;
2979 if n == 0 || n > total {
2980 return Some(range_not_satisfiable(total));
2981 }
2982 (total - n, total - 1)
2983 } else if range_spec.ends_with('-') {
2984 let start: usize = range_spec.trim_end_matches('-').parse().ok()?;
2986 if start >= total {
2987 return Some(range_not_satisfiable(total));
2988 }
2989 (start, total - 1)
2990 } else {
2991 let mut parts = range_spec.splitn(2, '-');
2993 let start: usize = parts.next()?.parse().ok()?;
2994 let end: usize = parts.next()?.parse().ok()?;
2995 if start > end || start >= total {
2996 return Some(range_not_satisfiable(total));
2997 }
2998 (start, end.min(total - 1))
2999 };
3000
3001 let slice = &content[start..=end];
3002 let content_range = format!("bytes {}-{}/{}", start, end, total);
3003
3004 let mut response = Response::new(Full::new(Bytes::copy_from_slice(slice)));
3005 *response.status_mut() = StatusCode::PARTIAL_CONTENT;
3006 if let Ok(v) = http::HeaderValue::from_str(&content_range) {
3007 response.headers_mut().insert(http::header::CONTENT_RANGE, v);
3008 }
3009 if let Ok(v) = http::HeaderValue::from_str(mime) {
3010 response
3011 .headers_mut()
3012 .insert(http::header::CONTENT_TYPE, v);
3013 }
3014 response
3015 .headers_mut()
3016 .insert(http::header::ACCEPT_RANGES, http::HeaderValue::from_static("bytes"));
3017 if let Some(etag_val) = etag {
3018 if let Ok(v) = http::HeaderValue::from_str(etag_val) {
3019 response.headers_mut().insert(http::header::ETAG, v);
3020 }
3021 }
3022 if let Some(encoding) = content_encoding {
3023 if let Ok(v) = http::HeaderValue::from_str(encoding) {
3024 response
3025 .headers_mut()
3026 .insert(http::header::CONTENT_ENCODING, v);
3027 }
3028 }
3029 Some(response)
3030}
3031
3032fn range_not_satisfiable(total: usize) -> Response<Full<Bytes>> {
3034 let content_range = format!("bytes */{}", total);
3035 let mut response = Response::new(Full::new(Bytes::from("Range Not Satisfiable")));
3036 *response.status_mut() = StatusCode::RANGE_NOT_SATISFIABLE;
3037 if let Ok(v) = http::HeaderValue::from_str(&content_range) {
3038 response.headers_mut().insert(http::header::CONTENT_RANGE, v);
3039 }
3040 response
3041}
3042
3043fn is_hashed_filename(filename: &str) -> bool {
3046 let parts: Vec<&str> = filename.rsplitn(3, '.').collect();
3047 if parts.len() < 3 {
3048 return false;
3049 }
3050 let hash_part = parts[1];
3052 hash_part.len() >= 6 && hash_part.chars().all(|c| c.is_ascii_hexdigit())
3053}
3054
3055fn guess_mime(path: &str) -> &'static str {
3056 match path.rsplit('.').next().unwrap_or("") {
3057 "html" | "htm" => "text/html; charset=utf-8",
3058 "css" => "text/css; charset=utf-8",
3059 "js" | "mjs" | "ts" | "tsx" => "application/javascript; charset=utf-8",
3060 "json" => "application/json; charset=utf-8",
3061 "png" => "image/png",
3062 "jpg" | "jpeg" => "image/jpeg",
3063 "gif" => "image/gif",
3064 "svg" => "image/svg+xml",
3065 "ico" => "image/x-icon",
3066 "avif" => "image/avif",
3067 "webp" => "image/webp",
3068 "webm" => "video/webm",
3069 "mp4" => "video/mp4",
3070 "woff" => "font/woff",
3071 "woff2" => "font/woff2",
3072 "ttf" => "font/ttf",
3073 "txt" => "text/plain; charset=utf-8",
3074 "xml" => "application/xml; charset=utf-8",
3075 "yaml" | "yml" => "application/yaml",
3076 "wasm" => "application/wasm",
3077 "pdf" => "application/pdf",
3078 "map" => "application/json",
3079 _ => "application/octet-stream",
3080 }
3081}
3082
3083fn apply_cache_control(
3084 mut response: Response<Full<Bytes>>,
3085 cache_control: Option<&str>,
3086) -> Response<Full<Bytes>> {
3087 if response.status() == StatusCode::OK {
3088 if let Some(value) = cache_control {
3089 if !response.headers().contains_key(http::header::CACHE_CONTROL) {
3090 if let Ok(header_value) = http::HeaderValue::from_str(value) {
3091 response
3092 .headers_mut()
3093 .insert(http::header::CACHE_CONTROL, header_value);
3094 }
3095 }
3096 }
3097 }
3098 response
3099}
3100
3101async fn maybe_handle_static_request(
3102 req: Request<Incoming>,
3103 method: &Method,
3104 path: &str,
3105 static_assets: &StaticAssetsConfig,
3106) -> Result<Request<Incoming>, HttpResponse> {
3107 if method != Method::GET && method != Method::HEAD {
3108 return Ok(req);
3109 }
3110
3111 if let Some(mount) = static_assets
3112 .mounts
3113 .iter()
3114 .find(|mount| strip_mount_prefix(path, &mount.route_prefix).is_some())
3115 {
3116 let accept_encoding = req.headers().get(http::header::ACCEPT_ENCODING).cloned();
3117 let if_none_match = req.headers().get(http::header::IF_NONE_MATCH).cloned();
3118 let range_header = req.headers().get(http::header::RANGE).cloned();
3119 let Some(stripped_path) = strip_mount_prefix(path, &mount.route_prefix) else {
3120 return Ok(req);
3121 };
3122 let response = match serve_static_file(
3123 &mount.directory,
3124 &stripped_path,
3125 static_assets,
3126 if_none_match.as_ref(),
3127 accept_encoding.as_ref(),
3128 range_header.as_ref(),
3129 )
3130 .await
3131 {
3132 Ok(response) => response,
3133 Err(_) => {
3134 return Err(Response::builder()
3135 .status(StatusCode::INTERNAL_SERVER_ERROR)
3136 .body(
3137 Full::new(Bytes::from("Failed to serve static asset"))
3138 .map_err(|never| match never {})
3139 .boxed(),
3140 )
3141 .unwrap_or_else(|_| {
3142 Response::new(
3143 Full::new(Bytes::new())
3144 .map_err(|never| match never {})
3145 .boxed(),
3146 )
3147 }));
3148 }
3149 };
3150 let mut response = apply_cache_control(response, static_assets.cache_control.as_deref());
3151 response = maybe_compress_static_response(
3152 response,
3153 accept_encoding,
3154 static_assets.enable_compression,
3155 );
3156 let (parts, body) = response.into_parts();
3157 return Err(Response::from_parts(
3158 parts,
3159 body.map_err(|never| match never {}).boxed(),
3160 ));
3161 }
3162
3163 if let Some(spa_file) = static_assets.spa_fallback.as_ref() {
3164 if looks_like_spa_request(path) {
3165 let accept_encoding = req.headers().get(http::header::ACCEPT_ENCODING).cloned();
3166 let response = match serve_single_file(spa_file).await {
3167 Ok(response) => response,
3168 Err(_) => {
3169 return Err(Response::builder()
3170 .status(StatusCode::INTERNAL_SERVER_ERROR)
3171 .body(
3172 Full::new(Bytes::from("Failed to serve SPA fallback"))
3173 .map_err(|never| match never {})
3174 .boxed(),
3175 )
3176 .unwrap_or_else(|_| {
3177 Response::new(
3178 Full::new(Bytes::new())
3179 .map_err(|never| match never {})
3180 .boxed(),
3181 )
3182 }));
3183 }
3184 };
3185 let mut response =
3186 apply_cache_control(response, static_assets.cache_control.as_deref());
3187 response = maybe_compress_static_response(
3188 response,
3189 accept_encoding,
3190 static_assets.enable_compression,
3191 );
3192 let (parts, body) = response.into_parts();
3193 return Err(Response::from_parts(
3194 parts,
3195 body.map_err(|never| match never {}).boxed(),
3196 ));
3197 }
3198 }
3199
3200 Ok(req)
3201}
3202
3203fn strip_mount_prefix(path: &str, prefix: &str) -> Option<String> {
3204 let normalized_prefix = if prefix == "/" {
3205 "/"
3206 } else {
3207 prefix.trim_end_matches('/')
3208 };
3209
3210 if normalized_prefix == "/" {
3211 return Some(path.to_string());
3212 }
3213
3214 if path == normalized_prefix {
3215 return Some("/".to_string());
3216 }
3217
3218 let with_slash = format!("{normalized_prefix}/");
3219 path.strip_prefix(&with_slash)
3220 .map(|stripped| format!("/{}", stripped))
3221}
3222
3223fn looks_like_spa_request(path: &str) -> bool {
3224 let tail = path.rsplit('/').next().unwrap_or_default();
3225 !tail.contains('.')
3226}
3227
3228fn maybe_compress_static_response(
3229 response: Response<Full<Bytes>>,
3230 accept_encoding: Option<http::HeaderValue>,
3231 enable_compression: bool,
3232) -> Response<Full<Bytes>> {
3233 if !enable_compression {
3234 return response;
3235 }
3236
3237 let Some(accept_encoding) = accept_encoding else {
3238 return response;
3239 };
3240
3241 let accept_str = accept_encoding.to_str().unwrap_or("");
3242 if !accept_str.contains("gzip") {
3243 return response;
3244 }
3245
3246 let status = response.status();
3247 let headers = response.headers().clone();
3248 let body = response.into_body();
3249
3250 let data = futures_util::FutureExt::now_or_never(BodyExt::collect(body))
3252 .and_then(|r| r.ok())
3253 .map(|collected| collected.to_bytes())
3254 .unwrap_or_default();
3255
3256 let compressed = {
3258 use flate2::write::GzEncoder;
3259 use flate2::Compression;
3260 use std::io::Write;
3261 let mut encoder = GzEncoder::new(Vec::new(), Compression::default());
3262 let _ = encoder.write_all(&data);
3263 encoder.finish().unwrap_or_default()
3264 };
3265
3266 let mut builder = Response::builder().status(status);
3267 for (name, value) in headers.iter() {
3268 if name != http::header::CONTENT_LENGTH && name != http::header::CONTENT_ENCODING {
3269 builder = builder.header(name, value);
3270 }
3271 }
3272 builder
3273 .header(http::header::CONTENT_ENCODING, "gzip")
3274 .body(Full::new(Bytes::from(compressed)))
3275 .unwrap_or_else(|_| Response::new(Full::new(Bytes::new())))
3276}
3277
3278async fn run_named_health_checks<R>(
3279 checks: &[NamedHealthCheck<R>],
3280 resources: Arc<R>,
3281) -> (bool, Vec<HealthCheckReport>)
3282where
3283 R: ranvier_core::transition::ResourceRequirement + Clone + Send + Sync + 'static,
3284{
3285 let mut reports = Vec::with_capacity(checks.len());
3286 let mut healthy = true;
3287
3288 for check in checks {
3289 match (check.check)(resources.clone()).await {
3290 Ok(()) => reports.push(HealthCheckReport {
3291 name: check.name.clone(),
3292 status: "ok",
3293 error: None,
3294 }),
3295 Err(error) => {
3296 healthy = false;
3297 reports.push(HealthCheckReport {
3298 name: check.name.clone(),
3299 status: "error",
3300 error: Some(error),
3301 });
3302 }
3303 }
3304 }
3305
3306 (healthy, reports)
3307}
3308
3309fn health_json_response(
3310 probe: &'static str,
3311 healthy: bool,
3312 checks: Vec<HealthCheckReport>,
3313) -> HttpResponse {
3314 let status_code = if healthy {
3315 StatusCode::OK
3316 } else {
3317 StatusCode::SERVICE_UNAVAILABLE
3318 };
3319 let status = if healthy { "ok" } else { "degraded" };
3320 let payload = HealthReport {
3321 status,
3322 probe,
3323 checks,
3324 };
3325
3326 let body = serde_json::to_vec(&payload)
3327 .unwrap_or_else(|_| br#"{"status":"error","probe":"health"}"#.to_vec());
3328
3329 Response::builder()
3330 .status(status_code)
3331 .header(http::header::CONTENT_TYPE, "application/json")
3332 .body(
3333 Full::new(Bytes::from(body))
3334 .map_err(|never| match never {})
3335 .boxed(),
3336 )
3337 .expect("valid HTTP response construction")
3338}
3339
3340async fn shutdown_signal() {
3341 #[cfg(unix)]
3342 {
3343 use tokio::signal::unix::{SignalKind, signal};
3344
3345 match signal(SignalKind::terminate()) {
3346 Ok(mut terminate) => {
3347 tokio::select! {
3348 _ = tokio::signal::ctrl_c() => {}
3349 _ = terminate.recv() => {}
3350 }
3351 }
3352 Err(err) => {
3353 tracing::warn!("Failed to install SIGTERM handler: {:?}", err);
3354 if let Err(ctrl_c_err) = tokio::signal::ctrl_c().await {
3355 tracing::warn!("Failed to listen for Ctrl+C: {:?}", ctrl_c_err);
3356 }
3357 }
3358 }
3359 }
3360
3361 #[cfg(not(unix))]
3362 {
3363 if let Err(err) = tokio::signal::ctrl_c().await {
3364 tracing::warn!("Failed to listen for Ctrl+C: {:?}", err);
3365 }
3366 }
3367}
3368
3369async fn drain_connections(
3370 connections: &mut tokio::task::JoinSet<()>,
3371 graceful_shutdown_timeout: Duration,
3372) -> bool {
3373 if connections.is_empty() {
3374 return false;
3375 }
3376
3377 let drain_result = tokio::time::timeout(graceful_shutdown_timeout, async {
3378 while let Some(join_result) = connections.join_next().await {
3379 if let Err(err) = join_result {
3380 tracing::warn!("Connection task join error during shutdown: {:?}", err);
3381 }
3382 }
3383 })
3384 .await;
3385
3386 if drain_result.is_err() {
3387 tracing::warn!(
3388 "Graceful shutdown timeout reached ({:?}). Aborting remaining connections.",
3389 graceful_shutdown_timeout
3390 );
3391 connections.abort_all();
3392 while let Some(join_result) = connections.join_next().await {
3393 if let Err(err) = join_result {
3394 tracing::warn!("Connection task abort join error: {:?}", err);
3395 }
3396 }
3397 true
3398 } else {
3399 false
3400 }
3401}
3402
3403#[cfg(feature = "tls")]
3405fn build_tls_acceptor(
3406 cert_path: &str,
3407 key_path: &str,
3408) -> Result<tokio_rustls::TlsAcceptor, Box<dyn std::error::Error + Send + Sync>> {
3409 use rustls::ServerConfig;
3410 use rustls_pemfile::{certs, private_key};
3411 use std::io::BufReader;
3412 use tokio_rustls::TlsAcceptor;
3413
3414 let cert_file = std::fs::File::open(cert_path)
3415 .map_err(|e| format!("Failed to open certificate file '{}': {}", cert_path, e))?;
3416 let key_file = std::fs::File::open(key_path)
3417 .map_err(|e| format!("Failed to open key file '{}': {}", key_path, e))?;
3418
3419 let cert_chain: Vec<_> = certs(&mut BufReader::new(cert_file))
3420 .collect::<Result<Vec<_>, _>>()
3421 .map_err(|e| format!("Failed to parse certificate PEM: {}", e))?;
3422
3423 let key = private_key(&mut BufReader::new(key_file))
3424 .map_err(|e| format!("Failed to parse private key PEM: {}", e))?
3425 .ok_or("No private key found in key file")?;
3426
3427 let config = ServerConfig::builder()
3428 .with_no_client_auth()
3429 .with_single_cert(cert_chain, key)
3430 .map_err(|e| format!("TLS configuration error: {}", e))?;
3431
3432 Ok(TlsAcceptor::from(Arc::new(config)))
3433}
3434
3435impl<R> Default for HttpIngress<R>
3436where
3437 R: ranvier_core::transition::ResourceRequirement + Clone + Send + Sync + 'static,
3438{
3439 fn default() -> Self {
3440 Self::new()
3441 }
3442}
3443
3444#[derive(Clone)]
3446pub struct RawIngressService<R> {
3447 routes: Arc<Vec<RouteEntry<R>>>,
3448 fallback: Option<RouteHandler<R>>,
3449 layers: Arc<Vec<ServiceLayer>>,
3450 health: Arc<HealthConfig<R>>,
3451 static_assets: Arc<StaticAssetsConfig>,
3452 preflight_config: Arc<Option<PreflightConfig>>,
3453 resources: Arc<R>,
3454}
3455
3456impl<R> hyper::service::Service<Request<Incoming>> for RawIngressService<R>
3457where
3458 R: ranvier_core::transition::ResourceRequirement + Clone + Send + Sync + 'static,
3459{
3460 type Response = HttpResponse;
3461 type Error = Infallible;
3462 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
3463
3464 fn call(&self, req: Request<Incoming>) -> Self::Future {
3465 let routes = self.routes.clone();
3466 let fallback = self.fallback.clone();
3467 let layers = self.layers.clone();
3468 let health = self.health.clone();
3469 let static_assets = self.static_assets.clone();
3470 let preflight_config = self.preflight_config.clone();
3471 let resources = self.resources.clone();
3472
3473 Box::pin(async move {
3474 let service = build_http_service(
3475 routes,
3476 fallback,
3477 resources,
3478 layers,
3479 health,
3480 static_assets,
3481 preflight_config,
3482 #[cfg(feature = "http3")]
3483 None,
3484 );
3485 service.call(req).await
3486 })
3487 }
3488}
3489
3490#[cfg(test)]
3491mod tests {
3492 use super::*;
3493 use async_trait::async_trait;
3494 use futures_util::{SinkExt, StreamExt};
3495 use serde::Deserialize;
3496 use std::fs;
3497 use std::sync::atomic::{AtomicBool, Ordering};
3498 use tempfile::tempdir;
3499 use tokio::io::{AsyncReadExt, AsyncWriteExt};
3500 use tokio_tungstenite::tungstenite::Message as WsClientMessage;
3501 use tokio_tungstenite::tungstenite::client::IntoClientRequest;
3502
3503 async fn connect_with_retry(addr: std::net::SocketAddr) -> tokio::net::TcpStream {
3504 let deadline = tokio::time::Instant::now() + Duration::from_secs(2);
3505
3506 loop {
3507 match tokio::net::TcpStream::connect(addr).await {
3508 Ok(stream) => return stream,
3509 Err(error) => {
3510 if tokio::time::Instant::now() >= deadline {
3511 panic!("connect server: {error}");
3512 }
3513 tokio::time::sleep(Duration::from_millis(25)).await;
3514 }
3515 }
3516 }
3517 }
3518
3519 #[test]
3520 fn route_pattern_matches_static_path() {
3521 let pattern = RoutePattern::parse("/orders/list");
3522 let params = pattern.match_path("/orders/list").expect("should match");
3523 assert!(params.into_inner().is_empty());
3524 }
3525
3526 #[test]
3527 fn route_pattern_matches_param_segments() {
3528 let pattern = RoutePattern::parse("/orders/:id/items/:item_id");
3529 let params = pattern
3530 .match_path("/orders/42/items/sku-123")
3531 .expect("should match");
3532 assert_eq!(params.get("id"), Some("42"));
3533 assert_eq!(params.get("item_id"), Some("sku-123"));
3534 }
3535
3536 #[test]
3537 fn route_pattern_matches_wildcard_segment() {
3538 let pattern = RoutePattern::parse("/assets/*path");
3539 let params = pattern
3540 .match_path("/assets/css/theme/light.css")
3541 .expect("should match");
3542 assert_eq!(params.get("path"), Some("css/theme/light.css"));
3543 }
3544
3545 #[test]
3546 fn route_pattern_rejects_non_matching_path() {
3547 let pattern = RoutePattern::parse("/orders/:id");
3548 assert!(pattern.match_path("/users/42").is_none());
3549 }
3550
3551 #[test]
3552 fn graceful_shutdown_timeout_defaults_to_30_seconds() {
3553 let ingress = HttpIngress::<()>::new();
3554 assert_eq!(ingress.graceful_shutdown_timeout, Duration::from_secs(30));
3555 assert!(ingress.layers.is_empty());
3556 assert!(ingress.bus_injectors.is_empty());
3557 assert!(ingress.static_assets.mounts.is_empty());
3558 assert!(ingress.on_start.is_none());
3559 assert!(ingress.on_shutdown.is_none());
3560 }
3561
3562 #[test]
3563 fn route_without_layer_keeps_empty_route_middleware_stack() {
3564 let ingress =
3565 HttpIngress::<()>::new().get("/ping", Axon::<(), (), String, ()>::new("Ping"));
3566 assert_eq!(ingress.routes.len(), 1);
3567 assert!(ingress.routes[0].layers.is_empty());
3568 assert!(ingress.routes[0].apply_global_layers);
3569 }
3570
3571 #[test]
3572 fn timeout_layer_registers_builtin_middleware() {
3573 let ingress = HttpIngress::<()>::new().timeout_layer(Duration::from_secs(1));
3574 assert_eq!(ingress.layers.len(), 1);
3575 }
3576
3577 #[test]
3578 fn request_id_layer_registers_builtin_middleware() {
3579 let ingress = HttpIngress::<()>::new().request_id_layer();
3580 assert_eq!(ingress.layers.len(), 1);
3581 }
3582
3583 #[test]
3584 fn compression_layer_registers_builtin_middleware() {
3585 let ingress = HttpIngress::<()>::new().compression_layer();
3586 assert!(ingress.static_assets.enable_compression);
3587 }
3588
3589 #[test]
3590 fn bus_injector_registration_adds_hook() {
3591 let ingress = HttpIngress::<()>::new().bus_injector(|_req, bus| {
3592 bus.insert("ok".to_string());
3593 });
3594 assert_eq!(ingress.bus_injectors.len(), 1);
3595 }
3596
3597 #[test]
3598 fn ws_route_registers_get_route_pattern() {
3599 let ingress =
3600 HttpIngress::<()>::new().ws("/ws/events", |_socket, _resources, _bus| async {});
3601 assert_eq!(ingress.routes.len(), 1);
3602 assert_eq!(ingress.routes[0].method, Method::GET);
3603 assert_eq!(ingress.routes[0].pattern.raw, "/ws/events");
3604 }
3605
3606 #[derive(Debug, Deserialize)]
3607 struct WsWelcomeFrame {
3608 connection_id: String,
3609 path: String,
3610 tenant: String,
3611 }
3612
3613 #[tokio::test]
3614 async fn ws_route_upgrades_and_bridges_event_source_sink_with_connection_bus() {
3615 let probe = std::net::TcpListener::bind("127.0.0.1:0").expect("bind probe");
3616 let addr = probe.local_addr().expect("local addr");
3617 drop(probe);
3618
3619 let ingress = HttpIngress::<()>::new()
3620 .bind(addr.to_string())
3621 .bus_injector(|req, bus| {
3622 if let Some(value) = req.headers.get("x-tenant-id").and_then(|v| v.to_str().ok()) {
3623 bus.insert(value.to_string());
3624 }
3625 })
3626 .ws("/ws/echo", |mut socket, _resources, bus| async move {
3627 let tenant = bus
3628 .read::<String>()
3629 .cloned()
3630 .unwrap_or_else(|| "unknown".to_string());
3631 if let Some(session) = bus.read::<WebSocketSessionContext>() {
3632 let welcome = serde_json::json!({
3633 "connection_id": session.connection_id().to_string(),
3634 "path": session.path(),
3635 "tenant": tenant,
3636 });
3637 let _ = socket.send_json(&welcome).await;
3638 }
3639
3640 while let Some(event) = socket.next_event().await {
3641 match event {
3642 WebSocketEvent::Text(text) => {
3643 let _ = socket.send_event(format!("echo:{text}")).await;
3644 }
3645 WebSocketEvent::Binary(bytes) => {
3646 let _ = socket.send_event(bytes).await;
3647 }
3648 WebSocketEvent::Close => break,
3649 WebSocketEvent::Ping(_) | WebSocketEvent::Pong(_) => {}
3650 }
3651 }
3652 });
3653
3654 let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel::<()>();
3655 let server = tokio::spawn(async move {
3656 ingress
3657 .run_with_shutdown_signal((), async move {
3658 let _ = shutdown_rx.await;
3659 })
3660 .await
3661 });
3662
3663 let ws_uri = format!("ws://{addr}/ws/echo?room=alpha");
3664 let mut ws_request = ws_uri
3665 .as_str()
3666 .into_client_request()
3667 .expect("ws client request");
3668 ws_request
3669 .headers_mut()
3670 .insert("x-tenant-id", http::HeaderValue::from_static("acme"));
3671 let (mut client, _response) = tokio_tungstenite::connect_async(ws_request)
3672 .await
3673 .expect("websocket connect");
3674
3675 let welcome = client
3676 .next()
3677 .await
3678 .expect("welcome frame")
3679 .expect("welcome frame ok");
3680 let welcome_text = match welcome {
3681 WsClientMessage::Text(text) => text.to_string(),
3682 other => panic!("expected text welcome frame, got {other:?}"),
3683 };
3684 let welcome_payload: WsWelcomeFrame =
3685 serde_json::from_str(&welcome_text).expect("welcome json");
3686 assert_eq!(welcome_payload.path, "/ws/echo");
3687 assert_eq!(welcome_payload.tenant, "acme");
3688 assert!(!welcome_payload.connection_id.is_empty());
3689
3690 client
3691 .send(WsClientMessage::Text("hello".into()))
3692 .await
3693 .expect("send text");
3694 let echo_text = client
3695 .next()
3696 .await
3697 .expect("echo text frame")
3698 .expect("echo text frame ok");
3699 assert_eq!(echo_text, WsClientMessage::Text("echo:hello".into()));
3700
3701 client
3702 .send(WsClientMessage::Binary(vec![1, 2, 3, 4].into()))
3703 .await
3704 .expect("send binary");
3705 let echo_binary = client
3706 .next()
3707 .await
3708 .expect("echo binary frame")
3709 .expect("echo binary frame ok");
3710 assert_eq!(
3711 echo_binary,
3712 WsClientMessage::Binary(vec![1, 2, 3, 4].into())
3713 );
3714
3715 client.close(None).await.expect("close websocket");
3716
3717 let _ = shutdown_tx.send(());
3718 server
3719 .await
3720 .expect("server join")
3721 .expect("server shutdown should succeed");
3722 }
3723
3724 #[test]
3725 fn route_descriptors_export_http_and_health_paths() {
3726 let ingress = HttpIngress::<()>::new()
3727 .get(
3728 "/orders/:id",
3729 Axon::<(), (), String, ()>::new("OrderById"),
3730 )
3731 .health_endpoint("/healthz")
3732 .readiness_liveness("/readyz", "/livez");
3733
3734 let descriptors = ingress.route_descriptors();
3735
3736 assert!(
3737 descriptors
3738 .iter()
3739 .any(|descriptor| descriptor.method() == Method::GET
3740 && descriptor.path_pattern() == "/orders/:id")
3741 );
3742 assert!(
3743 descriptors
3744 .iter()
3745 .any(|descriptor| descriptor.method() == Method::GET
3746 && descriptor.path_pattern() == "/healthz")
3747 );
3748 assert!(
3749 descriptors
3750 .iter()
3751 .any(|descriptor| descriptor.method() == Method::GET
3752 && descriptor.path_pattern() == "/readyz")
3753 );
3754 assert!(
3755 descriptors
3756 .iter()
3757 .any(|descriptor| descriptor.method() == Method::GET
3758 && descriptor.path_pattern() == "/livez")
3759 );
3760 }
3761
3762 #[tokio::test]
3763 async fn lifecycle_hooks_fire_on_start_and_shutdown() {
3764 let started = Arc::new(AtomicBool::new(false));
3765 let shutdown = Arc::new(AtomicBool::new(false));
3766
3767 let started_flag = started.clone();
3768 let shutdown_flag = shutdown.clone();
3769
3770 let ingress = HttpIngress::<()>::new()
3771 .bind("127.0.0.1:0")
3772 .on_start(move || {
3773 started_flag.store(true, Ordering::SeqCst);
3774 })
3775 .on_shutdown(move || {
3776 shutdown_flag.store(true, Ordering::SeqCst);
3777 })
3778 .graceful_shutdown(Duration::from_millis(50));
3779
3780 ingress
3781 .run_with_shutdown_signal((), async {
3782 tokio::time::sleep(Duration::from_millis(20)).await;
3783 })
3784 .await
3785 .expect("server should exit gracefully");
3786
3787 assert!(started.load(Ordering::SeqCst));
3788 assert!(shutdown.load(Ordering::SeqCst));
3789 }
3790
3791 #[tokio::test]
3792 async fn graceful_shutdown_drains_in_flight_requests_before_exit() {
3793 #[derive(Clone)]
3794 struct SlowDrainRoute;
3795
3796 #[async_trait]
3797 impl Transition<(), String> for SlowDrainRoute {
3798 type Error = String;
3799 type Resources = ();
3800
3801 async fn run(
3802 &self,
3803 _state: (),
3804 _resources: &Self::Resources,
3805 _bus: &mut Bus,
3806 ) -> Outcome<String, Self::Error> {
3807 tokio::time::sleep(Duration::from_millis(120)).await;
3808 Outcome::next("drained-ok".to_string())
3809 }
3810 }
3811
3812 let probe = std::net::TcpListener::bind("127.0.0.1:0").expect("bind probe");
3813 let addr = probe.local_addr().expect("local addr");
3814 drop(probe);
3815
3816 let ingress = HttpIngress::<()>::new()
3817 .bind(addr.to_string())
3818 .graceful_shutdown(Duration::from_millis(500))
3819 .get(
3820 "/drain",
3821 Axon::<(), (), String, ()>::new("SlowDrain").then(SlowDrainRoute),
3822 );
3823
3824 let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel::<()>();
3825 let server = tokio::spawn(async move {
3826 ingress
3827 .run_with_shutdown_signal((), async move {
3828 let _ = shutdown_rx.await;
3829 })
3830 .await
3831 });
3832
3833 let mut stream = connect_with_retry(addr).await;
3834 stream
3835 .write_all(b"GET /drain HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n")
3836 .await
3837 .expect("write request");
3838
3839 tokio::time::sleep(Duration::from_millis(20)).await;
3840 let _ = shutdown_tx.send(());
3841
3842 let mut buf = Vec::new();
3843 stream.read_to_end(&mut buf).await.expect("read response");
3844 let response = String::from_utf8_lossy(&buf);
3845 assert!(response.starts_with("HTTP/1.1 200"), "{response}");
3846 assert!(response.contains("drained-ok"), "{response}");
3847
3848 server
3849 .await
3850 .expect("server join")
3851 .expect("server shutdown should succeed");
3852 }
3853
3854 #[tokio::test]
3855 async fn serve_dir_serves_static_file_with_cache_and_metadata_headers() {
3856 let temp = tempdir().expect("tempdir");
3857 let root = temp.path().join("public");
3858 fs::create_dir_all(&root).expect("create dir");
3859 let file = root.join("hello.txt");
3860 fs::write(&file, "hello static").expect("write file");
3861
3862 let ingress =
3863 Ranvier::http::<()>().serve_dir("/static", root.to_string_lossy().to_string());
3864 let app = crate::test_harness::TestApp::new(ingress, ());
3865 let response = app
3866 .send(crate::test_harness::TestRequest::get("/static/hello.txt"))
3867 .await
3868 .expect("request should succeed");
3869
3870 assert_eq!(response.status(), StatusCode::OK);
3871 assert_eq!(response.text().expect("utf8"), "hello static");
3872 assert!(response.header("cache-control").is_some());
3873 let has_metadata_header =
3874 response.header("etag").is_some() || response.header("last-modified").is_some();
3875 assert!(has_metadata_header);
3876 }
3877
3878 #[tokio::test]
3879 async fn spa_fallback_returns_index_for_unmatched_path() {
3880 let temp = tempdir().expect("tempdir");
3881 let index = temp.path().join("index.html");
3882 fs::write(&index, "<html><body>spa</body></html>").expect("write index");
3883
3884 let ingress = Ranvier::http::<()>().spa_fallback(index.to_string_lossy().to_string());
3885 let app = crate::test_harness::TestApp::new(ingress, ());
3886 let response = app
3887 .send(crate::test_harness::TestRequest::get("/dashboard/settings"))
3888 .await
3889 .expect("request should succeed");
3890
3891 assert_eq!(response.status(), StatusCode::OK);
3892 assert!(response.text().expect("utf8").contains("spa"));
3893 }
3894
3895 #[tokio::test]
3896 async fn static_compression_layer_sets_content_encoding_for_gzip_client() {
3897 let temp = tempdir().expect("tempdir");
3898 let root = temp.path().join("public");
3899 fs::create_dir_all(&root).expect("create dir");
3900 let file = root.join("compressed.txt");
3901 fs::write(&file, "compress me ".repeat(400)).expect("write file");
3902
3903 let ingress = Ranvier::http::<()>()
3904 .serve_dir("/static", root.to_string_lossy().to_string())
3905 .compression_layer();
3906 let app = crate::test_harness::TestApp::new(ingress, ());
3907 let response = app
3908 .send(
3909 crate::test_harness::TestRequest::get("/static/compressed.txt")
3910 .header("accept-encoding", "gzip"),
3911 )
3912 .await
3913 .expect("request should succeed");
3914
3915 assert_eq!(response.status(), StatusCode::OK);
3916 assert_eq!(
3917 response
3918 .header("content-encoding")
3919 .and_then(|value| value.to_str().ok()),
3920 Some("gzip")
3921 );
3922 }
3923
3924 #[tokio::test]
3925 async fn drain_connections_completes_before_timeout() {
3926 let mut connections = tokio::task::JoinSet::new();
3927 connections.spawn(async {
3928 tokio::time::sleep(Duration::from_millis(20)).await;
3929 });
3930
3931 let timed_out = drain_connections(&mut connections, Duration::from_millis(200)).await;
3932 assert!(!timed_out);
3933 assert!(connections.is_empty());
3934 }
3935
3936 #[tokio::test]
3937 async fn drain_connections_times_out_and_aborts() {
3938 let mut connections = tokio::task::JoinSet::new();
3939 connections.spawn(async {
3940 tokio::time::sleep(Duration::from_secs(10)).await;
3941 });
3942
3943 let timed_out = drain_connections(&mut connections, Duration::from_millis(10)).await;
3944 assert!(timed_out);
3945 assert!(connections.is_empty());
3946 }
3947
3948 #[tokio::test]
3949 async fn timeout_layer_returns_408_for_slow_route() {
3950 #[derive(Clone)]
3951 struct SlowRoute;
3952
3953 #[async_trait]
3954 impl Transition<(), String> for SlowRoute {
3955 type Error = String;
3956 type Resources = ();
3957
3958 async fn run(
3959 &self,
3960 _state: (),
3961 _resources: &Self::Resources,
3962 _bus: &mut Bus,
3963 ) -> Outcome<String, Self::Error> {
3964 tokio::time::sleep(Duration::from_millis(80)).await;
3965 Outcome::next("slow-ok".to_string())
3966 }
3967 }
3968
3969 let probe = std::net::TcpListener::bind("127.0.0.1:0").expect("bind probe");
3970 let addr = probe.local_addr().expect("local addr");
3971 drop(probe);
3972
3973 let ingress = HttpIngress::<()>::new()
3974 .bind(addr.to_string())
3975 .timeout_layer(Duration::from_millis(10))
3976 .get(
3977 "/slow",
3978 Axon::<(), (), String, ()>::new("Slow").then(SlowRoute),
3979 );
3980
3981 let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel::<()>();
3982 let server = tokio::spawn(async move {
3983 ingress
3984 .run_with_shutdown_signal((), async move {
3985 let _ = shutdown_rx.await;
3986 })
3987 .await
3988 });
3989
3990 let mut stream = connect_with_retry(addr).await;
3991 stream
3992 .write_all(b"GET /slow HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n")
3993 .await
3994 .expect("write request");
3995
3996 let mut buf = Vec::new();
3997 stream.read_to_end(&mut buf).await.expect("read response");
3998 let response = String::from_utf8_lossy(&buf);
3999 assert!(response.starts_with("HTTP/1.1 408"), "{response}");
4000
4001 let _ = shutdown_tx.send(());
4002 server
4003 .await
4004 .expect("server join")
4005 .expect("server shutdown should succeed");
4006 }
4007
4008 fn extract_body(response: Response<Full<Bytes>>) -> Bytes {
4011 use http_body_util::BodyExt;
4012 let rt = tokio::runtime::Builder::new_current_thread()
4013 .build()
4014 .unwrap();
4015 rt.block_on(async {
4016 let collected = response.into_body().collect().await.unwrap();
4017 collected.to_bytes()
4018 })
4019 }
4020
4021 #[test]
4022 fn handle_range_bytes_start_end() {
4023 let content = b"Hello, World!";
4024 let range = http::HeaderValue::from_static("bytes=0-4");
4025 let response =
4026 super::handle_range_request(&range, content, "text/plain", None, None).unwrap();
4027 assert_eq!(response.status(), StatusCode::PARTIAL_CONTENT);
4028 assert_eq!(
4029 response.headers().get(http::header::CONTENT_RANGE).unwrap(),
4030 "bytes 0-4/13"
4031 );
4032 assert_eq!(extract_body(response), "Hello");
4033 }
4034
4035 #[test]
4036 fn handle_range_suffix() {
4037 let content = b"Hello, World!";
4038 let range = http::HeaderValue::from_static("bytes=-6");
4039 let response =
4040 super::handle_range_request(&range, content, "text/plain", None, None).unwrap();
4041 assert_eq!(response.status(), StatusCode::PARTIAL_CONTENT);
4042 assert_eq!(
4043 response.headers().get(http::header::CONTENT_RANGE).unwrap(),
4044 "bytes 7-12/13"
4045 );
4046 }
4047
4048 #[test]
4049 fn handle_range_from_offset() {
4050 let content = b"Hello, World!";
4051 let range = http::HeaderValue::from_static("bytes=7-");
4052 let response =
4053 super::handle_range_request(&range, content, "text/plain", None, None).unwrap();
4054 assert_eq!(response.status(), StatusCode::PARTIAL_CONTENT);
4055 assert_eq!(
4056 response.headers().get(http::header::CONTENT_RANGE).unwrap(),
4057 "bytes 7-12/13"
4058 );
4059 }
4060
4061 #[test]
4062 fn handle_range_out_of_bounds_returns_416() {
4063 let content = b"Hello";
4064 let range = http::HeaderValue::from_static("bytes=10-20");
4065 let response =
4066 super::handle_range_request(&range, content, "text/plain", None, None).unwrap();
4067 assert_eq!(response.status(), StatusCode::RANGE_NOT_SATISFIABLE);
4068 assert_eq!(
4069 response.headers().get(http::header::CONTENT_RANGE).unwrap(),
4070 "bytes */5"
4071 );
4072 }
4073
4074 #[test]
4075 fn handle_range_includes_accept_ranges_header() {
4076 let content = b"Hello, World!";
4077 let range = http::HeaderValue::from_static("bytes=0-0");
4078 let response =
4079 super::handle_range_request(&range, content, "text/plain", None, None).unwrap();
4080 assert_eq!(
4081 response.headers().get(http::header::ACCEPT_RANGES).unwrap(),
4082 "bytes"
4083 );
4084 }
4085
4086}