1use std::collections::HashMap;
4use std::sync::Arc;
5
6use reqwest::{Client, RequestBuilder};
7use tokio::sync::RwLock;
8use tracing::{debug, instrument};
9
10use super::types::PackageMetadata;
11use crate::pep::Version;
12use crate::registry::{RegistryConfig, ResolvedCredentials};
13use crate::{Error, Result};
14
15pub const DEFAULT_INDEX_URL: &str = "https://pypi.org/pypi";
17
18#[derive(Clone)]
20pub struct PyPIClient {
21 client: Client,
23 base_url: String,
25 credentials: Option<ResolvedCredentials>,
27 cache: Arc<RwLock<HashMap<String, PackageMetadata>>>,
29}
30
31impl PyPIClient {
32 pub fn new() -> Self {
34 Self::with_url(DEFAULT_INDEX_URL)
35 }
36
37 pub fn with_url(base_url: impl Into<String>) -> Self {
39 let client = Client::builder()
40 .user_agent(concat!("pro/", env!("CARGO_PKG_VERSION")))
41 .build()
42 .expect("failed to build HTTP client");
43
44 Self {
45 client,
46 base_url: base_url.into(),
47 credentials: None,
48 cache: Arc::new(RwLock::new(HashMap::new())),
49 }
50 }
51
52 pub fn from_registry(config: &RegistryConfig) -> Result<Self> {
54 let client = Client::builder()
55 .user_agent(concat!("pro/", env!("CARGO_PKG_VERSION")))
56 .build()
57 .expect("failed to build HTTP client");
58
59 let credentials = if config.has_auth() {
60 Some(config.resolve_credentials()?)
61 } else {
62 None
63 };
64
65 Ok(Self {
66 client,
67 base_url: config.api_url(),
68 credentials,
69 cache: Arc::new(RwLock::new(HashMap::new())),
70 })
71 }
72
73 pub fn with_credentials(mut self, credentials: ResolvedCredentials) -> Self {
75 self.credentials = Some(credentials);
76 self
77 }
78
79 fn apply_auth(&self, request: RequestBuilder) -> RequestBuilder {
81 match &self.credentials {
82 Some(creds) => {
83 if let Some(ref token) = creds.token {
84 request.bearer_auth(token)
86 } else if let (Some(ref username), Some(ref password)) =
87 (&creds.username, &creds.password)
88 {
89 request.basic_auth(username, Some(password))
91 } else {
92 request
93 }
94 }
95 None => request,
96 }
97 }
98
99 pub fn base_url(&self) -> &str {
101 &self.base_url
102 }
103
104 #[instrument(skip(self), fields(package = %name))]
106 pub async fn get_package(&self, name: &str) -> Result<PackageMetadata> {
107 let normalized = Self::normalize_name(name);
108
109 {
111 let cache = self.cache.read().await;
112 if let Some(metadata) = cache.get(&normalized) {
113 debug!("cache hit for {}", normalized);
114 return Ok(metadata.clone());
115 }
116 }
117
118 debug!("fetching metadata for {}", normalized);
119
120 let url = format!("{}/{}/json", self.base_url, normalized);
121 let request = self.apply_auth(self.client.get(&url));
122 let response = request.send().await?;
123
124 if response.status() == reqwest::StatusCode::NOT_FOUND {
125 return Err(Error::PackageNotFound {
126 package: name.to_string(),
127 });
128 }
129
130 let metadata: PackageMetadata = response.error_for_status()?.json().await?;
131
132 {
134 let mut cache = self.cache.write().await;
135 cache.insert(normalized, metadata.clone());
136 }
137
138 Ok(metadata)
139 }
140
141 #[instrument(skip(self), fields(package = %name, version = %version))]
143 pub async fn get_package_version(&self, name: &str, version: &str) -> Result<PackageMetadata> {
144 let normalized = Self::normalize_name(name);
145
146 debug!("fetching metadata for {}=={}", normalized, version);
147
148 let url = format!("{}/{}/{}/json", self.base_url, normalized, version);
149 let request = self.apply_auth(self.client.get(&url));
150 let response = request.send().await?;
151
152 if response.status() == reqwest::StatusCode::NOT_FOUND {
153 return Err(Error::VersionNotFound {
154 package: name.to_string(),
155 version: version.to_string(),
156 });
157 }
158
159 response
160 .error_for_status()?
161 .json()
162 .await
163 .map_err(Into::into)
164 }
165
166 #[instrument(skip(self), fields(package = %name))]
168 pub async fn get_versions(&self, name: &str) -> Result<Vec<Version>> {
169 let metadata = self.get_package(name).await?;
170
171 let mut versions: Vec<Version> = metadata
172 .releases
173 .keys()
174 .filter_map(|v| Version::parse(v).ok())
175 .collect();
176
177 versions.sort_by(|a, b| b.cmp(a));
179
180 Ok(versions)
181 }
182
183 #[instrument(skip(self), fields(package = %name))]
185 pub async fn get_available_versions(&self, name: &str) -> Result<Vec<Version>> {
186 let metadata = self.get_package(name).await?;
187
188 let mut versions: Vec<Version> = metadata
189 .releases
190 .iter()
191 .filter(|(_, files)| {
192 files.iter().any(|f| !f.yanked)
194 })
195 .filter_map(|(v, _)| Version::parse(v).ok())
196 .collect();
197
198 versions.sort_by(|a, b| b.cmp(a));
200
201 Ok(versions)
202 }
203
204 #[instrument(skip(self, names))]
206 pub async fn get_packages_concurrent(
207 &self,
208 names: &[String],
209 ) -> HashMap<String, Result<PackageMetadata>> {
210 use futures::future::join_all;
211
212 let futures: Vec<_> = names
213 .iter()
214 .map(|name| {
215 let name = name.clone();
216 let client = self.clone();
217 async move {
218 let result = client.get_package(&name).await;
219 (Self::normalize_name(&name), result)
220 }
221 })
222 .collect();
223
224 join_all(futures).await.into_iter().collect()
225 }
226
227 pub async fn clear_cache(&self) {
229 let mut cache = self.cache.write().await;
230 cache.clear();
231 }
232
233 fn normalize_name(name: &str) -> String {
235 name.to_lowercase()
236 .chars()
237 .map(|c| match c {
238 '_' | '.' => '-',
239 c => c,
240 })
241 .collect()
242 }
243}
244
245impl Default for PyPIClient {
246 fn default() -> Self {
247 Self::new()
248 }
249}
250
251#[cfg(test)]
252mod tests {
253 use super::*;
254
255 #[test]
256 fn test_normalize_name() {
257 assert_eq!(PyPIClient::normalize_name("requests"), "requests");
258 assert_eq!(PyPIClient::normalize_name("Requests"), "requests");
259 assert_eq!(PyPIClient::normalize_name("my_package"), "my-package");
260 assert_eq!(
261 PyPIClient::normalize_name("zope.interface"),
262 "zope-interface"
263 );
264 }
265
266 #[tokio::test]
267 #[ignore = "requires network"]
268 async fn test_get_package() {
269 let client = PyPIClient::new();
270 let metadata = client.get_package("requests").await.unwrap();
271 assert_eq!(metadata.info.name.to_lowercase(), "requests");
272 }
273
274 #[tokio::test]
275 #[ignore = "requires network"]
276 async fn test_get_versions() {
277 let client = PyPIClient::new();
278 let versions = client.get_versions("requests").await.unwrap();
279 assert!(!versions.is_empty());
280 for i in 1..versions.len() {
282 assert!(versions[i - 1] >= versions[i]);
283 }
284 }
285}