1#![allow(clippy::mutable_key_type)]
2
3use crate::hashed_regex::HashedRegex;
4use anyhow::Error;
5use http::header::{HeaderName, HeaderValue};
6use log::Level;
7use reqwest::Client;
8use serde_derive::{Deserialize, Serialize};
9use std::{
10 collections::HashMap,
11 convert::TryFrom,
12 fmt::{self, Display, Formatter},
13 str::FromStr,
14 time::Duration,
15};
16
17#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
19#[serde(default, rename_all = "kebab-case")]
20pub struct Config {
21 pub follow_web_links: bool,
25 pub traverse_parent_directories: bool,
27 pub latex_support: bool,
30 #[serde(default)]
32 pub exclude: Vec<HashedRegex>,
33 #[serde(default = "default_user_agent")]
35 pub user_agent: String,
36 #[serde(default = "default_cache_timeout")]
38 pub cache_timeout: u64,
39 #[serde(default)]
41 pub warning_policy: WarningPolicy,
42 #[serde(default)]
45 pub http_headers: HashMap<HashedRegex, Vec<HttpHeader>>,
46}
47
48#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)]
49#[serde(try_from = "String", into = "String")]
50pub struct HttpHeader {
51 pub name: HeaderName,
52 pub value: String,
53}
54
55impl HttpHeader {
56 pub(crate) fn interpolate(&self) -> Result<HeaderValue, Error> {
57 interpolate_env(&self.value)
58 }
59}
60
61impl Display for HttpHeader {
62 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
63 write!(f, "{}: {}", self.name, self.value)
64 }
65}
66
67impl Config {
68 pub const DEFAULT_CACHE_TIMEOUT: Duration = Duration::from_secs(60 * 60 * 12);
70 pub const DEFAULT_USER_AGENT: &'static str =
72 concat!(env!("CARGO_PKG_NAME"), "-", env!("CARGO_PKG_VERSION"));
73
74 pub fn should_skip(&self, link: &str) -> bool {
77 self.exclude.iter().any(|pat| pat.find(link).is_some())
78 }
79
80 pub(crate) fn client(&self) -> Client {
81 let mut headers = http::HeaderMap::new();
82 headers.insert(http::header::USER_AGENT, self.user_agent.parse().unwrap());
83 Client::builder().default_headers(headers).build().unwrap()
84 }
85
86 pub(crate) fn interpolate_headers(
87 &self,
88 warning_policy: WarningPolicy,
89 ) -> Vec<(HashedRegex, Vec<(HeaderName, HeaderValue)>)> {
90 let mut all_headers = Vec::new();
91 let log_level = warning_policy.to_log_level();
92
93 for (pattern, headers) in &self.http_headers {
94 let mut interpolated = Vec::new();
95
96 for header in headers {
97 match header.interpolate() {
98 Ok(value) => interpolated.push((header.name.clone(), value)),
99 Err(e) => {
100 log::log!(
107 log_level,
108 "Unable to interpolate \"{}\" because {}",
109 header,
110 e
111 );
112 }
113 }
114 }
115
116 all_headers.push((pattern.clone(), interpolated));
117 }
118
119 all_headers
120 }
121}
122
123impl Default for Config {
124 fn default() -> Config {
125 Config {
126 follow_web_links: false,
127 traverse_parent_directories: false,
128 latex_support: false,
129 exclude: Vec::new(),
130 user_agent: default_user_agent(),
131 http_headers: HashMap::new(),
132 warning_policy: WarningPolicy::Warn,
133 cache_timeout: Config::DEFAULT_CACHE_TIMEOUT.as_secs(),
134 }
135 }
136}
137
138impl FromStr for HttpHeader {
139 type Err = Error;
140
141 fn from_str(s: &str) -> Result<Self, Self::Err> {
142 match s.find(": ") {
143 Some(idx) => {
144 let name = s[..idx].parse()?;
145 let value = s[idx + 2..].to_string();
146 Ok(HttpHeader { name, value })
147 }
148
149 None => Err(Error::msg(format!(
150 "The `{}` HTTP header must be in the form `key: value` but it isn't",
151 s
152 ))),
153 }
154 }
155}
156
157impl TryFrom<&'_ str> for HttpHeader {
158 type Error = Error;
159
160 fn try_from(s: &'_ str) -> Result<Self, Error> {
161 HttpHeader::from_str(s)
162 }
163}
164
165impl TryFrom<String> for HttpHeader {
166 type Error = Error;
167
168 fn try_from(s: String) -> Result<Self, Error> {
169 HttpHeader::try_from(s.as_str())
170 }
171}
172
173impl From<HttpHeader> for String {
174 fn from(val: HttpHeader) -> Self {
175 let HttpHeader { name, value, .. } = val;
176 format!("{}: {}", name, value)
177 }
178}
179
180fn default_cache_timeout() -> u64 {
181 Config::DEFAULT_CACHE_TIMEOUT.as_secs()
182}
183fn default_user_agent() -> String {
184 Config::DEFAULT_USER_AGENT.to_string()
185}
186
187fn interpolate_env(value: &str) -> Result<HeaderValue, Error> {
188 use std::{iter::Peekable, str::CharIndices};
189
190 fn is_ident(ch: char) -> bool {
191 ch.is_ascii_alphanumeric() || ch == '_'
192 }
193
194 fn ident_end(start: usize, iter: &mut Peekable<CharIndices>) -> usize {
195 let mut end = start;
196 while let Some(&(i, ch)) = iter.peek() {
197 if !is_ident(ch) {
198 return i;
199 }
200 end = i + ch.len_utf8();
201 iter.next();
202 }
203
204 end
205 }
206
207 let mut res = String::with_capacity(value.len());
208 let mut backslash = false;
209 let mut iter = value.char_indices().peekable();
210
211 while let Some((i, ch)) = iter.next() {
212 if backslash {
213 match ch {
214 '$' | '\\' => res.push(ch),
215 _ => {
216 res.push('\\');
217 res.push(ch);
218 }
219 }
220
221 backslash = false;
222 } else {
223 match ch {
224 '\\' => backslash = true,
225 '$' => {
226 iter.next();
227 let start = i + 1;
228 let end = ident_end(start, &mut iter);
229 let name = &value[start..end];
230
231 match std::env::var(name) {
232 Ok(env) => res.push_str(&env),
233 Err(e) => {
234 return Err(Error::msg(format!(
235 "Failed to retrieve `{}` env var: {}",
236 name, e
237 )))
238 }
239 }
240 }
241
242 _ => res.push(ch),
243 }
244 }
245 }
246
247 if backslash {
249 res.push('\\');
250 }
251
252 Ok(res.parse()?)
253}
254
255#[derive(Debug, Copy, Clone, PartialEq, Serialize, Deserialize)]
257#[serde(rename_all = "kebab-case")]
258#[derive(Default)]
259pub enum WarningPolicy {
260 Ignore,
262 #[default]
264 Warn,
265 Error,
267}
268
269impl WarningPolicy {
270 pub(crate) fn to_log_level(self) -> Level {
271 match self {
272 WarningPolicy::Error => Level::Error,
273 WarningPolicy::Warn => Level::Warn,
274 WarningPolicy::Ignore => Level::Debug,
275 }
276 }
277}
278
279#[cfg(test)]
280mod tests {
281 use super::*;
282 use std::{convert::TryInto, iter::FromIterator};
283
284 const CONFIG: &str = r#"follow-web-links = true
285traverse-parent-directories = true
286latex-support = true
287exclude = ['google\.com']
288user-agent = "Internet Explorer"
289cache-timeout = 3600
290warning-policy = "error"
291
292[http-headers]
293https = ["accept: html/text", "authorization: Basic $TOKEN"]
294"#;
295
296 #[test]
297 fn deserialize_a_config() {
298 std::env::set_var("TOKEN", "QWxhZGRpbjpPcGVuU2VzYW1l");
299
300 let should_be = Config {
301 follow_web_links: true,
302 warning_policy: WarningPolicy::Error,
303 traverse_parent_directories: true,
304 exclude: vec![HashedRegex::new(r"google\.com").unwrap()],
305 user_agent: String::from("Internet Explorer"),
306 http_headers: HashMap::from_iter(vec![(
307 HashedRegex::new("https").unwrap(),
308 vec![
309 "Accept: html/text".try_into().unwrap(),
310 "Authorization: Basic $TOKEN".try_into().unwrap(),
311 ],
312 )]),
313 cache_timeout: 3600,
314 latex_support: true,
315 };
316
317 let got: Config = toml::from_str(CONFIG).unwrap();
318
319 assert_eq!(got, should_be);
320 }
321
322 #[test]
323 fn round_trip_config() {
324 std::env::set_var("TOKEN", "QWxhZGRpbjpPcGVuU2VzYW1l");
327
328 let deserialized: Config = toml::from_str(CONFIG).unwrap();
329 let reserialized = toml::to_string(&deserialized).unwrap();
330
331 assert_eq!(reserialized, CONFIG);
332 }
333
334 #[test]
335 fn interpolation() {
336 std::env::set_var("SUPER_SECRET_TOKEN", "abcdefg123456");
337 let header = HttpHeader {
338 name: "Authorization".parse().unwrap(),
339 value: "Basic $SUPER_SECRET_TOKEN".into(),
340 };
341 let should_be: HeaderValue = "Basic abcdefg123456".parse().unwrap();
342
343 let got = header.interpolate().unwrap();
344
345 assert_eq!(got, should_be);
346 }
347}