1use base64::Engine;
17use bytes::Bytes;
18use futures_util::{SinkExt, StreamExt};
19use http::{Method, Request, Response, StatusCode};
20use http_body_util::{BodyExt, Full};
21use hyper::body::Incoming;
22use hyper::server::conn::http1;
23use hyper::upgrade::Upgraded;
24use hyper_util::rt::TokioIo;
25use ranvier_core::event::{EventSink, EventSource};
26use ranvier_core::prelude::*;
27use ranvier_runtime::Axon;
28use serde::Serialize;
29use serde::de::DeserializeOwned;
30use sha1::{Digest, Sha1};
31use std::collections::HashMap;
32use std::convert::Infallible;
33use std::future::Future;
34use std::net::SocketAddr;
35use std::pin::Pin;
36use std::sync::Arc;
37use std::time::Duration;
38use tokio::net::TcpListener;
39use tokio::sync::Mutex;
40use tokio_tungstenite::WebSocketStream;
41use tokio_tungstenite::tungstenite::{Error as WsWireError, Message as WsWireMessage};
42use tracing::Instrument;
43
44use crate::guard_integration::{
45 GuardExec, GuardIntegration, PreflightConfig, RegisteredGuard, ResponseBodyTransformFn,
46 ResponseExtractorFn,
47};
48use crate::response::{HttpResponse, IntoResponse, json_error_response, outcome_to_response_with_error};
49
50pub struct Ranvier;
55
56impl Ranvier {
57 pub fn http<R>() -> HttpIngress<R>
59 where
60 R: ranvier_core::transition::ResourceRequirement + Clone,
61 {
62 HttpIngress::new()
63 }
64}
65
66type RouteHandler<R> = Arc<
68 dyn Fn(http::request::Parts, &R) -> Pin<Box<dyn Future<Output = HttpResponse> + Send>>
69 + Send
70 + Sync,
71>;
72
73#[derive(Clone)]
75struct BoxService(
76 Arc<
77 dyn Fn(Request<Incoming>) -> Pin<Box<dyn Future<Output = Result<HttpResponse, Infallible>> + Send>>
78 + Send
79 + Sync,
80 >,
81);
82
83impl BoxService {
84 fn new<F, Fut>(f: F) -> Self
85 where
86 F: Fn(Request<Incoming>) -> Fut + Send + Sync + 'static,
87 Fut: Future<Output = Result<HttpResponse, Infallible>> + Send + 'static,
88 {
89 Self(Arc::new(move |req| Box::pin(f(req))))
90 }
91
92 fn call(&self, req: Request<Incoming>) -> Pin<Box<dyn Future<Output = Result<HttpResponse, Infallible>> + Send>> {
93 (self.0)(req)
94 }
95}
96
97impl hyper::service::Service<Request<Incoming>> for BoxService {
98 type Response = HttpResponse;
99 type Error = Infallible;
100 type Future = Pin<Box<dyn Future<Output = Result<HttpResponse, Infallible>> + Send>>;
101
102 fn call(&self, req: Request<Incoming>) -> Self::Future {
103 (self.0)(req)
104 }
105}
106
107type BoxHttpService = BoxService;
108type ServiceLayer = Arc<dyn Fn(BoxHttpService) -> BoxHttpService + Send + Sync>;
109type LifecycleHook = Arc<dyn Fn() + Send + Sync>;
110type BusInjector = Arc<dyn Fn(&http::request::Parts, &mut Bus) + Send + Sync + 'static>;
111type WsSessionFuture = Pin<Box<dyn Future<Output = ()> + Send>>;
112type WsSessionHandler<R> =
113 Arc<dyn Fn(WebSocketConnection, Arc<R>, Bus) -> WsSessionFuture + Send + Sync>;
114type HealthCheckFuture = Pin<Box<dyn Future<Output = Result<(), String>> + Send>>;
115type HealthCheckFn<R> = Arc<dyn Fn(Arc<R>) -> HealthCheckFuture + Send + Sync>;
116const REQUEST_ID_HEADER: &str = "x-request-id";
117const WS_UPGRADE_TOKEN: &str = "websocket";
118const WS_GUID: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
119
120#[derive(Clone)]
121struct NamedHealthCheck<R> {
122 name: String,
123 check: HealthCheckFn<R>,
124}
125
126#[derive(Clone)]
127struct HealthConfig<R> {
128 health_path: Option<String>,
129 readiness_path: Option<String>,
130 liveness_path: Option<String>,
131 checks: Vec<NamedHealthCheck<R>>,
132}
133
134impl<R> Default for HealthConfig<R> {
135 fn default() -> Self {
136 Self {
137 health_path: None,
138 readiness_path: None,
139 liveness_path: None,
140 checks: Vec::new(),
141 }
142 }
143}
144
145#[derive(Clone, Default)]
146struct StaticAssetsConfig {
147 mounts: Vec<StaticMount>,
148 spa_fallback: Option<String>,
149 cache_control: Option<String>,
150 enable_compression: bool,
151 directory_index: Option<String>,
153 immutable_cache: bool,
156 serve_precompressed: bool,
158 enable_range_requests: bool,
160}
161
162#[derive(Clone)]
163struct StaticMount {
164 route_prefix: String,
165 directory: String,
166}
167
168#[cfg(feature = "tls")]
170#[derive(Clone)]
171struct TlsAcceptorConfig {
172 cert_path: String,
173 key_path: String,
174}
175
176#[derive(Serialize)]
177struct HealthReport {
178 status: &'static str,
179 probe: &'static str,
180 checks: Vec<HealthCheckReport>,
181}
182
183#[derive(Serialize)]
184struct HealthCheckReport {
185 name: String,
186 status: &'static str,
187 #[serde(skip_serializing_if = "Option::is_none")]
188 error: Option<String>,
189}
190
191fn timeout_middleware(timeout: Duration) -> ServiceLayer {
192 Arc::new(move |inner: BoxHttpService| {
193 BoxService::new(move |req: Request<Incoming>| {
194 let inner = inner.clone();
195 async move {
196 match tokio::time::timeout(timeout, inner.call(req)).await {
197 Ok(response) => response,
198 Err(_) => Ok(Response::builder()
199 .status(StatusCode::REQUEST_TIMEOUT)
200 .body(
201 Full::new(Bytes::from("Request Timeout"))
202 .map_err(|never| match never {})
203 .boxed(),
204 )
205 .expect("valid HTTP response construction")),
206 }
207 }
208 })
209 })
210}
211
212fn request_id_middleware() -> ServiceLayer {
213 Arc::new(move |inner: BoxHttpService| {
214 BoxService::new(move |req: Request<Incoming>| {
215 let inner = inner.clone();
216 async move {
217 let mut req = req;
218 let request_id = req
219 .headers()
220 .get(REQUEST_ID_HEADER)
221 .cloned()
222 .unwrap_or_else(|| {
223 http::HeaderValue::from_str(&uuid::Uuid::new_v4().to_string())
224 .unwrap_or_else(|_| {
225 http::HeaderValue::from_static("request-id-unavailable")
226 })
227 });
228 req.headers_mut()
229 .insert(REQUEST_ID_HEADER, request_id.clone());
230 let mut response = inner.call(req).await?;
231 response
232 .headers_mut()
233 .insert(REQUEST_ID_HEADER, request_id);
234 Ok(response)
235 }
236 })
237 })
238}
239
240#[derive(Clone, Debug, Default, PartialEq, Eq)]
241pub struct PathParams {
242 values: HashMap<String, String>,
243}
244
245#[derive(Clone, Debug)]
247pub struct HttpRouteDescriptor {
248 method: Method,
249 path_pattern: String,
250 pub body_schema: Option<serde_json::Value>,
252}
253
254impl HttpRouteDescriptor {
255 pub fn new(method: Method, path_pattern: impl Into<String>) -> Self {
256 Self {
257 method,
258 path_pattern: path_pattern.into(),
259 body_schema: None,
260 }
261 }
262
263 pub fn method(&self) -> &Method {
264 &self.method
265 }
266
267 pub fn path_pattern(&self) -> &str {
268 &self.path_pattern
269 }
270
271 pub fn body_schema(&self) -> Option<&serde_json::Value> {
276 self.body_schema.as_ref()
277 }
278}
279
280#[derive(Clone, Debug, PartialEq, Eq, Serialize)]
282pub struct WebSocketSessionContext {
283 connection_id: uuid::Uuid,
284 path: String,
285 query: Option<String>,
286}
287
288impl WebSocketSessionContext {
289 pub fn connection_id(&self) -> uuid::Uuid {
290 self.connection_id
291 }
292
293 pub fn path(&self) -> &str {
294 &self.path
295 }
296
297 pub fn query(&self) -> Option<&str> {
298 self.query.as_deref()
299 }
300}
301
302#[derive(Clone, Debug, PartialEq, Eq)]
304pub enum WebSocketEvent {
305 Text(String),
306 Binary(Vec<u8>),
307 Ping(Vec<u8>),
308 Pong(Vec<u8>),
309 Close,
310}
311
312impl WebSocketEvent {
313 pub fn text(value: impl Into<String>) -> Self {
314 Self::Text(value.into())
315 }
316
317 pub fn binary(value: impl Into<Vec<u8>>) -> Self {
318 Self::Binary(value.into())
319 }
320
321 pub fn json<T>(value: &T) -> Result<Self, serde_json::Error>
322 where
323 T: Serialize,
324 {
325 let text = serde_json::to_string(value)?;
326 Ok(Self::Text(text))
327 }
328}
329
330#[derive(Debug, thiserror::Error)]
331pub enum WebSocketError {
332 #[error("websocket wire error: {0}")]
333 Wire(#[from] WsWireError),
334 #[error("json serialization failed: {0}")]
335 JsonSerialize(#[source] serde_json::Error),
336 #[error("json deserialization failed: {0}")]
337 JsonDeserialize(#[source] serde_json::Error),
338 #[error("expected text or binary frame for json payload")]
339 NonDataFrame,
340}
341
342type WsServerStream = WebSocketStream<TokioIo<Upgraded>>;
343type WsServerSink = futures_util::stream::SplitSink<WsServerStream, WsWireMessage>;
344type WsServerSource = futures_util::stream::SplitStream<WsServerStream>;
345
346pub struct WebSocketConnection {
348 sink: Mutex<WsServerSink>,
349 source: Mutex<WsServerSource>,
350 session: WebSocketSessionContext,
351}
352
353impl WebSocketConnection {
354 fn new(stream: WsServerStream, session: WebSocketSessionContext) -> Self {
355 let (sink, source) = stream.split();
356 Self {
357 sink: Mutex::new(sink),
358 source: Mutex::new(source),
359 session,
360 }
361 }
362
363 pub fn session(&self) -> &WebSocketSessionContext {
364 &self.session
365 }
366
367 pub async fn send(&self, event: WebSocketEvent) -> Result<(), WebSocketError> {
368 let mut sink = self.sink.lock().await;
369 sink.send(event.into_wire_message()).await?;
370 Ok(())
371 }
372
373 pub async fn send_json<T>(&self, value: &T) -> Result<(), WebSocketError>
374 where
375 T: Serialize,
376 {
377 let event = WebSocketEvent::json(value).map_err(WebSocketError::JsonSerialize)?;
378 self.send(event).await
379 }
380
381 pub async fn next_json<T>(&mut self) -> Result<Option<T>, WebSocketError>
382 where
383 T: DeserializeOwned,
384 {
385 let Some(event) = self.recv_event().await? else {
386 return Ok(None);
387 };
388 match event {
389 WebSocketEvent::Text(text) => serde_json::from_str(&text)
390 .map(Some)
391 .map_err(WebSocketError::JsonDeserialize),
392 WebSocketEvent::Binary(bytes) => serde_json::from_slice(&bytes)
393 .map(Some)
394 .map_err(WebSocketError::JsonDeserialize),
395 _ => Err(WebSocketError::NonDataFrame),
396 }
397 }
398
399 async fn recv_event(&mut self) -> Result<Option<WebSocketEvent>, WsWireError> {
400 let mut source = self.source.lock().await;
401 while let Some(item) = source.next().await {
402 let message = item?;
403 if let Some(event) = WebSocketEvent::from_wire_message(message) {
404 return Ok(Some(event));
405 }
406 }
407 Ok(None)
408 }
409}
410
411impl WebSocketEvent {
412 fn from_wire_message(message: WsWireMessage) -> Option<Self> {
413 match message {
414 WsWireMessage::Text(value) => Some(Self::Text(value.to_string())),
415 WsWireMessage::Binary(value) => Some(Self::Binary(value.to_vec())),
416 WsWireMessage::Ping(value) => Some(Self::Ping(value.to_vec())),
417 WsWireMessage::Pong(value) => Some(Self::Pong(value.to_vec())),
418 WsWireMessage::Close(_) => Some(Self::Close),
419 WsWireMessage::Frame(_) => None,
420 }
421 }
422
423 fn into_wire_message(self) -> WsWireMessage {
424 match self {
425 Self::Text(value) => WsWireMessage::Text(value),
426 Self::Binary(value) => WsWireMessage::Binary(value),
427 Self::Ping(value) => WsWireMessage::Ping(value),
428 Self::Pong(value) => WsWireMessage::Pong(value),
429 Self::Close => WsWireMessage::Close(None),
430 }
431 }
432}
433
434#[async_trait::async_trait]
435impl EventSource<WebSocketEvent> for WebSocketConnection {
436 async fn next_event(&mut self) -> Option<WebSocketEvent> {
437 match self.recv_event().await {
438 Ok(event) => event,
439 Err(error) => {
440 tracing::warn!(ranvier.ws.error = %error, "websocket source read failed");
441 None
442 }
443 }
444 }
445}
446
447#[async_trait::async_trait]
448impl EventSink<WebSocketEvent> for WebSocketConnection {
449 type Error = WebSocketError;
450
451 async fn send_event(&self, event: WebSocketEvent) -> Result<(), Self::Error> {
452 self.send(event).await
453 }
454}
455
456#[async_trait::async_trait]
457impl EventSink<String> for WebSocketConnection {
458 type Error = WebSocketError;
459
460 async fn send_event(&self, event: String) -> Result<(), Self::Error> {
461 self.send(WebSocketEvent::Text(event)).await
462 }
463}
464
465#[async_trait::async_trait]
466impl EventSink<Vec<u8>> for WebSocketConnection {
467 type Error = WebSocketError;
468
469 async fn send_event(&self, event: Vec<u8>) -> Result<(), Self::Error> {
470 self.send(WebSocketEvent::Binary(event)).await
471 }
472}
473
474#[derive(Clone, Debug, Default, PartialEq, Eq)]
487pub struct QueryParams {
488 values: HashMap<String, String>,
489}
490
491impl QueryParams {
492 pub fn from_query(query: &str) -> Self {
494 let values = query
495 .split('&')
496 .filter(|s| !s.is_empty())
497 .filter_map(|pair| {
498 let mut parts = pair.splitn(2, '=');
499 let key = parts.next()?.to_string();
500 let value = parts.next().unwrap_or("").to_string();
501 Some((key, value))
502 })
503 .collect();
504 Self { values }
505 }
506
507 pub fn get(&self, key: &str) -> Option<&str> {
509 self.values.get(key).map(String::as_str)
510 }
511
512 pub fn get_parsed<T: std::str::FromStr>(&self, key: &str) -> Option<T> {
516 self.values.get(key).and_then(|v| v.parse().ok())
517 }
518
519 pub fn get_or<T: std::str::FromStr>(&self, key: &str, default: T) -> T {
521 self.get_parsed(key).unwrap_or(default)
522 }
523
524 pub fn contains(&self, key: &str) -> bool {
526 self.values.contains_key(key)
527 }
528
529 pub fn as_map(&self) -> &HashMap<String, String> {
531 &self.values
532 }
533}
534
535impl PathParams {
536 pub fn new(values: HashMap<String, String>) -> Self {
537 Self { values }
538 }
539
540 pub fn get(&self, key: &str) -> Option<&str> {
541 self.values.get(key).map(String::as_str)
542 }
543
544 pub fn get_parsed<T: std::str::FromStr>(&self, key: &str) -> Option<T> {
554 self.values.get(key).and_then(|v| v.parse().ok())
555 }
556
557 pub fn as_map(&self) -> &HashMap<String, String> {
558 &self.values
559 }
560
561 pub fn into_inner(self) -> HashMap<String, String> {
562 self.values
563 }
564}
565
566fn inject_query_params(parts: &http::request::Parts, bus: &mut ranvier_core::bus::Bus) {
570 if let Some(query) = parts.uri.query() {
571 bus.insert(QueryParams::from_query(query));
572 } else {
573 bus.insert(QueryParams::default());
574 }
575}
576
577#[derive(Clone, Debug, PartialEq, Eq)]
578enum RouteSegment {
579 Static(String),
580 Param(String),
581 Wildcard(String),
582}
583
584#[derive(Clone, Debug, PartialEq, Eq)]
585struct RoutePattern {
586 raw: String,
587 segments: Vec<RouteSegment>,
588}
589
590impl RoutePattern {
591 fn parse(path: &str) -> Self {
592 let segments = path_segments(path)
593 .into_iter()
594 .map(|segment| {
595 if let Some(name) = segment.strip_prefix(':') {
596 if !name.is_empty() {
597 return RouteSegment::Param(name.to_string());
598 }
599 }
600 if let Some(name) = segment.strip_prefix('*') {
601 if !name.is_empty() {
602 return RouteSegment::Wildcard(name.to_string());
603 }
604 }
605 RouteSegment::Static(segment.to_string())
606 })
607 .collect();
608
609 Self {
610 raw: path.to_string(),
611 segments,
612 }
613 }
614
615 fn match_path(&self, path: &str) -> Option<PathParams> {
616 let mut params = HashMap::new();
617 let path_segments = path_segments(path);
618 let mut pattern_index = 0usize;
619 let mut path_index = 0usize;
620
621 while pattern_index < self.segments.len() {
622 match &self.segments[pattern_index] {
623 RouteSegment::Static(expected) => {
624 let actual = path_segments.get(path_index)?;
625 if actual != expected {
626 return None;
627 }
628 pattern_index += 1;
629 path_index += 1;
630 }
631 RouteSegment::Param(name) => {
632 let actual = path_segments.get(path_index)?;
633 params.insert(name.clone(), (*actual).to_string());
634 pattern_index += 1;
635 path_index += 1;
636 }
637 RouteSegment::Wildcard(name) => {
638 let remaining = path_segments[path_index..].join("/");
639 params.insert(name.clone(), remaining);
640 pattern_index += 1;
641 path_index = path_segments.len();
642 break;
643 }
644 }
645 }
646
647 if pattern_index == self.segments.len() && path_index == path_segments.len() {
648 Some(PathParams::new(params))
649 } else {
650 None
651 }
652 }
653}
654
655#[derive(Clone)]
658struct BodyBytes(Bytes);
659
660#[derive(Clone)]
661struct RouteEntry<R> {
662 method: Method,
663 pattern: RoutePattern,
664 handler: RouteHandler<R>,
665 layers: Arc<Vec<ServiceLayer>>,
666 apply_global_layers: bool,
667 needs_body: bool,
670 body_schema: Option<serde_json::Value>,
672}
673
674fn path_segments(path: &str) -> Vec<&str> {
675 if path == "/" {
676 return Vec::new();
677 }
678
679 path.trim_matches('/')
680 .split('/')
681 .filter(|segment| !segment.is_empty())
682 .collect()
683}
684
685fn normalize_route_path(path: String) -> String {
686 if path.is_empty() {
687 return "/".to_string();
688 }
689 if path.starts_with('/') {
690 path
691 } else {
692 format!("/{path}")
693 }
694}
695
696fn find_matching_route<'a, R>(
697 routes: &'a [RouteEntry<R>],
698 method: &Method,
699 path: &str,
700) -> Option<(&'a RouteEntry<R>, PathParams)> {
701 for entry in routes {
702 if entry.method != *method {
703 continue;
704 }
705 if let Some(params) = entry.pattern.match_path(path) {
706 return Some((entry, params));
707 }
708 }
709 None
710}
711
712fn header_contains_token(
713 headers: &http::HeaderMap,
714 name: http::header::HeaderName,
715 token: &str,
716) -> bool {
717 headers
718 .get(name)
719 .and_then(|value| value.to_str().ok())
720 .map(|value| {
721 value
722 .split(',')
723 .any(|part| part.trim().eq_ignore_ascii_case(token))
724 })
725 .unwrap_or(false)
726}
727
728fn websocket_session_from_request<B>(req: &Request<B>) -> WebSocketSessionContext {
729 WebSocketSessionContext {
730 connection_id: uuid::Uuid::new_v4(),
731 path: req.uri().path().to_string(),
732 query: req.uri().query().map(str::to_string),
733 }
734}
735
736fn websocket_accept_key(client_key: &str) -> String {
737 let mut hasher = Sha1::new();
738 hasher.update(client_key.as_bytes());
739 hasher.update(WS_GUID.as_bytes());
740 let digest = hasher.finalize();
741 base64::engine::general_purpose::STANDARD.encode(digest)
742}
743
744fn websocket_bad_request(message: &'static str) -> HttpResponse {
745 Response::builder()
746 .status(StatusCode::BAD_REQUEST)
747 .body(
748 Full::new(Bytes::from(message))
749 .map_err(|never| match never {})
750 .boxed(),
751 )
752 .unwrap_or_else(|_| {
753 Response::new(
754 Full::new(Bytes::new())
755 .map_err(|never| match never {})
756 .boxed(),
757 )
758 })
759}
760
761fn websocket_upgrade_response<B>(
762 req: &mut Request<B>,
763) -> Result<(HttpResponse, hyper::upgrade::OnUpgrade), HttpResponse> {
764 if req.method() != Method::GET {
765 return Err(websocket_bad_request(
766 "WebSocket upgrade requires GET method",
767 ));
768 }
769
770 if !header_contains_token(req.headers(), http::header::CONNECTION, "upgrade") {
771 return Err(websocket_bad_request(
772 "Missing Connection: upgrade header for WebSocket",
773 ));
774 }
775
776 if !header_contains_token(req.headers(), http::header::UPGRADE, WS_UPGRADE_TOKEN) {
777 return Err(websocket_bad_request("Missing Upgrade: websocket header"));
778 }
779
780 if let Some(version) = req.headers().get("sec-websocket-version") {
781 if version != "13" {
782 return Err(websocket_bad_request(
783 "Unsupported Sec-WebSocket-Version (expected 13)",
784 ));
785 }
786 }
787
788 let Some(client_key) = req
789 .headers()
790 .get("sec-websocket-key")
791 .and_then(|value| value.to_str().ok())
792 else {
793 return Err(websocket_bad_request(
794 "Missing Sec-WebSocket-Key header for WebSocket",
795 ));
796 };
797
798 let accept_key = websocket_accept_key(client_key);
799 let on_upgrade = hyper::upgrade::on(req);
800 let response = Response::builder()
801 .status(StatusCode::SWITCHING_PROTOCOLS)
802 .header(http::header::UPGRADE, WS_UPGRADE_TOKEN)
803 .header(http::header::CONNECTION, "Upgrade")
804 .header("sec-websocket-accept", accept_key)
805 .body(
806 Full::new(Bytes::new())
807 .map_err(|never| match never {})
808 .boxed(),
809 )
810 .unwrap_or_else(|_| {
811 Response::new(
812 Full::new(Bytes::new())
813 .map_err(|never| match never {})
814 .boxed(),
815 )
816 });
817
818 Ok((response, on_upgrade))
819}
820
821pub struct HttpIngress<R = ()> {
827 addr: Option<String>,
829 routes: Vec<RouteEntry<R>>,
831 fallback: Option<RouteHandler<R>>,
833 layers: Vec<ServiceLayer>,
835 on_start: Option<LifecycleHook>,
837 on_shutdown: Option<LifecycleHook>,
839 graceful_shutdown_timeout: Duration,
841 bus_injectors: Vec<BusInjector>,
843 static_assets: StaticAssetsConfig,
845 health: HealthConfig<R>,
847 #[cfg(feature = "http3")]
848 http3_config: Option<crate::http3::Http3Config>,
849 #[cfg(feature = "http3")]
850 alt_svc_h3_port: Option<u16>,
851 #[cfg(feature = "tls")]
853 tls_config: Option<TlsAcceptorConfig>,
854 active_intervention: bool,
856 policy_registry: Option<ranvier_core::policy::PolicyRegistry>,
858 guard_execs: Vec<Arc<dyn GuardExec>>,
860 guard_response_extractors: Vec<ResponseExtractorFn>,
862 guard_body_transforms: Vec<ResponseBodyTransformFn>,
864 preflight_config: Option<PreflightConfig>,
866 _phantom: std::marker::PhantomData<R>,
867}
868
869impl<R> HttpIngress<R>
870where
871 R: ranvier_core::transition::ResourceRequirement + Clone + Send + Sync + 'static,
872{
873 pub fn new() -> Self {
875 Self {
876 addr: None,
877 routes: Vec::new(),
878 fallback: None,
879 layers: Vec::new(),
880 on_start: None,
881 on_shutdown: None,
882 graceful_shutdown_timeout: Duration::from_secs(30),
883 bus_injectors: Vec::new(),
884 static_assets: StaticAssetsConfig::default(),
885 health: HealthConfig::default(),
886 #[cfg(feature = "tls")]
887 tls_config: None,
888 #[cfg(feature = "http3")]
889 http3_config: None,
890 #[cfg(feature = "http3")]
891 alt_svc_h3_port: None,
892 active_intervention: false,
893 policy_registry: None,
894 guard_execs: Vec::new(),
895 guard_response_extractors: Vec::new(),
896 guard_body_transforms: Vec::new(),
897 preflight_config: None,
898 _phantom: std::marker::PhantomData,
899 }
900 }
901
902 pub fn bind(mut self, addr: impl Into<String>) -> Self {
906 self.addr = Some(addr.into());
907 self
908 }
909
910 pub fn active_intervention(mut self) -> Self {
916 self.active_intervention = true;
917 self
918 }
919
920 pub fn policy_registry(mut self, registry: ranvier_core::policy::PolicyRegistry) -> Self {
922 self.policy_registry = Some(registry);
923 self
924 }
925
926 pub fn on_start<F>(mut self, callback: F) -> Self
930 where
931 F: Fn() + Send + Sync + 'static,
932 {
933 self.on_start = Some(Arc::new(callback));
934 self
935 }
936
937 pub fn on_shutdown<F>(mut self, callback: F) -> Self
939 where
940 F: Fn() + Send + Sync + 'static,
941 {
942 self.on_shutdown = Some(Arc::new(callback));
943 self
944 }
945
946 pub fn graceful_shutdown(mut self, timeout: Duration) -> Self {
948 self.graceful_shutdown_timeout = timeout;
949 self
950 }
951
952 pub fn config(mut self, config: &ranvier_core::config::RanvierConfig) -> Self {
958 self.addr = Some(config.bind_addr());
959 self.graceful_shutdown_timeout = config.shutdown_timeout();
960 config.init_telemetry();
961 self
962 }
963
964 #[cfg(feature = "tls")]
966 pub fn tls(mut self, cert_path: impl Into<String>, key_path: impl Into<String>) -> Self {
967 self.tls_config = Some(TlsAcceptorConfig {
968 cert_path: cert_path.into(),
969 key_path: key_path.into(),
970 });
971 self
972 }
973
974 pub fn timeout_layer(mut self, timeout: Duration) -> Self {
979 self.layers.push(timeout_middleware(timeout));
980 self
981 }
982
983 pub fn request_id_layer(mut self) -> Self {
987 self.layers.push(request_id_middleware());
988 self
989 }
990
991 pub fn bus_injector<F>(mut self, injector: F) -> Self
996 where
997 F: Fn(&http::request::Parts, &mut Bus) + Send + Sync + 'static,
998 {
999 self.bus_injectors.push(Arc::new(injector));
1000 self
1001 }
1002
1003 #[cfg(feature = "htmx")]
1011 pub fn htmx_support(mut self) -> Self {
1012 self.bus_injectors
1013 .push(Arc::new(crate::htmx::inject_htmx_headers));
1014 self.guard_response_extractors
1015 .push(Arc::new(crate::htmx::extract_htmx_response_headers));
1016 self
1017 }
1018
1019 pub fn guard(mut self, guard: impl GuardIntegration) -> Self {
1039 let registration = guard.register();
1040 for injector in registration.bus_injectors {
1041 self.bus_injectors.push(injector);
1042 }
1043 self.guard_execs.push(registration.exec);
1044 if let Some(extractor) = registration.response_extractor {
1045 self.guard_response_extractors.push(extractor);
1046 }
1047 if let Some(transform) = registration.response_body_transform {
1048 self.guard_body_transforms.push(transform);
1049 }
1050 if registration.handles_preflight {
1051 if let Some(config) = registration.preflight_config {
1052 self.preflight_config = Some(config);
1053 }
1054 }
1055 self
1056 }
1057
1058 #[cfg(feature = "http3")]
1060 pub fn enable_http3(mut self, config: crate::http3::Http3Config) -> Self {
1061 self.http3_config = Some(config);
1062 self
1063 }
1064
1065 #[cfg(feature = "http3")]
1067 pub fn alt_svc_h3(mut self, port: u16) -> Self {
1068 self.alt_svc_h3_port = Some(port);
1069 self
1070 }
1071
1072 pub fn route_descriptors(&self) -> Vec<HttpRouteDescriptor> {
1076 let mut descriptors = self
1077 .routes
1078 .iter()
1079 .map(|entry| {
1080 let mut desc = HttpRouteDescriptor::new(entry.method.clone(), entry.pattern.raw.clone());
1081 desc.body_schema = entry.body_schema.clone();
1082 desc
1083 })
1084 .collect::<Vec<_>>();
1085
1086 if let Some(path) = &self.health.health_path {
1087 descriptors.push(HttpRouteDescriptor::new(Method::GET, path.clone()));
1088 }
1089 if let Some(path) = &self.health.readiness_path {
1090 descriptors.push(HttpRouteDescriptor::new(Method::GET, path.clone()));
1091 }
1092 if let Some(path) = &self.health.liveness_path {
1093 descriptors.push(HttpRouteDescriptor::new(Method::GET, path.clone()));
1094 }
1095
1096 descriptors
1097 }
1098
1099 pub fn serve_dir(
1105 mut self,
1106 route_prefix: impl Into<String>,
1107 directory: impl Into<String>,
1108 ) -> Self {
1109 self.static_assets.mounts.push(StaticMount {
1110 route_prefix: normalize_route_path(route_prefix.into()),
1111 directory: directory.into(),
1112 });
1113 if self.static_assets.cache_control.is_none() {
1114 self.static_assets.cache_control = Some("public, max-age=3600".to_string());
1115 }
1116 self
1117 }
1118
1119 pub fn spa_fallback(mut self, file_path: impl Into<String>) -> Self {
1123 self.static_assets.spa_fallback = Some(file_path.into());
1124 self
1125 }
1126
1127 pub fn static_cache_control(mut self, cache_control: impl Into<String>) -> Self {
1129 self.static_assets.cache_control = Some(cache_control.into());
1130 self
1131 }
1132
1133 pub fn directory_index(mut self, filename: impl Into<String>) -> Self {
1141 self.static_assets.directory_index = Some(filename.into());
1142 self
1143 }
1144
1145 pub fn immutable_cache(mut self) -> Self {
1150 self.static_assets.immutable_cache = true;
1151 self
1152 }
1153
1154 pub fn serve_precompressed(mut self) -> Self {
1160 self.static_assets.serve_precompressed = true;
1161 self
1162 }
1163
1164 pub fn enable_range_requests(mut self) -> Self {
1169 self.static_assets.enable_range_requests = true;
1170 self
1171 }
1172
1173 pub fn compression_layer(mut self) -> Self {
1175 self.static_assets.enable_compression = true;
1176 self
1177 }
1178
1179 pub fn ws<H, Fut>(mut self, path: impl Into<String>, handler: H) -> Self
1188 where
1189 H: Fn(WebSocketConnection, Arc<R>, Bus) -> Fut + Send + Sync + 'static,
1190 Fut: Future<Output = ()> + Send + 'static,
1191 {
1192 let path_str: String = path.into();
1193 let ws_handler: WsSessionHandler<R> = Arc::new(move |connection, resources, bus| {
1194 Box::pin(handler(connection, resources, bus))
1195 });
1196 let bus_injectors = Arc::new(self.bus_injectors.clone());
1197 let ws_guard_execs = Arc::new(self.guard_execs.clone());
1198 let path_for_pattern = path_str.clone();
1199 let path_for_handler = path_str;
1200
1201 let route_handler: RouteHandler<R> =
1202 Arc::new(move |parts: http::request::Parts, res: &R| {
1203 let ws_handler = ws_handler.clone();
1204 let bus_injectors = bus_injectors.clone();
1205 let ws_guard_execs = ws_guard_execs.clone();
1206 let resources = Arc::new(res.clone());
1207 let path = path_for_handler.clone();
1208
1209 Box::pin(async move {
1210 let request_id = uuid::Uuid::new_v4().to_string();
1211 let span = tracing::info_span!(
1212 "WebSocketUpgrade",
1213 ranvier.ws.path = %path,
1214 ranvier.ws.request_id = %request_id
1215 );
1216
1217 async move {
1218 let mut bus = Bus::new();
1219 inject_query_params(&parts, &mut bus);
1220 for injector in bus_injectors.iter() {
1221 injector(&parts, &mut bus);
1222 }
1223 for guard_exec in ws_guard_execs.iter() {
1224 if let Err(rejection) = guard_exec.exec_guard(&mut bus).await {
1225 return json_error_response(rejection.status, &rejection.message);
1226 }
1227 }
1228
1229 let mut req = Request::from_parts(parts, ());
1231 let session = websocket_session_from_request(&req);
1232 bus.insert(session.clone());
1233
1234 let (response, on_upgrade) = match websocket_upgrade_response(&mut req) {
1235 Ok(result) => result,
1236 Err(error_response) => return error_response,
1237 };
1238
1239 tokio::spawn(async move {
1240 match on_upgrade.await {
1241 Ok(upgraded) => {
1242 let stream = WebSocketStream::from_raw_socket(
1243 TokioIo::new(upgraded),
1244 tokio_tungstenite::tungstenite::protocol::Role::Server,
1245 None,
1246 )
1247 .await;
1248 let connection = WebSocketConnection::new(stream, session);
1249 ws_handler(connection, resources, bus).await;
1250 }
1251 Err(error) => {
1252 tracing::warn!(
1253 ranvier.ws.path = %path,
1254 ranvier.ws.error = %error,
1255 "websocket upgrade failed"
1256 );
1257 }
1258 }
1259 });
1260
1261 response
1262 }
1263 .instrument(span)
1264 .await
1265 }) as Pin<Box<dyn Future<Output = HttpResponse> + Send>>
1266 });
1267
1268 self.routes.push(RouteEntry {
1269 method: Method::GET,
1270 pattern: RoutePattern::parse(&path_for_pattern),
1271 handler: route_handler,
1272 layers: Arc::new(Vec::new()),
1273 apply_global_layers: true,
1274 needs_body: false,
1275 body_schema: None,
1276 });
1277
1278 self
1279 }
1280
1281 pub fn health_endpoint(mut self, path: impl Into<String>) -> Self {
1288 self.health.health_path = Some(normalize_route_path(path.into()));
1289 self
1290 }
1291
1292 pub fn health_check<F, Fut, Err>(mut self, name: impl Into<String>, check: F) -> Self
1296 where
1297 F: Fn(Arc<R>) -> Fut + Send + Sync + 'static,
1298 Fut: Future<Output = Result<(), Err>> + Send + 'static,
1299 Err: ToString + Send + 'static,
1300 {
1301 if self.health.health_path.is_none() {
1302 self.health.health_path = Some("/health".to_string());
1303 }
1304
1305 let check_fn: HealthCheckFn<R> = Arc::new(move |resources: Arc<R>| {
1306 let fut = check(resources);
1307 Box::pin(async move { fut.await.map_err(|error| error.to_string()) })
1308 });
1309
1310 self.health.checks.push(NamedHealthCheck {
1311 name: name.into(),
1312 check: check_fn,
1313 });
1314 self
1315 }
1316
1317 pub fn readiness_liveness(
1319 mut self,
1320 readiness_path: impl Into<String>,
1321 liveness_path: impl Into<String>,
1322 ) -> Self {
1323 self.health.readiness_path = Some(normalize_route_path(readiness_path.into()));
1324 self.health.liveness_path = Some(normalize_route_path(liveness_path.into()));
1325 self
1326 }
1327
1328 pub fn readiness_liveness_default(self) -> Self {
1330 self.readiness_liveness("/ready", "/live")
1331 }
1332
1333 pub fn route<Out, E>(self, path: impl Into<String>, circuit: Axon<(), Out, E, R>) -> Self
1337 where
1338 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1339 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1340 {
1341 self.route_method(Method::GET, path, circuit)
1342 }
1343 pub fn route_method<Out, E>(
1352 self,
1353 method: Method,
1354 path: impl Into<String>,
1355 circuit: Axon<(), Out, E, R>,
1356 ) -> Self
1357 where
1358 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1359 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1360 {
1361 self.route_method_with_error(method, path, circuit, |error| {
1362 (
1363 StatusCode::INTERNAL_SERVER_ERROR,
1364 format!("Error: {:?}", error),
1365 )
1366 .into_response()
1367 })
1368 }
1369
1370 pub fn route_method_with_error<Out, E, H>(
1371 self,
1372 method: Method,
1373 path: impl Into<String>,
1374 circuit: Axon<(), Out, E, R>,
1375 error_handler: H,
1376 ) -> Self
1377 where
1378 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1379 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1380 H: Fn(&E) -> HttpResponse + Send + Sync + 'static,
1381 {
1382 self.route_method_with_error_and_layers(
1383 method,
1384 path,
1385 circuit,
1386 error_handler,
1387 Arc::new(Vec::new()),
1388 true,
1389 )
1390 }
1391
1392
1393
1394 fn route_method_with_error_and_layers<Out, E, H>(
1395 mut self,
1396 method: Method,
1397 path: impl Into<String>,
1398 circuit: Axon<(), Out, E, R>,
1399 error_handler: H,
1400 route_layers: Arc<Vec<ServiceLayer>>,
1401 apply_global_layers: bool,
1402 ) -> Self
1403 where
1404 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1405 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1406 H: Fn(&E) -> HttpResponse + Send + Sync + 'static,
1407 {
1408 let path_str: String = path.into();
1409 let circuit = Arc::new(circuit);
1410 let error_handler = Arc::new(error_handler);
1411 let route_bus_injectors = Arc::new(self.bus_injectors.clone());
1412 let route_guard_execs = Arc::new(self.guard_execs.clone());
1413 let route_response_extractors = Arc::new(self.guard_response_extractors.clone());
1414 let route_body_transforms = Arc::new(self.guard_body_transforms.clone());
1415 let path_for_pattern = path_str.clone();
1416 let path_for_handler = path_str;
1417 let method_for_pattern = method.clone();
1418 let method_for_handler = method;
1419
1420 let handler: RouteHandler<R> = Arc::new(move |parts: http::request::Parts, res: &R| {
1421 let circuit = circuit.clone();
1422 let error_handler = error_handler.clone();
1423 let route_bus_injectors = route_bus_injectors.clone();
1424 let route_guard_execs = route_guard_execs.clone();
1425 let route_response_extractors = route_response_extractors.clone();
1426 let route_body_transforms = route_body_transforms.clone();
1427 let res = res.clone();
1428 let path = path_for_handler.clone();
1429 let method = method_for_handler.clone();
1430
1431 Box::pin(async move {
1432 let request_id = uuid::Uuid::new_v4().to_string();
1433 let span = tracing::info_span!(
1434 "HTTPRequest",
1435 ranvier.http.method = %method,
1436 ranvier.http.path = %path,
1437 ranvier.http.request_id = %request_id
1438 );
1439
1440 async move {
1441 let mut bus = Bus::new();
1442 inject_query_params(&parts, &mut bus);
1443 for injector in route_bus_injectors.iter() {
1444 injector(&parts, &mut bus);
1445 }
1446 for guard_exec in route_guard_execs.iter() {
1447 if let Err(rejection) = guard_exec.exec_guard(&mut bus).await {
1448 let mut response = json_error_response(rejection.status, &rejection.message);
1449 for extractor in route_response_extractors.iter() {
1450 extractor(&bus, response.headers_mut());
1451 }
1452 return response;
1453 }
1454 }
1455 if let Some(cached) = bus.read::<ranvier_guard::IdempotencyCachedResponse>() {
1457 let body = Bytes::from(cached.body.clone());
1458 let mut response = Response::builder()
1459 .status(StatusCode::OK)
1460 .header("content-type", "application/json")
1461 .body(Full::new(body).map_err(|n: Infallible| match n {}).boxed())
1462 .unwrap();
1463 for extractor in route_response_extractors.iter() {
1464 extractor(&bus, response.headers_mut());
1465 }
1466 return response;
1467 }
1468 let result = if let Some(td) = bus.read::<ranvier_guard::TimeoutDeadline>() {
1470 let remaining = td.remaining();
1471 if remaining.is_zero() {
1472 let mut response = json_error_response(
1473 StatusCode::REQUEST_TIMEOUT,
1474 "Request timeout: pipeline deadline exceeded",
1475 );
1476 for extractor in route_response_extractors.iter() {
1477 extractor(&bus, response.headers_mut());
1478 }
1479 return response;
1480 }
1481 match tokio::time::timeout(remaining, circuit.execute((), &res, &mut bus)).await {
1482 Ok(result) => result,
1483 Err(_) => {
1484 let mut response = json_error_response(
1485 StatusCode::REQUEST_TIMEOUT,
1486 "Request timeout: pipeline deadline exceeded",
1487 );
1488 for extractor in route_response_extractors.iter() {
1489 extractor(&bus, response.headers_mut());
1490 }
1491 return response;
1492 }
1493 }
1494 } else {
1495 circuit.execute((), &res, &mut bus).await
1496 };
1497 let mut response = outcome_to_response_with_error(result, |error| error_handler(error));
1498 for extractor in route_response_extractors.iter() {
1499 extractor(&bus, response.headers_mut());
1500 }
1501 if !route_body_transforms.is_empty() {
1502 response = apply_body_transforms(response, &bus, &route_body_transforms).await;
1503 }
1504 response
1505 }
1506 .instrument(span)
1507 .await
1508 }) as Pin<Box<dyn Future<Output = HttpResponse> + Send>>
1509 });
1510
1511 self.routes.push(RouteEntry {
1512 method: method_for_pattern,
1513 pattern: RoutePattern::parse(&path_for_pattern),
1514 handler,
1515 layers: route_layers,
1516 apply_global_layers,
1517 needs_body: false,
1518 body_schema: None,
1519 });
1520 self
1521 }
1522
1523 fn route_method_typed<T, Out, E>(
1529 mut self,
1530 method: Method,
1531 path: impl Into<String>,
1532 circuit: Axon<T, Out, E, R>,
1533 ) -> Self
1534 where
1535 T: serde::de::DeserializeOwned + Send + Sync + serde::Serialize + schemars::JsonSchema + 'static,
1536 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1537 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1538 {
1539 let body_schema = serde_json::to_value(schemars::schema_for!(T)).ok();
1540 let path_str: String = path.into();
1541 let circuit = Arc::new(circuit);
1542 let route_bus_injectors = Arc::new(self.bus_injectors.clone());
1543 let route_guard_execs = Arc::new(self.guard_execs.clone());
1544 let route_response_extractors = Arc::new(self.guard_response_extractors.clone());
1545 let route_body_transforms = Arc::new(self.guard_body_transforms.clone());
1546 let path_for_pattern = path_str.clone();
1547 let path_for_handler = path_str;
1548 let method_for_pattern = method.clone();
1549 let method_for_handler = method;
1550
1551 let handler: RouteHandler<R> = Arc::new(move |parts: http::request::Parts, res: &R| {
1552 let circuit = circuit.clone();
1553 let route_bus_injectors = route_bus_injectors.clone();
1554 let route_guard_execs = route_guard_execs.clone();
1555 let route_response_extractors = route_response_extractors.clone();
1556 let route_body_transforms = route_body_transforms.clone();
1557 let res = res.clone();
1558 let path = path_for_handler.clone();
1559 let method = method_for_handler.clone();
1560
1561 Box::pin(async move {
1562 let request_id = uuid::Uuid::new_v4().to_string();
1563 let span = tracing::info_span!(
1564 "HTTPRequest",
1565 ranvier.http.method = %method,
1566 ranvier.http.path = %path,
1567 ranvier.http.request_id = %request_id
1568 );
1569
1570 async move {
1571 let body_bytes = parts
1573 .extensions
1574 .get::<BodyBytes>()
1575 .map(|b| b.0.clone())
1576 .unwrap_or_default();
1577
1578 let input: T = match serde_json::from_slice(&body_bytes) {
1580 Ok(v) => v,
1581 Err(e) => {
1582 return json_error_response(
1583 StatusCode::BAD_REQUEST,
1584 &format!("Invalid request body: {}", e),
1585 );
1586 }
1587 };
1588
1589 let mut bus = Bus::new();
1590 inject_query_params(&parts, &mut bus);
1591 for injector in route_bus_injectors.iter() {
1592 injector(&parts, &mut bus);
1593 }
1594 for guard_exec in route_guard_execs.iter() {
1595 if let Err(rejection) = guard_exec.exec_guard(&mut bus).await {
1596 let mut response = json_error_response(rejection.status, &rejection.message);
1597 for extractor in route_response_extractors.iter() {
1598 extractor(&bus, response.headers_mut());
1599 }
1600 return response;
1601 }
1602 }
1603 if let Some(cached) = bus.read::<ranvier_guard::IdempotencyCachedResponse>() {
1605 let body = Bytes::from(cached.body.clone());
1606 let mut response = Response::builder()
1607 .status(StatusCode::OK)
1608 .header("content-type", "application/json")
1609 .body(Full::new(body).map_err(|n: Infallible| match n {}).boxed())
1610 .unwrap();
1611 for extractor in route_response_extractors.iter() {
1612 extractor(&bus, response.headers_mut());
1613 }
1614 return response;
1615 }
1616 let result = if let Some(td) = bus.read::<ranvier_guard::TimeoutDeadline>() {
1618 let remaining = td.remaining();
1619 if remaining.is_zero() {
1620 let mut response = json_error_response(
1621 StatusCode::REQUEST_TIMEOUT,
1622 "Request timeout: pipeline deadline exceeded",
1623 );
1624 for extractor in route_response_extractors.iter() {
1625 extractor(&bus, response.headers_mut());
1626 }
1627 return response;
1628 }
1629 match tokio::time::timeout(remaining, circuit.execute(input, &res, &mut bus)).await {
1630 Ok(result) => result,
1631 Err(_) => {
1632 let mut response = json_error_response(
1633 StatusCode::REQUEST_TIMEOUT,
1634 "Request timeout: pipeline deadline exceeded",
1635 );
1636 for extractor in route_response_extractors.iter() {
1637 extractor(&bus, response.headers_mut());
1638 }
1639 return response;
1640 }
1641 }
1642 } else {
1643 circuit.execute(input, &res, &mut bus).await
1644 };
1645 let mut response = outcome_to_response_with_error(result, |error| {
1646 if cfg!(debug_assertions) {
1647 (
1648 StatusCode::INTERNAL_SERVER_ERROR,
1649 format!("Error: {:?}", error),
1650 )
1651 .into_response()
1652 } else {
1653 json_error_response(
1654 StatusCode::INTERNAL_SERVER_ERROR,
1655 "Internal server error",
1656 )
1657 }
1658 });
1659 for extractor in route_response_extractors.iter() {
1660 extractor(&bus, response.headers_mut());
1661 }
1662 if !route_body_transforms.is_empty() {
1663 response = apply_body_transforms(response, &bus, &route_body_transforms).await;
1664 }
1665 response
1666 }
1667 .instrument(span)
1668 .await
1669 }) as Pin<Box<dyn Future<Output = HttpResponse> + Send>>
1670 });
1671
1672 self.routes.push(RouteEntry {
1673 method: method_for_pattern,
1674 pattern: RoutePattern::parse(&path_for_pattern),
1675 handler,
1676 layers: Arc::new(Vec::new()),
1677 apply_global_layers: true,
1678 needs_body: true,
1679 body_schema,
1680 });
1681 self
1682 }
1683
1684 pub fn get<Out, E>(self, path: impl Into<String>, circuit: Axon<(), Out, E, R>) -> Self
1685 where
1686 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1687 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1688 {
1689 self.route_method(Method::GET, path, circuit)
1690 }
1691
1692 pub fn get_with_error<Out, E, H>(
1693 self,
1694 path: impl Into<String>,
1695 circuit: Axon<(), Out, E, R>,
1696 error_handler: H,
1697 ) -> Self
1698 where
1699 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1700 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1701 H: Fn(&E) -> HttpResponse + Send + Sync + 'static,
1702 {
1703 self.route_method_with_error(Method::GET, path, circuit, error_handler)
1704 }
1705
1706 pub fn post<Out, E>(self, path: impl Into<String>, circuit: Axon<(), Out, E, R>) -> Self
1707 where
1708 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1709 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1710 {
1711 self.route_method(Method::POST, path, circuit)
1712 }
1713
1714 pub fn post_typed<T, Out, E>(
1730 self,
1731 path: impl Into<String>,
1732 circuit: Axon<T, Out, E, R>,
1733 ) -> Self
1734 where
1735 T: serde::de::DeserializeOwned + Send + Sync + serde::Serialize + schemars::JsonSchema + 'static,
1736 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1737 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1738 {
1739 self.route_method_typed::<T, Out, E>(Method::POST, path, circuit)
1740 }
1741
1742 pub fn put_typed<T, Out, E>(
1746 self,
1747 path: impl Into<String>,
1748 circuit: Axon<T, Out, E, R>,
1749 ) -> Self
1750 where
1751 T: serde::de::DeserializeOwned + Send + Sync + serde::Serialize + schemars::JsonSchema + 'static,
1752 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1753 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1754 {
1755 self.route_method_typed::<T, Out, E>(Method::PUT, path, circuit)
1756 }
1757
1758 pub fn patch_typed<T, Out, E>(
1762 self,
1763 path: impl Into<String>,
1764 circuit: Axon<T, Out, E, R>,
1765 ) -> Self
1766 where
1767 T: serde::de::DeserializeOwned + Send + Sync + serde::Serialize + schemars::JsonSchema + 'static,
1768 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1769 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1770 {
1771 self.route_method_typed::<T, Out, E>(Method::PATCH, path, circuit)
1772 }
1773
1774 fn route_method_json<T, Out, E>(
1780 mut self,
1781 method: Method,
1782 path: impl Into<String>,
1783 circuit: Axon<T, Out, E, R>,
1784 ) -> Self
1785 where
1786 T: serde::de::DeserializeOwned + Send + Sync + serde::Serialize + 'static,
1787 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1788 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1789 {
1790 let path_str: String = path.into();
1791 let circuit = Arc::new(circuit);
1792 let route_bus_injectors = Arc::new(self.bus_injectors.clone());
1793 let route_guard_execs = Arc::new(self.guard_execs.clone());
1794 let route_response_extractors = Arc::new(self.guard_response_extractors.clone());
1795 let route_body_transforms = Arc::new(self.guard_body_transforms.clone());
1796 let path_for_pattern = path_str.clone();
1797 let path_for_handler = path_str;
1798 let method_for_pattern = method.clone();
1799 let method_for_handler = method;
1800
1801 let handler: RouteHandler<R> = Arc::new(move |parts: http::request::Parts, res: &R| {
1802 let circuit = circuit.clone();
1803 let route_bus_injectors = route_bus_injectors.clone();
1804 let route_guard_execs = route_guard_execs.clone();
1805 let route_response_extractors = route_response_extractors.clone();
1806 let route_body_transforms = route_body_transforms.clone();
1807 let res = res.clone();
1808 let path = path_for_handler.clone();
1809 let method = method_for_handler.clone();
1810
1811 Box::pin(async move {
1812 let request_id = uuid::Uuid::new_v4().to_string();
1813 let span = tracing::info_span!(
1814 "HTTPRequest",
1815 ranvier.http.method = %method,
1816 ranvier.http.path = %path,
1817 ranvier.http.request_id = %request_id
1818 );
1819
1820 async move {
1821 let body_bytes = parts
1822 .extensions
1823 .get::<BodyBytes>()
1824 .map(|b| b.0.clone())
1825 .unwrap_or_default();
1826
1827 let input: T = match serde_json::from_slice(&body_bytes) {
1828 Ok(v) => v,
1829 Err(e) => {
1830 return json_error_response(
1831 StatusCode::BAD_REQUEST,
1832 &format!("Invalid request body: {}", e),
1833 );
1834 }
1835 };
1836
1837 let mut bus = Bus::new();
1838 inject_query_params(&parts, &mut bus);
1839 for injector in route_bus_injectors.iter() {
1840 injector(&parts, &mut bus);
1841 }
1842 for guard_exec in route_guard_execs.iter() {
1843 if let Err(rejection) = guard_exec.exec_guard(&mut bus).await {
1844 let mut response = json_error_response(rejection.status, &rejection.message);
1845 for extractor in route_response_extractors.iter() {
1846 extractor(&bus, response.headers_mut());
1847 }
1848 return response;
1849 }
1850 }
1851 if let Some(cached) = bus.read::<ranvier_guard::IdempotencyCachedResponse>() {
1852 let body = Bytes::from(cached.body.clone());
1853 let mut response = Response::builder()
1854 .status(StatusCode::OK)
1855 .header("content-type", "application/json")
1856 .body(Full::new(body).map_err(|n: Infallible| match n {}).boxed())
1857 .unwrap();
1858 for extractor in route_response_extractors.iter() {
1859 extractor(&bus, response.headers_mut());
1860 }
1861 return response;
1862 }
1863 let result = if let Some(td) = bus.read::<ranvier_guard::TimeoutDeadline>() {
1864 let remaining = td.remaining();
1865 if remaining.is_zero() {
1866 let mut response = json_error_response(
1867 StatusCode::REQUEST_TIMEOUT,
1868 "Request timeout: pipeline deadline exceeded",
1869 );
1870 for extractor in route_response_extractors.iter() {
1871 extractor(&bus, response.headers_mut());
1872 }
1873 return response;
1874 }
1875 match tokio::time::timeout(remaining, circuit.execute(input, &res, &mut bus)).await {
1876 Ok(result) => result,
1877 Err(_) => {
1878 let mut response = json_error_response(
1879 StatusCode::REQUEST_TIMEOUT,
1880 "Request timeout: pipeline deadline exceeded",
1881 );
1882 for extractor in route_response_extractors.iter() {
1883 extractor(&bus, response.headers_mut());
1884 }
1885 return response;
1886 }
1887 }
1888 } else {
1889 circuit.execute(input, &res, &mut bus).await
1890 };
1891 let mut response = outcome_to_response_with_error(result, |error| {
1892 if cfg!(debug_assertions) {
1893 (
1894 StatusCode::INTERNAL_SERVER_ERROR,
1895 format!("Error: {:?}", error),
1896 )
1897 .into_response()
1898 } else {
1899 json_error_response(
1900 StatusCode::INTERNAL_SERVER_ERROR,
1901 "Internal server error",
1902 )
1903 }
1904 });
1905 for extractor in route_response_extractors.iter() {
1906 extractor(&bus, response.headers_mut());
1907 }
1908 if !route_body_transforms.is_empty() {
1909 response = apply_body_transforms(response, &bus, &route_body_transforms).await;
1910 }
1911 response
1912 }
1913 .instrument(span)
1914 .await
1915 }) as Pin<Box<dyn Future<Output = HttpResponse> + Send>>
1916 });
1917
1918 self.routes.push(RouteEntry {
1919 method: method_for_pattern,
1920 pattern: RoutePattern::parse(&path_for_pattern),
1921 handler,
1922 layers: Arc::new(Vec::new()),
1923 apply_global_layers: true,
1924 needs_body: true,
1925 body_schema: None,
1926 });
1927 self
1928 }
1929
1930 pub fn post_json<T, Out, E>(
1943 self,
1944 path: impl Into<String>,
1945 circuit: Axon<T, Out, E, R>,
1946 ) -> Self
1947 where
1948 T: serde::de::DeserializeOwned + Send + Sync + serde::Serialize + 'static,
1949 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1950 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1951 {
1952 self.route_method_json::<T, Out, E>(Method::POST, path, circuit)
1953 }
1954
1955 pub fn put_json<T, Out, E>(
1959 self,
1960 path: impl Into<String>,
1961 circuit: Axon<T, Out, E, R>,
1962 ) -> Self
1963 where
1964 T: serde::de::DeserializeOwned + Send + Sync + serde::Serialize + 'static,
1965 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1966 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1967 {
1968 self.route_method_json::<T, Out, E>(Method::PUT, path, circuit)
1969 }
1970
1971 pub fn patch_json<T, Out, E>(
1975 self,
1976 path: impl Into<String>,
1977 circuit: Axon<T, Out, E, R>,
1978 ) -> Self
1979 where
1980 T: serde::de::DeserializeOwned + Send + Sync + serde::Serialize + 'static,
1981 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1982 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1983 {
1984 self.route_method_json::<T, Out, E>(Method::PATCH, path, circuit)
1985 }
1986
1987 pub fn put<Out, E>(self, path: impl Into<String>, circuit: Axon<(), Out, E, R>) -> Self
1988 where
1989 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1990 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1991 {
1992 self.route_method(Method::PUT, path, circuit)
1993 }
1994
1995 pub fn delete<Out, E>(self, path: impl Into<String>, circuit: Axon<(), Out, E, R>) -> Self
1996 where
1997 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
1998 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
1999 {
2000 self.route_method(Method::DELETE, path, circuit)
2001 }
2002
2003 pub fn patch<Out, E>(self, path: impl Into<String>, circuit: Axon<(), Out, E, R>) -> Self
2004 where
2005 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
2006 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
2007 {
2008 self.route_method(Method::PATCH, path, circuit)
2009 }
2010
2011 pub fn post_with_error<Out, E, H>(
2012 self,
2013 path: impl Into<String>,
2014 circuit: Axon<(), Out, E, R>,
2015 error_handler: H,
2016 ) -> Self
2017 where
2018 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
2019 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
2020 H: Fn(&E) -> HttpResponse + Send + Sync + 'static,
2021 {
2022 self.route_method_with_error(Method::POST, path, circuit, error_handler)
2023 }
2024
2025 pub fn put_with_error<Out, E, H>(
2026 self,
2027 path: impl Into<String>,
2028 circuit: Axon<(), Out, E, R>,
2029 error_handler: H,
2030 ) -> Self
2031 where
2032 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
2033 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
2034 H: Fn(&E) -> HttpResponse + Send + Sync + 'static,
2035 {
2036 self.route_method_with_error(Method::PUT, path, circuit, error_handler)
2037 }
2038
2039 pub fn delete_with_error<Out, E, H>(
2040 self,
2041 path: impl Into<String>,
2042 circuit: Axon<(), Out, E, R>,
2043 error_handler: H,
2044 ) -> Self
2045 where
2046 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
2047 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
2048 H: Fn(&E) -> HttpResponse + Send + Sync + 'static,
2049 {
2050 self.route_method_with_error(Method::DELETE, path, circuit, error_handler)
2051 }
2052
2053 pub fn patch_with_error<Out, E, H>(
2054 self,
2055 path: impl Into<String>,
2056 circuit: Axon<(), Out, E, R>,
2057 error_handler: H,
2058 ) -> Self
2059 where
2060 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
2061 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
2062 H: Fn(&E) -> HttpResponse + Send + Sync + 'static,
2063 {
2064 self.route_method_with_error(Method::PATCH, path, circuit, error_handler)
2065 }
2066
2067 #[cfg(feature = "streaming")]
2089 pub fn post_sse<Item, E>(
2090 self,
2091 path: impl Into<String>,
2092 circuit: ranvier_runtime::StreamingAxon<(), Item, E, R>,
2093 ) -> Self
2094 where
2095 Item: serde::Serialize + Send + Sync + 'static,
2096 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
2097 {
2098 self.route_sse_internal::<(), Item, E>(Method::POST, path, circuit, false)
2099 }
2100
2101 #[cfg(feature = "streaming")]
2118 pub fn post_sse_typed<T, Item, E>(
2119 self,
2120 path: impl Into<String>,
2121 circuit: ranvier_runtime::StreamingAxon<T, Item, E, R>,
2122 ) -> Self
2123 where
2124 T: serde::de::DeserializeOwned + Send + Sync + serde::Serialize + schemars::JsonSchema + 'static,
2125 Item: serde::Serialize + Send + Sync + 'static,
2126 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
2127 {
2128 self.route_sse_internal::<T, Item, E>(Method::POST, path, circuit, true)
2129 }
2130
2131 #[cfg(feature = "streaming")]
2133 fn route_sse_internal<T, Item, E>(
2134 mut self,
2135 method: Method,
2136 path: impl Into<String>,
2137 circuit: ranvier_runtime::StreamingAxon<T, Item, E, R>,
2138 needs_body: bool,
2139 ) -> Self
2140 where
2141 T: serde::de::DeserializeOwned + Send + Sync + serde::Serialize + 'static,
2142 Item: serde::Serialize + Send + Sync + 'static,
2143 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
2144 {
2145 let path_str: String = path.into();
2146 let circuit = Arc::new(circuit);
2147 let route_bus_injectors = Arc::new(self.bus_injectors.clone());
2148 let route_guard_execs = Arc::new(self.guard_execs.clone());
2149 let route_response_extractors = Arc::new(self.guard_response_extractors.clone());
2150 let path_for_pattern = path_str.clone();
2151 let path_for_handler = path_str;
2152 let method_for_pattern = method.clone();
2153 let method_for_handler = method;
2154
2155 let handler: RouteHandler<R> = Arc::new(move |parts: http::request::Parts, res: &R| {
2156 let circuit = circuit.clone();
2157 let route_bus_injectors = route_bus_injectors.clone();
2158 let route_guard_execs = route_guard_execs.clone();
2159 let route_response_extractors = route_response_extractors.clone();
2160 let res = res.clone();
2161 let path = path_for_handler.clone();
2162 let method = method_for_handler.clone();
2163
2164 Box::pin(async move {
2165 let request_id = uuid::Uuid::new_v4().to_string();
2166 let span = tracing::info_span!(
2167 "SSERequest",
2168 ranvier.http.method = %method,
2169 ranvier.http.path = %path,
2170 ranvier.http.request_id = %request_id
2171 );
2172
2173 async move {
2174 let input: T = if needs_body {
2176 let body_bytes = parts
2177 .extensions
2178 .get::<BodyBytes>()
2179 .map(|b| b.0.clone())
2180 .unwrap_or_default();
2181
2182 match serde_json::from_slice(&body_bytes) {
2183 Ok(v) => v,
2184 Err(e) => {
2185 return json_error_response(
2186 StatusCode::BAD_REQUEST,
2187 &format!("Invalid request body: {}", e),
2188 );
2189 }
2190 }
2191 } else {
2192 match serde_json::from_str("null") {
2195 Ok(v) => v,
2196 Err(_) => {
2197 return json_error_response(
2198 StatusCode::INTERNAL_SERVER_ERROR,
2199 "Internal: failed to construct default input",
2200 );
2201 }
2202 }
2203 };
2204
2205 let mut bus = Bus::new();
2206 inject_query_params(&parts, &mut bus);
2207 for injector in route_bus_injectors.iter() {
2208 injector(&parts, &mut bus);
2209 }
2210 for guard_exec in route_guard_execs.iter() {
2211 if let Err(rejection) = guard_exec.exec_guard(&mut bus).await {
2212 let mut response = json_error_response(rejection.status, &rejection.message);
2213 for extractor in route_response_extractors.iter() {
2214 extractor(&bus, response.headers_mut());
2215 }
2216 return response;
2217 }
2218 }
2219
2220 let stream = match circuit.execute(input, &res, &mut bus).await {
2222 Ok(s) => s,
2223 Err(e) => {
2224 tracing::error!("Streaming pipeline error: {}", e);
2225 if cfg!(debug_assertions) {
2226 return json_error_response(
2227 StatusCode::INTERNAL_SERVER_ERROR,
2228 &format!("Streaming error: {}", e),
2229 );
2230 } else {
2231 return json_error_response(
2232 StatusCode::INTERNAL_SERVER_ERROR,
2233 "Internal server error",
2234 );
2235 }
2236 }
2237 };
2238
2239 let buffer_size = circuit.buffer_size;
2242 let (tx, mut rx) = tokio::sync::mpsc::channel::<Bytes>(buffer_size);
2243
2244 tokio::spawn(async move {
2246 let mut pinned = Box::pin(stream);
2247 while let Some(item) = futures_util::StreamExt::next(&mut pinned).await {
2248 let text = match serde_json::to_string(&item) {
2249 Ok(json) => format!("data: {}\n\n", json),
2250 Err(e) => {
2251 tracing::error!("SSE item serialization error: {}", e);
2252 let err_text = "event: error\ndata: {\"message\":\"serialization error\",\"code\":\"serialize_error\"}\n\n".to_string();
2253 let _ = tx.send(Bytes::from(err_text)).await;
2254 break;
2255 }
2256 };
2257 if tx.send(Bytes::from(text)).await.is_err() {
2258 tracing::info!("SSE client disconnected");
2259 break;
2260 }
2261 }
2262 let _ = tx.send(Bytes::from("data: [DONE]\n\n")).await;
2264 });
2265
2266 let frame_stream = async_stream::stream! {
2268 while let Some(bytes) = rx.recv().await {
2269 yield Ok::<http_body::Frame<Bytes>, std::convert::Infallible>(
2270 http_body::Frame::data(bytes)
2271 );
2272 }
2273 };
2274
2275 let body = http_body_util::StreamBody::new(frame_stream);
2276 Response::builder()
2277 .status(StatusCode::OK)
2278 .header(http::header::CONTENT_TYPE, "text/event-stream")
2279 .header(http::header::CACHE_CONTROL, "no-cache")
2280 .header(http::header::CONNECTION, "keep-alive")
2281 .body(http_body_util::BodyExt::boxed(body))
2282 .expect("Valid SSE response")
2283 }
2284 .instrument(span)
2285 .await
2286 }) as Pin<Box<dyn Future<Output = HttpResponse> + Send>>
2287 });
2288
2289 self.routes.push(RouteEntry {
2290 method: method_for_pattern,
2291 pattern: RoutePattern::parse(&path_for_pattern),
2292 handler,
2293 layers: Arc::new(Vec::new()),
2294 apply_global_layers: true,
2295 needs_body,
2296 body_schema: None,
2297 });
2298 self
2299 }
2300
2301 fn route_method_with_extra_guards<Out, E>(
2307 mut self,
2308 method: Method,
2309 path: impl Into<String>,
2310 circuit: Axon<(), Out, E, R>,
2311 extra_guards: Vec<RegisteredGuard>,
2312 ) -> Self
2313 where
2314 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
2315 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
2316 {
2317 let saved_injectors = self.bus_injectors.len();
2319 let saved_execs = self.guard_execs.len();
2320 let saved_extractors = self.guard_response_extractors.len();
2321 let saved_transforms = self.guard_body_transforms.len();
2322
2323 for registration in extra_guards {
2325 for injector in registration.bus_injectors {
2326 self.bus_injectors.push(injector);
2327 }
2328 self.guard_execs.push(registration.exec);
2329 if let Some(extractor) = registration.response_extractor {
2330 self.guard_response_extractors.push(extractor);
2331 }
2332 if let Some(transform) = registration.response_body_transform {
2333 self.guard_body_transforms.push(transform);
2334 }
2335 }
2336
2337 self = self.route_method(method, path, circuit);
2339
2340 self.bus_injectors.truncate(saved_injectors);
2342 self.guard_execs.truncate(saved_execs);
2343 self.guard_response_extractors.truncate(saved_extractors);
2344 self.guard_body_transforms.truncate(saved_transforms);
2345
2346 self
2347 }
2348
2349 pub fn get_with_guards<Out, E>(
2367 self,
2368 path: impl Into<String>,
2369 circuit: Axon<(), Out, E, R>,
2370 extra_guards: Vec<RegisteredGuard>,
2371 ) -> Self
2372 where
2373 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
2374 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
2375 {
2376 self.route_method_with_extra_guards(Method::GET, path, circuit, extra_guards)
2377 }
2378
2379 pub fn post_with_guards<Out, E>(
2400 self,
2401 path: impl Into<String>,
2402 circuit: Axon<(), Out, E, R>,
2403 extra_guards: Vec<RegisteredGuard>,
2404 ) -> Self
2405 where
2406 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
2407 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
2408 {
2409 self.route_method_with_extra_guards(Method::POST, path, circuit, extra_guards)
2410 }
2411
2412 pub fn put_with_guards<Out, E>(
2414 self,
2415 path: impl Into<String>,
2416 circuit: Axon<(), Out, E, R>,
2417 extra_guards: Vec<RegisteredGuard>,
2418 ) -> Self
2419 where
2420 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
2421 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
2422 {
2423 self.route_method_with_extra_guards(Method::PUT, path, circuit, extra_guards)
2424 }
2425
2426 pub fn delete_with_guards<Out, E>(
2428 self,
2429 path: impl Into<String>,
2430 circuit: Axon<(), Out, E, R>,
2431 extra_guards: Vec<RegisteredGuard>,
2432 ) -> Self
2433 where
2434 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
2435 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
2436 {
2437 self.route_method_with_extra_guards(Method::DELETE, path, circuit, extra_guards)
2438 }
2439
2440 pub fn patch_with_guards<Out, E>(
2442 self,
2443 path: impl Into<String>,
2444 circuit: Axon<(), Out, E, R>,
2445 extra_guards: Vec<RegisteredGuard>,
2446 ) -> Self
2447 where
2448 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
2449 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
2450 {
2451 self.route_method_with_extra_guards(Method::PATCH, path, circuit, extra_guards)
2452 }
2453
2454 pub fn fallback<Out, E>(mut self, circuit: Axon<(), Out, E, R>) -> Self
2465 where
2466 Out: IntoResponse + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static,
2467 E: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + 'static,
2468 {
2469 let circuit = Arc::new(circuit);
2470 let fallback_bus_injectors = Arc::new(self.bus_injectors.clone());
2471 let fallback_guard_execs = Arc::new(self.guard_execs.clone());
2472 let fallback_response_extractors = Arc::new(self.guard_response_extractors.clone());
2473 let fallback_body_transforms = Arc::new(self.guard_body_transforms.clone());
2474
2475 let handler: RouteHandler<R> = Arc::new(move |parts: http::request::Parts, res: &R| {
2476 let circuit = circuit.clone();
2477 let fallback_bus_injectors = fallback_bus_injectors.clone();
2478 let fallback_guard_execs = fallback_guard_execs.clone();
2479 let fallback_response_extractors = fallback_response_extractors.clone();
2480 let fallback_body_transforms = fallback_body_transforms.clone();
2481 let res = res.clone();
2482 Box::pin(async move {
2483 let request_id = uuid::Uuid::new_v4().to_string();
2484 let span = tracing::info_span!(
2485 "HTTPRequest",
2486 ranvier.http.method = "FALLBACK",
2487 ranvier.http.request_id = %request_id
2488 );
2489
2490 async move {
2491 let mut bus = Bus::new();
2492 inject_query_params(&parts, &mut bus);
2493 for injector in fallback_bus_injectors.iter() {
2494 injector(&parts, &mut bus);
2495 }
2496 for guard_exec in fallback_guard_execs.iter() {
2497 if let Err(rejection) = guard_exec.exec_guard(&mut bus).await {
2498 let mut response = json_error_response(rejection.status, &rejection.message);
2499 for extractor in fallback_response_extractors.iter() {
2500 extractor(&bus, response.headers_mut());
2501 }
2502 return response;
2503 }
2504 }
2505 let result: ranvier_core::Outcome<Out, E> =
2506 circuit.execute((), &res, &mut bus).await;
2507
2508 let mut response = match result {
2509 Outcome::Next(output) => {
2510 let mut response = output.into_response();
2511 *response.status_mut() = StatusCode::NOT_FOUND;
2512 response
2513 }
2514 _ => Response::builder()
2515 .status(StatusCode::NOT_FOUND)
2516 .body(
2517 Full::new(Bytes::from("Not Found"))
2518 .map_err(|never| match never {})
2519 .boxed(),
2520 )
2521 .expect("valid HTTP response construction"),
2522 };
2523 for extractor in fallback_response_extractors.iter() {
2524 extractor(&bus, response.headers_mut());
2525 }
2526 if !fallback_body_transforms.is_empty() {
2527 response = apply_body_transforms(response, &bus, &fallback_body_transforms).await;
2528 }
2529 response
2530 }
2531 .instrument(span)
2532 .await
2533 }) as Pin<Box<dyn Future<Output = HttpResponse> + Send>>
2534 });
2535
2536 self.fallback = Some(handler);
2537 self
2538 }
2539
2540 pub async fn run(self, resources: R) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
2544 self.run_with_shutdown_signal(resources, shutdown_signal())
2545 .await
2546 }
2547
2548 async fn run_with_shutdown_signal<S>(
2549 self,
2550 resources: R,
2551 shutdown_signal: S,
2552 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>>
2553 where
2554 S: Future<Output = ()> + Send,
2555 {
2556 let addr_str = self.addr.as_deref().unwrap_or("127.0.0.1:3000");
2557 let addr: SocketAddr = addr_str.parse()?;
2558
2559 let mut raw_routes = self.routes;
2560 if self.active_intervention {
2561 let handler: RouteHandler<R> = Arc::new(|_parts, _res| {
2562 Box::pin(async move {
2563 Response::builder()
2564 .status(StatusCode::OK)
2565 .body(
2566 Full::new(Bytes::from("Intervention accepted"))
2567 .map_err(|never| match never {} as Infallible)
2568 .boxed(),
2569 )
2570 .expect("valid HTTP response construction")
2571 }) as Pin<Box<dyn Future<Output = HttpResponse> + Send>>
2572 });
2573
2574 raw_routes.push(RouteEntry {
2575 method: Method::POST,
2576 pattern: RoutePattern::parse("/_system/intervene/force_resume"),
2577 handler,
2578 layers: Arc::new(Vec::new()),
2579 apply_global_layers: true,
2580 needs_body: false,
2581 body_schema: None,
2582 });
2583 }
2584
2585 if let Some(registry) = self.policy_registry.clone() {
2586 let handler: RouteHandler<R> = Arc::new(move |_parts, _res| {
2587 let _registry = registry.clone();
2588 Box::pin(async move {
2589 Response::builder()
2593 .status(StatusCode::OK)
2594 .body(
2595 Full::new(Bytes::from("Policy registry active"))
2596 .map_err(|never| match never {} as Infallible)
2597 .boxed(),
2598 )
2599 .expect("valid HTTP response construction")
2600 }) as Pin<Box<dyn Future<Output = HttpResponse> + Send>>
2601 });
2602
2603 raw_routes.push(RouteEntry {
2604 method: Method::POST,
2605 pattern: RoutePattern::parse("/_system/policy/reload"),
2606 handler,
2607 layers: Arc::new(Vec::new()),
2608 apply_global_layers: true,
2609 needs_body: false,
2610 body_schema: None,
2611 });
2612 }
2613 let routes = Arc::new(raw_routes);
2614 let fallback = self.fallback;
2615 let layers = Arc::new(self.layers);
2616 let health = Arc::new(self.health);
2617 let static_assets = Arc::new(self.static_assets);
2618 let preflight_config = Arc::new(self.preflight_config);
2619 let on_start = self.on_start;
2620 let on_shutdown = self.on_shutdown;
2621 let graceful_shutdown_timeout = self.graceful_shutdown_timeout;
2622 let resources = Arc::new(resources);
2623
2624 let listener = TcpListener::bind(addr).await?;
2625
2626 #[cfg(feature = "tls")]
2628 let tls_acceptor = if let Some(ref tls_cfg) = self.tls_config {
2629 let acceptor = build_tls_acceptor(&tls_cfg.cert_path, &tls_cfg.key_path)?;
2630 tracing::info!("Ranvier HTTP Ingress listening on https://{}", addr);
2631 Some(acceptor)
2632 } else {
2633 tracing::info!("Ranvier HTTP Ingress listening on http://{}", addr);
2634 None
2635 };
2636 #[cfg(not(feature = "tls"))]
2637 tracing::info!("Ranvier HTTP Ingress listening on http://{}", addr);
2638
2639 if let Some(callback) = on_start.as_ref() {
2640 callback();
2641 }
2642
2643 tokio::pin!(shutdown_signal);
2644 let mut connections = tokio::task::JoinSet::new();
2645
2646 loop {
2647 tokio::select! {
2648 _ = &mut shutdown_signal => {
2649 tracing::info!("Shutdown signal received. Draining in-flight connections.");
2650 break;
2651 }
2652 accept_result = listener.accept() => {
2653 let (stream, _) = accept_result?;
2654
2655 let routes = routes.clone();
2656 let fallback = fallback.clone();
2657 let resources = resources.clone();
2658 let layers = layers.clone();
2659 let health = health.clone();
2660 let static_assets = static_assets.clone();
2661 let preflight_config = preflight_config.clone();
2662 #[cfg(feature = "http3")]
2663 let alt_svc_h3_port = self.alt_svc_h3_port;
2664
2665 #[cfg(feature = "tls")]
2666 let tls_acceptor = tls_acceptor.clone();
2667
2668 connections.spawn(async move {
2669 let service = build_http_service(
2670 routes,
2671 fallback,
2672 resources,
2673 layers,
2674 health,
2675 static_assets,
2676 preflight_config,
2677 #[cfg(feature = "http3")] alt_svc_h3_port,
2678 );
2679
2680 #[cfg(feature = "tls")]
2681 if let Some(acceptor) = tls_acceptor {
2682 match acceptor.accept(stream).await {
2683 Ok(tls_stream) => {
2684 let io = TokioIo::new(tls_stream);
2685 if let Err(err) = http1::Builder::new()
2686 .serve_connection(io, service)
2687 .with_upgrades()
2688 .await
2689 {
2690 tracing::error!("Error serving TLS connection: {:?}", err);
2691 }
2692 }
2693 Err(err) => {
2694 tracing::warn!("TLS handshake failed: {:?}", err);
2695 }
2696 }
2697 return;
2698 }
2699
2700 let io = TokioIo::new(stream);
2701 if let Err(err) = http1::Builder::new()
2702 .serve_connection(io, service)
2703 .with_upgrades()
2704 .await
2705 {
2706 tracing::error!("Error serving connection: {:?}", err);
2707 }
2708 });
2709 }
2710 Some(join_result) = connections.join_next(), if !connections.is_empty() => {
2711 if let Err(err) = join_result {
2712 tracing::warn!("Connection task join error: {:?}", err);
2713 }
2714 }
2715 }
2716 }
2717
2718 let _timed_out = drain_connections(&mut connections, graceful_shutdown_timeout).await;
2719
2720 drop(resources);
2721 if let Some(callback) = on_shutdown.as_ref() {
2722 callback();
2723 }
2724
2725 Ok(())
2726 }
2727
2728 pub fn into_raw_service(self, resources: R) -> RawIngressService<R> {
2744 let routes = Arc::new(self.routes);
2745 let fallback = self.fallback;
2746 let layers = Arc::new(self.layers);
2747 let health = Arc::new(self.health);
2748 let static_assets = Arc::new(self.static_assets);
2749 let preflight_config = Arc::new(self.preflight_config);
2750 let resources = Arc::new(resources);
2751
2752 RawIngressService {
2753 routes,
2754 fallback,
2755 layers,
2756 health,
2757 static_assets,
2758 preflight_config,
2759 resources,
2760 }
2761 }
2762}
2763
2764async fn apply_body_transforms(
2769 response: HttpResponse,
2770 bus: &Bus,
2771 transforms: &[ResponseBodyTransformFn],
2772) -> HttpResponse {
2773 use http_body_util::BodyExt;
2774
2775 let (parts, body) = response.into_parts();
2776
2777 let collected = match body.collect().await {
2779 Ok(c) => c.to_bytes(),
2780 Err(_) => {
2781 return Response::builder()
2783 .status(StatusCode::INTERNAL_SERVER_ERROR)
2784 .body(
2785 Full::new(Bytes::from("body collection failed"))
2786 .map_err(|never| match never {})
2787 .boxed(),
2788 )
2789 .expect("valid response");
2790 }
2791 };
2792
2793 let mut transformed = collected;
2794 for transform in transforms {
2795 transformed = transform(bus, transformed);
2796 }
2797
2798 Response::from_parts(
2799 parts,
2800 Full::new(transformed)
2801 .map_err(|never| match never {})
2802 .boxed(),
2803 )
2804}
2805
2806fn build_http_service<R>(
2807 routes: Arc<Vec<RouteEntry<R>>>,
2808 fallback: Option<RouteHandler<R>>,
2809 resources: Arc<R>,
2810 layers: Arc<Vec<ServiceLayer>>,
2811 health: Arc<HealthConfig<R>>,
2812 static_assets: Arc<StaticAssetsConfig>,
2813 preflight_config: Arc<Option<PreflightConfig>>,
2814 #[cfg(feature = "http3")] alt_svc_port: Option<u16>,
2815) -> BoxHttpService
2816where
2817 R: ranvier_core::transition::ResourceRequirement + Clone + Send + Sync + 'static,
2818{
2819 BoxService::new(move |req: Request<Incoming>| {
2820 let routes = routes.clone();
2821 let fallback = fallback.clone();
2822 let resources = resources.clone();
2823 let layers = layers.clone();
2824 let health = health.clone();
2825 let static_assets = static_assets.clone();
2826 let preflight_config = preflight_config.clone();
2827
2828 async move {
2829 let mut req = req;
2830 let method = req.method().clone();
2831 let path = req.uri().path().to_string();
2832
2833 if let Some(response) =
2834 maybe_handle_health_request(&method, &path, &health, resources.clone()).await
2835 {
2836 return Ok::<_, Infallible>(response.into_response());
2837 }
2838
2839 if method == Method::OPTIONS {
2841 if let Some(ref config) = *preflight_config {
2842 let origin = req
2843 .headers()
2844 .get("origin")
2845 .and_then(|v| v.to_str().ok())
2846 .unwrap_or("");
2847 let is_wildcard = config.allowed_origins.iter().any(|o| o == "*");
2848 let is_allowed = is_wildcard
2849 || config.allowed_origins.iter().any(|o| o == origin);
2850
2851 if is_allowed || origin.is_empty() {
2852 let allow_origin = if is_wildcard {
2853 "*".to_string()
2854 } else {
2855 origin.to_string()
2856 };
2857 let mut response = Response::builder()
2858 .status(StatusCode::NO_CONTENT)
2859 .body(
2860 Full::new(Bytes::new())
2861 .map_err(|never| match never {})
2862 .boxed(),
2863 )
2864 .expect("valid preflight response");
2865 let headers = response.headers_mut();
2866 if let Ok(v) = allow_origin.parse() {
2867 headers.insert("access-control-allow-origin", v);
2868 }
2869 if let Ok(v) = config.allowed_methods.parse() {
2870 headers.insert("access-control-allow-methods", v);
2871 }
2872 if let Ok(v) = config.allowed_headers.parse() {
2873 headers.insert("access-control-allow-headers", v);
2874 }
2875 if let Ok(v) = config.max_age.parse() {
2876 headers.insert("access-control-max-age", v);
2877 }
2878 if config.allow_credentials {
2879 headers.insert(
2880 "access-control-allow-credentials",
2881 "true".parse().expect("valid header value"),
2882 );
2883 }
2884 return Ok(response);
2885 }
2886 }
2887 }
2888
2889 if let Some((entry, params)) = find_matching_route(routes.as_slice(), &method, &path) {
2890 req.extensions_mut().insert(params);
2891 let effective_layers = if entry.apply_global_layers {
2892 merge_layers(&layers, &entry.layers)
2893 } else {
2894 entry.layers.clone()
2895 };
2896
2897 if effective_layers.is_empty() {
2898 let (mut parts, body) = req.into_parts();
2899 if entry.needs_body {
2900 match BodyExt::collect(body).await {
2901 Ok(collected) => { parts.extensions.insert(BodyBytes(collected.to_bytes())); }
2902 Err(_) => {
2903 return Ok(json_error_response(
2904 StatusCode::BAD_REQUEST,
2905 "Failed to read request body",
2906 ));
2907 }
2908 }
2909 }
2910 #[allow(unused_mut)]
2911 let mut res = (entry.handler)(parts, &resources).await;
2912 #[cfg(feature = "http3")]
2913 if let Some(port) = alt_svc_port {
2914 if let Ok(val) =
2915 http::HeaderValue::from_str(&format!("h3=\":{}\"; ma=86400", port))
2916 {
2917 res.headers_mut().insert(http::header::ALT_SVC, val);
2918 }
2919 }
2920 Ok::<_, Infallible>(res)
2921 } else {
2922 let route_service = build_route_service(
2923 entry.handler.clone(),
2924 resources.clone(),
2925 effective_layers,
2926 entry.needs_body,
2927 );
2928 #[allow(unused_mut)]
2929 let mut res = route_service.call(req).await;
2930 #[cfg(feature = "http3")]
2931 #[allow(irrefutable_let_patterns)]
2932 if let Ok(ref mut r) = res {
2933 if let Some(port) = alt_svc_port {
2934 if let Ok(val) =
2935 http::HeaderValue::from_str(&format!("h3=\":{}\"; ma=86400", port))
2936 {
2937 r.headers_mut().insert(http::header::ALT_SVC, val);
2938 }
2939 }
2940 }
2941 res
2942 }
2943 } else {
2944 let req =
2945 match maybe_handle_static_request(req, &method, &path, static_assets.as_ref())
2946 .await
2947 {
2948 Ok(req) => req,
2949 Err(response) => return Ok(response),
2950 };
2951
2952 #[allow(unused_mut)]
2953 let mut fallback_res = if let Some(ref fb) = fallback {
2954 if layers.is_empty() {
2955 let (parts, _) = req.into_parts();
2956 Ok(fb(parts, &resources).await)
2957 } else {
2958 let fallback_service =
2959 build_route_service(fb.clone(), resources.clone(), layers.clone(), false);
2960 fallback_service.call(req).await
2961 }
2962 } else {
2963 Ok(Response::builder()
2964 .status(StatusCode::NOT_FOUND)
2965 .body(
2966 Full::new(Bytes::from("Not Found"))
2967 .map_err(|never| match never {})
2968 .boxed(),
2969 )
2970 .expect("valid HTTP response construction"))
2971 };
2972
2973 #[cfg(feature = "http3")]
2974 if let Ok(r) = fallback_res.as_mut() {
2975 if let Some(port) = alt_svc_port {
2976 if let Ok(val) =
2977 http::HeaderValue::from_str(&format!("h3=\":{}\"; ma=86400", port))
2978 {
2979 r.headers_mut().insert(http::header::ALT_SVC, val);
2980 }
2981 }
2982 }
2983
2984 fallback_res
2985 }
2986 }
2987 })
2988}
2989
2990fn build_route_service<R>(
2991 handler: RouteHandler<R>,
2992 resources: Arc<R>,
2993 layers: Arc<Vec<ServiceLayer>>,
2994 needs_body: bool,
2995) -> BoxHttpService
2996where
2997 R: ranvier_core::transition::ResourceRequirement + Clone + Send + Sync + 'static,
2998{
2999 let mut service = BoxService::new(move |req: Request<Incoming>| {
3000 let handler = handler.clone();
3001 let resources = resources.clone();
3002 async move {
3003 let (mut parts, body) = req.into_parts();
3004 if needs_body {
3005 match BodyExt::collect(body).await {
3006 Ok(collected) => { parts.extensions.insert(BodyBytes(collected.to_bytes())); }
3007 Err(_) => {
3008 return Ok(json_error_response(
3009 StatusCode::BAD_REQUEST,
3010 "Failed to read request body",
3011 ));
3012 }
3013 }
3014 }
3015 Ok::<_, Infallible>(handler(parts, &resources).await)
3016 }
3017 });
3018
3019 for layer in layers.iter() {
3020 service = layer(service);
3021 }
3022 service
3023}
3024
3025fn merge_layers(
3026 global_layers: &Arc<Vec<ServiceLayer>>,
3027 route_layers: &Arc<Vec<ServiceLayer>>,
3028) -> Arc<Vec<ServiceLayer>> {
3029 if global_layers.is_empty() {
3030 return route_layers.clone();
3031 }
3032 if route_layers.is_empty() {
3033 return global_layers.clone();
3034 }
3035
3036 let mut combined = Vec::with_capacity(global_layers.len() + route_layers.len());
3037 combined.extend(global_layers.iter().cloned());
3038 combined.extend(route_layers.iter().cloned());
3039 Arc::new(combined)
3040}
3041
3042async fn maybe_handle_health_request<R>(
3043 method: &Method,
3044 path: &str,
3045 health: &HealthConfig<R>,
3046 resources: Arc<R>,
3047) -> Option<HttpResponse>
3048where
3049 R: ranvier_core::transition::ResourceRequirement + Clone + Send + Sync + 'static,
3050{
3051 if method != Method::GET {
3052 return None;
3053 }
3054
3055 if let Some(liveness_path) = health.liveness_path.as_ref() {
3056 if path == liveness_path {
3057 return Some(health_json_response("liveness", true, Vec::new()));
3058 }
3059 }
3060
3061 if let Some(readiness_path) = health.readiness_path.as_ref() {
3062 if path == readiness_path {
3063 let (healthy, checks) = run_named_health_checks(&health.checks, resources).await;
3064 return Some(health_json_response("readiness", healthy, checks));
3065 }
3066 }
3067
3068 if let Some(health_path) = health.health_path.as_ref() {
3069 if path == health_path {
3070 let (healthy, checks) = run_named_health_checks(&health.checks, resources).await;
3071 return Some(health_json_response("health", healthy, checks));
3072 }
3073 }
3074
3075 None
3076}
3077
3078async fn serve_single_file(file_path: &str) -> Result<Response<Full<Bytes>>, std::io::Error> {
3080 let path = std::path::Path::new(file_path);
3081 let content = tokio::fs::read(path).await?;
3082 let mime = guess_mime(file_path);
3083 let mut response = Response::new(Full::new(Bytes::from(content)));
3084 if let Ok(value) = http::HeaderValue::from_str(mime) {
3085 response
3086 .headers_mut()
3087 .insert(http::header::CONTENT_TYPE, value);
3088 }
3089 if let Ok(metadata) = tokio::fs::metadata(path).await {
3090 if let Ok(modified) = metadata.modified() {
3091 if let Ok(duration) = modified.duration_since(std::time::UNIX_EPOCH) {
3092 let etag = format!("\"{}\"", duration.as_secs());
3093 if let Ok(value) = http::HeaderValue::from_str(&etag) {
3094 response.headers_mut().insert(http::header::ETAG, value);
3095 }
3096 }
3097 }
3098 }
3099 Ok(response)
3100}
3101
3102async fn serve_static_file(
3104 directory: &str,
3105 file_subpath: &str,
3106 config: &StaticAssetsConfig,
3107 if_none_match: Option<&http::HeaderValue>,
3108 accept_encoding: Option<&http::HeaderValue>,
3109 range_header: Option<&http::HeaderValue>,
3110) -> Result<Response<Full<Bytes>>, std::io::Error> {
3111 let subpath = file_subpath.trim_start_matches('/');
3112
3113 let resolved_subpath;
3115 if subpath.is_empty() || subpath.ends_with('/') {
3116 if let Some(ref index) = config.directory_index {
3117 resolved_subpath = if subpath.is_empty() {
3118 index.clone()
3119 } else {
3120 format!("{}{}", subpath, index)
3121 };
3122 } else {
3123 return Err(std::io::Error::new(
3124 std::io::ErrorKind::NotFound,
3125 "empty path",
3126 ));
3127 }
3128 } else {
3129 resolved_subpath = subpath.to_string();
3130 }
3131
3132 let full_path = std::path::Path::new(directory).join(&resolved_subpath);
3133 let canonical = tokio::fs::canonicalize(&full_path).await?;
3135 let dir_canonical = tokio::fs::canonicalize(directory).await?;
3136 if !canonical.starts_with(&dir_canonical) {
3137 return Err(std::io::Error::new(
3138 std::io::ErrorKind::PermissionDenied,
3139 "path traversal detected",
3140 ));
3141 }
3142
3143 let etag = if let Ok(metadata) = tokio::fs::metadata(&canonical).await {
3145 metadata
3146 .modified()
3147 .ok()
3148 .and_then(|m| m.duration_since(std::time::UNIX_EPOCH).ok())
3149 .map(|d| format!("\"{}\"", d.as_secs()))
3150 } else {
3151 None
3152 };
3153
3154 if let (Some(client_etag), Some(server_etag)) = (if_none_match, &etag) {
3156 if client_etag.as_bytes() == server_etag.as_bytes() {
3157 let mut response = Response::new(Full::new(Bytes::new()));
3158 *response.status_mut() = StatusCode::NOT_MODIFIED;
3159 if let Ok(value) = http::HeaderValue::from_str(server_etag) {
3160 response.headers_mut().insert(http::header::ETAG, value);
3161 }
3162 return Ok(response);
3163 }
3164 }
3165
3166 let (serve_path, content_encoding) = if config.serve_precompressed {
3168 let client_accepts = accept_encoding
3169 .and_then(|v| v.to_str().ok())
3170 .unwrap_or("");
3171 let canonical_str = canonical.to_str().unwrap_or("");
3172
3173 if client_accepts.contains("br") {
3174 let br_path = format!("{}.br", canonical_str);
3175 if tokio::fs::metadata(&br_path).await.is_ok() {
3176 (std::path::PathBuf::from(br_path), Some("br"))
3177 } else if client_accepts.contains("gzip") {
3178 let gz_path = format!("{}.gz", canonical_str);
3179 if tokio::fs::metadata(&gz_path).await.is_ok() {
3180 (std::path::PathBuf::from(gz_path), Some("gzip"))
3181 } else {
3182 (canonical.clone(), None)
3183 }
3184 } else {
3185 (canonical.clone(), None)
3186 }
3187 } else if client_accepts.contains("gzip") {
3188 let gz_path = format!("{}.gz", canonical_str);
3189 if tokio::fs::metadata(&gz_path).await.is_ok() {
3190 (std::path::PathBuf::from(gz_path), Some("gzip"))
3191 } else {
3192 (canonical.clone(), None)
3193 }
3194 } else {
3195 (canonical.clone(), None)
3196 }
3197 } else {
3198 (canonical.clone(), None)
3199 };
3200
3201 let content = tokio::fs::read(&serve_path).await?;
3202 let mime = guess_mime(canonical.to_str().unwrap_or(""));
3204
3205 if config.enable_range_requests {
3207 if let Some(range_val) = range_header {
3208 if let Some(response) = handle_range_request(
3209 range_val,
3210 &content,
3211 mime,
3212 etag.as_deref(),
3213 content_encoding,
3214 ) {
3215 return Ok(response);
3216 }
3217 }
3218 }
3219
3220 let mut response = Response::new(Full::new(Bytes::from(content)));
3221 if let Ok(value) = http::HeaderValue::from_str(mime) {
3222 response
3223 .headers_mut()
3224 .insert(http::header::CONTENT_TYPE, value);
3225 }
3226 if let Some(ref etag_val) = etag {
3227 if let Ok(value) = http::HeaderValue::from_str(etag_val) {
3228 response.headers_mut().insert(http::header::ETAG, value);
3229 }
3230 }
3231 if let Some(encoding) = content_encoding {
3232 if let Ok(value) = http::HeaderValue::from_str(encoding) {
3233 response
3234 .headers_mut()
3235 .insert(http::header::CONTENT_ENCODING, value);
3236 }
3237 }
3238 if config.enable_range_requests {
3239 response
3240 .headers_mut()
3241 .insert(http::header::ACCEPT_RANGES, http::HeaderValue::from_static("bytes"));
3242 }
3243
3244 if config.immutable_cache {
3246 if let Some(filename) = canonical.file_name().and_then(|n| n.to_str()) {
3247 if is_hashed_filename(filename) {
3248 if let Ok(value) = http::HeaderValue::from_str(
3249 "public, max-age=31536000, immutable",
3250 ) {
3251 response
3252 .headers_mut()
3253 .insert(http::header::CACHE_CONTROL, value);
3254 }
3255 }
3256 }
3257 }
3258
3259 Ok(response)
3260}
3261
3262fn handle_range_request(
3266 range_header: &http::HeaderValue,
3267 content: &[u8],
3268 mime: &str,
3269 etag: Option<&str>,
3270 content_encoding: Option<&str>,
3271) -> Option<Response<Full<Bytes>>> {
3272 let range_str = range_header.to_str().ok()?;
3273 let range_spec = range_str.strip_prefix("bytes=")?;
3274 let total = content.len();
3275 if total == 0 {
3276 return None;
3277 }
3278
3279 let (start, end) = if let Some(suffix) = range_spec.strip_prefix('-') {
3280 let n: usize = suffix.parse().ok()?;
3282 if n == 0 || n > total {
3283 return Some(range_not_satisfiable(total));
3284 }
3285 (total - n, total - 1)
3286 } else if range_spec.ends_with('-') {
3287 let start: usize = range_spec.trim_end_matches('-').parse().ok()?;
3289 if start >= total {
3290 return Some(range_not_satisfiable(total));
3291 }
3292 (start, total - 1)
3293 } else {
3294 let mut parts = range_spec.splitn(2, '-');
3296 let start: usize = parts.next()?.parse().ok()?;
3297 let end: usize = parts.next()?.parse().ok()?;
3298 if start > end || start >= total {
3299 return Some(range_not_satisfiable(total));
3300 }
3301 (start, end.min(total - 1))
3302 };
3303
3304 let slice = &content[start..=end];
3305 let content_range = format!("bytes {}-{}/{}", start, end, total);
3306
3307 let mut response = Response::new(Full::new(Bytes::copy_from_slice(slice)));
3308 *response.status_mut() = StatusCode::PARTIAL_CONTENT;
3309 if let Ok(v) = http::HeaderValue::from_str(&content_range) {
3310 response.headers_mut().insert(http::header::CONTENT_RANGE, v);
3311 }
3312 if let Ok(v) = http::HeaderValue::from_str(mime) {
3313 response
3314 .headers_mut()
3315 .insert(http::header::CONTENT_TYPE, v);
3316 }
3317 response
3318 .headers_mut()
3319 .insert(http::header::ACCEPT_RANGES, http::HeaderValue::from_static("bytes"));
3320 if let Some(etag_val) = etag {
3321 if let Ok(v) = http::HeaderValue::from_str(etag_val) {
3322 response.headers_mut().insert(http::header::ETAG, v);
3323 }
3324 }
3325 if let Some(encoding) = content_encoding {
3326 if let Ok(v) = http::HeaderValue::from_str(encoding) {
3327 response
3328 .headers_mut()
3329 .insert(http::header::CONTENT_ENCODING, v);
3330 }
3331 }
3332 Some(response)
3333}
3334
3335fn range_not_satisfiable(total: usize) -> Response<Full<Bytes>> {
3337 let content_range = format!("bytes */{}", total);
3338 let mut response = Response::new(Full::new(Bytes::from("Range Not Satisfiable")));
3339 *response.status_mut() = StatusCode::RANGE_NOT_SATISFIABLE;
3340 if let Ok(v) = http::HeaderValue::from_str(&content_range) {
3341 response.headers_mut().insert(http::header::CONTENT_RANGE, v);
3342 }
3343 response
3344}
3345
3346fn is_hashed_filename(filename: &str) -> bool {
3349 let parts: Vec<&str> = filename.rsplitn(3, '.').collect();
3350 if parts.len() < 3 {
3351 return false;
3352 }
3353 let hash_part = parts[1];
3355 hash_part.len() >= 6 && hash_part.chars().all(|c| c.is_ascii_hexdigit())
3356}
3357
3358fn guess_mime(path: &str) -> &'static str {
3359 match path.rsplit('.').next().unwrap_or("") {
3360 "html" | "htm" => "text/html; charset=utf-8",
3361 "css" => "text/css; charset=utf-8",
3362 "js" | "mjs" | "ts" | "tsx" => "application/javascript; charset=utf-8",
3363 "json" => "application/json; charset=utf-8",
3364 "png" => "image/png",
3365 "jpg" | "jpeg" => "image/jpeg",
3366 "gif" => "image/gif",
3367 "svg" => "image/svg+xml",
3368 "ico" => "image/x-icon",
3369 "avif" => "image/avif",
3370 "webp" => "image/webp",
3371 "webm" => "video/webm",
3372 "mp4" => "video/mp4",
3373 "woff" => "font/woff",
3374 "woff2" => "font/woff2",
3375 "ttf" => "font/ttf",
3376 "txt" => "text/plain; charset=utf-8",
3377 "xml" => "application/xml; charset=utf-8",
3378 "yaml" | "yml" => "application/yaml",
3379 "wasm" => "application/wasm",
3380 "pdf" => "application/pdf",
3381 "map" => "application/json",
3382 _ => "application/octet-stream",
3383 }
3384}
3385
3386fn apply_cache_control(
3387 mut response: Response<Full<Bytes>>,
3388 cache_control: Option<&str>,
3389) -> Response<Full<Bytes>> {
3390 if response.status() == StatusCode::OK {
3391 if let Some(value) = cache_control {
3392 if !response.headers().contains_key(http::header::CACHE_CONTROL) {
3393 if let Ok(header_value) = http::HeaderValue::from_str(value) {
3394 response
3395 .headers_mut()
3396 .insert(http::header::CACHE_CONTROL, header_value);
3397 }
3398 }
3399 }
3400 }
3401 response
3402}
3403
3404async fn maybe_handle_static_request(
3405 req: Request<Incoming>,
3406 method: &Method,
3407 path: &str,
3408 static_assets: &StaticAssetsConfig,
3409) -> Result<Request<Incoming>, HttpResponse> {
3410 if method != Method::GET && method != Method::HEAD {
3411 return Ok(req);
3412 }
3413
3414 if let Some(mount) = static_assets
3415 .mounts
3416 .iter()
3417 .find(|mount| strip_mount_prefix(path, &mount.route_prefix).is_some())
3418 {
3419 let accept_encoding = req.headers().get(http::header::ACCEPT_ENCODING).cloned();
3420 let if_none_match = req.headers().get(http::header::IF_NONE_MATCH).cloned();
3421 let range_header = req.headers().get(http::header::RANGE).cloned();
3422 let Some(stripped_path) = strip_mount_prefix(path, &mount.route_prefix) else {
3423 return Ok(req);
3424 };
3425 let response = match serve_static_file(
3426 &mount.directory,
3427 &stripped_path,
3428 static_assets,
3429 if_none_match.as_ref(),
3430 accept_encoding.as_ref(),
3431 range_header.as_ref(),
3432 )
3433 .await
3434 {
3435 Ok(response) => response,
3436 Err(_) => {
3437 return Err(Response::builder()
3438 .status(StatusCode::INTERNAL_SERVER_ERROR)
3439 .body(
3440 Full::new(Bytes::from("Failed to serve static asset"))
3441 .map_err(|never| match never {})
3442 .boxed(),
3443 )
3444 .unwrap_or_else(|_| {
3445 Response::new(
3446 Full::new(Bytes::new())
3447 .map_err(|never| match never {})
3448 .boxed(),
3449 )
3450 }));
3451 }
3452 };
3453 let mut response = apply_cache_control(response, static_assets.cache_control.as_deref());
3454 response = maybe_compress_static_response(
3455 response,
3456 accept_encoding,
3457 static_assets.enable_compression,
3458 );
3459 let (parts, body) = response.into_parts();
3460 return Err(Response::from_parts(
3461 parts,
3462 body.map_err(|never| match never {}).boxed(),
3463 ));
3464 }
3465
3466 if let Some(spa_file) = static_assets.spa_fallback.as_ref() {
3467 if looks_like_spa_request(path) {
3468 let accept_encoding = req.headers().get(http::header::ACCEPT_ENCODING).cloned();
3469 let response = match serve_single_file(spa_file).await {
3470 Ok(response) => response,
3471 Err(_) => {
3472 return Err(Response::builder()
3473 .status(StatusCode::INTERNAL_SERVER_ERROR)
3474 .body(
3475 Full::new(Bytes::from("Failed to serve SPA fallback"))
3476 .map_err(|never| match never {})
3477 .boxed(),
3478 )
3479 .unwrap_or_else(|_| {
3480 Response::new(
3481 Full::new(Bytes::new())
3482 .map_err(|never| match never {})
3483 .boxed(),
3484 )
3485 }));
3486 }
3487 };
3488 let mut response =
3489 apply_cache_control(response, static_assets.cache_control.as_deref());
3490 response = maybe_compress_static_response(
3491 response,
3492 accept_encoding,
3493 static_assets.enable_compression,
3494 );
3495 let (parts, body) = response.into_parts();
3496 return Err(Response::from_parts(
3497 parts,
3498 body.map_err(|never| match never {}).boxed(),
3499 ));
3500 }
3501 }
3502
3503 Ok(req)
3504}
3505
3506fn strip_mount_prefix(path: &str, prefix: &str) -> Option<String> {
3507 let normalized_prefix = if prefix == "/" {
3508 "/"
3509 } else {
3510 prefix.trim_end_matches('/')
3511 };
3512
3513 if normalized_prefix == "/" {
3514 return Some(path.to_string());
3515 }
3516
3517 if path == normalized_prefix {
3518 return Some("/".to_string());
3519 }
3520
3521 let with_slash = format!("{normalized_prefix}/");
3522 path.strip_prefix(&with_slash)
3523 .map(|stripped| format!("/{}", stripped))
3524}
3525
3526fn looks_like_spa_request(path: &str) -> bool {
3527 let tail = path.rsplit('/').next().unwrap_or_default();
3528 !tail.contains('.')
3529}
3530
3531fn maybe_compress_static_response(
3532 response: Response<Full<Bytes>>,
3533 accept_encoding: Option<http::HeaderValue>,
3534 enable_compression: bool,
3535) -> Response<Full<Bytes>> {
3536 if !enable_compression {
3537 return response;
3538 }
3539
3540 let Some(accept_encoding) = accept_encoding else {
3541 return response;
3542 };
3543
3544 let accept_str = accept_encoding.to_str().unwrap_or("");
3545 if !accept_str.contains("gzip") {
3546 return response;
3547 }
3548
3549 let status = response.status();
3550 let headers = response.headers().clone();
3551 let body = response.into_body();
3552
3553 let data = futures_util::FutureExt::now_or_never(BodyExt::collect(body))
3555 .and_then(|r| r.ok())
3556 .map(|collected| collected.to_bytes())
3557 .unwrap_or_default();
3558
3559 let compressed = {
3561 use flate2::write::GzEncoder;
3562 use flate2::Compression;
3563 use std::io::Write;
3564 let mut encoder = GzEncoder::new(Vec::new(), Compression::default());
3565 let _ = encoder.write_all(&data);
3566 encoder.finish().unwrap_or_default()
3567 };
3568
3569 let mut builder = Response::builder().status(status);
3570 for (name, value) in headers.iter() {
3571 if name != http::header::CONTENT_LENGTH && name != http::header::CONTENT_ENCODING {
3572 builder = builder.header(name, value);
3573 }
3574 }
3575 builder
3576 .header(http::header::CONTENT_ENCODING, "gzip")
3577 .body(Full::new(Bytes::from(compressed)))
3578 .unwrap_or_else(|_| Response::new(Full::new(Bytes::new())))
3579}
3580
3581async fn run_named_health_checks<R>(
3582 checks: &[NamedHealthCheck<R>],
3583 resources: Arc<R>,
3584) -> (bool, Vec<HealthCheckReport>)
3585where
3586 R: ranvier_core::transition::ResourceRequirement + Clone + Send + Sync + 'static,
3587{
3588 let mut reports = Vec::with_capacity(checks.len());
3589 let mut healthy = true;
3590
3591 for check in checks {
3592 match (check.check)(resources.clone()).await {
3593 Ok(()) => reports.push(HealthCheckReport {
3594 name: check.name.clone(),
3595 status: "ok",
3596 error: None,
3597 }),
3598 Err(error) => {
3599 healthy = false;
3600 reports.push(HealthCheckReport {
3601 name: check.name.clone(),
3602 status: "error",
3603 error: Some(error),
3604 });
3605 }
3606 }
3607 }
3608
3609 (healthy, reports)
3610}
3611
3612fn health_json_response(
3613 probe: &'static str,
3614 healthy: bool,
3615 checks: Vec<HealthCheckReport>,
3616) -> HttpResponse {
3617 let status_code = if healthy {
3618 StatusCode::OK
3619 } else {
3620 StatusCode::SERVICE_UNAVAILABLE
3621 };
3622 let status = if healthy { "ok" } else { "degraded" };
3623 let payload = HealthReport {
3624 status,
3625 probe,
3626 checks,
3627 };
3628
3629 let body = serde_json::to_vec(&payload)
3630 .unwrap_or_else(|_| br#"{"status":"error","probe":"health"}"#.to_vec());
3631
3632 Response::builder()
3633 .status(status_code)
3634 .header(http::header::CONTENT_TYPE, "application/json")
3635 .body(
3636 Full::new(Bytes::from(body))
3637 .map_err(|never| match never {})
3638 .boxed(),
3639 )
3640 .expect("valid HTTP response construction")
3641}
3642
3643async fn shutdown_signal() {
3644 #[cfg(unix)]
3645 {
3646 use tokio::signal::unix::{SignalKind, signal};
3647
3648 match signal(SignalKind::terminate()) {
3649 Ok(mut terminate) => {
3650 tokio::select! {
3651 _ = tokio::signal::ctrl_c() => {}
3652 _ = terminate.recv() => {}
3653 }
3654 }
3655 Err(err) => {
3656 tracing::warn!("Failed to install SIGTERM handler: {:?}", err);
3657 if let Err(ctrl_c_err) = tokio::signal::ctrl_c().await {
3658 tracing::warn!("Failed to listen for Ctrl+C: {:?}", ctrl_c_err);
3659 }
3660 }
3661 }
3662 }
3663
3664 #[cfg(not(unix))]
3665 {
3666 if let Err(err) = tokio::signal::ctrl_c().await {
3667 tracing::warn!("Failed to listen for Ctrl+C: {:?}", err);
3668 }
3669 }
3670}
3671
3672async fn drain_connections(
3673 connections: &mut tokio::task::JoinSet<()>,
3674 graceful_shutdown_timeout: Duration,
3675) -> bool {
3676 if connections.is_empty() {
3677 return false;
3678 }
3679
3680 let drain_result = tokio::time::timeout(graceful_shutdown_timeout, async {
3681 while let Some(join_result) = connections.join_next().await {
3682 if let Err(err) = join_result {
3683 tracing::warn!("Connection task join error during shutdown: {:?}", err);
3684 }
3685 }
3686 })
3687 .await;
3688
3689 if drain_result.is_err() {
3690 tracing::warn!(
3691 "Graceful shutdown timeout reached ({:?}). Aborting remaining connections.",
3692 graceful_shutdown_timeout
3693 );
3694 connections.abort_all();
3695 while let Some(join_result) = connections.join_next().await {
3696 if let Err(err) = join_result {
3697 tracing::warn!("Connection task abort join error: {:?}", err);
3698 }
3699 }
3700 true
3701 } else {
3702 false
3703 }
3704}
3705
3706#[cfg(feature = "tls")]
3708fn build_tls_acceptor(
3709 cert_path: &str,
3710 key_path: &str,
3711) -> Result<tokio_rustls::TlsAcceptor, Box<dyn std::error::Error + Send + Sync>> {
3712 use rustls::ServerConfig;
3713 use rustls_pemfile::{certs, private_key};
3714 use std::io::BufReader;
3715 use tokio_rustls::TlsAcceptor;
3716
3717 let cert_file = std::fs::File::open(cert_path)
3718 .map_err(|e| format!("Failed to open certificate file '{}': {}", cert_path, e))?;
3719 let key_file = std::fs::File::open(key_path)
3720 .map_err(|e| format!("Failed to open key file '{}': {}", key_path, e))?;
3721
3722 let cert_chain: Vec<_> = certs(&mut BufReader::new(cert_file))
3723 .collect::<Result<Vec<_>, _>>()
3724 .map_err(|e| format!("Failed to parse certificate PEM: {}", e))?;
3725
3726 let key = private_key(&mut BufReader::new(key_file))
3727 .map_err(|e| format!("Failed to parse private key PEM: {}", e))?
3728 .ok_or("No private key found in key file")?;
3729
3730 let config = ServerConfig::builder()
3731 .with_no_client_auth()
3732 .with_single_cert(cert_chain, key)
3733 .map_err(|e| format!("TLS configuration error: {}", e))?;
3734
3735 Ok(TlsAcceptor::from(Arc::new(config)))
3736}
3737
3738impl<R> Default for HttpIngress<R>
3739where
3740 R: ranvier_core::transition::ResourceRequirement + Clone + Send + Sync + 'static,
3741{
3742 fn default() -> Self {
3743 Self::new()
3744 }
3745}
3746
3747#[derive(Clone)]
3749pub struct RawIngressService<R> {
3750 routes: Arc<Vec<RouteEntry<R>>>,
3751 fallback: Option<RouteHandler<R>>,
3752 layers: Arc<Vec<ServiceLayer>>,
3753 health: Arc<HealthConfig<R>>,
3754 static_assets: Arc<StaticAssetsConfig>,
3755 preflight_config: Arc<Option<PreflightConfig>>,
3756 resources: Arc<R>,
3757}
3758
3759impl<R> hyper::service::Service<Request<Incoming>> for RawIngressService<R>
3760where
3761 R: ranvier_core::transition::ResourceRequirement + Clone + Send + Sync + 'static,
3762{
3763 type Response = HttpResponse;
3764 type Error = Infallible;
3765 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
3766
3767 fn call(&self, req: Request<Incoming>) -> Self::Future {
3768 let routes = self.routes.clone();
3769 let fallback = self.fallback.clone();
3770 let layers = self.layers.clone();
3771 let health = self.health.clone();
3772 let static_assets = self.static_assets.clone();
3773 let preflight_config = self.preflight_config.clone();
3774 let resources = self.resources.clone();
3775
3776 Box::pin(async move {
3777 let service = build_http_service(
3778 routes,
3779 fallback,
3780 resources,
3781 layers,
3782 health,
3783 static_assets,
3784 preflight_config,
3785 #[cfg(feature = "http3")]
3786 None,
3787 );
3788 service.call(req).await
3789 })
3790 }
3791}
3792
3793#[cfg(test)]
3794mod tests {
3795 use super::*;
3796 use async_trait::async_trait;
3797 use futures_util::{SinkExt, StreamExt};
3798 use serde::Deserialize;
3799 use std::fs;
3800 use std::sync::atomic::{AtomicBool, Ordering};
3801 use tempfile::tempdir;
3802 use tokio::io::{AsyncReadExt, AsyncWriteExt};
3803 use tokio_tungstenite::tungstenite::Message as WsClientMessage;
3804 use tokio_tungstenite::tungstenite::client::IntoClientRequest;
3805
3806 async fn connect_with_retry(addr: std::net::SocketAddr) -> tokio::net::TcpStream {
3807 let deadline = tokio::time::Instant::now() + Duration::from_secs(2);
3808
3809 loop {
3810 match tokio::net::TcpStream::connect(addr).await {
3811 Ok(stream) => return stream,
3812 Err(error) => {
3813 if tokio::time::Instant::now() >= deadline {
3814 panic!("connect server: {error}");
3815 }
3816 tokio::time::sleep(Duration::from_millis(25)).await;
3817 }
3818 }
3819 }
3820 }
3821
3822 #[test]
3823 fn route_pattern_matches_static_path() {
3824 let pattern = RoutePattern::parse("/orders/list");
3825 let params = pattern.match_path("/orders/list").expect("should match");
3826 assert!(params.into_inner().is_empty());
3827 }
3828
3829 #[test]
3830 fn route_pattern_matches_param_segments() {
3831 let pattern = RoutePattern::parse("/orders/:id/items/:item_id");
3832 let params = pattern
3833 .match_path("/orders/42/items/sku-123")
3834 .expect("should match");
3835 assert_eq!(params.get("id"), Some("42"));
3836 assert_eq!(params.get("item_id"), Some("sku-123"));
3837 }
3838
3839 #[test]
3840 fn route_pattern_matches_wildcard_segment() {
3841 let pattern = RoutePattern::parse("/assets/*path");
3842 let params = pattern
3843 .match_path("/assets/css/theme/light.css")
3844 .expect("should match");
3845 assert_eq!(params.get("path"), Some("css/theme/light.css"));
3846 }
3847
3848 #[test]
3849 fn route_pattern_rejects_non_matching_path() {
3850 let pattern = RoutePattern::parse("/orders/:id");
3851 assert!(pattern.match_path("/users/42").is_none());
3852 }
3853
3854 #[test]
3855 fn graceful_shutdown_timeout_defaults_to_30_seconds() {
3856 let ingress = HttpIngress::<()>::new();
3857 assert_eq!(ingress.graceful_shutdown_timeout, Duration::from_secs(30));
3858 assert!(ingress.layers.is_empty());
3859 assert!(ingress.bus_injectors.is_empty());
3860 assert!(ingress.static_assets.mounts.is_empty());
3861 assert!(ingress.on_start.is_none());
3862 assert!(ingress.on_shutdown.is_none());
3863 }
3864
3865 #[test]
3866 fn route_without_layer_keeps_empty_route_middleware_stack() {
3867 let ingress =
3868 HttpIngress::<()>::new().get("/ping", Axon::<(), (), String, ()>::new("Ping"));
3869 assert_eq!(ingress.routes.len(), 1);
3870 assert!(ingress.routes[0].layers.is_empty());
3871 assert!(ingress.routes[0].apply_global_layers);
3872 }
3873
3874 #[test]
3875 fn timeout_layer_registers_builtin_middleware() {
3876 let ingress = HttpIngress::<()>::new().timeout_layer(Duration::from_secs(1));
3877 assert_eq!(ingress.layers.len(), 1);
3878 }
3879
3880 #[test]
3881 fn request_id_layer_registers_builtin_middleware() {
3882 let ingress = HttpIngress::<()>::new().request_id_layer();
3883 assert_eq!(ingress.layers.len(), 1);
3884 }
3885
3886 #[test]
3887 fn compression_layer_registers_builtin_middleware() {
3888 let ingress = HttpIngress::<()>::new().compression_layer();
3889 assert!(ingress.static_assets.enable_compression);
3890 }
3891
3892 #[test]
3893 fn bus_injector_registration_adds_hook() {
3894 let ingress = HttpIngress::<()>::new().bus_injector(|_req, bus| {
3895 bus.insert("ok".to_string());
3896 });
3897 assert_eq!(ingress.bus_injectors.len(), 1);
3898 }
3899
3900 #[test]
3901 fn ws_route_registers_get_route_pattern() {
3902 let ingress =
3903 HttpIngress::<()>::new().ws("/ws/events", |_socket, _resources, _bus| async {});
3904 assert_eq!(ingress.routes.len(), 1);
3905 assert_eq!(ingress.routes[0].method, Method::GET);
3906 assert_eq!(ingress.routes[0].pattern.raw, "/ws/events");
3907 }
3908
3909 #[derive(Debug, Deserialize)]
3910 struct WsWelcomeFrame {
3911 connection_id: String,
3912 path: String,
3913 tenant: String,
3914 }
3915
3916 #[tokio::test]
3917 async fn ws_route_upgrades_and_bridges_event_source_sink_with_connection_bus() {
3918 let probe = std::net::TcpListener::bind("127.0.0.1:0").expect("bind probe");
3919 let addr = probe.local_addr().expect("local addr");
3920 drop(probe);
3921
3922 let ingress = HttpIngress::<()>::new()
3923 .bind(addr.to_string())
3924 .bus_injector(|req, bus| {
3925 if let Some(value) = req.headers.get("x-tenant-id").and_then(|v| v.to_str().ok()) {
3926 bus.insert(value.to_string());
3927 }
3928 })
3929 .ws("/ws/echo", |mut socket, _resources, bus| async move {
3930 let tenant = bus
3931 .read::<String>()
3932 .cloned()
3933 .unwrap_or_else(|| "unknown".to_string());
3934 if let Some(session) = bus.read::<WebSocketSessionContext>() {
3935 let welcome = serde_json::json!({
3936 "connection_id": session.connection_id().to_string(),
3937 "path": session.path(),
3938 "tenant": tenant,
3939 });
3940 let _ = socket.send_json(&welcome).await;
3941 }
3942
3943 while let Some(event) = socket.next_event().await {
3944 match event {
3945 WebSocketEvent::Text(text) => {
3946 let _ = socket.send_event(format!("echo:{text}")).await;
3947 }
3948 WebSocketEvent::Binary(bytes) => {
3949 let _ = socket.send_event(bytes).await;
3950 }
3951 WebSocketEvent::Close => break,
3952 WebSocketEvent::Ping(_) | WebSocketEvent::Pong(_) => {}
3953 }
3954 }
3955 });
3956
3957 let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel::<()>();
3958 let server = tokio::spawn(async move {
3959 ingress
3960 .run_with_shutdown_signal((), async move {
3961 let _ = shutdown_rx.await;
3962 })
3963 .await
3964 });
3965
3966 let ws_uri = format!("ws://{addr}/ws/echo?room=alpha");
3967 let mut ws_request = ws_uri
3968 .as_str()
3969 .into_client_request()
3970 .expect("ws client request");
3971 ws_request
3972 .headers_mut()
3973 .insert("x-tenant-id", http::HeaderValue::from_static("acme"));
3974 let (mut client, _response) = tokio_tungstenite::connect_async(ws_request)
3975 .await
3976 .expect("websocket connect");
3977
3978 let welcome = client
3979 .next()
3980 .await
3981 .expect("welcome frame")
3982 .expect("welcome frame ok");
3983 let welcome_text = match welcome {
3984 WsClientMessage::Text(text) => text.to_string(),
3985 other => panic!("expected text welcome frame, got {other:?}"),
3986 };
3987 let welcome_payload: WsWelcomeFrame =
3988 serde_json::from_str(&welcome_text).expect("welcome json");
3989 assert_eq!(welcome_payload.path, "/ws/echo");
3990 assert_eq!(welcome_payload.tenant, "acme");
3991 assert!(!welcome_payload.connection_id.is_empty());
3992
3993 client
3994 .send(WsClientMessage::Text("hello".into()))
3995 .await
3996 .expect("send text");
3997 let echo_text = client
3998 .next()
3999 .await
4000 .expect("echo text frame")
4001 .expect("echo text frame ok");
4002 assert_eq!(echo_text, WsClientMessage::Text("echo:hello".into()));
4003
4004 client
4005 .send(WsClientMessage::Binary(vec![1, 2, 3, 4].into()))
4006 .await
4007 .expect("send binary");
4008 let echo_binary = client
4009 .next()
4010 .await
4011 .expect("echo binary frame")
4012 .expect("echo binary frame ok");
4013 assert_eq!(
4014 echo_binary,
4015 WsClientMessage::Binary(vec![1, 2, 3, 4].into())
4016 );
4017
4018 client.close(None).await.expect("close websocket");
4019
4020 let _ = shutdown_tx.send(());
4021 server
4022 .await
4023 .expect("server join")
4024 .expect("server shutdown should succeed");
4025 }
4026
4027 #[test]
4028 fn route_descriptors_export_http_and_health_paths() {
4029 let ingress = HttpIngress::<()>::new()
4030 .get(
4031 "/orders/:id",
4032 Axon::<(), (), String, ()>::new("OrderById"),
4033 )
4034 .health_endpoint("/healthz")
4035 .readiness_liveness("/readyz", "/livez");
4036
4037 let descriptors = ingress.route_descriptors();
4038
4039 assert!(
4040 descriptors
4041 .iter()
4042 .any(|descriptor| descriptor.method() == Method::GET
4043 && descriptor.path_pattern() == "/orders/:id")
4044 );
4045 assert!(
4046 descriptors
4047 .iter()
4048 .any(|descriptor| descriptor.method() == Method::GET
4049 && descriptor.path_pattern() == "/healthz")
4050 );
4051 assert!(
4052 descriptors
4053 .iter()
4054 .any(|descriptor| descriptor.method() == Method::GET
4055 && descriptor.path_pattern() == "/readyz")
4056 );
4057 assert!(
4058 descriptors
4059 .iter()
4060 .any(|descriptor| descriptor.method() == Method::GET
4061 && descriptor.path_pattern() == "/livez")
4062 );
4063 }
4064
4065 #[tokio::test]
4066 async fn lifecycle_hooks_fire_on_start_and_shutdown() {
4067 let started = Arc::new(AtomicBool::new(false));
4068 let shutdown = Arc::new(AtomicBool::new(false));
4069
4070 let started_flag = started.clone();
4071 let shutdown_flag = shutdown.clone();
4072
4073 let ingress = HttpIngress::<()>::new()
4074 .bind("127.0.0.1:0")
4075 .on_start(move || {
4076 started_flag.store(true, Ordering::SeqCst);
4077 })
4078 .on_shutdown(move || {
4079 shutdown_flag.store(true, Ordering::SeqCst);
4080 })
4081 .graceful_shutdown(Duration::from_millis(50));
4082
4083 ingress
4084 .run_with_shutdown_signal((), async {
4085 tokio::time::sleep(Duration::from_millis(20)).await;
4086 })
4087 .await
4088 .expect("server should exit gracefully");
4089
4090 assert!(started.load(Ordering::SeqCst));
4091 assert!(shutdown.load(Ordering::SeqCst));
4092 }
4093
4094 #[tokio::test]
4095 async fn graceful_shutdown_drains_in_flight_requests_before_exit() {
4096 #[derive(Clone)]
4097 struct SlowDrainRoute;
4098
4099 #[async_trait]
4100 impl Transition<(), String> for SlowDrainRoute {
4101 type Error = String;
4102 type Resources = ();
4103
4104 async fn run(
4105 &self,
4106 _state: (),
4107 _resources: &Self::Resources,
4108 _bus: &mut Bus,
4109 ) -> Outcome<String, Self::Error> {
4110 tokio::time::sleep(Duration::from_millis(120)).await;
4111 Outcome::next("drained-ok".to_string())
4112 }
4113 }
4114
4115 let probe = std::net::TcpListener::bind("127.0.0.1:0").expect("bind probe");
4116 let addr = probe.local_addr().expect("local addr");
4117 drop(probe);
4118
4119 let ingress = HttpIngress::<()>::new()
4120 .bind(addr.to_string())
4121 .graceful_shutdown(Duration::from_millis(500))
4122 .get(
4123 "/drain",
4124 Axon::<(), (), String, ()>::new("SlowDrain").then(SlowDrainRoute),
4125 );
4126
4127 let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel::<()>();
4128 let server = tokio::spawn(async move {
4129 ingress
4130 .run_with_shutdown_signal((), async move {
4131 let _ = shutdown_rx.await;
4132 })
4133 .await
4134 });
4135
4136 let mut stream = connect_with_retry(addr).await;
4137 stream
4138 .write_all(b"GET /drain HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n")
4139 .await
4140 .expect("write request");
4141
4142 tokio::time::sleep(Duration::from_millis(20)).await;
4143 let _ = shutdown_tx.send(());
4144
4145 let mut buf = Vec::new();
4146 stream.read_to_end(&mut buf).await.expect("read response");
4147 let response = String::from_utf8_lossy(&buf);
4148 assert!(response.starts_with("HTTP/1.1 200"), "{response}");
4149 assert!(response.contains("drained-ok"), "{response}");
4150
4151 server
4152 .await
4153 .expect("server join")
4154 .expect("server shutdown should succeed");
4155 }
4156
4157 #[tokio::test]
4158 async fn serve_dir_serves_static_file_with_cache_and_metadata_headers() {
4159 let temp = tempdir().expect("tempdir");
4160 let root = temp.path().join("public");
4161 fs::create_dir_all(&root).expect("create dir");
4162 let file = root.join("hello.txt");
4163 fs::write(&file, "hello static").expect("write file");
4164
4165 let ingress =
4166 Ranvier::http::<()>().serve_dir("/static", root.to_string_lossy().to_string());
4167 let app = crate::test_harness::TestApp::new(ingress, ());
4168 let response = app
4169 .send(crate::test_harness::TestRequest::get("/static/hello.txt"))
4170 .await
4171 .expect("request should succeed");
4172
4173 assert_eq!(response.status(), StatusCode::OK);
4174 assert_eq!(response.text().expect("utf8"), "hello static");
4175 assert!(response.header("cache-control").is_some());
4176 let has_metadata_header =
4177 response.header("etag").is_some() || response.header("last-modified").is_some();
4178 assert!(has_metadata_header);
4179 }
4180
4181 #[tokio::test]
4182 async fn spa_fallback_returns_index_for_unmatched_path() {
4183 let temp = tempdir().expect("tempdir");
4184 let index = temp.path().join("index.html");
4185 fs::write(&index, "<html><body>spa</body></html>").expect("write index");
4186
4187 let ingress = Ranvier::http::<()>().spa_fallback(index.to_string_lossy().to_string());
4188 let app = crate::test_harness::TestApp::new(ingress, ());
4189 let response = app
4190 .send(crate::test_harness::TestRequest::get("/dashboard/settings"))
4191 .await
4192 .expect("request should succeed");
4193
4194 assert_eq!(response.status(), StatusCode::OK);
4195 assert!(response.text().expect("utf8").contains("spa"));
4196 }
4197
4198 #[tokio::test]
4199 async fn static_compression_layer_sets_content_encoding_for_gzip_client() {
4200 let temp = tempdir().expect("tempdir");
4201 let root = temp.path().join("public");
4202 fs::create_dir_all(&root).expect("create dir");
4203 let file = root.join("compressed.txt");
4204 fs::write(&file, "compress me ".repeat(400)).expect("write file");
4205
4206 let ingress = Ranvier::http::<()>()
4207 .serve_dir("/static", root.to_string_lossy().to_string())
4208 .compression_layer();
4209 let app = crate::test_harness::TestApp::new(ingress, ());
4210 let response = app
4211 .send(
4212 crate::test_harness::TestRequest::get("/static/compressed.txt")
4213 .header("accept-encoding", "gzip"),
4214 )
4215 .await
4216 .expect("request should succeed");
4217
4218 assert_eq!(response.status(), StatusCode::OK);
4219 assert_eq!(
4220 response
4221 .header("content-encoding")
4222 .and_then(|value| value.to_str().ok()),
4223 Some("gzip")
4224 );
4225 }
4226
4227 #[tokio::test]
4228 async fn drain_connections_completes_before_timeout() {
4229 let mut connections = tokio::task::JoinSet::new();
4230 connections.spawn(async {
4231 tokio::time::sleep(Duration::from_millis(20)).await;
4232 });
4233
4234 let timed_out = drain_connections(&mut connections, Duration::from_millis(200)).await;
4235 assert!(!timed_out);
4236 assert!(connections.is_empty());
4237 }
4238
4239 #[tokio::test]
4240 async fn drain_connections_times_out_and_aborts() {
4241 let mut connections = tokio::task::JoinSet::new();
4242 connections.spawn(async {
4243 tokio::time::sleep(Duration::from_secs(10)).await;
4244 });
4245
4246 let timed_out = drain_connections(&mut connections, Duration::from_millis(10)).await;
4247 assert!(timed_out);
4248 assert!(connections.is_empty());
4249 }
4250
4251 #[tokio::test]
4252 async fn timeout_layer_returns_408_for_slow_route() {
4253 #[derive(Clone)]
4254 struct SlowRoute;
4255
4256 #[async_trait]
4257 impl Transition<(), String> for SlowRoute {
4258 type Error = String;
4259 type Resources = ();
4260
4261 async fn run(
4262 &self,
4263 _state: (),
4264 _resources: &Self::Resources,
4265 _bus: &mut Bus,
4266 ) -> Outcome<String, Self::Error> {
4267 tokio::time::sleep(Duration::from_millis(80)).await;
4268 Outcome::next("slow-ok".to_string())
4269 }
4270 }
4271
4272 let probe = std::net::TcpListener::bind("127.0.0.1:0").expect("bind probe");
4273 let addr = probe.local_addr().expect("local addr");
4274 drop(probe);
4275
4276 let ingress = HttpIngress::<()>::new()
4277 .bind(addr.to_string())
4278 .timeout_layer(Duration::from_millis(10))
4279 .get(
4280 "/slow",
4281 Axon::<(), (), String, ()>::new("Slow").then(SlowRoute),
4282 );
4283
4284 let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel::<()>();
4285 let server = tokio::spawn(async move {
4286 ingress
4287 .run_with_shutdown_signal((), async move {
4288 let _ = shutdown_rx.await;
4289 })
4290 .await
4291 });
4292
4293 let mut stream = connect_with_retry(addr).await;
4294 stream
4295 .write_all(b"GET /slow HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n")
4296 .await
4297 .expect("write request");
4298
4299 let mut buf = Vec::new();
4300 stream.read_to_end(&mut buf).await.expect("read response");
4301 let response = String::from_utf8_lossy(&buf);
4302 assert!(response.starts_with("HTTP/1.1 408"), "{response}");
4303
4304 let _ = shutdown_tx.send(());
4305 server
4306 .await
4307 .expect("server join")
4308 .expect("server shutdown should succeed");
4309 }
4310
4311 fn extract_body(response: Response<Full<Bytes>>) -> Bytes {
4314 use http_body_util::BodyExt;
4315 let rt = tokio::runtime::Builder::new_current_thread()
4316 .build()
4317 .unwrap();
4318 rt.block_on(async {
4319 let collected = response.into_body().collect().await.unwrap();
4320 collected.to_bytes()
4321 })
4322 }
4323
4324 #[test]
4325 fn handle_range_bytes_start_end() {
4326 let content = b"Hello, World!";
4327 let range = http::HeaderValue::from_static("bytes=0-4");
4328 let response =
4329 super::handle_range_request(&range, content, "text/plain", None, None).unwrap();
4330 assert_eq!(response.status(), StatusCode::PARTIAL_CONTENT);
4331 assert_eq!(
4332 response.headers().get(http::header::CONTENT_RANGE).unwrap(),
4333 "bytes 0-4/13"
4334 );
4335 assert_eq!(extract_body(response), "Hello");
4336 }
4337
4338 #[test]
4339 fn handle_range_suffix() {
4340 let content = b"Hello, World!";
4341 let range = http::HeaderValue::from_static("bytes=-6");
4342 let response =
4343 super::handle_range_request(&range, content, "text/plain", None, None).unwrap();
4344 assert_eq!(response.status(), StatusCode::PARTIAL_CONTENT);
4345 assert_eq!(
4346 response.headers().get(http::header::CONTENT_RANGE).unwrap(),
4347 "bytes 7-12/13"
4348 );
4349 }
4350
4351 #[test]
4352 fn handle_range_from_offset() {
4353 let content = b"Hello, World!";
4354 let range = http::HeaderValue::from_static("bytes=7-");
4355 let response =
4356 super::handle_range_request(&range, content, "text/plain", None, None).unwrap();
4357 assert_eq!(response.status(), StatusCode::PARTIAL_CONTENT);
4358 assert_eq!(
4359 response.headers().get(http::header::CONTENT_RANGE).unwrap(),
4360 "bytes 7-12/13"
4361 );
4362 }
4363
4364 #[test]
4365 fn handle_range_out_of_bounds_returns_416() {
4366 let content = b"Hello";
4367 let range = http::HeaderValue::from_static("bytes=10-20");
4368 let response =
4369 super::handle_range_request(&range, content, "text/plain", None, None).unwrap();
4370 assert_eq!(response.status(), StatusCode::RANGE_NOT_SATISFIABLE);
4371 assert_eq!(
4372 response.headers().get(http::header::CONTENT_RANGE).unwrap(),
4373 "bytes */5"
4374 );
4375 }
4376
4377 #[test]
4378 fn handle_range_includes_accept_ranges_header() {
4379 let content = b"Hello, World!";
4380 let range = http::HeaderValue::from_static("bytes=0-0");
4381 let response =
4382 super::handle_range_request(&range, content, "text/plain", None, None).unwrap();
4383 assert_eq!(
4384 response.headers().get(http::header::ACCEPT_RANGES).unwrap(),
4385 "bytes"
4386 );
4387 }
4388
4389}