tower_request_modifier/
lib.rs

1#![doc(html_root_url = "https://docs.rs/tower-request-modifier/0.1.0")]
2#![deny(missing_docs, missing_debug_implementations, unreachable_pub)]
3#![cfg_attr(test, deny(warnings))]
4
5//! A `tower::Service` middleware to modify the request.
6
7use futures::Poll;
8use http::header::{HeaderName, HeaderValue};
9use http::uri::{self, Uri};
10use http::{HttpTryFrom, Request};
11use std::fmt;
12use std::sync::Arc;
13use tower_service::Service;
14
15/// Wraps an HTTP service, injecting authority and scheme on every request.
16pub struct RequestModifier<T, B> {
17    inner: T,
18    modifiers: Arc<Vec<Box<dyn Fn(Request<B>) -> Request<B> + Send + Sync>>>,
19}
20
21impl<T, B> std::fmt::Debug for RequestModifier<T, B> {
22    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
23        writeln!(f, "RequestModifier with {} modifiers", self.modifiers.len())
24    }
25}
26
27/// Configure an `RequestModifier` instance
28pub struct Builder<B> {
29    modifiers: Vec<Result<Box<dyn Fn(Request<B>) -> Request<B> + Send + Sync>, BuilderError>>,
30}
31
32impl<B> Default for Builder<B> {
33    fn default() -> Self {
34        Builder {
35            modifiers: Vec::default(),
36        }
37    }
38}
39
40/// Errors that can happen when building an `RequestModifier`.
41#[derive(Debug)]
42pub struct BuilderError {
43    _p: (),
44}
45
46// ===== impl RequestModifier ======
47
48impl<T, B> RequestModifier<T, B> {
49    /// Create a new `RequestModifier`
50    pub fn new(
51        inner: T,
52        modifiers: Arc<Vec<Box<Fn(Request<B>) -> Request<B> + Send + Sync>>>,
53    ) -> Self {
54        RequestModifier {
55            inner: inner,
56            modifiers: modifiers,
57        }
58    }
59
60    /// Returns a reference to the inner service.
61    pub fn get_ref(&self) -> &T {
62        &self.inner
63    }
64
65    /// Returns a mutable reference to the inner service.
66    pub fn get_mut(&mut self) -> &mut T {
67        &mut self.inner
68    }
69
70    /// Consumes `self`, returning the inner service.
71    pub fn into_inner(self) -> T {
72        self.inner
73    }
74}
75
76impl<T, B> Service<Request<B>> for RequestModifier<T, B>
77where
78    T: Service<Request<B>>,
79{
80    type Response = T::Response;
81    type Error = T::Error;
82    type Future = T::Future;
83
84    fn poll_ready(&mut self) -> Poll<(), Self::Error> {
85        self.inner.poll_ready()
86    }
87
88    fn call(&mut self, mut req: Request<B>) -> Self::Future {
89        let mods = &self.modifiers;
90        for m in mods.iter() {
91            req = m(req);
92        }
93
94        // Call the inner service
95        self.inner.call(req)
96    }
97}
98
99impl<T, B> Clone for RequestModifier<T, B>
100where
101    T: Clone,
102{
103    fn clone(&self) -> Self {
104        RequestModifier {
105            inner: self.inner.clone(),
106            modifiers: self.modifiers.clone(),
107        }
108    }
109}
110
111// ===== impl Builder ======
112
113impl<B> Builder<B> {
114    /// Return a new, default builder
115    pub fn new() -> Self {
116        Builder::default()
117    }
118
119    /// Build a Fn to add desired header
120    fn make_add_header(
121        name: HeaderName,
122        val: HeaderValue,
123    ) -> Box<Fn(Request<B>) -> Request<B> + Send + Sync> {
124        Box::new(move |mut req: Request<B>| {
125            req.headers_mut().append(name.clone(), val.clone());
126            req
127        })
128    }
129
130    /// Set a header on all requests.
131    pub fn add_header<T: ToString, R>(mut self, name: T, val: R) -> Self
132    where
133        HeaderName: HttpTryFrom<T>,
134        HeaderValue: HttpTryFrom<R>,
135    {
136        let name = HeaderName::try_from(name);
137        let val = HeaderValue::try_from(val);
138
139        let err = BuilderError { _p: () };
140
141        let modification = match (name, val) {
142            (Ok(name), Ok(val)) => Ok(Self::make_add_header(name, val)),
143            (_, _) => Err(err),
144        };
145
146        self.modifiers.push(modification);
147        self
148    }
149
150    /// Build a Fn to perform desired Request origin modification
151    fn make_set_origin(
152        scheme: uri::Scheme,
153        authority: uri::Authority,
154    ) -> Box<Fn(Request<B>) -> Request<B> + Send + Sync> {
155        Box::new(move |req: Request<B>| {
156            // Split the request into the head and the body.
157            let (mut head, body) = req.into_parts();
158
159            // Split the request URI into parts.
160            let mut uri: http::uri::Parts = head.uri.into();
161
162            // Update the URI parts, setting the scheme and authority
163            uri.authority = Some(authority.clone());
164            uri.scheme = Some(scheme.clone());
165
166            // Update the the request URI
167            head.uri = http::Uri::from_parts(uri).expect("valid uri");
168
169            Request::from_parts(head, body)
170        })
171    }
172
173    /// Set the URI to use as the origin for all requests.
174    pub fn set_origin<T>(mut self, uri: T) -> Self
175    where
176        Uri: HttpTryFrom<T>,
177    {
178        let modification = Uri::try_from(uri)
179            .map_err(|_| BuilderError { _p: () })
180            .and_then(|u| {
181                let parts = uri::Parts::from(u);
182
183                let scheme = parts.scheme.ok_or(BuilderError { _p: () })?;
184                let authority = parts.authority.ok_or(BuilderError { _p: () })?;
185
186                let check = match parts.path_and_query {
187                    None => Ok(()),
188                    Some(ref path) if path == "/" => Ok(()),
189                    _ => Err(BuilderError { _p: () }),
190                };
191
192                check.and_then(|_| Ok(Self::make_set_origin(scheme, authority)))
193            });
194
195        self.modifiers.push(modification);
196        self
197    }
198
199    /// Run an arbitrary modifier on all requests
200    pub fn add_modifier(
201        mut self,
202        modifier: Box<Fn(Request<B>) -> Request<B> + Send + Sync>,
203    ) -> Self {
204        self.modifiers.push(Ok(modifier));
205        self
206    }
207
208    /// Build the `RequestModifier` from the provided settings.
209    pub fn build<T>(self, inner: T) -> Result<RequestModifier<T, B>, BuilderError> {
210        let modifiers = self.modifiers.into_iter().collect::<Result<Vec<_>, _>>()?;
211        Ok(RequestModifier::new(inner, Arc::new(modifiers)))
212    }
213}
214
215impl<B> fmt::Debug for Builder<B> {
216    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
217        write!(f, "RequestModifierBuilder")
218    }
219}