1use crate::cache::ContentCache;
2use crate::errors::Result;
3use crate::registry::RegistryClient;
4use crate::resolver::DependencyGraph;
5use futures::future::join_all;
6use serde::{Deserialize, Serialize};
7use std::sync::atomic::{AtomicUsize, Ordering};
8use std::sync::Arc;
9use tokio::sync::{mpsc, Semaphore};
10
11#[derive(Debug, Clone)]
12pub struct FetchProgress {
13 pub package: String,
14 pub version: String,
15 pub cached: bool,
16 pub total: usize,
17 pub completed: usize,
18}
19
20#[derive(Debug, Default, Serialize, Deserialize)]
21pub struct FetchResult {
22 pub fetched: usize,
23 pub cached: usize,
24 pub failed: Vec<String>,
25}
26
27pub struct Fetcher {
28 registry: Arc<RegistryClient>,
29 cache: Arc<ContentCache>,
30 concurrency: usize,
31}
32
33impl Fetcher {
34 pub fn new(
35 registry: Arc<RegistryClient>,
36 cache: Arc<ContentCache>,
37 concurrency: Option<usize>,
38 ) -> Self {
39 Self {
40 registry,
41 cache,
42 concurrency: concurrency.unwrap_or(64),
43 }
44 }
45
46 pub async fn fetch_all(
47 &self,
48 graph: &DependencyGraph,
49 progress_tx: Option<mpsc::Sender<FetchProgress>>,
50 ) -> Result<FetchResult> {
51 let total = graph.packages.len();
52 let semaphore = Arc::new(Semaphore::new(self.concurrency));
53 let mut tasks = Vec::new();
54
55 let cached_count = Arc::new(AtomicUsize::new(0));
56 let fetched_count = Arc::new(AtomicUsize::new(0));
57 let failed_list = Arc::new(tokio::sync::Mutex::new(Vec::new()));
58
59 for package in &graph.packages {
60 let pkg_name = package.name.clone();
61 let pkg_version = package.version.clone();
62 let pkg_integrity = package.integrity.clone();
63 let pkg_tarball_url = package.tarball_url.clone();
64
65 if pkg_version.starts_with("link:")
67 || pkg_version.starts_with("file:")
68 || pkg_version.starts_with("workspace:")
69 {
70 cached_count.fetch_add(1, Ordering::Relaxed);
71 continue;
72 }
73
74 if pkg_version.starts_with("git+") || pkg_version.starts_with("github:") {
76 if !pkg_integrity.is_empty() && self.cache.has(&pkg_integrity) {
77 cached_count.fetch_add(1, Ordering::Relaxed);
78 continue;
79 }
80 if !pkg_tarball_url.is_empty() {
82 let cache = Arc::clone(&self.cache);
83 let registry = Arc::clone(&self.registry);
84 let semaphore = Arc::clone(&semaphore);
85 let fetched_count = Arc::clone(&fetched_count);
86 let failed_list = Arc::clone(&failed_list);
87 let task = tokio::spawn(async move {
88 let _permit = semaphore.acquire().await.unwrap();
89 match registry.fetch_tarball(&pkg_tarball_url).await {
90 Ok(tarball_bytes) => {
91 let git_integrity = format!(
92 "git-{}-{}",
93 pkg_name,
94 pkg_version.replace(['/', ':', '#', '+'], "-")
95 );
96 let cache_clone = Arc::clone(&cache);
97 let bytes_vec = tarball_bytes.to_vec();
98 match tokio::task::spawn_blocking(move || {
99 cache_clone.store(&git_integrity, &bytes_vec)
100 })
101 .await
102 {
103 Ok(Ok(_)) => {
104 fetched_count.fetch_add(1, Ordering::Relaxed);
105 }
106 _ => {
107 failed_list
108 .lock()
109 .await
110 .push(format!("{}@{}", pkg_name, pkg_version));
111 }
112 }
113 }
114 Err(_) => {
115 failed_list
116 .lock()
117 .await
118 .push(format!("{}@{}", pkg_name, pkg_version));
119 }
120 }
121 });
122 tasks.push(task);
123 }
124 continue;
125 }
126
127 let cache = Arc::clone(&self.cache);
128 let registry = Arc::clone(&self.registry);
129 let semaphore = Arc::clone(&semaphore);
130 let progress_tx = progress_tx.clone();
131 let cached_count = Arc::clone(&cached_count);
132 let fetched_count = Arc::clone(&fetched_count);
133 let failed_list = Arc::clone(&failed_list);
134
135 if cache.has(&pkg_integrity) {
137 let new_cached = cached_count.fetch_add(1, Ordering::Relaxed) + 1;
138 let completed = new_cached + fetched_count.load(Ordering::Relaxed);
139
140 if let Some(tx) = progress_tx {
141 let progress = FetchProgress {
142 package: pkg_name,
143 version: pkg_version,
144 cached: true,
145 total,
146 completed,
147 };
148 let _ = tx.send(progress).await;
149 }
150 continue;
151 }
152
153 let task = tokio::spawn(async move {
155 let _permit = semaphore.acquire().await.unwrap();
156
157 match registry.fetch_tarball(&pkg_tarball_url).await {
158 Ok(tarball_bytes) => {
159 let cache_clone = Arc::clone(&cache);
160 let integrity_clone = pkg_integrity.clone();
161 let bytes_vec = tarball_bytes.to_vec();
162 let store_result = tokio::task::spawn_blocking(move || {
163 cache_clone.store(&integrity_clone, &bytes_vec)
164 })
165 .await;
166 match store_result {
167 Ok(Ok(_)) => {
168 let new_fetched = fetched_count.fetch_add(1, Ordering::Relaxed) + 1;
169 let completed = new_fetched + cached_count.load(Ordering::Relaxed);
170
171 if let Some(tx) = progress_tx {
172 let progress = FetchProgress {
173 package: pkg_name,
174 version: pkg_version,
175 cached: false,
176 total,
177 completed,
178 };
179 let _ = tx.send(progress).await;
180 }
181 }
182 _ => {
183 failed_list
184 .lock()
185 .await
186 .push(format!("{}@{}", pkg_name, pkg_version));
187 }
188 }
189 }
190 Err(_) => {
191 failed_list
192 .lock()
193 .await
194 .push(format!("{}@{}", pkg_name, pkg_version));
195 }
196 }
197 });
198
199 tasks.push(task);
200 }
201
202 join_all(tasks).await;
204
205 let cached = cached_count.load(Ordering::Relaxed);
206 let fetched = fetched_count.load(Ordering::Relaxed);
207 let failed = failed_list.lock().await.clone();
208
209 Ok(FetchResult {
210 fetched,
211 cached,
212 failed,
213 })
214 }
215}