use anyhow::Result;
use futures::stream::Stream;
use futures::FutureExt;
use reqwest::IntoUrl;
use std::collections::{HashMap, HashSet, VecDeque};
use std::fmt;
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
mod domain;
pub mod error;
mod requests;
pub mod response;
pub mod robots;
pub use crate::requests::RequestDelay;
use crate::requests::{response_info, QueuedRequestBuilder};
pub use crate::response::Response;
pub use domain::{AllowList, AllowListConfig, BlockList, DomainListing};
pub use scraper;
#[must_use = "Collector does nothing until polled."]
pub struct Collector<T: Scraper> {
crawler: Crawler<T>,
pub scraper: T,
}
impl<T> Collector<T>
where
T: Scraper,
<T as Scraper>::State: fmt::Debug,
{
pub fn new(scraper: T, config: CrawlerConfig) -> Self {
Self {
crawler: Crawler::new(config),
scraper,
}
}
pub fn scraper(&self) -> &T {
&self.scraper
}
pub fn scraper_mut(&mut self) -> &mut T {
&mut self.scraper
}
pub fn crawler(&self) -> &Crawler<T> {
&self.crawler
}
pub fn crawler_mut(&mut self) -> &mut Crawler<T> {
&mut self.crawler
}
pub fn stats(&self) -> &Stats {
&self.crawler.stats
}
}
impl<T> Stream for Collector<T>
where
T: Scraper + Unpin + 'static,
<T as Scraper>::State: Unpin + Send + Sync + 'static,
<T as Scraper>::Output: Unpin,
{
type Item = Result<T::Output>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let pin = self.get_mut();
loop {
match pin.crawler.poll(cx) {
Poll::Ready(Some(result)) => match result {
CrawlResult::Finished(Ok(output)) => return Poll::Ready(Some(Ok(output))),
CrawlResult::Finished(Err(err)) => return Poll::Ready(Some(Err(err))),
CrawlResult::Crawled(Ok(response)) => {
pin.crawler.current_depth = response.depth;
pin.crawler.stats.response_count =
pin.crawler.stats.response_count.wrapping_add(1);
let output = pin.scraper.scrape(response, &mut pin.crawler);
pin.crawler.current_depth = 0;
match output {
Ok(Some(output)) => return Poll::Ready(Some(Ok(output))),
Err(err) => return Poll::Ready(Some(Err(err))),
_ => {}
}
}
CrawlResult::Crawled(Err(err)) => return Poll::Ready(Some(Err(err))),
},
Poll::Ready(None) => return Poll::Ready(None),
Poll::Pending => return Poll::Pending,
}
}
}
}
type OutputRequest<T> = Pin<Box<dyn Future<Output = Result<Option<T>>>>>;
type CrawlRequest<T> = Pin<Box<dyn Future<Output = Result<Response<T>>>>>;
pub struct Crawler<T: Scraper> {
in_progress_complete_requests: Vec<OutputRequest<T::Output>>,
in_progress_crawl_requests: Vec<CrawlRequest<T::State>>,
queued_results: VecDeque<CrawlResult<T>>,
client: reqwest::Client,
current_depth: usize,
list: DomainListing<T::State>,
stats: Stats,
max_depth: usize,
respect_robots_txt: bool,
skip_non_successful_responses: bool,
}
impl<T: Scraper> Crawler<T> {
pub fn new(config: CrawlerConfig) -> Self {
let client = config.client.unwrap_or_default();
let list = if config.allowed_domains.is_empty() {
let block_list = BlockList::new(
config.disallowed_domains,
client.clone(),
config.respect_robots_txt,
config.skip_non_successful_responses,
config.max_depth.unwrap_or(usize::MAX),
config
.max_requests
.unwrap_or(CrawlerConfig::MAX_CONCURRENT_REQUESTS),
);
DomainListing::BlockList(block_list)
} else {
let mut allow_list = AllowList::default();
let max_requests = config
.max_requests
.unwrap_or(CrawlerConfig::MAX_CONCURRENT_REQUESTS)
/ config.allowed_domains.len();
for (domain, delay) in config.allowed_domains {
let allow = AllowListConfig {
delay,
respect_robots_txt: config.respect_robots_txt,
client: client.clone(),
skip_non_successful_responses: config.skip_non_successful_responses,
max_depth: config.max_depth.unwrap_or(usize::MAX),
max_requests,
};
allow_list.allow(domain, allow);
}
DomainListing::AllowList(allow_list)
};
Self {
in_progress_complete_requests: Default::default(),
in_progress_crawl_requests: Default::default(),
queued_results: Default::default(),
client,
current_depth: 0,
list,
stats: Default::default(),
max_depth: config.max_depth.unwrap_or(usize::MAX),
respect_robots_txt: config.respect_robots_txt,
skip_non_successful_responses: config.skip_non_successful_responses,
}
}
pub fn max_depth(&self) -> usize {
self.max_depth
}
pub fn respects_robots_txt(&self) -> bool {
self.respect_robots_txt
}
pub fn skips_non_successful_responses(&self) -> bool {
self.skip_non_successful_responses
}
}
impl<T> Crawler<T>
where
T: Scraper + Unpin + 'static,
<T as Scraper>::State: Unpin + Send + Sync + 'static,
<T as Scraper>::Output: Unpin,
{
pub fn crawl<TCrawlFunction, TCrawlFuture>(&mut self, fun: TCrawlFunction)
where
TCrawlFunction: FnOnce(&reqwest::Client) -> TCrawlFuture,
TCrawlFuture: Future<Output = Result<(reqwest::Response, Option<T::State>)>> + 'static,
{
let depth = self.current_depth + 1;
let fut = (fun)(&self.client);
let fut = Box::pin(async move {
let (mut resp, state) = fut.await?;
let (status, url, headers) = response_info(&mut resp);
let text = resp.text().await?;
Ok(Response {
depth,
request_url: url.clone(),
response_url: url,
response_status: status,
response_headers: headers,
text,
state,
})
});
self.in_progress_crawl_requests.push(fut)
}
pub fn complete<TCrawlFunction, TCrawlFuture>(&mut self, fun: TCrawlFunction)
where
TCrawlFunction: FnOnce(&reqwest::Client) -> TCrawlFuture,
TCrawlFuture: Future<Output = Result<Option<T::Output>>> + 'static,
{
let fut = (fun)(&self.client);
self.in_progress_complete_requests.push(Box::pin(fut))
}
pub fn visit(&mut self, url: impl IntoUrl) {
self.request(self.client.request(reqwest::Method::GET, url))
}
pub fn visit_with_state(&mut self, url: impl IntoUrl, state: T::State) {
self.request_with_state(self.client.request(reqwest::Method::GET, url), state)
}
pub fn request(&mut self, req: reqwest::RequestBuilder) {
self.queue_request(req, None)
}
pub fn request_with_state(&mut self, req: reqwest::RequestBuilder, state: T::State) {
self.queue_request(req, Some(state))
}
fn queue_request(&mut self, request: reqwest::RequestBuilder, state: Option<T::State>) {
let req = QueuedRequestBuilder {
request,
state,
depth: self.current_depth + 1,
};
if let Err(err) = self.list.add_request(req) {
self.queued_results
.push_back(CrawlResult::Crawled(Err(err.into())))
}
}
pub fn client(&self) -> &reqwest::Client {
&self.client
}
fn poll(&mut self, cx: &mut Context<'_>) -> Poll<Option<CrawlResult<T>>> {
loop {
if let Some(result) = self.queued_results.pop_front() {
return Poll::Ready(Some(result));
}
for n in (0..self.in_progress_complete_requests.len()).rev() {
let mut request = self.in_progress_complete_requests.swap_remove(n);
if let Poll::Ready(resp) = request.poll_unpin(cx) {
match resp {
Ok(Some(output)) => {
self.queued_results
.push_back(CrawlResult::Finished(Ok(output)));
}
Err(err) => {
self.queued_results
.push_back(CrawlResult::Finished(Err(err)));
}
_ => {}
}
} else {
self.in_progress_complete_requests.push(request);
}
}
for n in (0..self.in_progress_crawl_requests.len()).rev() {
let mut request = self.in_progress_crawl_requests.swap_remove(n);
if let Poll::Ready(resp) = request.poll_unpin(cx) {
self.queued_results.push_back(CrawlResult::Crawled(resp));
} else {
self.in_progress_crawl_requests.push(request);
}
}
let mut busy = false;
loop {
match Stream::poll_next(Pin::new(&mut self.list), cx) {
Poll::Ready(Some(resp)) => {
self.queued_results.push_back(CrawlResult::Crawled(resp));
}
Poll::Pending => {
busy = true;
break;
}
_ => break,
}
}
if self.queued_results.is_empty() {
if !busy
&& self.in_progress_crawl_requests.is_empty()
&& self.in_progress_complete_requests.is_empty()
{
return Poll::Ready(None);
}
return Poll::Pending;
}
}
}
}
enum CrawlResult<T: Scraper> {
Finished(Result<T::Output>),
Crawled(Result<Response<T::State>>),
}
pub trait Scraper: Sized {
type Output;
type State: fmt::Debug;
fn scrape(
&mut self,
response: Response<Self::State>,
crawler: &mut Crawler<Self>,
) -> Result<Option<Self::Output>>;
}
#[derive(Debug, Clone, Copy, Default)]
pub struct Stats {
pub request_count: usize,
pub response_count: usize,
}
pub struct CrawlerConfig {
max_depth: Option<usize>,
max_requests: Option<usize>,
skip_non_successful_responses: bool,
allowed_domains: HashMap<String, Option<RequestDelay>>,
disallowed_domains: HashSet<String>,
respect_robots_txt: bool,
client: Option<reqwest::Client>,
}
impl Default for CrawlerConfig {
fn default() -> Self {
Self {
max_depth: None,
max_requests: None,
skip_non_successful_responses: true,
allowed_domains: Default::default(),
disallowed_domains: Default::default(),
respect_robots_txt: false,
client: None,
}
}
}
impl CrawlerConfig {
const MAX_CONCURRENT_REQUESTS: usize = 1_00;
pub fn max_depth(mut self, max_depth: usize) -> Self {
self.max_depth = Some(max_depth);
self
}
pub fn respect_robots_txt(mut self) -> Self {
self.respect_robots_txt = true;
self
}
pub fn scrape_non_success_response(mut self) -> Self {
self.skip_non_successful_responses = false;
self
}
pub fn set_client(mut self, client: reqwest::Client) -> Self {
self.client = Some(client);
self
}
#[deprecated(
since = "0.2.0",
note = "You do not have to wrap the Client it in a `Arc` to reuse it, because it already uses an `Arc` internally. Users should use `set_client` instead."
)]
pub fn with_shared_client(mut self, client: std::sync::Arc<reqwest::Client>) -> Self {
self.client = Some(client.as_ref().clone());
self
}
pub fn disallow_domain(mut self, domain: impl Into<String>) -> Self {
self.disallowed_domains.insert(domain.into());
self
}
pub fn disallow_domains<I, T>(mut self, domains: I) -> Self
where
I: IntoIterator<Item = T>,
T: Into<String>,
{
for domain in domains.into_iter() {
self.disallowed_domains.insert(domain.into());
}
self
}
pub fn allow_domain_with_delay(
mut self,
domain: impl Into<String>,
delay: RequestDelay,
) -> Self {
self.allowed_domains.insert(domain.into(), Some(delay));
self
}
pub fn allow_domain(mut self, domain: impl Into<String>) -> Self {
self.allowed_domains.insert(domain.into(), None);
self
}
pub fn allow_domains<I, T>(mut self, domains: I) -> Self
where
I: IntoIterator<Item = T>,
T: Into<String>,
{
for domain in domains.into_iter() {
self.allowed_domains.insert(domain.into(), None);
}
self
}
pub fn allow_domains_with_delay<I, T>(mut self, domains: I) -> Self
where
I: IntoIterator<Item = (T, RequestDelay)>,
T: Into<String>,
{
for (domain, delay) in domains.into_iter() {
self.allowed_domains.insert(domain.into(), Some(delay));
}
self
}
pub fn max_concurrent_requests(mut self, max_requests: usize) -> Self {
self.max_requests = Some(max_requests);
self
}
}