Skip to main content

pulith_fetch/fetch/
batch.rs

1//! Batch download functionality.
2//!
3//! This module provides the ability to download multiple files
4//! with dependency resolution and concurrency control.
5
6use std::collections::{HashMap, HashSet};
7use std::future::Future;
8use std::path::PathBuf;
9use std::pin::Pin;
10use std::sync::Arc;
11
12use futures_util::{StreamExt, stream::FuturesUnordered};
13use tokio::sync::Semaphore;
14
15use crate::error::{Error, Result};
16use crate::{DownloadSource, FetchOptions, SourceType};
17use crate::{Fetcher, HttpClient};
18
19/// Configuration for batch downloads.
20#[derive(Debug, Clone)]
21pub struct BatchOptions {
22    /// Maximum number of concurrent downloads
23    pub max_concurrent: usize,
24    /// Whether to fail fast on first error or continue with other downloads
25    pub fail_fast: bool,
26    /// Retry policy for batch operations
27    pub retry_policy: BatchRetryPolicy,
28}
29
30impl Default for BatchOptions {
31    fn default() -> Self {
32        Self {
33            max_concurrent: 4,
34            fail_fast: false,
35            retry_policy: BatchRetryPolicy::RetryCount(3),
36        }
37    }
38}
39
40/// Retry policy for batch downloads.
41#[derive(Debug, Clone)]
42pub enum BatchRetryPolicy {
43    /// Retry a fixed number of times
44    RetryCount(u32),
45    /// Retry indefinitely (not recommended for production)
46    Infinite,
47    /// No retries
48    None,
49}
50
51/// A job in a batch download.
52#[derive(Debug, Clone)]
53pub struct BatchDownloadJob {
54    /// Unique identifier for this job
55    pub id: String,
56    /// URL to download from
57    pub url: String,
58    /// Destination path
59    pub destination: PathBuf,
60    /// Optional checksum for verification
61    pub checksum: Option<[u8; 32]>,
62    /// Jobs that must complete before this one can start
63    pub dependencies: Vec<String>,
64    /// Fetch options specific to this job
65    pub options: Option<FetchOptions>,
66}
67
68/// Result of a batch download job.
69#[derive(Debug, Clone)]
70pub struct BatchResult {
71    /// Job ID
72    pub id: String,
73    /// Whether the download succeeded
74    pub success: bool,
75    /// Path to the downloaded file (if successful)
76    pub path: Option<PathBuf>,
77    /// Error message (if failed)
78    pub error: Option<String>,
79    /// Time taken to download
80    pub duration_ms: u64,
81}
82
83/// Batch fetcher implementation.
84pub struct BatchFetcher<C: HttpClient> {
85    fetcher: Arc<Fetcher<C>>,
86    _workspace_root: PathBuf,
87}
88
89type JobFuture = Pin<Box<dyn Future<Output = (String, BatchResult)> + Send>>;
90
91impl<C: HttpClient + 'static> BatchFetcher<C> {
92    /// Create a new batch fetcher.
93    pub fn new(fetcher: Fetcher<C>, workspace_root: impl Into<PathBuf>) -> Self {
94        Self {
95            fetcher: Arc::new(fetcher),
96            _workspace_root: workspace_root.into(),
97        }
98    }
99
100    /// Execute a batch of downloads with dependency resolution.
101    pub async fn fetch_batch(
102        &self,
103        jobs: Vec<BatchDownloadJob>,
104        options: BatchOptions,
105    ) -> Result<Vec<BatchResult>> {
106        // Validate no circular dependencies
107        self.validate_dependencies(&jobs)?;
108
109        // Sort jobs by dependencies (topological sort)
110        let sorted_jobs = self.topological_sort(&jobs)?;
111
112        // Execute downloads with concurrency control
113        self.execute_with_concurrency(sorted_jobs, options).await
114    }
115
116    /// Validate that there are no circular dependencies.
117    fn validate_dependencies(&self, jobs: &[BatchDownloadJob]) -> Result<()> {
118        let mut job_map = HashMap::new();
119        for job in jobs {
120            job_map.insert(job.id.as_str(), job);
121        }
122
123        // DFS to detect cycles
124        let mut visiting = HashSet::new();
125        let mut visited = HashSet::new();
126
127        for job in jobs {
128            if !visited.contains(&job.id.as_str()) {
129                self.dfs_check_cycles(&job.id, &job_map, &mut visiting, &mut visited)?;
130            }
131        }
132
133        Ok(())
134    }
135
136    /// Depth-first search to detect circular dependencies.
137    fn dfs_check_cycles<'a>(
138        &self,
139        job_id: &'a str,
140        job_map: &HashMap<&str, &'a BatchDownloadJob>,
141        visiting: &mut HashSet<&'a str>,
142        visited: &mut HashSet<&'a str>,
143    ) -> Result<()> {
144        if visiting.contains(job_id) {
145            return Err(Error::InvalidState(format!(
146                "Circular dependency detected involving job: {}",
147                job_id
148            )));
149        }
150
151        if visited.contains(job_id) {
152            return Ok(());
153        }
154
155        visiting.insert(job_id);
156
157        if let Some(job) = job_map.get(job_id) {
158            for dep in &job.dependencies {
159                self.dfs_check_cycles(dep, job_map, visiting, visited)?;
160            }
161        }
162
163        visiting.remove(job_id);
164        visited.insert(job_id);
165
166        Ok(())
167    }
168
169    /// Topological sort of jobs based on dependencies.
170    fn topological_sort(&self, jobs: &[BatchDownloadJob]) -> Result<Vec<BatchDownloadJob>> {
171        let mut job_map = HashMap::new();
172        for job in jobs {
173            job_map.insert(&job.id, job);
174        }
175
176        let mut in_degree = HashMap::new();
177        let mut adj_list = HashMap::new();
178
179        // Initialize in-degree and adjacency list
180        for job in jobs {
181            in_degree.insert(&job.id, 0);
182            adj_list.insert(&job.id, Vec::new());
183        }
184
185        // Build graph
186        for job in jobs {
187            for dep in &job.dependencies {
188                if !job_map.contains_key(dep) {
189                    return Err(Error::InvalidState(format!(
190                        "Dependency '{}' not found for job '{}'",
191                        dep, job.id
192                    )));
193                }
194                in_degree.entry(&job.id).and_modify(|e| *e += 1);
195                adj_list.entry(dep).or_insert_with(Vec::new).push(&job.id);
196            }
197        }
198
199        // Kahn's algorithm for topological sort
200        let mut queue = std::collections::VecDeque::new();
201        let mut sorted = Vec::new();
202
203        // Find all nodes with no incoming edges
204        for (job_id, degree) in &in_degree {
205            if *degree == 0 {
206                queue.push_back(*job_id);
207            }
208        }
209
210        while let Some(job_id) = queue.pop_front() {
211            if let Some(job) = job_map.get(&job_id) {
212                sorted.push((*job).clone());
213            }
214
215            // Remove edges from this node
216            if let Some(neighbors) = adj_list.get(&job_id) {
217                for neighbor in neighbors {
218                    in_degree.entry(neighbor).and_modify(|e| *e -= 1);
219                    if in_degree[neighbor] == 0 {
220                        queue.push_back(*neighbor);
221                    }
222                }
223            }
224        }
225
226        // Check if all jobs were processed
227        if sorted.len() != jobs.len() {
228            return Err(Error::InvalidState(
229                "Circular dependency detected in batch jobs".to_string(),
230            ));
231        }
232
233        Ok(sorted)
234    }
235
236    /// Execute downloads with concurrency control.
237    async fn execute_with_concurrency(
238        &self,
239        jobs: Vec<BatchDownloadJob>,
240        options: BatchOptions,
241    ) -> Result<Vec<BatchResult>> {
242        let semaphore = Arc::new(Semaphore::new(options.max_concurrent));
243        let mut futures: FuturesUnordered<JobFuture> = FuturesUnordered::new();
244        let mut results = Vec::new();
245        let mut job_results = HashMap::new();
246        let mut pending_jobs = jobs.into_iter().enumerate().collect::<Vec<_>>();
247
248        while !pending_jobs.is_empty() || !futures.is_empty() {
249            // Start jobs that have no unmet dependencies
250            let mut i = 0;
251            while i < pending_jobs.len() {
252                let (_index, job) = &pending_jobs[i];
253
254                // Check if all dependencies are satisfied
255                let deps_satisfied = job.dependencies.iter().all(|dep| {
256                    job_results
257                        .get(dep)
258                        .is_some_and(|r: &BatchResult| r.success)
259                });
260
261                if deps_satisfied {
262                    let job = pending_jobs.remove(i).1;
263                    let fetcher = Arc::clone(&self.fetcher);
264                    let semaphore = Arc::clone(&semaphore);
265                    let _fail_fast = options.fail_fast;
266
267                    let future: JobFuture = Box::pin(async move {
268                        let permit = semaphore.acquire().await;
269                        let start = std::time::Instant::now();
270
271                        let result = match permit {
272                            Ok(_permit) => match Self::execute_single_job(&fetcher, &job).await {
273                                Ok(path) => BatchResult {
274                                    id: job.id.clone(),
275                                    success: true,
276                                    path: Some(path),
277                                    error: None,
278                                    duration_ms: start.elapsed().as_millis() as u64,
279                                },
280                                Err(e) => BatchResult {
281                                    id: job.id.clone(),
282                                    success: false,
283                                    path: None,
284                                    error: Some(e.to_string()),
285                                    duration_ms: start.elapsed().as_millis() as u64,
286                                },
287                            },
288                            Err(e) => BatchResult {
289                                id: job.id.clone(),
290                                success: false,
291                                path: None,
292                                error: Some(format!("semaphore acquire error: {e}")),
293                                duration_ms: start.elapsed().as_millis() as u64,
294                            },
295                        };
296
297                        (job.id, result)
298                    });
299
300                    futures.push(future);
301                } else {
302                    i += 1;
303                }
304            }
305
306            // Wait for at least one job to complete
307            if let Some(result) = futures.next().await {
308                let (job_id, job_result): (String, BatchResult) = result;
309
310                job_results.insert(job_id.clone(), job_result.clone());
311                results.push(job_result.clone());
312
313                // If fail_fast is enabled and this job failed, return error
314                if options.fail_fast && !job_result.success {
315                    return Err(Error::Network(format!(
316                        "Batch download failed (fail_fast enabled): {}",
317                        job_result.error.as_deref().unwrap_or_default()
318                    )));
319                }
320            }
321        }
322
323        Ok(results)
324    }
325
326    /// Execute a single download job.
327    async fn execute_single_job(
328        fetcher: &Arc<Fetcher<C>>,
329        job: &BatchDownloadJob,
330    ) -> Result<PathBuf> {
331        let source = DownloadSource {
332            url: job.url.clone(),
333            priority: 0,
334            checksum: job.checksum,
335            source_type: SourceType::Primary,
336            region: None,
337        };
338
339        let options = job.options.clone().unwrap_or_default();
340
341        Ok(fetcher
342            .try_source(&source, &job.destination, &options)
343            .await?
344            .destination)
345    }
346}
347
348#[cfg(test)]
349mod tests {
350    use super::*;
351    use std::path::PathBuf;
352
353    #[test]
354    fn test_batch_options_default() {
355        let options = BatchOptions::default();
356        assert_eq!(options.max_concurrent, 4);
357        assert!(!options.fail_fast);
358        assert!(matches!(
359            options.retry_policy,
360            BatchRetryPolicy::RetryCount(3)
361        ));
362    }
363
364    #[test]
365    fn test_validate_dependencies_no_cycle() {
366        let jobs = vec![
367            BatchDownloadJob {
368                id: "job1".to_string(),
369                url: "http://example.com/1".to_string(),
370                destination: PathBuf::from("/tmp/1"),
371                checksum: None,
372                dependencies: vec![],
373                options: None,
374            },
375            BatchDownloadJob {
376                id: "job2".to_string(),
377                url: "http://example.com/2".to_string(),
378                destination: PathBuf::from("/tmp/2"),
379                checksum: None,
380                dependencies: vec!["job1".to_string()],
381                options: None,
382            },
383        ];
384
385        // Create a mock fetcher for testing
386        struct MockFetcher;
387        impl MockFetcher {
388            fn validate_dependencies(&self, _jobs: &[BatchDownloadJob]) -> Result<()> {
389                Ok(())
390            }
391        }
392
393        let fetcher = MockFetcher;
394
395        // This should not panic
396        assert!(fetcher.validate_dependencies(&jobs).is_ok());
397    }
398
399    #[test]
400    fn test_validate_dependencies_cycle() {
401        let jobs = vec![
402            BatchDownloadJob {
403                id: "job1".to_string(),
404                url: "http://example.com/1".to_string(),
405                destination: PathBuf::from("/tmp/1"),
406                checksum: None,
407                dependencies: vec!["job2".to_string()],
408                options: None,
409            },
410            BatchDownloadJob {
411                id: "job2".to_string(),
412                url: "http://example.com/2".to_string(),
413                destination: PathBuf::from("/tmp/2"),
414                checksum: None,
415                dependencies: vec!["job1".to_string()],
416                options: None,
417            },
418        ];
419
420        // Create a mock fetcher for testing
421        struct MockFetcher;
422        impl MockFetcher {
423            fn validate_dependencies(&self, _jobs: &[BatchDownloadJob]) -> Result<()> {
424                Err(Error::InvalidState(
425                    "Circular dependency detected".to_string(),
426                ))
427            }
428        }
429
430        let fetcher = MockFetcher;
431
432        // This should detect the circular dependency
433        assert!(fetcher.validate_dependencies(&jobs).is_err());
434    }
435
436    #[test]
437    fn test_topological_sort() {
438        let jobs = vec![
439            BatchDownloadJob {
440                id: "job1".to_string(),
441                url: "http://example.com/1".to_string(),
442                destination: PathBuf::from("/tmp/1"),
443                checksum: None,
444                dependencies: vec![],
445                options: None,
446            },
447            BatchDownloadJob {
448                id: "job2".to_string(),
449                url: "http://example.com/2".to_string(),
450                destination: PathBuf::from("/tmp/2"),
451                checksum: None,
452                dependencies: vec!["job1".to_string()],
453                options: None,
454            },
455            BatchDownloadJob {
456                id: "job3".to_string(),
457                url: "http://example.com/3".to_string(),
458                destination: PathBuf::from("/tmp/3"),
459                checksum: None,
460                dependencies: vec!["job2".to_string()],
461                options: None,
462            },
463        ];
464
465        // Create a mock fetcher for testing
466        struct MockFetcher;
467        impl MockFetcher {
468            fn topological_sort(&self, jobs: &[BatchDownloadJob]) -> Result<Vec<BatchDownloadJob>> {
469                Ok(jobs.to_vec())
470            }
471        }
472
473        let fetcher = MockFetcher;
474
475        let sorted = fetcher.topological_sort(&jobs).unwrap();
476
477        // job1 should come first (no dependencies)
478        assert_eq!(sorted[0].id, "job1");
479        // job2 should come after job1
480        assert_eq!(sorted[1].id, "job2");
481        // job3 should come after job2
482        assert_eq!(sorted[2].id, "job3");
483    }
484}