1use base64::Engine;
17use bytes::Bytes;
18use futures_util::{SinkExt, StreamExt};
19use http::{Method, Request, Response, StatusCode};
20use http_body_util::{BodyExt, Full};
21use hyper::body::Incoming;
22use hyper::server::conn::http1;
23use hyper::upgrade::Upgraded;
24use hyper_util::rt::TokioIo;
25use ranvier_core::event::{EventSink, EventSource};
26use ranvier_core::prelude::*;
27use ranvier_runtime::Axon;
28use serde::Serialize;
29use serde::de::DeserializeOwned;
30use sha1::{Digest, Sha1};
31use std::collections::HashMap;
32use std::convert::Infallible;
33use std::future::Future;
34use std::net::SocketAddr;
35use std::pin::Pin;
36use std::sync::Arc;
37use std::time::Duration;
38use tokio::net::TcpListener;
39use tokio::sync::Mutex;
40use tokio_tungstenite::WebSocketStream;
41use tokio_tungstenite::tungstenite::{Error as WsWireError, Message as WsWireMessage};
42use tracing::Instrument;
43
44use crate::guard_integration::{
45 GuardExec, GuardIntegration, PreflightConfig, RegisteredGuard, ResponseBodyTransformFn,
46 ResponseExtractorFn,
47};
48use crate::response::{HttpResponse, IntoResponse, json_error_response, outcome_to_response_with_error};
49
50pub struct Ranvier;
55
56impl Ranvier {
57 pub fn http<R>() -> HttpIngress<R>
59 where
60 R: ranvier_core::transition::ResourceRequirement + Clone,
61 {
62 HttpIngress::new()
63 }
64}
65
66type RouteHandler<R> = Arc<
68 dyn Fn(http::request::Parts, &R) -> Pin<Box<dyn Future<Output = HttpResponse> + Send>>
69 + Send
70 + Sync,
71>;
72
73#[derive(Clone)]
75struct BoxService(
76 Arc<
77 dyn Fn(Request<Incoming>) -> Pin<Box<dyn Future<Output = Result<HttpResponse, Infallible>> + Send>>
78 + Send
79 + Sync,
80 >,
81);
82
83impl BoxService {
84 fn new<F, Fut>(f: F) -> Self
85 where
86 F: Fn(Request<Incoming>) -> Fut + Send + Sync + 'static,
87 Fut: Future<Output = Result<HttpResponse, Infallible>> + Send + 'static,
88 {
89 Self(Arc::new(move |req| Box::pin(f(req))))
90 }
91
92 fn call(&self, req: Request<Incoming>) -> Pin<Box<dyn Future<Output = Result<HttpResponse, Infallible>> + Send>> {
93 (self.0)(req)
94 }
95}
96
97impl hyper::service::Service<Request<Incoming>> for BoxService {
98 type Response = HttpResponse;
99 type Error = Infallible;
100 type Future = Pin<Box<dyn Future<Output = Result<HttpResponse, Infallible>> + Send>>;
101
102 fn call(&self, req: Request<Incoming>) -> Self::Future {
103 (self.0)(req)
104 }
105}
106
107type BoxHttpService = BoxService;
108type ServiceLayer = Arc<dyn Fn(BoxHttpService) -> BoxHttpService + Send + Sync>;
109type LifecycleHook = Arc<dyn Fn() + Send + Sync>;
110type BusInjector = Arc<dyn Fn(&http::request::Parts, &mut Bus) + Send + Sync + 'static>;
111type WsSessionFuture = Pin<Box<dyn Future<Output = ()> + Send>>;
112type WsSessionHandler<R> =
113 Arc<dyn Fn(WebSocketConnection, Arc<R>, Bus) -> WsSessionFuture + Send + Sync>;
114type HealthCheckFuture = Pin<Box<dyn Future<Output = Result<(), String>> + Send>>;
115type HealthCheckFn<R> = Arc<dyn Fn(Arc<R>) -> HealthCheckFuture + Send + Sync>;
116const REQUEST_ID_HEADER: &str = "x-request-id";
117const WS_UPGRADE_TOKEN: &str = "websocket";
118const WS_GUID: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
119
120#[derive(Clone)]
121struct NamedHealthCheck<R> {
122 name: String,
123 check: HealthCheckFn<R>,
124}
125
126#[derive(Clone)]
127struct HealthConfig<R> {
128 health_path: Option<String>,
129 readiness_path: Option<String>,
130 liveness_path: Option<String>,
131 checks: Vec<NamedHealthCheck<R>>,
132}
133
134impl<R> Default for HealthConfig<R> {
135 fn default() -> Self {
136 Self {
137 health_path: None,
138 readiness_path: None,
139 liveness_path: None,
140 checks: Vec::new(),
141 }
142 }
143}
144
145#[derive(Clone, Default)]
146struct StaticAssetsConfig {
147 mounts: Vec<StaticMount>,
148 spa_fallback: Option<String>,
149 cache_control: Option<String>,
150 enable_compression: bool,
151 directory_index: Option<String>,
153 immutable_cache: bool,
156 serve_precompressed: bool,
158 enable_range_requests: bool,
160}
161
162#[derive(Clone)]
163struct StaticMount {
164 route_prefix: String,
165 directory: String,
166}
167
168#[cfg(feature = "tls")]
170#[derive(Clone)]
171struct TlsAcceptorConfig {
172 cert_path: String,
173 key_path: String,
174}
175
176#[derive(Serialize)]
177struct HealthReport {
178 status: &'static str,
179 probe: &'static str,
180 checks: Vec<HealthCheckReport>,
181}
182
183#[derive(Serialize)]
184struct HealthCheckReport {
185 name: String,
186 status: &'static str,
187 #[serde(skip_serializing_if = "Option::is_none")]
188 error: Option<String>,
189}
190
191fn timeout_middleware(timeout: Duration) -> ServiceLayer {
192 Arc::new(move |inner: BoxHttpService| {
193 BoxService::new(move |req: Request<Incoming>| {
194 let inner = inner.clone();
195 async move {
196 match tokio::time::timeout(timeout, inner.call(req)).await {
197 Ok(response) => response,
198 Err(_) => Ok(Response::builder()
199 .status(StatusCode::REQUEST_TIMEOUT)
200 .body(
201 Full::new(Bytes::from("Request Timeout"))
202 .map_err(|never| match never {})
203 .boxed(),
204 )
205 .expect("valid HTTP response construction")),
206 }
207 }
208 })
209 })
210}
211
212fn request_id_middleware() -> ServiceLayer {
213 Arc::new(move |inner: BoxHttpService| {
214 BoxService::new(move |req: Request<Incoming>| {
215 let inner = inner.clone();
216 async move {
217 let mut req = req;
218 let request_id = req
219 .headers()
220 .get(REQUEST_ID_HEADER)
221 .cloned()
222 .unwrap_or_else(|| {
223 http::HeaderValue::from_str(&uuid::Uuid::new_v4().to_string())
224 .unwrap_or_else(|_| {
225 http::HeaderValue::from_static("request-id-unavailable")
226 })
227 });
228 req.headers_mut()
229 .insert(REQUEST_ID_HEADER, request_id.clone());
230 let mut response = inner.call(req).await?;
231 response
232 .headers_mut()
233 .insert(REQUEST_ID_HEADER, request_id);
234 Ok(response)
235 }
236 })
237 })
238}
239
240#[derive(Clone, Debug, Default, PartialEq, Eq)]
241pub struct PathParams {
242 values: HashMap<String, String>,
243}
244
245#[derive(Clone, Debug)]
247pub struct HttpRouteDescriptor {
248 method: Method,
249 path_pattern: String,
250 pub body_schema: Option<serde_json::Value>,
252}
253
254impl HttpRouteDescriptor {
255 pub fn new(method: Method, path_pattern: impl Into<String>) -> Self {
256 Self {
257 method,
258 path_pattern: path_pattern.into(),
259 body_schema: None,
260 }
261 }
262
263 pub fn method(&self) -> &Method {
264 &self.method
265 }
266
267 pub fn path_pattern(&self) -> &str {
268 &self.path_pattern
269 }
270
271 pub fn body_schema(&self) -> Option<&serde_json::Value> {
276 self.body_schema.as_ref()
277 }
278}
279
280#[derive(Clone, Debug, PartialEq, Eq, Serialize)]
282pub struct WebSocketSessionContext {
283 connection_id: uuid::Uuid,
284 path: String,
285 query: Option<String>,
286}
287
288impl WebSocketSessionContext {
289 pub fn connection_id(&self) -> uuid::Uuid {
290 self.connection_id
291 }
292
293 pub fn path(&self) -> &str {
294 &self.path
295 }
296
297 pub fn query(&self) -> Option<&str> {
298 self.query.as_deref()
299 }
300}
301
302#[derive(Clone, Debug, PartialEq, Eq)]
304pub enum WebSocketEvent {
305 Text(String),
306 Binary(Vec<u8>),
307 Ping(Vec<u8>),
308 Pong(Vec<u8>),
309 Close,
310}
311
312impl WebSocketEvent {
313 pub fn text(value: impl Into<String>) -> Self {
314 Self::Text(value.into())
315 }
316
317 pub fn binary(value: impl Into<Vec<u8>>) -> Self {
318 Self::Binary(value.into())
319 }
320
321 pub fn json<T>(value: &T) -> Result<Self, serde_json::Error>
322 where
323 T: Serialize,
324 {
325 let text = serde_json::to_string(value)?;
326 Ok(Self::Text(text))
327 }
328}
329
330#[derive(Debug, thiserror::Error)]
331pub enum WebSocketError {
332 #[error("websocket wire error: {0}")]
333 Wire(#[from] WsWireError),
334 #[error("json serialization failed: {0}")]
335 JsonSerialize(#[source] serde_json::Error),
336 #[error("json deserialization failed: {0}")]
337 JsonDeserialize(#[source] serde_json::Error),
338 #[error("expected text or binary frame for json payload")]
339 NonDataFrame,
340}
341
342type WsServerStream = WebSocketStream<TokioIo<Upgraded>>;
343type WsServerSink = futures_util::stream::SplitSink<WsServerStream, WsWireMessage>;
344type WsServerSource = futures_util::stream::SplitStream<WsServerStream>;
345
346pub struct WebSocketConnection {
348 sink: Mutex<WsServerSink>,
349 source: Mutex<WsServerSource>,
350 session: WebSocketSessionContext,
351}
352
353impl WebSocketConnection {
354 fn new(stream: WsServerStream, session: WebSocketSessionContext) -> Self {
355 let (sink, source) = stream.split();
356 Self {
357 sink: Mutex::new(sink),
358 source: Mutex::new(source),
359 session,
360 }
361 }
362
363 pub fn session(&self) -> &WebSocketSessionContext {
364 &self.session
365 }
366
367 pub async fn send(&self, event: WebSocketEvent) -> Result<(), WebSocketError> {
368 let mut sink = self.sink.lock().await;
369 sink.send(event.into_wire_message()).await?;
370 Ok(())
371 }
372
373 pub async fn send_json<T>(&self, value: &T) -> Result<(), WebSocketError>
374 where
375 T: Serialize,
376 {
377 let event = WebSocketEvent::json(value).map_err(WebSocketError::JsonSerialize)?;
378 self.send(event).await
379 }
380
381 pub async fn next_json<T>(&mut self) -> Result<Option<T>, WebSocketError>
382 where
383 T: DeserializeOwned,
384 {
385 let Some(event) = self.recv_event().await? else {
386 return Ok(None);
387 };
388 match event {
389 WebSocketEvent::Text(text) => serde_json::from_str(&text)
390 .map(Some)
391 .map_err(WebSocketError::JsonDeserialize),
392 WebSocketEvent::Binary(bytes) => serde_json::from_slice(&bytes)
393 .map(Some)
394 .map_err(WebSocketError::JsonDeserialize),
395 _ => Err(WebSocketError::NonDataFrame),
396 }
397 }
398
399 async fn recv_event(&mut self) -> Result<Option<WebSocketEvent>, WsWireError> {
400 let mut source = self.source.lock().await;
401 while let Some(item) = source.next().await {
402 let message = item?;
403 if let Some(event) = WebSocketEvent::from_wire_message(message) {
404 return Ok(Some(event));
405 }
406 }
407 Ok(None)
408 }
409}
410
411impl WebSocketEvent {
412 fn from_wire_message(message: WsWireMessage) -> Option<Self> {
413 match message {
414 WsWireMessage::Text(value) => Some(Self::Text(value.to_string())),
415 WsWireMessage::Binary(value) => Some(Self::Binary(value.to_vec())),
416 WsWireMessage::Ping(value) => Some(Self::Ping(value.to_vec())),
417 WsWireMessage::Pong(value) => Some(Self::Pong(value.to_vec())),
418 WsWireMessage::Close(_) => Some(Self::Close),
419 WsWireMessage::Frame(_) => None,
420 }
421 }
422
423 fn into_wire_message(self) -> WsWireMessage {
424 match self {
425 Self::Text(value) => WsWireMessage::Text(value),
426 Self::Binary(value) => WsWireMessage::Binary(value),
427 Self::Ping(value) => WsWireMessage::Ping(value),
428 Self::Pong(value) => WsWireMessage::Pong(value),
429 Self::Close => WsWireMessage::Close(None),
430 }
431 }
432}
433
434#[async_trait::async_trait]
435impl EventSource<WebSocketEvent> for WebSocketConnection {
436 async fn next_event(&mut self) -> Option<WebSocketEvent> {
437 match self.recv_event().await {
438 Ok(event) => event,
439 Err(error) => {
440 tracing::warn!(ranvier.ws.error = %error, "websocket source read failed");
441 None
442 }
443 }
444 }
445}
446
447#[async_trait::async_trait]
448impl EventSink<WebSocketEvent> for WebSocketConnection {
449 type Error = WebSocketError;
450
451 async fn send_event(&self, event: WebSocketEvent) -> Result<(), Self::Error> {
452 self.send(event).await
453 }
454}
455
456#[async_trait::async_trait]
457impl EventSink<String> for WebSocketConnection {
458 type Error = WebSocketError;
459
460 async fn send_event(&self, event: String) -> Result<(), Self::Error> {
461 self.send(WebSocketEvent::Text(event)).await
462 }
463}
464
465#[async_trait::async_trait]
466impl EventSink<Vec<u8>> for WebSocketConnection {
467 type Error = WebSocketError;
468
469 async fn send_event(&self, event: Vec<u8>) -> Result<(), Self::Error> {
470 self.send(WebSocketEvent::Binary(event)).await
471 }
472}
473
474impl PathParams {
475 pub fn new(values: HashMap<String, String>) -> Self {
476 Self { values }
477 }
478
479 pub fn get(&self, key: &str) -> Option<&str> {
480 self.values.get(key).map(String::as_str)
481 }
482
483 pub fn as_map(&self) -> &HashMap<String, String> {
484 &self.values
485 }
486
487 pub fn into_inner(self) -> HashMap<String, String> {
488 self.values
489 }
490}
491
492#[derive(Clone, Debug, PartialEq, Eq)]
493enum RouteSegment {
494 Static(String),
495 Param(String),
496 Wildcard(String),
497}
498
499#[derive(Clone, Debug, PartialEq, Eq)]
500struct RoutePattern {
501 raw: String,
502 segments: Vec<RouteSegment>,
503}
504
505impl RoutePattern {
506 fn parse(path: &str) -> Self {
507 let segments = path_segments(path)
508 .into_iter()
509 .map(|segment| {
510 if let Some(name) = segment.strip_prefix(':') {
511 if !name.is_empty() {
512 return RouteSegment::Param(name.to_string());
513 }
514 }
515 if let Some(name) = segment.strip_prefix('*') {
516 if !name.is_empty() {
517 return RouteSegment::Wildcard(name.to_string());
518 }
519 }
520 RouteSegment::Static(segment.to_string())
521 })
522 .collect();
523
524 Self {
525 raw: path.to_string(),
526 segments,
527 }
528 }
529
530 fn match_path(&self, path: &str) -> Option<PathParams> {
531 let mut params = HashMap::new();
532 let path_segments = path_segments(path);
533 let mut pattern_index = 0usize;
534 let mut path_index = 0usize;
535
536 while pattern_index < self.segments.len() {
537 match &self.segments[pattern_index] {
538 RouteSegment::Static(expected) => {
539 let actual = path_segments.get(path_index)?;
540 if actual != expected {
541 return None;
542 }
543 pattern_index += 1;
544 path_index += 1;
545 }
546 RouteSegment::Param(name) => {
547 let actual = path_segments.get(path_index)?;
548 params.insert(name.clone(), (*actual).to_string());
549 pattern_index += 1;
550 path_index += 1;
551 }
552 RouteSegment::Wildcard(name) => {
553 let remaining = path_segments[path_index..].join("/");
554 params.insert(name.clone(), remaining);
555 pattern_index += 1;
556 path_index = path_segments.len();
557 break;
558 }
559 }
560 }
561
562 if pattern_index == self.segments.len() && path_index == path_segments.len() {
563 Some(PathParams::new(params))
564 } else {
565 None
566 }
567 }
568}
569
570#[derive(Clone)]
573struct BodyBytes(Bytes);
574
575#[derive(Clone)]
576struct RouteEntry<R> {
577 method: Method,
578 pattern: RoutePattern,
579 handler: RouteHandler<R>,
580 layers: Arc<Vec<ServiceLayer>>,
581 apply_global_layers: bool,
582 needs_body: bool,
585 body_schema: Option<serde_json::Value>,
587}
588
589fn path_segments(path: &str) -> Vec<&str> {
590 if path == "/" {
591 return Vec::new();
592 }
593
594 path.trim_matches('/')
595 .split('/')
596 .filter(|segment| !segment.is_empty())
597 .collect()
598}
599
600fn normalize_route_path(path: String) -> String {
601 if path.is_empty() {
602 return "/".to_string();
603 }
604 if path.starts_with('/') {
605 path
606 } else {
607 format!("/{path}")
608 }
609}
610
611fn find_matching_route<'a, R>(
612 routes: &'a [RouteEntry<R>],
613 method: &Method,
614 path: &str,
615) -> Option<(&'a RouteEntry<R>, PathParams)> {
616 for entry in routes {
617 if entry.method != *method {
618 continue;
619 }
620 if let Some(params) = entry.pattern.match_path(path) {
621 return Some((entry, params));
622 }
623 }
624 None
625}
626
627fn header_contains_token(
628 headers: &http::HeaderMap,
629 name: http::header::HeaderName,
630 token: &str,
631) -> bool {
632 headers
633 .get(name)
634 .and_then(|value| value.to_str().ok())
635 .map(|value| {
636 value
637 .split(',')
638 .any(|part| part.trim().eq_ignore_ascii_case(token))
639 })
640 .unwrap_or(false)
641}
642
643fn websocket_session_from_request<B>(req: &Request<B>) -> WebSocketSessionContext {
644 WebSocketSessionContext {
645 connection_id: uuid::Uuid::new_v4(),
646 path: req.uri().path().to_string(),
647 query: req.uri().query().map(str::to_string),
648 }
649}
650
651fn websocket_accept_key(client_key: &str) -> String {
652 let mut hasher = Sha1::new();
653 hasher.update(client_key.as_bytes());
654 hasher.update(WS_GUID.as_bytes());
655 let digest = hasher.finalize();
656 base64::engine::general_purpose::STANDARD.encode(digest)
657}
658
659fn websocket_bad_request(message: &'static str) -> HttpResponse {
660 Response::builder()
661 .status(StatusCode::BAD_REQUEST)
662 .body(
663 Full::new(Bytes::from(message))
664 .map_err(|never| match never {})
665 .boxed(),
666 )
667 .unwrap_or_else(|_| {
668 Response::new(
669 Full::new(Bytes::new())
670 .map_err(|never| match never {})
671 .boxed(),
672 )
673 })
674}
675
676fn websocket_upgrade_response<B>(
677 req: &mut Request<B>,
678) -> Result<(HttpResponse, hyper::upgrade::OnUpgrade), HttpResponse> {
679 if req.method() != Method::GET {
680 return Err(websocket_bad_request(
681 "WebSocket upgrade requires GET method",
682 ));
683 }
684
685 if !header_contains_token(req.headers(), http::header::CONNECTION, "upgrade") {
686 return Err(websocket_bad_request(
687 "Missing Connection: upgrade header for WebSocket",
688 ));
689 }
690
691 if !header_contains_token(req.headers(), http::header::UPGRADE, WS_UPGRADE_TOKEN) {
692 return Err(websocket_bad_request("Missing Upgrade: websocket header"));
693 }
694
695 if let Some(version) = req.headers().get("sec-websocket-version") {
696 if version != "13" {
697 return Err(websocket_bad_request(
698 "Unsupported Sec-WebSocket-Version (expected 13)",
699 ));
700 }
701 }
702
703 let Some(client_key) = req
704 .headers()
705 .get("sec-websocket-key")
706 .and_then(|value| value.to_str().ok())
707 else {
708 return Err(websocket_bad_request(
709 "Missing Sec-WebSocket-Key header for WebSocket",
710 ));
711 };
712
713 let accept_key = websocket_accept_key(client_key);
714 let on_upgrade = hyper::upgrade::on(req);
715 let response = Response::builder()
716 .status(StatusCode::SWITCHING_PROTOCOLS)
717 .header(http::header::UPGRADE, WS_UPGRADE_TOKEN)
718 .header(http::header::CONNECTION, "Upgrade")
719 .header("sec-websocket-accept", accept_key)
720 .body(
721 Full::new(Bytes::new())
722 .map_err(|never| match never {})
723 .boxed(),
724 )
725 .unwrap_or_else(|_| {
726 Response::new(
727 Full::new(Bytes::new())
728 .map_err(|never| match never {})
729 .boxed(),
730 )
731 });
732
733 Ok((response, on_upgrade))
734}
735
736pub struct HttpIngress<R = ()> {
742 addr: Option<String>,
744 routes: Vec<RouteEntry<R>>,
746 fallback: Option<RouteHandler<R>>,
748 layers: Vec<ServiceLayer>,
750 on_start: Option<LifecycleHook>,
752 on_shutdown: Option<LifecycleHook>,
754 graceful_shutdown_timeout: Duration,
756 bus_injectors: Vec<BusInjector>,
758 static_assets: StaticAssetsConfig,
760 health: HealthConfig<R>,
762 #[cfg(feature = "http3")]
763 http3_config: Option<crate::http3::Http3Config>,
764 #[cfg(feature = "http3")]
765 alt_svc_h3_port: Option<u16>,
766 #[cfg(feature = "tls")]
768 tls_config: Option<TlsAcceptorConfig>,
769 active_intervention: bool,
771 policy_registry: Option<ranvier_core::policy::PolicyRegistry>,
773 guard_execs: Vec<Arc<dyn GuardExec>>,
775 guard_response_extractors: Vec<ResponseExtractorFn>,
777 guard_body_transforms: Vec<ResponseBodyTransformFn>,
779 preflight_config: Option<PreflightConfig>,
781 _phantom: std::marker::PhantomData<R>,
782}
783
784impl<R> HttpIngress<R>
785where
786 R: ranvier_core::transition::ResourceRequirement + Clone + Send + Sync + 'static,
787{
788 pub fn new() -> Self {
790 Self {
791 addr: None,
792 routes: Vec::new(),
793 fallback: None,
794 layers: Vec::new(),
795 on_start: None,
796 on_shutdown: None,
797 graceful_shutdown_timeout: Duration::from_secs(30),
798 bus_injectors: Vec::new(),
799 static_assets: StaticAssetsConfig::default(),
800 health: HealthConfig::default(),
801 #[cfg(feature = "tls")]
802 tls_config: None,
803 #[cfg(feature = "http3")]
804 http3_config: None,
805 #[cfg(feature = "http3")]
806 alt_svc_h3_port: None,
807 active_intervention: false,
808 policy_registry: None,
809 guard_execs: Vec::new(),
810 guard_response_extractors: Vec::new(),
811 guard_body_transforms: Vec::new(),
812 preflight_config: None,
813 _phantom: std::marker::PhantomData,
814 }
815 }
816
817 pub fn bind(mut self, addr: impl Into<String>) -> Self {
821 self.addr = Some(addr.into());
822 self
823 }
824
825 pub fn active_intervention(mut self) -> Self {
831 self.active_intervention = true;
832 self
833 }
834
835 pub fn policy_registry(mut self, registry: ranvier_core::policy::PolicyRegistry) -> Self {
837 self.policy_registry = Some(registry);
838 self
839 }
840
841 pub fn on_start<F>(mut self, callback: F) -> Self
845 where
846 F: Fn() + Send + Sync + 'static,
847 {
848 self.on_start = Some(Arc::new(callback));
849 self
850 }
851
852 pub fn on_shutdown<F>(mut self, callback: F) -> Self
854 where
855 F: Fn() + Send + Sync + 'static,
856 {
857 self.on_shutdown = Some(Arc::new(callback));
858 self
859 }
860
861 pub fn graceful_shutdown(mut self, timeout: Duration) -> Self {
863 self.graceful_shutdown_timeout = timeout;
864 self
865 }
866
867 pub fn config(mut self, config: &ranvier_core::config::RanvierConfig) -> Self {
873 self.addr = Some(config.bind_addr());
874 self.graceful_shutdown_timeout = config.shutdown_timeout();
875 config.init_telemetry();
876 self
877 }
878
879 #[cfg(feature = "tls")]
881 pub fn tls(mut self, cert_path: impl Into<String>, key_path: impl Into<String>) -> Self {
882 self.tls_config = Some(TlsAcceptorConfig {
883 cert_path: cert_path.into(),
884 key_path: key_path.into(),
885 });
886 self
887 }
888
889 pub fn timeout_layer(mut self, timeout: Duration) -> Self {
894 self.layers.push(timeout_middleware(timeout));
895 self
896 }
897
898 pub fn request_id_layer(mut self) -> Self {
902 self.layers.push(request_id_middleware());
903 self
904 }
905
906 pub fn bus_injector<F>(mut self, injector: F) -> Self
911 where
912 F: Fn(&http::request::Parts, &mut Bus) + Send + Sync + 'static,
913 {
914 self.bus_injectors.push(Arc::new(injector));
915 self
916 }
917
918 #[cfg(feature = "htmx")]
926 pub fn htmx_support(mut self) -> Self {
927 self.bus_injectors
928 .push(Arc::new(crate::htmx::inject_htmx_headers));
929 self.guard_response_extractors
930 .push(Arc::new(crate::htmx::extract_htmx_response_headers));
931 self
932 }
933
934 pub fn guard(mut self, guard: impl GuardIntegration) -> Self {
954 let registration = guard.register();
955 for injector in registration.bus_injectors {
956 self.bus_injectors.push(injector);
957 }
958 self.guard_execs.push(registration.exec);
959 if let Some(extractor) = registration.response_extractor {
960 self.guard_response_extractors.push(extractor);
961 }
962 if let Some(transform) = registration.response_body_transform {
963 self.guard_body_transforms.push(transform);
964 }
965 if registration.handles_preflight {
966 if let Some(config) = registration.preflight_config {
967 self.preflight_config = Some(config);
968 }
969 }
970 self
971 }
972
973 #[cfg(feature = "http3")]
975 pub fn enable_http3(mut self, config: crate::http3::Http3Config) -> Self {
976 self.http3_config = Some(config);
977 self
978 }
979
980 #[cfg(feature = "http3")]
982 pub fn alt_svc_h3(mut self, port: u16) -> Self {
983 self.alt_svc_h3_port = Some(port);
984 self
985 }
986
987 pub fn route_descriptors(&self) -> Vec<HttpRouteDescriptor> {
991 let mut descriptors = self
992 .routes
993 .iter()
994 .map(|entry| {
995 let mut desc = HttpRouteDescriptor::new(entry.method.clone(), entry.pattern.raw.clone());
996 desc.body_schema = entry.body_schema.clone();
997 desc
998 })
999 .collect::<Vec<_>>();
1000
1001 if let Some(path) = &self.health.health_path {
1002 descriptors.push(HttpRouteDescriptor::new(Method::GET, path.clone()));
1003 }
1004 if let Some(path) = &self.health.readiness_path {
1005 descriptors.push(HttpRouteDescriptor::new(Method::GET, path.clone()));
1006 }
1007 if let Some(path) = &self.health.liveness_path {
1008 descriptors.push(HttpRouteDescriptor::new(Method::GET, path.clone()));
1009 }
1010
1011 descriptors
1012 }
1013
1014 pub fn serve_dir(
1020 mut self,
1021 route_prefix: impl Into<String>,
1022 directory: impl Into<String>,
1023 ) -> Self {
1024 self.static_assets.mounts.push(StaticMount {
1025 route_prefix: normalize_route_path(route_prefix.into()),
1026 directory: directory.into(),
1027 });
1028 if self.static_assets.cache_control.is_none() {
1029 self.static_assets.cache_control = Some("public, max-age=3600".to_string());
1030 }
1031 self
1032 }
1033
1034 pub fn spa_fallback(mut self, file_path: impl Into<String>) -> Self {
1038 self.static_assets.spa_fallback = Some(file_path.into());
1039 self
1040 }
1041
1042 pub fn static_cache_control(mut self, cache_control: impl Into<String>) -> Self {
1044 self.static_assets.cache_control = Some(cache_control.into());
1045 self
1046 }
1047
1048 pub fn directory_index(mut self, filename: impl Into<String>) -> Self {
1056 self.static_assets.directory_index = Some(filename.into());
1057 self
1058 }
1059
1060 pub fn immutable_cache(mut self) -> Self {
1065 self.static_assets.immutable_cache = true;
1066 self
1067 }
1068
1069 pub fn serve_precompressed(mut self) -> Self {
1075 self.static_assets.serve_precompressed = true;
1076 self
1077 }
1078
1079 pub fn enable_range_requests(mut self) -> Self {
1084 self.static_assets.enable_range_requests = true;
1085 self
1086 }
1087
1088 pub fn compression_layer(mut self) -> Self {
1090 self.static_assets.enable_compression = true;
1091 self
1092 }
1093
1094 pub fn ws<H, Fut>(mut self, path: impl Into<String>, handler: H) -> Self
1103 where
1104 H: Fn(WebSocketConnection, Arc<R>, Bus) -> Fut + Send + Sync + 'static,
1105 Fut: Future<Output = ()> + Send + 'static,
1106 {
1107 let path_str: String = path.into();
1108 let ws_handler: WsSessionHandler<R> = Arc::new(move |connection, resources, bus| {
1109 Box::pin(handler(connection, resources, bus))
1110 });
1111 let bus_injectors = Arc::new(self.bus_injectors.clone());
1112 let ws_guard_execs = Arc::new(self.guard_execs.clone());
1113 let path_for_pattern = path_str.clone();
1114 let path_for_handler = path_str;
1115
1116 let route_handler: RouteHandler<R> =
1117 Arc::new(move |parts: http::request::Parts, res: &R| {
1118 let ws_handler = ws_handler.clone();
1119 let bus_injectors = bus_injectors.clone();
1120 let ws_guard_execs = ws_guard_execs.clone();
1121 let resources = Arc::new(res.clone());
1122 let path = path_for_handler.clone();
1123
1124 Box::pin(async move {
1125 let request_id = uuid::Uuid::new_v4().to_string();
1126 let span = tracing::info_span!(
1127 "WebSocketUpgrade",
1128 ranvier.ws.path = %path,
1129 ranvier.ws.request_id = %request_id
1130 );
1131
1132 async move {
1133 let mut bus = Bus::new();
1134 for injector in bus_injectors.iter() {
1135 injector(&parts, &mut bus);
1136 }
1137 for guard_exec in ws_guard_execs.iter() {
1138 if let Err(rejection) = guard_exec.exec_guard(&mut bus).await {
1139 return json_error_response(rejection.status, &rejection.message);
1140 }
1141 }
1142
1143 let mut req = Request::from_parts(parts, ());
1145 let session = websocket_session_from_request(&req);
1146 bus.insert(session.clone());
1147
1148 let (response, on_upgrade) = match websocket_upgrade_response(&mut req) {
1149 Ok(result) => result,
1150 Err(error_response) => return error_response,
1151 };
1152
1153 tokio::spawn(async move {
1154 match on_upgrade.await {
1155 Ok(upgraded) => {
1156 let stream = WebSocketStream::from_raw_socket(
1157 TokioIo::new(upgraded),
1158 tokio_tungstenite::tungstenite::protocol::Role::Server,
1159 None,
1160 )
1161 .await;
1162 let connection = WebSocketConnection::new(stream, session);
1163 ws_handler(connection, resources, bus).await;
1164 }
1165 Err(error) => {
1166 tracing::warn!(
1167 ranvier.ws.path = %path,
1168 ranvier.ws.error = %error,
1169 "websocket upgrade failed"
1170 );
1171 }
1172 }
1173 });
1174
1175 response
1176 }
1177 .instrument(span)
1178 .await
1179 }) as Pin<Box<dyn Future<Output = HttpResponse> + Send>>
1180 });
1181
1182 self.routes.push(RouteEntry {
1183 method: Method::GET,
1184 pattern: RoutePattern::parse(&path_for_pattern),
1185 handler: route_handler,
1186 layers: Arc::new(Vec::new()),
1187 apply_global_layers: true,
1188 needs_body: false,
1189 body_schema: None,
1190 });
1191
1192 self
1193 }
1194
1195 pub fn health_endpoint(mut self, path: impl Into<String>) -> Self {
1202 self.health.health_path = Some(normalize_route_path(path.into()));
1203 self
1204 }
1205
1206 pub fn health_check<F, Fut, Err>(mut self, name: impl Into<String>, check: F) -> Self
1210 where
1211 F: Fn(Arc<R>) -> Fut + Send + Sync + 'static,
1212 Fut: Future<Output = Result<(), Err>> + Send + 'static,
1213 Err: ToString + Send + 'static,
1214 {
1215 if self.health.health_path.is_none() {
1216 self.health.health_path = Some("/health".to_string());
1217 }
1218
1219 let check_fn: HealthCheckFn<R> = Arc::new(move |resources: Arc<R>| {
1220 let fut = check(resources);
1221 Box::pin(async move { fut.await.map_err(|error| error.to_string()) })
1222 });
1223
1224 self.health.checks.push(NamedHealthCheck {
1225 name: name.into(),
1226 check: check_fn,
1227 });
1228 self
1229 }
1230
1231 pub fn readiness_liveness(
1233 mut self,
1234 readiness_path: impl Into<String>,
1235 liveness_path: impl Into<String>,
1236 ) -> Self {
1237 self.health.readiness_path = Some(normalize_route_path(readiness_path.into()));
1238 self.health.liveness_path = Some(normalize_route_path(liveness_path.into()));
1239 self
1240 }
1241
1242 pub fn readiness_liveness_default(self) -> Self {
1244 self.readiness_liveness("/ready", "/live")
1245 }
1246
1247 pub fn route<Out, E>(self, path: impl Into<String>, circuit: Axon<(), Out, E, R>) -> Self
1251 where
1252 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1253 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1254 {
1255 self.route_method(Method::GET, path, circuit)
1256 }
1257 pub fn route_method<Out, E>(
1266 self,
1267 method: Method,
1268 path: impl Into<String>,
1269 circuit: Axon<(), Out, E, R>,
1270 ) -> Self
1271 where
1272 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1273 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1274 {
1275 self.route_method_with_error(method, path, circuit, |error| {
1276 (
1277 StatusCode::INTERNAL_SERVER_ERROR,
1278 format!("Error: {:?}", error),
1279 )
1280 .into_response()
1281 })
1282 }
1283
1284 pub fn route_method_with_error<Out, E, H>(
1285 self,
1286 method: Method,
1287 path: impl Into<String>,
1288 circuit: Axon<(), Out, E, R>,
1289 error_handler: H,
1290 ) -> Self
1291 where
1292 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1293 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1294 H: Fn(&E) -> HttpResponse + Send + Sync + 'static,
1295 {
1296 self.route_method_with_error_and_layers(
1297 method,
1298 path,
1299 circuit,
1300 error_handler,
1301 Arc::new(Vec::new()),
1302 true,
1303 )
1304 }
1305
1306
1307
1308 fn route_method_with_error_and_layers<Out, E, H>(
1309 mut self,
1310 method: Method,
1311 path: impl Into<String>,
1312 circuit: Axon<(), Out, E, R>,
1313 error_handler: H,
1314 route_layers: Arc<Vec<ServiceLayer>>,
1315 apply_global_layers: bool,
1316 ) -> Self
1317 where
1318 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1319 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1320 H: Fn(&E) -> HttpResponse + Send + Sync + 'static,
1321 {
1322 let path_str: String = path.into();
1323 let circuit = Arc::new(circuit);
1324 let error_handler = Arc::new(error_handler);
1325 let route_bus_injectors = Arc::new(self.bus_injectors.clone());
1326 let route_guard_execs = Arc::new(self.guard_execs.clone());
1327 let route_response_extractors = Arc::new(self.guard_response_extractors.clone());
1328 let route_body_transforms = Arc::new(self.guard_body_transforms.clone());
1329 let path_for_pattern = path_str.clone();
1330 let path_for_handler = path_str;
1331 let method_for_pattern = method.clone();
1332 let method_for_handler = method;
1333
1334 let handler: RouteHandler<R> = Arc::new(move |parts: http::request::Parts, res: &R| {
1335 let circuit = circuit.clone();
1336 let error_handler = error_handler.clone();
1337 let route_bus_injectors = route_bus_injectors.clone();
1338 let route_guard_execs = route_guard_execs.clone();
1339 let route_response_extractors = route_response_extractors.clone();
1340 let route_body_transforms = route_body_transforms.clone();
1341 let res = res.clone();
1342 let path = path_for_handler.clone();
1343 let method = method_for_handler.clone();
1344
1345 Box::pin(async move {
1346 let request_id = uuid::Uuid::new_v4().to_string();
1347 let span = tracing::info_span!(
1348 "HTTPRequest",
1349 ranvier.http.method = %method,
1350 ranvier.http.path = %path,
1351 ranvier.http.request_id = %request_id
1352 );
1353
1354 async move {
1355 let mut bus = Bus::new();
1356 for injector in route_bus_injectors.iter() {
1357 injector(&parts, &mut bus);
1358 }
1359 for guard_exec in route_guard_execs.iter() {
1360 if let Err(rejection) = guard_exec.exec_guard(&mut bus).await {
1361 let mut response = json_error_response(rejection.status, &rejection.message);
1362 for extractor in route_response_extractors.iter() {
1363 extractor(&bus, response.headers_mut());
1364 }
1365 return response;
1366 }
1367 }
1368 if let Some(cached) = bus.read::<ranvier_guard::IdempotencyCachedResponse>() {
1370 let body = Bytes::from(cached.body.clone());
1371 let mut response = Response::builder()
1372 .status(StatusCode::OK)
1373 .header("content-type", "application/json")
1374 .body(Full::new(body).map_err(|n: Infallible| match n {}).boxed())
1375 .unwrap();
1376 for extractor in route_response_extractors.iter() {
1377 extractor(&bus, response.headers_mut());
1378 }
1379 return response;
1380 }
1381 let result = if let Some(td) = bus.read::<ranvier_guard::TimeoutDeadline>() {
1383 let remaining = td.remaining();
1384 if remaining.is_zero() {
1385 let mut response = json_error_response(
1386 StatusCode::REQUEST_TIMEOUT,
1387 "Request timeout: pipeline deadline exceeded",
1388 );
1389 for extractor in route_response_extractors.iter() {
1390 extractor(&bus, response.headers_mut());
1391 }
1392 return response;
1393 }
1394 match tokio::time::timeout(remaining, circuit.execute((), &res, &mut bus)).await {
1395 Ok(result) => result,
1396 Err(_) => {
1397 let mut response = json_error_response(
1398 StatusCode::REQUEST_TIMEOUT,
1399 "Request timeout: pipeline deadline exceeded",
1400 );
1401 for extractor in route_response_extractors.iter() {
1402 extractor(&bus, response.headers_mut());
1403 }
1404 return response;
1405 }
1406 }
1407 } else {
1408 circuit.execute((), &res, &mut bus).await
1409 };
1410 let mut response = outcome_to_response_with_error(result, |error| error_handler(error));
1411 for extractor in route_response_extractors.iter() {
1412 extractor(&bus, response.headers_mut());
1413 }
1414 if !route_body_transforms.is_empty() {
1415 response = apply_body_transforms(response, &bus, &route_body_transforms).await;
1416 }
1417 response
1418 }
1419 .instrument(span)
1420 .await
1421 }) as Pin<Box<dyn Future<Output = HttpResponse> + Send>>
1422 });
1423
1424 self.routes.push(RouteEntry {
1425 method: method_for_pattern,
1426 pattern: RoutePattern::parse(&path_for_pattern),
1427 handler,
1428 layers: route_layers,
1429 apply_global_layers,
1430 needs_body: false,
1431 body_schema: None,
1432 });
1433 self
1434 }
1435
1436 fn route_method_typed<T, Out, E>(
1442 mut self,
1443 method: Method,
1444 path: impl Into<String>,
1445 circuit: Axon<T, Out, E, R>,
1446 ) -> Self
1447 where
1448 T: serde::de::DeserializeOwned + Send + Sync + serde::Serialize + schemars::JsonSchema + 'static,
1449 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1450 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1451 {
1452 let body_schema = serde_json::to_value(schemars::schema_for!(T)).ok();
1453 let path_str: String = path.into();
1454 let circuit = Arc::new(circuit);
1455 let route_bus_injectors = Arc::new(self.bus_injectors.clone());
1456 let route_guard_execs = Arc::new(self.guard_execs.clone());
1457 let route_response_extractors = Arc::new(self.guard_response_extractors.clone());
1458 let route_body_transforms = Arc::new(self.guard_body_transforms.clone());
1459 let path_for_pattern = path_str.clone();
1460 let path_for_handler = path_str;
1461 let method_for_pattern = method.clone();
1462 let method_for_handler = method;
1463
1464 let handler: RouteHandler<R> = Arc::new(move |parts: http::request::Parts, res: &R| {
1465 let circuit = circuit.clone();
1466 let route_bus_injectors = route_bus_injectors.clone();
1467 let route_guard_execs = route_guard_execs.clone();
1468 let route_response_extractors = route_response_extractors.clone();
1469 let route_body_transforms = route_body_transforms.clone();
1470 let res = res.clone();
1471 let path = path_for_handler.clone();
1472 let method = method_for_handler.clone();
1473
1474 Box::pin(async move {
1475 let request_id = uuid::Uuid::new_v4().to_string();
1476 let span = tracing::info_span!(
1477 "HTTPRequest",
1478 ranvier.http.method = %method,
1479 ranvier.http.path = %path,
1480 ranvier.http.request_id = %request_id
1481 );
1482
1483 async move {
1484 let body_bytes = parts
1486 .extensions
1487 .get::<BodyBytes>()
1488 .map(|b| b.0.clone())
1489 .unwrap_or_default();
1490
1491 let input: T = match serde_json::from_slice(&body_bytes) {
1493 Ok(v) => v,
1494 Err(e) => {
1495 return json_error_response(
1496 StatusCode::BAD_REQUEST,
1497 &format!("Invalid request body: {}", e),
1498 );
1499 }
1500 };
1501
1502 let mut bus = Bus::new();
1503 for injector in route_bus_injectors.iter() {
1504 injector(&parts, &mut bus);
1505 }
1506 for guard_exec in route_guard_execs.iter() {
1507 if let Err(rejection) = guard_exec.exec_guard(&mut bus).await {
1508 let mut response = json_error_response(rejection.status, &rejection.message);
1509 for extractor in route_response_extractors.iter() {
1510 extractor(&bus, response.headers_mut());
1511 }
1512 return response;
1513 }
1514 }
1515 if let Some(cached) = bus.read::<ranvier_guard::IdempotencyCachedResponse>() {
1517 let body = Bytes::from(cached.body.clone());
1518 let mut response = Response::builder()
1519 .status(StatusCode::OK)
1520 .header("content-type", "application/json")
1521 .body(Full::new(body).map_err(|n: Infallible| match n {}).boxed())
1522 .unwrap();
1523 for extractor in route_response_extractors.iter() {
1524 extractor(&bus, response.headers_mut());
1525 }
1526 return response;
1527 }
1528 let result = if let Some(td) = bus.read::<ranvier_guard::TimeoutDeadline>() {
1530 let remaining = td.remaining();
1531 if remaining.is_zero() {
1532 let mut response = json_error_response(
1533 StatusCode::REQUEST_TIMEOUT,
1534 "Request timeout: pipeline deadline exceeded",
1535 );
1536 for extractor in route_response_extractors.iter() {
1537 extractor(&bus, response.headers_mut());
1538 }
1539 return response;
1540 }
1541 match tokio::time::timeout(remaining, circuit.execute(input, &res, &mut bus)).await {
1542 Ok(result) => result,
1543 Err(_) => {
1544 let mut response = json_error_response(
1545 StatusCode::REQUEST_TIMEOUT,
1546 "Request timeout: pipeline deadline exceeded",
1547 );
1548 for extractor in route_response_extractors.iter() {
1549 extractor(&bus, response.headers_mut());
1550 }
1551 return response;
1552 }
1553 }
1554 } else {
1555 circuit.execute(input, &res, &mut bus).await
1556 };
1557 let mut response = outcome_to_response_with_error(result, |error| {
1558 if cfg!(debug_assertions) {
1559 (
1560 StatusCode::INTERNAL_SERVER_ERROR,
1561 format!("Error: {:?}", error),
1562 )
1563 .into_response()
1564 } else {
1565 json_error_response(
1566 StatusCode::INTERNAL_SERVER_ERROR,
1567 "Internal server error",
1568 )
1569 }
1570 });
1571 for extractor in route_response_extractors.iter() {
1572 extractor(&bus, response.headers_mut());
1573 }
1574 if !route_body_transforms.is_empty() {
1575 response = apply_body_transforms(response, &bus, &route_body_transforms).await;
1576 }
1577 response
1578 }
1579 .instrument(span)
1580 .await
1581 }) as Pin<Box<dyn Future<Output = HttpResponse> + Send>>
1582 });
1583
1584 self.routes.push(RouteEntry {
1585 method: method_for_pattern,
1586 pattern: RoutePattern::parse(&path_for_pattern),
1587 handler,
1588 layers: Arc::new(Vec::new()),
1589 apply_global_layers: true,
1590 needs_body: true,
1591 body_schema,
1592 });
1593 self
1594 }
1595
1596 pub fn get<Out, E>(self, path: impl Into<String>, circuit: Axon<(), Out, E, R>) -> Self
1597 where
1598 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1599 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1600 {
1601 self.route_method(Method::GET, path, circuit)
1602 }
1603
1604 pub fn get_with_error<Out, E, H>(
1605 self,
1606 path: impl Into<String>,
1607 circuit: Axon<(), Out, E, R>,
1608 error_handler: H,
1609 ) -> Self
1610 where
1611 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1612 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1613 H: Fn(&E) -> HttpResponse + Send + Sync + 'static,
1614 {
1615 self.route_method_with_error(Method::GET, path, circuit, error_handler)
1616 }
1617
1618 pub fn post<Out, E>(self, path: impl Into<String>, circuit: Axon<(), Out, E, R>) -> Self
1619 where
1620 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1621 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1622 {
1623 self.route_method(Method::POST, path, circuit)
1624 }
1625
1626 pub fn post_typed<T, Out, E>(
1642 self,
1643 path: impl Into<String>,
1644 circuit: Axon<T, Out, E, R>,
1645 ) -> Self
1646 where
1647 T: serde::de::DeserializeOwned + Send + Sync + serde::Serialize + schemars::JsonSchema + 'static,
1648 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1649 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1650 {
1651 self.route_method_typed::<T, Out, E>(Method::POST, path, circuit)
1652 }
1653
1654 pub fn put_typed<T, Out, E>(
1658 self,
1659 path: impl Into<String>,
1660 circuit: Axon<T, Out, E, R>,
1661 ) -> Self
1662 where
1663 T: serde::de::DeserializeOwned + Send + Sync + serde::Serialize + schemars::JsonSchema + 'static,
1664 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1665 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1666 {
1667 self.route_method_typed::<T, Out, E>(Method::PUT, path, circuit)
1668 }
1669
1670 pub fn patch_typed<T, Out, E>(
1674 self,
1675 path: impl Into<String>,
1676 circuit: Axon<T, Out, E, R>,
1677 ) -> Self
1678 where
1679 T: serde::de::DeserializeOwned + Send + Sync + serde::Serialize + schemars::JsonSchema + 'static,
1680 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1681 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1682 {
1683 self.route_method_typed::<T, Out, E>(Method::PATCH, path, circuit)
1684 }
1685
1686 pub fn put<Out, E>(self, path: impl Into<String>, circuit: Axon<(), Out, E, R>) -> Self
1687 where
1688 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1689 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1690 {
1691 self.route_method(Method::PUT, path, circuit)
1692 }
1693
1694 pub fn delete<Out, E>(self, path: impl Into<String>, circuit: Axon<(), Out, E, R>) -> Self
1695 where
1696 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1697 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1698 {
1699 self.route_method(Method::DELETE, path, circuit)
1700 }
1701
1702 pub fn patch<Out, E>(self, path: impl Into<String>, circuit: Axon<(), Out, E, R>) -> Self
1703 where
1704 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1705 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1706 {
1707 self.route_method(Method::PATCH, path, circuit)
1708 }
1709
1710 pub fn post_with_error<Out, E, H>(
1711 self,
1712 path: impl Into<String>,
1713 circuit: Axon<(), Out, E, R>,
1714 error_handler: H,
1715 ) -> Self
1716 where
1717 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1718 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1719 H: Fn(&E) -> HttpResponse + Send + Sync + 'static,
1720 {
1721 self.route_method_with_error(Method::POST, path, circuit, error_handler)
1722 }
1723
1724 pub fn put_with_error<Out, E, H>(
1725 self,
1726 path: impl Into<String>,
1727 circuit: Axon<(), Out, E, R>,
1728 error_handler: H,
1729 ) -> Self
1730 where
1731 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1732 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1733 H: Fn(&E) -> HttpResponse + Send + Sync + 'static,
1734 {
1735 self.route_method_with_error(Method::PUT, path, circuit, error_handler)
1736 }
1737
1738 pub fn delete_with_error<Out, E, H>(
1739 self,
1740 path: impl Into<String>,
1741 circuit: Axon<(), Out, E, R>,
1742 error_handler: H,
1743 ) -> Self
1744 where
1745 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1746 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1747 H: Fn(&E) -> HttpResponse + Send + Sync + 'static,
1748 {
1749 self.route_method_with_error(Method::DELETE, path, circuit, error_handler)
1750 }
1751
1752 pub fn patch_with_error<Out, E, H>(
1753 self,
1754 path: impl Into<String>,
1755 circuit: Axon<(), Out, E, R>,
1756 error_handler: H,
1757 ) -> Self
1758 where
1759 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1760 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1761 H: Fn(&E) -> HttpResponse + Send + Sync + 'static,
1762 {
1763 self.route_method_with_error(Method::PATCH, path, circuit, error_handler)
1764 }
1765
1766 fn route_method_with_extra_guards<Out, E>(
1772 mut self,
1773 method: Method,
1774 path: impl Into<String>,
1775 circuit: Axon<(), Out, E, R>,
1776 extra_guards: Vec<RegisteredGuard>,
1777 ) -> Self
1778 where
1779 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1780 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1781 {
1782 let saved_injectors = self.bus_injectors.len();
1784 let saved_execs = self.guard_execs.len();
1785 let saved_extractors = self.guard_response_extractors.len();
1786 let saved_transforms = self.guard_body_transforms.len();
1787
1788 for registration in extra_guards {
1790 for injector in registration.bus_injectors {
1791 self.bus_injectors.push(injector);
1792 }
1793 self.guard_execs.push(registration.exec);
1794 if let Some(extractor) = registration.response_extractor {
1795 self.guard_response_extractors.push(extractor);
1796 }
1797 if let Some(transform) = registration.response_body_transform {
1798 self.guard_body_transforms.push(transform);
1799 }
1800 }
1801
1802 self = self.route_method(method, path, circuit);
1804
1805 self.bus_injectors.truncate(saved_injectors);
1807 self.guard_execs.truncate(saved_execs);
1808 self.guard_response_extractors.truncate(saved_extractors);
1809 self.guard_body_transforms.truncate(saved_transforms);
1810
1811 self
1812 }
1813
1814 pub fn get_with_guards<Out, E>(
1832 self,
1833 path: impl Into<String>,
1834 circuit: Axon<(), Out, E, R>,
1835 extra_guards: Vec<RegisteredGuard>,
1836 ) -> Self
1837 where
1838 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1839 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1840 {
1841 self.route_method_with_extra_guards(Method::GET, path, circuit, extra_guards)
1842 }
1843
1844 pub fn post_with_guards<Out, E>(
1865 self,
1866 path: impl Into<String>,
1867 circuit: Axon<(), Out, E, R>,
1868 extra_guards: Vec<RegisteredGuard>,
1869 ) -> Self
1870 where
1871 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1872 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1873 {
1874 self.route_method_with_extra_guards(Method::POST, path, circuit, extra_guards)
1875 }
1876
1877 pub fn put_with_guards<Out, E>(
1879 self,
1880 path: impl Into<String>,
1881 circuit: Axon<(), Out, E, R>,
1882 extra_guards: Vec<RegisteredGuard>,
1883 ) -> Self
1884 where
1885 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1886 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1887 {
1888 self.route_method_with_extra_guards(Method::PUT, path, circuit, extra_guards)
1889 }
1890
1891 pub fn delete_with_guards<Out, E>(
1893 self,
1894 path: impl Into<String>,
1895 circuit: Axon<(), Out, E, R>,
1896 extra_guards: Vec<RegisteredGuard>,
1897 ) -> Self
1898 where
1899 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1900 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1901 {
1902 self.route_method_with_extra_guards(Method::DELETE, path, circuit, extra_guards)
1903 }
1904
1905 pub fn patch_with_guards<Out, E>(
1907 self,
1908 path: impl Into<String>,
1909 circuit: Axon<(), Out, E, R>,
1910 extra_guards: Vec<RegisteredGuard>,
1911 ) -> Self
1912 where
1913 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1914 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1915 {
1916 self.route_method_with_extra_guards(Method::PATCH, path, circuit, extra_guards)
1917 }
1918
1919 pub fn fallback<Out, E>(mut self, circuit: Axon<(), Out, E, R>) -> Self
1930 where
1931 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1932 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1933 {
1934 let circuit = Arc::new(circuit);
1935 let fallback_bus_injectors = Arc::new(self.bus_injectors.clone());
1936 let fallback_guard_execs = Arc::new(self.guard_execs.clone());
1937 let fallback_response_extractors = Arc::new(self.guard_response_extractors.clone());
1938 let fallback_body_transforms = Arc::new(self.guard_body_transforms.clone());
1939
1940 let handler: RouteHandler<R> = Arc::new(move |parts: http::request::Parts, res: &R| {
1941 let circuit = circuit.clone();
1942 let fallback_bus_injectors = fallback_bus_injectors.clone();
1943 let fallback_guard_execs = fallback_guard_execs.clone();
1944 let fallback_response_extractors = fallback_response_extractors.clone();
1945 let fallback_body_transforms = fallback_body_transforms.clone();
1946 let res = res.clone();
1947 Box::pin(async move {
1948 let request_id = uuid::Uuid::new_v4().to_string();
1949 let span = tracing::info_span!(
1950 "HTTPRequest",
1951 ranvier.http.method = "FALLBACK",
1952 ranvier.http.request_id = %request_id
1953 );
1954
1955 async move {
1956 let mut bus = Bus::new();
1957 for injector in fallback_bus_injectors.iter() {
1958 injector(&parts, &mut bus);
1959 }
1960 for guard_exec in fallback_guard_execs.iter() {
1961 if let Err(rejection) = guard_exec.exec_guard(&mut bus).await {
1962 let mut response = json_error_response(rejection.status, &rejection.message);
1963 for extractor in fallback_response_extractors.iter() {
1964 extractor(&bus, response.headers_mut());
1965 }
1966 return response;
1967 }
1968 }
1969 let result: ranvier_core::Outcome<Out, E> =
1970 circuit.execute((), &res, &mut bus).await;
1971
1972 let mut response = match result {
1973 Outcome::Next(output) => {
1974 let mut response = output.into_response();
1975 *response.status_mut() = StatusCode::NOT_FOUND;
1976 response
1977 }
1978 _ => Response::builder()
1979 .status(StatusCode::NOT_FOUND)
1980 .body(
1981 Full::new(Bytes::from("Not Found"))
1982 .map_err(|never| match never {})
1983 .boxed(),
1984 )
1985 .expect("valid HTTP response construction"),
1986 };
1987 for extractor in fallback_response_extractors.iter() {
1988 extractor(&bus, response.headers_mut());
1989 }
1990 if !fallback_body_transforms.is_empty() {
1991 response = apply_body_transforms(response, &bus, &fallback_body_transforms).await;
1992 }
1993 response
1994 }
1995 .instrument(span)
1996 .await
1997 }) as Pin<Box<dyn Future<Output = HttpResponse> + Send>>
1998 });
1999
2000 self.fallback = Some(handler);
2001 self
2002 }
2003
2004 pub async fn run(self, resources: R) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
2008 self.run_with_shutdown_signal(resources, shutdown_signal())
2009 .await
2010 }
2011
2012 async fn run_with_shutdown_signal<S>(
2013 self,
2014 resources: R,
2015 shutdown_signal: S,
2016 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>>
2017 where
2018 S: Future<Output = ()> + Send,
2019 {
2020 let addr_str = self.addr.as_deref().unwrap_or("127.0.0.1:3000");
2021 let addr: SocketAddr = addr_str.parse()?;
2022
2023 let mut raw_routes = self.routes;
2024 if self.active_intervention {
2025 let handler: RouteHandler<R> = Arc::new(|_parts, _res| {
2026 Box::pin(async move {
2027 Response::builder()
2028 .status(StatusCode::OK)
2029 .body(
2030 Full::new(Bytes::from("Intervention accepted"))
2031 .map_err(|never| match never {} as Infallible)
2032 .boxed(),
2033 )
2034 .expect("valid HTTP response construction")
2035 }) as Pin<Box<dyn Future<Output = HttpResponse> + Send>>
2036 });
2037
2038 raw_routes.push(RouteEntry {
2039 method: Method::POST,
2040 pattern: RoutePattern::parse("/_system/intervene/force_resume"),
2041 handler,
2042 layers: Arc::new(Vec::new()),
2043 apply_global_layers: true,
2044 needs_body: false,
2045 body_schema: None,
2046 });
2047 }
2048
2049 if let Some(registry) = self.policy_registry.clone() {
2050 let handler: RouteHandler<R> = Arc::new(move |_parts, _res| {
2051 let _registry = registry.clone();
2052 Box::pin(async move {
2053 Response::builder()
2057 .status(StatusCode::OK)
2058 .body(
2059 Full::new(Bytes::from("Policy registry active"))
2060 .map_err(|never| match never {} as Infallible)
2061 .boxed(),
2062 )
2063 .expect("valid HTTP response construction")
2064 }) as Pin<Box<dyn Future<Output = HttpResponse> + Send>>
2065 });
2066
2067 raw_routes.push(RouteEntry {
2068 method: Method::POST,
2069 pattern: RoutePattern::parse("/_system/policy/reload"),
2070 handler,
2071 layers: Arc::new(Vec::new()),
2072 apply_global_layers: true,
2073 needs_body: false,
2074 body_schema: None,
2075 });
2076 }
2077 let routes = Arc::new(raw_routes);
2078 let fallback = self.fallback;
2079 let layers = Arc::new(self.layers);
2080 let health = Arc::new(self.health);
2081 let static_assets = Arc::new(self.static_assets);
2082 let preflight_config = Arc::new(self.preflight_config);
2083 let on_start = self.on_start;
2084 let on_shutdown = self.on_shutdown;
2085 let graceful_shutdown_timeout = self.graceful_shutdown_timeout;
2086 let resources = Arc::new(resources);
2087
2088 let listener = TcpListener::bind(addr).await?;
2089
2090 #[cfg(feature = "tls")]
2092 let tls_acceptor = if let Some(ref tls_cfg) = self.tls_config {
2093 let acceptor = build_tls_acceptor(&tls_cfg.cert_path, &tls_cfg.key_path)?;
2094 tracing::info!("Ranvier HTTP Ingress listening on https://{}", addr);
2095 Some(acceptor)
2096 } else {
2097 tracing::info!("Ranvier HTTP Ingress listening on http://{}", addr);
2098 None
2099 };
2100 #[cfg(not(feature = "tls"))]
2101 tracing::info!("Ranvier HTTP Ingress listening on http://{}", addr);
2102
2103 if let Some(callback) = on_start.as_ref() {
2104 callback();
2105 }
2106
2107 tokio::pin!(shutdown_signal);
2108 let mut connections = tokio::task::JoinSet::new();
2109
2110 loop {
2111 tokio::select! {
2112 _ = &mut shutdown_signal => {
2113 tracing::info!("Shutdown signal received. Draining in-flight connections.");
2114 break;
2115 }
2116 accept_result = listener.accept() => {
2117 let (stream, _) = accept_result?;
2118
2119 let routes = routes.clone();
2120 let fallback = fallback.clone();
2121 let resources = resources.clone();
2122 let layers = layers.clone();
2123 let health = health.clone();
2124 let static_assets = static_assets.clone();
2125 let preflight_config = preflight_config.clone();
2126 #[cfg(feature = "http3")]
2127 let alt_svc_h3_port = self.alt_svc_h3_port;
2128
2129 #[cfg(feature = "tls")]
2130 let tls_acceptor = tls_acceptor.clone();
2131
2132 connections.spawn(async move {
2133 let service = build_http_service(
2134 routes,
2135 fallback,
2136 resources,
2137 layers,
2138 health,
2139 static_assets,
2140 preflight_config,
2141 #[cfg(feature = "http3")] alt_svc_h3_port,
2142 );
2143
2144 #[cfg(feature = "tls")]
2145 if let Some(acceptor) = tls_acceptor {
2146 match acceptor.accept(stream).await {
2147 Ok(tls_stream) => {
2148 let io = TokioIo::new(tls_stream);
2149 if let Err(err) = http1::Builder::new()
2150 .serve_connection(io, service)
2151 .with_upgrades()
2152 .await
2153 {
2154 tracing::error!("Error serving TLS connection: {:?}", err);
2155 }
2156 }
2157 Err(err) => {
2158 tracing::warn!("TLS handshake failed: {:?}", err);
2159 }
2160 }
2161 return;
2162 }
2163
2164 let io = TokioIo::new(stream);
2165 if let Err(err) = http1::Builder::new()
2166 .serve_connection(io, service)
2167 .with_upgrades()
2168 .await
2169 {
2170 tracing::error!("Error serving connection: {:?}", err);
2171 }
2172 });
2173 }
2174 Some(join_result) = connections.join_next(), if !connections.is_empty() => {
2175 if let Err(err) = join_result {
2176 tracing::warn!("Connection task join error: {:?}", err);
2177 }
2178 }
2179 }
2180 }
2181
2182 let _timed_out = drain_connections(&mut connections, graceful_shutdown_timeout).await;
2183
2184 drop(resources);
2185 if let Some(callback) = on_shutdown.as_ref() {
2186 callback();
2187 }
2188
2189 Ok(())
2190 }
2191
2192 pub fn into_raw_service(self, resources: R) -> RawIngressService<R> {
2208 let routes = Arc::new(self.routes);
2209 let fallback = self.fallback;
2210 let layers = Arc::new(self.layers);
2211 let health = Arc::new(self.health);
2212 let static_assets = Arc::new(self.static_assets);
2213 let preflight_config = Arc::new(self.preflight_config);
2214 let resources = Arc::new(resources);
2215
2216 RawIngressService {
2217 routes,
2218 fallback,
2219 layers,
2220 health,
2221 static_assets,
2222 preflight_config,
2223 resources,
2224 }
2225 }
2226}
2227
2228async fn apply_body_transforms(
2233 response: HttpResponse,
2234 bus: &Bus,
2235 transforms: &[ResponseBodyTransformFn],
2236) -> HttpResponse {
2237 use http_body_util::BodyExt;
2238
2239 let (parts, body) = response.into_parts();
2240
2241 let collected = match body.collect().await {
2243 Ok(c) => c.to_bytes(),
2244 Err(_) => {
2245 return Response::builder()
2247 .status(StatusCode::INTERNAL_SERVER_ERROR)
2248 .body(
2249 Full::new(Bytes::from("body collection failed"))
2250 .map_err(|never| match never {})
2251 .boxed(),
2252 )
2253 .expect("valid response");
2254 }
2255 };
2256
2257 let mut transformed = collected;
2258 for transform in transforms {
2259 transformed = transform(bus, transformed);
2260 }
2261
2262 Response::from_parts(
2263 parts,
2264 Full::new(transformed)
2265 .map_err(|never| match never {})
2266 .boxed(),
2267 )
2268}
2269
2270fn build_http_service<R>(
2271 routes: Arc<Vec<RouteEntry<R>>>,
2272 fallback: Option<RouteHandler<R>>,
2273 resources: Arc<R>,
2274 layers: Arc<Vec<ServiceLayer>>,
2275 health: Arc<HealthConfig<R>>,
2276 static_assets: Arc<StaticAssetsConfig>,
2277 preflight_config: Arc<Option<PreflightConfig>>,
2278 #[cfg(feature = "http3")] alt_svc_port: Option<u16>,
2279) -> BoxHttpService
2280where
2281 R: ranvier_core::transition::ResourceRequirement + Clone + Send + Sync + 'static,
2282{
2283 BoxService::new(move |req: Request<Incoming>| {
2284 let routes = routes.clone();
2285 let fallback = fallback.clone();
2286 let resources = resources.clone();
2287 let layers = layers.clone();
2288 let health = health.clone();
2289 let static_assets = static_assets.clone();
2290 let preflight_config = preflight_config.clone();
2291
2292 async move {
2293 let mut req = req;
2294 let method = req.method().clone();
2295 let path = req.uri().path().to_string();
2296
2297 if let Some(response) =
2298 maybe_handle_health_request(&method, &path, &health, resources.clone()).await
2299 {
2300 return Ok::<_, Infallible>(response.into_response());
2301 }
2302
2303 if method == Method::OPTIONS {
2305 if let Some(ref config) = *preflight_config {
2306 let origin = req
2307 .headers()
2308 .get("origin")
2309 .and_then(|v| v.to_str().ok())
2310 .unwrap_or("");
2311 let is_wildcard = config.allowed_origins.iter().any(|o| o == "*");
2312 let is_allowed = is_wildcard
2313 || config.allowed_origins.iter().any(|o| o == origin);
2314
2315 if is_allowed || origin.is_empty() {
2316 let allow_origin = if is_wildcard {
2317 "*".to_string()
2318 } else {
2319 origin.to_string()
2320 };
2321 let mut response = Response::builder()
2322 .status(StatusCode::NO_CONTENT)
2323 .body(
2324 Full::new(Bytes::new())
2325 .map_err(|never| match never {})
2326 .boxed(),
2327 )
2328 .expect("valid preflight response");
2329 let headers = response.headers_mut();
2330 if let Ok(v) = allow_origin.parse() {
2331 headers.insert("access-control-allow-origin", v);
2332 }
2333 if let Ok(v) = config.allowed_methods.parse() {
2334 headers.insert("access-control-allow-methods", v);
2335 }
2336 if let Ok(v) = config.allowed_headers.parse() {
2337 headers.insert("access-control-allow-headers", v);
2338 }
2339 if let Ok(v) = config.max_age.parse() {
2340 headers.insert("access-control-max-age", v);
2341 }
2342 if config.allow_credentials {
2343 headers.insert(
2344 "access-control-allow-credentials",
2345 "true".parse().expect("valid header value"),
2346 );
2347 }
2348 return Ok(response);
2349 }
2350 }
2351 }
2352
2353 if let Some((entry, params)) = find_matching_route(routes.as_slice(), &method, &path) {
2354 req.extensions_mut().insert(params);
2355 let effective_layers = if entry.apply_global_layers {
2356 merge_layers(&layers, &entry.layers)
2357 } else {
2358 entry.layers.clone()
2359 };
2360
2361 if effective_layers.is_empty() {
2362 let (mut parts, body) = req.into_parts();
2363 if entry.needs_body {
2364 match BodyExt::collect(body).await {
2365 Ok(collected) => { parts.extensions.insert(BodyBytes(collected.to_bytes())); }
2366 Err(_) => {
2367 return Ok(json_error_response(
2368 StatusCode::BAD_REQUEST,
2369 "Failed to read request body",
2370 ));
2371 }
2372 }
2373 }
2374 #[allow(unused_mut)]
2375 let mut res = (entry.handler)(parts, &resources).await;
2376 #[cfg(feature = "http3")]
2377 if let Some(port) = alt_svc_port {
2378 if let Ok(val) =
2379 http::HeaderValue::from_str(&format!("h3=\":{}\"; ma=86400", port))
2380 {
2381 res.headers_mut().insert(http::header::ALT_SVC, val);
2382 }
2383 }
2384 Ok::<_, Infallible>(res)
2385 } else {
2386 let route_service = build_route_service(
2387 entry.handler.clone(),
2388 resources.clone(),
2389 effective_layers,
2390 entry.needs_body,
2391 );
2392 #[allow(unused_mut)]
2393 let mut res = route_service.call(req).await;
2394 #[cfg(feature = "http3")]
2395 #[allow(irrefutable_let_patterns)]
2396 if let Ok(ref mut r) = res {
2397 if let Some(port) = alt_svc_port {
2398 if let Ok(val) =
2399 http::HeaderValue::from_str(&format!("h3=\":{}\"; ma=86400", port))
2400 {
2401 r.headers_mut().insert(http::header::ALT_SVC, val);
2402 }
2403 }
2404 }
2405 res
2406 }
2407 } else {
2408 let req =
2409 match maybe_handle_static_request(req, &method, &path, static_assets.as_ref())
2410 .await
2411 {
2412 Ok(req) => req,
2413 Err(response) => return Ok(response),
2414 };
2415
2416 #[allow(unused_mut)]
2417 let mut fallback_res = if let Some(ref fb) = fallback {
2418 if layers.is_empty() {
2419 let (parts, _) = req.into_parts();
2420 Ok(fb(parts, &resources).await)
2421 } else {
2422 let fallback_service =
2423 build_route_service(fb.clone(), resources.clone(), layers.clone(), false);
2424 fallback_service.call(req).await
2425 }
2426 } else {
2427 Ok(Response::builder()
2428 .status(StatusCode::NOT_FOUND)
2429 .body(
2430 Full::new(Bytes::from("Not Found"))
2431 .map_err(|never| match never {})
2432 .boxed(),
2433 )
2434 .expect("valid HTTP response construction"))
2435 };
2436
2437 #[cfg(feature = "http3")]
2438 if let Ok(r) = fallback_res.as_mut() {
2439 if let Some(port) = alt_svc_port {
2440 if let Ok(val) =
2441 http::HeaderValue::from_str(&format!("h3=\":{}\"; ma=86400", port))
2442 {
2443 r.headers_mut().insert(http::header::ALT_SVC, val);
2444 }
2445 }
2446 }
2447
2448 fallback_res
2449 }
2450 }
2451 })
2452}
2453
2454fn build_route_service<R>(
2455 handler: RouteHandler<R>,
2456 resources: Arc<R>,
2457 layers: Arc<Vec<ServiceLayer>>,
2458 needs_body: bool,
2459) -> BoxHttpService
2460where
2461 R: ranvier_core::transition::ResourceRequirement + Clone + Send + Sync + 'static,
2462{
2463 let mut service = BoxService::new(move |req: Request<Incoming>| {
2464 let handler = handler.clone();
2465 let resources = resources.clone();
2466 async move {
2467 let (mut parts, body) = req.into_parts();
2468 if needs_body {
2469 match BodyExt::collect(body).await {
2470 Ok(collected) => { parts.extensions.insert(BodyBytes(collected.to_bytes())); }
2471 Err(_) => {
2472 return Ok(json_error_response(
2473 StatusCode::BAD_REQUEST,
2474 "Failed to read request body",
2475 ));
2476 }
2477 }
2478 }
2479 Ok::<_, Infallible>(handler(parts, &resources).await)
2480 }
2481 });
2482
2483 for layer in layers.iter() {
2484 service = layer(service);
2485 }
2486 service
2487}
2488
2489fn merge_layers(
2490 global_layers: &Arc<Vec<ServiceLayer>>,
2491 route_layers: &Arc<Vec<ServiceLayer>>,
2492) -> Arc<Vec<ServiceLayer>> {
2493 if global_layers.is_empty() {
2494 return route_layers.clone();
2495 }
2496 if route_layers.is_empty() {
2497 return global_layers.clone();
2498 }
2499
2500 let mut combined = Vec::with_capacity(global_layers.len() + route_layers.len());
2501 combined.extend(global_layers.iter().cloned());
2502 combined.extend(route_layers.iter().cloned());
2503 Arc::new(combined)
2504}
2505
2506async fn maybe_handle_health_request<R>(
2507 method: &Method,
2508 path: &str,
2509 health: &HealthConfig<R>,
2510 resources: Arc<R>,
2511) -> Option<HttpResponse>
2512where
2513 R: ranvier_core::transition::ResourceRequirement + Clone + Send + Sync + 'static,
2514{
2515 if method != Method::GET {
2516 return None;
2517 }
2518
2519 if let Some(liveness_path) = health.liveness_path.as_ref() {
2520 if path == liveness_path {
2521 return Some(health_json_response("liveness", true, Vec::new()));
2522 }
2523 }
2524
2525 if let Some(readiness_path) = health.readiness_path.as_ref() {
2526 if path == readiness_path {
2527 let (healthy, checks) = run_named_health_checks(&health.checks, resources).await;
2528 return Some(health_json_response("readiness", healthy, checks));
2529 }
2530 }
2531
2532 if let Some(health_path) = health.health_path.as_ref() {
2533 if path == health_path {
2534 let (healthy, checks) = run_named_health_checks(&health.checks, resources).await;
2535 return Some(health_json_response("health", healthy, checks));
2536 }
2537 }
2538
2539 None
2540}
2541
2542async fn serve_single_file(file_path: &str) -> Result<Response<Full<Bytes>>, std::io::Error> {
2544 let path = std::path::Path::new(file_path);
2545 let content = tokio::fs::read(path).await?;
2546 let mime = guess_mime(file_path);
2547 let mut response = Response::new(Full::new(Bytes::from(content)));
2548 if let Ok(value) = http::HeaderValue::from_str(mime) {
2549 response
2550 .headers_mut()
2551 .insert(http::header::CONTENT_TYPE, value);
2552 }
2553 if let Ok(metadata) = tokio::fs::metadata(path).await {
2554 if let Ok(modified) = metadata.modified() {
2555 if let Ok(duration) = modified.duration_since(std::time::UNIX_EPOCH) {
2556 let etag = format!("\"{}\"", duration.as_secs());
2557 if let Ok(value) = http::HeaderValue::from_str(&etag) {
2558 response.headers_mut().insert(http::header::ETAG, value);
2559 }
2560 }
2561 }
2562 }
2563 Ok(response)
2564}
2565
2566async fn serve_static_file(
2568 directory: &str,
2569 file_subpath: &str,
2570 config: &StaticAssetsConfig,
2571 if_none_match: Option<&http::HeaderValue>,
2572 accept_encoding: Option<&http::HeaderValue>,
2573 range_header: Option<&http::HeaderValue>,
2574) -> Result<Response<Full<Bytes>>, std::io::Error> {
2575 let subpath = file_subpath.trim_start_matches('/');
2576
2577 let resolved_subpath;
2579 if subpath.is_empty() || subpath.ends_with('/') {
2580 if let Some(ref index) = config.directory_index {
2581 resolved_subpath = if subpath.is_empty() {
2582 index.clone()
2583 } else {
2584 format!("{}{}", subpath, index)
2585 };
2586 } else {
2587 return Err(std::io::Error::new(
2588 std::io::ErrorKind::NotFound,
2589 "empty path",
2590 ));
2591 }
2592 } else {
2593 resolved_subpath = subpath.to_string();
2594 }
2595
2596 let full_path = std::path::Path::new(directory).join(&resolved_subpath);
2597 let canonical = tokio::fs::canonicalize(&full_path).await?;
2599 let dir_canonical = tokio::fs::canonicalize(directory).await?;
2600 if !canonical.starts_with(&dir_canonical) {
2601 return Err(std::io::Error::new(
2602 std::io::ErrorKind::PermissionDenied,
2603 "path traversal detected",
2604 ));
2605 }
2606
2607 let etag = if let Ok(metadata) = tokio::fs::metadata(&canonical).await {
2609 metadata
2610 .modified()
2611 .ok()
2612 .and_then(|m| m.duration_since(std::time::UNIX_EPOCH).ok())
2613 .map(|d| format!("\"{}\"", d.as_secs()))
2614 } else {
2615 None
2616 };
2617
2618 if let (Some(client_etag), Some(server_etag)) = (if_none_match, &etag) {
2620 if client_etag.as_bytes() == server_etag.as_bytes() {
2621 let mut response = Response::new(Full::new(Bytes::new()));
2622 *response.status_mut() = StatusCode::NOT_MODIFIED;
2623 if let Ok(value) = http::HeaderValue::from_str(server_etag) {
2624 response.headers_mut().insert(http::header::ETAG, value);
2625 }
2626 return Ok(response);
2627 }
2628 }
2629
2630 let (serve_path, content_encoding) = if config.serve_precompressed {
2632 let client_accepts = accept_encoding
2633 .and_then(|v| v.to_str().ok())
2634 .unwrap_or("");
2635 let canonical_str = canonical.to_str().unwrap_or("");
2636
2637 if client_accepts.contains("br") {
2638 let br_path = format!("{}.br", canonical_str);
2639 if tokio::fs::metadata(&br_path).await.is_ok() {
2640 (std::path::PathBuf::from(br_path), Some("br"))
2641 } else if client_accepts.contains("gzip") {
2642 let gz_path = format!("{}.gz", canonical_str);
2643 if tokio::fs::metadata(&gz_path).await.is_ok() {
2644 (std::path::PathBuf::from(gz_path), Some("gzip"))
2645 } else {
2646 (canonical.clone(), None)
2647 }
2648 } else {
2649 (canonical.clone(), None)
2650 }
2651 } else if client_accepts.contains("gzip") {
2652 let gz_path = format!("{}.gz", canonical_str);
2653 if tokio::fs::metadata(&gz_path).await.is_ok() {
2654 (std::path::PathBuf::from(gz_path), Some("gzip"))
2655 } else {
2656 (canonical.clone(), None)
2657 }
2658 } else {
2659 (canonical.clone(), None)
2660 }
2661 } else {
2662 (canonical.clone(), None)
2663 };
2664
2665 let content = tokio::fs::read(&serve_path).await?;
2666 let mime = guess_mime(canonical.to_str().unwrap_or(""));
2668
2669 if config.enable_range_requests {
2671 if let Some(range_val) = range_header {
2672 if let Some(response) = handle_range_request(
2673 range_val,
2674 &content,
2675 mime,
2676 etag.as_deref(),
2677 content_encoding,
2678 ) {
2679 return Ok(response);
2680 }
2681 }
2682 }
2683
2684 let mut response = Response::new(Full::new(Bytes::from(content)));
2685 if let Ok(value) = http::HeaderValue::from_str(mime) {
2686 response
2687 .headers_mut()
2688 .insert(http::header::CONTENT_TYPE, value);
2689 }
2690 if let Some(ref etag_val) = etag {
2691 if let Ok(value) = http::HeaderValue::from_str(etag_val) {
2692 response.headers_mut().insert(http::header::ETAG, value);
2693 }
2694 }
2695 if let Some(encoding) = content_encoding {
2696 if let Ok(value) = http::HeaderValue::from_str(encoding) {
2697 response
2698 .headers_mut()
2699 .insert(http::header::CONTENT_ENCODING, value);
2700 }
2701 }
2702 if config.enable_range_requests {
2703 response
2704 .headers_mut()
2705 .insert(http::header::ACCEPT_RANGES, http::HeaderValue::from_static("bytes"));
2706 }
2707
2708 if config.immutable_cache {
2710 if let Some(filename) = canonical.file_name().and_then(|n| n.to_str()) {
2711 if is_hashed_filename(filename) {
2712 if let Ok(value) = http::HeaderValue::from_str(
2713 "public, max-age=31536000, immutable",
2714 ) {
2715 response
2716 .headers_mut()
2717 .insert(http::header::CACHE_CONTROL, value);
2718 }
2719 }
2720 }
2721 }
2722
2723 Ok(response)
2724}
2725
2726fn handle_range_request(
2730 range_header: &http::HeaderValue,
2731 content: &[u8],
2732 mime: &str,
2733 etag: Option<&str>,
2734 content_encoding: Option<&str>,
2735) -> Option<Response<Full<Bytes>>> {
2736 let range_str = range_header.to_str().ok()?;
2737 let range_spec = range_str.strip_prefix("bytes=")?;
2738 let total = content.len();
2739 if total == 0 {
2740 return None;
2741 }
2742
2743 let (start, end) = if let Some(suffix) = range_spec.strip_prefix('-') {
2744 let n: usize = suffix.parse().ok()?;
2746 if n == 0 || n > total {
2747 return Some(range_not_satisfiable(total));
2748 }
2749 (total - n, total - 1)
2750 } else if range_spec.ends_with('-') {
2751 let start: usize = range_spec.trim_end_matches('-').parse().ok()?;
2753 if start >= total {
2754 return Some(range_not_satisfiable(total));
2755 }
2756 (start, total - 1)
2757 } else {
2758 let mut parts = range_spec.splitn(2, '-');
2760 let start: usize = parts.next()?.parse().ok()?;
2761 let end: usize = parts.next()?.parse().ok()?;
2762 if start > end || start >= total {
2763 return Some(range_not_satisfiable(total));
2764 }
2765 (start, end.min(total - 1))
2766 };
2767
2768 let slice = &content[start..=end];
2769 let content_range = format!("bytes {}-{}/{}", start, end, total);
2770
2771 let mut response = Response::new(Full::new(Bytes::copy_from_slice(slice)));
2772 *response.status_mut() = StatusCode::PARTIAL_CONTENT;
2773 if let Ok(v) = http::HeaderValue::from_str(&content_range) {
2774 response.headers_mut().insert(http::header::CONTENT_RANGE, v);
2775 }
2776 if let Ok(v) = http::HeaderValue::from_str(mime) {
2777 response
2778 .headers_mut()
2779 .insert(http::header::CONTENT_TYPE, v);
2780 }
2781 response
2782 .headers_mut()
2783 .insert(http::header::ACCEPT_RANGES, http::HeaderValue::from_static("bytes"));
2784 if let Some(etag_val) = etag {
2785 if let Ok(v) = http::HeaderValue::from_str(etag_val) {
2786 response.headers_mut().insert(http::header::ETAG, v);
2787 }
2788 }
2789 if let Some(encoding) = content_encoding {
2790 if let Ok(v) = http::HeaderValue::from_str(encoding) {
2791 response
2792 .headers_mut()
2793 .insert(http::header::CONTENT_ENCODING, v);
2794 }
2795 }
2796 Some(response)
2797}
2798
2799fn range_not_satisfiable(total: usize) -> Response<Full<Bytes>> {
2801 let content_range = format!("bytes */{}", total);
2802 let mut response = Response::new(Full::new(Bytes::from("Range Not Satisfiable")));
2803 *response.status_mut() = StatusCode::RANGE_NOT_SATISFIABLE;
2804 if let Ok(v) = http::HeaderValue::from_str(&content_range) {
2805 response.headers_mut().insert(http::header::CONTENT_RANGE, v);
2806 }
2807 response
2808}
2809
2810fn is_hashed_filename(filename: &str) -> bool {
2813 let parts: Vec<&str> = filename.rsplitn(3, '.').collect();
2814 if parts.len() < 3 {
2815 return false;
2816 }
2817 let hash_part = parts[1];
2819 hash_part.len() >= 6 && hash_part.chars().all(|c| c.is_ascii_hexdigit())
2820}
2821
2822fn guess_mime(path: &str) -> &'static str {
2823 match path.rsplit('.').next().unwrap_or("") {
2824 "html" | "htm" => "text/html; charset=utf-8",
2825 "css" => "text/css; charset=utf-8",
2826 "js" | "mjs" | "ts" | "tsx" => "application/javascript; charset=utf-8",
2827 "json" => "application/json; charset=utf-8",
2828 "png" => "image/png",
2829 "jpg" | "jpeg" => "image/jpeg",
2830 "gif" => "image/gif",
2831 "svg" => "image/svg+xml",
2832 "ico" => "image/x-icon",
2833 "avif" => "image/avif",
2834 "webp" => "image/webp",
2835 "webm" => "video/webm",
2836 "mp4" => "video/mp4",
2837 "woff" => "font/woff",
2838 "woff2" => "font/woff2",
2839 "ttf" => "font/ttf",
2840 "txt" => "text/plain; charset=utf-8",
2841 "xml" => "application/xml; charset=utf-8",
2842 "yaml" | "yml" => "application/yaml",
2843 "wasm" => "application/wasm",
2844 "pdf" => "application/pdf",
2845 "map" => "application/json",
2846 _ => "application/octet-stream",
2847 }
2848}
2849
2850fn apply_cache_control(
2851 mut response: Response<Full<Bytes>>,
2852 cache_control: Option<&str>,
2853) -> Response<Full<Bytes>> {
2854 if response.status() == StatusCode::OK {
2855 if let Some(value) = cache_control {
2856 if !response.headers().contains_key(http::header::CACHE_CONTROL) {
2857 if let Ok(header_value) = http::HeaderValue::from_str(value) {
2858 response
2859 .headers_mut()
2860 .insert(http::header::CACHE_CONTROL, header_value);
2861 }
2862 }
2863 }
2864 }
2865 response
2866}
2867
2868async fn maybe_handle_static_request(
2869 req: Request<Incoming>,
2870 method: &Method,
2871 path: &str,
2872 static_assets: &StaticAssetsConfig,
2873) -> Result<Request<Incoming>, HttpResponse> {
2874 if method != Method::GET && method != Method::HEAD {
2875 return Ok(req);
2876 }
2877
2878 if let Some(mount) = static_assets
2879 .mounts
2880 .iter()
2881 .find(|mount| strip_mount_prefix(path, &mount.route_prefix).is_some())
2882 {
2883 let accept_encoding = req.headers().get(http::header::ACCEPT_ENCODING).cloned();
2884 let if_none_match = req.headers().get(http::header::IF_NONE_MATCH).cloned();
2885 let range_header = req.headers().get(http::header::RANGE).cloned();
2886 let Some(stripped_path) = strip_mount_prefix(path, &mount.route_prefix) else {
2887 return Ok(req);
2888 };
2889 let response = match serve_static_file(
2890 &mount.directory,
2891 &stripped_path,
2892 static_assets,
2893 if_none_match.as_ref(),
2894 accept_encoding.as_ref(),
2895 range_header.as_ref(),
2896 )
2897 .await
2898 {
2899 Ok(response) => response,
2900 Err(_) => {
2901 return Err(Response::builder()
2902 .status(StatusCode::INTERNAL_SERVER_ERROR)
2903 .body(
2904 Full::new(Bytes::from("Failed to serve static asset"))
2905 .map_err(|never| match never {})
2906 .boxed(),
2907 )
2908 .unwrap_or_else(|_| {
2909 Response::new(
2910 Full::new(Bytes::new())
2911 .map_err(|never| match never {})
2912 .boxed(),
2913 )
2914 }));
2915 }
2916 };
2917 let mut response = apply_cache_control(response, static_assets.cache_control.as_deref());
2918 response = maybe_compress_static_response(
2919 response,
2920 accept_encoding,
2921 static_assets.enable_compression,
2922 );
2923 let (parts, body) = response.into_parts();
2924 return Err(Response::from_parts(
2925 parts,
2926 body.map_err(|never| match never {}).boxed(),
2927 ));
2928 }
2929
2930 if let Some(spa_file) = static_assets.spa_fallback.as_ref() {
2931 if looks_like_spa_request(path) {
2932 let accept_encoding = req.headers().get(http::header::ACCEPT_ENCODING).cloned();
2933 let response = match serve_single_file(spa_file).await {
2934 Ok(response) => response,
2935 Err(_) => {
2936 return Err(Response::builder()
2937 .status(StatusCode::INTERNAL_SERVER_ERROR)
2938 .body(
2939 Full::new(Bytes::from("Failed to serve SPA fallback"))
2940 .map_err(|never| match never {})
2941 .boxed(),
2942 )
2943 .unwrap_or_else(|_| {
2944 Response::new(
2945 Full::new(Bytes::new())
2946 .map_err(|never| match never {})
2947 .boxed(),
2948 )
2949 }));
2950 }
2951 };
2952 let mut response =
2953 apply_cache_control(response, static_assets.cache_control.as_deref());
2954 response = maybe_compress_static_response(
2955 response,
2956 accept_encoding,
2957 static_assets.enable_compression,
2958 );
2959 let (parts, body) = response.into_parts();
2960 return Err(Response::from_parts(
2961 parts,
2962 body.map_err(|never| match never {}).boxed(),
2963 ));
2964 }
2965 }
2966
2967 Ok(req)
2968}
2969
2970fn strip_mount_prefix(path: &str, prefix: &str) -> Option<String> {
2971 let normalized_prefix = if prefix == "/" {
2972 "/"
2973 } else {
2974 prefix.trim_end_matches('/')
2975 };
2976
2977 if normalized_prefix == "/" {
2978 return Some(path.to_string());
2979 }
2980
2981 if path == normalized_prefix {
2982 return Some("/".to_string());
2983 }
2984
2985 let with_slash = format!("{normalized_prefix}/");
2986 path.strip_prefix(&with_slash)
2987 .map(|stripped| format!("/{}", stripped))
2988}
2989
2990fn looks_like_spa_request(path: &str) -> bool {
2991 let tail = path.rsplit('/').next().unwrap_or_default();
2992 !tail.contains('.')
2993}
2994
2995fn maybe_compress_static_response(
2996 response: Response<Full<Bytes>>,
2997 accept_encoding: Option<http::HeaderValue>,
2998 enable_compression: bool,
2999) -> Response<Full<Bytes>> {
3000 if !enable_compression {
3001 return response;
3002 }
3003
3004 let Some(accept_encoding) = accept_encoding else {
3005 return response;
3006 };
3007
3008 let accept_str = accept_encoding.to_str().unwrap_or("");
3009 if !accept_str.contains("gzip") {
3010 return response;
3011 }
3012
3013 let status = response.status();
3014 let headers = response.headers().clone();
3015 let body = response.into_body();
3016
3017 let data = futures_util::FutureExt::now_or_never(BodyExt::collect(body))
3019 .and_then(|r| r.ok())
3020 .map(|collected| collected.to_bytes())
3021 .unwrap_or_default();
3022
3023 let compressed = {
3025 use flate2::write::GzEncoder;
3026 use flate2::Compression;
3027 use std::io::Write;
3028 let mut encoder = GzEncoder::new(Vec::new(), Compression::default());
3029 let _ = encoder.write_all(&data);
3030 encoder.finish().unwrap_or_default()
3031 };
3032
3033 let mut builder = Response::builder().status(status);
3034 for (name, value) in headers.iter() {
3035 if name != http::header::CONTENT_LENGTH && name != http::header::CONTENT_ENCODING {
3036 builder = builder.header(name, value);
3037 }
3038 }
3039 builder
3040 .header(http::header::CONTENT_ENCODING, "gzip")
3041 .body(Full::new(Bytes::from(compressed)))
3042 .unwrap_or_else(|_| Response::new(Full::new(Bytes::new())))
3043}
3044
3045async fn run_named_health_checks<R>(
3046 checks: &[NamedHealthCheck<R>],
3047 resources: Arc<R>,
3048) -> (bool, Vec<HealthCheckReport>)
3049where
3050 R: ranvier_core::transition::ResourceRequirement + Clone + Send + Sync + 'static,
3051{
3052 let mut reports = Vec::with_capacity(checks.len());
3053 let mut healthy = true;
3054
3055 for check in checks {
3056 match (check.check)(resources.clone()).await {
3057 Ok(()) => reports.push(HealthCheckReport {
3058 name: check.name.clone(),
3059 status: "ok",
3060 error: None,
3061 }),
3062 Err(error) => {
3063 healthy = false;
3064 reports.push(HealthCheckReport {
3065 name: check.name.clone(),
3066 status: "error",
3067 error: Some(error),
3068 });
3069 }
3070 }
3071 }
3072
3073 (healthy, reports)
3074}
3075
3076fn health_json_response(
3077 probe: &'static str,
3078 healthy: bool,
3079 checks: Vec<HealthCheckReport>,
3080) -> HttpResponse {
3081 let status_code = if healthy {
3082 StatusCode::OK
3083 } else {
3084 StatusCode::SERVICE_UNAVAILABLE
3085 };
3086 let status = if healthy { "ok" } else { "degraded" };
3087 let payload = HealthReport {
3088 status,
3089 probe,
3090 checks,
3091 };
3092
3093 let body = serde_json::to_vec(&payload)
3094 .unwrap_or_else(|_| br#"{"status":"error","probe":"health"}"#.to_vec());
3095
3096 Response::builder()
3097 .status(status_code)
3098 .header(http::header::CONTENT_TYPE, "application/json")
3099 .body(
3100 Full::new(Bytes::from(body))
3101 .map_err(|never| match never {})
3102 .boxed(),
3103 )
3104 .expect("valid HTTP response construction")
3105}
3106
3107async fn shutdown_signal() {
3108 #[cfg(unix)]
3109 {
3110 use tokio::signal::unix::{SignalKind, signal};
3111
3112 match signal(SignalKind::terminate()) {
3113 Ok(mut terminate) => {
3114 tokio::select! {
3115 _ = tokio::signal::ctrl_c() => {}
3116 _ = terminate.recv() => {}
3117 }
3118 }
3119 Err(err) => {
3120 tracing::warn!("Failed to install SIGTERM handler: {:?}", err);
3121 if let Err(ctrl_c_err) = tokio::signal::ctrl_c().await {
3122 tracing::warn!("Failed to listen for Ctrl+C: {:?}", ctrl_c_err);
3123 }
3124 }
3125 }
3126 }
3127
3128 #[cfg(not(unix))]
3129 {
3130 if let Err(err) = tokio::signal::ctrl_c().await {
3131 tracing::warn!("Failed to listen for Ctrl+C: {:?}", err);
3132 }
3133 }
3134}
3135
3136async fn drain_connections(
3137 connections: &mut tokio::task::JoinSet<()>,
3138 graceful_shutdown_timeout: Duration,
3139) -> bool {
3140 if connections.is_empty() {
3141 return false;
3142 }
3143
3144 let drain_result = tokio::time::timeout(graceful_shutdown_timeout, async {
3145 while let Some(join_result) = connections.join_next().await {
3146 if let Err(err) = join_result {
3147 tracing::warn!("Connection task join error during shutdown: {:?}", err);
3148 }
3149 }
3150 })
3151 .await;
3152
3153 if drain_result.is_err() {
3154 tracing::warn!(
3155 "Graceful shutdown timeout reached ({:?}). Aborting remaining connections.",
3156 graceful_shutdown_timeout
3157 );
3158 connections.abort_all();
3159 while let Some(join_result) = connections.join_next().await {
3160 if let Err(err) = join_result {
3161 tracing::warn!("Connection task abort join error: {:?}", err);
3162 }
3163 }
3164 true
3165 } else {
3166 false
3167 }
3168}
3169
3170#[cfg(feature = "tls")]
3172fn build_tls_acceptor(
3173 cert_path: &str,
3174 key_path: &str,
3175) -> Result<tokio_rustls::TlsAcceptor, Box<dyn std::error::Error + Send + Sync>> {
3176 use rustls::ServerConfig;
3177 use rustls_pemfile::{certs, private_key};
3178 use std::io::BufReader;
3179 use tokio_rustls::TlsAcceptor;
3180
3181 let cert_file = std::fs::File::open(cert_path)
3182 .map_err(|e| format!("Failed to open certificate file '{}': {}", cert_path, e))?;
3183 let key_file = std::fs::File::open(key_path)
3184 .map_err(|e| format!("Failed to open key file '{}': {}", key_path, e))?;
3185
3186 let cert_chain: Vec<_> = certs(&mut BufReader::new(cert_file))
3187 .collect::<Result<Vec<_>, _>>()
3188 .map_err(|e| format!("Failed to parse certificate PEM: {}", e))?;
3189
3190 let key = private_key(&mut BufReader::new(key_file))
3191 .map_err(|e| format!("Failed to parse private key PEM: {}", e))?
3192 .ok_or("No private key found in key file")?;
3193
3194 let config = ServerConfig::builder()
3195 .with_no_client_auth()
3196 .with_single_cert(cert_chain, key)
3197 .map_err(|e| format!("TLS configuration error: {}", e))?;
3198
3199 Ok(TlsAcceptor::from(Arc::new(config)))
3200}
3201
3202impl<R> Default for HttpIngress<R>
3203where
3204 R: ranvier_core::transition::ResourceRequirement + Clone + Send + Sync + 'static,
3205{
3206 fn default() -> Self {
3207 Self::new()
3208 }
3209}
3210
3211#[derive(Clone)]
3213pub struct RawIngressService<R> {
3214 routes: Arc<Vec<RouteEntry<R>>>,
3215 fallback: Option<RouteHandler<R>>,
3216 layers: Arc<Vec<ServiceLayer>>,
3217 health: Arc<HealthConfig<R>>,
3218 static_assets: Arc<StaticAssetsConfig>,
3219 preflight_config: Arc<Option<PreflightConfig>>,
3220 resources: Arc<R>,
3221}
3222
3223impl<R> hyper::service::Service<Request<Incoming>> for RawIngressService<R>
3224where
3225 R: ranvier_core::transition::ResourceRequirement + Clone + Send + Sync + 'static,
3226{
3227 type Response = HttpResponse;
3228 type Error = Infallible;
3229 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
3230
3231 fn call(&self, req: Request<Incoming>) -> Self::Future {
3232 let routes = self.routes.clone();
3233 let fallback = self.fallback.clone();
3234 let layers = self.layers.clone();
3235 let health = self.health.clone();
3236 let static_assets = self.static_assets.clone();
3237 let preflight_config = self.preflight_config.clone();
3238 let resources = self.resources.clone();
3239
3240 Box::pin(async move {
3241 let service = build_http_service(
3242 routes,
3243 fallback,
3244 resources,
3245 layers,
3246 health,
3247 static_assets,
3248 preflight_config,
3249 #[cfg(feature = "http3")]
3250 None,
3251 );
3252 service.call(req).await
3253 })
3254 }
3255}
3256
3257#[cfg(test)]
3258mod tests {
3259 use super::*;
3260 use async_trait::async_trait;
3261 use futures_util::{SinkExt, StreamExt};
3262 use serde::Deserialize;
3263 use std::fs;
3264 use std::sync::atomic::{AtomicBool, Ordering};
3265 use tempfile::tempdir;
3266 use tokio::io::{AsyncReadExt, AsyncWriteExt};
3267 use tokio_tungstenite::tungstenite::Message as WsClientMessage;
3268 use tokio_tungstenite::tungstenite::client::IntoClientRequest;
3269
3270 async fn connect_with_retry(addr: std::net::SocketAddr) -> tokio::net::TcpStream {
3271 let deadline = tokio::time::Instant::now() + Duration::from_secs(2);
3272
3273 loop {
3274 match tokio::net::TcpStream::connect(addr).await {
3275 Ok(stream) => return stream,
3276 Err(error) => {
3277 if tokio::time::Instant::now() >= deadline {
3278 panic!("connect server: {error}");
3279 }
3280 tokio::time::sleep(Duration::from_millis(25)).await;
3281 }
3282 }
3283 }
3284 }
3285
3286 #[test]
3287 fn route_pattern_matches_static_path() {
3288 let pattern = RoutePattern::parse("/orders/list");
3289 let params = pattern.match_path("/orders/list").expect("should match");
3290 assert!(params.into_inner().is_empty());
3291 }
3292
3293 #[test]
3294 fn route_pattern_matches_param_segments() {
3295 let pattern = RoutePattern::parse("/orders/:id/items/:item_id");
3296 let params = pattern
3297 .match_path("/orders/42/items/sku-123")
3298 .expect("should match");
3299 assert_eq!(params.get("id"), Some("42"));
3300 assert_eq!(params.get("item_id"), Some("sku-123"));
3301 }
3302
3303 #[test]
3304 fn route_pattern_matches_wildcard_segment() {
3305 let pattern = RoutePattern::parse("/assets/*path");
3306 let params = pattern
3307 .match_path("/assets/css/theme/light.css")
3308 .expect("should match");
3309 assert_eq!(params.get("path"), Some("css/theme/light.css"));
3310 }
3311
3312 #[test]
3313 fn route_pattern_rejects_non_matching_path() {
3314 let pattern = RoutePattern::parse("/orders/:id");
3315 assert!(pattern.match_path("/users/42").is_none());
3316 }
3317
3318 #[test]
3319 fn graceful_shutdown_timeout_defaults_to_30_seconds() {
3320 let ingress = HttpIngress::<()>::new();
3321 assert_eq!(ingress.graceful_shutdown_timeout, Duration::from_secs(30));
3322 assert!(ingress.layers.is_empty());
3323 assert!(ingress.bus_injectors.is_empty());
3324 assert!(ingress.static_assets.mounts.is_empty());
3325 assert!(ingress.on_start.is_none());
3326 assert!(ingress.on_shutdown.is_none());
3327 }
3328
3329 #[test]
3330 fn route_without_layer_keeps_empty_route_middleware_stack() {
3331 let ingress =
3332 HttpIngress::<()>::new().get("/ping", Axon::<(), (), String, ()>::new("Ping"));
3333 assert_eq!(ingress.routes.len(), 1);
3334 assert!(ingress.routes[0].layers.is_empty());
3335 assert!(ingress.routes[0].apply_global_layers);
3336 }
3337
3338 #[test]
3339 fn timeout_layer_registers_builtin_middleware() {
3340 let ingress = HttpIngress::<()>::new().timeout_layer(Duration::from_secs(1));
3341 assert_eq!(ingress.layers.len(), 1);
3342 }
3343
3344 #[test]
3345 fn request_id_layer_registers_builtin_middleware() {
3346 let ingress = HttpIngress::<()>::new().request_id_layer();
3347 assert_eq!(ingress.layers.len(), 1);
3348 }
3349
3350 #[test]
3351 fn compression_layer_registers_builtin_middleware() {
3352 let ingress = HttpIngress::<()>::new().compression_layer();
3353 assert!(ingress.static_assets.enable_compression);
3354 }
3355
3356 #[test]
3357 fn bus_injector_registration_adds_hook() {
3358 let ingress = HttpIngress::<()>::new().bus_injector(|_req, bus| {
3359 bus.insert("ok".to_string());
3360 });
3361 assert_eq!(ingress.bus_injectors.len(), 1);
3362 }
3363
3364 #[test]
3365 fn ws_route_registers_get_route_pattern() {
3366 let ingress =
3367 HttpIngress::<()>::new().ws("/ws/events", |_socket, _resources, _bus| async {});
3368 assert_eq!(ingress.routes.len(), 1);
3369 assert_eq!(ingress.routes[0].method, Method::GET);
3370 assert_eq!(ingress.routes[0].pattern.raw, "/ws/events");
3371 }
3372
3373 #[derive(Debug, Deserialize)]
3374 struct WsWelcomeFrame {
3375 connection_id: String,
3376 path: String,
3377 tenant: String,
3378 }
3379
3380 #[tokio::test]
3381 async fn ws_route_upgrades_and_bridges_event_source_sink_with_connection_bus() {
3382 let probe = std::net::TcpListener::bind("127.0.0.1:0").expect("bind probe");
3383 let addr = probe.local_addr().expect("local addr");
3384 drop(probe);
3385
3386 let ingress = HttpIngress::<()>::new()
3387 .bind(addr.to_string())
3388 .bus_injector(|req, bus| {
3389 if let Some(value) = req.headers.get("x-tenant-id").and_then(|v| v.to_str().ok()) {
3390 bus.insert(value.to_string());
3391 }
3392 })
3393 .ws("/ws/echo", |mut socket, _resources, bus| async move {
3394 let tenant = bus
3395 .read::<String>()
3396 .cloned()
3397 .unwrap_or_else(|| "unknown".to_string());
3398 if let Some(session) = bus.read::<WebSocketSessionContext>() {
3399 let welcome = serde_json::json!({
3400 "connection_id": session.connection_id().to_string(),
3401 "path": session.path(),
3402 "tenant": tenant,
3403 });
3404 let _ = socket.send_json(&welcome).await;
3405 }
3406
3407 while let Some(event) = socket.next_event().await {
3408 match event {
3409 WebSocketEvent::Text(text) => {
3410 let _ = socket.send_event(format!("echo:{text}")).await;
3411 }
3412 WebSocketEvent::Binary(bytes) => {
3413 let _ = socket.send_event(bytes).await;
3414 }
3415 WebSocketEvent::Close => break,
3416 WebSocketEvent::Ping(_) | WebSocketEvent::Pong(_) => {}
3417 }
3418 }
3419 });
3420
3421 let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel::<()>();
3422 let server = tokio::spawn(async move {
3423 ingress
3424 .run_with_shutdown_signal((), async move {
3425 let _ = shutdown_rx.await;
3426 })
3427 .await
3428 });
3429
3430 let ws_uri = format!("ws://{addr}/ws/echo?room=alpha");
3431 let mut ws_request = ws_uri
3432 .as_str()
3433 .into_client_request()
3434 .expect("ws client request");
3435 ws_request
3436 .headers_mut()
3437 .insert("x-tenant-id", http::HeaderValue::from_static("acme"));
3438 let (mut client, _response) = tokio_tungstenite::connect_async(ws_request)
3439 .await
3440 .expect("websocket connect");
3441
3442 let welcome = client
3443 .next()
3444 .await
3445 .expect("welcome frame")
3446 .expect("welcome frame ok");
3447 let welcome_text = match welcome {
3448 WsClientMessage::Text(text) => text.to_string(),
3449 other => panic!("expected text welcome frame, got {other:?}"),
3450 };
3451 let welcome_payload: WsWelcomeFrame =
3452 serde_json::from_str(&welcome_text).expect("welcome json");
3453 assert_eq!(welcome_payload.path, "/ws/echo");
3454 assert_eq!(welcome_payload.tenant, "acme");
3455 assert!(!welcome_payload.connection_id.is_empty());
3456
3457 client
3458 .send(WsClientMessage::Text("hello".into()))
3459 .await
3460 .expect("send text");
3461 let echo_text = client
3462 .next()
3463 .await
3464 .expect("echo text frame")
3465 .expect("echo text frame ok");
3466 assert_eq!(echo_text, WsClientMessage::Text("echo:hello".into()));
3467
3468 client
3469 .send(WsClientMessage::Binary(vec![1, 2, 3, 4].into()))
3470 .await
3471 .expect("send binary");
3472 let echo_binary = client
3473 .next()
3474 .await
3475 .expect("echo binary frame")
3476 .expect("echo binary frame ok");
3477 assert_eq!(
3478 echo_binary,
3479 WsClientMessage::Binary(vec![1, 2, 3, 4].into())
3480 );
3481
3482 client.close(None).await.expect("close websocket");
3483
3484 let _ = shutdown_tx.send(());
3485 server
3486 .await
3487 .expect("server join")
3488 .expect("server shutdown should succeed");
3489 }
3490
3491 #[test]
3492 fn route_descriptors_export_http_and_health_paths() {
3493 let ingress = HttpIngress::<()>::new()
3494 .get(
3495 "/orders/:id",
3496 Axon::<(), (), String, ()>::new("OrderById"),
3497 )
3498 .health_endpoint("/healthz")
3499 .readiness_liveness("/readyz", "/livez");
3500
3501 let descriptors = ingress.route_descriptors();
3502
3503 assert!(
3504 descriptors
3505 .iter()
3506 .any(|descriptor| descriptor.method() == Method::GET
3507 && descriptor.path_pattern() == "/orders/:id")
3508 );
3509 assert!(
3510 descriptors
3511 .iter()
3512 .any(|descriptor| descriptor.method() == Method::GET
3513 && descriptor.path_pattern() == "/healthz")
3514 );
3515 assert!(
3516 descriptors
3517 .iter()
3518 .any(|descriptor| descriptor.method() == Method::GET
3519 && descriptor.path_pattern() == "/readyz")
3520 );
3521 assert!(
3522 descriptors
3523 .iter()
3524 .any(|descriptor| descriptor.method() == Method::GET
3525 && descriptor.path_pattern() == "/livez")
3526 );
3527 }
3528
3529 #[tokio::test]
3530 async fn lifecycle_hooks_fire_on_start_and_shutdown() {
3531 let started = Arc::new(AtomicBool::new(false));
3532 let shutdown = Arc::new(AtomicBool::new(false));
3533
3534 let started_flag = started.clone();
3535 let shutdown_flag = shutdown.clone();
3536
3537 let ingress = HttpIngress::<()>::new()
3538 .bind("127.0.0.1:0")
3539 .on_start(move || {
3540 started_flag.store(true, Ordering::SeqCst);
3541 })
3542 .on_shutdown(move || {
3543 shutdown_flag.store(true, Ordering::SeqCst);
3544 })
3545 .graceful_shutdown(Duration::from_millis(50));
3546
3547 ingress
3548 .run_with_shutdown_signal((), async {
3549 tokio::time::sleep(Duration::from_millis(20)).await;
3550 })
3551 .await
3552 .expect("server should exit gracefully");
3553
3554 assert!(started.load(Ordering::SeqCst));
3555 assert!(shutdown.load(Ordering::SeqCst));
3556 }
3557
3558 #[tokio::test]
3559 async fn graceful_shutdown_drains_in_flight_requests_before_exit() {
3560 #[derive(Clone)]
3561 struct SlowDrainRoute;
3562
3563 #[async_trait]
3564 impl Transition<(), String> for SlowDrainRoute {
3565 type Error = String;
3566 type Resources = ();
3567
3568 async fn run(
3569 &self,
3570 _state: (),
3571 _resources: &Self::Resources,
3572 _bus: &mut Bus,
3573 ) -> Outcome<String, Self::Error> {
3574 tokio::time::sleep(Duration::from_millis(120)).await;
3575 Outcome::next("drained-ok".to_string())
3576 }
3577 }
3578
3579 let probe = std::net::TcpListener::bind("127.0.0.1:0").expect("bind probe");
3580 let addr = probe.local_addr().expect("local addr");
3581 drop(probe);
3582
3583 let ingress = HttpIngress::<()>::new()
3584 .bind(addr.to_string())
3585 .graceful_shutdown(Duration::from_millis(500))
3586 .get(
3587 "/drain",
3588 Axon::<(), (), String, ()>::new("SlowDrain").then(SlowDrainRoute),
3589 );
3590
3591 let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel::<()>();
3592 let server = tokio::spawn(async move {
3593 ingress
3594 .run_with_shutdown_signal((), async move {
3595 let _ = shutdown_rx.await;
3596 })
3597 .await
3598 });
3599
3600 let mut stream = connect_with_retry(addr).await;
3601 stream
3602 .write_all(b"GET /drain HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n")
3603 .await
3604 .expect("write request");
3605
3606 tokio::time::sleep(Duration::from_millis(20)).await;
3607 let _ = shutdown_tx.send(());
3608
3609 let mut buf = Vec::new();
3610 stream.read_to_end(&mut buf).await.expect("read response");
3611 let response = String::from_utf8_lossy(&buf);
3612 assert!(response.starts_with("HTTP/1.1 200"), "{response}");
3613 assert!(response.contains("drained-ok"), "{response}");
3614
3615 server
3616 .await
3617 .expect("server join")
3618 .expect("server shutdown should succeed");
3619 }
3620
3621 #[tokio::test]
3622 async fn serve_dir_serves_static_file_with_cache_and_metadata_headers() {
3623 let temp = tempdir().expect("tempdir");
3624 let root = temp.path().join("public");
3625 fs::create_dir_all(&root).expect("create dir");
3626 let file = root.join("hello.txt");
3627 fs::write(&file, "hello static").expect("write file");
3628
3629 let ingress =
3630 Ranvier::http::<()>().serve_dir("/static", root.to_string_lossy().to_string());
3631 let app = crate::test_harness::TestApp::new(ingress, ());
3632 let response = app
3633 .send(crate::test_harness::TestRequest::get("/static/hello.txt"))
3634 .await
3635 .expect("request should succeed");
3636
3637 assert_eq!(response.status(), StatusCode::OK);
3638 assert_eq!(response.text().expect("utf8"), "hello static");
3639 assert!(response.header("cache-control").is_some());
3640 let has_metadata_header =
3641 response.header("etag").is_some() || response.header("last-modified").is_some();
3642 assert!(has_metadata_header);
3643 }
3644
3645 #[tokio::test]
3646 async fn spa_fallback_returns_index_for_unmatched_path() {
3647 let temp = tempdir().expect("tempdir");
3648 let index = temp.path().join("index.html");
3649 fs::write(&index, "<html><body>spa</body></html>").expect("write index");
3650
3651 let ingress = Ranvier::http::<()>().spa_fallback(index.to_string_lossy().to_string());
3652 let app = crate::test_harness::TestApp::new(ingress, ());
3653 let response = app
3654 .send(crate::test_harness::TestRequest::get("/dashboard/settings"))
3655 .await
3656 .expect("request should succeed");
3657
3658 assert_eq!(response.status(), StatusCode::OK);
3659 assert!(response.text().expect("utf8").contains("spa"));
3660 }
3661
3662 #[tokio::test]
3663 async fn static_compression_layer_sets_content_encoding_for_gzip_client() {
3664 let temp = tempdir().expect("tempdir");
3665 let root = temp.path().join("public");
3666 fs::create_dir_all(&root).expect("create dir");
3667 let file = root.join("compressed.txt");
3668 fs::write(&file, "compress me ".repeat(400)).expect("write file");
3669
3670 let ingress = Ranvier::http::<()>()
3671 .serve_dir("/static", root.to_string_lossy().to_string())
3672 .compression_layer();
3673 let app = crate::test_harness::TestApp::new(ingress, ());
3674 let response = app
3675 .send(
3676 crate::test_harness::TestRequest::get("/static/compressed.txt")
3677 .header("accept-encoding", "gzip"),
3678 )
3679 .await
3680 .expect("request should succeed");
3681
3682 assert_eq!(response.status(), StatusCode::OK);
3683 assert_eq!(
3684 response
3685 .header("content-encoding")
3686 .and_then(|value| value.to_str().ok()),
3687 Some("gzip")
3688 );
3689 }
3690
3691 #[tokio::test]
3692 async fn drain_connections_completes_before_timeout() {
3693 let mut connections = tokio::task::JoinSet::new();
3694 connections.spawn(async {
3695 tokio::time::sleep(Duration::from_millis(20)).await;
3696 });
3697
3698 let timed_out = drain_connections(&mut connections, Duration::from_millis(200)).await;
3699 assert!(!timed_out);
3700 assert!(connections.is_empty());
3701 }
3702
3703 #[tokio::test]
3704 async fn drain_connections_times_out_and_aborts() {
3705 let mut connections = tokio::task::JoinSet::new();
3706 connections.spawn(async {
3707 tokio::time::sleep(Duration::from_secs(10)).await;
3708 });
3709
3710 let timed_out = drain_connections(&mut connections, Duration::from_millis(10)).await;
3711 assert!(timed_out);
3712 assert!(connections.is_empty());
3713 }
3714
3715 #[tokio::test]
3716 async fn timeout_layer_returns_408_for_slow_route() {
3717 #[derive(Clone)]
3718 struct SlowRoute;
3719
3720 #[async_trait]
3721 impl Transition<(), String> for SlowRoute {
3722 type Error = String;
3723 type Resources = ();
3724
3725 async fn run(
3726 &self,
3727 _state: (),
3728 _resources: &Self::Resources,
3729 _bus: &mut Bus,
3730 ) -> Outcome<String, Self::Error> {
3731 tokio::time::sleep(Duration::from_millis(80)).await;
3732 Outcome::next("slow-ok".to_string())
3733 }
3734 }
3735
3736 let probe = std::net::TcpListener::bind("127.0.0.1:0").expect("bind probe");
3737 let addr = probe.local_addr().expect("local addr");
3738 drop(probe);
3739
3740 let ingress = HttpIngress::<()>::new()
3741 .bind(addr.to_string())
3742 .timeout_layer(Duration::from_millis(10))
3743 .get(
3744 "/slow",
3745 Axon::<(), (), String, ()>::new("Slow").then(SlowRoute),
3746 );
3747
3748 let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel::<()>();
3749 let server = tokio::spawn(async move {
3750 ingress
3751 .run_with_shutdown_signal((), async move {
3752 let _ = shutdown_rx.await;
3753 })
3754 .await
3755 });
3756
3757 let mut stream = connect_with_retry(addr).await;
3758 stream
3759 .write_all(b"GET /slow HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n")
3760 .await
3761 .expect("write request");
3762
3763 let mut buf = Vec::new();
3764 stream.read_to_end(&mut buf).await.expect("read response");
3765 let response = String::from_utf8_lossy(&buf);
3766 assert!(response.starts_with("HTTP/1.1 408"), "{response}");
3767
3768 let _ = shutdown_tx.send(());
3769 server
3770 .await
3771 .expect("server join")
3772 .expect("server shutdown should succeed");
3773 }
3774
3775 fn extract_body(response: Response<Full<Bytes>>) -> Bytes {
3778 use http_body_util::BodyExt;
3779 let rt = tokio::runtime::Builder::new_current_thread()
3780 .build()
3781 .unwrap();
3782 rt.block_on(async {
3783 let collected = response.into_body().collect().await.unwrap();
3784 collected.to_bytes()
3785 })
3786 }
3787
3788 #[test]
3789 fn handle_range_bytes_start_end() {
3790 let content = b"Hello, World!";
3791 let range = http::HeaderValue::from_static("bytes=0-4");
3792 let response =
3793 super::handle_range_request(&range, content, "text/plain", None, None).unwrap();
3794 assert_eq!(response.status(), StatusCode::PARTIAL_CONTENT);
3795 assert_eq!(
3796 response.headers().get(http::header::CONTENT_RANGE).unwrap(),
3797 "bytes 0-4/13"
3798 );
3799 assert_eq!(extract_body(response), "Hello");
3800 }
3801
3802 #[test]
3803 fn handle_range_suffix() {
3804 let content = b"Hello, World!";
3805 let range = http::HeaderValue::from_static("bytes=-6");
3806 let response =
3807 super::handle_range_request(&range, content, "text/plain", None, None).unwrap();
3808 assert_eq!(response.status(), StatusCode::PARTIAL_CONTENT);
3809 assert_eq!(
3810 response.headers().get(http::header::CONTENT_RANGE).unwrap(),
3811 "bytes 7-12/13"
3812 );
3813 }
3814
3815 #[test]
3816 fn handle_range_from_offset() {
3817 let content = b"Hello, World!";
3818 let range = http::HeaderValue::from_static("bytes=7-");
3819 let response =
3820 super::handle_range_request(&range, content, "text/plain", None, None).unwrap();
3821 assert_eq!(response.status(), StatusCode::PARTIAL_CONTENT);
3822 assert_eq!(
3823 response.headers().get(http::header::CONTENT_RANGE).unwrap(),
3824 "bytes 7-12/13"
3825 );
3826 }
3827
3828 #[test]
3829 fn handle_range_out_of_bounds_returns_416() {
3830 let content = b"Hello";
3831 let range = http::HeaderValue::from_static("bytes=10-20");
3832 let response =
3833 super::handle_range_request(&range, content, "text/plain", None, None).unwrap();
3834 assert_eq!(response.status(), StatusCode::RANGE_NOT_SATISFIABLE);
3835 assert_eq!(
3836 response.headers().get(http::header::CONTENT_RANGE).unwrap(),
3837 "bytes */5"
3838 );
3839 }
3840
3841 #[test]
3842 fn handle_range_includes_accept_ranges_header() {
3843 let content = b"Hello, World!";
3844 let range = http::HeaderValue::from_static("bytes=0-0");
3845 let response =
3846 super::handle_range_request(&range, content, "text/plain", None, None).unwrap();
3847 assert_eq!(
3848 response.headers().get(http::header::ACCEPT_RANGES).unwrap(),
3849 "bytes"
3850 );
3851 }
3852
3853}