1#![doc = include_str!("../README.md")]
2#![warn(missing_docs)]
3
4use std::future;
5use std::net::{SocketAddr, ToSocketAddrs};
6use std::sync::Arc;
7
8use anyhow::anyhow;
9use http::Extensions;
10use http_acl::utils::authority::{Authority, Host};
11use reqwest::{
12 Request, Response,
13 dns::{Name, Resolve, Resolving},
14};
15use reqwest_middleware::{Error, Middleware, Next};
16use thiserror::Error;
17
18pub use http_acl::{self, HttpAcl, HttpAclBuilder};
19
20#[derive(Debug, Clone)]
21pub struct HttpAclMiddleware {
23 acl: Arc<HttpAcl>,
24}
25
26impl HttpAclMiddleware {
27 pub fn new(acl: HttpAcl) -> Self {
29 Self { acl: Arc::new(acl) }
30 }
31
32 pub fn acl(&self) -> Arc<HttpAcl> {
34 self.acl.clone()
35 }
36
37 pub fn dns_resolver(&self) -> Arc<HttpAclDnsResolver> {
39 Arc::new(HttpAclDnsResolver::new(self))
40 }
41
42 pub fn with_dns_resolver(&self, dns_resolver: Arc<dyn Resolve>) -> Arc<HttpAclDnsResolver> {
44 Arc::new(HttpAclDnsResolver::with_dns_resolver(self, dns_resolver))
45 }
46}
47
48#[async_trait::async_trait]
49impl Middleware for HttpAclMiddleware {
50 async fn handle(
51 &self,
52 req: Request,
53 extensions: &mut Extensions,
54 next: Next<'_>,
55 ) -> std::result::Result<Response, Error> {
56 let scheme = req.url().scheme();
57 let acl_scheme_match = self.acl.is_scheme_allowed(scheme);
58 if acl_scheme_match.is_denied() {
59 return Err(Error::Middleware(anyhow!(
60 "scheme {} is denied - {}",
61 scheme,
62 acl_scheme_match
63 )));
64 }
65
66 let method = req.method().as_str();
67 let acl_method_match = self.acl.is_method_allowed(method);
68 if acl_method_match.is_denied() {
69 return Err(Error::Middleware(anyhow!(
70 "method {} is denied - {}",
71 method,
72 acl_method_match
73 )));
74 }
75
76 if let Some(host) = req.url().host_str() {
77 let authority = Authority::parse(host)
78 .map_err(|_| Error::Middleware(anyhow!("invalid host: {}", host)))?;
79
80 match authority.host {
81 Host::Ip(ip) => {
82 let acl_ip_match = self.acl.is_ip_allowed(&ip);
83 if acl_ip_match.is_denied() {
84 return Err(Error::Middleware(anyhow!(
85 "ip {} is denied - {}",
86 ip,
87 acl_ip_match
88 )));
89 }
90 }
91 Host::Domain(domain) => {
92 let acl_host_match = self.acl.is_host_allowed(&domain);
93 if acl_host_match.is_denied() {
94 return Err(Error::Middleware(anyhow!(
95 "host {} is denied - {}",
96 domain,
97 acl_host_match
98 )));
99 }
100 }
101 }
102
103 if let Some(port) = req.url().port_or_known_default() {
104 let acl_port_match = self.acl.is_port_allowed(port);
105 if acl_port_match.is_denied() {
106 return Err(Error::Middleware(anyhow!(
107 "port {} is denied - {}",
108 port,
109 acl_port_match
110 )));
111 }
112 }
113
114 let acl_url_path_match = self.acl.is_url_path_allowed(req.url().path());
115 if acl_url_path_match.is_denied() {
116 return Err(Error::Middleware(anyhow!(
117 "path {} is denied - {}",
118 req.url().path(),
119 acl_url_path_match
120 )));
121 }
122
123 next.run(req, extensions).await
124 } else {
125 return Err(Error::Middleware(anyhow!("missing host")));
126 }
127 }
128}
129
130type BoxError = Box<dyn std::error::Error + Send + Sync>;
131
132struct GaiResolver;
133
134impl Resolve for GaiResolver {
135 fn resolve(&self, name: Name) -> Resolving {
136 Box::pin(async move {
137 let addresses = name
138 .as_str()
139 .to_socket_addrs()
140 .map_err(|e| Box::new(e) as BoxError)?;
141 Ok(Box::new(addresses.into_iter()) as Box<dyn Iterator<Item = SocketAddr> + Send>)
142 })
143 }
144}
145
146pub struct HttpAclDnsResolver {
148 dns_resolver: Arc<dyn Resolve>,
149 acl: Arc<HttpAcl>,
150}
151
152impl HttpAclDnsResolver {
153 pub fn new(middleware: &HttpAclMiddleware) -> Self {
155 Self {
156 dns_resolver: Arc::new(GaiResolver),
157 acl: middleware.acl(),
158 }
159 }
160
161 pub fn with_dns_resolver(
163 middleware: &HttpAclMiddleware,
164 dns_resolver: Arc<dyn Resolve>,
165 ) -> Self {
166 Self {
167 dns_resolver,
168 acl: middleware.acl(),
169 }
170 }
171}
172
173impl Resolve for HttpAclDnsResolver {
174 fn resolve(&self, name: Name) -> Resolving {
175 if self.acl.is_host_allowed(name.as_str()).is_denied() {
176 let err: BoxError = Box::new(std::io::Error::new(
177 std::io::ErrorKind::Other,
178 "Host denied by ACL",
179 ));
180 return Box::pin(future::ready(Err(err)));
181 }
182
183 let acl = self.acl.clone();
184 let resolver = self.dns_resolver.clone();
185
186 Box::pin(async move {
187 let resolved = resolver.resolve(name).await;
188 match resolved {
189 Ok(addresses) => {
190 let filtered = addresses
191 .into_iter()
192 .filter(|addr| {
193 acl.is_ip_allowed(&addr.ip()).is_allowed()
194 && acl.is_port_allowed(addr.port()).is_allowed()
195 })
196 .collect::<Vec<_>>();
197 Ok(Box::new(filtered.into_iter())
198 as Box<dyn Iterator<Item = SocketAddr> + Send>)
199 }
200 Err(e) => Err(e),
201 }
202 })
203 }
204}
205
206#[derive(Error, Debug)]
207pub enum HttpAclError {
209 #[error("Host resolution denied by ACL: {host}")]
211 HostDenied {
212 host: String,
214 },
215}
216
217#[cfg(test)]
218mod tests {
219 use super::*;
220
221 #[tokio::test]
222 async fn test_http_acl_middleware() {
223 let acl = HttpAcl::builder()
224 .add_denied_host("example.com".to_string())
225 .unwrap()
226 .build();
227
228 let middleware = HttpAclMiddleware::new(acl);
229
230 let client = reqwest_middleware::ClientBuilder::new(
231 reqwest::Client::builder()
232 .dns_resolver(middleware.dns_resolver())
233 .build()
234 .unwrap(),
235 )
236 .with(middleware)
237 .build();
238
239 let request = client.get("http://example.com/").send().await;
240
241 assert!(request.is_err());
242 assert_eq!(
243 request.unwrap_err().to_string(),
244 "host example.com is denied - The entiy is denied according to the denied ACL."
245 );
246 }
247}