use std::{
collections::HashMap,
env::{self, current_dir},
fs::{self, File, read_to_string},
path::{Path, PathBuf},
str::FromStr,
sync::Arc,
};
use color_eyre::eyre::{Error as ColorError, eyre};
use futures::future::try_join_all;
use indicatif::{MultiProgress, ProgressBar, ProgressStyle};
use jiff::Timestamp;
use log::{debug, info, warn};
use prettytable::{Table, row};
use reqwest::Client;
use serde::{Deserialize, Serialize};
use tokio::task::JoinHandle;
use crate::{
EntryError, RegistryError, ValidationError,
data::{DownloadStatus, RefDataset},
downloads::{check_url, request_dataset},
validate::UnvalidatedFile,
};
#[derive(Debug, Default, Serialize, Deserialize, Clone)]
pub struct Project {
project: Registry,
}
type MultiDownloadResults = Vec<Result<UnvalidatedFile, ColorError>>;
impl Project {
fn new(title: Option<String>, description: Option<String>, global: bool) -> Self {
let registry = Registry {
title,
description,
global,
..Registry::default()
};
Self { project: registry }
}
#[inline]
#[must_use]
pub fn datasets(&self) -> &[RefDataset] {
self.project.datasets.as_slice()
}
#[inline]
pub fn datasets_mut(&mut self) -> &mut [RefDataset] {
self.project.datasets.as_mut_slice()
}
#[inline]
#[must_use]
pub fn datasets_owned(self) -> Vec<RefDataset> {
self.project.datasets
}
#[inline]
pub fn get_dataset(&self, label: &str) -> Result<&RefDataset, EntryError> {
let datasets = self.datasets();
if datasets
.iter()
.map(|dataset| dataset.label.as_str())
.filter(|ds_label| *ds_label == label)
.collect::<Vec<&str>>()
.is_empty()
{
Err(EntryError::LabelNotFound(label.to_string()))?;
}
let entry: Vec<_> = datasets
.iter()
.filter(|dataset| dataset.label == label)
.collect();
assert_eq!(entry.len(), 1);
Ok(entry[0])
}
#[inline]
pub fn get_dataset_urls(&self, label: &str) -> Result<Vec<String>, EntryError> {
let dataset = self.get_dataset(label)?;
let urls = vec![
dataset.fasta.clone(),
dataset.genbank.clone(),
dataset.gfa.clone(),
dataset.gff.clone(),
dataset.gtf.clone(),
dataset.bed.clone(),
]
.into_iter()
.flatten()
.map(|download| download.url_owned())
.collect::<Vec<String>>();
Ok(urls)
}
#[inline]
pub fn get_all_urls(&self) -> Result<Vec<String>, EntryError> {
let datasets = self.datasets();
let mut all_urls = Vec::new();
for dataset in datasets {
let urls = vec![
dataset.fasta.clone(),
dataset.genbank.clone(),
dataset.gfa.clone(),
dataset.gff.clone(),
dataset.gtf.clone(),
dataset.bed.clone(),
]
.into_iter()
.flatten()
.map(|download| download.url_owned())
.collect::<Vec<String>>();
all_urls.extend(urls);
}
assert!(
all_urls.iter().all(|url| !url.is_empty()),
"Found empty URLs in dataset"
);
assert!(
all_urls
.iter()
.all(|url| url.starts_with("http://") || url.starts_with("https://")),
"Found invalid URL protocols"
);
Ok(all_urls)
}
#[must_use]
pub fn is_registered(&self, label: &str) -> bool {
!self
.datasets()
.iter()
.filter(|dataset| dataset.label == label)
.collect::<Vec<&RefDataset>>()
.is_empty()
}
pub async fn register(mut self, new_dataset: RefDataset) -> Result<Self, EntryError> {
let Some(dataset_match_idx) = self.get_dataset_idx(&new_dataset.label) else {
self.project.datasets.push(new_dataset);
return Ok(self);
};
let previous_datasets = self.datasets_mut();
let dataset_to_update = &mut previous_datasets[dataset_match_idx];
match new_dataset {
RefDataset {
fasta: Some(ref fasta),
..
} => {
let url_str = fasta.url();
if is_likely_url(url_str) {
let _ = check_url(url_str).await?;
} else if !PathBuf::from(url_str).is_file() {
return Err(EntryError::InvalidURL(eyre!(
"The provided uri {url_str} was not a web link, nor was it a local file path pointing to something that exists."
)));
}
dataset_to_update.fasta = new_dataset.fasta;
},
RefDataset {
genbank: Some(ref genbank),
..
} => {
let url_str = genbank.url();
if is_likely_url(url_str) {
let _ = check_url(url_str).await?;
}
dataset_to_update.genbank = new_dataset.genbank;
},
RefDataset {
gfa: Some(ref gfa), ..
} => {
let url_str = gfa.url();
if is_likely_url(url_str) {
let _ = check_url(url_str).await?;
}
dataset_to_update.gfa = new_dataset.gfa;
},
RefDataset {
gff: Some(ref gff), ..
} => {
let url_str = gff.url();
if is_likely_url(url_str) {
let _ = check_url(url_str).await?;
}
dataset_to_update.gff = new_dataset.gff;
},
RefDataset {
gtf: Some(ref gtf), ..
} => {
let url_str = gtf.url();
if is_likely_url(url_str) {
let _ = check_url(url_str).await?;
}
dataset_to_update.gtf = new_dataset.gtf;
},
RefDataset {
bed: Some(ref bed), ..
} => {
let url_str = bed.url();
if is_likely_url(url_str) {
let _ = check_url(url_str).await?;
}
dataset_to_update.bed = new_dataset.bed;
},
RefDataset {
tar: Some(ref tar), ..
} => {
let url_str = tar.url();
if is_likely_url(url_str) {
let _ = check_url(url_str).await?;
}
dataset_to_update.tar = new_dataset.tar;
},
RefDataset {
label: _,
fasta: None,
genbank: None,
gfa: None,
gff: None,
gtf: None,
bed: None,
tar: None,
} => return Err(EntryError::LabelButNoFiles),
}
Ok(self)
}
#[inline]
fn get_dataset_idx(&self, label: &str) -> Option<usize> {
let dataset_match_indices: Vec<_> = self
.datasets()
.iter()
.enumerate()
.filter(|(_i, dataset)| dataset.label == label)
.map(|(i, _)| i)
.collect();
if dataset_match_indices.is_empty() {
return None;
}
assert_eq!(
dataset_match_indices.len(),
1,
"Invalid state slipped through the cracks when identifying which dataset should be updated with the new information for dataset '{}'. Somehow, multiple indices were returned: {:?}",
label,
&dataset_match_indices
);
Some(dataset_match_indices[0])
}
#[allow(clippy::similar_names)]
pub(crate) fn collect_downloads(
&self,
label: Option<&str>,
target_dir: &Path,
) -> Vec<(RefDataset, Vec<UnvalidatedFile>)> {
let datasets = if let Some(label) = label {
self.clone()
.datasets_owned()
.into_iter()
.filter(|dataset| dataset.label == label)
.collect::<Vec<_>>()
} else {
self.clone()
.datasets_owned()
.into_iter()
.collect::<Vec<_>>()
};
assert_ne!(0, datasets.len());
datasets
.into_iter()
.map(|dataset| {
let fasta = dataset.get_fasta_download(target_dir);
let genbank = dataset.get_genbank_download(target_dir);
let gfa = dataset.get_gfa_download(target_dir);
let gtf = dataset.get_gtf_download(target_dir);
let gff = dataset.get_gff_download(target_dir);
let bed = dataset.get_bed_download(target_dir);
let tar = dataset.get_tar_download(target_dir);
info!(
"Preparing to download these files:\n{:?}",
[&fasta, &genbank, &gfa, &gff, >f, &bed, &tar]
);
let files = [fasta, genbank, gfa, gff, gtf, bed, tar]
.into_iter()
.flatten()
.collect::<Vec<_>>();
(dataset, files)
})
.collect::<Vec<_>>()
}
#[allow(clippy::too_many_lines)]
pub async fn download_dataset(
self,
label: Option<&str>,
target_dir: PathBuf,
) -> color_eyre::Result<Self> {
let shared_client = Client::new();
let dataset_files: Vec<(RefDataset, Vec<UnvalidatedFile>)> =
self.collect_downloads(label, &target_dir);
let num_to_download = count_downloads(&dataset_files);
if num_to_download == 0 {
info!(
"All requested files were previously downloaded and still passed checksums, so no downloads will be performed."
);
return Ok(self);
}
let (mut toplevel_pb, multiprog) = setup_progress_tracking(label, num_to_download);
let dataset_task_handles =
submit_download_requests(dataset_files, &shared_client, &target_dir, &multiprog);
let updated_datasets =
update_project_datasets(dataset_task_handles, &mut toplevel_pb).await?;
toplevel_pb.finish_with_message(format!(
"Done! {num_to_download} files successfully downloaded to {target_dir:?}."
));
let updated_project = self.update_registry(&updated_datasets);
Ok(updated_project)
}
#[must_use]
pub fn update_registry(self, new_datasets: &[RefDataset]) -> Project {
let old_datasets: HashMap<&str, &RefDataset> = self
.datasets()
.iter()
.map(|dataset| (dataset.label.as_str(), dataset))
.collect();
let updated_datasets: HashMap<&str, &RefDataset> = new_datasets
.iter()
.map(|dataset| (dataset.label.as_str(), dataset))
.collect();
let merged_datasets: Vec<RefDataset> = old_datasets
.into_iter()
.map(|(label, dataset)| match updated_datasets.get(label) {
Some(aha) => (*aha).to_owned(),
None => dataset.clone(),
})
.collect();
let updated_registry = Registry {
datasets: merged_datasets,
last_modified: Timestamp::now(),
..self.project
};
Self {
project: updated_registry,
}
}
pub fn remove(mut self, label: &str) -> Result<Self, EntryError> {
if self
.datasets()
.iter()
.filter(|dataset| dataset.label == label)
.collect::<Vec<&RefDataset>>()
.is_empty()
{
return Err(EntryError::LabelNotFound(label.to_string()));
}
self.project
.filter_datasets(|dataset| dataset.label != label);
if self.datasets().is_empty() {
return Err(EntryError::FinalEntry(label.to_string()));
}
Ok(self)
}
fn print_single_label_data(self, label: &str) {
let datasets = self.datasets();
let dataset: Vec<_> = datasets
.iter()
.filter(|dataset| dataset.label == label)
.collect();
assert_eq!(
dataset.len(),
1,
"No project with the label '{label}' has been registered. Run `refman list` without the label to see which datasets are registered."
);
let unwrapped_dataset = dataset[0];
eprintln!("URLs registered for {label}:");
eprintln!("--------------------{}", "-".repeat(label.len()));
eprintln!(
" - FASTA: {}",
unwrapped_dataset
.fasta
.clone()
.unwrap_or(DownloadStatus::default())
);
eprintln!(
" - Genbank: {}",
unwrapped_dataset
.genbank
.clone()
.unwrap_or(DownloadStatus::default())
);
eprintln!(
" - GFA: {}",
unwrapped_dataset
.gfa
.clone()
.unwrap_or(DownloadStatus::default())
);
eprintln!(
" - GFF: {}",
unwrapped_dataset
.gff
.clone()
.unwrap_or(DownloadStatus::default())
);
eprintln!(
" - GTF: {}",
unwrapped_dataset
.gtf
.clone()
.unwrap_or(DownloadStatus::default())
);
eprintln!(
" - BED: {}",
unwrapped_dataset
.bed
.clone()
.unwrap_or(DownloadStatus::default())
);
eprintln!(
" - TAR: {}",
unwrapped_dataset
.tar
.clone()
.unwrap_or(DownloadStatus::default())
);
}
fn print_all_labels(self) {
let title_field = &self.project.title;
if let Some(title) = title_field {
info!("Showing available data registered for {title}:");
}
let mut pretty_table = Table::new();
pretty_table.add_row(row![
"Label", "FASTA", "Genbank", "GFA", "GFF", "GTF", "BED", "TAR",
]);
let datasets = self.datasets();
for dataset in datasets {
pretty_table.add_row(row![
dataset.label,
abbreviate_str(
dataset
.fasta
.clone()
.unwrap_or(DownloadStatus::default())
.url_owned(),
20,
8,
25
),
abbreviate_str(
dataset
.genbank
.clone()
.unwrap_or(DownloadStatus::default())
.url_owned(),
20,
8,
25
),
abbreviate_str(
dataset
.gfa
.clone()
.unwrap_or(DownloadStatus::default())
.url_owned(),
20,
8,
25
),
abbreviate_str(
dataset
.gff
.clone()
.unwrap_or(DownloadStatus::default())
.url_owned(),
20,
8,
25
),
abbreviate_str(
dataset
.gtf
.clone()
.unwrap_or(DownloadStatus::default())
.url_owned(),
20,
8,
25
),
abbreviate_str(
dataset
.bed
.clone()
.unwrap_or(DownloadStatus::default())
.url_owned(),
20,
8,
25
),
abbreviate_str(
dataset
.tar
.clone()
.unwrap_or(DownloadStatus::default())
.url_owned(),
20,
8,
25
),
]);
}
pretty_table.printstd();
}
pub fn prettyprint(self, label: Option<String>) {
if let Some(label_str) = label {
self.print_single_label_data(&label_str);
return;
}
self.print_all_labels();
}
}
#[inline]
fn abbreviate_str(s: String, max_chars: usize, head_chars: usize, tail_chars: usize) -> String {
let char_count = s.chars().count();
if char_count <= max_chars {
return s;
}
let head: String = s.chars().take(head_chars).collect();
let tail: String = s
.chars()
.rev()
.take(tail_chars)
.collect::<String>()
.chars()
.rev()
.collect();
format!("{head}...{tail}")
}
#[derive(Debug, Serialize, Deserialize, Clone)]
struct Registry {
title: Option<String>,
description: Option<String>,
last_modified: Timestamp,
global: bool,
datasets: Vec<RefDataset>,
}
impl Default for Registry {
fn default() -> Self {
Registry {
title: None,
description: None,
last_modified: Timestamp::now(),
global: false,
datasets: vec![],
}
}
}
impl Registry {
fn filter_datasets<F>(&mut self, predicate: F)
where
F: FnMut(&RefDataset) -> bool,
{
self.datasets.retain(predicate);
}
}
pub struct RegistryOptions {
resolved_path: PathBuf,
title: Option<String>,
description: Option<String>,
global: bool,
}
impl RegistryOptions {
pub fn try_new(
title: Option<String>,
description: Option<String>,
requested_path: &Option<String>,
global: bool,
) -> Result<RegistryOptions, RegistryError> {
if let Some(possible_path) = requested_path.as_deref() {
let maybe_path = PathBuf::from_str(possible_path).ok();
let resolved_path = resolve_registry_path(maybe_path, global)?;
Ok(Self {
resolved_path,
title,
description,
global,
})
} else {
let resolved_path = resolve_registry_path(None, global)?;
Ok(Self {
resolved_path,
title,
description,
global,
})
}
}
pub fn init(&self) -> Result<(), RegistryError> {
if self.resolved_path.exists() {
info!("A refman registry already exists. Start filling it with `refman register`.");
} else {
let mut new_project =
Project::new(self.title.clone(), self.description.clone(), self.global);
File::create(&self.resolved_path)?;
self.write_registry(&mut new_project)?;
}
Ok(())
}
pub fn read_registry(&self) -> Result<Project, RegistryError> {
if !self.resolved_path.exists() {
let new_project = Project::default();
return Ok(new_project);
}
if fs::metadata(&self.resolved_path)?.len() == 0 {
let new_project = Project::default();
return Ok(new_project);
}
let toml_contents = read_to_string(self.resolved_path.clone())?;
let project: Project = toml::from_str(&toml_contents)?;
Ok(project)
}
pub fn write_registry(&self, project: &mut Project) -> Result<(), RegistryError> {
project.project.last_modified = Timestamp::now();
let toml_text = toml::to_string_pretty(project)?;
fs::write(&self.resolved_path, toml_text)?;
Ok(())
}
}
fn resolve_registry_path(
maybe_path: Option<PathBuf>,
global: bool,
) -> Result<PathBuf, RegistryError> {
let registry_path = match maybe_path {
Some(valid_path) => {
if let Some(path_str) = valid_path.to_str() {
debug!("Setting the refman home to '{path_str}'");
set_refman_home(path_str);
}
valid_path.join("refman.toml")
},
None => {
if !global {
let current_dir = current_dir()?;
if let Some(current_dir_string) = current_dir.to_str() {
debug!("Setting the refman home to '{current_dir_string}'");
set_refman_home(current_dir_string);
}
return Ok(current_dir.join("refman.toml"));
}
let refman_home: Option<PathBuf> = match env::var("REFMAN_HOME") {
Ok(path_str) => {
debug!(
"Desired file path detected in the REFMAN_HOME environment variable: '{}'. A global registry will be placed there.",
path_str
);
let path = PathBuf::from(path_str);
Some(path)
},
Err(_) => {
debug!(
"The REFMAN_HOME variable is not set. The registry will thus be placed in its default location in the user's home directory."
);
dirs::home_dir()
},
};
if let Some(dir) = refman_home {
let resolved_home = dir.join(".refman");
debug!("setting the refman home to '{:?}'", resolved_home);
resolved_home
} else {
warn!("unable to access home directory, so `refman `will place its registry in the current working directory. unless this path is provided in the next `refman` run, `refman` may be unable to pick up where it leaves off during the current run.");
let current_dir = current_dir()?;
if let Some(current_dir_string) = current_dir.to_str() {
debug!("setting the refman home to '{current_dir_string}'");
set_refman_home(current_dir_string);
}
let resolved_home = current_dir.join(".refman");
debug!("setting the refman home to '{:?}'", resolved_home);
resolved_home
}.join("refman.toml")
}, };
Ok(registry_path)
}
fn set_refman_home(desired_dir: &str) {
if let Ok(old_home) = env::var("REFMAN_HOME") {
warn!(
"The environment variable $REFMAN_HOME was previously set to {}, but a new location at {} was requested. `refman` will overwrite the old $REFMAN_HOME value and proceed.",
old_home, desired_dir
);
unsafe { env::set_var("REFMAN_HOME", desired_dir) }
} else {
debug!(
"The REFMAN_HOME environment variable has not previously been set. Now setting it to the requested directory, {}",
desired_dir
);
unsafe { env::set_var("REFMAN_HOME", desired_dir) }
}
}
fn is_likely_url(url: &str) -> bool {
url.starts_with("http") || url.starts_with("ftp") || url.starts_with("sftp")
}
#[inline]
fn count_downloads(dataset_files: &[(RefDataset, Vec<UnvalidatedFile>)]) -> usize {
let mut num_to_download = 0;
for (_, files) in dataset_files {
num_to_download += files.len();
}
info!("{num_to_download} downloads are confirmed. Proceeding...");
num_to_download
}
#[allow(clippy::expect_used)]
fn setup_progress_tracking(
label: Option<&str>,
num_to_download: usize,
) -> (ProgressBar, Arc<MultiProgress>) {
let message = match label {
Some(label_str) => {
format!("Downloading {num_to_download} files for project labeled '{label_str}'...")
},
None => format!("Downloading all {num_to_download} files listed in the refman registry..."),
};
let multi_pb = Arc::new(MultiProgress::new());
let toplevel_pb = multi_pb.add(ProgressBar::new(num_to_download as u64));
toplevel_pb.set_style(
ProgressStyle::default_bar()
.template("{msg} [{bar:40.cyan/blue}] {pos}/{len} ({eta})")
.expect("Failed to set template"),
);
toplevel_pb.set_message(message);
(toplevel_pb, multi_pb)
}
fn submit_download_requests(
dataset_files: Vec<(RefDataset, Vec<UnvalidatedFile>)>,
shared_client: &Client,
target_dir: &Path,
mp: &Arc<MultiProgress>,
) -> Vec<JoinHandle<Result<(RefDataset, MultiDownloadResults), ColorError>>> {
let num_to_download = dataset_files.len();
let mut dataset_task_handles: Vec<
JoinHandle<Result<(RefDataset, MultiDownloadResults), ColorError>>,
> = Vec::with_capacity(num_to_download);
for (dataset, files) in dataset_files {
let shared_client = shared_client.clone();
let mp = mp.clone();
let target_dir = Arc::new(target_dir.to_path_buf());
let handle: JoinHandle<_> = tokio::spawn(async move {
let file_task_handles = files.into_iter().map(|file| {
let client = shared_client.clone();
let dir = target_dir.clone();
let mp = mp.clone();
tokio::spawn(async move { request_dataset(file, client, dir, mp).await })
});
let file_results = try_join_all(file_task_handles).await?;
Ok((dataset, file_results))
});
dataset_task_handles.push(handle);
}
dataset_task_handles
}
async fn update_project_datasets(
dataset_task_handles: Vec<JoinHandle<Result<(RefDataset, MultiDownloadResults), ColorError>>>,
toplevel_pb: &mut ProgressBar,
) -> color_eyre::Result<Vec<RefDataset>> {
let updated_datasets: Vec<RefDataset> = try_join_all(dataset_task_handles)
.await?
.into_iter()
.filter_map(|dataset_result| {
toplevel_pb.inc(1);
match dataset_result {
Ok((dataset, file_results)) => {
match file_results.into_iter().collect::<Result<Vec<_>, _>>() {
Ok(successful_files) => Some((dataset, successful_files)),
Err(msg) => {
warn!("Failed to download files because of this error: {}", msg);
None
}
}
}
Err(msg) => {
warn!("Failed to download files because of this error: {}", msg);
None
}
}
})
.map(
|(mut dataset, files)| -> Result<RefDataset, ValidationError> {
for file in files {
dataset.update_with_download(&file)?;
}
Ok(dataset)
},
)
.collect::<Result<Vec<RefDataset>, ValidationError>>()?;
Ok(updated_datasets)
}
#[cfg(test)]
mod tests {
#![allow(clippy::expect_used, clippy::unwrap_used)]
use super::*;
use tempfile::tempdir;
#[test]
fn test_new_project() {
let title = Some("Test Project".to_string());
let desc = Some("A test project".to_string());
let project = Project::new(title.clone(), desc.clone(), false);
assert_eq!(project.project.title, title);
assert_eq!(project.project.description, desc);
assert!(!project.project.global);
assert!(project.project.datasets.is_empty());
}
#[test]
fn test_registry_options_new() {
let temp_dir = tempdir().unwrap();
let dir_path = temp_dir.path().to_str().unwrap();
let options = RegistryOptions::try_new(
Some("Test Registry".to_string()),
Some("Test Description".to_string()),
&Some(dir_path.to_string()),
false,
)
.unwrap();
assert_eq!(
options.resolved_path,
PathBuf::from(dir_path).join("refman.toml")
);
assert_eq!(options.title, Some("Test Registry".to_string()));
assert_eq!(options.description, Some("Test Description".to_string()));
assert!(!options.global);
}
#[test]
fn test_read_write_registry() {
let temp_dir = tempdir().unwrap();
let dir_path = temp_dir.path().to_str().unwrap();
let options =
RegistryOptions::try_new(None, None, &Some(dir_path.to_string()), false).unwrap();
let mut project = Project::new(None, None, false);
options.write_registry(&mut project).unwrap();
assert!(options.resolved_path.exists());
let read_project = options.read_registry().unwrap();
assert_eq!(read_project.datasets().len(), 0);
}
}