langchain_rust/tools/scraper/
scraper.rs1use async_trait::async_trait;
2use regex::Regex;
3use scraper::{ElementRef, Html, Selector};
4use serde_json::Value;
5use std::{error::Error, sync::Arc};
6
7use crate::tools::Tool;
8
9pub struct WebScrapper {}
10
11impl WebScrapper {
12 pub fn new() -> Self {
13 Self {}
14 }
15}
16
17#[async_trait]
18impl Tool for WebScrapper {
19 fn name(&self) -> String {
20 String::from("Web Scraper")
21 }
22 fn description(&self) -> String {
23 String::from(
24 "Web Scraper will scan a url and return the content of the web page.
25 Input should be a working url.",
26 )
27 }
28 async fn run(&self, input: Value) -> Result<String, Box<dyn Error>> {
29 let input = input.as_str().ok_or("Invalid input")?;
30 match scrape_url(input).await {
31 Ok(content) => Ok(content),
32 Err(e) => Ok(format!("Error scraping {}: {}\n", input, e)),
33 }
34 }
35}
36
37impl Into<Arc<dyn Tool>> for WebScrapper {
38 fn into(self) -> Arc<dyn Tool> {
39 Arc::new(self)
40 }
41}
42
43async fn scrape_url(url: &str) -> Result<String, Box<dyn Error>> {
44 let res = reqwest::get(url).await?.text().await?;
45
46 let document = Html::parse_document(&res);
47 let body_selector = Selector::parse("body").unwrap();
48
49 let mut text = Vec::new();
50 for element in document.select(&body_selector) {
51 collect_text_not_in_script(&element, &mut text);
52 }
53
54 let joined_text = text.join(" ");
55 let cleaned_text = joined_text.replace(['\n', '\t'], " ");
56 let re = Regex::new(r"\s+").unwrap();
57 let final_text = re.replace_all(&cleaned_text, " ");
58 Ok(final_text.to_string())
59}
60
61fn collect_text_not_in_script(element: &ElementRef, text: &mut Vec<String>) {
62 for node in element.children() {
63 if node.value().is_element() {
64 let tag_name = node.value().as_element().unwrap().name();
65 if tag_name == "script" {
66 continue;
67 }
68 collect_text_not_in_script(&ElementRef::wrap(node).unwrap(), text);
69 } else if node.value().is_text() {
70 text.push(node.value().as_text().unwrap().text.to_string());
71 }
72 }
73}
74
75#[cfg(test)]
76mod tests {
77 use super::*;
78 use tokio;
79
80 #[tokio::test]
81 async fn test_scrape_url() {
82 let mut server = mockito::Server::new_async().await;
84
85 let mock = server
87 .mock("GET", "/")
88 .with_status(200)
89 .with_header("content-type", "text/plain")
90 .with_body("<html><body>Hello World</body></html>")
91 .create();
92
93 let scraper = WebScrapper::new();
95
96 let url = server.url();
98
99 let result = scraper.call(&url).await;
101
102 assert!(result.is_ok());
104 let content = result.unwrap();
105 assert_eq!(content.trim(), "Hello World");
106
107 mock.assert();
109 }
110}