mdbook_linkcheck2/
config.rs

1#![allow(clippy::mutable_key_type)]
2
3use crate::hashed_regex::HashedRegex;
4use anyhow::Error;
5use http::header::{HeaderName, HeaderValue};
6use log::Level;
7use reqwest::Client;
8use serde_derive::{Deserialize, Serialize};
9use std::{
10    collections::HashMap,
11    convert::TryFrom,
12    fmt::{self, Display, Formatter},
13    str::FromStr,
14    time::Duration,
15};
16
17/// The configuration options available with this backend.
18#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
19#[serde(default, rename_all = "kebab-case")]
20pub struct Config {
21    /// If a link on the internet is encountered, should we still try to check
22    /// if it's valid? Defaults to `false` because this has a big performance
23    /// impact.
24    pub follow_web_links: bool,
25    /// Are we allowed to link to files outside of the book's source directory?
26    pub traverse_parent_directories: bool,
27    /// Turns on support for latex. If true, then the latex fragments will be
28    /// cut off before the file is processed for link consistency.
29    pub latex_support: bool,
30    /// A list of URL patterns to ignore when checking remote links.
31    #[serde(default)]
32    pub exclude: Vec<HashedRegex>,
33    /// The user-agent used whenever any web requests are made.
34    #[serde(default = "default_user_agent")]
35    pub user_agent: String,
36    /// The number of seconds a cached result is valid for.
37    #[serde(default = "default_cache_timeout")]
38    pub cache_timeout: u64,
39    /// The policy to use when warnings are encountered.
40    #[serde(default)]
41    pub warning_policy: WarningPolicy,
42    /// The map of regexes representing sets of web sites and
43    /// the list of HTTP headers that must be sent to matching sites.
44    #[serde(default)]
45    pub http_headers: HashMap<HashedRegex, Vec<HttpHeader>>,
46}
47
48#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)]
49#[serde(try_from = "String", into = "String")]
50pub struct HttpHeader {
51    pub name: HeaderName,
52    pub value: String,
53}
54
55impl HttpHeader {
56    pub(crate) fn interpolate(&self) -> Result<HeaderValue, Error> {
57        interpolate_env(&self.value)
58    }
59}
60
61impl Display for HttpHeader {
62    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
63        write!(f, "{}: {}", self.name, self.value)
64    }
65}
66
67impl Config {
68    /// The default cache timeout (around 12 hours).
69    pub const DEFAULT_CACHE_TIMEOUT: Duration = Duration::from_secs(60 * 60 * 12);
70    /// The default user-agent.
71    pub const DEFAULT_USER_AGENT: &'static str =
72        concat!(env!("CARGO_PKG_NAME"), "-", env!("CARGO_PKG_VERSION"));
73
74    /// Checks [`Config::exclude`] to see if the provided link should be
75    /// skipped.
76    pub fn should_skip(&self, link: &str) -> bool {
77        self.exclude.iter().any(|pat| pat.find(link).is_some())
78    }
79
80    pub(crate) fn client(&self) -> Client {
81        let mut headers = http::HeaderMap::new();
82        headers.insert(http::header::USER_AGENT, self.user_agent.parse().unwrap());
83        Client::builder().default_headers(headers).build().unwrap()
84    }
85
86    pub(crate) fn interpolate_headers(
87        &self,
88        warning_policy: WarningPolicy,
89    ) -> Vec<(HashedRegex, Vec<(HeaderName, HeaderValue)>)> {
90        let mut all_headers = Vec::new();
91        let log_level = warning_policy.to_log_level();
92
93        for (pattern, headers) in &self.http_headers {
94            let mut interpolated = Vec::new();
95
96            for header in headers {
97                match header.interpolate() {
98                    Ok(value) => interpolated.push((header.name.clone(), value)),
99                    Err(e) => {
100                        // We don't want failed interpolation (i.e. due to a
101                        // missing env variable) to abort the whole
102                        // linkchecking, so emit a warning and keep going.
103                        //
104                        // If it was important, the user would notice a "broken"
105                        // link and read back through the logs.
106                        log::log!(
107                            log_level,
108                            "Unable to interpolate \"{}\" because {}",
109                            header,
110                            e
111                        );
112                    }
113                }
114            }
115
116            all_headers.push((pattern.clone(), interpolated));
117        }
118
119        all_headers
120    }
121}
122
123impl Default for Config {
124    fn default() -> Config {
125        Config {
126            follow_web_links: false,
127            traverse_parent_directories: false,
128            latex_support: false,
129            exclude: Vec::new(),
130            user_agent: default_user_agent(),
131            http_headers: HashMap::new(),
132            warning_policy: WarningPolicy::Warn,
133            cache_timeout: Config::DEFAULT_CACHE_TIMEOUT.as_secs(),
134        }
135    }
136}
137
138impl FromStr for HttpHeader {
139    type Err = Error;
140
141    fn from_str(s: &str) -> Result<Self, Self::Err> {
142        match s.find(": ") {
143            Some(idx) => {
144                let name = s[..idx].parse()?;
145                let value = s[idx + 2..].to_string();
146                Ok(HttpHeader { name, value })
147            }
148
149            None => Err(Error::msg(format!(
150                "The `{}` HTTP header must be in the form `key: value` but it isn't",
151                s
152            ))),
153        }
154    }
155}
156
157impl TryFrom<&'_ str> for HttpHeader {
158    type Error = Error;
159
160    fn try_from(s: &'_ str) -> Result<Self, Error> {
161        HttpHeader::from_str(s)
162    }
163}
164
165impl TryFrom<String> for HttpHeader {
166    type Error = Error;
167
168    fn try_from(s: String) -> Result<Self, Error> {
169        HttpHeader::try_from(s.as_str())
170    }
171}
172
173impl From<HttpHeader> for String {
174    fn from(val: HttpHeader) -> Self {
175        let HttpHeader { name, value, .. } = val;
176        format!("{}: {}", name, value)
177    }
178}
179
180fn default_cache_timeout() -> u64 {
181    Config::DEFAULT_CACHE_TIMEOUT.as_secs()
182}
183fn default_user_agent() -> String {
184    Config::DEFAULT_USER_AGENT.to_string()
185}
186
187fn interpolate_env(value: &str) -> Result<HeaderValue, Error> {
188    use std::{iter::Peekable, str::CharIndices};
189
190    fn is_ident(ch: char) -> bool {
191        ch.is_ascii_alphanumeric() || ch == '_'
192    }
193
194    fn ident_end(start: usize, iter: &mut Peekable<CharIndices>) -> usize {
195        let mut end = start;
196        while let Some(&(i, ch)) = iter.peek() {
197            if !is_ident(ch) {
198                return i;
199            }
200            end = i + ch.len_utf8();
201            iter.next();
202        }
203
204        end
205    }
206
207    let mut res = String::with_capacity(value.len());
208    let mut backslash = false;
209    let mut iter = value.char_indices().peekable();
210
211    while let Some((i, ch)) = iter.next() {
212        if backslash {
213            match ch {
214                '$' | '\\' => res.push(ch),
215                _ => {
216                    res.push('\\');
217                    res.push(ch);
218                }
219            }
220
221            backslash = false;
222        } else {
223            match ch {
224                '\\' => backslash = true,
225                '$' => {
226                    iter.next();
227                    let start = i + 1;
228                    let end = ident_end(start, &mut iter);
229                    let name = &value[start..end];
230
231                    match std::env::var(name) {
232                        Ok(env) => res.push_str(&env),
233                        Err(e) => {
234                            return Err(Error::msg(format!(
235                                "Failed to retrieve `{}` env var: {}",
236                                name, e
237                            )))
238                        }
239                    }
240                }
241
242                _ => res.push(ch),
243            }
244        }
245    }
246
247    // trailing backslash
248    if backslash {
249        res.push('\\');
250    }
251
252    Ok(res.parse()?)
253}
254
255/// How should warnings be treated?
256#[derive(Debug, Copy, Clone, PartialEq, Serialize, Deserialize)]
257#[serde(rename_all = "kebab-case")]
258#[derive(Default)]
259pub enum WarningPolicy {
260    /// Silently ignore them.
261    Ignore,
262    /// Warn the user, but don't fail the linkcheck.
263    #[default]
264    Warn,
265    /// Treat warnings as errors.
266    Error,
267}
268
269impl WarningPolicy {
270    pub(crate) fn to_log_level(self) -> Level {
271        match self {
272            WarningPolicy::Error => Level::Error,
273            WarningPolicy::Warn => Level::Warn,
274            WarningPolicy::Ignore => Level::Debug,
275        }
276    }
277}
278
279#[cfg(test)]
280mod tests {
281    use super::*;
282    use std::{convert::TryInto, iter::FromIterator};
283
284    const CONFIG: &str = r#"follow-web-links = true
285traverse-parent-directories = true
286latex-support = true
287exclude = ['google\.com']
288user-agent = "Internet Explorer"
289cache-timeout = 3600
290warning-policy = "error"
291
292[http-headers]
293https = ["accept: html/text", "authorization: Basic $TOKEN"]
294"#;
295
296    #[test]
297    fn deserialize_a_config() {
298        std::env::set_var("TOKEN", "QWxhZGRpbjpPcGVuU2VzYW1l");
299
300        let should_be = Config {
301            follow_web_links: true,
302            warning_policy: WarningPolicy::Error,
303            traverse_parent_directories: true,
304            exclude: vec![HashedRegex::new(r"google\.com").unwrap()],
305            user_agent: String::from("Internet Explorer"),
306            http_headers: HashMap::from_iter(vec![(
307                HashedRegex::new("https").unwrap(),
308                vec![
309                    "Accept: html/text".try_into().unwrap(),
310                    "Authorization: Basic $TOKEN".try_into().unwrap(),
311                ],
312            )]),
313            cache_timeout: 3600,
314            latex_support: true,
315        };
316
317        let got: Config = toml::from_str(CONFIG).unwrap();
318
319        assert_eq!(got, should_be);
320    }
321
322    #[test]
323    fn round_trip_config() {
324        // A check that a value of an env var is not leaked in the
325        // deserialization
326        std::env::set_var("TOKEN", "QWxhZGRpbjpPcGVuU2VzYW1l");
327
328        let deserialized: Config = toml::from_str(CONFIG).unwrap();
329        let reserialized = toml::to_string(&deserialized).unwrap();
330
331        assert_eq!(reserialized, CONFIG);
332    }
333
334    #[test]
335    fn interpolation() {
336        std::env::set_var("SUPER_SECRET_TOKEN", "abcdefg123456");
337        let header = HttpHeader {
338            name: "Authorization".parse().unwrap(),
339            value: "Basic $SUPER_SECRET_TOKEN".into(),
340        };
341        let should_be: HeaderValue = "Basic abcdefg123456".parse().unwrap();
342
343        let got = header.interpolate().unwrap();
344
345        assert_eq!(got, should_be);
346    }
347}