http_acl_reqwest/
lib.rs

1#![doc = include_str!("../README.md")]
2#![warn(missing_docs)]
3
4use std::future;
5use std::net::{SocketAddr, ToSocketAddrs};
6use std::sync::Arc;
7
8use anyhow::anyhow;
9use http::Extensions;
10use http_acl::utils::authority::{Authority, Host};
11use reqwest::{
12    Request, Response,
13    dns::{Name, Resolve, Resolving},
14};
15use reqwest_middleware::{Error, Middleware, Next};
16use thiserror::Error;
17
18pub use http_acl::{self, HttpAcl, HttpAclBuilder};
19
20#[derive(Debug, Clone)]
21/// A reqwest middleware that enforces an HTTP ACL.
22pub struct HttpAclMiddleware {
23    acl: Arc<HttpAcl>,
24}
25
26impl HttpAclMiddleware {
27    /// Create a new HTTP ACL middleware.
28    pub fn new(acl: HttpAcl) -> Self {
29        Self { acl: Arc::new(acl) }
30    }
31
32    /// Get the ACL.
33    pub fn acl(&self) -> Arc<HttpAcl> {
34        self.acl.clone()
35    }
36
37    /// Create a DNS resolver that enforces the ACL.
38    pub fn dns_resolver(&self) -> Arc<HttpAclDnsResolver> {
39        Arc::new(HttpAclDnsResolver::new(self))
40    }
41
42    /// Create a DNS resolver that enforces the ACL with a custom DNS resolver.
43    pub fn with_dns_resolver(&self, dns_resolver: Arc<dyn Resolve>) -> Arc<HttpAclDnsResolver> {
44        Arc::new(HttpAclDnsResolver::with_dns_resolver(self, dns_resolver))
45    }
46}
47
48#[async_trait::async_trait]
49impl Middleware for HttpAclMiddleware {
50    async fn handle(
51        &self,
52        req: Request,
53        extensions: &mut Extensions,
54        next: Next<'_>,
55    ) -> std::result::Result<Response, Error> {
56        let scheme = req.url().scheme();
57        let acl_scheme_match = self.acl.is_scheme_allowed(scheme);
58        if acl_scheme_match.is_denied() {
59            return Err(Error::Middleware(anyhow!(
60                "scheme {} is denied - {}",
61                scheme,
62                acl_scheme_match
63            )));
64        }
65
66        let method = req.method().as_str();
67        let acl_method_match = self.acl.is_method_allowed(method);
68        if acl_method_match.is_denied() {
69            return Err(Error::Middleware(anyhow!(
70                "method {} is denied - {}",
71                method,
72                acl_method_match
73            )));
74        }
75
76        if let Some(host) = req.url().host_str() {
77            let authority = Authority::parse(host)
78                .map_err(|_| Error::Middleware(anyhow!("invalid host: {}", host)))?;
79
80            match &authority.host {
81                Host::Ip(ip) => {
82                    let acl_ip_match = self.acl.is_ip_allowed(ip);
83                    if acl_ip_match.is_denied() {
84                        return Err(Error::Middleware(anyhow!(
85                            "ip {} is denied - {}",
86                            ip,
87                            acl_ip_match
88                        )));
89                    }
90                }
91                Host::Domain(domain) => {
92                    let acl_host_match = self.acl.is_host_allowed(domain);
93                    if acl_host_match.is_denied() {
94                        return Err(Error::Middleware(anyhow!(
95                            "host {} is denied - {}",
96                            domain,
97                            acl_host_match
98                        )));
99                    }
100                }
101            }
102
103            if let Some(port) = req.url().port_or_known_default() {
104                let acl_port_match = self.acl.is_port_allowed(port);
105                if acl_port_match.is_denied() {
106                    return Err(Error::Middleware(anyhow!(
107                        "port {} is denied - {}",
108                        port,
109                        acl_port_match
110                    )));
111                }
112            }
113
114            for (key, value) in req.headers() {
115                let header_name = key.as_str();
116                let header_value = value.to_str().map_err(|_| {
117                    Error::Middleware(anyhow!("invalid header value for {}", header_name))
118                })?;
119                let acl_header_match = self.acl.is_header_allowed(header_name, header_value);
120                if acl_header_match.is_denied() {
121                    return Err(Error::Middleware(anyhow!(
122                        "header {}: {} is denied - {}",
123                        header_name,
124                        header_value,
125                        acl_header_match
126                    )));
127                }
128            }
129
130            let acl_url_path_match = self.acl.is_url_path_allowed(req.url().path());
131            if acl_url_path_match.is_denied() {
132                return Err(Error::Middleware(anyhow!(
133                    "path {} is denied - {}",
134                    req.url().path(),
135                    acl_url_path_match
136                )));
137            }
138
139            let valid_match = self.acl.is_valid(
140                scheme,
141                &authority,
142                req.headers()
143                    .iter()
144                    .filter_map(|(k, v)| Some((k.as_str(), v.to_str().ok()?))),
145                req.body().and_then(|b| b.as_bytes()),
146            );
147            if valid_match.is_denied() {
148                return Err(Error::Middleware(anyhow!(
149                    "request is denied - {}",
150                    valid_match
151                )));
152            }
153
154            next.run(req, extensions).await
155        } else {
156            return Err(Error::Middleware(anyhow!("missing host")));
157        }
158    }
159}
160
161type BoxError = Box<dyn std::error::Error + Send + Sync>;
162
163struct GaiResolver;
164
165impl Resolve for GaiResolver {
166    fn resolve(&self, name: Name) -> Resolving {
167        Box::pin(async move {
168            let addresses = name
169                .as_str()
170                .to_socket_addrs()
171                .map_err(|e| Box::new(e) as BoxError)?;
172            Ok(Box::new(addresses.into_iter()) as Box<dyn Iterator<Item = SocketAddr> + Send>)
173        })
174    }
175}
176
177/// A DNS resolver that enforces an HTTP ACL.
178pub struct HttpAclDnsResolver {
179    dns_resolver: Arc<dyn Resolve>,
180    acl: Arc<HttpAcl>,
181}
182
183impl HttpAclDnsResolver {
184    /// Create a new ACL resolver.
185    pub fn new(middleware: &HttpAclMiddleware) -> Self {
186        Self {
187            dns_resolver: Arc::new(GaiResolver),
188            acl: middleware.acl(),
189        }
190    }
191
192    /// Create a new ACL resolver with a custom DNS resolver.
193    pub fn with_dns_resolver(
194        middleware: &HttpAclMiddleware,
195        dns_resolver: Arc<dyn Resolve>,
196    ) -> Self {
197        Self {
198            dns_resolver,
199            acl: middleware.acl(),
200        }
201    }
202}
203
204impl Resolve for HttpAclDnsResolver {
205    fn resolve(&self, name: Name) -> Resolving {
206        if self.acl.is_host_allowed(name.as_str()).is_denied() {
207            let err: BoxError = Box::new(std::io::Error::other("Host denied by ACL"));
208            return Box::pin(future::ready(Err(err)));
209        }
210
211        let acl = self.acl.clone();
212        let resolver = self.dns_resolver.clone();
213
214        Box::pin(async move {
215            if let Some(tcp_address) = acl.resolve_static_dns_mapping(name.as_str()) {
216                Ok(Box::new(std::iter::once(tcp_address))
217                    as Box<dyn Iterator<Item = SocketAddr> + Send>)
218            } else {
219                let resolved = resolver.resolve(name).await;
220                match resolved {
221                    Ok(addresses) => {
222                        let filtered = addresses
223                            .into_iter()
224                            .filter(|addr| {
225                                acl.is_ip_allowed(&addr.ip()).is_allowed()
226                                    && acl.is_port_allowed(addr.port()).is_allowed()
227                            })
228                            .collect::<Vec<_>>();
229                        Ok(Box::new(filtered.into_iter())
230                            as Box<dyn Iterator<Item = SocketAddr> + Send>)
231                    }
232                    Err(e) => Err(e),
233                }
234            }
235        })
236    }
237}
238
239#[derive(Error, Debug)]
240/// An error that can occur when resolving a host.
241pub enum HttpAclError {
242    /// Host resolution denied by ACL.
243    #[error("Host resolution denied by ACL: {host}")]
244    HostDenied {
245        /// The host that was denied.
246        host: String,
247    },
248}
249
250#[cfg(test)]
251mod tests {
252    use super::*;
253
254    #[tokio::test]
255    async fn test_http_acl_middleware() {
256        let acl = HttpAcl::builder()
257            .add_denied_host("example.com".to_string())
258            .unwrap()
259            .build();
260
261        let middleware = HttpAclMiddleware::new(acl);
262
263        let client = reqwest_middleware::ClientBuilder::new(
264            reqwest::Client::builder()
265                .dns_resolver(middleware.dns_resolver())
266                .build()
267                .unwrap(),
268        )
269        .with(middleware)
270        .build();
271
272        let request = client.get("http://example.com/").send().await;
273
274        assert!(request.is_err());
275        assert_eq!(
276            request.unwrap_err().to_string(),
277            "host example.com is denied - The entity is denied according to the denied ACL."
278        );
279    }
280}