lychee_lib/
collector.rs

1use crate::ErrorKind;
2use crate::InputSource;
3use crate::filter::PathExcludes;
4use crate::types::resolver::UrlContentResolver;
5use crate::{
6    Base, Input, InputResolver, Request, Result, basic_auth::BasicAuthExtractor,
7    extract::Extractor, types::FileExtensions, types::uri::raw::RawUri, utils::request,
8};
9use dashmap::DashSet;
10use futures::TryStreamExt;
11use futures::{
12    StreamExt,
13    stream::{self, Stream},
14};
15use http::HeaderMap;
16use par_stream::ParStreamExt;
17use reqwest::Client;
18use std::collections::HashSet;
19use std::path::PathBuf;
20use std::sync::Arc;
21
22/// Collector keeps the state of link collection
23/// It drives the link extraction from inputs
24#[allow(clippy::struct_excessive_bools)]
25#[derive(Debug, Clone)]
26pub struct Collector {
27    basic_auth_extractor: Option<BasicAuthExtractor>,
28    skip_missing_inputs: bool,
29    skip_ignored: bool,
30    skip_hidden: bool,
31    include_verbatim: bool,
32    include_wikilinks: bool,
33    use_html5ever: bool,
34    root_dir: Option<PathBuf>,
35    base: Option<Base>,
36    excluded_paths: PathExcludes,
37    headers: HeaderMap,
38    client: Client,
39}
40
41impl Default for Collector {
42    /// # Panics
43    ///
44    /// We call [`Collector::new()`] which can panic in certain scenarios.
45    ///
46    /// Use `Collector::new()` instead if you need to handle
47    /// [`ClientBuilder`](crate::ClientBuilder) errors gracefully.
48    fn default() -> Self {
49        Collector {
50            basic_auth_extractor: None,
51            skip_missing_inputs: false,
52            include_verbatim: false,
53            include_wikilinks: false,
54            use_html5ever: false,
55            skip_hidden: true,
56            skip_ignored: true,
57            root_dir: None,
58            base: None,
59            headers: HeaderMap::new(),
60            client: Client::new(),
61            excluded_paths: PathExcludes::empty(),
62        }
63    }
64}
65
66impl Collector {
67    /// Create a new collector with an empty cache
68    ///
69    /// # Errors
70    ///
71    /// Returns an `Err` if the `root_dir` is not an absolute path
72    /// or if the reqwest `Client` fails to build
73    pub fn new(root_dir: Option<PathBuf>, base: Option<Base>) -> Result<Self> {
74        if let Some(root_dir) = &root_dir {
75            if root_dir.is_relative() {
76                return Err(ErrorKind::RootDirMustBeAbsolute(root_dir.clone()));
77            }
78        }
79        Ok(Collector {
80            basic_auth_extractor: None,
81            skip_missing_inputs: false,
82            include_verbatim: false,
83            include_wikilinks: false,
84            use_html5ever: false,
85            skip_hidden: true,
86            skip_ignored: true,
87            headers: HeaderMap::new(),
88            client: Client::builder()
89                .build()
90                .map_err(ErrorKind::BuildRequestClient)?,
91            excluded_paths: PathExcludes::empty(),
92            root_dir,
93            base,
94        })
95    }
96
97    /// Skip missing input files (default is to error if they don't exist)
98    #[must_use]
99    pub const fn skip_missing_inputs(mut self, yes: bool) -> Self {
100        self.skip_missing_inputs = yes;
101        self
102    }
103
104    /// Skip files that are hidden
105    #[must_use]
106    pub const fn skip_hidden(mut self, yes: bool) -> Self {
107        self.skip_hidden = yes;
108        self
109    }
110
111    /// Skip files that are ignored
112    #[must_use]
113    pub const fn skip_ignored(mut self, yes: bool) -> Self {
114        self.skip_ignored = yes;
115        self
116    }
117
118    /// Set headers to use when resolving input URLs
119    #[must_use]
120    pub fn headers(mut self, headers: HeaderMap) -> Self {
121        self.headers = headers;
122        self
123    }
124
125    /// Set client to use for checking input URLs
126    #[must_use]
127    pub fn client(mut self, client: Client) -> Self {
128        self.client = client;
129        self
130    }
131
132    /// Use `html5ever` to parse HTML instead of `html5gum`.
133    #[must_use]
134    pub const fn use_html5ever(mut self, yes: bool) -> Self {
135        self.use_html5ever = yes;
136        self
137    }
138
139    /// Skip over links in verbatim sections (like Markdown code blocks)
140    #[must_use]
141    pub const fn include_verbatim(mut self, yes: bool) -> Self {
142        self.include_verbatim = yes;
143        self
144    }
145
146    #[allow(clippy::doc_markdown)]
147    /// Check WikiLinks in Markdown files
148    #[must_use]
149    pub const fn include_wikilinks(mut self, yes: bool) -> Self {
150        self.include_wikilinks = yes;
151        self
152    }
153
154    /// Pass a [`BasicAuthExtractor`] which is capable to match found
155    /// URIs to basic auth credentials. These credentials get passed to the
156    /// request in question.
157    #[must_use]
158    #[allow(clippy::missing_const_for_fn)]
159    pub fn basic_auth_extractor(mut self, extractor: BasicAuthExtractor) -> Self {
160        self.basic_auth_extractor = Some(extractor);
161        self
162    }
163
164    /// Configure which paths to exclude
165    #[must_use]
166    pub fn excluded_paths(mut self, excluded_paths: PathExcludes) -> Self {
167        self.excluded_paths = excluded_paths;
168        self
169    }
170
171    /// Collect all sources from a list of [`Input`]s. For further details,
172    /// see also [`Input::get_sources`](crate::Input#method.get_sources).
173    pub fn collect_sources(self, inputs: HashSet<Input>) -> impl Stream<Item = Result<String>> {
174        self.collect_sources_with_file_types(inputs, crate::types::FileType::default_extensions())
175    }
176
177    /// Collect all sources from a list of [`Input`]s with specific file extensions.
178    pub fn collect_sources_with_file_types(
179        self,
180        inputs: HashSet<Input>,
181        file_extensions: FileExtensions,
182    ) -> impl Stream<Item = Result<String>> + 'static {
183        let seen = Arc::new(DashSet::new());
184        let skip_hidden = self.skip_hidden;
185        let skip_ignored = self.skip_ignored;
186        let excluded_paths = self.excluded_paths;
187
188        stream::iter(inputs)
189            .par_then_unordered(None, move |input| {
190                let excluded_paths = excluded_paths.clone();
191                let file_extensions = file_extensions.clone();
192                async move {
193                    let input_sources = InputResolver::resolve(
194                        &input,
195                        file_extensions,
196                        skip_hidden,
197                        skip_ignored,
198                        &excluded_paths,
199                    );
200
201                    input_sources
202                        .map(|source| source.map(|source| source.to_string()))
203                        .collect::<Vec<_>>()
204                        .await
205                }
206            })
207            .map(stream::iter)
208            .flatten()
209            .filter_map({
210                move |source: Result<String>| {
211                    let seen = Arc::clone(&seen);
212                    async move {
213                        if let Ok(s) = &source {
214                            if !seen.insert(s.clone()) {
215                                return None;
216                            }
217                        }
218                        Some(source)
219                    }
220                }
221            })
222    }
223
224    /// Convenience method to fetch all unique links from inputs
225    /// with the default extensions.
226    pub fn collect_links(self, inputs: HashSet<Input>) -> impl Stream<Item = Result<Request>> {
227        self.collect_links_from_file_types(inputs, crate::types::FileType::default_extensions())
228    }
229
230    /// Fetch all unique links from inputs
231    /// All relative URLs get prefixed with `base` (if given).
232    /// (This can be a directory or a base URL)
233    ///
234    /// # Errors
235    ///
236    /// Will return `Err` if links cannot be extracted from an input
237    pub fn collect_links_from_file_types(
238        self,
239        inputs: HashSet<Input>,
240        extensions: FileExtensions,
241    ) -> impl Stream<Item = Result<Request>> {
242        let skip_missing_inputs = self.skip_missing_inputs;
243        let skip_hidden = self.skip_hidden;
244        let skip_ignored = self.skip_ignored;
245        let global_base = self.base;
246        let excluded_paths = self.excluded_paths;
247
248        let resolver = UrlContentResolver {
249            basic_auth_extractor: self.basic_auth_extractor.clone(),
250            headers: self.headers.clone(),
251            client: self.client,
252        };
253
254        let extractor = Extractor::new(
255            self.use_html5ever,
256            self.include_verbatim,
257            self.include_wikilinks,
258        );
259
260        stream::iter(inputs)
261            .par_then_unordered(None, move |input| {
262                let default_base = global_base.clone();
263                let extensions = extensions.clone();
264                let resolver = resolver.clone();
265                let excluded_paths = excluded_paths.clone();
266
267                async move {
268                    let base = match &input.source {
269                        InputSource::RemoteUrl(url) => Base::try_from(url.as_str()).ok(),
270                        _ => default_base,
271                    };
272
273                    input
274                        .get_contents(
275                            skip_missing_inputs,
276                            skip_hidden,
277                            skip_ignored,
278                            extensions,
279                            resolver,
280                            excluded_paths,
281                        )
282                        .map(move |content| (content, base.clone()))
283                }
284            })
285            .flatten()
286            .par_then_unordered(None, move |(content, base)| {
287                let root_dir = self.root_dir.clone();
288                let basic_auth_extractor = self.basic_auth_extractor.clone();
289                async move {
290                    let content = content?;
291                    let uris: Vec<RawUri> = extractor.extract(&content);
292                    let requests = request::create(
293                        uris,
294                        &content.source,
295                        root_dir.as_ref(),
296                        base.as_ref(),
297                        basic_auth_extractor.as_ref(),
298                    );
299                    Result::Ok(stream::iter(requests.into_iter().map(Ok)))
300                }
301            })
302            .try_flatten()
303    }
304}
305
306#[cfg(test)]
307mod tests {
308    use std::borrow::Cow;
309    use std::{collections::HashSet, convert::TryFrom, fs::File, io::Write};
310    use test_utils::{fixtures_path, load_fixture, mail, mock_server, path, website};
311
312    use http::StatusCode;
313    use reqwest::Url;
314
315    use super::*;
316    use crate::{
317        Result, Uri,
318        filter::PathExcludes,
319        types::{FileType, Input, InputSource},
320    };
321
322    // Helper function to run the collector on the given inputs
323    async fn collect(
324        inputs: HashSet<Input>,
325        root_dir: Option<PathBuf>,
326        base: Option<Base>,
327    ) -> Result<HashSet<Uri>> {
328        let responses = Collector::new(root_dir, base)?.collect_links(inputs);
329        Ok(responses.map(|r| r.unwrap().uri).collect().await)
330    }
331
332    /// Helper function for collecting verbatim links
333    ///
334    /// A verbatim link is a link that is not parsed by the HTML parser.
335    /// For example, a link in a code block or a script tag.
336    async fn collect_verbatim(
337        inputs: HashSet<Input>,
338        root_dir: Option<PathBuf>,
339        base: Option<Base>,
340        extensions: FileExtensions,
341    ) -> Result<HashSet<Uri>> {
342        let responses = Collector::new(root_dir, base)?
343            .include_verbatim(true)
344            .collect_links_from_file_types(inputs, extensions);
345        Ok(responses.map(|r| r.unwrap().uri).collect().await)
346    }
347
348    const TEST_STRING: &str = "http://test-string.com";
349    const TEST_URL: &str = "https://test-url.org";
350    const TEST_FILE: &str = "https://test-file.io";
351    const TEST_GLOB_1: &str = "https://test-glob-1.io";
352    const TEST_GLOB_2_MAIL: &str = "test@glob-2.io";
353
354    #[tokio::test]
355    async fn test_file_without_extension_is_plaintext() -> Result<()> {
356        let temp_dir = tempfile::tempdir().unwrap();
357        // Treat as plaintext file (no extension)
358        let file_path = temp_dir.path().join("README");
359        let _file = File::create(&file_path).unwrap();
360        let input = Input::new(&file_path.as_path().display().to_string(), None, true)?;
361        let contents: Vec<_> = input
362            .get_contents(
363                true,
364                true,
365                true,
366                FileType::default_extensions(),
367                UrlContentResolver::default(),
368                PathExcludes::empty(),
369            )
370            .collect::<Vec<_>>()
371            .await;
372
373        assert_eq!(contents.len(), 1);
374        assert_eq!(contents[0].as_ref().unwrap().file_type, FileType::Plaintext);
375        Ok(())
376    }
377
378    #[tokio::test]
379    async fn test_url_without_extension_is_html() -> Result<()> {
380        let input = Input::new("https://example.com/", None, true)?;
381        let contents: Vec<_> = input
382            .get_contents(
383                true,
384                true,
385                true,
386                FileType::default_extensions(),
387                UrlContentResolver::default(),
388                PathExcludes::empty(),
389            )
390            .collect::<Vec<_>>()
391            .await;
392
393        assert_eq!(contents.len(), 1);
394        assert_eq!(contents[0].as_ref().unwrap().file_type, FileType::Html);
395        Ok(())
396    }
397
398    #[tokio::test]
399    async fn test_collect_sources() -> Result<()> {
400        let temp_dir = tempfile::tempdir().unwrap();
401        let temp_dir_path = temp_dir.path();
402
403        std::env::set_current_dir(temp_dir_path)?;
404
405        let file_path = temp_dir_path.join("markdown.md");
406        File::create(&file_path).unwrap();
407
408        let file_path = temp_dir_path.join("README");
409        File::create(&file_path).unwrap();
410
411        let inputs = HashSet::from_iter([
412            Input::from_input_source(InputSource::FsGlob {
413                pattern: "*.md".to_string(),
414                ignore_case: true,
415            }),
416            Input::from_input_source(InputSource::FsGlob {
417                pattern: "markdown.*".to_string(),
418                ignore_case: true,
419            }),
420        ]);
421
422        let collector = Collector::new(Some(temp_dir_path.to_path_buf()), None)?;
423
424        let sources: Vec<_> = collector.collect_sources(inputs).collect().await;
425
426        assert_eq!(sources.len(), 1);
427        assert_eq!(sources[0], Ok("markdown.md".to_string()));
428
429        return Ok(());
430    }
431
432    #[tokio::test]
433    async fn test_collect_links() -> Result<()> {
434        let temp_dir = tempfile::tempdir().unwrap();
435        let temp_dir_path = temp_dir.path();
436
437        let file_path = temp_dir_path.join("f");
438        let file_glob_1_path = temp_dir_path.join("glob-1");
439        let file_glob_2_path = temp_dir_path.join("glob-2");
440
441        let mut file = File::create(&file_path).unwrap();
442        let mut file_glob_1 = File::create(file_glob_1_path).unwrap();
443        let mut file_glob_2 = File::create(file_glob_2_path).unwrap();
444
445        writeln!(file, "{TEST_FILE}").unwrap();
446        writeln!(file_glob_1, "{TEST_GLOB_1}").unwrap();
447        writeln!(file_glob_2, "{TEST_GLOB_2_MAIL}").unwrap();
448
449        let mock_server = mock_server!(StatusCode::OK, set_body_string(TEST_URL));
450
451        let inputs = HashSet::from_iter([
452            Input::from_input_source(InputSource::String(Cow::Borrowed(TEST_STRING))),
453            Input::from_input_source(InputSource::RemoteUrl(Box::new(
454                Url::parse(&mock_server.uri())
455                    .map_err(|e| (mock_server.uri(), e))
456                    .unwrap(),
457            ))),
458            Input::from_input_source(InputSource::FsPath(file_path)),
459            Input::from_input_source(InputSource::FsGlob {
460                pattern: temp_dir_path.join("glob*").to_str().unwrap().to_owned(),
461                ignore_case: true,
462            }),
463        ]);
464
465        let links = collect_verbatim(inputs, None, None, FileType::default_extensions())
466            .await
467            .ok()
468            .unwrap();
469
470        let expected_links = HashSet::from_iter([
471            website!(TEST_STRING),
472            website!(TEST_URL),
473            website!(TEST_FILE),
474            website!(TEST_GLOB_1),
475            mail!(TEST_GLOB_2_MAIL),
476        ]);
477
478        assert_eq!(links, expected_links);
479
480        Ok(())
481    }
482
483    #[tokio::test]
484    async fn test_collect_markdown_links() {
485        let base = Base::try_from("https://github.com/hello-rust/lychee/").unwrap();
486        let input = Input {
487            source: InputSource::String(Cow::Borrowed(
488                "This is [a test](https://endler.dev). This is a relative link test [Relative Link Test](relative_link)",
489            )),
490            file_type_hint: Some(FileType::Markdown),
491        };
492        let inputs = HashSet::from_iter([input]);
493
494        let links = collect(inputs, None, Some(base)).await.ok().unwrap();
495
496        let expected_links = HashSet::from_iter([
497            website!("https://endler.dev"),
498            website!("https://github.com/hello-rust/lychee/relative_link"),
499        ]);
500
501        assert_eq!(links, expected_links);
502    }
503
504    #[tokio::test]
505    async fn test_collect_html_links() {
506        let base = Base::try_from("https://github.com/lycheeverse/").unwrap();
507        let input = Input {
508            source: InputSource::String(Cow::Borrowed(
509                r#"<html>
510                <div class="row">
511                    <a href="https://github.com/lycheeverse/lychee/">
512                    <a href="blob/master/README.md">README</a>
513                </div>
514            </html>"#,
515            )),
516            file_type_hint: Some(FileType::Html),
517        };
518        let inputs = HashSet::from_iter([input]);
519
520        let links = collect(inputs, None, Some(base)).await.ok().unwrap();
521
522        let expected_links = HashSet::from_iter([
523            website!("https://github.com/lycheeverse/lychee/"),
524            website!("https://github.com/lycheeverse/blob/master/README.md"),
525        ]);
526
527        assert_eq!(links, expected_links);
528    }
529
530    #[tokio::test]
531    async fn test_collect_html_srcset() {
532        let base = Base::try_from("https://example.com/").unwrap();
533        let input = Input {
534            source: InputSource::String(Cow::Borrowed(
535                r#"
536            <img
537                src="/static/image.png"
538                srcset="
539                /static/image300.png  300w,
540                /static/image600.png  600w,
541                "
542            />
543          "#,
544            )),
545            file_type_hint: Some(FileType::Html),
546        };
547        let inputs = HashSet::from_iter([input]);
548
549        let links = collect(inputs, None, Some(base)).await.ok().unwrap();
550
551        let expected_links = HashSet::from_iter([
552            website!("https://example.com/static/image.png"),
553            website!("https://example.com/static/image300.png"),
554            website!("https://example.com/static/image600.png"),
555        ]);
556
557        assert_eq!(links, expected_links);
558    }
559
560    #[tokio::test]
561    async fn test_markdown_internal_url() {
562        let base = Base::try_from("https://localhost.com/").unwrap();
563
564        let input = Input {
565            source: InputSource::String(Cow::Borrowed(
566                "This is [an internal url](@/internal.md)
567        This is [an internal url](@/internal.markdown)
568        This is [an internal url](@/internal.markdown#example)
569        This is [an internal url](@/internal.md#example)",
570            )),
571            file_type_hint: Some(FileType::Markdown),
572        };
573        let inputs = HashSet::from_iter([input]);
574
575        let links = collect(inputs, None, Some(base)).await.ok().unwrap();
576
577        let expected = HashSet::from_iter([
578            website!("https://localhost.com/@/internal.md"),
579            website!("https://localhost.com/@/internal.markdown"),
580            website!("https://localhost.com/@/internal.md#example"),
581            website!("https://localhost.com/@/internal.markdown#example"),
582        ]);
583
584        assert_eq!(links, expected);
585    }
586
587    #[tokio::test]
588    async fn test_extract_html5_not_valid_xml_relative_links() {
589        let base = Base::try_from("https://example.com").unwrap();
590        let input = load_fixture!("TEST_HTML5.html");
591
592        let input = Input {
593            source: InputSource::String(Cow::Owned(input)),
594            file_type_hint: Some(FileType::Html),
595        };
596        let inputs = HashSet::from_iter([input]);
597
598        let links = collect(inputs, None, Some(base)).await.ok().unwrap();
599
600        let expected_links = HashSet::from_iter([
601            // the body links wouldn't be present if the file was parsed strictly as XML
602            website!("https://example.com/body/a"),
603            website!("https://example.com/body/div_empty_a"),
604            website!("https://example.com/css/style_full_url.css"),
605            website!("https://example.com/css/style_relative_url.css"),
606            website!("https://example.com/head/home"),
607            website!("https://example.com/images/icon.png"),
608        ]);
609
610        assert_eq!(links, expected_links);
611    }
612
613    #[tokio::test]
614    async fn test_relative_url_with_base_extracted_from_input() {
615        let contents = r#"<html>
616            <div class="row">
617                <a href="https://github.com/lycheeverse/lychee/">GitHub</a>
618                <a href="/about">About</a>
619            </div>
620        </html>"#;
621        let mock_server = mock_server!(StatusCode::OK, set_body_string(contents));
622
623        let server_uri = Url::parse(&mock_server.uri()).unwrap();
624
625        let input = Input::from_input_source(InputSource::RemoteUrl(Box::new(server_uri.clone())));
626
627        let inputs = HashSet::from_iter([input]);
628
629        let links = collect(inputs, None, None).await.ok().unwrap();
630
631        let expected_urls = HashSet::from_iter([
632            website!("https://github.com/lycheeverse/lychee/"),
633            website!(&format!("{server_uri}about")),
634        ]);
635
636        assert_eq!(links, expected_urls);
637    }
638
639    #[tokio::test]
640    async fn test_email_with_query_params() {
641        let input = Input::from_input_source(InputSource::String(Cow::Borrowed(
642            "This is a mailto:user@example.com?subject=Hello link",
643        )));
644
645        let inputs = HashSet::from_iter([input]);
646
647        let links = collect(inputs, None, None).await.ok().unwrap();
648
649        let expected_links = HashSet::from_iter([mail!("user@example.com")]);
650
651        assert_eq!(links, expected_links);
652    }
653
654    #[tokio::test]
655    async fn test_multiple_remote_urls() {
656        let mock_server_1 = mock_server!(
657            StatusCode::OK,
658            set_body_string(r#"<a href="relative.html">Link</a>"#)
659        );
660        let mock_server_2 = mock_server!(
661            StatusCode::OK,
662            set_body_string(r#"<a href="relative.html">Link</a>"#)
663        );
664
665        let inputs = HashSet::from_iter([
666            Input {
667                source: InputSource::RemoteUrl(Box::new(
668                    Url::parse(&format!(
669                        "{}/foo/index.html",
670                        mock_server_1.uri().trim_end_matches('/')
671                    ))
672                    .unwrap(),
673                )),
674                file_type_hint: Some(FileType::Html),
675            },
676            Input {
677                source: InputSource::RemoteUrl(Box::new(
678                    Url::parse(&format!(
679                        "{}/bar/index.html",
680                        mock_server_2.uri().trim_end_matches('/')
681                    ))
682                    .unwrap(),
683                )),
684                file_type_hint: Some(FileType::Html),
685            },
686        ]);
687
688        let links = collect(inputs, None, None).await.ok().unwrap();
689
690        let expected_links = HashSet::from_iter([
691            website!(&format!(
692                "{}/foo/relative.html",
693                mock_server_1.uri().trim_end_matches('/')
694            )),
695            website!(&format!(
696                "{}/bar/relative.html",
697                mock_server_2.uri().trim_end_matches('/')
698            )),
699        ]);
700
701        assert_eq!(links, expected_links);
702    }
703
704    #[tokio::test]
705    async fn test_file_path_with_base() {
706        let base = Base::try_from("/path/to/root").unwrap();
707        assert_eq!(base, Base::Local("/path/to/root".into()));
708
709        let input = Input {
710            source: InputSource::String(Cow::Borrowed(
711                r#"
712                <a href="index.html">Index</a>
713                <a href="about.html">About</a>
714                <a href="/another.html">Another</a>
715            "#,
716            )),
717            file_type_hint: Some(FileType::Html),
718        };
719
720        let inputs = HashSet::from_iter([input]);
721
722        let links = collect(inputs, None, Some(base)).await.ok().unwrap();
723
724        let expected_links = HashSet::from_iter([
725            path!("/path/to/root/index.html"),
726            path!("/path/to/root/about.html"),
727            path!("/another.html"),
728        ]);
729
730        assert_eq!(links, expected_links);
731    }
732}