Skip to main content

dnx_core/
fetcher.rs

1use crate::cache::ContentCache;
2use crate::errors::Result;
3use crate::registry::RegistryClient;
4use crate::resolver::DependencyGraph;
5use futures::future::join_all;
6use serde::{Deserialize, Serialize};
7use std::sync::atomic::{AtomicUsize, Ordering};
8use std::sync::Arc;
9use tokio::sync::{mpsc, Semaphore};
10
11#[derive(Debug, Clone)]
12pub struct FetchProgress {
13    pub package: String,
14    pub version: String,
15    pub cached: bool,
16    pub total: usize,
17    pub completed: usize,
18}
19
20#[derive(Debug, Default, Serialize, Deserialize)]
21pub struct FetchResult {
22    pub fetched: usize,
23    pub cached: usize,
24    pub failed: Vec<String>,
25}
26
27pub struct Fetcher {
28    registry: Arc<RegistryClient>,
29    cache: Arc<ContentCache>,
30    concurrency: usize,
31}
32
33impl Fetcher {
34    pub fn new(
35        registry: Arc<RegistryClient>,
36        cache: Arc<ContentCache>,
37        concurrency: Option<usize>,
38    ) -> Self {
39        Self {
40            registry,
41            cache,
42            concurrency: concurrency.unwrap_or(64),
43        }
44    }
45
46    pub async fn fetch_all(
47        &self,
48        graph: &DependencyGraph,
49        progress_tx: Option<mpsc::Sender<FetchProgress>>,
50    ) -> Result<FetchResult> {
51        let total = graph.packages.len();
52        let semaphore = Arc::new(Semaphore::new(self.concurrency));
53        let mut tasks = Vec::new();
54
55        let cached_count = Arc::new(AtomicUsize::new(0));
56        let fetched_count = Arc::new(AtomicUsize::new(0));
57        let failed_list = Arc::new(tokio::sync::Mutex::new(Vec::new()));
58
59        for package in &graph.packages {
60            let pkg_name = package.name.clone();
61            let pkg_version = package.version.clone();
62            let pkg_integrity = package.integrity.clone();
63            let pkg_tarball_url = package.tarball_url.clone();
64
65            // Skip packages that don't come from the registry (git, file, link, workspace)
66            if pkg_version.starts_with("link:")
67                || pkg_version.starts_with("file:")
68                || pkg_version.starts_with("workspace:")
69            {
70                cached_count.fetch_add(1, Ordering::Relaxed);
71                continue;
72            }
73
74            // Handle git deps: clone to temp, store in cache
75            if pkg_version.starts_with("git+") || pkg_version.starts_with("github:") {
76                if !pkg_integrity.is_empty() && self.cache.has(&pkg_integrity) {
77                    cached_count.fetch_add(1, Ordering::Relaxed);
78                    continue;
79                }
80                // For git deps with a tarball URL (github codeload), fetch as tarball
81                if !pkg_tarball_url.is_empty() {
82                    let cache = Arc::clone(&self.cache);
83                    let registry = Arc::clone(&self.registry);
84                    let semaphore = Arc::clone(&semaphore);
85                    let fetched_count = Arc::clone(&fetched_count);
86                    let failed_list = Arc::clone(&failed_list);
87                    let task = tokio::spawn(async move {
88                        let _permit = semaphore.acquire().await.unwrap();
89                        match registry.fetch_tarball(&pkg_tarball_url).await {
90                            Ok(tarball_bytes) => {
91                                let git_integrity = format!(
92                                    "git-{}-{}",
93                                    pkg_name,
94                                    pkg_version.replace(['/', ':', '#', '+'], "-")
95                                );
96                                let cache_clone = Arc::clone(&cache);
97                                let bytes_vec = tarball_bytes.to_vec();
98                                match tokio::task::spawn_blocking(move || {
99                                    cache_clone.store(&git_integrity, &bytes_vec)
100                                })
101                                .await
102                                {
103                                    Ok(Ok(_)) => {
104                                        fetched_count.fetch_add(1, Ordering::Relaxed);
105                                    }
106                                    _ => {
107                                        failed_list
108                                            .lock()
109                                            .await
110                                            .push(format!("{}@{}", pkg_name, pkg_version));
111                                    }
112                                }
113                            }
114                            Err(_) => {
115                                failed_list
116                                    .lock()
117                                    .await
118                                    .push(format!("{}@{}", pkg_name, pkg_version));
119                            }
120                        }
121                    });
122                    tasks.push(task);
123                }
124                continue;
125            }
126
127            let cache = Arc::clone(&self.cache);
128            let registry = Arc::clone(&self.registry);
129            let semaphore = Arc::clone(&semaphore);
130            let progress_tx = progress_tx.clone();
131            let cached_count = Arc::clone(&cached_count);
132            let fetched_count = Arc::clone(&fetched_count);
133            let failed_list = Arc::clone(&failed_list);
134
135            // Check if already cached
136            if cache.has(&pkg_integrity) {
137                let new_cached = cached_count.fetch_add(1, Ordering::Relaxed) + 1;
138                let completed = new_cached + fetched_count.load(Ordering::Relaxed);
139
140                if let Some(tx) = progress_tx {
141                    let progress = FetchProgress {
142                        package: pkg_name,
143                        version: pkg_version,
144                        cached: true,
145                        total,
146                        completed,
147                    };
148                    let _ = tx.send(progress).await;
149                }
150                continue;
151            }
152
153            // Spawn task to fetch package
154            let task = tokio::spawn(async move {
155                let _permit = semaphore.acquire().await.unwrap();
156
157                match registry.fetch_tarball(&pkg_tarball_url).await {
158                    Ok(tarball_bytes) => {
159                        let cache_clone = Arc::clone(&cache);
160                        let integrity_clone = pkg_integrity.clone();
161                        let bytes_vec = tarball_bytes.to_vec();
162                        let store_result = tokio::task::spawn_blocking(move || {
163                            cache_clone.store(&integrity_clone, &bytes_vec)
164                        })
165                        .await;
166                        match store_result {
167                            Ok(Ok(_)) => {
168                                let new_fetched = fetched_count.fetch_add(1, Ordering::Relaxed) + 1;
169                                let completed = new_fetched + cached_count.load(Ordering::Relaxed);
170
171                                if let Some(tx) = progress_tx {
172                                    let progress = FetchProgress {
173                                        package: pkg_name,
174                                        version: pkg_version,
175                                        cached: false,
176                                        total,
177                                        completed,
178                                    };
179                                    let _ = tx.send(progress).await;
180                                }
181                            }
182                            _ => {
183                                failed_list
184                                    .lock()
185                                    .await
186                                    .push(format!("{}@{}", pkg_name, pkg_version));
187                            }
188                        }
189                    }
190                    Err(_) => {
191                        failed_list
192                            .lock()
193                            .await
194                            .push(format!("{}@{}", pkg_name, pkg_version));
195                    }
196                }
197            });
198
199            tasks.push(task);
200        }
201
202        // Wait for all tasks to complete
203        join_all(tasks).await;
204
205        let cached = cached_count.load(Ordering::Relaxed);
206        let fetched = fetched_count.load(Ordering::Relaxed);
207        let failed = failed_list.lock().await.clone();
208
209        Ok(FetchResult {
210            fetched,
211            cached,
212            failed,
213        })
214    }
215}