#![doc = include_str!("../examples/rewrite_uri.rs")]
use std::task::{Context, Poll};
use futures_util::future::{Either, Ready, ready};
use tower_layer::Layer;
use tower_service::Service;
pub trait RewriteUri {
type Error;
fn rewrite_uri(&mut self, uri: &http::Uri) -> Result<http::Uri, Self::Error>;
}
impl<F, E> RewriteUri for F
where
F: FnMut(&http::Uri) -> Result<http::Uri, E>,
{
type Error = E;
fn rewrite_uri(&mut self, uri: &http::Uri) -> Result<http::Uri, Self::Error> {
self(uri)
}
}
#[derive(Debug, Clone)]
pub struct RewriteUriLayer<R> {
rewrite: R,
}
impl<R> RewriteUriLayer<R> {
pub fn new(rewrite: R) -> Self {
Self { rewrite }
}
}
impl<S, R: Clone> Layer<S> for RewriteUriLayer<R> {
type Service = RewriteUriService<S, R>;
fn layer(&self, inner: S) -> Self::Service {
RewriteUriService::new(inner, self.rewrite.clone())
}
}
#[derive(Debug, Clone)]
pub struct RewriteUriService<S, R> {
inner: S,
rewrite: R,
}
impl<S, R> RewriteUriService<S, R> {
pub fn new(inner: S, rewrite: R) -> Self {
Self { inner, rewrite }
}
}
impl<S, R, ReqBody> Service<http::Request<ReqBody>> for RewriteUriService<S, R>
where
S: Service<http::Request<ReqBody>>,
R: RewriteUri,
R::Error: Into<S::Error>,
{
type Response = S::Response;
type Error = S::Error;
type Future = Either<Ready<Result<S::Response, S::Error>>, S::Future>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, mut req: http::Request<ReqBody>) -> Self::Future {
match self.rewrite.rewrite_uri(req.uri()) {
Ok(new_uri) => {
*req.uri_mut() = new_uri;
Either::Right(self.inner.call(req))
}
Err(e) => Either::Left(ready(Err(e.into()))),
}
}
}
#[cfg(test)]
mod tests {
use std::convert::Infallible;
use http::{Request, Response, Uri};
use tower::{ServiceBuilder, service_fn};
use tower_layer::Layer as _;
use tower_service::Service as _;
use super::{RewriteUri, RewriteUriLayer, RewriteUriService};
fn capture_uri_service()
-> impl tower_service::Service<Request<()>, Response = Response<String>, Error = Infallible>
{
service_fn(|req: Request<()>| async move {
Ok::<_, Infallible>(Response::new(req.uri().to_string()))
})
}
#[tokio::test]
async fn test_rewrite_uri_with_closure() {
let mut svc = RewriteUriService::new(capture_uri_service(), |_uri: &Uri| {
Ok::<_, Infallible>(Uri::from_static("http://example.com/rewritten"))
});
let response = svc
.call(Request::builder().uri("/original").body(()).unwrap())
.await
.unwrap();
assert_eq!(response.into_body(), "http://example.com/rewritten");
}
#[tokio::test]
async fn test_rewrite_uri_layer() {
let mut svc = RewriteUriLayer::new(|_uri: &Uri| {
Ok::<_, Infallible>(Uri::from_static("http://example.com/via-layer"))
})
.layer(capture_uri_service());
let req = Request::builder().uri("/original").body(()).unwrap();
let response = svc.call(req).await.unwrap();
assert_eq!(response.into_body(), "http://example.com/via-layer");
}
#[tokio::test]
async fn test_rewrite_uri_service_builder() {
let mut svc = ServiceBuilder::new()
.layer(RewriteUriLayer::new(|uri: &Uri| {
let path = uri.path_and_query().map_or("/", |pq| pq.as_str());
let new_uri: Uri = format!("http://example.com{path}").parse().unwrap();
Ok::<_, Infallible>(new_uri)
}))
.service(capture_uri_service());
let req = Request::builder().uri("/hello").body(()).unwrap();
let response = svc.call(req).await.unwrap();
assert_eq!(response.into_body(), "http://example.com/hello");
}
#[tokio::test]
async fn test_rewrite_uri_error_propagates() {
let inner =
service_fn(|_: Request<()>| async { Ok::<_, String>(Response::new("ok".to_string())) });
let mut svc = RewriteUriService::new(inner, |_uri: &Uri| {
Err::<Uri, String>("rewrite failed".to_string())
});
let req = Request::builder().uri("/original").body(()).unwrap();
let result = svc.call(req).await;
assert!(result.is_err());
assert_eq!(result.unwrap_err(), "rewrite failed");
}
#[tokio::test]
async fn test_rewrite_uri_struct_impl() {
#[derive(Clone)]
struct PrependBase {
base: &'static str,
}
impl RewriteUri for PrependBase {
type Error = Infallible;
fn rewrite_uri(&mut self, uri: &Uri) -> Result<Uri, Self::Error> {
let path = uri.path_and_query().map_or("/", |pq| pq.as_str());
Ok(format!("{}{path}", self.base).parse().unwrap())
}
}
let mut svc = RewriteUriLayer::new(PrependBase {
base: "http://backend.internal",
})
.layer(capture_uri_service());
let req = Request::builder().uri("/api/users").body(()).unwrap();
let response = svc.call(req).await.unwrap();
assert_eq!(response.into_body(), "http://backend.internal/api/users");
}
}