1use crate::alloc::boxed::Box;
32use crate::alloc::string::String;
33use crate::alloc::sync::Arc;
34use crate::alloc::vec::Vec;
35use crate::defaults;
36use crate::layers::{
37 error::ErrorLayer, headers::HeaderLayer, metadata::MetadataLayer, response::ResponseLayer,
38};
39use crate::metadata::MetadataForwardingConfig;
40use crate::router::{RouteMetadata, Router};
41use crate::{GatewayError, GatewayRequest, GatewayResponse, GatewayResult};
42use core::task::{Context, Poll};
43use core::time::Duration;
44use percent_encoding::percent_decode_str;
45use std::future::Future;
46use std::pin::Pin;
47use tonic::metadata::MetadataMap;
48use tower::{Service, ServiceBuilder};
49
50pub type ErrorHandler = Arc<dyn Fn(&GatewayRequest, GatewayError) -> GatewayResponse + Send + Sync>;
52
53pub type MetadataAnnotator = Arc<dyn Fn(&GatewayRequest) -> MetadataMap + Send + Sync>;
55
56pub type ResponseModifier = Arc<dyn Fn(&GatewayRequest, &mut GatewayResponse) + Send + Sync>;
58
59pub type HeaderMatcher = Arc<dyn Fn(&str) -> Option<String> + Send + Sync>;
61
62pub type AuthVerifier =
64 Arc<dyn Fn(&GatewayRequest, &RouteMetadata) -> Result<(), GatewayError> + Send + Sync>;
65
66pub type MetricsRecorder = Arc<dyn Fn(&GatewayRequest, &GatewayResult, Duration) + Send + Sync>;
68
69pub type TracingStartHandler =
71 Arc<dyn Fn(&GatewayRequest) -> Box<dyn core::any::Any + Send> + Send + Sync>;
72
73pub type TracingEndHandler =
75 Arc<dyn Fn(Box<dyn core::any::Any + Send>, &GatewayResult) + Send + Sync>;
76
77#[derive(Clone, Copy, Debug, PartialEq, Eq)]
79pub enum UnescapingMode {
80 AllCharacters,
82 Default,
84}
85
86#[derive(Clone)]
91struct RouterService<S> {
92 router: Router<S>,
93 auth_verifier: Option<AuthVerifier>,
94 unescaping_mode: UnescapingMode,
95}
96
97impl<S> Service<GatewayRequest> for RouterService<S>
98where
99 S: Service<GatewayRequest, Response = GatewayResponse, Error = GatewayError>
100 + Clone
101 + Send
102 + 'static,
103 S::Future: Send + 'static,
104{
105 type Response = GatewayResponse;
106 type Error = GatewayError;
107 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
108
109 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
110 Poll::Ready(Ok(()))
111 }
112
113 fn call(&mut self, mut req: GatewayRequest) -> Self::Future {
114 let mut path = req.uri().path().to_string();
115
116 match self.unescaping_mode {
118 UnescapingMode::AllCharacters => {
119 if let Ok(decoded) = percent_decode_str(&path).decode_utf8() {
120 path = decoded.to_string();
121 }
122 }
123 UnescapingMode::Default => {}
124 }
125
126 let method = req.method().clone();
127
128 let match_result = self.router.match_request(&method, &path);
130
131 if let Some((service, params, metadata)) = match_result {
132 if let Some(verifier) = &self.auth_verifier {
134 match verifier(&req, metadata) {
136 Ok(_) => {}
137 Err(e) => return Box::pin(async move { Err(e) }),
138 }
139 }
140
141 req.extensions_mut().insert(params);
143
144 let mut service = service.clone();
145 Box::pin(async move { service.call(req).await })
146 } else {
147 Box::pin(async move { Err(GatewayError::NotFound) })
149 }
150 }
151}
152
153#[derive(Clone)]
155struct MetricsLayer<S> {
156 inner: S,
157 recorder: Option<MetricsRecorder>,
158}
159
160impl<S> Service<GatewayRequest> for MetricsLayer<S>
161where
162 S: Service<GatewayRequest, Response = GatewayResponse, Error = GatewayError>,
163 S::Future: Send + 'static,
164{
165 type Response = GatewayResponse;
166 type Error = GatewayError;
167 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
168
169 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
170 self.inner.poll_ready(cx)
171 }
172
173 fn call(&mut self, req: GatewayRequest) -> Self::Future {
174 let recorder = self.recorder.clone();
175
176 let method = req.method().clone();
178 let uri = req.uri().clone();
179 let headers = req.headers().clone();
180
181 let start_time = std::time::Instant::now();
182 let fut = self.inner.call(req);
183
184 Box::pin(async move {
185 let res = fut.await;
186 let duration = start_time.elapsed();
187
188 if let Some(rec) = recorder {
189 let mut partial_req = http::Request::builder()
191 .method(method)
192 .uri(uri)
193 .body(Vec::new())
194 .unwrap();
195 *partial_req.headers_mut() = headers;
196
197 rec(&partial_req, &res, duration);
198 }
199 res
200 })
201 }
202}
203
204#[derive(Clone)]
206struct TraceLayer<S> {
207 inner: S,
208 start: Option<TracingStartHandler>,
209 end: Option<TracingEndHandler>,
210}
211
212impl<S> Service<GatewayRequest> for TraceLayer<S>
213where
214 S: Service<GatewayRequest, Response = GatewayResponse, Error = GatewayError>,
215 S::Future: Send + 'static,
216{
217 type Response = GatewayResponse;
218 type Error = GatewayError;
219 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
220
221 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
222 self.inner.poll_ready(cx)
223 }
224
225 fn call(&mut self, req: GatewayRequest) -> Self::Future {
226 let token = if let Some(start) = &self.start {
227 Some(start(&req))
228 } else {
229 None
230 };
231
232 let end = self.end.clone();
233 let fut = self.inner.call(req);
234
235 Box::pin(async move {
236 let res = fut.await;
237 if let Some(end_handler) = end {
238 if let Some(t) = token {
239 end_handler(t, &res);
240 }
241 }
242 res
243 })
244 }
245}
246
247pub struct Gateway<S> {
252 router: Router<S>,
253 error_handler: Option<ErrorHandler>,
254 metadata_annotators: Vec<MetadataAnnotator>,
255 response_modifiers: Vec<ResponseModifier>,
256 incoming_header_matcher: Option<HeaderMatcher>,
257 outgoing_header_matcher: Option<HeaderMatcher>,
258 unescaping_mode: UnescapingMode,
259
260 auth_verifier: Option<AuthVerifier>,
261 metrics_recorder: Option<MetricsRecorder>,
262 tracing_start: Option<TracingStartHandler>,
263 tracing_end: Option<TracingEndHandler>,
264
265 metadata_config: MetadataForwardingConfig,
266}
267
268impl<S> Gateway<S> {
269 pub fn new(router: Router<S>) -> Self {
271 Self {
272 router,
273 error_handler: Some(Arc::new(defaults::default_error_handler)),
274 metadata_annotators: vec![Arc::new(defaults::default_metadata_annotator)],
275 response_modifiers: vec![Arc::new(defaults::default_response_modifier)],
276 incoming_header_matcher: Some(Arc::new(defaults::default_incoming_header_matcher)),
277 outgoing_header_matcher: Some(Arc::new(defaults::default_outgoing_header_matcher)),
278 unescaping_mode: UnescapingMode::Default,
279 auth_verifier: None, metrics_recorder: None,
281 tracing_start: None,
282 tracing_end: None,
283 metadata_config: MetadataForwardingConfig::default(),
284 }
285 }
286
287 pub fn with_error_handler<F>(mut self, handler: F) -> Self
289 where
290 F: Fn(&GatewayRequest, GatewayError) -> GatewayResponse + Send + Sync + 'static,
291 {
292 self.error_handler = Some(Arc::new(handler));
293 self
294 }
295
296 pub fn with_metadata<F>(mut self, annotator: F) -> Self
298 where
299 F: Fn(&GatewayRequest) -> MetadataMap + Send + Sync + 'static,
300 {
301 self.metadata_annotators.push(Arc::new(annotator));
302 self
303 }
304
305 pub fn with_response_modifier<F>(mut self, modifier: F) -> Self
307 where
308 F: Fn(&GatewayRequest, &mut GatewayResponse) + Send + Sync + 'static,
309 {
310 self.response_modifiers.push(Arc::new(modifier));
311 self
312 }
313
314 pub fn with_incoming_header_matcher<F>(mut self, matcher: F) -> Self
316 where
317 F: Fn(&str) -> Option<String> + Send + Sync + 'static,
318 {
319 self.incoming_header_matcher = Some(Arc::new(matcher));
320 self
321 }
322
323 pub fn with_outgoing_header_matcher<F>(mut self, matcher: F) -> Self
325 where
326 F: Fn(&str) -> Option<String> + Send + Sync + 'static,
327 {
328 self.outgoing_header_matcher = Some(Arc::new(matcher));
329 self
330 }
331
332 pub fn with_unescaping_mode(mut self, mode: UnescapingMode) -> Self {
334 self.unescaping_mode = mode;
335 self
336 }
337
338 pub fn with_auth_verifier<F>(mut self, verifier: F) -> Self
340 where
341 F: Fn(&GatewayRequest, &RouteMetadata) -> Result<(), GatewayError> + Send + Sync + 'static,
342 {
343 self.auth_verifier = Some(Arc::new(verifier));
344 self
345 }
346
347 pub fn with_metrics_recorder<F>(mut self, recorder: F) -> Self
349 where
350 F: Fn(&GatewayRequest, &GatewayResult, Duration) + Send + Sync + 'static,
351 {
352 self.metrics_recorder = Some(Arc::new(recorder));
353 self
354 }
355
356 pub fn with_tracing<Start, End>(mut self, start: Start, end: End) -> Self
358 where
359 Start: Fn(&GatewayRequest) -> Box<dyn core::any::Any + Send> + Send + Sync + 'static,
360 End: Fn(Box<dyn core::any::Any + Send>, &GatewayResult) + Send + Sync + 'static,
361 {
362 self.tracing_start = Some(Arc::new(start));
363 self.tracing_end = Some(Arc::new(end));
364 self
365 }
366
367 pub fn with_metadata_config(mut self, config: MetadataForwardingConfig) -> Self {
369 self.metadata_config = config;
370 self
371 }
372
373 pub fn metadata_config(&self) -> &MetadataForwardingConfig {
375 &self.metadata_config
376 }
377}
378
379impl<S> Gateway<S>
380where
381 S: Service<GatewayRequest, Response = GatewayResponse, Error = GatewayError>
382 + Clone
383 + Send
384 + 'static,
385 S::Future: Send + 'static,
386{
387 pub fn into_service(
391 self,
392 ) -> tower::util::BoxCloneService<GatewayRequest, GatewayResponse, GatewayError> {
393 let router_service = RouterService {
394 router: self.router,
395 auth_verifier: self.auth_verifier.clone(),
396 unescaping_mode: self.unescaping_mode,
397 };
398
399 let service = ServiceBuilder::new()
400 .layer_fn(|inner| TraceLayer {
401 inner,
402 start: self.tracing_start.clone(),
403 end: self.tracing_end.clone(),
404 })
405 .layer_fn(|inner| MetricsLayer {
406 inner,
407 recorder: self.metrics_recorder.clone(),
408 })
409 .layer_fn(|inner| ErrorLayer::new(inner, self.error_handler.clone()))
410 .layer_fn(|inner| ResponseLayer::new(inner, self.response_modifiers.clone()))
411 .layer_fn(|inner| {
412 HeaderLayer::new(
413 inner,
414 self.incoming_header_matcher.clone(),
415 self.outgoing_header_matcher.clone(),
416 )
417 })
418 .layer_fn(|inner| {
419 MetadataLayer::new(
420 inner,
421 self.metadata_annotators.clone(),
422 self.metadata_config.clone(),
423 )
424 })
425 .service(router_service);
426
427 tower::util::BoxCloneService::new(service)
428 }
429}
430
431#[cfg(test)]
432mod tests {
433 use super::*;
434 use crate::alloc::string::ToString;
435 use crate::router::{AuthConfig, AuthLocation, RouteMetadata, Router};
436 use gateway_internal::path_template::{Op, OpCode, Pattern};
437 use http::StatusCode;
438 use http_body_util::BodyExt;
439 use std::sync::atomic::{AtomicUsize, Ordering};
440 use std::sync::Arc;
441 use tower::util::BoxCloneService;
442
443 fn test_pattern() -> Pattern {
444 Pattern {
445 ops: vec![Op {
446 code: OpCode::LitPush,
447 operand: 0,
448 }],
449 pool: vec!["test".to_string()],
450 vars: vec![],
451 stack_size: 1,
452 tail_len: 0,
453 verb: None,
454 }
455 }
456
457 fn make_router() -> Router<BoxCloneService<GatewayRequest, GatewayResponse, GatewayError>> {
458 let mut router = Router::new();
459 let service = tower::service_fn(|req: GatewayRequest| async move {
460 let mut resp = http::Response::builder().status(StatusCode::OK);
461 if let Some(val) = req.headers().get("x-bar") {
462 resp = resp.header("x-echo-bar", val);
463 }
464 if let Some(md) = req.extensions().get::<MetadataMap>() {
465 if let Some(val) = md.get("test-key") {
466 resp = resp.header("x-meta-echo", val.to_str().unwrap());
467 }
468 }
469 Ok(resp
470 .body(http_body_util::BodyExt::boxed_unsync(
471 http_body_util::Full::new(crate::bytes::Bytes::from("ok"))
472 .map_err(|_| unreachable!()),
473 ))
474 .unwrap())
475 });
476 crate::router::route(
477 &mut router,
478 http::Method::GET,
479 test_pattern(),
480 BoxCloneService::new(service),
481 );
482 router
483 }
484
485 #[tokio::test]
486 async fn test_gateway_metrics() {
487 let router = make_router();
488 let calls = Arc::new(AtomicUsize::new(0));
489 let calls_clone = calls.clone();
490
491 let gateway = Gateway::new(router).with_metrics_recorder(move |_, res, dur| {
492 calls_clone.fetch_add(1, Ordering::SeqCst);
493 assert!(res.is_ok());
494 assert!(dur.as_nanos() > 0);
495 });
496
497 let mut service = gateway.into_service();
498 let req = http::Request::builder()
499 .method("GET")
500 .uri("/test")
501 .body(Vec::new())
502 .unwrap();
503 let _ = service.call(req).await;
504 assert_eq!(calls.load(Ordering::SeqCst), 1);
505 }
506
507 #[tokio::test]
508 async fn test_gateway_tracing() {
509 let router = make_router();
510 let trace_val = Arc::new(AtomicUsize::new(0));
511 let tv1 = trace_val.clone();
512 let tv2 = trace_val.clone();
513
514 let gateway = Gateway::new(router).with_tracing(
515 move |_| {
516 tv1.fetch_add(1, Ordering::SeqCst);
517 Box::new(123u32)
518 },
519 move |token, _| {
520 let val = *token.downcast::<u32>().unwrap();
521 assert_eq!(val, 123);
522 tv2.fetch_add(1, Ordering::SeqCst);
523 },
524 );
525
526 let mut service = gateway.into_service();
527 let req = http::Request::builder()
528 .method("GET")
529 .uri("/test")
530 .body(Vec::new())
531 .unwrap();
532 let _ = service.call(req).await;
533 assert_eq!(trace_val.load(Ordering::SeqCst), 2);
534 }
535
536 #[tokio::test]
537 async fn test_gateway_auth_verifier_success() {
538 let mut router: Router<BoxCloneService<GatewayRequest, GatewayResponse, GatewayError>> =
539 Router::new();
540 let service = tower::service_fn(|_| async {
541 Ok(http::Response::new(BodyExt::boxed_unsync(
542 http_body_util::Full::new(crate::bytes::Bytes::new()).map_err(|_| unreachable!()),
543 )))
544 });
545
546 let meta = RouteMetadata {
547 auth_required: Some(AuthConfig {
548 scheme: "ApiKey".to_string(),
549 location: AuthLocation::Header,
550 name: "X-Key".to_string(),
551 }),
552 };
553 crate::router::route_with_metadata(
554 &mut router,
555 http::Method::GET,
556 test_pattern(),
557 BoxCloneService::new(service),
558 meta,
559 );
560
561 let gateway = Gateway::new(router).with_auth_verifier(|req, meta| {
562 if let Some(auth) = &meta.auth_required {
563 if auth.location == AuthLocation::Header {
564 if req.headers().contains_key(&auth.name) {
565 return Ok(());
566 }
567 }
568 }
569 Err(GatewayError::Upstream(tonic::Status::unauthenticated(
570 "missing key",
571 )))
572 });
573
574 let mut service = gateway.into_service();
575 let req = http::Request::builder()
576 .method("GET")
577 .uri("/test")
578 .header("X-Key", "secret")
579 .body(Vec::new())
580 .unwrap();
581 let resp = service.call(req).await.unwrap();
582 assert_eq!(resp.status(), StatusCode::OK);
583 }
584
585 #[tokio::test]
586 async fn test_gateway_auth_verifier_fail() {
587 let mut router: Router<BoxCloneService<GatewayRequest, GatewayResponse, GatewayError>> =
588 Router::new();
589 let service = tower::service_fn(|_| async {
590 Ok(http::Response::new(BodyExt::boxed_unsync(
591 http_body_util::Full::new(crate::bytes::Bytes::new()).map_err(|_| unreachable!()),
592 )))
593 });
594
595 let meta = RouteMetadata {
596 auth_required: Some(AuthConfig {
597 scheme: "ApiKey".to_string(),
598 location: AuthLocation::Header,
599 name: "X-Key".to_string(),
600 }),
601 };
602 crate::router::route_with_metadata(
603 &mut router,
604 http::Method::GET,
605 test_pattern(),
606 BoxCloneService::new(service),
607 meta,
608 );
609
610 let gateway = Gateway::new(router)
612 .with_auth_verifier(|_, _| {
613 Err(GatewayError::Upstream(tonic::Status::unauthenticated(
614 "fail",
615 )))
616 })
617 .with_error_handler(|_, _| {
618 http::Response::builder()
619 .status(StatusCode::UNAUTHORIZED)
620 .body(BodyExt::boxed_unsync(
621 http_body_util::Full::new(crate::bytes::Bytes::new())
622 .map_err(|_| unreachable!()),
623 ))
624 .unwrap()
625 });
626
627 let mut service = gateway.into_service();
628 let req = http::Request::builder()
629 .method("GET")
630 .uri("/test")
631 .body(Vec::new())
632 .unwrap();
633 let resp = service.call(req).await.unwrap();
634 assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
635 }
636}