python_check_updates/
pypi.rs

1use crate::version::Version;
2use anyhow::{anyhow, Context, Result};
3use serde::Deserialize;
4use std::collections::HashMap;
5use std::str::FromStr;
6use std::sync::Arc;
7use tokio::sync::Semaphore;
8
9/// Client for querying PyPI API
10pub struct PyPiClient {
11    client: reqwest::Client,
12    base_url: String,
13    include_prerelease: bool,
14}
15
16/// Package information from PyPI
17#[derive(Debug, Clone)]
18pub struct PackageInfo {
19    pub name: String,
20    pub versions: Vec<Version>,
21    pub latest: Version,
22    pub latest_stable: Option<Version>,
23}
24
25/// PyPI JSON API response structure
26#[derive(Debug, Deserialize)]
27struct PyPiResponse {
28    info: PyPiInfo,
29    releases: HashMap<String, Vec<PyPiRelease>>,
30}
31
32#[derive(Debug, Deserialize)]
33struct PyPiInfo {
34    name: String,
35}
36
37#[derive(Debug, Deserialize)]
38struct PyPiRelease {
39    #[allow(dead_code)]
40    yanked: Option<bool>,
41}
42
43impl PyPiClient {
44    pub fn new(include_prerelease: bool) -> Self {
45        Self {
46            client: reqwest::Client::builder()
47                .user_agent("python-check-updates/0.1.0")
48                .timeout(std::time::Duration::from_secs(30))
49                .build()
50                .unwrap_or_else(|_| reqwest::Client::new()),
51            base_url: "https://pypi.org/pypi".to_string(),
52            include_prerelease,
53        }
54    }
55
56    pub fn with_index_url(mut self, url: &str) -> Self {
57        // Remove trailing slash if present
58        self.base_url = url.trim_end_matches('/').to_string();
59        self
60    }
61
62    /// Fetch package info from PyPI
63    pub async fn get_package(&self, name: &str) -> Result<PackageInfo> {
64        let url = format!("{}/{}/json", self.base_url, name);
65
66        let response = self
67            .client
68            .get(&url)
69            .send()
70            .await
71            .context(format!("Failed to fetch package '{}'", name))?;
72
73        if !response.status().is_success() {
74            if response.status() == 404 {
75                return Err(anyhow!("Package '{}' not found on PyPI", name));
76            }
77            return Err(anyhow!(
78                "PyPI API request failed with status: {}",
79                response.status()
80            ));
81        }
82
83        let pypi_data: PyPiResponse = response
84            .json()
85            .await
86            .context(format!("Failed to parse JSON response for '{}'", name))?;
87
88        // Parse all versions from releases
89        let mut all_versions: Vec<Version> = Vec::new();
90        for (version_str, releases) in &pypi_data.releases {
91            // Skip yanked releases (empty release list or all yanked)
92            if releases.is_empty() {
93                continue;
94            }
95
96            // Check if all releases are yanked
97            let all_yanked = releases.iter().all(|r| r.yanked.unwrap_or(false));
98            if all_yanked {
99                continue;
100            }
101
102            // Try to parse the version
103            if let Ok(version) = Version::from_str(version_str) {
104                all_versions.push(version);
105            }
106        }
107
108        if all_versions.is_empty() {
109            return Err(anyhow!("No valid versions found for package '{}'", name));
110        }
111
112        // Sort versions in ascending order
113        all_versions.sort();
114
115        // Filter versions based on prerelease setting
116        let filtered_versions: Vec<Version> = if self.include_prerelease {
117            all_versions.clone()
118        } else {
119            all_versions
120                .iter()
121                .filter(|v| !v.is_prerelease())
122                .cloned()
123                .collect()
124        };
125
126        if filtered_versions.is_empty() {
127            return Err(anyhow!(
128                "No stable versions found for package '{}' (use --pre-release to include pre-releases)",
129                name
130            ));
131        }
132
133        // Get latest version (with or without prerelease)
134        let latest = if self.include_prerelease {
135            all_versions
136                .last()
137                .ok_or_else(|| anyhow!("No versions found"))?
138                .clone()
139        } else {
140            filtered_versions
141                .last()
142                .ok_or_else(|| anyhow!("No stable versions found"))?
143                .clone()
144        };
145
146        // Get latest stable version (always filter out prereleases)
147        let latest_stable = all_versions
148            .iter()
149            .filter(|v| !v.is_prerelease())
150            .last()
151            .cloned();
152
153        Ok(PackageInfo {
154            name: pypi_data.info.name,
155            versions: filtered_versions,
156            latest,
157            latest_stable,
158        })
159    }
160
161    /// Fetch multiple packages concurrently
162    pub async fn get_packages(
163        &self,
164        names: &[String],
165        progress_callback: impl Fn(usize, usize) + Send + Sync + 'static,
166    ) -> Result<GetPackagesResult> {
167        let total = names.len();
168        let progress_callback = Arc::new(progress_callback);
169
170        // Limit concurrent requests to avoid overwhelming the server
171        let semaphore = Arc::new(Semaphore::new(10));
172
173        let mut tasks = Vec::new();
174
175        for (index, name) in names.iter().enumerate() {
176            let client = self.clone();
177            let name = name.clone();
178            let callback = Arc::clone(&progress_callback);
179            let semaphore = Arc::clone(&semaphore);
180
181            let task = tokio::spawn(async move {
182                // Acquire semaphore permit
183                let _permit = semaphore.acquire().await.unwrap();
184
185                let result = client.get_package(&name).await;
186
187                // Call progress callback
188                callback(index + 1, total);
189
190                (name, result)
191            });
192
193            tasks.push(task);
194        }
195
196        // Wait for all tasks to complete
197        let mut packages = HashMap::new();
198        let mut errors = Vec::new();
199
200        for task in tasks {
201            match task.await {
202                Ok((name, Ok(package_info))) => {
203                    packages.insert(name, package_info);
204                }
205                Ok((name, Err(e))) => {
206                    // Extract just the error message without "Failed to fetch" prefix
207                    let error_msg = e.to_string();
208                    errors.push((name, error_msg));
209                }
210                Err(e) => {
211                    errors.push(("unknown".to_string(), format!("Task failed: {}", e)));
212                }
213            }
214        }
215
216        // Format errors as strings
217        let formatted_errors: Vec<String> = errors
218            .into_iter()
219            .map(|(name, msg)| format!("{}: {}", name, msg))
220            .collect();
221
222        // If we have some results, return them even if some packages failed
223        if !packages.is_empty() || formatted_errors.is_empty() {
224            Ok(GetPackagesResult {
225                packages,
226                errors: formatted_errors,
227            })
228        } else {
229            // All packages failed
230            Err(anyhow!(
231                "Failed to fetch all packages:\n{}",
232                formatted_errors.join("\n")
233            ))
234        }
235    }
236}
237
238/// Result of fetching multiple packages
239#[derive(Debug, Clone)]
240pub struct GetPackagesResult {
241    pub packages: HashMap<String, PackageInfo>,
242    pub errors: Vec<String>,
243}
244
245// Implement Clone for PyPiClient to support concurrent usage
246impl Clone for PyPiClient {
247    fn clone(&self) -> Self {
248        Self {
249            client: self.client.clone(),
250            base_url: self.base_url.clone(),
251            include_prerelease: self.include_prerelease,
252        }
253    }
254}
255
256#[cfg(test)]
257mod tests {
258    use super::*;
259
260    #[tokio::test]
261    async fn test_get_package_requests() {
262        let client = PyPiClient::new(false);
263        let result = client.get_package("requests").await;
264
265        assert!(result.is_ok(), "Failed to fetch requests package: {:?}", result.err());
266
267        let package_info = result.unwrap();
268        assert_eq!(package_info.name.to_lowercase(), "requests");
269        assert!(!package_info.versions.is_empty());
270        assert!(package_info.latest_stable.is_some());
271    }
272
273    #[tokio::test]
274    async fn test_get_package_not_found() {
275        let client = PyPiClient::new(false);
276        let result = client.get_package("this-package-definitely-does-not-exist-12345").await;
277
278        assert!(result.is_err());
279        assert!(result.unwrap_err().to_string().contains("not found"));
280    }
281
282    #[tokio::test]
283    async fn test_get_packages_concurrent() {
284        let client = PyPiClient::new(false);
285        let packages = vec![
286            "requests".to_string(),
287            "flask".to_string(),
288        ];
289
290        // Use Arc<AtomicUsize> for thread-safe counter
291        let progress_calls = Arc::new(std::sync::atomic::AtomicUsize::new(0));
292        let progress_calls_clone = Arc::clone(&progress_calls);
293
294        let result = client.get_packages(&packages, move |_current, _total| {
295            progress_calls_clone.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
296        }).await;
297
298        assert!(result.is_ok(), "Failed to fetch packages: {:?}", result.err());
299
300        let results = result.unwrap();
301        assert!(!results.packages.is_empty());
302
303        // Verify progress callback was called
304        let calls = progress_calls.load(std::sync::atomic::Ordering::SeqCst);
305        assert!(calls > 0, "Progress callback should have been called");
306    }
307
308    #[tokio::test]
309    async fn test_custom_index_url() {
310        let client = PyPiClient::new(false)
311            .with_index_url("https://pypi.org/pypi/");
312
313        assert_eq!(client.base_url, "https://pypi.org/pypi");
314    }
315
316    #[tokio::test]
317    async fn test_prerelease_filtering() {
318        let client_stable = PyPiClient::new(false);
319        let client_pre = PyPiClient::new(true);
320
321        // Find a package that has prereleases (e.g., many popular packages)
322        // This test might be flaky depending on package state
323        let result_stable = client_stable.get_package("django").await;
324        let result_pre = client_pre.get_package("django").await;
325
326        if result_stable.is_ok() && result_pre.is_ok() {
327            let stable = result_stable.unwrap();
328            let pre = result_pre.unwrap();
329
330            // Pre-release client might have more versions
331            assert!(pre.versions.len() >= stable.versions.len());
332        }
333    }
334}