use flate2::read::GzDecoder;
use indicatif::{MultiProgress, ProgressBar, ProgressStyle};
use regex::Regex;
use reqwest::{Client, Url, header};
use reqwest_middleware::{ClientBuilder, ClientWithMiddleware};
use reqwest_retry::{Jitter, RetryTransientMiddleware, policies::ExponentialBackoff};
use std::{
fs::File,
io::{BufRead, BufReader},
path::{Path, PathBuf},
process, str,
sync::Arc,
time::Duration,
};
use tokio::{
io::{AsyncWriteExt, BufWriter},
sync::Semaphore,
task::JoinSet,
};
use crate::errors::DownloadError;
const BASE_URL: &str = "https://data.commoncrawl.org/";
static APP_USER_AGENT: &str = concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION"),);
pub struct DownloadOptions<'a> {
pub snapshot: String,
pub data_type: &'a str,
pub paths: &'a Path,
pub dst: &'a Path,
pub threads: usize,
pub max_retries: usize,
pub numbered: bool,
pub files_only: bool,
pub progress: bool,
}
pub struct TaskOptions {
pub number: usize,
pub path: String,
pub dst: PathBuf,
pub numbered: bool,
pub files_only: bool,
pub progress: bool,
}
impl Default for DownloadOptions<'_> {
fn default() -> Self {
DownloadOptions {
snapshot: "".to_string(),
data_type: "",
paths: Path::new(""),
dst: Path::new(""),
threads: 1,
max_retries: 1000,
numbered: false,
files_only: false,
progress: false,
}
}
}
fn new_client(max_retries: usize) -> Result<ClientWithMiddleware, DownloadError> {
let retry_policy = ExponentialBackoff::builder()
.retry_bounds(Duration::from_secs(1), Duration::from_secs(3600))
.jitter(Jitter::Bounded)
.base(2)
.build_with_max_retries(u32::try_from(max_retries).unwrap());
let client_base = Client::builder().user_agent(APP_USER_AGENT).build()?;
Ok(ClientBuilder::new(client_base)
.with(RetryTransientMiddleware::new_with_policy(retry_policy))
.build())
}
pub async fn download_paths(mut options: DownloadOptions<'_>) -> Result<(), DownloadError> {
let news_re = Regex::new(r"^(CC\-NEWS)\-([0-9]{4})\-([0-9]{2})$").unwrap();
let snapshot_original_ref = options.snapshot.clone();
if news_re.is_match(&options.snapshot) {
let caps = news_re.captures(&options.snapshot).unwrap();
options.snapshot = format!("{}/{}/{}", &caps[1], &caps[2], &caps[3]);
}
let paths = format!(
"{}crawl-data/{}/{}.paths.gz",
BASE_URL, options.snapshot, options.data_type
);
println!("Downloading paths from: {}", paths);
let url = Url::parse(&paths)?;
let client = new_client(options.max_retries)?;
let filename = url
.path_segments() .and_then(|segments| segments.last()) .unwrap_or("file.download");
let resp = client.head(url.as_str()).send().await?;
match resp.status() {
status if status.is_success() => (),
status if status.as_u16() == 404 => {
return Err(format!(
"\n\nThe reference combination you requested:\n\tCRAWL: {}\n\tSUBSET: {}\n\tURL: {}\n\nDoesn't seem to exist or it is currently not accessible.\n\tError code: {} {}",
snapshot_original_ref, options.data_type, url, status.as_str(), status.canonical_reason().unwrap_or("")
)
.into());
}
status => {
return Err(format!(
"Couldn't download URL: {}. Error code: {} {}",
url,
status.as_str(),
status.canonical_reason().unwrap_or("")
)
.into());
}
}
let request = client.get(url.as_str());
let mut dst = options.dst.to_path_buf();
dst.push(filename);
let outfile = tokio::fs::File::create(dst.clone()).await?;
let mut outfile = BufWriter::new(outfile);
let mut download = request.send().await?;
while let Some(chunk) = download.chunk().await? {
outfile.write_all(&chunk).await?; }
outfile.flush().await?;
println!("Downloaded paths to: {}", dst.to_str().unwrap());
Ok(())
}
async fn download_task(
client: ClientWithMiddleware,
multibar: Arc<MultiProgress>,
task_options: TaskOptions,
) -> Result<(), DownloadError> {
let url = Url::parse(&task_options.path)?;
let download_size = {
let resp = client.head(url.as_str()).send().await?;
if resp.status().is_success() {
resp.headers() .get(header::CONTENT_LENGTH) .and_then(|ct_len| ct_len.to_str().ok()) .and_then(|ct_len| ct_len.parse().ok()) .unwrap_or(0) } else {
return Err(
format!("Couldn't download URL: {}. Error: {:?}", url, resp.status()).into(),
);
}
};
let filename = if task_options.numbered {
&format!("{}{}", task_options.number, ".txt.gz")
} else if task_options.files_only {
url.path_segments()
.and_then(|segments| segments.last())
.unwrap_or("file.download")
} else {
url.path().strip_prefix("/").unwrap_or("file.download")
};
let mut dst = task_options.dst.clone();
dst.push(filename);
let request = client.get(url.as_str());
let progress_bar = multibar.add(ProgressBar::new(download_size));
if task_options.progress {
progress_bar.set_style(
ProgressStyle::default_bar()
.template("[{bar:40.cyan/blue}] {bytes}/{total_bytes} - {msg}")?
.progress_chars("#>-"),
);
progress_bar.set_message(filename.to_owned());
} else {
println!("Downloading: {}", url.as_str());
}
if !task_options.numbered {
if let Some(parent) = dst.parent() {
tokio::fs::create_dir_all(parent).await?;
}
}
let outfile = tokio::fs::File::create(dst.clone()).await?;
let mut outfile = BufWriter::new(outfile);
let mut download = request.send().await?;
while let Some(chunk) = download.chunk().await? {
if task_options.progress {
progress_bar.inc(chunk.len() as u64); }
outfile.write_all(&chunk).await?; }
if task_options.progress {
progress_bar.finish();
multibar.remove(&progress_bar);
} else {
multibar.remove(&progress_bar);
println!("Downloaded file to: {}", dst.to_str().unwrap());
}
outfile.flush().await?;
Ok(())
}
pub async fn download(options: DownloadOptions<'_>) -> Result<(), DownloadError> {
let file = {
let gzip_file = match File::open(options.paths) {
Ok(file) => file,
Err(e) => {
eprintln!(
"Could not open file {}\nError: {}",
options.paths.display(),
e
);
process::exit(1)
}
};
let file_decoded = GzDecoder::new(gzip_file);
BufReader::new(file_decoded)
};
let paths: Vec<(usize, String)> = file
.lines()
.map(|line| {
let line = line.unwrap();
format!("{}{}", BASE_URL, line)
})
.enumerate()
.collect();
let multibar = std::sync::Arc::new(indicatif::MultiProgress::new());
let main_pb = std::sync::Arc::new(
multibar
.clone()
.add(indicatif::ProgressBar::new(paths.len() as u64)),
);
if options.progress {
main_pb.set_style(
indicatif::ProgressStyle::default_bar().template("{msg} {bar:10} {pos}/{len}")?,
);
main_pb.set_message("total ");
main_pb.tick();
}
let client = new_client(options.max_retries)?;
let semaphore = Arc::new(Semaphore::new(options.threads));
let mut set = JoinSet::new();
for (number, path) in paths {
let multibar = multibar.clone();
let main_pb = main_pb.clone();
let client = client.clone();
let dst = options.dst.to_path_buf();
let semaphore = semaphore.clone();
set.spawn(async move {
let _permit = semaphore.acquire().await;
let task_options = TaskOptions {
path,
number,
dst,
numbered: options.numbered,
files_only: options.files_only,
progress: options.progress,
};
let res = download_task(client, multibar, task_options).await;
if options.progress {
main_pb.inc(1);
}
res
});
}
let multibar = {
let multibar = multibar.clone();
tokio::task::spawn_blocking(move || multibar)
};
while let Some(result) = set.join_next().await {
match result {
Ok(Ok(())) => {}
Ok(Err(e)) => eprintln!("Error: {:?}", e),
Err(e) => eprintln!("Error: {:?}", e),
}
}
if options.progress {
main_pb.finish_with_message("done");
multibar.await?;
} else {
println!("All downloads completed");
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use serde::Deserialize;
use std::collections::HashMap;
#[derive(Deserialize, Debug)]
pub struct HeadersEcho {
pub headers: HashMap<String, String>,
}
#[test]
fn user_agent_format() {
assert_eq!(
APP_USER_AGENT,
concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION"),)
);
}
#[tokio::test]
async fn user_agent_test() -> Result<(), DownloadError> {
let client = new_client(1000)?;
let response = client.get("http://httpbin.org/headers").send().await?;
let out: HeadersEcho = response.json().await?;
assert_eq!(out.headers["User-Agent"], APP_USER_AGENT);
Ok(())
}
}