Skip to main content

gateway_runtime/
gateway.rs

1//! # Gateway Builder and Service
2//!
3//! ## Purpose
4//! This module provides the `Gateway` builder struct, which is the primary entry point for configuring
5//! and constructing the runtime service stack. It orchestrates the various middleware layers
6//! (routing, authentication, error handling, metadata, etc.) into a cohesive `tower::Service`.
7//!
8//! ## Scope
9//! This module defines:
10//! -   `Gateway`: A builder for configuring the runtime.
11//! -   `RouterService`: The core service responsible for routing and authentication.
12//! -   `UnescapingMode`: Configuration for path unescaping.
13//! -   Type aliases for various handler callbacks (`ErrorHandler`, `AuthVerifier`, etc.).
14//!
15//! ## Middleware Stack
16//! The `Gateway::into_service()` method constructs a `tower::Service` with the following layer order
17//! (outer to inner):
18//! 1.  **Tracing**: Request/Response tracing.
19//! 2.  **Metrics**: Request duration and status recording.
20//! 3.  **Error Handling**: Catches errors from inner layers and converts them to HTTP responses.
21//! 4.  **Response Modifiers**: modifying the response before sending it back.
22//! 5.  **Headers**: Filtering/Transforming incoming and outgoing headers.
23//! 6.  **Metadata**: Extracting and injecting metadata (e.g., from headers or annotators).
24//! 7.  **RouterService**: The core logic (Path matching -> Auth -> Dispatch).
25//!
26//! ## Position in the Architecture
27//! The `Gateway` is the glue that binds the `Router` (generated code registry) with the
28//! runtime features (handlers, defaults). The resulting service is typically passed to
29//! an HTTP server (like `hyper`).
30
31use crate::alloc::boxed::Box;
32use crate::alloc::string::String;
33use crate::alloc::sync::Arc;
34use crate::alloc::vec::Vec;
35use crate::defaults;
36use crate::layers::{
37    error::ErrorLayer, headers::HeaderLayer, metadata::MetadataLayer, response::ResponseLayer,
38};
39use crate::metadata::MetadataForwardingConfig;
40use crate::router::{RouteMetadata, Router};
41use crate::{GatewayError, GatewayRequest, GatewayResponse, GatewayResult};
42use core::task::{Context, Poll};
43use core::time::Duration;
44use percent_encoding::percent_decode_str;
45use std::future::Future;
46use std::pin::Pin;
47use tonic::metadata::MetadataMap;
48use tower::{Service, ServiceBuilder};
49
50/// A handler for converting errors into HTTP responses.
51pub type ErrorHandler = Arc<dyn Fn(&GatewayRequest, GatewayError) -> GatewayResponse + Send + Sync>;
52
53/// A handler for annotating requests with metadata.
54pub type MetadataAnnotator = Arc<dyn Fn(&GatewayRequest) -> MetadataMap + Send + Sync>;
55
56/// A handler for modifying HTTP responses before they are sent.
57pub type ResponseModifier = Arc<dyn Fn(&GatewayRequest, &mut GatewayResponse) + Send + Sync>;
58
59/// A handler for matching and transforming headers.
60pub type HeaderMatcher = Arc<dyn Fn(&str) -> Option<String> + Send + Sync>;
61
62/// A handler for verifying authentication requirements.
63pub type AuthVerifier =
64    Arc<dyn Fn(&GatewayRequest, &RouteMetadata) -> Result<(), GatewayError> + Send + Sync>;
65
66/// A handler for recording metrics.
67pub type MetricsRecorder = Arc<dyn Fn(&GatewayRequest, &GatewayResult, Duration) + Send + Sync>;
68
69/// A handler for tracing start. Returns an opaque token (TraceContext) to be passed to end.
70pub type TracingStartHandler =
71    Arc<dyn Fn(&GatewayRequest) -> Box<dyn core::any::Any + Send> + Send + Sync>;
72
73/// A handler for tracing end.
74pub type TracingEndHandler =
75    Arc<dyn Fn(Box<dyn core::any::Any + Send>, &GatewayResult) + Send + Sync>;
76
77/// Configuration for unescaping path parameters.
78#[derive(Clone, Copy, Debug, PartialEq, Eq)]
79pub enum UnescapingMode {
80    /// Unescape all characters using URL decoding.
81    AllCharacters,
82    /// Default behavior (no unescaping).
83    Default,
84}
85
86/// The core service that handles routing and authentication logic.
87///
88/// This service matches the request path against the `Router` and executes the configured
89/// `AuthVerifier`. If successful, it dispatches the request to the matched service.
90#[derive(Clone)]
91struct RouterService<S> {
92    router: Router<S>,
93    auth_verifier: Option<AuthVerifier>,
94    unescaping_mode: UnescapingMode,
95}
96
97impl<S> Service<GatewayRequest> for RouterService<S>
98where
99    S: Service<GatewayRequest, Response = GatewayResponse, Error = GatewayError>
100        + Clone
101        + Send
102        + 'static,
103    S::Future: Send + 'static,
104{
105    type Response = GatewayResponse;
106    type Error = GatewayError;
107    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
108
109    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
110        Poll::Ready(Ok(()))
111    }
112
113    fn call(&mut self, mut req: GatewayRequest) -> Self::Future {
114        let mut path = req.uri().path().to_string();
115
116        // Handle path unescaping if configured
117        match self.unescaping_mode {
118            UnescapingMode::AllCharacters => {
119                if let Ok(decoded) = percent_decode_str(&path).decode_utf8() {
120                    path = decoded.to_string();
121                }
122            }
123            UnescapingMode::Default => {}
124        }
125
126        let method = req.method().clone();
127
128        // Match request against registered routes
129        let match_result = self.router.match_request(&method, &path);
130
131        if let Some((service, params, metadata)) = match_result {
132            // Auth Verification Phase
133            if let Some(verifier) = &self.auth_verifier {
134                // Check authentication requirements
135                match verifier(&req, metadata) {
136                    Ok(_) => {}
137                    Err(e) => return Box::pin(async move { Err(e) }),
138                }
139            }
140
141            // Store captured path parameters for use by the service (handlers)
142            req.extensions_mut().insert(params);
143
144            let mut service = service.clone();
145            Box::pin(async move { service.call(req).await })
146        } else {
147            // No route matched
148            Box::pin(async move { Err(GatewayError::NotFound) })
149        }
150    }
151}
152
153/// A generic layer for recording metrics around the inner service execution.
154#[derive(Clone)]
155struct MetricsLayer<S> {
156    inner: S,
157    recorder: Option<MetricsRecorder>,
158}
159
160impl<S> Service<GatewayRequest> for MetricsLayer<S>
161where
162    S: Service<GatewayRequest, Response = GatewayResponse, Error = GatewayError>,
163    S::Future: Send + 'static,
164{
165    type Response = GatewayResponse;
166    type Error = GatewayError;
167    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
168
169    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
170        self.inner.poll_ready(cx)
171    }
172
173    fn call(&mut self, req: GatewayRequest) -> Self::Future {
174        let recorder = self.recorder.clone();
175
176        // Capture request metadata for the recorder callback
177        let method = req.method().clone();
178        let uri = req.uri().clone();
179        let headers = req.headers().clone();
180
181        let start_time = std::time::Instant::now();
182        let fut = self.inner.call(req);
183
184        Box::pin(async move {
185            let res = fut.await;
186            let duration = start_time.elapsed();
187
188            if let Some(rec) = recorder {
189                // Reconstruct a partial request for context in the recorder
190                let mut partial_req = http::Request::builder()
191                    .method(method)
192                    .uri(uri)
193                    .body(Vec::new())
194                    .unwrap();
195                *partial_req.headers_mut() = headers;
196
197                rec(&partial_req, &res, duration);
198            }
199            res
200        })
201    }
202}
203
204/// A generic layer for wrapping execution with tracing start/end hooks.
205#[derive(Clone)]
206struct TraceLayer<S> {
207    inner: S,
208    start: Option<TracingStartHandler>,
209    end: Option<TracingEndHandler>,
210}
211
212impl<S> Service<GatewayRequest> for TraceLayer<S>
213where
214    S: Service<GatewayRequest, Response = GatewayResponse, Error = GatewayError>,
215    S::Future: Send + 'static,
216{
217    type Response = GatewayResponse;
218    type Error = GatewayError;
219    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
220
221    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
222        self.inner.poll_ready(cx)
223    }
224
225    fn call(&mut self, req: GatewayRequest) -> Self::Future {
226        let token = if let Some(start) = &self.start {
227            Some(start(&req))
228        } else {
229            None
230        };
231
232        let end = self.end.clone();
233        let fut = self.inner.call(req);
234
235        Box::pin(async move {
236            let res = fut.await;
237            if let Some(end_handler) = end {
238                if let Some(t) = token {
239                    end_handler(t, &res);
240                }
241            }
242            res
243        })
244    }
245}
246
247/// A builder and configuration struct for the Gateway runtime.
248///
249/// Wraps a `Router` and allows attaching various handlers and configuration options.
250/// The `into_service()` method consumes this builder to produce the final `Service`.
251pub struct Gateway<S> {
252    router: Router<S>,
253    error_handler: Option<ErrorHandler>,
254    metadata_annotators: Vec<MetadataAnnotator>,
255    response_modifiers: Vec<ResponseModifier>,
256    incoming_header_matcher: Option<HeaderMatcher>,
257    outgoing_header_matcher: Option<HeaderMatcher>,
258    unescaping_mode: UnescapingMode,
259
260    auth_verifier: Option<AuthVerifier>,
261    metrics_recorder: Option<MetricsRecorder>,
262    tracing_start: Option<TracingStartHandler>,
263    tracing_end: Option<TracingEndHandler>,
264
265    metadata_config: MetadataForwardingConfig,
266}
267
268impl<S> Gateway<S> {
269    /// Creates a new `Gateway` wrapping the given `Router` and initialized with secure defaults.
270    pub fn new(router: Router<S>) -> Self {
271        Self {
272            router,
273            error_handler: Some(Arc::new(defaults::default_error_handler)),
274            metadata_annotators: vec![Arc::new(defaults::default_metadata_annotator)],
275            response_modifiers: vec![Arc::new(defaults::default_response_modifier)],
276            incoming_header_matcher: Some(Arc::new(defaults::default_incoming_header_matcher)),
277            outgoing_header_matcher: Some(Arc::new(defaults::default_outgoing_header_matcher)),
278            unescaping_mode: UnescapingMode::Default,
279            auth_verifier: None, // No default auth verifier, must be explicitly set if needed.
280            metrics_recorder: None,
281            tracing_start: None,
282            tracing_end: None,
283            metadata_config: MetadataForwardingConfig::default(),
284        }
285    }
286
287    /// Sets the custom error handler.
288    pub fn with_error_handler<F>(mut self, handler: F) -> Self
289    where
290        F: Fn(&GatewayRequest, GatewayError) -> GatewayResponse + Send + Sync + 'static,
291    {
292        self.error_handler = Some(Arc::new(handler));
293        self
294    }
295
296    /// Adds a metadata annotator.
297    pub fn with_metadata<F>(mut self, annotator: F) -> Self
298    where
299        F: Fn(&GatewayRequest) -> MetadataMap + Send + Sync + 'static,
300    {
301        self.metadata_annotators.push(Arc::new(annotator));
302        self
303    }
304
305    /// Adds a response modifier.
306    pub fn with_response_modifier<F>(mut self, modifier: F) -> Self
307    where
308        F: Fn(&GatewayRequest, &mut GatewayResponse) + Send + Sync + 'static,
309    {
310        self.response_modifiers.push(Arc::new(modifier));
311        self
312    }
313
314    /// Sets the incoming header matcher.
315    pub fn with_incoming_header_matcher<F>(mut self, matcher: F) -> Self
316    where
317        F: Fn(&str) -> Option<String> + Send + Sync + 'static,
318    {
319        self.incoming_header_matcher = Some(Arc::new(matcher));
320        self
321    }
322
323    /// Sets the outgoing header matcher.
324    pub fn with_outgoing_header_matcher<F>(mut self, matcher: F) -> Self
325    where
326        F: Fn(&str) -> Option<String> + Send + Sync + 'static,
327    {
328        self.outgoing_header_matcher = Some(Arc::new(matcher));
329        self
330    }
331
332    /// Sets the unescaping mode.
333    pub fn with_unescaping_mode(mut self, mode: UnescapingMode) -> Self {
334        self.unescaping_mode = mode;
335        self
336    }
337
338    /// Sets the authentication verifier.
339    pub fn with_auth_verifier<F>(mut self, verifier: F) -> Self
340    where
341        F: Fn(&GatewayRequest, &RouteMetadata) -> Result<(), GatewayError> + Send + Sync + 'static,
342    {
343        self.auth_verifier = Some(Arc::new(verifier));
344        self
345    }
346
347    /// Sets the metrics recorder.
348    pub fn with_metrics_recorder<F>(mut self, recorder: F) -> Self
349    where
350        F: Fn(&GatewayRequest, &GatewayResult, Duration) + Send + Sync + 'static,
351    {
352        self.metrics_recorder = Some(Arc::new(recorder));
353        self
354    }
355
356    /// Sets tracing handlers.
357    pub fn with_tracing<Start, End>(mut self, start: Start, end: End) -> Self
358    where
359        Start: Fn(&GatewayRequest) -> Box<dyn core::any::Any + Send> + Send + Sync + 'static,
360        End: Fn(Box<dyn core::any::Any + Send>, &GatewayResult) + Send + Sync + 'static,
361    {
362        self.tracing_start = Some(Arc::new(start));
363        self.tracing_end = Some(Arc::new(end));
364        self
365    }
366
367    /// Sets the metadata forwarding configuration.
368    pub fn with_metadata_config(mut self, config: MetadataForwardingConfig) -> Self {
369        self.metadata_config = config;
370        self
371    }
372
373    /// Returns a reference to the metadata configuration.
374    pub fn metadata_config(&self) -> &MetadataForwardingConfig {
375        &self.metadata_config
376    }
377}
378
379impl<S> Gateway<S>
380where
381    S: Service<GatewayRequest, Response = GatewayResponse, Error = GatewayError>
382        + Clone
383        + Send
384        + 'static,
385    S::Future: Send + 'static,
386{
387    /// Consumes the Gateway configuration and returns a constructed `tower::BoxCloneService`.
388    ///
389    /// This method assembles the complete middleware stack around the router.
390    pub fn into_service(
391        self,
392    ) -> tower::util::BoxCloneService<GatewayRequest, GatewayResponse, GatewayError> {
393        let router_service = RouterService {
394            router: self.router,
395            auth_verifier: self.auth_verifier.clone(),
396            unescaping_mode: self.unescaping_mode,
397        };
398
399        let service = ServiceBuilder::new()
400            .layer_fn(|inner| TraceLayer {
401                inner,
402                start: self.tracing_start.clone(),
403                end: self.tracing_end.clone(),
404            })
405            .layer_fn(|inner| MetricsLayer {
406                inner,
407                recorder: self.metrics_recorder.clone(),
408            })
409            .layer_fn(|inner| ErrorLayer::new(inner, self.error_handler.clone()))
410            .layer_fn(|inner| ResponseLayer::new(inner, self.response_modifiers.clone()))
411            .layer_fn(|inner| {
412                HeaderLayer::new(
413                    inner,
414                    self.incoming_header_matcher.clone(),
415                    self.outgoing_header_matcher.clone(),
416                )
417            })
418            .layer_fn(|inner| {
419                MetadataLayer::new(
420                    inner,
421                    self.metadata_annotators.clone(),
422                    self.metadata_config.clone(),
423                )
424            })
425            .service(router_service);
426
427        tower::util::BoxCloneService::new(service)
428    }
429}
430
431#[cfg(test)]
432mod tests {
433    use super::*;
434    use crate::alloc::string::ToString;
435    use crate::router::{AuthConfig, AuthLocation, RouteMetadata, Router};
436    use gateway_internal::path_template::{Op, OpCode, Pattern};
437    use http::StatusCode;
438    use http_body_util::BodyExt;
439    use std::sync::atomic::{AtomicUsize, Ordering};
440    use std::sync::Arc;
441    use tower::util::BoxCloneService;
442
443    fn test_pattern() -> Pattern {
444        Pattern {
445            ops: vec![Op {
446                code: OpCode::LitPush,
447                operand: 0,
448            }],
449            pool: vec!["test".to_string()],
450            vars: vec![],
451            stack_size: 1,
452            tail_len: 0,
453            verb: None,
454        }
455    }
456
457    fn make_router() -> Router<BoxCloneService<GatewayRequest, GatewayResponse, GatewayError>> {
458        let mut router = Router::new();
459        let service = tower::service_fn(|req: GatewayRequest| async move {
460            let mut resp = http::Response::builder().status(StatusCode::OK);
461            if let Some(val) = req.headers().get("x-bar") {
462                resp = resp.header("x-echo-bar", val);
463            }
464            if let Some(md) = req.extensions().get::<MetadataMap>() {
465                if let Some(val) = md.get("test-key") {
466                    resp = resp.header("x-meta-echo", val.to_str().unwrap());
467                }
468            }
469            Ok(resp
470                .body(http_body_util::BodyExt::boxed_unsync(
471                    http_body_util::Full::new(crate::bytes::Bytes::from("ok"))
472                        .map_err(|_| unreachable!()),
473                ))
474                .unwrap())
475        });
476        crate::router::route(
477            &mut router,
478            http::Method::GET,
479            test_pattern(),
480            BoxCloneService::new(service),
481        );
482        router
483    }
484
485    #[tokio::test]
486    async fn test_gateway_metrics() {
487        let router = make_router();
488        let calls = Arc::new(AtomicUsize::new(0));
489        let calls_clone = calls.clone();
490
491        let gateway = Gateway::new(router).with_metrics_recorder(move |_, res, dur| {
492            calls_clone.fetch_add(1, Ordering::SeqCst);
493            assert!(res.is_ok());
494            assert!(dur.as_nanos() > 0);
495        });
496
497        let mut service = gateway.into_service();
498        let req = http::Request::builder()
499            .method("GET")
500            .uri("/test")
501            .body(Vec::new())
502            .unwrap();
503        let _ = service.call(req).await;
504        assert_eq!(calls.load(Ordering::SeqCst), 1);
505    }
506
507    #[tokio::test]
508    async fn test_gateway_tracing() {
509        let router = make_router();
510        let trace_val = Arc::new(AtomicUsize::new(0));
511        let tv1 = trace_val.clone();
512        let tv2 = trace_val.clone();
513
514        let gateway = Gateway::new(router).with_tracing(
515            move |_| {
516                tv1.fetch_add(1, Ordering::SeqCst);
517                Box::new(123u32)
518            },
519            move |token, _| {
520                let val = *token.downcast::<u32>().unwrap();
521                assert_eq!(val, 123);
522                tv2.fetch_add(1, Ordering::SeqCst);
523            },
524        );
525
526        let mut service = gateway.into_service();
527        let req = http::Request::builder()
528            .method("GET")
529            .uri("/test")
530            .body(Vec::new())
531            .unwrap();
532        let _ = service.call(req).await;
533        assert_eq!(trace_val.load(Ordering::SeqCst), 2);
534    }
535
536    #[tokio::test]
537    async fn test_gateway_auth_verifier_success() {
538        let mut router: Router<BoxCloneService<GatewayRequest, GatewayResponse, GatewayError>> =
539            Router::new();
540        let service = tower::service_fn(|_| async {
541            Ok(http::Response::new(BodyExt::boxed_unsync(
542                http_body_util::Full::new(crate::bytes::Bytes::new()).map_err(|_| unreachable!()),
543            )))
544        });
545
546        let meta = RouteMetadata {
547            auth_required: Some(AuthConfig {
548                scheme: "ApiKey".to_string(),
549                location: AuthLocation::Header,
550                name: "X-Key".to_string(),
551            }),
552        };
553        crate::router::route_with_metadata(
554            &mut router,
555            http::Method::GET,
556            test_pattern(),
557            BoxCloneService::new(service),
558            meta,
559        );
560
561        let gateway = Gateway::new(router).with_auth_verifier(|req, meta| {
562            if let Some(auth) = &meta.auth_required {
563                if auth.location == AuthLocation::Header {
564                    if req.headers().contains_key(&auth.name) {
565                        return Ok(());
566                    }
567                }
568            }
569            Err(GatewayError::Upstream(tonic::Status::unauthenticated(
570                "missing key",
571            )))
572        });
573
574        let mut service = gateway.into_service();
575        let req = http::Request::builder()
576            .method("GET")
577            .uri("/test")
578            .header("X-Key", "secret")
579            .body(Vec::new())
580            .unwrap();
581        let resp = service.call(req).await.unwrap();
582        assert_eq!(resp.status(), StatusCode::OK);
583    }
584
585    #[tokio::test]
586    async fn test_gateway_auth_verifier_fail() {
587        let mut router: Router<BoxCloneService<GatewayRequest, GatewayResponse, GatewayError>> =
588            Router::new();
589        let service = tower::service_fn(|_| async {
590            Ok(http::Response::new(BodyExt::boxed_unsync(
591                http_body_util::Full::new(crate::bytes::Bytes::new()).map_err(|_| unreachable!()),
592            )))
593        });
594
595        let meta = RouteMetadata {
596            auth_required: Some(AuthConfig {
597                scheme: "ApiKey".to_string(),
598                location: AuthLocation::Header,
599                name: "X-Key".to_string(),
600            }),
601        };
602        crate::router::route_with_metadata(
603            &mut router,
604            http::Method::GET,
605            test_pattern(),
606            BoxCloneService::new(service),
607            meta,
608        );
609
610        // Error handler needed to map verify error to response
611        let gateway = Gateway::new(router)
612            .with_auth_verifier(|_, _| {
613                Err(GatewayError::Upstream(tonic::Status::unauthenticated(
614                    "fail",
615                )))
616            })
617            .with_error_handler(|_, _| {
618                http::Response::builder()
619                    .status(StatusCode::UNAUTHORIZED)
620                    .body(BodyExt::boxed_unsync(
621                        http_body_util::Full::new(crate::bytes::Bytes::new())
622                            .map_err(|_| unreachable!()),
623                    ))
624                    .unwrap()
625            });
626
627        let mut service = gateway.into_service();
628        let req = http::Request::builder()
629            .method("GET")
630            .uri("/test")
631            .body(Vec::new())
632            .unwrap();
633        let resp = service.call(req).await.unwrap();
634        assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
635    }
636}