Skip to main content

pro_core/index/
client.rs

1//! PyPI HTTP client with private registry support
2
3use std::collections::HashMap;
4use std::sync::Arc;
5
6use reqwest::{Client, RequestBuilder};
7use tokio::sync::RwLock;
8use tracing::{debug, instrument};
9
10use super::types::PackageMetadata;
11use crate::pep::Version;
12use crate::registry::{RegistryConfig, ResolvedCredentials};
13use crate::{Error, Result};
14
15/// Default PyPI index URL
16pub const DEFAULT_INDEX_URL: &str = "https://pypi.org/pypi";
17
18/// PyPI index client for fetching package metadata
19#[derive(Clone)]
20pub struct PyPIClient {
21    /// HTTP client
22    client: Client,
23    /// Base URL for the index (JSON API)
24    base_url: String,
25    /// Resolved credentials for authentication
26    credentials: Option<ResolvedCredentials>,
27    /// Cache for package metadata
28    cache: Arc<RwLock<HashMap<String, PackageMetadata>>>,
29}
30
31impl PyPIClient {
32    /// Create a new PyPI client with default settings
33    pub fn new() -> Self {
34        Self::with_url(DEFAULT_INDEX_URL)
35    }
36
37    /// Create a new PyPI client with a custom index URL
38    pub fn with_url(base_url: impl Into<String>) -> Self {
39        let client = Client::builder()
40            .user_agent(concat!("pro/", env!("CARGO_PKG_VERSION")))
41            .build()
42            .expect("failed to build HTTP client");
43
44        Self {
45            client,
46            base_url: base_url.into(),
47            credentials: None,
48            cache: Arc::new(RwLock::new(HashMap::new())),
49        }
50    }
51
52    /// Create a client from a registry configuration
53    pub fn from_registry(config: &RegistryConfig) -> Result<Self> {
54        let client = Client::builder()
55            .user_agent(concat!("pro/", env!("CARGO_PKG_VERSION")))
56            .build()
57            .expect("failed to build HTTP client");
58
59        let credentials = if config.has_auth() {
60            Some(config.resolve_credentials()?)
61        } else {
62            None
63        };
64
65        Ok(Self {
66            client,
67            base_url: config.api_url(),
68            credentials,
69            cache: Arc::new(RwLock::new(HashMap::new())),
70        })
71    }
72
73    /// Set credentials for authentication
74    pub fn with_credentials(mut self, credentials: ResolvedCredentials) -> Self {
75        self.credentials = Some(credentials);
76        self
77    }
78
79    /// Apply authentication to a request
80    fn apply_auth(&self, request: RequestBuilder) -> RequestBuilder {
81        match &self.credentials {
82            Some(creds) => {
83                if let Some(ref token) = creds.token {
84                    // Bearer token authentication
85                    request.bearer_auth(token)
86                } else if let (Some(ref username), Some(ref password)) =
87                    (&creds.username, &creds.password)
88                {
89                    // Basic authentication
90                    request.basic_auth(username, Some(password))
91                } else {
92                    request
93                }
94            }
95            None => request,
96        }
97    }
98
99    /// Get the base URL
100    pub fn base_url(&self) -> &str {
101        &self.base_url
102    }
103
104    /// Fetch package metadata from PyPI
105    #[instrument(skip(self), fields(package = %name))]
106    pub async fn get_package(&self, name: &str) -> Result<PackageMetadata> {
107        let normalized = Self::normalize_name(name);
108
109        // Check cache first
110        {
111            let cache = self.cache.read().await;
112            if let Some(metadata) = cache.get(&normalized) {
113                debug!("cache hit for {}", normalized);
114                return Ok(metadata.clone());
115            }
116        }
117
118        debug!("fetching metadata for {}", normalized);
119
120        let url = format!("{}/{}/json", self.base_url, normalized);
121        let request = self.apply_auth(self.client.get(&url));
122        let response = request.send().await?;
123
124        if response.status() == reqwest::StatusCode::NOT_FOUND {
125            return Err(Error::PackageNotFound {
126                package: name.to_string(),
127            });
128        }
129
130        let metadata: PackageMetadata = response.error_for_status()?.json().await?;
131
132        // Cache the result
133        {
134            let mut cache = self.cache.write().await;
135            cache.insert(normalized, metadata.clone());
136        }
137
138        Ok(metadata)
139    }
140
141    /// Fetch metadata for a specific version
142    #[instrument(skip(self), fields(package = %name, version = %version))]
143    pub async fn get_package_version(&self, name: &str, version: &str) -> Result<PackageMetadata> {
144        let normalized = Self::normalize_name(name);
145
146        debug!("fetching metadata for {}=={}", normalized, version);
147
148        let url = format!("{}/{}/{}/json", self.base_url, normalized, version);
149        let request = self.apply_auth(self.client.get(&url));
150        let response = request.send().await?;
151
152        if response.status() == reqwest::StatusCode::NOT_FOUND {
153            return Err(Error::VersionNotFound {
154                package: name.to_string(),
155                version: version.to_string(),
156            });
157        }
158
159        response
160            .error_for_status()?
161            .json()
162            .await
163            .map_err(Into::into)
164    }
165
166    /// Get all available versions for a package
167    #[instrument(skip(self), fields(package = %name))]
168    pub async fn get_versions(&self, name: &str) -> Result<Vec<Version>> {
169        let metadata = self.get_package(name).await?;
170
171        let mut versions: Vec<Version> = metadata
172            .releases
173            .keys()
174            .filter_map(|v| Version::parse(v).ok())
175            .collect();
176
177        // Sort by version, newest first
178        versions.sort_by(|a, b| b.cmp(a));
179
180        Ok(versions)
181    }
182
183    /// Get available versions that have non-yanked files
184    #[instrument(skip(self), fields(package = %name))]
185    pub async fn get_available_versions(&self, name: &str) -> Result<Vec<Version>> {
186        let metadata = self.get_package(name).await?;
187
188        let mut versions: Vec<Version> = metadata
189            .releases
190            .iter()
191            .filter(|(_, files)| {
192                // Version is available if it has at least one non-yanked file
193                files.iter().any(|f| !f.yanked)
194            })
195            .filter_map(|(v, _)| Version::parse(v).ok())
196            .collect();
197
198        // Sort by version, newest first
199        versions.sort_by(|a, b| b.cmp(a));
200
201        Ok(versions)
202    }
203
204    /// Fetch metadata for multiple packages concurrently
205    #[instrument(skip(self, names))]
206    pub async fn get_packages_concurrent(
207        &self,
208        names: &[String],
209    ) -> HashMap<String, Result<PackageMetadata>> {
210        use futures::future::join_all;
211
212        let futures: Vec<_> = names
213            .iter()
214            .map(|name| {
215                let name = name.clone();
216                let client = self.clone();
217                async move {
218                    let result = client.get_package(&name).await;
219                    (Self::normalize_name(&name), result)
220                }
221            })
222            .collect();
223
224        join_all(futures).await.into_iter().collect()
225    }
226
227    /// Clear the metadata cache
228    pub async fn clear_cache(&self) {
229        let mut cache = self.cache.write().await;
230        cache.clear();
231    }
232
233    /// Normalize a package name according to PEP 503
234    fn normalize_name(name: &str) -> String {
235        name.to_lowercase()
236            .chars()
237            .map(|c| match c {
238                '_' | '.' => '-',
239                c => c,
240            })
241            .collect()
242    }
243}
244
245impl Default for PyPIClient {
246    fn default() -> Self {
247        Self::new()
248    }
249}
250
251#[cfg(test)]
252mod tests {
253    use super::*;
254
255    #[test]
256    fn test_normalize_name() {
257        assert_eq!(PyPIClient::normalize_name("requests"), "requests");
258        assert_eq!(PyPIClient::normalize_name("Requests"), "requests");
259        assert_eq!(PyPIClient::normalize_name("my_package"), "my-package");
260        assert_eq!(
261            PyPIClient::normalize_name("zope.interface"),
262            "zope-interface"
263        );
264    }
265
266    #[tokio::test]
267    #[ignore = "requires network"]
268    async fn test_get_package() {
269        let client = PyPIClient::new();
270        let metadata = client.get_package("requests").await.unwrap();
271        assert_eq!(metadata.info.name.to_lowercase(), "requests");
272    }
273
274    #[tokio::test]
275    #[ignore = "requires network"]
276    async fn test_get_versions() {
277        let client = PyPIClient::new();
278        let versions = client.get_versions("requests").await.unwrap();
279        assert!(!versions.is_empty());
280        // Versions should be sorted newest first
281        for i in 1..versions.len() {
282            assert!(versions[i - 1] >= versions[i]);
283        }
284    }
285}