http_acl_reqwest/
lib.rs

1#![doc = include_str!("../README.md")]
2#![warn(missing_docs)]
3
4use std::future;
5use std::net::{SocketAddr, ToSocketAddrs};
6use std::sync::Arc;
7
8use anyhow::anyhow;
9use http::Extensions;
10use http_acl::utils::authority::{Authority, Host};
11use reqwest::{
12    Request, Response,
13    dns::{Name, Resolve, Resolving},
14};
15use reqwest_middleware::{Error, Middleware, Next};
16use thiserror::Error;
17
18pub use http_acl::{self, HttpAcl, HttpAclBuilder};
19
20#[derive(Debug, Clone)]
21/// A reqwest middleware that enforces an HTTP ACL.
22pub struct HttpAclMiddleware {
23    acl: Arc<HttpAcl>,
24}
25
26impl HttpAclMiddleware {
27    /// Create a new HTTP ACL middleware.
28    pub fn new(acl: HttpAcl) -> Self {
29        Self { acl: Arc::new(acl) }
30    }
31
32    /// Get the ACL.
33    pub fn acl(&self) -> Arc<HttpAcl> {
34        self.acl.clone()
35    }
36
37    /// Create a DNS resolver that enforces the ACL.
38    pub fn dns_resolver(&self) -> Arc<HttpAclDnsResolver> {
39        Arc::new(HttpAclDnsResolver::new(self))
40    }
41
42    /// Create a DNS resolver that enforces the ACL with a custom DNS resolver.
43    pub fn with_dns_resolver(&self, dns_resolver: Arc<dyn Resolve>) -> Arc<HttpAclDnsResolver> {
44        Arc::new(HttpAclDnsResolver::with_dns_resolver(self, dns_resolver))
45    }
46}
47
48#[async_trait::async_trait]
49impl Middleware for HttpAclMiddleware {
50    async fn handle(
51        &self,
52        req: Request,
53        extensions: &mut Extensions,
54        next: Next<'_>,
55    ) -> std::result::Result<Response, Error> {
56        let scheme = req.url().scheme();
57        let acl_scheme_match = self.acl.is_scheme_allowed(scheme);
58        if acl_scheme_match.is_denied() {
59            return Err(Error::Middleware(anyhow!(
60                "scheme {} is denied - {}",
61                scheme,
62                acl_scheme_match
63            )));
64        }
65
66        let method = req.method().as_str();
67        let acl_method_match = self.acl.is_method_allowed(method);
68        if acl_method_match.is_denied() {
69            return Err(Error::Middleware(anyhow!(
70                "method {} is denied - {}",
71                method,
72                acl_method_match
73            )));
74        }
75
76        if let Some(host) = req.url().host_str() {
77            let authority = Authority::parse(host)
78                .map_err(|_| Error::Middleware(anyhow!("invalid host: {}", host)))?;
79
80            match authority.host {
81                Host::Ip(ip) => {
82                    let acl_ip_match = self.acl.is_ip_allowed(&ip);
83                    if acl_ip_match.is_denied() {
84                        return Err(Error::Middleware(anyhow!(
85                            "ip {} is denied - {}",
86                            ip,
87                            acl_ip_match
88                        )));
89                    }
90                }
91                Host::Domain(domain) => {
92                    let acl_host_match = self.acl.is_host_allowed(&domain);
93                    if acl_host_match.is_denied() {
94                        return Err(Error::Middleware(anyhow!(
95                            "host {} is denied - {}",
96                            domain,
97                            acl_host_match
98                        )));
99                    }
100                }
101            }
102
103            if let Some(port) = req.url().port_or_known_default() {
104                let acl_port_match = self.acl.is_port_allowed(port);
105                if acl_port_match.is_denied() {
106                    return Err(Error::Middleware(anyhow!(
107                        "port {} is denied - {}",
108                        port,
109                        acl_port_match
110                    )));
111                }
112            }
113
114            for (key, value) in req.headers() {
115                let header_name = key.as_str();
116                let header_value = value.to_str().map_err(|_| {
117                    Error::Middleware(anyhow!("invalid header value for {}", header_name))
118                })?;
119                let acl_header_match = self.acl.is_header_allowed(header_name, header_value);
120                if acl_header_match.is_denied() {
121                    return Err(Error::Middleware(anyhow!(
122                        "header {}: {} is denied - {}",
123                        header_name,
124                        header_value,
125                        acl_header_match
126                    )));
127                }
128            }
129
130            let acl_url_path_match = self.acl.is_url_path_allowed(req.url().path());
131            if acl_url_path_match.is_denied() {
132                return Err(Error::Middleware(anyhow!(
133                    "path {} is denied - {}",
134                    req.url().path(),
135                    acl_url_path_match
136                )));
137            }
138
139            next.run(req, extensions).await
140        } else {
141            return Err(Error::Middleware(anyhow!("missing host")));
142        }
143    }
144}
145
146type BoxError = Box<dyn std::error::Error + Send + Sync>;
147
148struct GaiResolver;
149
150impl Resolve for GaiResolver {
151    fn resolve(&self, name: Name) -> Resolving {
152        Box::pin(async move {
153            let addresses = name
154                .as_str()
155                .to_socket_addrs()
156                .map_err(|e| Box::new(e) as BoxError)?;
157            Ok(Box::new(addresses.into_iter()) as Box<dyn Iterator<Item = SocketAddr> + Send>)
158        })
159    }
160}
161
162/// A DNS resolver that enforces an HTTP ACL.
163pub struct HttpAclDnsResolver {
164    dns_resolver: Arc<dyn Resolve>,
165    acl: Arc<HttpAcl>,
166}
167
168impl HttpAclDnsResolver {
169    /// Create a new ACL resolver.
170    pub fn new(middleware: &HttpAclMiddleware) -> Self {
171        Self {
172            dns_resolver: Arc::new(GaiResolver),
173            acl: middleware.acl(),
174        }
175    }
176
177    /// Create a new ACL resolver with a custom DNS resolver.
178    pub fn with_dns_resolver(
179        middleware: &HttpAclMiddleware,
180        dns_resolver: Arc<dyn Resolve>,
181    ) -> Self {
182        Self {
183            dns_resolver,
184            acl: middleware.acl(),
185        }
186    }
187}
188
189impl Resolve for HttpAclDnsResolver {
190    fn resolve(&self, name: Name) -> Resolving {
191        if self.acl.is_host_allowed(name.as_str()).is_denied() {
192            let err: BoxError = Box::new(std::io::Error::other("Host denied by ACL"));
193            return Box::pin(future::ready(Err(err)));
194        }
195
196        let acl = self.acl.clone();
197        let resolver = self.dns_resolver.clone();
198
199        Box::pin(async move {
200            if let Some(tcp_address) = acl.resolve_static_dns_mapping(name.as_str()) {
201                Ok(Box::new(std::iter::once(tcp_address))
202                    as Box<dyn Iterator<Item = SocketAddr> + Send>)
203            } else {
204                let resolved = resolver.resolve(name).await;
205                match resolved {
206                    Ok(addresses) => {
207                        let filtered = addresses
208                            .into_iter()
209                            .filter(|addr| {
210                                acl.is_ip_allowed(&addr.ip()).is_allowed()
211                                    && acl.is_port_allowed(addr.port()).is_allowed()
212                            })
213                            .collect::<Vec<_>>();
214                        Ok(Box::new(filtered.into_iter())
215                            as Box<dyn Iterator<Item = SocketAddr> + Send>)
216                    }
217                    Err(e) => Err(e),
218                }
219            }
220        })
221    }
222}
223
224#[derive(Error, Debug)]
225/// An error that can occur when resolving a host.
226pub enum HttpAclError {
227    /// Host resolution denied by ACL.
228    #[error("Host resolution denied by ACL: {host}")]
229    HostDenied {
230        /// The host that was denied.
231        host: String,
232    },
233}
234
235#[cfg(test)]
236mod tests {
237    use super::*;
238
239    #[tokio::test]
240    async fn test_http_acl_middleware() {
241        let acl = HttpAcl::builder()
242            .add_denied_host("example.com".to_string())
243            .unwrap()
244            .build();
245
246        let middleware = HttpAclMiddleware::new(acl);
247
248        let client = reqwest_middleware::ClientBuilder::new(
249            reqwest::Client::builder()
250                .dns_resolver(middleware.dns_resolver())
251                .build()
252                .unwrap(),
253        )
254        .with(middleware)
255        .build();
256
257        let request = client.get("http://example.com/").send().await;
258
259        assert!(request.is_err());
260        assert_eq!(
261            request.unwrap_err().to_string(),
262            "host example.com is denied - The entiy is denied according to the denied ACL."
263        );
264    }
265}