use std::collections::{HashMap, HashSet};
use std::sync::{Arc, Mutex, mpsc};
use std::thread;
use std::time::Duration;
trait Fetcher: Send + Sync + 'static {
fn fetch(&self, url: &str) -> Result<Vec<String>, String>;
}
struct FakeFetcher {
results: HashMap<String, Result<Vec<String>, String>>,
}
impl Fetcher for FakeFetcher {
fn fetch(&self, url: &str) -> Result<Vec<String>, String> {
println!("Fetching: {}", url);
thread::sleep(Duration::from_millis(100));
match self.results.get(url) {
Some(Ok(urls)) => {
println!("Found: {}", url);
Ok(urls.clone())
}
Some(Err(e)) => {
println!("Error on: {}", url);
Err(e.clone())
}
None => {
println!("Missing: {}", url);
Err(format!("not found: {}", url))
}
}
}
}
fn get_fake_fetcher() -> FakeFetcher {
let mut results = HashMap::new();
results.insert(
"https://example.com/".to_string(),
Ok(vec![
"https://example.com/page1".to_string(),
"https://example.com/page2".to_string(),
]),
);
results.insert(
"https://example.com/page1".to_string(),
Ok(vec![
"https://example.com/".to_string(),
"https://example.com/page3".to_string(),
]),
);
results.insert(
"https://example.com/page2".to_string(),
Ok(vec!["https://example.com/".to_string()]),
);
results.insert(
"https://example.com/page3".to_string(),
Ok(vec![
"https://example.com/page1".to_string(),
"https://example.com/page4".to_string(),
]),
);
results.insert("https://example.com/page4".to_string(), Ok(vec![]));
FakeFetcher { results }
}
fn serial_crawler(url: String, fetcher: &impl Fetcher, fetched: &mut HashSet<String>) {
if fetched.contains(&url) {
return;
}
fetched.insert(url.clone());
if let Ok(fetched_urls) = fetcher.fetch(&url) {
for url in fetched_urls {
serial_crawler(url, fetcher, fetched);
}
}
}
fn concurrent_mutex_crawler(
url: String,
fetcher: Arc<impl Fetcher>,
fetched: Arc<Mutex<HashSet<String>>>,
) {
{
let mut cache = fetched.lock().unwrap();
if cache.contains(&url) {
return;
}
cache.insert(url.clone());
}
let mut threads = vec![];
if let Ok(fetched_urls) = fetcher.fetch(&url) {
for url in fetched_urls {
let clone_fetcher = fetcher.clone();
let fetched_clone = fetched.clone();
threads.push(thread::spawn(move || {
concurrent_mutex_crawler(url, clone_fetcher, fetched_clone);
}));
}
}
for thread in threads {
thread.join().unwrap();
}
}
fn concurrent_channel_crawler(url: String, fetcher: Arc<impl Fetcher>) -> HashSet<String> {
let (tx, rx) = mpsc::channel::<Result<Vec<String>, String>>();
let mut fetched = HashSet::new();
let mut outstanding_fetches = 0;
if !fetched.contains(&url) {
fetched.insert(url.clone());
outstanding_fetches += 1;
let tx_initial = tx.clone();
let fetcher_initial = Arc::clone(&fetcher);
thread::spawn(move || {
tx_initial.send(fetcher_initial.fetch(&url)).unwrap();
});
}
while outstanding_fetches > 0 {
let result = rx.recv().unwrap();
outstanding_fetches -= 1;
if let Ok(urls) = result {
for u in urls {
if !fetched.contains(&u) {
fetched.insert(u.clone());
outstanding_fetches += 1;
let tx_worker = tx.clone();
let fetcher_worker = Arc::clone(&fetcher);
thread::spawn(move || {
tx_worker.send(fetcher_worker.fetch(&u)).unwrap();
});
}
}
}
}
fetched
}
fn main() {
let fetcher = get_fake_fetcher();
let fetcher_arc = Arc::new(fetcher);
println!("--- Serial Crawler ---");
let mut fetched_serial = HashSet::new();
serial_crawler(
"https://example.com/".to_string(),
&*fetcher_arc,
&mut fetched_serial,
);
println!("----------------------\n");
println!("--- Concurrent Mutex Crawler ---");
let fetched_mutex = Arc::new(Mutex::new(HashSet::new()));
concurrent_mutex_crawler(
"https://example.com/".to_string(),
fetcher_arc.clone(),
fetched_mutex,
);
println!("----------------------\n");
println!("--- Concurrent Channel Crawler ---");
concurrent_channel_crawler("https://example.com/".to_string(), fetcher_arc);
println!("----------------------\n");
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashSet;
fn get_expected_urls() -> HashSet<String> {
[
"https://example.com/",
"https://example.com/page1",
"https://example.com/page2",
"https://example.com/page3",
"https://example.com/page4",
]
.iter()
.map(|s| s.to_string())
.collect()
}
#[test]
fn test_serial_crawler_happy_path() {
let fetcher = get_fake_fetcher();
let mut fetched = HashSet::new();
serial_crawler("https://example.com/".to_string(), &fetcher, &mut fetched);
let expected = get_expected_urls();
assert_eq!(fetched, expected);
}
#[test]
fn test_concurrent_mutex_crawler_happy_path() {
let fetcher = Arc::new(get_fake_fetcher());
let fetched_arc = Arc::new(Mutex::new(HashSet::new()));
concurrent_mutex_crawler(
"https://example.com/".to_string(),
fetcher,
Arc::clone(&fetched_arc),
);
let fetched_guard = fetched_arc.lock().unwrap();
let expected = get_expected_urls();
assert_eq!(*fetched_guard, expected);
}
#[test]
fn test_concurrent_channel_crawler_happy_path() {
let fetcher = Arc::new(get_fake_fetcher());
let fetched = concurrent_channel_crawler("https://example.com/".to_string(), fetcher);
let expected = get_expected_urls();
assert_eq!(fetched, expected);
}
#[test]
fn test_start_with_nonexistent_url() {
let fetcher = get_fake_fetcher();
let url = "https://nonexistent.com/".to_string();
let expected: HashSet<String> = [url.clone()].iter().cloned().collect();
let mut fetched_serial = HashSet::new();
serial_crawler(url.clone(), &fetcher, &mut fetched_serial);
assert_eq!(
fetched_serial, expected,
"Serial crawler failed on non-existent URL"
);
let fetcher_arc = Arc::new(get_fake_fetcher()); let fetched_mutex = Arc::new(Mutex::new(HashSet::new()));
concurrent_mutex_crawler(url.clone(), fetcher_arc.clone(), fetched_mutex.clone());
assert_eq!(
*fetched_mutex.lock().unwrap(),
expected,
"Mutex crawler failed on non-existent URL"
);
let fetched_channel = concurrent_channel_crawler(url.clone(), fetcher_arc);
assert_eq!(
fetched_channel, expected,
"Channel crawler failed on non-existent URL"
);
}
#[test]
fn test_single_url_no_links() {
let mut results = HashMap::new();
let url = "https://single.com/".to_string();
results.insert(url.clone(), Ok(vec![]));
let fetcher = FakeFetcher { results };
let expected: HashSet<String> = [url.clone()].iter().cloned().collect();
let mut fetched_serial = HashSet::new();
serial_crawler(url.clone(), &fetcher, &mut fetched_serial);
assert_eq!(
fetched_serial, expected,
"Serial crawler failed on single URL"
);
let fetcher_arc = Arc::new(fetcher);
let fetched_mutex = Arc::new(Mutex::new(HashSet::new()));
concurrent_mutex_crawler(url.clone(), fetcher_arc.clone(), fetched_mutex.clone());
assert_eq!(
*fetched_mutex.lock().unwrap(),
expected,
"Mutex crawler failed on single URL"
);
let fetched_channel = concurrent_channel_crawler(url.clone(), fetcher_arc);
assert_eq!(
fetched_channel, expected,
"Channel crawler failed on single URL"
);
}
#[test]
fn test_crawler_with_fetch_error() {
let mut results = HashMap::new();
let start_url = "https://start.com/".to_string();
let ok_url = "https://ok.com/".to_string();
let bad_url = "https://bad.com/".to_string();
results.insert(start_url.clone(), Ok(vec![ok_url.clone(), bad_url.clone()]));
results.insert(ok_url.clone(), Ok(vec![]));
results.insert(bad_url.clone(), Err("permanent failure".to_string()));
let fetcher = FakeFetcher { results };
let expected: HashSet<String> = [start_url.clone(), ok_url.clone(), bad_url.clone()]
.iter()
.cloned()
.collect();
let mut fetched_serial = HashSet::new();
serial_crawler(start_url.clone(), &fetcher, &mut fetched_serial);
assert_eq!(
fetched_serial, expected,
"Serial crawler failed with fetch error"
);
let fetcher_arc = Arc::new(fetcher);
let fetched_mutex = Arc::new(Mutex::new(HashSet::new()));
concurrent_mutex_crawler(
start_url.clone(),
fetcher_arc.clone(),
fetched_mutex.clone(),
);
assert_eq!(
*fetched_mutex.lock().unwrap(),
expected,
"Mutex crawler failed with fetch error"
);
let fetched_channel = concurrent_channel_crawler(start_url.clone(), fetcher_arc);
assert_eq!(
fetched_channel, expected,
"Channel crawler failed with fetch error"
);
}
}