use std::fmt;
use crate::Request;
use rama_core::{Service, telemetry::tracing};
use rama_net::http::uri::{UriMatchError, UriMatchReplace};
use rama_utils::macros::define_inner_service_accessors;
use std::borrow::Cow;
pub struct RewriteUriService<R, S> {
match_replace: R,
inner: S,
}
impl<R, S> fmt::Debug for RewriteUriService<R, S>
where
R: fmt::Debug,
S: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("RewriteUriService")
.field("match_replace", &self.match_replace)
.field("inner", &self.inner)
.finish()
}
}
impl<R, S> Clone for RewriteUriService<R, S>
where
R: Clone,
S: Clone,
{
fn clone(&self) -> Self {
Self {
match_replace: self.match_replace.clone(),
inner: self.inner.clone(),
}
}
}
impl<R, S> RewriteUriService<R, S> {
pub fn new(match_replace: R, service: S) -> Self {
Self {
match_replace,
inner: service,
}
}
}
impl<R, S> RewriteUriService<R, S> {
define_inner_service_accessors!();
#[must_use]
pub fn match_replace_ref(&self) -> &R {
&self.match_replace
}
#[must_use]
pub fn match_replace_mut(&mut self) -> &mut R {
&mut self.match_replace
}
}
impl<ReqBody, R, S> Service<Request<ReqBody>> for RewriteUriService<R, S>
where
S: Service<Request<ReqBody>>,
R: UriMatchReplace + Send + Sync + 'static,
ReqBody: Send + 'static,
{
type Output = S::Output;
type Error = S::Error;
async fn serve(&self, mut req: Request<ReqBody>) -> Result<Self::Output, Self::Error> {
let full_uri = req.request_uri();
if let Ok(uri) = self
.match_replace
.match_replace_uri(Cow::Owned(full_uri))
.inspect_err(|err| match err {
UriMatchError::NoMatch(uri) => {
tracing::trace!("no match found for uri: {uri}; ignore")
}
UriMatchError::Unexpected(err) => {
tracing::trace!("unexpected error while trying to match uri: {err}; ignore")
}
})
{
*req.uri_mut() = uri.into_owned()
}
self.inner.serve(req).await
}
}