langchain_rust/tools/scraper/
scraper.rs

1use async_trait::async_trait;
2use regex::Regex;
3use scraper::{ElementRef, Html, Selector};
4use serde_json::Value;
5use std::{error::Error, sync::Arc};
6
7use crate::tools::Tool;
8
9pub struct WebScrapper {}
10
11impl WebScrapper {
12    pub fn new() -> Self {
13        Self {}
14    }
15}
16
17#[async_trait]
18impl Tool for WebScrapper {
19    fn name(&self) -> String {
20        String::from("Web Scraper")
21    }
22    fn description(&self) -> String {
23        String::from(
24            "Web Scraper will scan a url and return the content of the web page.
25		Input should be a working url.",
26        )
27    }
28    async fn run(&self, input: Value) -> Result<String, Box<dyn Error>> {
29        let input = input.as_str().ok_or("Invalid input")?;
30        match scrape_url(input).await {
31            Ok(content) => Ok(content),
32            Err(e) => Ok(format!("Error scraping {}: {}\n", input, e)),
33        }
34    }
35}
36
37impl Into<Arc<dyn Tool>> for WebScrapper {
38    fn into(self) -> Arc<dyn Tool> {
39        Arc::new(self)
40    }
41}
42
43async fn scrape_url(url: &str) -> Result<String, Box<dyn Error>> {
44    let res = reqwest::get(url).await?.text().await?;
45
46    let document = Html::parse_document(&res);
47    let body_selector = Selector::parse("body").unwrap();
48
49    let mut text = Vec::new();
50    for element in document.select(&body_selector) {
51        collect_text_not_in_script(&element, &mut text);
52    }
53
54    let joined_text = text.join(" ");
55    let cleaned_text = joined_text.replace(['\n', '\t'], " ");
56    let re = Regex::new(r"\s+").unwrap();
57    let final_text = re.replace_all(&cleaned_text, " ");
58    Ok(final_text.to_string())
59}
60
61fn collect_text_not_in_script(element: &ElementRef, text: &mut Vec<String>) {
62    for node in element.children() {
63        if node.value().is_element() {
64            let tag_name = node.value().as_element().unwrap().name();
65            if tag_name == "script" {
66                continue;
67            }
68            collect_text_not_in_script(&ElementRef::wrap(node).unwrap(), text);
69        } else if node.value().is_text() {
70            text.push(node.value().as_text().unwrap().text.to_string());
71        }
72    }
73}
74
75#[cfg(test)]
76mod tests {
77    use super::*;
78    use tokio;
79
80    #[tokio::test]
81    async fn test_scrape_url() {
82        // Request a new server from the pool
83        let mut server = mockito::Server::new_async().await;
84
85        // Create a mock on the server
86        let mock = server
87            .mock("GET", "/")
88            .with_status(200)
89            .with_header("content-type", "text/plain")
90            .with_body("<html><body>Hello World</body></html>")
91            .create();
92
93        // Instantiate your WebScrapper
94        let scraper = WebScrapper::new();
95
96        // Use the server URL for scraping
97        let url = server.url();
98
99        // Call the WebScrapper with the mocked URL
100        let result = scraper.call(&url).await;
101
102        // Assert that the result is Ok and contains "Hello World"
103        assert!(result.is_ok());
104        let content = result.unwrap();
105        assert_eq!(content.trim(), "Hello World");
106
107        // Verify that the mock was called as expected
108        mock.assert();
109    }
110}