1use check_updates_core::{PackageInfo, Version};
2use anyhow::{anyhow, Context, Result};
3use serde::Deserialize;
4use std::collections::HashMap;
5use std::str::FromStr;
6use std::sync::Arc;
7use tokio::sync::Semaphore;
8
9pub struct PyPiClient {
11 client: reqwest::Client,
12 base_url: String,
13 include_prerelease: bool,
14}
15
16#[derive(Debug, Deserialize)]
18struct PyPiResponse {
19 info: PyPiInfo,
20 releases: HashMap<String, Vec<PyPiRelease>>,
21}
22
23#[derive(Debug, Deserialize)]
24struct PyPiInfo {
25 name: String,
26}
27
28#[derive(Debug, Deserialize)]
29struct PyPiRelease {
30 #[allow(dead_code)]
31 yanked: Option<bool>,
32}
33
34impl PyPiClient {
35 pub fn new(include_prerelease: bool) -> Self {
36 Self {
37 client: reqwest::Client::builder()
38 .user_agent("python-check-updates/0.1.0")
39 .timeout(std::time::Duration::from_secs(30))
40 .build()
41 .unwrap_or_else(|_| reqwest::Client::new()),
42 base_url: "https://pypi.org/pypi".to_string(),
43 include_prerelease,
44 }
45 }
46
47 pub fn with_index_url(mut self, url: &str) -> Self {
48 self.base_url = url.trim_end_matches('/').to_string();
50 self
51 }
52
53 pub async fn get_package(&self, name: &str) -> Result<PackageInfo> {
55 let url = format!("{}/{}/json", self.base_url, name);
56
57 let response = self
58 .client
59 .get(&url)
60 .send()
61 .await
62 .context(format!("Failed to fetch package '{name}'"))?;
63
64 if !response.status().is_success() {
65 if response.status() == 404 {
66 return Err(anyhow!("Package '{name}' not found on PyPI"));
67 }
68 return Err(anyhow!(
69 "PyPI API request failed with status: {}",
70 response.status()
71 ));
72 }
73
74 let pypi_data: PyPiResponse = response
75 .json()
76 .await
77 .context(format!("Failed to parse JSON response for '{name}'"))?;
78
79 let mut all_versions: Vec<Version> = Vec::new();
81 for (version_str, releases) in &pypi_data.releases {
82 if releases.is_empty() {
84 continue;
85 }
86
87 let all_yanked = releases.iter().all(|r| r.yanked.unwrap_or(false));
89 if all_yanked {
90 continue;
91 }
92
93 if let Ok(version) = Version::from_str(version_str) {
95 all_versions.push(version);
96 }
97 }
98
99 if all_versions.is_empty() {
100 return Err(anyhow!("No valid versions found for package '{name}'"));
101 }
102
103 all_versions.sort();
105
106 let filtered_versions: Vec<Version> = if self.include_prerelease {
108 all_versions.clone()
109 } else {
110 all_versions
111 .iter()
112 .filter(|v| !v.is_prerelease())
113 .cloned()
114 .collect()
115 };
116
117 if filtered_versions.is_empty() {
118 return Err(anyhow!(
119 "No stable versions found for package '{name}' (use --pre-release to include pre-releases)"
120 ));
121 }
122
123 let latest = if self.include_prerelease {
125 all_versions
126 .last()
127 .ok_or_else(|| anyhow!("No versions found"))?
128 .clone()
129 } else {
130 filtered_versions
131 .last()
132 .ok_or_else(|| anyhow!("No stable versions found"))?
133 .clone()
134 };
135
136 let latest_stable = all_versions
138 .iter()
139 .rfind(|v| !v.is_prerelease())
140 .cloned();
141
142 Ok(PackageInfo {
143 name: pypi_data.info.name,
144 versions: filtered_versions,
145 latest,
146 latest_stable,
147 })
148 }
149
150 pub async fn get_packages(
152 &self,
153 names: &[String],
154 progress_callback: impl Fn(usize, usize) + Send + Sync + 'static,
155 ) -> Result<GetPackagesResult> {
156 let total = names.len();
157 let progress_callback = Arc::new(progress_callback);
158
159 let semaphore = Arc::new(Semaphore::new(10));
161
162 let mut tasks = Vec::new();
163
164 for (index, name) in names.iter().enumerate() {
165 let client = self.clone();
166 let name = name.clone();
167 let callback = Arc::clone(&progress_callback);
168 let semaphore = Arc::clone(&semaphore);
169
170 let task = tokio::spawn(async move {
171 let _permit = semaphore.acquire().await.expect("semaphore closed");
173
174 let result = client.get_package(&name).await;
175
176 callback(index + 1, total);
178
179 (name, result)
180 });
181
182 tasks.push(task);
183 }
184
185 let mut packages = HashMap::new();
187 let mut errors = Vec::new();
188
189 for task in tasks {
190 match task.await {
191 Ok((name, Ok(package_info))) => {
192 packages.insert(name, package_info);
193 }
194 Ok((name, Err(e))) => {
195 let error_msg = e.to_string();
197 errors.push((name, error_msg));
198 }
199 Err(e) => {
200 errors.push(("unknown".to_string(), format!("Task failed: {e}")));
201 }
202 }
203 }
204
205 let formatted_errors: Vec<String> = errors
207 .into_iter()
208 .map(|(name, msg)| format!("{name}: {msg}"))
209 .collect();
210
211 if !packages.is_empty() || formatted_errors.is_empty() {
213 Ok(GetPackagesResult {
214 packages,
215 errors: formatted_errors,
216 })
217 } else {
218 Err(anyhow!(
220 "Failed to fetch all packages:\n{}",
221 formatted_errors.join("\n")
222 ))
223 }
224 }
225}
226
227#[derive(Debug, Clone)]
229pub struct GetPackagesResult {
230 pub packages: HashMap<String, PackageInfo>,
231 pub errors: Vec<String>,
232}
233
234impl Clone for PyPiClient {
236 fn clone(&self) -> Self {
237 Self {
238 client: self.client.clone(),
239 base_url: self.base_url.clone(),
240 include_prerelease: self.include_prerelease,
241 }
242 }
243}
244
245#[cfg(test)]
246mod tests {
247 use super::*;
248
249 #[tokio::test]
250 async fn test_get_package_requests() {
251 let client = PyPiClient::new(false);
252 let result = client.get_package("requests").await;
253
254 assert!(result.is_ok(), "Failed to fetch requests package: {:?}", result.err());
255
256 let package_info = result.unwrap();
257 assert_eq!(package_info.name.to_lowercase(), "requests");
258 assert!(!package_info.versions.is_empty());
259 assert!(package_info.latest_stable.is_some());
260 }
261
262 #[tokio::test]
263 async fn test_get_package_not_found() {
264 let client = PyPiClient::new(false);
265 let result = client.get_package("this-package-definitely-does-not-exist-12345").await;
266
267 assert!(result.is_err());
268 assert!(result.unwrap_err().to_string().contains("not found"));
269 }
270
271 #[tokio::test]
272 async fn test_get_packages_concurrent() {
273 let client = PyPiClient::new(false);
274 let packages = vec![
275 "requests".to_string(),
276 "flask".to_string(),
277 ];
278
279 let progress_calls = Arc::new(std::sync::atomic::AtomicUsize::new(0));
281 let progress_calls_clone = Arc::clone(&progress_calls);
282
283 let result = client.get_packages(&packages, move |_current, _total| {
284 progress_calls_clone.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
285 }).await;
286
287 assert!(result.is_ok(), "Failed to fetch packages: {:?}", result.err());
288
289 let results = result.unwrap();
290 assert!(!results.packages.is_empty());
291
292 let calls = progress_calls.load(std::sync::atomic::Ordering::SeqCst);
294 assert!(calls > 0, "Progress callback should have been called");
295 }
296
297 #[tokio::test]
298 async fn test_custom_index_url() {
299 let client = PyPiClient::new(false)
300 .with_index_url("https://pypi.org/pypi/");
301
302 assert_eq!(client.base_url, "https://pypi.org/pypi");
303 }
304
305 #[tokio::test]
306 async fn test_prerelease_filtering() {
307 let client_stable = PyPiClient::new(false);
308 let client_pre = PyPiClient::new(true);
309
310 let result_stable = client_stable.get_package("django").await;
313 let result_pre = client_pre.get_package("django").await;
314
315 if result_stable.is_ok() && result_pre.is_ok() {
316 let stable = result_stable.unwrap();
317 let pre = result_pre.unwrap();
318
319 assert!(pre.versions.len() >= stable.versions.len());
321 }
322 }
323}