Skip to main content

lychee_lib/
collector.rs

1use crate::ErrorKind;
2use crate::InputSource;
3use crate::Preprocessor;
4use crate::filter::PathExcludes;
5
6use crate::types::resolver::UrlContentResolver;
7use crate::{
8    Base, Input, LycheeResult, Request, RequestError, basic_auth::BasicAuthExtractor,
9    extract::Extractor, types::FileExtensions, types::uri::raw::RawUri, utils::request,
10};
11use futures::TryStreamExt;
12use futures::{
13    StreamExt,
14    stream::{self, Stream},
15};
16use http::HeaderMap;
17use par_stream::ParStreamExt;
18use reqwest::Client;
19use std::collections::HashSet;
20use std::path::PathBuf;
21
22/// Collector keeps the state of link collection
23/// It drives the link extraction from inputs
24#[allow(clippy::struct_excessive_bools)]
25#[derive(Debug, Clone)]
26pub struct Collector {
27    basic_auth_extractor: Option<BasicAuthExtractor>,
28    skip_missing_inputs: bool,
29    skip_ignored: bool,
30    skip_hidden: bool,
31    include_verbatim: bool,
32    include_wikilinks: bool,
33    use_html5ever: bool,
34    root_dir: Option<PathBuf>,
35    base: Option<Base>,
36    excluded_paths: PathExcludes,
37    headers: HeaderMap,
38    client: Client,
39    preprocessor: Option<Preprocessor>,
40}
41
42impl Default for Collector {
43    /// # Panics
44    ///
45    /// We call [`Collector::new()`] which can panic in certain scenarios.
46    ///
47    /// Use `Collector::new()` instead if you need to handle
48    /// [`ClientBuilder`](crate::ClientBuilder) errors gracefully.
49    fn default() -> Self {
50        Collector {
51            basic_auth_extractor: None,
52            skip_missing_inputs: false,
53            include_verbatim: false,
54            include_wikilinks: false,
55            use_html5ever: false,
56            skip_hidden: true,
57            skip_ignored: true,
58            root_dir: None,
59            base: None,
60            headers: HeaderMap::new(),
61            client: Client::new(),
62            excluded_paths: PathExcludes::empty(),
63            preprocessor: None,
64        }
65    }
66}
67
68impl Collector {
69    /// Create a new collector with an empty cache
70    ///
71    /// # Errors
72    ///
73    /// Returns an `Err` if the `root_dir` is not an absolute path
74    /// or if the reqwest `Client` fails to build
75    pub fn new(root_dir: Option<PathBuf>, base: Option<Base>) -> LycheeResult<Self> {
76        let root_dir = match root_dir {
77            Some(root_dir) if base.is_some() => Some(root_dir),
78            Some(root_dir) => {
79                let root_dir_exists = root_dir.read_dir().map(|_| ());
80                let root_dir = root_dir_exists
81                    .and_then(|()| std::path::absolute(&root_dir))
82                    .map_err(|e| ErrorKind::InvalidRootDir(root_dir, e))?;
83                Some(root_dir)
84            }
85            None => None,
86        };
87        Ok(Collector {
88            basic_auth_extractor: None,
89            skip_missing_inputs: false,
90            include_verbatim: false,
91            include_wikilinks: false,
92            use_html5ever: false,
93            skip_hidden: true,
94            skip_ignored: true,
95            preprocessor: None,
96            headers: HeaderMap::new(),
97            client: Client::builder()
98                .build()
99                .map_err(ErrorKind::BuildRequestClient)?,
100            excluded_paths: PathExcludes::empty(),
101            root_dir,
102            base,
103        })
104    }
105
106    /// Skip missing input files (default is to error if they don't exist)
107    #[must_use]
108    pub const fn skip_missing_inputs(mut self, yes: bool) -> Self {
109        self.skip_missing_inputs = yes;
110        self
111    }
112
113    /// Skip files that are hidden
114    #[must_use]
115    pub const fn skip_hidden(mut self, yes: bool) -> Self {
116        self.skip_hidden = yes;
117        self
118    }
119
120    /// Skip files that are ignored
121    #[must_use]
122    pub const fn skip_ignored(mut self, yes: bool) -> Self {
123        self.skip_ignored = yes;
124        self
125    }
126
127    /// Set headers to use when resolving input URLs
128    #[must_use]
129    pub fn headers(mut self, headers: HeaderMap) -> Self {
130        self.headers = headers;
131        self
132    }
133
134    /// Set client to use for checking input URLs
135    #[must_use]
136    pub fn client(mut self, client: Client) -> Self {
137        self.client = client;
138        self
139    }
140
141    /// Use `html5ever` to parse HTML instead of `html5gum`.
142    #[must_use]
143    pub const fn use_html5ever(mut self, yes: bool) -> Self {
144        self.use_html5ever = yes;
145        self
146    }
147
148    /// Skip over links in verbatim sections (like Markdown code blocks)
149    #[must_use]
150    pub const fn include_verbatim(mut self, yes: bool) -> Self {
151        self.include_verbatim = yes;
152        self
153    }
154
155    /// Check WikiLinks in Markdown files
156    #[allow(clippy::doc_markdown)]
157    #[must_use]
158    pub const fn include_wikilinks(mut self, yes: bool) -> Self {
159        self.include_wikilinks = yes;
160        self
161    }
162
163    /// Configure a file [`Preprocessor`]
164    #[must_use]
165    pub fn preprocessor(mut self, preprocessor: Option<Preprocessor>) -> Self {
166        self.preprocessor = preprocessor;
167        self
168    }
169
170    /// Pass a [`BasicAuthExtractor`] which is capable to match found
171    /// URIs to basic auth credentials. These credentials get passed to the
172    /// request in question.
173    #[must_use]
174    #[allow(clippy::missing_const_for_fn)]
175    pub fn basic_auth_extractor(mut self, extractor: BasicAuthExtractor) -> Self {
176        self.basic_auth_extractor = Some(extractor);
177        self
178    }
179
180    /// Configure which paths to exclude
181    #[must_use]
182    pub fn excluded_paths(mut self, excluded_paths: PathExcludes) -> Self {
183        self.excluded_paths = excluded_paths;
184        self
185    }
186
187    /// Convenience method to fetch all unique links from inputs
188    /// with the default extensions.
189    pub fn collect_links(
190        self,
191        inputs: HashSet<Input>,
192    ) -> impl Stream<Item = Result<Request, RequestError>> {
193        self.collect_links_from_file_types(inputs, crate::types::FileType::default_extensions())
194    }
195
196    /// Fetch all unique links from inputs
197    /// All relative URLs get prefixed with `base` (if given).
198    /// (This can be a directory or a base URL)
199    ///
200    /// # Errors
201    ///
202    /// Will return `Err` if links cannot be extracted from an input
203    pub fn collect_links_from_file_types(
204        self,
205        inputs: HashSet<Input>,
206        extensions: FileExtensions,
207    ) -> impl Stream<Item = Result<Request, RequestError>> {
208        let skip_missing_inputs = self.skip_missing_inputs;
209        let skip_hidden = self.skip_hidden;
210        let skip_ignored = self.skip_ignored;
211        let global_base = self.base;
212        let excluded_paths = self.excluded_paths;
213
214        let resolver = UrlContentResolver {
215            basic_auth_extractor: self.basic_auth_extractor.clone(),
216            headers: self.headers.clone(),
217            client: self.client,
218        };
219
220        let extractor = Extractor::new(
221            self.use_html5ever,
222            self.include_verbatim,
223            self.include_wikilinks,
224        );
225
226        stream::iter(inputs)
227            .par_then_unordered(None, move |input| {
228                let default_base = global_base.clone();
229                let extensions = extensions.clone();
230                let resolver = resolver.clone();
231                let excluded_paths = excluded_paths.clone();
232                let preprocessor = self.preprocessor.clone();
233
234                async move {
235                    let base = match &input.source {
236                        InputSource::RemoteUrl(url) => Base::try_from(url.as_str()).ok(),
237                        _ => default_base,
238                    };
239
240                    input
241                        .get_contents(
242                            skip_missing_inputs,
243                            skip_hidden,
244                            skip_ignored,
245                            extensions,
246                            resolver,
247                            excluded_paths,
248                            preprocessor,
249                        )
250                        .map(move |content| (content, base.clone()))
251                }
252            })
253            .flatten()
254            .par_then_unordered(None, move |(content, base)| {
255                let root_dir = self.root_dir.clone();
256                let basic_auth_extractor = self.basic_auth_extractor.clone();
257                async move {
258                    let content = content?;
259                    let uris: Vec<RawUri> = extractor.extract(&content);
260                    let requests = request::create(
261                        uris,
262                        &content.source,
263                        root_dir.as_ref(),
264                        base.as_ref(),
265                        basic_auth_extractor.as_ref(),
266                    );
267                    Result::Ok(stream::iter(requests))
268                }
269            })
270            .try_flatten()
271    }
272}
273
274#[cfg(test)]
275mod tests {
276    use std::borrow::Cow;
277    use std::{collections::HashSet, convert::TryFrom, fs::File, io::Write};
278    use test_utils::{fixtures_path, load_fixture, mail, mock_server, path, website};
279
280    use http::StatusCode;
281    use reqwest::Url;
282
283    use super::*;
284    use crate::{
285        LycheeResult, Uri,
286        filter::PathExcludes,
287        types::{FileType, Input, InputSource},
288    };
289
290    // Helper function to run the collector on the given inputs
291    async fn collect(
292        inputs: HashSet<Input>,
293        root_dir: Option<PathBuf>,
294        base: Option<Base>,
295    ) -> LycheeResult<HashSet<Uri>> {
296        let responses = Collector::new(root_dir, base)?.collect_links(inputs);
297        Ok(responses.map(|r| r.unwrap().uri).collect().await)
298    }
299
300    /// Helper function for collecting verbatim links
301    ///
302    /// A verbatim link is a link that is not parsed by the HTML parser.
303    /// For example, a link in a code block or a script tag.
304    async fn collect_verbatim(
305        inputs: HashSet<Input>,
306        root_dir: Option<PathBuf>,
307        base: Option<Base>,
308        extensions: FileExtensions,
309    ) -> LycheeResult<HashSet<Uri>> {
310        let responses = Collector::new(root_dir, base)?
311            .include_verbatim(true)
312            .collect_links_from_file_types(inputs, extensions);
313        Ok(responses.map(|r| r.unwrap().uri).collect().await)
314    }
315
316    const TEST_STRING: &str = "http://test-string.com";
317    const TEST_URL: &str = "https://test-url.org";
318    const TEST_FILE: &str = "https://test-file.io";
319    const TEST_GLOB_1: &str = "https://test-glob-1.io";
320    const TEST_GLOB_2_MAIL: &str = "test@glob-2.io";
321
322    #[tokio::test]
323    async fn test_file_without_extension_is_plaintext() -> LycheeResult<()> {
324        let temp_dir = tempfile::tempdir().unwrap();
325        // Treat as plaintext file (no extension)
326        let file_path = temp_dir.path().join("README");
327        let _file = File::create(&file_path).unwrap();
328        let input = Input::new(&file_path.as_path().display().to_string(), None, true)?;
329        let contents: Vec<_> = input
330            .get_contents(
331                true,
332                true,
333                true,
334                FileType::default_extensions(),
335                UrlContentResolver::default(),
336                PathExcludes::empty(),
337                None,
338            )
339            .collect::<Vec<_>>()
340            .await;
341
342        assert_eq!(contents.len(), 1);
343        assert_eq!(contents[0].as_ref().unwrap().file_type, FileType::Plaintext);
344        Ok(())
345    }
346
347    #[tokio::test]
348    async fn test_url_without_extension_is_html() -> LycheeResult<()> {
349        let input = Input::new("https://example.com/", None, true)?;
350        let contents: Vec<_> = input
351            .get_contents(
352                true,
353                true,
354                true,
355                FileType::default_extensions(),
356                UrlContentResolver::default(),
357                PathExcludes::empty(),
358                None,
359            )
360            .collect::<Vec<_>>()
361            .await;
362
363        assert_eq!(contents.len(), 1);
364        assert_eq!(contents[0].as_ref().unwrap().file_type, FileType::Html);
365        Ok(())
366    }
367
368    #[tokio::test]
369    async fn test_collect_links() -> LycheeResult<()> {
370        let temp_dir = tempfile::tempdir().unwrap();
371        let temp_dir_path = temp_dir.path();
372
373        let file_path = temp_dir_path.join("f");
374        let file_glob_1_path = temp_dir_path.join("glob-1");
375        let file_glob_2_path = temp_dir_path.join("glob-2");
376
377        let mut file = File::create(&file_path).unwrap();
378        let mut file_glob_1 = File::create(file_glob_1_path).unwrap();
379        let mut file_glob_2 = File::create(file_glob_2_path).unwrap();
380
381        writeln!(file, "{TEST_FILE}").unwrap();
382        writeln!(file_glob_1, "{TEST_GLOB_1}").unwrap();
383        writeln!(file_glob_2, "{TEST_GLOB_2_MAIL}").unwrap();
384
385        let mock_server = mock_server!(StatusCode::OK, set_body_string(TEST_URL));
386
387        let inputs = HashSet::from_iter([
388            Input::from_input_source(InputSource::String(Cow::Borrowed(TEST_STRING))),
389            Input::from_input_source(InputSource::RemoteUrl(Box::new(
390                Url::parse(&mock_server.uri())
391                    .map_err(|e| (mock_server.uri(), e))
392                    .unwrap(),
393            ))),
394            Input::from_input_source(InputSource::FsPath(file_path)),
395            Input::from_input_source(InputSource::FsGlob {
396                pattern: glob::Pattern::new(&temp_dir_path.join("glob*").to_string_lossy())?,
397                ignore_case: true,
398            }),
399        ]);
400
401        let links = collect_verbatim(inputs, None, None, FileType::default_extensions())
402            .await
403            .ok()
404            .unwrap();
405
406        let expected_links = HashSet::from_iter([
407            website!(TEST_STRING),
408            website!(TEST_URL),
409            website!(TEST_FILE),
410            website!(TEST_GLOB_1),
411            mail!(TEST_GLOB_2_MAIL),
412        ]);
413
414        assert_eq!(links, expected_links);
415
416        Ok(())
417    }
418
419    #[tokio::test]
420    async fn test_collect_markdown_links() {
421        let base = Base::try_from("https://github.com/hello-rust/lychee/").unwrap();
422        let input = Input {
423            source: InputSource::String(Cow::Borrowed(
424                "This is [a test](https://endler.dev). This is a relative link test [Relative Link Test](relative_link)",
425            )),
426            file_type_hint: Some(FileType::Markdown),
427        };
428        let inputs = HashSet::from_iter([input]);
429
430        let links = collect(inputs, None, Some(base)).await.ok().unwrap();
431
432        let expected_links = HashSet::from_iter([
433            website!("https://endler.dev"),
434            website!("https://github.com/hello-rust/lychee/relative_link"),
435        ]);
436
437        assert_eq!(links, expected_links);
438    }
439
440    #[tokio::test]
441    async fn test_collect_html_links() {
442        let base = Base::try_from("https://github.com/lycheeverse/").unwrap();
443        let input = Input {
444            source: InputSource::String(Cow::Borrowed(
445                r#"<html>
446                <div class="row">
447                    <a href="https://github.com/lycheeverse/lychee/">
448                    <a href="blob/master/README.md">README</a>
449                </div>
450            </html>"#,
451            )),
452            file_type_hint: Some(FileType::Html),
453        };
454        let inputs = HashSet::from_iter([input]);
455
456        let links = collect(inputs, None, Some(base)).await.ok().unwrap();
457
458        let expected_links = HashSet::from_iter([
459            website!("https://github.com/lycheeverse/lychee/"),
460            website!("https://github.com/lycheeverse/blob/master/README.md"),
461        ]);
462
463        assert_eq!(links, expected_links);
464    }
465
466    #[tokio::test]
467    async fn test_collect_html_srcset() {
468        let base = Base::try_from("https://example.com/").unwrap();
469        let input = Input {
470            source: InputSource::String(Cow::Borrowed(
471                r#"
472            <img
473                src="/static/image.png"
474                srcset="
475                /static/image300.png  300w,
476                /static/image600.png  600w,
477                "
478            />
479          "#,
480            )),
481            file_type_hint: Some(FileType::Html),
482        };
483        let inputs = HashSet::from_iter([input]);
484
485        let links = collect(inputs, None, Some(base)).await.ok().unwrap();
486
487        let expected_links = HashSet::from_iter([
488            website!("https://example.com/static/image.png"),
489            website!("https://example.com/static/image300.png"),
490            website!("https://example.com/static/image600.png"),
491        ]);
492
493        assert_eq!(links, expected_links);
494    }
495
496    #[tokio::test]
497    async fn test_markdown_internal_url() {
498        let base = Base::try_from("https://localhost.com/").unwrap();
499
500        let input = Input {
501            source: InputSource::String(Cow::Borrowed(
502                "This is [an internal url](@/internal.md)
503        This is [an internal url](@/internal.markdown)
504        This is [an internal url](@/internal.markdown#example)
505        This is [an internal url](@/internal.md#example)",
506            )),
507            file_type_hint: Some(FileType::Markdown),
508        };
509        let inputs = HashSet::from_iter([input]);
510
511        let links = collect(inputs, None, Some(base)).await.ok().unwrap();
512
513        let expected = HashSet::from_iter([
514            website!("https://localhost.com/@/internal.md"),
515            website!("https://localhost.com/@/internal.markdown"),
516            website!("https://localhost.com/@/internal.md#example"),
517            website!("https://localhost.com/@/internal.markdown#example"),
518        ]);
519
520        assert_eq!(links, expected);
521    }
522
523    #[tokio::test]
524    async fn test_extract_html5_not_valid_xml_relative_links() {
525        let base = Base::try_from("https://example.com").unwrap();
526        let input = load_fixture!("TEST_HTML5.html");
527
528        let input = Input {
529            source: InputSource::String(Cow::Owned(input)),
530            file_type_hint: Some(FileType::Html),
531        };
532        let inputs = HashSet::from_iter([input]);
533
534        let links = collect(inputs, None, Some(base)).await.ok().unwrap();
535
536        let expected_links = HashSet::from_iter([
537            // the body links wouldn't be present if the file was parsed strictly as XML
538            website!("https://example.com/body/a"),
539            website!("https://example.com/body/div_empty_a"),
540            website!("https://example.com/css/style_full_url.css"),
541            website!("https://example.com/css/style_relative_url.css"),
542            website!("https://example.com/head/home"),
543            website!("https://example.com/images/icon.png"),
544        ]);
545
546        assert_eq!(links, expected_links);
547    }
548
549    #[tokio::test]
550    async fn test_relative_url_with_base_extracted_from_input() {
551        let contents = r#"<html>
552            <div class="row">
553                <a href="https://github.com/lycheeverse/lychee/">GitHub</a>
554                <a href="/about">About</a>
555            </div>
556        </html>"#;
557        let mock_server = mock_server!(StatusCode::OK, set_body_string(contents));
558
559        let server_uri = Url::parse(&mock_server.uri()).unwrap();
560
561        let input = Input::from_input_source(InputSource::RemoteUrl(Box::new(server_uri.clone())));
562
563        let inputs = HashSet::from_iter([input]);
564
565        let links = collect(inputs, None, None).await.ok().unwrap();
566
567        let expected_urls = HashSet::from_iter([
568            website!("https://github.com/lycheeverse/lychee/"),
569            website!(&format!("{server_uri}about")),
570        ]);
571
572        assert_eq!(links, expected_urls);
573    }
574
575    #[tokio::test]
576    async fn test_email_with_query_params() {
577        let input = Input::from_input_source(InputSource::String(Cow::Borrowed(
578            "This is a mailto:user@example.com?subject=Hello link",
579        )));
580
581        let inputs = HashSet::from_iter([input]);
582
583        let links = collect(inputs, None, None).await.ok().unwrap();
584
585        let expected_links = HashSet::from_iter([mail!("user@example.com")]);
586
587        assert_eq!(links, expected_links);
588    }
589
590    #[tokio::test]
591    async fn test_multiple_remote_urls() {
592        let mock_server_1 = mock_server!(
593            StatusCode::OK,
594            set_body_string(r#"<a href="relative.html">Link</a>"#)
595        );
596        let mock_server_2 = mock_server!(
597            StatusCode::OK,
598            set_body_string(r#"<a href="relative.html">Link</a>"#)
599        );
600
601        let inputs = HashSet::from_iter([
602            Input {
603                source: InputSource::RemoteUrl(Box::new(
604                    Url::parse(&format!(
605                        "{}/foo/index.html",
606                        mock_server_1.uri().trim_end_matches('/')
607                    ))
608                    .unwrap(),
609                )),
610                file_type_hint: Some(FileType::Html),
611            },
612            Input {
613                source: InputSource::RemoteUrl(Box::new(
614                    Url::parse(&format!(
615                        "{}/bar/index.html",
616                        mock_server_2.uri().trim_end_matches('/')
617                    ))
618                    .unwrap(),
619                )),
620                file_type_hint: Some(FileType::Html),
621            },
622        ]);
623
624        let links = collect(inputs, None, None).await.ok().unwrap();
625
626        let expected_links = HashSet::from_iter([
627            website!(&format!(
628                "{}/foo/relative.html",
629                mock_server_1.uri().trim_end_matches('/')
630            )),
631            website!(&format!(
632                "{}/bar/relative.html",
633                mock_server_2.uri().trim_end_matches('/')
634            )),
635        ]);
636
637        assert_eq!(links, expected_links);
638    }
639
640    #[tokio::test]
641    async fn test_file_path_with_base() {
642        let base = Base::try_from("/path/to/root").unwrap();
643        assert_eq!(base, Base::Local("/path/to/root".into()));
644
645        let input = Input {
646            source: InputSource::String(Cow::Borrowed(
647                r#"
648                <a href="index.html">Index</a>
649                <a href="about.html">About</a>
650                <a href="/another.html">Another</a>
651            "#,
652            )),
653            file_type_hint: Some(FileType::Html),
654        };
655
656        let inputs = HashSet::from_iter([input]);
657
658        let links = collect(inputs, None, Some(base)).await.ok().unwrap();
659
660        let expected_links = HashSet::from_iter([
661            path!("/path/to/root/index.html"),
662            path!("/path/to/root/about.html"),
663            path!("/another.html"),
664        ]);
665
666        assert_eq!(links, expected_links);
667    }
668}