crabler_tokio/
lib.rs

1//! Goal of this library is to help crabs with web crawling.
2//!
3//!```rust
4//!use crabler_tokio::*;
5//!
6//!#[derive(WebScraper)]
7//!#[on_response(response_handler)]
8//!#[on_html("a[href]", print_handler)]
9//!struct Scraper {}
10//!
11//!impl Scraper {
12//!    async fn response_handler(&self, response: Response) -> Result<()> {
13//!        println!("Status {}", response.status);
14//!        Ok(())
15//!    }
16//!
17//!    async fn print_handler(&self, response: Response, a: Element) -> Result<()> {
18//!        if let Some(href) = a.attr("href") {
19//!            println!("Found link {} on {}", href, response.url);
20//!        }
21//!
22//!        Ok(())
23//!    }
24//!}
25//!
26//!#[tokio::main]
27//!async fn main() -> Result<()> {
28//!    let scraper = Scraper {};
29//!
30//!    scraper.run(Opts::default().with_urls(vec!["https://www.rust-lang.org/"])).await
31//!}
32//!```
33
34mod opts;
35pub use opts::*;
36
37mod errors;
38pub use errors::*;
39
40pub use crabquery::{Document, Element};
41use flume::{unbounded, Receiver, Sender};
42use log::{debug, error, info, warn};
43use std::collections::HashSet;
44use std::fmt::Debug;
45use std::sync::atomic::{AtomicUsize, Ordering};
46use std::sync::Arc;
47use tokio::{fs::File, io::AsyncWriteExt, sync::RwLock};
48
49pub use async_trait::async_trait;
50pub use crabler_tokio_derive::WebScraper;
51
52#[cfg(feature = "debug")]
53fn enable_logging() {
54    femme::with_level(femme::LevelFilter::Info);
55}
56
57#[cfg(not(feature = "debug"))]
58fn enable_logging() {}
59
60#[async_trait(?Send)]
61pub trait WebScraper {
62    async fn dispatch_on_page(&mut self, page: String) -> Result<()>;
63    async fn dispatch_on_html(
64        &mut self,
65        selector: &str,
66        response: Response,
67        element: Element,
68    ) -> Result<()>;
69    async fn dispatch_on_response(&mut self, response: Response) -> Result<()>;
70    fn all_html_selectors(&self) -> Vec<&str>;
71    async fn run(self, opts: Opts) -> Result<()>;
72}
73
74#[derive(Debug)]
75enum WorkInput {
76    Navigate(String),
77    Download { url: String, destination: String },
78    Exit,
79}
80
81#[derive(Debug)]
82pub struct Response {
83    pub url: String,
84    pub status: u16,
85    pub download_destination: Option<String>,
86    workinput_tx: Sender<WorkInput>,
87    counter: Arc<AtomicUsize>,
88}
89
90impl Response {
91    fn new(
92        status: u16,
93        url: String,
94        download_destination: Option<String>,
95        workinput_tx: Sender<WorkInput>,
96        counter: Arc<AtomicUsize>,
97    ) -> Self {
98        Response {
99            status,
100            url,
101            download_destination,
102            workinput_tx,
103            counter,
104        }
105    }
106
107    /// Schedule scraper to visit given url,
108    /// this will be executed on one of worker tasks
109    pub async fn navigate(&mut self, url: String) -> Result<()> {
110        debug!("Increasing counter by 1");
111        self.counter.fetch_add(1, Ordering::SeqCst);
112        self.workinput_tx
113            .send_async(WorkInput::Navigate(url))
114            .await?;
115
116        Ok(())
117    }
118
119    /// Schedule scraper to download file from url into destination path
120    pub async fn download_file(&mut self, url: String, destination: String) -> Result<()> {
121        debug!("Increasing counter by 1");
122        self.counter.fetch_add(1, Ordering::SeqCst);
123        self.workinput_tx
124            .send_async(WorkInput::Download { url, destination })
125            .await?;
126
127        Ok(())
128    }
129}
130
131#[derive(Clone)]
132struct Channels<T> {
133    tx: Sender<T>,
134    rx: Receiver<T>,
135}
136
137impl<T> Channels<T> {
138    fn new() -> Self {
139        let (tx, rx) = unbounded();
140
141        Self { tx, rx }
142    }
143}
144
145pub struct Crabler<T>
146where
147    T: WebScraper,
148{
149    visited_links: Arc<RwLock<HashSet<String>>>,
150    workinput_ch: Channels<WorkInput>,
151    workoutput_ch: Channels<WorkOutput>,
152    scraper: T,
153    counter: Arc<AtomicUsize>,
154    workers: Vec<tokio::task::JoinHandle<()>>,
155    reqwest_client: reqwest::Client,
156}
157
158impl<T> Crabler<T>
159where
160    T: WebScraper,
161{
162    /// Create new WebScraper out of given scraper struct
163    pub fn new(scraper: T, opts: &Opts) -> Self {
164        let visited_links = Arc::new(RwLock::new(HashSet::new()));
165        let workinput_ch = Channels::new();
166        let workoutput_ch = Channels::new();
167        let counter = Arc::new(AtomicUsize::new(0));
168        let workers = vec![];
169        let reqwest_client = if opts.follow_redirects {
170            reqwest::Client::new()
171        } else {
172            reqwest::Client::builder()
173                .redirect(reqwest::redirect::Policy::none())
174                .build()
175                .unwrap()
176        };
177
178        Crabler {
179            visited_links,
180            workinput_ch,
181            workoutput_ch,
182            scraper,
183            counter,
184            workers,
185            reqwest_client,
186        }
187    }
188
189    async fn shutdown(&mut self) -> Result<()> {
190        for _ in self.workers.iter() {
191            self.workinput_ch.tx.send_async(WorkInput::Exit).await?;
192        }
193
194        Ok(())
195    }
196
197    /// Schedule scraper to visit given url,
198    /// this will be executed on one of worker tasks
199    pub async fn navigate(&mut self, url: &str) -> Result<()> {
200        debug!("Increasing counter by 1");
201        self.counter.fetch_add(1, Ordering::SeqCst);
202        Ok(self
203            .workinput_ch
204            .tx
205            .send_async(WorkInput::Navigate(url.to_string()))
206            .await?)
207    }
208
209    /// Run processing loop for the given WebScraper
210    pub async fn run(&mut self) -> Result<()> {
211        enable_logging();
212
213        let ret = self.event_loop().await;
214        self.shutdown().await?;
215        ret
216    }
217
218    async fn event_loop(&mut self) -> Result<()> {
219        loop {
220            let output = self.workoutput_ch.rx.recv_async().await?;
221            let response_url;
222            let response_status;
223            let mut response_destination = None;
224
225            match output {
226                WorkOutput::Markup { text, url, status } => {
227                    info!("Fetched markup from: {}", url);
228                    self.scraper.dispatch_on_page(text.clone()).await?;
229                    let document = Document::from(text);
230                    response_url = url.clone();
231                    response_status = status;
232
233                    let selectors = self
234                        .scraper
235                        .all_html_selectors()
236                        .iter()
237                        .map(|s| s.to_string())
238                        .collect::<Vec<_>>();
239
240                    for selector in selectors {
241                        for el in document.select(selector.as_str()) {
242                            let response = Response::new(
243                                status,
244                                url.clone(),
245                                None,
246                                self.workinput_ch.tx.clone(),
247                                self.counter.clone(),
248                            );
249                            self.scraper
250                                .dispatch_on_html(selector.as_str(), response, el)
251                                .await?;
252                        }
253                    }
254                }
255                WorkOutput::Download { url, destination } => {
256                    debug!("Downloaded: {} -> {}", url, destination);
257                    response_url = url;
258                    response_destination = Some(destination);
259                    response_status = 200;
260                }
261                WorkOutput::Noop(url) => {
262                    debug!("Noop: {}", url);
263                    response_url = url;
264                    response_status = 304;
265                }
266                WorkOutput::Error(url, e) => {
267                    error!("Error from {}: {}", url, e);
268                    response_url = url;
269                    response_status = 500;
270                }
271                WorkOutput::Exit => {
272                    error!("Received exit output");
273                    response_url = "".to_string();
274                    response_status = 500;
275                }
276            }
277
278            let response = Response::new(
279                response_status,
280                response_url,
281                response_destination,
282                self.workinput_ch.tx.clone(),
283                self.counter.clone(),
284            );
285            self.scraper.dispatch_on_response(response).await?;
286
287            debug!("Decreasing counter by 1");
288            self.counter.fetch_sub(1, Ordering::SeqCst);
289
290            let cur_count = self.counter.load(Ordering::SeqCst);
291            debug!("Done processing work output, counter is at {}", cur_count);
292            debug!("Queue len: {}", self.workoutput_ch.rx.len());
293
294            if cur_count == 0 {
295                return Ok(());
296            }
297        }
298    }
299
300    /// Create and start new worker tasks.
301    /// Worker task will automatically exit after scraper instance is freed.
302    pub fn start_worker(&mut self) {
303        let visited_links = self.visited_links.clone();
304        let workinput_rx = self.workinput_ch.rx.clone();
305        let workoutput_tx = self.workoutput_ch.tx.clone();
306        let reqwest_client = self.reqwest_client.clone();
307
308        let worker = Worker::new(visited_links, workinput_rx, workoutput_tx, reqwest_client);
309
310        let handle = tokio::task::spawn(async move {
311            loop {
312                info!("🐿️ Starting http worker");
313
314                match worker.start().await {
315                    Ok(()) => {
316                        info!("Shutting down worker");
317                        break;
318                    }
319                    Err(e) => warn!("❌ Restarting worker: {}", e),
320                }
321            }
322        });
323
324        self.workers.push(handle);
325    }
326}
327
328struct Worker {
329    visited_links: Arc<RwLock<HashSet<String>>>,
330    workinput_rx: Receiver<WorkInput>,
331    workoutput_tx: Sender<WorkOutput>,
332    reqwest_client: reqwest::Client,
333}
334
335impl Worker {
336    fn new(
337        visited_links: Arc<RwLock<HashSet<String>>>,
338        workinput_rx: Receiver<WorkInput>,
339        workoutput_tx: Sender<WorkOutput>,
340        reqwest_client: reqwest::Client,
341    ) -> Self {
342        Worker {
343            visited_links,
344            workinput_rx,
345            workoutput_tx,
346            reqwest_client,
347        }
348    }
349
350    async fn start(&self) -> Result<()> {
351        let workoutput_tx = self.workoutput_tx.clone();
352
353        loop {
354            let workinput = self.workinput_rx.recv_async().await;
355            if workinput.is_err() {
356                continue;
357            }
358
359            let workinput = workinput?;
360            let payload = self.process_message(workinput).await;
361
362            match payload {
363                Ok(WorkOutput::Exit) => return Ok(()),
364                _ => workoutput_tx.send_async(payload?).await?,
365            }
366        }
367    }
368
369    async fn process_message(&self, workinput: WorkInput) -> Result<WorkOutput> {
370        match workinput {
371            WorkInput::Navigate(url) => {
372                let workoutput = self.navigate(url.clone()).await;
373
374                if let Err(e) = workoutput {
375                    Ok(WorkOutput::Error(url, e))
376                } else {
377                    workoutput
378                }
379            }
380            WorkInput::Download { url, destination } => {
381                let workoutput = self.download(url.clone(), destination).await;
382
383                if let Err(e) = workoutput {
384                    Ok(WorkOutput::Error(url, e))
385                } else {
386                    workoutput
387                }
388            }
389            WorkInput::Exit => Ok(WorkOutput::Exit),
390        }
391    }
392
393    async fn navigate(&self, url: String) -> Result<WorkOutput> {
394        let contains = self.visited_links.read().await.contains(&url.clone());
395
396        if !contains {
397            self.visited_links.write().await.insert(url.clone());
398            let response = self.reqwest_client.get(&url).send().await?;
399
400            WorkOutput::try_from_response(response, url.clone()).await
401        } else {
402            Ok(WorkOutput::Noop(url))
403        }
404    }
405
406    async fn download(&self, url: String, destination: String) -> Result<WorkOutput> {
407        let contains = self.visited_links.read().await.contains(&url.clone());
408
409        if !contains {
410            // need to notify parent about work being done
411            let response = self.reqwest_client.get(&*url).send().await?.bytes().await?;
412            let mut dest = File::create(destination.clone()).await?;
413            dest.write_all(&response).await?;
414
415            Ok(WorkOutput::Download { url, destination })
416        } else {
417            Ok(WorkOutput::Noop(url))
418        }
419    }
420}
421
422#[derive(Debug)]
423enum WorkOutput {
424    Markup {
425        url: String,
426        text: String,
427        status: u16,
428    },
429    Download {
430        url: String,
431        destination: String,
432    },
433    Noop(String),
434    Error(String, CrablerError),
435    Exit,
436}
437
438impl WorkOutput {
439    async fn try_from_response(response: reqwest::Response, url: String) -> Result<Self> {
440        let status = response.status().into();
441        let text = response.text().await?;
442
443        if text.is_empty() {
444            error!("body is empty")
445        }
446
447        Ok(WorkOutput::Markup { status, url, text })
448    }
449}