Skip to main content

pcu/
pypi.rs

1use check_updates_core::{PackageInfo, Version};
2use anyhow::{anyhow, Context, Result};
3use serde::Deserialize;
4use std::collections::HashMap;
5use std::str::FromStr;
6use std::sync::Arc;
7use tokio::sync::Semaphore;
8
9/// Client for querying PyPI API
10pub struct PyPiClient {
11    client: reqwest::Client,
12    base_url: String,
13    include_prerelease: bool,
14}
15
16/// PyPI JSON API response structure
17#[derive(Debug, Deserialize)]
18struct PyPiResponse {
19    info: PyPiInfo,
20    releases: HashMap<String, Vec<PyPiRelease>>,
21}
22
23#[derive(Debug, Deserialize)]
24struct PyPiInfo {
25    name: String,
26}
27
28#[derive(Debug, Deserialize)]
29struct PyPiRelease {
30    #[allow(dead_code)]
31    yanked: Option<bool>,
32}
33
34impl PyPiClient {
35    pub fn new(include_prerelease: bool) -> Self {
36        Self {
37            client: reqwest::Client::builder()
38                .user_agent("python-check-updates/0.1.0")
39                .timeout(std::time::Duration::from_secs(30))
40                .build()
41                .unwrap_or_else(|_| reqwest::Client::new()),
42            base_url: "https://pypi.org/pypi".to_string(),
43            include_prerelease,
44        }
45    }
46
47    pub fn with_index_url(mut self, url: &str) -> Self {
48        // Remove trailing slash if present
49        self.base_url = url.trim_end_matches('/').to_string();
50        self
51    }
52
53    /// Fetch package info from PyPI
54    pub async fn get_package(&self, name: &str) -> Result<PackageInfo> {
55        let url = format!("{}/{}/json", self.base_url, name);
56
57        let response = self
58            .client
59            .get(&url)
60            .send()
61            .await
62            .context(format!("Failed to fetch package '{name}'"))?;
63
64        if !response.status().is_success() {
65            if response.status() == 404 {
66                return Err(anyhow!("Package '{name}' not found on PyPI"));
67            }
68            return Err(anyhow!(
69                "PyPI API request failed with status: {}",
70                response.status()
71            ));
72        }
73
74        let pypi_data: PyPiResponse = response
75            .json()
76            .await
77            .context(format!("Failed to parse JSON response for '{name}'"))?;
78
79        // Parse all versions from releases
80        let mut all_versions: Vec<Version> = Vec::new();
81        for (version_str, releases) in &pypi_data.releases {
82            // Skip yanked releases (empty release list or all yanked)
83            if releases.is_empty() {
84                continue;
85            }
86
87            // Check if all releases are yanked
88            let all_yanked = releases.iter().all(|r| r.yanked.unwrap_or(false));
89            if all_yanked {
90                continue;
91            }
92
93            // Try to parse the version
94            if let Ok(version) = Version::from_str(version_str) {
95                all_versions.push(version);
96            }
97        }
98
99        if all_versions.is_empty() {
100            return Err(anyhow!("No valid versions found for package '{name}'"));
101        }
102
103        // Sort versions in ascending order
104        all_versions.sort();
105
106        // Filter versions based on prerelease setting
107        let filtered_versions: Vec<Version> = if self.include_prerelease {
108            all_versions.clone()
109        } else {
110            all_versions
111                .iter()
112                .filter(|v| !v.is_prerelease())
113                .cloned()
114                .collect()
115        };
116
117        if filtered_versions.is_empty() {
118            return Err(anyhow!(
119                "No stable versions found for package '{name}' (use --pre-release to include pre-releases)"
120            ));
121        }
122
123        // Get latest version (with or without prerelease)
124        let latest = if self.include_prerelease {
125            all_versions
126                .last()
127                .ok_or_else(|| anyhow!("No versions found"))?
128                .clone()
129        } else {
130            filtered_versions
131                .last()
132                .ok_or_else(|| anyhow!("No stable versions found"))?
133                .clone()
134        };
135
136        // Get latest stable version (always filter out prereleases)
137        let latest_stable = all_versions
138            .iter()
139            .rfind(|v| !v.is_prerelease())
140            .cloned();
141
142        Ok(PackageInfo {
143            name: pypi_data.info.name,
144            versions: filtered_versions,
145            latest,
146            latest_stable,
147        })
148    }
149
150    /// Fetch multiple packages concurrently
151    pub async fn get_packages(
152        &self,
153        names: &[String],
154        progress_callback: impl Fn(usize, usize) + Send + Sync + 'static,
155    ) -> Result<GetPackagesResult> {
156        let total = names.len();
157        let progress_callback = Arc::new(progress_callback);
158
159        // Limit concurrent requests to avoid overwhelming the server
160        let semaphore = Arc::new(Semaphore::new(10));
161
162        let mut tasks = Vec::new();
163
164        for (index, name) in names.iter().enumerate() {
165            let client = self.clone();
166            let name = name.clone();
167            let callback = Arc::clone(&progress_callback);
168            let semaphore = Arc::clone(&semaphore);
169
170            let task = tokio::spawn(async move {
171                // Acquire semaphore permit
172                let _permit = semaphore.acquire().await.expect("semaphore closed");
173
174                let result = client.get_package(&name).await;
175
176                // Call progress callback
177                callback(index + 1, total);
178
179                (name, result)
180            });
181
182            tasks.push(task);
183        }
184
185        // Wait for all tasks to complete
186        let mut packages = HashMap::new();
187        let mut errors = Vec::new();
188
189        for task in tasks {
190            match task.await {
191                Ok((name, Ok(package_info))) => {
192                    packages.insert(name, package_info);
193                }
194                Ok((name, Err(e))) => {
195                    // Extract just the error message without "Failed to fetch" prefix
196                    let error_msg = e.to_string();
197                    errors.push((name, error_msg));
198                }
199                Err(e) => {
200                    errors.push(("unknown".to_string(), format!("Task failed: {e}")));
201                }
202            }
203        }
204
205        // Format errors as strings
206        let formatted_errors: Vec<String> = errors
207            .into_iter()
208            .map(|(name, msg)| format!("{name}: {msg}"))
209            .collect();
210
211        // If we have some results, return them even if some packages failed
212        if !packages.is_empty() || formatted_errors.is_empty() {
213            Ok(GetPackagesResult {
214                packages,
215                errors: formatted_errors,
216            })
217        } else {
218            // All packages failed
219            Err(anyhow!(
220                "Failed to fetch all packages:\n{}",
221                formatted_errors.join("\n")
222            ))
223        }
224    }
225}
226
227/// Result of fetching multiple packages
228#[derive(Debug, Clone)]
229pub struct GetPackagesResult {
230    pub packages: HashMap<String, PackageInfo>,
231    pub errors: Vec<String>,
232}
233
234// Implement Clone for PyPiClient to support concurrent usage
235impl Clone for PyPiClient {
236    fn clone(&self) -> Self {
237        Self {
238            client: self.client.clone(),
239            base_url: self.base_url.clone(),
240            include_prerelease: self.include_prerelease,
241        }
242    }
243}
244
245#[cfg(test)]
246mod tests {
247    use super::*;
248
249    #[tokio::test]
250    async fn test_get_package_requests() {
251        let client = PyPiClient::new(false);
252        let result = client.get_package("requests").await;
253
254        assert!(result.is_ok(), "Failed to fetch requests package: {:?}", result.err());
255
256        let package_info = result.unwrap();
257        assert_eq!(package_info.name.to_lowercase(), "requests");
258        assert!(!package_info.versions.is_empty());
259        assert!(package_info.latest_stable.is_some());
260    }
261
262    #[tokio::test]
263    async fn test_get_package_not_found() {
264        let client = PyPiClient::new(false);
265        let result = client.get_package("this-package-definitely-does-not-exist-12345").await;
266
267        assert!(result.is_err());
268        assert!(result.unwrap_err().to_string().contains("not found"));
269    }
270
271    #[tokio::test]
272    async fn test_get_packages_concurrent() {
273        let client = PyPiClient::new(false);
274        let packages = vec![
275            "requests".to_string(),
276            "flask".to_string(),
277        ];
278
279        // Use Arc<AtomicUsize> for thread-safe counter
280        let progress_calls = Arc::new(std::sync::atomic::AtomicUsize::new(0));
281        let progress_calls_clone = Arc::clone(&progress_calls);
282
283        let result = client.get_packages(&packages, move |_current, _total| {
284            progress_calls_clone.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
285        }).await;
286
287        assert!(result.is_ok(), "Failed to fetch packages: {:?}", result.err());
288
289        let results = result.unwrap();
290        assert!(!results.packages.is_empty());
291
292        // Verify progress callback was called
293        let calls = progress_calls.load(std::sync::atomic::Ordering::SeqCst);
294        assert!(calls > 0, "Progress callback should have been called");
295    }
296
297    #[tokio::test]
298    async fn test_custom_index_url() {
299        let client = PyPiClient::new(false)
300            .with_index_url("https://pypi.org/pypi/");
301
302        assert_eq!(client.base_url, "https://pypi.org/pypi");
303    }
304
305    #[tokio::test]
306    async fn test_prerelease_filtering() {
307        let client_stable = PyPiClient::new(false);
308        let client_pre = PyPiClient::new(true);
309
310        // Find a package that has prereleases (e.g., many popular packages)
311        // This test might be flaky depending on package state
312        let result_stable = client_stable.get_package("django").await;
313        let result_pre = client_pre.get_package("django").await;
314
315        if result_stable.is_ok() && result_pre.is_ok() {
316            let stable = result_stable.unwrap();
317            let pre = result_pre.unwrap();
318
319            // Pre-release client might have more versions
320            assert!(pre.versions.len() >= stable.versions.len());
321        }
322    }
323}