mdbook_linkcheck/
config.rs

1use crate::hashed_regex::HashedRegex;
2use anyhow::Error;
3use http::header::{HeaderName, HeaderValue};
4use log::Level;
5use reqwest::Client;
6use serde_derive::{Deserialize, Serialize};
7use std::{
8    collections::HashMap,
9    convert::TryFrom,
10    fmt::{self, Display, Formatter},
11    str::FromStr,
12    time::Duration,
13};
14
15/// The configuration options available with this backend.
16#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
17#[serde(default, rename_all = "kebab-case")]
18pub struct Config {
19    /// If a link on the internet is encountered, should we still try to check
20    /// if it's valid? Defaults to `false` because this has a big performance
21    /// impact.
22    pub follow_web_links: bool,
23    /// Are we allowed to link to files outside of the book's source directory?
24    pub traverse_parent_directories: bool,
25    /// A list of URL patterns to ignore when checking remote links.
26    #[serde(default)]
27    pub exclude: Vec<HashedRegex>,
28    /// The user-agent used whenever any web requests are made.
29    #[serde(default = "default_user_agent")]
30    pub user_agent: String,
31    /// The number of seconds a cached result is valid for.
32    #[serde(default = "default_cache_timeout")]
33    pub cache_timeout: u64,
34    /// The policy to use when warnings are encountered.
35    #[serde(default)]
36    pub warning_policy: WarningPolicy,
37    /// The map of regexes representing sets of web sites and
38    /// the list of HTTP headers that must be sent to matching sites.
39    #[serde(default)]
40    pub http_headers: HashMap<HashedRegex, Vec<HttpHeader>>,
41}
42
43#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)]
44#[serde(try_from = "String", into = "String")]
45pub struct HttpHeader {
46    pub name: HeaderName,
47    pub value: String,
48}
49
50impl HttpHeader {
51    pub(crate) fn interpolate(&self) -> Result<HeaderValue, Error> {
52        interpolate_env(&self.value)
53    }
54}
55
56impl Display for HttpHeader {
57    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
58        write!(f, "{}: {}", self.name, self.value)
59    }
60}
61
62impl Config {
63    /// The default cache timeout (around 12 hours).
64    pub const DEFAULT_CACHE_TIMEOUT: Duration =
65        Duration::from_secs(60 * 60 * 12);
66    /// The default user-agent.
67    pub const DEFAULT_USER_AGENT: &'static str =
68        concat!(env!("CARGO_PKG_NAME"), "-", env!("CARGO_PKG_VERSION"));
69
70    /// Checks [`Config::exclude`] to see if the provided link should be
71    /// skipped.
72    pub fn should_skip(&self, link: &str) -> bool {
73        self.exclude.iter().any(|pat| pat.find(link).is_some())
74    }
75
76    pub(crate) fn client(&self) -> Client {
77        let mut headers = http::HeaderMap::new();
78        headers
79            .insert(http::header::USER_AGENT, self.user_agent.parse().unwrap());
80        Client::builder().default_headers(headers).build().unwrap()
81    }
82
83    pub(crate) fn interpolate_headers(
84        &self,
85        warning_policy: WarningPolicy,
86    ) -> Vec<(HashedRegex, Vec<(HeaderName, HeaderValue)>)> {
87        let mut all_headers = Vec::new();
88        let log_level = warning_policy.to_log_level();
89
90        for (pattern, headers) in &self.http_headers {
91            let mut interpolated = Vec::new();
92
93            for header in headers {
94                match header.interpolate() {
95                    Ok(value) => {
96                        interpolated.push((header.name.clone(), value))
97                    },
98                    Err(e) => {
99                        // We don't want failed interpolation (i.e. due to a
100                        // missing env variable) to abort the whole
101                        // linkchecking, so emit a warning and keep going.
102                        //
103                        // If it was important, the user would notice a "broken"
104                        // link and read back through the logs.
105                        log::log!(
106                            log_level,
107                            "Unable to interpolate \"{}\" because {}",
108                            header,
109                            e
110                        );
111                    },
112                }
113            }
114
115            all_headers.push((pattern.clone(), interpolated));
116        }
117
118        all_headers
119    }
120}
121
122impl Default for Config {
123    fn default() -> Config {
124        Config {
125            follow_web_links: false,
126            traverse_parent_directories: false,
127            exclude: Vec::new(),
128            user_agent: default_user_agent(),
129            http_headers: HashMap::new(),
130            warning_policy: WarningPolicy::Warn,
131            cache_timeout: Config::DEFAULT_CACHE_TIMEOUT.as_secs(),
132        }
133    }
134}
135
136impl FromStr for HttpHeader {
137    type Err = Error;
138
139    fn from_str(s: &str) -> Result<Self, Self::Err> {
140        match s.find(": ") {
141            Some(idx) => {
142                let name = s[..idx].parse()?;
143                let value = s[idx + 2..].to_string();
144                Ok(HttpHeader {
145                    name,
146                    value,
147                })
148            },
149
150            None => Err(Error::msg(format!(
151                "The `{}` HTTP header must be in the form `key: value` but it isn't",
152                s
153            ))),
154        }
155    }
156}
157
158impl TryFrom<&'_ str> for HttpHeader {
159    type Error = Error;
160
161    fn try_from(s: &'_ str) -> Result<Self, Error> { HttpHeader::from_str(s) }
162}
163
164impl TryFrom<String> for HttpHeader {
165    type Error = Error;
166
167    fn try_from(s: String) -> Result<Self, Error> {
168        HttpHeader::try_from(s.as_str())
169    }
170}
171
172impl Into<String> for HttpHeader {
173    fn into(self) -> String {
174        let HttpHeader { name, value, .. } = self;
175        format!("{}: {}", name, value)
176    }
177}
178
179fn default_cache_timeout() -> u64 { Config::DEFAULT_CACHE_TIMEOUT.as_secs() }
180fn default_user_agent() -> String { Config::DEFAULT_USER_AGENT.to_string() }
181
182fn interpolate_env(value: &str) -> Result<HeaderValue, Error> {
183    use std::{iter::Peekable, str::CharIndices};
184
185    fn is_ident(ch: char) -> bool { ch.is_ascii_alphanumeric() || ch == '_' }
186
187    fn ident_end(start: usize, iter: &mut Peekable<CharIndices>) -> usize {
188        let mut end = start;
189        while let Some(&(i, ch)) = iter.peek() {
190            if !is_ident(ch) {
191                return i;
192            }
193            end = i + ch.len_utf8();
194            iter.next();
195        }
196
197        end
198    }
199
200    let mut res = String::with_capacity(value.len());
201    let mut backslash = false;
202    let mut iter = value.char_indices().peekable();
203
204    while let Some((i, ch)) = iter.next() {
205        if backslash {
206            match ch {
207                '$' | '\\' => res.push(ch),
208                _ => {
209                    res.push('\\');
210                    res.push(ch);
211                },
212            }
213
214            backslash = false;
215        } else {
216            match ch {
217                '\\' => backslash = true,
218                '$' => {
219                    iter.next();
220                    let start = i + 1;
221                    let end = ident_end(start, &mut iter);
222                    let name = &value[start..end];
223
224                    match std::env::var(name) {
225                        Ok(env) => res.push_str(&env),
226                        Err(e) => {
227                            return Err(Error::msg(format!(
228                                "Failed to retrieve `{}` env var: {}",
229                                name, e
230                            )))
231                        },
232                    }
233                },
234
235                _ => res.push(ch),
236            }
237        }
238    }
239
240    // trailing backslash
241    if backslash {
242        res.push('\\');
243    }
244
245    Ok(res.parse()?)
246}
247
248/// How should warnings be treated?
249#[derive(Debug, Copy, Clone, PartialEq, Serialize, Deserialize)]
250#[serde(rename_all = "kebab-case")]
251pub enum WarningPolicy {
252    /// Silently ignore them.
253    Ignore,
254    /// Warn the user, but don't fail the linkcheck.
255    Warn,
256    /// Treat warnings as errors.
257    Error,
258}
259
260impl WarningPolicy {
261    pub(crate) fn to_log_level(self) -> Level {
262        match self {
263            WarningPolicy::Error => Level::Error,
264            WarningPolicy::Warn => Level::Warn,
265            WarningPolicy::Ignore => Level::Debug,
266        }
267    }
268}
269
270impl Default for WarningPolicy {
271    fn default() -> WarningPolicy { WarningPolicy::Warn }
272}
273
274#[cfg(test)]
275mod tests {
276    use super::*;
277    use std::{convert::TryInto, iter::FromIterator};
278    use toml;
279
280    const CONFIG: &str = r#"follow-web-links = true
281traverse-parent-directories = true
282exclude = ["google\\.com"]
283user-agent = "Internet Explorer"
284cache-timeout = 3600
285warning-policy = "error"
286
287[http-headers]
288https = ["accept: html/text", "authorization: Basic $TOKEN"]
289"#;
290
291    #[test]
292    fn deserialize_a_config() {
293        std::env::set_var("TOKEN", "QWxhZGRpbjpPcGVuU2VzYW1l");
294
295        let should_be = Config {
296            follow_web_links: true,
297            warning_policy: WarningPolicy::Error,
298            traverse_parent_directories: true,
299            exclude: vec![HashedRegex::new(r"google\.com").unwrap()],
300            user_agent: String::from("Internet Explorer"),
301            http_headers: HashMap::from_iter(vec![(
302                HashedRegex::new("https").unwrap(),
303                vec![
304                    "Accept: html/text".try_into().unwrap(),
305                    "Authorization: Basic $TOKEN".try_into().unwrap(),
306                ],
307            )]),
308            cache_timeout: 3600,
309        };
310
311        let got: Config = toml::from_str(CONFIG).unwrap();
312
313        assert_eq!(got, should_be);
314    }
315
316    #[test]
317    fn round_trip_config() {
318        // A check that a value of an env var is not leaked in the
319        // deserialization
320        std::env::set_var("TOKEN", "QWxhZGRpbjpPcGVuU2VzYW1l");
321
322        let deserialized: Config = toml::from_str(CONFIG).unwrap();
323        let reserialized = toml::to_string(&deserialized).unwrap();
324
325        assert_eq!(reserialized, CONFIG);
326    }
327
328    #[test]
329    fn interpolation() {
330        std::env::set_var("SUPER_SECRET_TOKEN", "abcdefg123456");
331        let header = HttpHeader {
332            name: "Authorization".parse().unwrap(),
333            value: "Basic $SUPER_SECRET_TOKEN".into(),
334        };
335        let should_be: HeaderValue = "Basic abcdefg123456".parse().unwrap();
336
337        let got = header.interpolate().unwrap();
338
339        assert_eq!(got, should_be);
340    }
341}