http_acl_reqwest/
lib.rs

1#![doc = include_str!("../README.md")]
2#![warn(missing_docs)]
3
4use std::future;
5use std::net::{SocketAddr, ToSocketAddrs};
6use std::sync::Arc;
7
8use anyhow::anyhow;
9use http::Extensions;
10use http_acl::utils::authority::{Authority, Host};
11use reqwest::{
12    Request, Response,
13    dns::{Name, Resolve, Resolving},
14};
15use reqwest_middleware::{Error, Middleware, Next};
16use thiserror::Error;
17
18pub use http_acl::{self, HttpAcl, HttpAclBuilder};
19
20#[derive(Debug, Clone)]
21/// A reqwest middleware that enforces an HTTP ACL.
22pub struct HttpAclMiddleware {
23    acl: Arc<HttpAcl>,
24}
25
26impl HttpAclMiddleware {
27    /// Create a new HTTP ACL middleware.
28    pub fn new(acl: HttpAcl) -> Self {
29        Self { acl: Arc::new(acl) }
30    }
31
32    /// Get the ACL.
33    pub fn acl(&self) -> Arc<HttpAcl> {
34        self.acl.clone()
35    }
36
37    /// Create a DNS resolver that enforces the ACL.
38    pub fn dns_resolver(&self) -> Arc<HttpAclDnsResolver> {
39        Arc::new(HttpAclDnsResolver::new(self))
40    }
41
42    /// Create a DNS resolver that enforces the ACL with a custom DNS resolver.
43    pub fn with_dns_resolver(&self, dns_resolver: Arc<dyn Resolve>) -> Arc<HttpAclDnsResolver> {
44        Arc::new(HttpAclDnsResolver::with_dns_resolver(self, dns_resolver))
45    }
46}
47
48#[async_trait::async_trait]
49impl Middleware for HttpAclMiddleware {
50    async fn handle(
51        &self,
52        req: Request,
53        extensions: &mut Extensions,
54        next: Next<'_>,
55    ) -> std::result::Result<Response, Error> {
56        let scheme = req.url().scheme();
57        let acl_scheme_match = self.acl.is_scheme_allowed(scheme);
58        if acl_scheme_match.is_denied() {
59            return Err(Error::Middleware(anyhow!(
60                "scheme {} is denied - {}",
61                scheme,
62                acl_scheme_match
63            )));
64        }
65
66        let method = req.method().as_str();
67        let acl_method_match = self.acl.is_method_allowed(method);
68        if acl_method_match.is_denied() {
69            return Err(Error::Middleware(anyhow!(
70                "method {} is denied - {}",
71                method,
72                acl_method_match
73            )));
74        }
75
76        if let Some(host) = req.url().host_str() {
77            let authority = Authority::parse(host)
78                .map_err(|_| Error::Middleware(anyhow!("invalid host: {}", host)))?;
79
80            match authority.host {
81                Host::Ip(ip) => {
82                    let acl_ip_match = self.acl.is_ip_allowed(&ip);
83                    if acl_ip_match.is_denied() {
84                        return Err(Error::Middleware(anyhow!(
85                            "ip {} is denied - {}",
86                            ip,
87                            acl_ip_match
88                        )));
89                    }
90                }
91                Host::Domain(domain) => {
92                    let acl_host_match = self.acl.is_host_allowed(&domain);
93                    if acl_host_match.is_denied() {
94                        return Err(Error::Middleware(anyhow!(
95                            "host {} is denied - {}",
96                            domain,
97                            acl_host_match
98                        )));
99                    }
100                }
101            }
102
103            if let Some(port) = req.url().port_or_known_default() {
104                let acl_port_match = self.acl.is_port_allowed(port);
105                if acl_port_match.is_denied() {
106                    return Err(Error::Middleware(anyhow!(
107                        "port {} is denied - {}",
108                        port,
109                        acl_port_match
110                    )));
111                }
112            }
113
114            let acl_url_path_match = self.acl.is_url_path_allowed(req.url().path());
115            if acl_url_path_match.is_denied() {
116                return Err(Error::Middleware(anyhow!(
117                    "path {} is denied - {}",
118                    req.url().path(),
119                    acl_url_path_match
120                )));
121            }
122
123            next.run(req, extensions).await
124        } else {
125            return Err(Error::Middleware(anyhow!("missing host")));
126        }
127    }
128}
129
130type BoxError = Box<dyn std::error::Error + Send + Sync>;
131
132struct GaiResolver;
133
134impl Resolve for GaiResolver {
135    fn resolve(&self, name: Name) -> Resolving {
136        Box::pin(async move {
137            let addresses = name
138                .as_str()
139                .to_socket_addrs()
140                .map_err(|e| Box::new(e) as BoxError)?;
141            Ok(Box::new(addresses.into_iter()) as Box<dyn Iterator<Item = SocketAddr> + Send>)
142        })
143    }
144}
145
146/// A DNS resolver that enforces an HTTP ACL.
147pub struct HttpAclDnsResolver {
148    dns_resolver: Arc<dyn Resolve>,
149    acl: Arc<HttpAcl>,
150}
151
152impl HttpAclDnsResolver {
153    /// Create a new ACL resolver.
154    pub fn new(middleware: &HttpAclMiddleware) -> Self {
155        Self {
156            dns_resolver: Arc::new(GaiResolver),
157            acl: middleware.acl(),
158        }
159    }
160
161    /// Create a new ACL resolver with a custom DNS resolver.
162    pub fn with_dns_resolver(
163        middleware: &HttpAclMiddleware,
164        dns_resolver: Arc<dyn Resolve>,
165    ) -> Self {
166        Self {
167            dns_resolver,
168            acl: middleware.acl(),
169        }
170    }
171}
172
173impl Resolve for HttpAclDnsResolver {
174    fn resolve(&self, name: Name) -> Resolving {
175        if self.acl.is_host_allowed(name.as_str()).is_denied() {
176            let err: BoxError = Box::new(std::io::Error::new(
177                std::io::ErrorKind::Other,
178                "Host denied by ACL",
179            ));
180            return Box::pin(future::ready(Err(err)));
181        }
182
183        let acl = self.acl.clone();
184        let resolver = self.dns_resolver.clone();
185
186        Box::pin(async move {
187            let resolved = resolver.resolve(name).await;
188            match resolved {
189                Ok(addresses) => {
190                    let filtered = addresses
191                        .into_iter()
192                        .filter(|addr| {
193                            acl.is_ip_allowed(&addr.ip()).is_allowed()
194                                && acl.is_port_allowed(addr.port()).is_allowed()
195                        })
196                        .collect::<Vec<_>>();
197                    Ok(Box::new(filtered.into_iter())
198                        as Box<dyn Iterator<Item = SocketAddr> + Send>)
199                }
200                Err(e) => Err(e),
201            }
202        })
203    }
204}
205
206#[derive(Error, Debug)]
207/// An error that can occur when resolving a host.
208pub enum HttpAclError {
209    /// Host resolution denied by ACL.
210    #[error("Host resolution denied by ACL: {host}")]
211    HostDenied {
212        /// The host that was denied.
213        host: String,
214    },
215}
216
217#[cfg(test)]
218mod tests {
219    use super::*;
220
221    #[tokio::test]
222    async fn test_http_acl_middleware() {
223        let acl = HttpAcl::builder()
224            .add_denied_host("example.com".to_string())
225            .unwrap()
226            .build();
227
228        let middleware = HttpAclMiddleware::new(acl);
229
230        let client = reqwest_middleware::ClientBuilder::new(
231            reqwest::Client::builder()
232                .dns_resolver(middleware.dns_resolver())
233                .build()
234                .unwrap(),
235        )
236        .with(middleware)
237        .build();
238
239        let request = client.get("http://example.com/").send().await;
240
241        assert!(request.is_err());
242        assert_eq!(
243            request.unwrap_err().to_string(),
244            "host example.com is denied - The entiy is denied according to the denied ACL."
245        );
246    }
247}