1#[cfg(feature = "graphql")]
2use async_graphql::{InputValueError, InputValueResult, Scalar, ScalarType, Value};
3
4#[cfg(feature = "serialize")]
5mod serialize;
6
7use cidr::IpCidr;
8use std::collections::{hash_set::IntoIter, HashSet};
9use std::net::IpAddr;
10use std::str::FromStr;
11
12#[derive(Clone, Debug, Eq, Hash, PartialEq)]
13pub enum NoProxyItem {
14 Wildcard,
15 IpCidr(String, IpCidr),
16 WithDot(String, bool, bool),
17 Plain(String),
18}
19
20#[cfg(feature = "graphql")]
21#[Scalar]
22impl ScalarType for NoProxyItem {
23 fn parse(value: Value) -> InputValueResult<Self> {
24 match value {
25 Value::String(s) => Ok(Self::from(s)),
26 _ => Err(InputValueError::expected_type(value)),
27 }
28 }
29
30 fn is_valid(value: &Value) -> bool {
31 matches!(value, Value::String(_))
32 }
33
34 fn to_value(&self) -> Value {
35 Value::String(self.to_string())
36 }
37}
38
39impl NoProxyItem {
40 fn as_str(&self) -> &str {
41 match self {
42 Self::Wildcard => "*",
43 Self::IpCidr(value, _) | Self::WithDot(value, _, _) | Self::Plain(value) => {
44 value.as_str()
45 }
46 }
47 }
48}
49
50impl std::fmt::Display for NoProxyItem {
51 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
52 f.write_str(self.as_str())
53 }
54}
55
56impl<V: AsRef<str> + Into<String>> From<V> for NoProxyItem {
57 fn from(value: V) -> Self {
58 let value_str = value.as_ref();
59 if value_str == "*" {
60 Self::Wildcard
61 } else if let Ok(ip_cidr) = IpCidr::from_str(value_str) {
62 Self::IpCidr(value.into(), ip_cidr)
63 } else {
64 let start = value_str.starts_with('.');
65 let end = value_str.ends_with('.');
66 if start || end {
67 Self::WithDot(value.into(), start, end)
68 } else {
69 Self::Plain(value.into())
70 }
71 }
72 }
73}
74
75fn parse_host(input: &str) -> &str {
76 if input.starts_with('[') {
79 let x: &[_] = &['[', ']'];
80 input.trim_matches(x)
81 } else {
82 input
83 }
84}
85
86impl NoProxyItem {
87 pub fn matches(&self, value: &str) -> bool {
88 let value = parse_host(value);
89 match self {
90 Self::Wildcard => true,
91 Self::IpCidr(source, ip_cidr) => {
92 if value == source {
93 true
94 } else if let Ok(ip_value) = IpAddr::from_str(value) {
95 ip_cidr.contains(&ip_value)
96 } else {
97 false
98 }
99 }
100 Self::WithDot(source, start, end) => {
101 if *start && *end {
102 value.contains(source)
103 } else if *start {
104 value.ends_with(source)
105 } else if *end {
106 value.starts_with(source)
107 } else {
108 source == value
109 }
110 }
111 Self::Plain(source) => source == value,
112 }
113 }
114}
115
116#[derive(Clone, Debug, Default, Eq, PartialEq)]
117#[cfg_attr(feature = "graphql", derive(async_graphql::SimpleObject))]
118pub struct NoProxy {
119 content: HashSet<NoProxyItem>,
120 has_wildcard: bool,
121}
122
123impl NoProxy {
124 fn from_iterator<V: AsRef<str>, I: Iterator<Item = V>>(iterator: I) -> Self {
125 let content: HashSet<_> = iterator
126 .filter_map(|item| {
127 let short = item.as_ref().trim();
128 if short.is_empty() {
129 None
130 } else {
131 Some(NoProxyItem::from(short))
132 }
133 })
134 .collect();
135 let has_wildcard = content.contains(&NoProxyItem::Wildcard);
136 Self {
137 content,
138 has_wildcard,
139 }
140 }
141}
142
143impl<V: AsRef<str>> From<V> for NoProxy {
144 fn from(value: V) -> Self {
145 Self::from_iterator(value.as_ref().split(','))
146 }
147}
148
149impl IntoIterator for NoProxy {
150 type Item = NoProxyItem;
151 type IntoIter = IntoIter<NoProxyItem>;
152
153 fn into_iter(self) -> Self::IntoIter {
154 self.content.into_iter()
155 }
156}
157
158impl Extend<NoProxyItem> for NoProxy {
159 fn extend<T: IntoIterator<Item = NoProxyItem>>(&mut self, iter: T) {
160 self.content.extend(iter);
161 self.has_wildcard = self.content.contains(&NoProxyItem::Wildcard);
162 }
163}
164
165impl NoProxy {
166 pub fn is_empty(&self) -> bool {
167 self.content.is_empty()
168 }
169
170 pub fn matches(&self, input: &str) -> bool {
171 if self.has_wildcard {
172 return true;
173 }
174 self.content.iter().any(|item| item.matches(input))
175 }
176}
177
178impl std::fmt::Display for NoProxy {
179 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
180 for (index, item) in self.content.iter().enumerate() {
181 if index > 0 {
182 write!(f, ",")?;
183 }
184 item.fmt(f)?;
185 }
186 Ok(())
187 }
188}
189
190#[cfg(test)]
191mod tests {
192 use super::*;
193
194 fn should_match(pattern: &str, value: &str) {
195 let no_proxy = NoProxy::from(pattern);
196 assert!(
197 no_proxy.matches(value),
198 "{} should match {}",
199 pattern,
200 value
201 );
202 }
203
204 fn shouldnt_match(pattern: &str, value: &str) {
205 let no_proxy = NoProxy::from(pattern);
206 assert!(
207 !no_proxy.matches(value),
208 "{} should not match {}",
209 pattern,
210 value
211 );
212 }
213
214 #[test]
215 fn filter_empty() {
216 let no_proxy = NoProxy::from("");
217 assert!(no_proxy.is_empty());
218 }
219
220 #[test]
221 fn wildcard() {
222 should_match("*", "www.wikipedia.org");
223 should_match("*", "192.168.0.1");
224 should_match("localhost , *", "wikipedia.org");
225 }
226
227 #[test]
228 fn cidr() {
229 should_match("21.19.35.0/24", "21.19.35.4");
230 shouldnt_match("21.19.35.0/24", "127.0.0.1");
231 }
232
233 #[test]
234 fn leading_dot() {
235 should_match(".wikipedia.org", "fr.wikipedia.org");
236 shouldnt_match(".wikipedia.org", "fr.wikipedia.co.uk");
237 shouldnt_match(".wikipedia.org", "wikipedia.org");
238 shouldnt_match(".wikipedia.org", "google.com");
239 should_match(".168.0.1", "192.168.0.1");
240 shouldnt_match(".168.0.1", "192.169.0.1");
241 }
242
243 #[test]
244 fn trailing_dot() {
245 should_match("fr.wikipedia.", "fr.wikipedia.com");
246 should_match("fr.wikipedia.", "fr.wikipedia.org");
247 should_match("fr.wikipedia.", "fr.wikipedia.somewhere.dangerous");
248 shouldnt_match("fr.wikipedia.", "www.google.com");
249 should_match("192.168.0.", "192.168.0.1");
250 shouldnt_match("192.168.0.", "192.169.0.1");
251 }
252
253 #[test]
254 fn white_space() {
255 shouldnt_match("", "localhost");
256 shouldnt_match("", "somewhere.local");
257 }
258
259 #[test]
260 fn combination() {
261 let pattern = "127.0.0.1,localhost,.local,169.254.169.254,fileshare.company.com";
262 should_match(pattern, "localhost");
263 should_match(pattern, "somewhere.local");
264 }
265
266 #[test]
267 fn from_reqwest() {
268 let pattern = ".foo.bar,bar.baz,10.42.1.0/24,::1,10.124.7.8,2001::/17";
269 shouldnt_match(pattern, "hyper.rs");
270 shouldnt_match(pattern, "foo.bar.baz");
271 shouldnt_match(pattern, "10.43.1.1");
272 shouldnt_match(pattern, "10.124.7.7");
273 shouldnt_match(pattern, "[ffff:db8:a0b:12f0::1]");
274 shouldnt_match(pattern, "[2005:db8:a0b:12f0::1]");
275
276 should_match(pattern, "hello.foo.bar");
277 should_match(pattern, "bar.baz");
278 should_match(pattern, "10.42.1.100");
279 should_match(pattern, "[::1]");
280 should_match(pattern, "[2001:db8:a0b:12f0::1]");
281 should_match(pattern, "10.124.7.8");
282 }
283
284 #[test]
285 fn extending() {
286 let mut first = NoProxy::from("foo.bar");
287 let second = NoProxy::from("*");
288 first.extend(second);
289 assert!(first.has_wildcard);
290 }
291}