1use crate::hashed_regex::HashedRegex;
2use anyhow::Error;
3use http::header::{HeaderName, HeaderValue};
4use log::Level;
5use reqwest::Client;
6use serde_derive::{Deserialize, Serialize};
7use std::{
8 collections::HashMap,
9 convert::TryFrom,
10 fmt::{self, Display, Formatter},
11 str::FromStr,
12 time::Duration,
13};
14
15#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
17#[serde(default, rename_all = "kebab-case")]
18pub struct Config {
19 pub follow_web_links: bool,
23 pub traverse_parent_directories: bool,
25 #[serde(default)]
27 pub exclude: Vec<HashedRegex>,
28 #[serde(default = "default_user_agent")]
30 pub user_agent: String,
31 #[serde(default = "default_cache_timeout")]
33 pub cache_timeout: u64,
34 #[serde(default)]
36 pub warning_policy: WarningPolicy,
37 #[serde(default)]
40 pub http_headers: HashMap<HashedRegex, Vec<HttpHeader>>,
41}
42
43#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)]
44#[serde(try_from = "String", into = "String")]
45pub struct HttpHeader {
46 pub name: HeaderName,
47 pub value: String,
48}
49
50impl HttpHeader {
51 pub(crate) fn interpolate(&self) -> Result<HeaderValue, Error> {
52 interpolate_env(&self.value)
53 }
54}
55
56impl Display for HttpHeader {
57 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
58 write!(f, "{}: {}", self.name, self.value)
59 }
60}
61
62impl Config {
63 pub const DEFAULT_CACHE_TIMEOUT: Duration =
65 Duration::from_secs(60 * 60 * 12);
66 pub const DEFAULT_USER_AGENT: &'static str =
68 concat!(env!("CARGO_PKG_NAME"), "-", env!("CARGO_PKG_VERSION"));
69
70 pub fn should_skip(&self, link: &str) -> bool {
73 self.exclude.iter().any(|pat| pat.find(link).is_some())
74 }
75
76 pub(crate) fn client(&self) -> Client {
77 let mut headers = http::HeaderMap::new();
78 headers
79 .insert(http::header::USER_AGENT, self.user_agent.parse().unwrap());
80 Client::builder().default_headers(headers).build().unwrap()
81 }
82
83 pub(crate) fn interpolate_headers(
84 &self,
85 warning_policy: WarningPolicy,
86 ) -> Vec<(HashedRegex, Vec<(HeaderName, HeaderValue)>)> {
87 let mut all_headers = Vec::new();
88 let log_level = warning_policy.to_log_level();
89
90 for (pattern, headers) in &self.http_headers {
91 let mut interpolated = Vec::new();
92
93 for header in headers {
94 match header.interpolate() {
95 Ok(value) => {
96 interpolated.push((header.name.clone(), value))
97 },
98 Err(e) => {
99 log::log!(
106 log_level,
107 "Unable to interpolate \"{}\" because {}",
108 header,
109 e
110 );
111 },
112 }
113 }
114
115 all_headers.push((pattern.clone(), interpolated));
116 }
117
118 all_headers
119 }
120}
121
122impl Default for Config {
123 fn default() -> Config {
124 Config {
125 follow_web_links: false,
126 traverse_parent_directories: false,
127 exclude: Vec::new(),
128 user_agent: default_user_agent(),
129 http_headers: HashMap::new(),
130 warning_policy: WarningPolicy::Warn,
131 cache_timeout: Config::DEFAULT_CACHE_TIMEOUT.as_secs(),
132 }
133 }
134}
135
136impl FromStr for HttpHeader {
137 type Err = Error;
138
139 fn from_str(s: &str) -> Result<Self, Self::Err> {
140 match s.find(": ") {
141 Some(idx) => {
142 let name = s[..idx].parse()?;
143 let value = s[idx + 2..].to_string();
144 Ok(HttpHeader {
145 name,
146 value,
147 })
148 },
149
150 None => Err(Error::msg(format!(
151 "The `{}` HTTP header must be in the form `key: value` but it isn't",
152 s
153 ))),
154 }
155 }
156}
157
158impl TryFrom<&'_ str> for HttpHeader {
159 type Error = Error;
160
161 fn try_from(s: &'_ str) -> Result<Self, Error> { HttpHeader::from_str(s) }
162}
163
164impl TryFrom<String> for HttpHeader {
165 type Error = Error;
166
167 fn try_from(s: String) -> Result<Self, Error> {
168 HttpHeader::try_from(s.as_str())
169 }
170}
171
172impl Into<String> for HttpHeader {
173 fn into(self) -> String {
174 let HttpHeader { name, value, .. } = self;
175 format!("{}: {}", name, value)
176 }
177}
178
179fn default_cache_timeout() -> u64 { Config::DEFAULT_CACHE_TIMEOUT.as_secs() }
180fn default_user_agent() -> String { Config::DEFAULT_USER_AGENT.to_string() }
181
182fn interpolate_env(value: &str) -> Result<HeaderValue, Error> {
183 use std::{iter::Peekable, str::CharIndices};
184
185 fn is_ident(ch: char) -> bool { ch.is_ascii_alphanumeric() || ch == '_' }
186
187 fn ident_end(start: usize, iter: &mut Peekable<CharIndices>) -> usize {
188 let mut end = start;
189 while let Some(&(i, ch)) = iter.peek() {
190 if !is_ident(ch) {
191 return i;
192 }
193 end = i + ch.len_utf8();
194 iter.next();
195 }
196
197 end
198 }
199
200 let mut res = String::with_capacity(value.len());
201 let mut backslash = false;
202 let mut iter = value.char_indices().peekable();
203
204 while let Some((i, ch)) = iter.next() {
205 if backslash {
206 match ch {
207 '$' | '\\' => res.push(ch),
208 _ => {
209 res.push('\\');
210 res.push(ch);
211 },
212 }
213
214 backslash = false;
215 } else {
216 match ch {
217 '\\' => backslash = true,
218 '$' => {
219 iter.next();
220 let start = i + 1;
221 let end = ident_end(start, &mut iter);
222 let name = &value[start..end];
223
224 match std::env::var(name) {
225 Ok(env) => res.push_str(&env),
226 Err(e) => {
227 return Err(Error::msg(format!(
228 "Failed to retrieve `{}` env var: {}",
229 name, e
230 )))
231 },
232 }
233 },
234
235 _ => res.push(ch),
236 }
237 }
238 }
239
240 if backslash {
242 res.push('\\');
243 }
244
245 Ok(res.parse()?)
246}
247
248#[derive(Debug, Copy, Clone, PartialEq, Serialize, Deserialize)]
250#[serde(rename_all = "kebab-case")]
251pub enum WarningPolicy {
252 Ignore,
254 Warn,
256 Error,
258}
259
260impl WarningPolicy {
261 pub(crate) fn to_log_level(self) -> Level {
262 match self {
263 WarningPolicy::Error => Level::Error,
264 WarningPolicy::Warn => Level::Warn,
265 WarningPolicy::Ignore => Level::Debug,
266 }
267 }
268}
269
270impl Default for WarningPolicy {
271 fn default() -> WarningPolicy { WarningPolicy::Warn }
272}
273
274#[cfg(test)]
275mod tests {
276 use super::*;
277 use std::{convert::TryInto, iter::FromIterator};
278 use toml;
279
280 const CONFIG: &str = r#"follow-web-links = true
281traverse-parent-directories = true
282exclude = ["google\\.com"]
283user-agent = "Internet Explorer"
284cache-timeout = 3600
285warning-policy = "error"
286
287[http-headers]
288https = ["accept: html/text", "authorization: Basic $TOKEN"]
289"#;
290
291 #[test]
292 fn deserialize_a_config() {
293 std::env::set_var("TOKEN", "QWxhZGRpbjpPcGVuU2VzYW1l");
294
295 let should_be = Config {
296 follow_web_links: true,
297 warning_policy: WarningPolicy::Error,
298 traverse_parent_directories: true,
299 exclude: vec![HashedRegex::new(r"google\.com").unwrap()],
300 user_agent: String::from("Internet Explorer"),
301 http_headers: HashMap::from_iter(vec![(
302 HashedRegex::new("https").unwrap(),
303 vec![
304 "Accept: html/text".try_into().unwrap(),
305 "Authorization: Basic $TOKEN".try_into().unwrap(),
306 ],
307 )]),
308 cache_timeout: 3600,
309 };
310
311 let got: Config = toml::from_str(CONFIG).unwrap();
312
313 assert_eq!(got, should_be);
314 }
315
316 #[test]
317 fn round_trip_config() {
318 std::env::set_var("TOKEN", "QWxhZGRpbjpPcGVuU2VzYW1l");
321
322 let deserialized: Config = toml::from_str(CONFIG).unwrap();
323 let reserialized = toml::to_string(&deserialized).unwrap();
324
325 assert_eq!(reserialized, CONFIG);
326 }
327
328 #[test]
329 fn interpolation() {
330 std::env::set_var("SUPER_SECRET_TOKEN", "abcdefg123456");
331 let header = HttpHeader {
332 name: "Authorization".parse().unwrap(),
333 value: "Basic $SUPER_SECRET_TOKEN".into(),
334 };
335 let should_be: HeaderValue = "Basic abcdefg123456".parse().unwrap();
336
337 let got = header.interpolate().unwrap();
338
339 assert_eq!(got, should_be);
340 }
341}