use std::cmp::Ordering;
use std::collections::{BinaryHeap, HashMap, HashSet};
use std::sync::RwLock;
use url::Url;
use crate::types::config::CrawlStrategy;
#[derive(Debug, Clone)]
pub struct CrawlEntry {
pub url: Url,
pub depth: u32,
pub priority: i32,
pub parent_url: Option<Url>,
pub added_at: std::time::Instant,
}
impl CrawlEntry {
pub fn new(url: Url, depth: u32, priority: i32) -> Self {
Self {
url,
depth,
priority,
parent_url: None,
added_at: std::time::Instant::now(),
}
}
pub fn with_parent(mut self, parent: Url) -> Self {
self.parent_url = Some(parent);
self
}
}
impl Eq for CrawlEntry {}
impl PartialEq for CrawlEntry {
fn eq(&self, other: &Self) -> bool {
self.url == other.url
}
}
impl PartialOrd for CrawlEntry {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for CrawlEntry {
fn cmp(&self, other: &Self) -> Ordering {
self.priority
.cmp(&other.priority)
.then_with(|| other.depth.cmp(&self.depth)) }
}
#[allow(dead_code)]
pub struct Frontier {
queue: RwLock<BinaryHeap<CrawlEntry>>,
seen: RwLock<HashSet<String>>,
domain_counts: RwLock<HashMap<String, u32>>,
strategy: CrawlStrategy,
max_depth: u32,
max_per_domain: u32,
max_total: u32,
}
impl Frontier {
pub fn new(strategy: CrawlStrategy, max_depth: u32, max_per_domain: u32, max_total: u32) -> Self {
Self {
queue: RwLock::new(BinaryHeap::new()),
seen: RwLock::new(HashSet::new()),
domain_counts: RwLock::new(HashMap::new()),
strategy,
max_depth,
max_per_domain,
max_total,
}
}
pub fn push(&self, entry: CrawlEntry) -> bool {
if entry.depth > self.max_depth {
return false;
}
let url_key = entry.url.to_string();
let domain = entry.url.host_str().unwrap_or("").to_string();
{
let seen = self.seen.read().unwrap();
if seen.contains(&url_key) {
return false;
}
}
{
let counts = self.domain_counts.read().unwrap();
if let Some(&count) = counts.get(&domain) {
if count >= self.max_per_domain {
return false;
}
}
}
{
let seen = self.seen.read().unwrap();
if seen.len() >= self.max_total as usize {
return false;
}
}
{
let mut seen = self.seen.write().unwrap();
let mut queue = self.queue.write().unwrap();
let mut counts = self.domain_counts.write().unwrap();
seen.insert(url_key);
queue.push(entry);
*counts.entry(domain).or_insert(0) += 1;
}
true
}
pub fn push_many(&self, entries: Vec<CrawlEntry>) -> usize {
entries.into_iter().filter(|e| self.push(e.clone())).count()
}
pub fn pop(&self) -> Option<CrawlEntry> {
let mut queue = self.queue.write().unwrap();
queue.pop()
}
pub fn len(&self) -> usize {
self.queue.read().unwrap().len()
}
pub fn is_empty(&self) -> bool {
self.queue.read().unwrap().is_empty()
}
pub fn seen_count(&self) -> usize {
self.seen.read().unwrap().len()
}
pub fn has_seen(&self, url: &Url) -> bool {
self.seen.read().unwrap().contains(&url.to_string())
}
pub fn mark_seen(&self, url: &Url) {
self.seen.write().unwrap().insert(url.to_string());
}
pub fn clear(&self) {
let mut queue = self.queue.write().unwrap();
let mut seen = self.seen.write().unwrap();
let mut counts = self.domain_counts.write().unwrap();
queue.clear();
seen.clear();
counts.clear();
}
pub fn domain_stats(&self) -> HashMap<String, u32> {
self.domain_counts.read().unwrap().clone()
}
}