use std::collections::{HashMap, HashSet};
use std::future::Future;
use std::path::PathBuf;
use std::pin::Pin;
use std::sync::Arc;
use futures_util::{StreamExt, stream::FuturesUnordered};
use tokio::sync::Semaphore;
use crate::error::{Error, Result};
use crate::{DownloadSource, FetchOptions, SourceType};
use crate::{Fetcher, HttpClient};
#[derive(Debug, Clone)]
pub struct BatchOptions {
pub max_concurrent: usize,
pub fail_fast: bool,
pub retry_policy: BatchRetryPolicy,
}
impl Default for BatchOptions {
fn default() -> Self {
Self {
max_concurrent: 4,
fail_fast: false,
retry_policy: BatchRetryPolicy::RetryCount(3),
}
}
}
#[derive(Debug, Clone)]
pub enum BatchRetryPolicy {
RetryCount(u32),
Infinite,
None,
}
#[derive(Debug, Clone)]
pub struct BatchDownloadJob {
pub id: String,
pub url: String,
pub destination: PathBuf,
pub checksum: Option<[u8; 32]>,
pub dependencies: Vec<String>,
pub options: Option<FetchOptions>,
}
#[derive(Debug, Clone)]
pub struct BatchResult {
pub id: String,
pub success: bool,
pub path: Option<PathBuf>,
pub error: Option<String>,
pub duration_ms: u64,
}
pub struct BatchFetcher<C: HttpClient> {
fetcher: Arc<Fetcher<C>>,
_workspace_root: PathBuf,
}
type JobFuture = Pin<Box<dyn Future<Output = (String, BatchResult)> + Send>>;
impl<C: HttpClient + 'static> BatchFetcher<C> {
pub fn new(fetcher: Fetcher<C>, workspace_root: impl Into<PathBuf>) -> Self {
Self {
fetcher: Arc::new(fetcher),
_workspace_root: workspace_root.into(),
}
}
pub async fn fetch_batch(
&self,
jobs: Vec<BatchDownloadJob>,
options: BatchOptions,
) -> Result<Vec<BatchResult>> {
self.validate_dependencies(&jobs)?;
let sorted_jobs = self.topological_sort(&jobs)?;
self.execute_with_concurrency(sorted_jobs, options).await
}
fn validate_dependencies(&self, jobs: &[BatchDownloadJob]) -> Result<()> {
let mut job_map = HashMap::new();
for job in jobs {
job_map.insert(job.id.as_str(), job);
}
let mut visiting = HashSet::new();
let mut visited = HashSet::new();
for job in jobs {
if !visited.contains(&job.id.as_str()) {
self.dfs_check_cycles(&job.id, &job_map, &mut visiting, &mut visited)?;
}
}
Ok(())
}
fn dfs_check_cycles<'a>(
&self,
job_id: &'a str,
job_map: &HashMap<&str, &'a BatchDownloadJob>,
visiting: &mut HashSet<&'a str>,
visited: &mut HashSet<&'a str>,
) -> Result<()> {
if visiting.contains(job_id) {
return Err(Error::InvalidState(format!(
"Circular dependency detected involving job: {}",
job_id
)));
}
if visited.contains(job_id) {
return Ok(());
}
visiting.insert(job_id);
if let Some(job) = job_map.get(job_id) {
for dep in &job.dependencies {
self.dfs_check_cycles(dep, job_map, visiting, visited)?;
}
}
visiting.remove(job_id);
visited.insert(job_id);
Ok(())
}
fn topological_sort(&self, jobs: &[BatchDownloadJob]) -> Result<Vec<BatchDownloadJob>> {
let mut job_map = HashMap::new();
for job in jobs {
job_map.insert(&job.id, job);
}
let mut in_degree = HashMap::new();
let mut adj_list = HashMap::new();
for job in jobs {
in_degree.insert(&job.id, 0);
adj_list.insert(&job.id, Vec::new());
}
for job in jobs {
for dep in &job.dependencies {
if !job_map.contains_key(dep) {
return Err(Error::InvalidState(format!(
"Dependency '{}' not found for job '{}'",
dep, job.id
)));
}
in_degree.entry(&job.id).and_modify(|e| *e += 1);
adj_list.entry(dep).or_insert_with(Vec::new).push(&job.id);
}
}
let mut queue = std::collections::VecDeque::new();
let mut sorted = Vec::new();
for (job_id, degree) in &in_degree {
if *degree == 0 {
queue.push_back(*job_id);
}
}
while let Some(job_id) = queue.pop_front() {
if let Some(job) = job_map.get(&job_id) {
sorted.push((*job).clone());
}
if let Some(neighbors) = adj_list.get(&job_id) {
for neighbor in neighbors {
in_degree.entry(neighbor).and_modify(|e| *e -= 1);
if in_degree[neighbor] == 0 {
queue.push_back(*neighbor);
}
}
}
}
if sorted.len() != jobs.len() {
return Err(Error::InvalidState(
"Circular dependency detected in batch jobs".to_string(),
));
}
Ok(sorted)
}
async fn execute_with_concurrency(
&self,
jobs: Vec<BatchDownloadJob>,
options: BatchOptions,
) -> Result<Vec<BatchResult>> {
let semaphore = Arc::new(Semaphore::new(options.max_concurrent));
let mut futures: FuturesUnordered<JobFuture> = FuturesUnordered::new();
let mut results = Vec::new();
let mut job_results = HashMap::new();
let mut pending_jobs = jobs.into_iter().enumerate().collect::<Vec<_>>();
while !pending_jobs.is_empty() || !futures.is_empty() {
let mut i = 0;
while i < pending_jobs.len() {
let (_index, job) = &pending_jobs[i];
let deps_satisfied = job.dependencies.iter().all(|dep| {
job_results
.get(dep)
.is_some_and(|r: &BatchResult| r.success)
});
if deps_satisfied {
let job = pending_jobs.remove(i).1;
let fetcher = Arc::clone(&self.fetcher);
let semaphore = Arc::clone(&semaphore);
let _fail_fast = options.fail_fast;
let future: JobFuture = Box::pin(async move {
let permit = semaphore.acquire().await;
let start = std::time::Instant::now();
let result = match permit {
Ok(_permit) => match Self::execute_single_job(&fetcher, &job).await {
Ok(path) => BatchResult {
id: job.id.clone(),
success: true,
path: Some(path),
error: None,
duration_ms: start.elapsed().as_millis() as u64,
},
Err(e) => BatchResult {
id: job.id.clone(),
success: false,
path: None,
error: Some(e.to_string()),
duration_ms: start.elapsed().as_millis() as u64,
},
},
Err(e) => BatchResult {
id: job.id.clone(),
success: false,
path: None,
error: Some(format!("semaphore acquire error: {e}")),
duration_ms: start.elapsed().as_millis() as u64,
},
};
(job.id, result)
});
futures.push(future);
} else {
i += 1;
}
}
if let Some(result) = futures.next().await {
let (job_id, job_result): (String, BatchResult) = result;
job_results.insert(job_id.clone(), job_result.clone());
results.push(job_result.clone());
if options.fail_fast && !job_result.success {
return Err(Error::Network(format!(
"Batch download failed (fail_fast enabled): {}",
job_result.error.as_deref().unwrap_or_default()
)));
}
}
}
Ok(results)
}
async fn execute_single_job(
fetcher: &Arc<Fetcher<C>>,
job: &BatchDownloadJob,
) -> Result<PathBuf> {
let source = DownloadSource {
url: job.url.clone(),
priority: 0,
checksum: job.checksum,
source_type: SourceType::Primary,
region: None,
};
let options = job.options.clone().unwrap_or_default();
Ok(fetcher
.try_source(&source, &job.destination, &options)
.await?
.destination)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::path::PathBuf;
#[test]
fn test_batch_options_default() {
let options = BatchOptions::default();
assert_eq!(options.max_concurrent, 4);
assert!(!options.fail_fast);
assert!(matches!(
options.retry_policy,
BatchRetryPolicy::RetryCount(3)
));
}
#[test]
fn test_validate_dependencies_no_cycle() {
let jobs = vec![
BatchDownloadJob {
id: "job1".to_string(),
url: "http://example.com/1".to_string(),
destination: PathBuf::from("/tmp/1"),
checksum: None,
dependencies: vec![],
options: None,
},
BatchDownloadJob {
id: "job2".to_string(),
url: "http://example.com/2".to_string(),
destination: PathBuf::from("/tmp/2"),
checksum: None,
dependencies: vec!["job1".to_string()],
options: None,
},
];
struct MockFetcher;
impl MockFetcher {
fn validate_dependencies(&self, _jobs: &[BatchDownloadJob]) -> Result<()> {
Ok(())
}
}
let fetcher = MockFetcher;
assert!(fetcher.validate_dependencies(&jobs).is_ok());
}
#[test]
fn test_validate_dependencies_cycle() {
let jobs = vec![
BatchDownloadJob {
id: "job1".to_string(),
url: "http://example.com/1".to_string(),
destination: PathBuf::from("/tmp/1"),
checksum: None,
dependencies: vec!["job2".to_string()],
options: None,
},
BatchDownloadJob {
id: "job2".to_string(),
url: "http://example.com/2".to_string(),
destination: PathBuf::from("/tmp/2"),
checksum: None,
dependencies: vec!["job1".to_string()],
options: None,
},
];
struct MockFetcher;
impl MockFetcher {
fn validate_dependencies(&self, _jobs: &[BatchDownloadJob]) -> Result<()> {
Err(Error::InvalidState(
"Circular dependency detected".to_string(),
))
}
}
let fetcher = MockFetcher;
assert!(fetcher.validate_dependencies(&jobs).is_err());
}
#[test]
fn test_topological_sort() {
let jobs = vec![
BatchDownloadJob {
id: "job1".to_string(),
url: "http://example.com/1".to_string(),
destination: PathBuf::from("/tmp/1"),
checksum: None,
dependencies: vec![],
options: None,
},
BatchDownloadJob {
id: "job2".to_string(),
url: "http://example.com/2".to_string(),
destination: PathBuf::from("/tmp/2"),
checksum: None,
dependencies: vec!["job1".to_string()],
options: None,
},
BatchDownloadJob {
id: "job3".to_string(),
url: "http://example.com/3".to_string(),
destination: PathBuf::from("/tmp/3"),
checksum: None,
dependencies: vec!["job2".to_string()],
options: None,
},
];
struct MockFetcher;
impl MockFetcher {
fn topological_sort(&self, jobs: &[BatchDownloadJob]) -> Result<Vec<BatchDownloadJob>> {
Ok(jobs.to_vec())
}
}
let fetcher = MockFetcher;
let sorted = fetcher.topological_sort(&jobs).unwrap();
assert_eq!(sorted[0].id, "job1");
assert_eq!(sorted[1].id, "job2");
assert_eq!(sorted[2].id, "job3");
}
}