lychee_lib/
collector.rs

1use crate::ErrorKind;
2use crate::InputSource;
3use crate::Preprocessor;
4use crate::filter::PathExcludes;
5
6use crate::types::resolver::UrlContentResolver;
7use crate::{
8    Base, Input, LycheeResult, Request, RequestError, basic_auth::BasicAuthExtractor,
9    extract::Extractor, types::FileExtensions, types::uri::raw::RawUri, utils::request,
10};
11use futures::TryStreamExt;
12use futures::{
13    StreamExt,
14    stream::{self, Stream},
15};
16use http::HeaderMap;
17use par_stream::ParStreamExt;
18use reqwest::Client;
19use std::collections::HashSet;
20use std::path::PathBuf;
21
22/// Collector keeps the state of link collection
23/// It drives the link extraction from inputs
24#[allow(clippy::struct_excessive_bools)]
25#[derive(Debug, Clone)]
26pub struct Collector {
27    basic_auth_extractor: Option<BasicAuthExtractor>,
28    skip_missing_inputs: bool,
29    skip_ignored: bool,
30    skip_hidden: bool,
31    include_verbatim: bool,
32    include_wikilinks: bool,
33    use_html5ever: bool,
34    root_dir: Option<PathBuf>,
35    base: Option<Base>,
36    excluded_paths: PathExcludes,
37    headers: HeaderMap,
38    client: Client,
39    preprocessor: Option<Preprocessor>,
40}
41
42impl Default for Collector {
43    /// # Panics
44    ///
45    /// We call [`Collector::new()`] which can panic in certain scenarios.
46    ///
47    /// Use `Collector::new()` instead if you need to handle
48    /// [`ClientBuilder`](crate::ClientBuilder) errors gracefully.
49    fn default() -> Self {
50        Collector {
51            basic_auth_extractor: None,
52            skip_missing_inputs: false,
53            include_verbatim: false,
54            include_wikilinks: false,
55            use_html5ever: false,
56            skip_hidden: true,
57            skip_ignored: true,
58            root_dir: None,
59            base: None,
60            headers: HeaderMap::new(),
61            client: Client::new(),
62            excluded_paths: PathExcludes::empty(),
63            preprocessor: None,
64        }
65    }
66}
67
68impl Collector {
69    /// Create a new collector with an empty cache
70    ///
71    /// # Errors
72    ///
73    /// Returns an `Err` if the `root_dir` is not an absolute path
74    /// or if the reqwest `Client` fails to build
75    pub fn new(root_dir: Option<PathBuf>, base: Option<Base>) -> LycheeResult<Self> {
76        let root_dir = match root_dir {
77            Some(root_dir) if base.is_some() => Some(root_dir),
78            Some(root_dir) => Some(
79                root_dir
80                    .canonicalize()
81                    .map_err(|e| ErrorKind::InvalidRootDir(root_dir, e))?,
82            ),
83            None => None,
84        };
85        Ok(Collector {
86            basic_auth_extractor: None,
87            skip_missing_inputs: false,
88            include_verbatim: false,
89            include_wikilinks: false,
90            use_html5ever: false,
91            skip_hidden: true,
92            skip_ignored: true,
93            preprocessor: None,
94            headers: HeaderMap::new(),
95            client: Client::builder()
96                .build()
97                .map_err(ErrorKind::BuildRequestClient)?,
98            excluded_paths: PathExcludes::empty(),
99            root_dir,
100            base,
101        })
102    }
103
104    /// Skip missing input files (default is to error if they don't exist)
105    #[must_use]
106    pub const fn skip_missing_inputs(mut self, yes: bool) -> Self {
107        self.skip_missing_inputs = yes;
108        self
109    }
110
111    /// Skip files that are hidden
112    #[must_use]
113    pub const fn skip_hidden(mut self, yes: bool) -> Self {
114        self.skip_hidden = yes;
115        self
116    }
117
118    /// Skip files that are ignored
119    #[must_use]
120    pub const fn skip_ignored(mut self, yes: bool) -> Self {
121        self.skip_ignored = yes;
122        self
123    }
124
125    /// Set headers to use when resolving input URLs
126    #[must_use]
127    pub fn headers(mut self, headers: HeaderMap) -> Self {
128        self.headers = headers;
129        self
130    }
131
132    /// Set client to use for checking input URLs
133    #[must_use]
134    pub fn client(mut self, client: Client) -> Self {
135        self.client = client;
136        self
137    }
138
139    /// Use `html5ever` to parse HTML instead of `html5gum`.
140    #[must_use]
141    pub const fn use_html5ever(mut self, yes: bool) -> Self {
142        self.use_html5ever = yes;
143        self
144    }
145
146    /// Skip over links in verbatim sections (like Markdown code blocks)
147    #[must_use]
148    pub const fn include_verbatim(mut self, yes: bool) -> Self {
149        self.include_verbatim = yes;
150        self
151    }
152
153    /// Check WikiLinks in Markdown files
154    #[allow(clippy::doc_markdown)]
155    #[must_use]
156    pub const fn include_wikilinks(mut self, yes: bool) -> Self {
157        self.include_wikilinks = yes;
158        self
159    }
160
161    /// Configure a file [`Preprocessor`]
162    #[must_use]
163    pub fn preprocessor(mut self, preprocessor: Option<Preprocessor>) -> Self {
164        self.preprocessor = preprocessor;
165        self
166    }
167
168    /// Pass a [`BasicAuthExtractor`] which is capable to match found
169    /// URIs to basic auth credentials. These credentials get passed to the
170    /// request in question.
171    #[must_use]
172    #[allow(clippy::missing_const_for_fn)]
173    pub fn basic_auth_extractor(mut self, extractor: BasicAuthExtractor) -> Self {
174        self.basic_auth_extractor = Some(extractor);
175        self
176    }
177
178    /// Configure which paths to exclude
179    #[must_use]
180    pub fn excluded_paths(mut self, excluded_paths: PathExcludes) -> Self {
181        self.excluded_paths = excluded_paths;
182        self
183    }
184
185    /// Convenience method to fetch all unique links from inputs
186    /// with the default extensions.
187    pub fn collect_links(
188        self,
189        inputs: HashSet<Input>,
190    ) -> impl Stream<Item = Result<Request, RequestError>> {
191        self.collect_links_from_file_types(inputs, crate::types::FileType::default_extensions())
192    }
193
194    /// Fetch all unique links from inputs
195    /// All relative URLs get prefixed with `base` (if given).
196    /// (This can be a directory or a base URL)
197    ///
198    /// # Errors
199    ///
200    /// Will return `Err` if links cannot be extracted from an input
201    pub fn collect_links_from_file_types(
202        self,
203        inputs: HashSet<Input>,
204        extensions: FileExtensions,
205    ) -> impl Stream<Item = Result<Request, RequestError>> {
206        let skip_missing_inputs = self.skip_missing_inputs;
207        let skip_hidden = self.skip_hidden;
208        let skip_ignored = self.skip_ignored;
209        let global_base = self.base;
210        let excluded_paths = self.excluded_paths;
211
212        let resolver = UrlContentResolver {
213            basic_auth_extractor: self.basic_auth_extractor.clone(),
214            headers: self.headers.clone(),
215            client: self.client,
216        };
217
218        let extractor = Extractor::new(
219            self.use_html5ever,
220            self.include_verbatim,
221            self.include_wikilinks,
222        );
223
224        stream::iter(inputs)
225            .par_then_unordered(None, move |input| {
226                let default_base = global_base.clone();
227                let extensions = extensions.clone();
228                let resolver = resolver.clone();
229                let excluded_paths = excluded_paths.clone();
230                let preprocessor = self.preprocessor.clone();
231
232                async move {
233                    let base = match &input.source {
234                        InputSource::RemoteUrl(url) => Base::try_from(url.as_str()).ok(),
235                        _ => default_base,
236                    };
237
238                    input
239                        .get_contents(
240                            skip_missing_inputs,
241                            skip_hidden,
242                            skip_ignored,
243                            extensions,
244                            resolver,
245                            excluded_paths,
246                            preprocessor,
247                        )
248                        .map(move |content| (content, base.clone()))
249                }
250            })
251            .flatten()
252            .par_then_unordered(None, move |(content, base)| {
253                let root_dir = self.root_dir.clone();
254                let basic_auth_extractor = self.basic_auth_extractor.clone();
255                async move {
256                    let content = content?;
257                    let uris: Vec<RawUri> = extractor.extract(&content);
258                    let requests = request::create(
259                        uris,
260                        &content.source,
261                        root_dir.as_ref(),
262                        base.as_ref(),
263                        basic_auth_extractor.as_ref(),
264                    );
265                    Result::Ok(stream::iter(requests))
266                }
267            })
268            .try_flatten()
269    }
270}
271
272#[cfg(test)]
273mod tests {
274    use std::borrow::Cow;
275    use std::{collections::HashSet, convert::TryFrom, fs::File, io::Write};
276    use test_utils::{fixtures_path, load_fixture, mail, mock_server, path, website};
277
278    use http::StatusCode;
279    use reqwest::Url;
280
281    use super::*;
282    use crate::{
283        LycheeResult, Uri,
284        filter::PathExcludes,
285        types::{FileType, Input, InputSource},
286    };
287
288    // Helper function to run the collector on the given inputs
289    async fn collect(
290        inputs: HashSet<Input>,
291        root_dir: Option<PathBuf>,
292        base: Option<Base>,
293    ) -> LycheeResult<HashSet<Uri>> {
294        let responses = Collector::new(root_dir, base)?.collect_links(inputs);
295        Ok(responses.map(|r| r.unwrap().uri).collect().await)
296    }
297
298    /// Helper function for collecting verbatim links
299    ///
300    /// A verbatim link is a link that is not parsed by the HTML parser.
301    /// For example, a link in a code block or a script tag.
302    async fn collect_verbatim(
303        inputs: HashSet<Input>,
304        root_dir: Option<PathBuf>,
305        base: Option<Base>,
306        extensions: FileExtensions,
307    ) -> LycheeResult<HashSet<Uri>> {
308        let responses = Collector::new(root_dir, base)?
309            .include_verbatim(true)
310            .collect_links_from_file_types(inputs, extensions);
311        Ok(responses.map(|r| r.unwrap().uri).collect().await)
312    }
313
314    const TEST_STRING: &str = "http://test-string.com";
315    const TEST_URL: &str = "https://test-url.org";
316    const TEST_FILE: &str = "https://test-file.io";
317    const TEST_GLOB_1: &str = "https://test-glob-1.io";
318    const TEST_GLOB_2_MAIL: &str = "test@glob-2.io";
319
320    #[tokio::test]
321    async fn test_file_without_extension_is_plaintext() -> LycheeResult<()> {
322        let temp_dir = tempfile::tempdir().unwrap();
323        // Treat as plaintext file (no extension)
324        let file_path = temp_dir.path().join("README");
325        let _file = File::create(&file_path).unwrap();
326        let input = Input::new(&file_path.as_path().display().to_string(), None, true)?;
327        let contents: Vec<_> = input
328            .get_contents(
329                true,
330                true,
331                true,
332                FileType::default_extensions(),
333                UrlContentResolver::default(),
334                PathExcludes::empty(),
335                None,
336            )
337            .collect::<Vec<_>>()
338            .await;
339
340        assert_eq!(contents.len(), 1);
341        assert_eq!(contents[0].as_ref().unwrap().file_type, FileType::Plaintext);
342        Ok(())
343    }
344
345    #[tokio::test]
346    async fn test_url_without_extension_is_html() -> LycheeResult<()> {
347        let input = Input::new("https://example.com/", None, true)?;
348        let contents: Vec<_> = input
349            .get_contents(
350                true,
351                true,
352                true,
353                FileType::default_extensions(),
354                UrlContentResolver::default(),
355                PathExcludes::empty(),
356                None,
357            )
358            .collect::<Vec<_>>()
359            .await;
360
361        assert_eq!(contents.len(), 1);
362        assert_eq!(contents[0].as_ref().unwrap().file_type, FileType::Html);
363        Ok(())
364    }
365
366    #[tokio::test]
367    async fn test_collect_links() -> LycheeResult<()> {
368        let temp_dir = tempfile::tempdir().unwrap();
369        let temp_dir_path = temp_dir.path();
370
371        let file_path = temp_dir_path.join("f");
372        let file_glob_1_path = temp_dir_path.join("glob-1");
373        let file_glob_2_path = temp_dir_path.join("glob-2");
374
375        let mut file = File::create(&file_path).unwrap();
376        let mut file_glob_1 = File::create(file_glob_1_path).unwrap();
377        let mut file_glob_2 = File::create(file_glob_2_path).unwrap();
378
379        writeln!(file, "{TEST_FILE}").unwrap();
380        writeln!(file_glob_1, "{TEST_GLOB_1}").unwrap();
381        writeln!(file_glob_2, "{TEST_GLOB_2_MAIL}").unwrap();
382
383        let mock_server = mock_server!(StatusCode::OK, set_body_string(TEST_URL));
384
385        let inputs = HashSet::from_iter([
386            Input::from_input_source(InputSource::String(Cow::Borrowed(TEST_STRING))),
387            Input::from_input_source(InputSource::RemoteUrl(Box::new(
388                Url::parse(&mock_server.uri())
389                    .map_err(|e| (mock_server.uri(), e))
390                    .unwrap(),
391            ))),
392            Input::from_input_source(InputSource::FsPath(file_path)),
393            Input::from_input_source(InputSource::FsGlob {
394                pattern: glob::Pattern::new(&temp_dir_path.join("glob*").to_string_lossy())?,
395                ignore_case: true,
396            }),
397        ]);
398
399        let links = collect_verbatim(inputs, None, None, FileType::default_extensions())
400            .await
401            .ok()
402            .unwrap();
403
404        let expected_links = HashSet::from_iter([
405            website!(TEST_STRING),
406            website!(TEST_URL),
407            website!(TEST_FILE),
408            website!(TEST_GLOB_1),
409            mail!(TEST_GLOB_2_MAIL),
410        ]);
411
412        assert_eq!(links, expected_links);
413
414        Ok(())
415    }
416
417    #[tokio::test]
418    async fn test_collect_markdown_links() {
419        let base = Base::try_from("https://github.com/hello-rust/lychee/").unwrap();
420        let input = Input {
421            source: InputSource::String(Cow::Borrowed(
422                "This is [a test](https://endler.dev). This is a relative link test [Relative Link Test](relative_link)",
423            )),
424            file_type_hint: Some(FileType::Markdown),
425        };
426        let inputs = HashSet::from_iter([input]);
427
428        let links = collect(inputs, None, Some(base)).await.ok().unwrap();
429
430        let expected_links = HashSet::from_iter([
431            website!("https://endler.dev"),
432            website!("https://github.com/hello-rust/lychee/relative_link"),
433        ]);
434
435        assert_eq!(links, expected_links);
436    }
437
438    #[tokio::test]
439    async fn test_collect_html_links() {
440        let base = Base::try_from("https://github.com/lycheeverse/").unwrap();
441        let input = Input {
442            source: InputSource::String(Cow::Borrowed(
443                r#"<html>
444                <div class="row">
445                    <a href="https://github.com/lycheeverse/lychee/">
446                    <a href="blob/master/README.md">README</a>
447                </div>
448            </html>"#,
449            )),
450            file_type_hint: Some(FileType::Html),
451        };
452        let inputs = HashSet::from_iter([input]);
453
454        let links = collect(inputs, None, Some(base)).await.ok().unwrap();
455
456        let expected_links = HashSet::from_iter([
457            website!("https://github.com/lycheeverse/lychee/"),
458            website!("https://github.com/lycheeverse/blob/master/README.md"),
459        ]);
460
461        assert_eq!(links, expected_links);
462    }
463
464    #[tokio::test]
465    async fn test_collect_html_srcset() {
466        let base = Base::try_from("https://example.com/").unwrap();
467        let input = Input {
468            source: InputSource::String(Cow::Borrowed(
469                r#"
470            <img
471                src="/static/image.png"
472                srcset="
473                /static/image300.png  300w,
474                /static/image600.png  600w,
475                "
476            />
477          "#,
478            )),
479            file_type_hint: Some(FileType::Html),
480        };
481        let inputs = HashSet::from_iter([input]);
482
483        let links = collect(inputs, None, Some(base)).await.ok().unwrap();
484
485        let expected_links = HashSet::from_iter([
486            website!("https://example.com/static/image.png"),
487            website!("https://example.com/static/image300.png"),
488            website!("https://example.com/static/image600.png"),
489        ]);
490
491        assert_eq!(links, expected_links);
492    }
493
494    #[tokio::test]
495    async fn test_markdown_internal_url() {
496        let base = Base::try_from("https://localhost.com/").unwrap();
497
498        let input = Input {
499            source: InputSource::String(Cow::Borrowed(
500                "This is [an internal url](@/internal.md)
501        This is [an internal url](@/internal.markdown)
502        This is [an internal url](@/internal.markdown#example)
503        This is [an internal url](@/internal.md#example)",
504            )),
505            file_type_hint: Some(FileType::Markdown),
506        };
507        let inputs = HashSet::from_iter([input]);
508
509        let links = collect(inputs, None, Some(base)).await.ok().unwrap();
510
511        let expected = HashSet::from_iter([
512            website!("https://localhost.com/@/internal.md"),
513            website!("https://localhost.com/@/internal.markdown"),
514            website!("https://localhost.com/@/internal.md#example"),
515            website!("https://localhost.com/@/internal.markdown#example"),
516        ]);
517
518        assert_eq!(links, expected);
519    }
520
521    #[tokio::test]
522    async fn test_extract_html5_not_valid_xml_relative_links() {
523        let base = Base::try_from("https://example.com").unwrap();
524        let input = load_fixture!("TEST_HTML5.html");
525
526        let input = Input {
527            source: InputSource::String(Cow::Owned(input)),
528            file_type_hint: Some(FileType::Html),
529        };
530        let inputs = HashSet::from_iter([input]);
531
532        let links = collect(inputs, None, Some(base)).await.ok().unwrap();
533
534        let expected_links = HashSet::from_iter([
535            // the body links wouldn't be present if the file was parsed strictly as XML
536            website!("https://example.com/body/a"),
537            website!("https://example.com/body/div_empty_a"),
538            website!("https://example.com/css/style_full_url.css"),
539            website!("https://example.com/css/style_relative_url.css"),
540            website!("https://example.com/head/home"),
541            website!("https://example.com/images/icon.png"),
542        ]);
543
544        assert_eq!(links, expected_links);
545    }
546
547    #[tokio::test]
548    async fn test_relative_url_with_base_extracted_from_input() {
549        let contents = r#"<html>
550            <div class="row">
551                <a href="https://github.com/lycheeverse/lychee/">GitHub</a>
552                <a href="/about">About</a>
553            </div>
554        </html>"#;
555        let mock_server = mock_server!(StatusCode::OK, set_body_string(contents));
556
557        let server_uri = Url::parse(&mock_server.uri()).unwrap();
558
559        let input = Input::from_input_source(InputSource::RemoteUrl(Box::new(server_uri.clone())));
560
561        let inputs = HashSet::from_iter([input]);
562
563        let links = collect(inputs, None, None).await.ok().unwrap();
564
565        let expected_urls = HashSet::from_iter([
566            website!("https://github.com/lycheeverse/lychee/"),
567            website!(&format!("{server_uri}about")),
568        ]);
569
570        assert_eq!(links, expected_urls);
571    }
572
573    #[tokio::test]
574    async fn test_email_with_query_params() {
575        let input = Input::from_input_source(InputSource::String(Cow::Borrowed(
576            "This is a mailto:user@example.com?subject=Hello link",
577        )));
578
579        let inputs = HashSet::from_iter([input]);
580
581        let links = collect(inputs, None, None).await.ok().unwrap();
582
583        let expected_links = HashSet::from_iter([mail!("user@example.com")]);
584
585        assert_eq!(links, expected_links);
586    }
587
588    #[tokio::test]
589    async fn test_multiple_remote_urls() {
590        let mock_server_1 = mock_server!(
591            StatusCode::OK,
592            set_body_string(r#"<a href="relative.html">Link</a>"#)
593        );
594        let mock_server_2 = mock_server!(
595            StatusCode::OK,
596            set_body_string(r#"<a href="relative.html">Link</a>"#)
597        );
598
599        let inputs = HashSet::from_iter([
600            Input {
601                source: InputSource::RemoteUrl(Box::new(
602                    Url::parse(&format!(
603                        "{}/foo/index.html",
604                        mock_server_1.uri().trim_end_matches('/')
605                    ))
606                    .unwrap(),
607                )),
608                file_type_hint: Some(FileType::Html),
609            },
610            Input {
611                source: InputSource::RemoteUrl(Box::new(
612                    Url::parse(&format!(
613                        "{}/bar/index.html",
614                        mock_server_2.uri().trim_end_matches('/')
615                    ))
616                    .unwrap(),
617                )),
618                file_type_hint: Some(FileType::Html),
619            },
620        ]);
621
622        let links = collect(inputs, None, None).await.ok().unwrap();
623
624        let expected_links = HashSet::from_iter([
625            website!(&format!(
626                "{}/foo/relative.html",
627                mock_server_1.uri().trim_end_matches('/')
628            )),
629            website!(&format!(
630                "{}/bar/relative.html",
631                mock_server_2.uri().trim_end_matches('/')
632            )),
633        ]);
634
635        assert_eq!(links, expected_links);
636    }
637
638    #[tokio::test]
639    async fn test_file_path_with_base() {
640        let base = Base::try_from("/path/to/root").unwrap();
641        assert_eq!(base, Base::Local("/path/to/root".into()));
642
643        let input = Input {
644            source: InputSource::String(Cow::Borrowed(
645                r#"
646                <a href="index.html">Index</a>
647                <a href="about.html">About</a>
648                <a href="/another.html">Another</a>
649            "#,
650            )),
651            file_type_hint: Some(FileType::Html),
652        };
653
654        let inputs = HashSet::from_iter([input]);
655
656        let links = collect(inputs, None, Some(base)).await.ok().unwrap();
657
658        let expected_links = HashSet::from_iter([
659            path!("/path/to/root/index.html"),
660            path!("/path/to/root/about.html"),
661            path!("/another.html"),
662        ]);
663
664        assert_eq!(links, expected_links);
665    }
666}