1#![doc = include_str!("../README.md")]
2#![warn(missing_docs)]
3
4use std::future;
5use std::net::{SocketAddr, ToSocketAddrs};
6use std::sync::Arc;
7
8use anyhow::anyhow;
9use http::Extensions;
10use http_acl::utils::authority::{Authority, Host};
11use reqwest::{
12 Request, Response,
13 dns::{Name, Resolve, Resolving},
14};
15use reqwest_middleware::{Error, Middleware, Next};
16use thiserror::Error;
17
18pub use http_acl::{self, HttpAcl, HttpAclBuilder};
19
20#[derive(Debug, Clone)]
21pub struct HttpAclMiddleware {
23 acl: Arc<HttpAcl>,
24}
25
26impl HttpAclMiddleware {
27 pub fn new(acl: HttpAcl) -> Self {
29 Self { acl: Arc::new(acl) }
30 }
31
32 pub fn acl(&self) -> Arc<HttpAcl> {
34 self.acl.clone()
35 }
36
37 pub fn dns_resolver(&self) -> Arc<HttpAclDnsResolver> {
39 Arc::new(HttpAclDnsResolver::new(self))
40 }
41
42 pub fn with_dns_resolver(&self, dns_resolver: Arc<dyn Resolve>) -> Arc<HttpAclDnsResolver> {
44 Arc::new(HttpAclDnsResolver::with_dns_resolver(self, dns_resolver))
45 }
46}
47
48#[async_trait::async_trait]
49impl Middleware for HttpAclMiddleware {
50 async fn handle(
51 &self,
52 req: Request,
53 extensions: &mut Extensions,
54 next: Next<'_>,
55 ) -> std::result::Result<Response, Error> {
56 let scheme = req.url().scheme();
57 let acl_scheme_match = self.acl.is_scheme_allowed(scheme);
58 if acl_scheme_match.is_denied() {
59 return Err(Error::Middleware(anyhow!(
60 "scheme {} is denied - {}",
61 scheme,
62 acl_scheme_match
63 )));
64 }
65
66 let method = req.method().as_str();
67 let acl_method_match = self.acl.is_method_allowed(method);
68 if acl_method_match.is_denied() {
69 return Err(Error::Middleware(anyhow!(
70 "method {} is denied - {}",
71 method,
72 acl_method_match
73 )));
74 }
75
76 if let Some(host) = req.url().host_str() {
77 let authority = Authority::parse(host)
78 .map_err(|_| Error::Middleware(anyhow!("invalid host: {}", host)))?;
79
80 match authority.host {
81 Host::Ip(ip) => {
82 let acl_ip_match = self.acl.is_ip_allowed(&ip);
83 if acl_ip_match.is_denied() {
84 return Err(Error::Middleware(anyhow!(
85 "ip {} is denied - {}",
86 ip,
87 acl_ip_match
88 )));
89 }
90 }
91 Host::Domain(domain) => {
92 let acl_host_match = self.acl.is_host_allowed(&domain);
93 if acl_host_match.is_denied() {
94 return Err(Error::Middleware(anyhow!(
95 "host {} is denied - {}",
96 domain,
97 acl_host_match
98 )));
99 }
100 }
101 }
102
103 if let Some(port) = req.url().port_or_known_default() {
104 let acl_port_match = self.acl.is_port_allowed(port);
105 if acl_port_match.is_denied() {
106 return Err(Error::Middleware(anyhow!(
107 "port {} is denied - {}",
108 port,
109 acl_port_match
110 )));
111 }
112 }
113
114 for (key, value) in req.headers() {
115 let header_name = key.as_str();
116 let header_value = value.to_str().map_err(|_| {
117 Error::Middleware(anyhow!("invalid header value for {}", header_name))
118 })?;
119 let acl_header_match = self.acl.is_header_allowed(header_name, header_value);
120 if acl_header_match.is_denied() {
121 return Err(Error::Middleware(anyhow!(
122 "header {}: {} is denied - {}",
123 header_name,
124 header_value,
125 acl_header_match
126 )));
127 }
128 }
129
130 let acl_url_path_match = self.acl.is_url_path_allowed(req.url().path());
131 if acl_url_path_match.is_denied() {
132 return Err(Error::Middleware(anyhow!(
133 "path {} is denied - {}",
134 req.url().path(),
135 acl_url_path_match
136 )));
137 }
138
139 next.run(req, extensions).await
140 } else {
141 return Err(Error::Middleware(anyhow!("missing host")));
142 }
143 }
144}
145
146type BoxError = Box<dyn std::error::Error + Send + Sync>;
147
148struct GaiResolver;
149
150impl Resolve for GaiResolver {
151 fn resolve(&self, name: Name) -> Resolving {
152 Box::pin(async move {
153 let addresses = name
154 .as_str()
155 .to_socket_addrs()
156 .map_err(|e| Box::new(e) as BoxError)?;
157 Ok(Box::new(addresses.into_iter()) as Box<dyn Iterator<Item = SocketAddr> + Send>)
158 })
159 }
160}
161
162pub struct HttpAclDnsResolver {
164 dns_resolver: Arc<dyn Resolve>,
165 acl: Arc<HttpAcl>,
166}
167
168impl HttpAclDnsResolver {
169 pub fn new(middleware: &HttpAclMiddleware) -> Self {
171 Self {
172 dns_resolver: Arc::new(GaiResolver),
173 acl: middleware.acl(),
174 }
175 }
176
177 pub fn with_dns_resolver(
179 middleware: &HttpAclMiddleware,
180 dns_resolver: Arc<dyn Resolve>,
181 ) -> Self {
182 Self {
183 dns_resolver,
184 acl: middleware.acl(),
185 }
186 }
187}
188
189impl Resolve for HttpAclDnsResolver {
190 fn resolve(&self, name: Name) -> Resolving {
191 if self.acl.is_host_allowed(name.as_str()).is_denied() {
192 let err: BoxError = Box::new(std::io::Error::other("Host denied by ACL"));
193 return Box::pin(future::ready(Err(err)));
194 }
195
196 let acl = self.acl.clone();
197 let resolver = self.dns_resolver.clone();
198
199 Box::pin(async move {
200 if let Some(tcp_address) = acl.resolve_static_dns_mapping(name.as_str()) {
201 Ok(Box::new(std::iter::once(tcp_address))
202 as Box<dyn Iterator<Item = SocketAddr> + Send>)
203 } else {
204 let resolved = resolver.resolve(name).await;
205 match resolved {
206 Ok(addresses) => {
207 let filtered = addresses
208 .into_iter()
209 .filter(|addr| {
210 acl.is_ip_allowed(&addr.ip()).is_allowed()
211 && acl.is_port_allowed(addr.port()).is_allowed()
212 })
213 .collect::<Vec<_>>();
214 Ok(Box::new(filtered.into_iter())
215 as Box<dyn Iterator<Item = SocketAddr> + Send>)
216 }
217 Err(e) => Err(e),
218 }
219 }
220 })
221 }
222}
223
224#[derive(Error, Debug)]
225pub enum HttpAclError {
227 #[error("Host resolution denied by ACL: {host}")]
229 HostDenied {
230 host: String,
232 },
233}
234
235#[cfg(test)]
236mod tests {
237 use super::*;
238
239 #[tokio::test]
240 async fn test_http_acl_middleware() {
241 let acl = HttpAcl::builder()
242 .add_denied_host("example.com".to_string())
243 .unwrap()
244 .build();
245
246 let middleware = HttpAclMiddleware::new(acl);
247
248 let client = reqwest_middleware::ClientBuilder::new(
249 reqwest::Client::builder()
250 .dns_resolver(middleware.dns_resolver())
251 .build()
252 .unwrap(),
253 )
254 .with(middleware)
255 .build();
256
257 let request = client.get("http://example.com/").send().await;
258
259 assert!(request.is_err());
260 assert_eq!(
261 request.unwrap_err().to_string(),
262 "host example.com is denied - The entiy is denied according to the denied ACL."
263 );
264 }
265}