use futures::future::join_all;
use indicatif::ProgressStyle;
use std::collections::{HashMap, HashSet};
use std::sync::{Arc, Mutex};
use tokio::task;
use crate::cli::Args;
use crate::network::{NetworkScope, NetworkSettings};
use crate::progress::ProgressManager;
use crate::providers::Provider;
use crate::utils::verbose_print;
pub fn apply_network_settings_to_provider(provider: &mut dyn Provider, settings: &NetworkSettings) {
if settings.scope == NetworkScope::Testers {
return;
}
provider.with_subdomains(settings.include_subdomains);
provider.with_timeout(settings.timeout);
provider.with_retries(settings.retries);
provider.with_random_agent(settings.random_agent);
provider.with_insecure(settings.insecure);
provider.with_parallel(settings.parallel);
if let Some(proxy) = &settings.proxy {
provider.with_proxy(Some(proxy.clone()));
if let Some(auth) = &settings.proxy_auth {
provider.with_proxy_auth(Some(auth.clone()));
}
}
if let Some(rate) = settings.rate_limit {
provider.with_rate_limit(Some(rate));
}
}
pub fn add_provider<T: Provider + 'static>(
args: &Args,
network_settings: &NetworkSettings,
providers: &mut Vec<Box<dyn Provider>>,
provider_names: &mut Vec<String>,
provider_name: String,
provider_builder: impl FnOnce() -> T,
) {
if args.verbose && !args.silent {
println!("Adding {provider_name} provider");
if network_settings.include_subdomains {
println!("Subdomain inclusion enabled for {provider_name}");
}
if network_settings.proxy.is_some() {
println!(
"Using proxy for {provider_name}: {}",
network_settings.proxy.as_ref().unwrap()
);
}
if network_settings.random_agent && !args.silent {
println!("Random User-Agent enabled for {provider_name}");
}
println!(
"Timeout set to {} seconds for {provider_name}",
network_settings.timeout
);
println!(
"Retries set to {} for {provider_name}",
network_settings.retries
);
println!(
"Parallel requests set to {} for {provider_name}",
network_settings.parallel
);
if let Some(rate) = network_settings.rate_limit {
println!("Rate limit set to {rate} requests/second for {provider_name}");
}
}
let mut provider = provider_builder();
apply_network_settings_to_provider(&mut provider, network_settings);
providers.push(Box::new(provider));
provider_names.push(provider_name);
}
pub async fn process_domains(
domains: Vec<String>,
args: &Args,
progress_manager: &ProgressManager,
providers: &[Box<dyn Provider>],
provider_names: &[String],
) -> HashSet<String> {
let all_urls = Arc::new(Mutex::new(HashSet::new()));
let total_domains = domains.len();
let total_providers = providers.len();
let overall_bar = progress_manager.create_domain_bar(total_domains);
overall_bar.set_message("Processing domains");
let processed_domains = Arc::new(Mutex::new(0usize));
let provider_bars = progress_manager.create_provider_bars(provider_names);
let provider_queues: Vec<Arc<Mutex<Vec<String>>>> = (0..providers.len())
.map(|_| Arc::new(Mutex::new(domains.clone())))
.collect();
let domain_completion = Arc::new(Mutex::new(
domains
.iter()
.map(|d| (d.clone(), 0))
.collect::<HashMap<String, usize>>(),
));
verbose_print(
args,
format!("Using provider-based concurrency with {total_providers} providers"),
);
let provider_data: Vec<_> = providers
.iter()
.enumerate()
.map(|(idx, provider)| (provider.clone_box(), provider_names[idx].clone(), idx))
.collect();
let mut provider_futures = Vec::new();
let timeout = args.timeout;
let verbose = args.verbose;
let silent = args.silent;
let no_progress = args.no_progress;
for (p_idx, (provider_clone, provider_name, original_idx)) in
provider_data.into_iter().enumerate()
{
let all_urls_clone = Arc::clone(&all_urls);
let processed_domains_clone = Arc::clone(&processed_domains);
let queue = Arc::clone(&provider_queues[p_idx]);
let domain_completion_clone = Arc::clone(&domain_completion);
let overall_bar_clone = overall_bar.clone();
let provider_bar = provider_bars[original_idx].clone();
let provider_future = task::spawn(async move {
let mut current_domain_idx = 0;
loop {
let domain = {
let mut queue = queue.lock().unwrap();
if queue.is_empty() {
break; }
current_domain_idx += 1;
queue.remove(0)
};
provider_bar.set_message(format!(
"({current_domain_idx}/{total_domains}) Fetching data for {domain}"
));
let bar_clone = provider_bar.clone();
if !no_progress && !silent {
provider_bar.tick();
}
let ticker_handle = tokio::spawn(async move {
let start_time = std::time::Instant::now();
let total_duration_ms = timeout * 1000;
let spinner_phase_duration =
std::time::Duration::from_millis(total_duration_ms / 10);
tokio::time::sleep(spinner_phase_duration).await;
let progress_style = ProgressStyle::with_template(
"{prefix:.bold.dim} [{bar:40.green/white}] {spinner} {wide_msg}",
)
.unwrap()
.progress_chars("=> ")
.tick_strings(&["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"]);
bar_clone.set_style(progress_style);
let update_interval_ms = 100;
let end_time = start_time + std::time::Duration::from_millis(total_duration_ms);
while std::time::Instant::now() < end_time {
let now = std::time::Instant::now();
let elapsed = now.duration_since(start_time).as_millis() as u64;
let progress = (elapsed * 100) / total_duration_ms;
bar_clone.set_position(progress.min(99));
tokio::time::sleep(std::time::Duration::from_millis(update_interval_ms))
.await;
}
bar_clone.set_position(100);
});
let result = match provider_clone.fetch_urls(&domain).await {
Ok(urls) => {
provider_bar.set_position(100);
provider_bar.set_message(format!(
"({}/{}) Found {} URLs for {}",
current_domain_idx,
total_domains,
urls.len(),
domain
));
ticker_handle.abort();
provider_bar.set_style(
ProgressStyle::with_template(
"{prefix:.bold.dim} [{bar:40.green/white}] ✓ {wide_msg}",
)
.unwrap()
.progress_chars("=>"),
);
provider_bar.tick();
{
let mut url_set = all_urls_clone.lock().unwrap();
for url in &urls {
url_set.insert(url.clone());
}
}
let mut is_domain_complete = false;
{
let mut completion_map = domain_completion_clone.lock().unwrap();
if let Some(count) = completion_map.get_mut(&domain) {
*count += 1;
is_domain_complete = *count >= total_providers;
}
}
if is_domain_complete {
let mut count = processed_domains_clone.lock().unwrap();
*count += 1;
overall_bar_clone.set_position(*count as u64);
overall_bar_clone.set_message(format!(
"Completed {}/{} domains",
*count, total_domains
));
if verbose && !silent {
println!(
"Domain completed: {} ({}/{})",
domain, *count, total_domains
);
}
}
if verbose && !silent {
println!(
" - {}: Found {} URLs for {}",
provider_name,
urls.len(),
domain
);
}
Ok(urls.len())
}
Err(e) => {
provider_bar.set_position(100);
provider_bar.set_message(format!(
"({current_domain_idx}/{total_domains}) Error: for {domain}"
));
ticker_handle.abort();
provider_bar.set_style(
ProgressStyle::with_template(
"{prefix:.bold.dim} [{bar:40.red/white}] ✗ {wide_msg}",
)
.unwrap()
.progress_chars("=>"),
);
provider_bar.tick();
let mut is_domain_complete = false;
{
let mut completion_map = domain_completion_clone.lock().unwrap();
if let Some(count) = completion_map.get_mut(&domain) {
*count += 1;
is_domain_complete = *count >= total_providers;
}
}
if is_domain_complete {
let mut count = processed_domains_clone.lock().unwrap();
*count += 1;
overall_bar_clone.set_position(*count as u64);
overall_bar_clone.set_message(format!(
"Completed {}/{} domains",
*count, total_domains
));
if verbose && !silent {
println!(
"Domain completed: {} ({}/{})",
domain, *count, total_domains
);
}
}
if verbose && !silent {
eprintln!("Error fetching URLs for {domain} from {provider_name}: {e}");
}
Err(e.to_string())
}
};
if let Err(err) = result {
if verbose && !silent {
println!(" - {provider_name}: Error - {err} for {domain}");
}
}
if current_domain_idx < total_domains {
provider_bar.set_position(0); }
}
if current_domain_idx >= total_domains {
provider_bar.finish_with_message(format!(
"({total_domains}/{total_domains}) Completed all domains"
));
}
if verbose && !silent {
println!("Provider {provider_name} has completed processing all domains");
}
});
provider_futures.push(provider_future);
}
join_all(provider_futures).await;
overall_bar.finish_with_message("All domains processed");
Arc::try_unwrap(all_urls).unwrap().into_inner().unwrap()
}