1use std::collections::{HashMap, HashSet};
7use std::future::Future;
8use std::path::PathBuf;
9use std::pin::Pin;
10use std::sync::Arc;
11
12use futures_util::{StreamExt, stream::FuturesUnordered};
13use tokio::sync::Semaphore;
14
15use crate::error::{Error, Result};
16use crate::{DownloadSource, FetchOptions, SourceType};
17use crate::{Fetcher, HttpClient};
18
19#[derive(Debug, Clone)]
21pub struct BatchOptions {
22 pub max_concurrent: usize,
24 pub fail_fast: bool,
26 pub retry_policy: BatchRetryPolicy,
28}
29
30impl Default for BatchOptions {
31 fn default() -> Self {
32 Self {
33 max_concurrent: 4,
34 fail_fast: false,
35 retry_policy: BatchRetryPolicy::RetryCount(3),
36 }
37 }
38}
39
40#[derive(Debug, Clone)]
42pub enum BatchRetryPolicy {
43 RetryCount(u32),
45 Infinite,
47 None,
49}
50
51#[derive(Debug, Clone)]
53pub struct BatchDownloadJob {
54 pub id: String,
56 pub url: String,
58 pub destination: PathBuf,
60 pub checksum: Option<[u8; 32]>,
62 pub dependencies: Vec<String>,
64 pub options: Option<FetchOptions>,
66}
67
68#[derive(Debug, Clone)]
70pub struct BatchResult {
71 pub id: String,
73 pub success: bool,
75 pub path: Option<PathBuf>,
77 pub error: Option<String>,
79 pub duration_ms: u64,
81}
82
83pub struct BatchFetcher<C: HttpClient> {
85 fetcher: Arc<Fetcher<C>>,
86 _workspace_root: PathBuf,
87}
88
89type JobFuture = Pin<Box<dyn Future<Output = (String, BatchResult)> + Send>>;
90
91impl<C: HttpClient + 'static> BatchFetcher<C> {
92 pub fn new(fetcher: Fetcher<C>, workspace_root: impl Into<PathBuf>) -> Self {
94 Self {
95 fetcher: Arc::new(fetcher),
96 _workspace_root: workspace_root.into(),
97 }
98 }
99
100 pub async fn fetch_batch(
102 &self,
103 jobs: Vec<BatchDownloadJob>,
104 options: BatchOptions,
105 ) -> Result<Vec<BatchResult>> {
106 self.validate_dependencies(&jobs)?;
108
109 let sorted_jobs = self.topological_sort(&jobs)?;
111
112 self.execute_with_concurrency(sorted_jobs, options).await
114 }
115
116 fn validate_dependencies(&self, jobs: &[BatchDownloadJob]) -> Result<()> {
118 let mut job_map = HashMap::new();
119 for job in jobs {
120 job_map.insert(job.id.as_str(), job);
121 }
122
123 let mut visiting = HashSet::new();
125 let mut visited = HashSet::new();
126
127 for job in jobs {
128 if !visited.contains(&job.id.as_str()) {
129 self.dfs_check_cycles(&job.id, &job_map, &mut visiting, &mut visited)?;
130 }
131 }
132
133 Ok(())
134 }
135
136 fn dfs_check_cycles<'a>(
138 &self,
139 job_id: &'a str,
140 job_map: &HashMap<&str, &'a BatchDownloadJob>,
141 visiting: &mut HashSet<&'a str>,
142 visited: &mut HashSet<&'a str>,
143 ) -> Result<()> {
144 if visiting.contains(job_id) {
145 return Err(Error::InvalidState(format!(
146 "Circular dependency detected involving job: {}",
147 job_id
148 )));
149 }
150
151 if visited.contains(job_id) {
152 return Ok(());
153 }
154
155 visiting.insert(job_id);
156
157 if let Some(job) = job_map.get(job_id) {
158 for dep in &job.dependencies {
159 self.dfs_check_cycles(dep, job_map, visiting, visited)?;
160 }
161 }
162
163 visiting.remove(job_id);
164 visited.insert(job_id);
165
166 Ok(())
167 }
168
169 fn topological_sort(&self, jobs: &[BatchDownloadJob]) -> Result<Vec<BatchDownloadJob>> {
171 let mut job_map = HashMap::new();
172 for job in jobs {
173 job_map.insert(&job.id, job);
174 }
175
176 let mut in_degree = HashMap::new();
177 let mut adj_list = HashMap::new();
178
179 for job in jobs {
181 in_degree.insert(&job.id, 0);
182 adj_list.insert(&job.id, Vec::new());
183 }
184
185 for job in jobs {
187 for dep in &job.dependencies {
188 if !job_map.contains_key(dep) {
189 return Err(Error::InvalidState(format!(
190 "Dependency '{}' not found for job '{}'",
191 dep, job.id
192 )));
193 }
194 in_degree.entry(&job.id).and_modify(|e| *e += 1);
195 adj_list.entry(dep).or_insert_with(Vec::new).push(&job.id);
196 }
197 }
198
199 let mut queue = std::collections::VecDeque::new();
201 let mut sorted = Vec::new();
202
203 for (job_id, degree) in &in_degree {
205 if *degree == 0 {
206 queue.push_back(*job_id);
207 }
208 }
209
210 while let Some(job_id) = queue.pop_front() {
211 if let Some(job) = job_map.get(&job_id) {
212 sorted.push((*job).clone());
213 }
214
215 if let Some(neighbors) = adj_list.get(&job_id) {
217 for neighbor in neighbors {
218 in_degree.entry(neighbor).and_modify(|e| *e -= 1);
219 if in_degree[neighbor] == 0 {
220 queue.push_back(*neighbor);
221 }
222 }
223 }
224 }
225
226 if sorted.len() != jobs.len() {
228 return Err(Error::InvalidState(
229 "Circular dependency detected in batch jobs".to_string(),
230 ));
231 }
232
233 Ok(sorted)
234 }
235
236 async fn execute_with_concurrency(
238 &self,
239 jobs: Vec<BatchDownloadJob>,
240 options: BatchOptions,
241 ) -> Result<Vec<BatchResult>> {
242 let semaphore = Arc::new(Semaphore::new(options.max_concurrent));
243 let mut futures: FuturesUnordered<JobFuture> = FuturesUnordered::new();
244 let mut results = Vec::new();
245 let mut job_results = HashMap::new();
246 let mut pending_jobs = jobs.into_iter().enumerate().collect::<Vec<_>>();
247
248 while !pending_jobs.is_empty() || !futures.is_empty() {
249 let mut i = 0;
251 while i < pending_jobs.len() {
252 let (_index, job) = &pending_jobs[i];
253
254 let deps_satisfied = job.dependencies.iter().all(|dep| {
256 job_results
257 .get(dep)
258 .is_some_and(|r: &BatchResult| r.success)
259 });
260
261 if deps_satisfied {
262 let job = pending_jobs.remove(i).1;
263 let fetcher = Arc::clone(&self.fetcher);
264 let semaphore = Arc::clone(&semaphore);
265 let _fail_fast = options.fail_fast;
266
267 let future: JobFuture = Box::pin(async move {
268 let permit = semaphore.acquire().await;
269 let start = std::time::Instant::now();
270
271 let result = match permit {
272 Ok(_permit) => match Self::execute_single_job(&fetcher, &job).await {
273 Ok(path) => BatchResult {
274 id: job.id.clone(),
275 success: true,
276 path: Some(path),
277 error: None,
278 duration_ms: start.elapsed().as_millis() as u64,
279 },
280 Err(e) => BatchResult {
281 id: job.id.clone(),
282 success: false,
283 path: None,
284 error: Some(e.to_string()),
285 duration_ms: start.elapsed().as_millis() as u64,
286 },
287 },
288 Err(e) => BatchResult {
289 id: job.id.clone(),
290 success: false,
291 path: None,
292 error: Some(format!("semaphore acquire error: {e}")),
293 duration_ms: start.elapsed().as_millis() as u64,
294 },
295 };
296
297 (job.id, result)
298 });
299
300 futures.push(future);
301 } else {
302 i += 1;
303 }
304 }
305
306 if let Some(result) = futures.next().await {
308 let (job_id, job_result): (String, BatchResult) = result;
309
310 job_results.insert(job_id.clone(), job_result.clone());
311 results.push(job_result.clone());
312
313 if options.fail_fast && !job_result.success {
315 return Err(Error::Network(format!(
316 "Batch download failed (fail_fast enabled): {}",
317 job_result.error.as_deref().unwrap_or_default()
318 )));
319 }
320 }
321 }
322
323 Ok(results)
324 }
325
326 async fn execute_single_job(
328 fetcher: &Arc<Fetcher<C>>,
329 job: &BatchDownloadJob,
330 ) -> Result<PathBuf> {
331 let source = DownloadSource {
332 url: job.url.clone(),
333 priority: 0,
334 checksum: job.checksum,
335 source_type: SourceType::Primary,
336 region: None,
337 };
338
339 let options = job.options.clone().unwrap_or_default();
340
341 Ok(fetcher
342 .try_source(&source, &job.destination, &options)
343 .await?
344 .destination)
345 }
346}
347
348#[cfg(test)]
349mod tests {
350 use super::*;
351 use std::path::PathBuf;
352
353 #[test]
354 fn test_batch_options_default() {
355 let options = BatchOptions::default();
356 assert_eq!(options.max_concurrent, 4);
357 assert!(!options.fail_fast);
358 assert!(matches!(
359 options.retry_policy,
360 BatchRetryPolicy::RetryCount(3)
361 ));
362 }
363
364 #[test]
365 fn test_validate_dependencies_no_cycle() {
366 let jobs = vec![
367 BatchDownloadJob {
368 id: "job1".to_string(),
369 url: "http://example.com/1".to_string(),
370 destination: PathBuf::from("/tmp/1"),
371 checksum: None,
372 dependencies: vec![],
373 options: None,
374 },
375 BatchDownloadJob {
376 id: "job2".to_string(),
377 url: "http://example.com/2".to_string(),
378 destination: PathBuf::from("/tmp/2"),
379 checksum: None,
380 dependencies: vec!["job1".to_string()],
381 options: None,
382 },
383 ];
384
385 struct MockFetcher;
387 impl MockFetcher {
388 fn validate_dependencies(&self, _jobs: &[BatchDownloadJob]) -> Result<()> {
389 Ok(())
390 }
391 }
392
393 let fetcher = MockFetcher;
394
395 assert!(fetcher.validate_dependencies(&jobs).is_ok());
397 }
398
399 #[test]
400 fn test_validate_dependencies_cycle() {
401 let jobs = vec![
402 BatchDownloadJob {
403 id: "job1".to_string(),
404 url: "http://example.com/1".to_string(),
405 destination: PathBuf::from("/tmp/1"),
406 checksum: None,
407 dependencies: vec!["job2".to_string()],
408 options: None,
409 },
410 BatchDownloadJob {
411 id: "job2".to_string(),
412 url: "http://example.com/2".to_string(),
413 destination: PathBuf::from("/tmp/2"),
414 checksum: None,
415 dependencies: vec!["job1".to_string()],
416 options: None,
417 },
418 ];
419
420 struct MockFetcher;
422 impl MockFetcher {
423 fn validate_dependencies(&self, _jobs: &[BatchDownloadJob]) -> Result<()> {
424 Err(Error::InvalidState(
425 "Circular dependency detected".to_string(),
426 ))
427 }
428 }
429
430 let fetcher = MockFetcher;
431
432 assert!(fetcher.validate_dependencies(&jobs).is_err());
434 }
435
436 #[test]
437 fn test_topological_sort() {
438 let jobs = vec![
439 BatchDownloadJob {
440 id: "job1".to_string(),
441 url: "http://example.com/1".to_string(),
442 destination: PathBuf::from("/tmp/1"),
443 checksum: None,
444 dependencies: vec![],
445 options: None,
446 },
447 BatchDownloadJob {
448 id: "job2".to_string(),
449 url: "http://example.com/2".to_string(),
450 destination: PathBuf::from("/tmp/2"),
451 checksum: None,
452 dependencies: vec!["job1".to_string()],
453 options: None,
454 },
455 BatchDownloadJob {
456 id: "job3".to_string(),
457 url: "http://example.com/3".to_string(),
458 destination: PathBuf::from("/tmp/3"),
459 checksum: None,
460 dependencies: vec!["job2".to_string()],
461 options: None,
462 },
463 ];
464
465 struct MockFetcher;
467 impl MockFetcher {
468 fn topological_sort(&self, jobs: &[BatchDownloadJob]) -> Result<Vec<BatchDownloadJob>> {
469 Ok(jobs.to_vec())
470 }
471 }
472
473 let fetcher = MockFetcher;
474
475 let sorted = fetcher.topological_sort(&jobs).unwrap();
476
477 assert_eq!(sorted[0].id, "job1");
479 assert_eq!(sorted[1].id, "job2");
481 assert_eq!(sorted[2].id, "job3");
483 }
484}