1#![doc = include_str!("../README.md")]
2#![warn(missing_docs)]
3
4use std::future;
5use std::net::{SocketAddr, ToSocketAddrs};
6use std::sync::Arc;
7
8use anyhow::anyhow;
9use http::Extensions;
10use http_acl::utils::authority::{Authority, Host};
11use reqwest::{
12 Request, Response,
13 dns::{Name, Resolve, Resolving},
14};
15use reqwest_middleware::{Error, Middleware, Next};
16use thiserror::Error;
17
18pub use http_acl::{self, HttpAcl, HttpAclBuilder};
19
20#[derive(Debug, Clone)]
21pub struct HttpAclMiddleware {
23 acl: Arc<HttpAcl>,
24}
25
26impl HttpAclMiddleware {
27 pub fn new(acl: HttpAcl) -> Self {
29 Self { acl: Arc::new(acl) }
30 }
31
32 pub fn acl(&self) -> Arc<HttpAcl> {
34 self.acl.clone()
35 }
36
37 pub fn dns_resolver(&self) -> Arc<HttpAclDnsResolver> {
39 Arc::new(HttpAclDnsResolver::new(self))
40 }
41
42 pub fn with_dns_resolver(&self, dns_resolver: Arc<dyn Resolve>) -> Arc<HttpAclDnsResolver> {
44 Arc::new(HttpAclDnsResolver::with_dns_resolver(self, dns_resolver))
45 }
46}
47
48#[async_trait::async_trait]
49impl Middleware for HttpAclMiddleware {
50 async fn handle(
51 &self,
52 req: Request,
53 extensions: &mut Extensions,
54 next: Next<'_>,
55 ) -> std::result::Result<Response, Error> {
56 let scheme = req.url().scheme();
57 let acl_scheme_match = self.acl.is_scheme_allowed(scheme);
58 if acl_scheme_match.is_denied() {
59 return Err(Error::Middleware(anyhow!(
60 "scheme {} is denied - {}",
61 scheme,
62 acl_scheme_match
63 )));
64 }
65
66 let method = req.method().as_str();
67 let acl_method_match = self.acl.is_method_allowed(method);
68 if acl_method_match.is_denied() {
69 return Err(Error::Middleware(anyhow!(
70 "method {} is denied - {}",
71 method,
72 acl_method_match
73 )));
74 }
75
76 if let Some(host) = req.url().host_str() {
77 let authority = Authority::parse(host)
78 .map_err(|_| Error::Middleware(anyhow!("invalid host: {}", host)))?;
79
80 match &authority.host {
81 Host::Ip(ip) => {
82 let acl_ip_match = self.acl.is_ip_allowed(ip);
83 if acl_ip_match.is_denied() {
84 return Err(Error::Middleware(anyhow!(
85 "ip {} is denied - {}",
86 ip,
87 acl_ip_match
88 )));
89 }
90 }
91 Host::Domain(domain) => {
92 let acl_host_match = self.acl.is_host_allowed(domain);
93 if acl_host_match.is_denied() {
94 return Err(Error::Middleware(anyhow!(
95 "host {} is denied - {}",
96 domain,
97 acl_host_match
98 )));
99 }
100 }
101 }
102
103 if let Some(port) = req.url().port_or_known_default() {
104 let acl_port_match = self.acl.is_port_allowed(port);
105 if acl_port_match.is_denied() {
106 return Err(Error::Middleware(anyhow!(
107 "port {} is denied - {}",
108 port,
109 acl_port_match
110 )));
111 }
112 }
113
114 for (key, value) in req.headers() {
115 let header_name = key.as_str();
116 let header_value = value.to_str().map_err(|_| {
117 Error::Middleware(anyhow!("invalid header value for {}", header_name))
118 })?;
119 let acl_header_match = self.acl.is_header_allowed(header_name, header_value);
120 if acl_header_match.is_denied() {
121 return Err(Error::Middleware(anyhow!(
122 "header {}: {} is denied - {}",
123 header_name,
124 header_value,
125 acl_header_match
126 )));
127 }
128 }
129
130 let acl_url_path_match = self.acl.is_url_path_allowed(req.url().path());
131 if acl_url_path_match.is_denied() {
132 return Err(Error::Middleware(anyhow!(
133 "path {} is denied - {}",
134 req.url().path(),
135 acl_url_path_match
136 )));
137 }
138
139 let valid_match = self.acl.is_valid(
140 scheme,
141 &authority,
142 req.headers()
143 .iter()
144 .filter_map(|(k, v)| Some((k.as_str(), v.to_str().ok()?))),
145 req.body().and_then(|b| b.as_bytes()),
146 );
147 if valid_match.is_denied() {
148 return Err(Error::Middleware(anyhow!(
149 "request is denied - {}",
150 valid_match
151 )));
152 }
153
154 next.run(req, extensions).await
155 } else {
156 return Err(Error::Middleware(anyhow!("missing host")));
157 }
158 }
159}
160
161type BoxError = Box<dyn std::error::Error + Send + Sync>;
162
163struct GaiResolver;
164
165impl Resolve for GaiResolver {
166 fn resolve(&self, name: Name) -> Resolving {
167 Box::pin(async move {
168 let addresses = name
169 .as_str()
170 .to_socket_addrs()
171 .map_err(|e| Box::new(e) as BoxError)?;
172 Ok(Box::new(addresses.into_iter()) as Box<dyn Iterator<Item = SocketAddr> + Send>)
173 })
174 }
175}
176
177pub struct HttpAclDnsResolver {
179 dns_resolver: Arc<dyn Resolve>,
180 acl: Arc<HttpAcl>,
181}
182
183impl HttpAclDnsResolver {
184 pub fn new(middleware: &HttpAclMiddleware) -> Self {
186 Self {
187 dns_resolver: Arc::new(GaiResolver),
188 acl: middleware.acl(),
189 }
190 }
191
192 pub fn with_dns_resolver(
194 middleware: &HttpAclMiddleware,
195 dns_resolver: Arc<dyn Resolve>,
196 ) -> Self {
197 Self {
198 dns_resolver,
199 acl: middleware.acl(),
200 }
201 }
202}
203
204impl Resolve for HttpAclDnsResolver {
205 fn resolve(&self, name: Name) -> Resolving {
206 if self.acl.is_host_allowed(name.as_str()).is_denied() {
207 let err: BoxError = Box::new(std::io::Error::other("Host denied by ACL"));
208 return Box::pin(future::ready(Err(err)));
209 }
210
211 let acl = self.acl.clone();
212 let resolver = self.dns_resolver.clone();
213
214 Box::pin(async move {
215 if let Some(tcp_address) = acl.resolve_static_dns_mapping(name.as_str()) {
216 Ok(Box::new(std::iter::once(tcp_address))
217 as Box<dyn Iterator<Item = SocketAddr> + Send>)
218 } else {
219 let resolved = resolver.resolve(name).await;
220 match resolved {
221 Ok(addresses) => {
222 let filtered = addresses
223 .into_iter()
224 .filter(|addr| {
225 acl.is_ip_allowed(&addr.ip()).is_allowed()
226 && acl.is_port_allowed(addr.port()).is_allowed()
227 })
228 .collect::<Vec<_>>();
229 Ok(Box::new(filtered.into_iter())
230 as Box<dyn Iterator<Item = SocketAddr> + Send>)
231 }
232 Err(e) => Err(e),
233 }
234 }
235 })
236 }
237}
238
239#[derive(Error, Debug)]
240pub enum HttpAclError {
242 #[error("Host resolution denied by ACL: {host}")]
244 HostDenied {
245 host: String,
247 },
248}
249
250#[cfg(test)]
251mod tests {
252 use super::*;
253
254 #[tokio::test]
255 async fn test_http_acl_middleware() {
256 let acl = HttpAcl::builder()
257 .add_denied_host("example.com".to_string())
258 .unwrap()
259 .build();
260
261 let middleware = HttpAclMiddleware::new(acl);
262
263 let client = reqwest_middleware::ClientBuilder::new(
264 reqwest::Client::builder()
265 .dns_resolver(middleware.dns_resolver())
266 .build()
267 .unwrap(),
268 )
269 .with(middleware)
270 .build();
271
272 let request = client.get("http://example.com/").send().await;
273
274 assert!(request.is_err());
275 assert_eq!(
276 request.unwrap_err().to_string(),
277 "host example.com is denied - The entity is denied according to the denied ACL."
278 );
279 }
280}