use crate::ports::outbound::{LicenseRepository, PyPiMetadata};
use crate::shared::Result;
use async_trait::async_trait;
use serde::Deserialize;
use std::collections::HashSet;
use std::time::Duration;
#[derive(Debug, Deserialize)]
struct PyPiPackageInfo {
info: PyPiInfo,
#[serde(default)]
urls: Vec<PyPiUrl>,
}
#[derive(Debug, Deserialize)]
struct PyPiUrl {
#[serde(default)]
digests: PyPiDigests,
}
#[derive(Debug, Default, Deserialize)]
struct PyPiDigests {
#[serde(default)]
sha256: Option<String>,
}
#[derive(Debug, Deserialize)]
struct PyPiInfo {
#[serde(default)]
license: Option<String>,
#[serde(default)]
license_expression: Option<String>,
#[serde(default)]
summary: Option<String>,
#[serde(default)]
classifiers: Vec<String>,
}
pub struct PyPiLicenseRepository {
client: reqwest::Client,
max_retries: u32,
}
impl PyPiLicenseRepository {
pub fn new() -> Result<Self> {
let version = env!("CARGO_PKG_VERSION");
let user_agent = format!("uv-sbom/{}", version);
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(10))
.user_agent(user_agent)
.build()?;
Ok(Self {
client,
max_retries: 3,
})
}
async fn fetch_with_retry(&self, package_name: &str, version: &str) -> Result<PyPiPackageInfo> {
let mut last_error = None;
for attempt in 1..=self.max_retries {
match self.fetch_from_pypi(package_name, version).await {
Ok(result) => return Ok(result),
Err(e) => {
last_error = Some(e);
if attempt < self.max_retries {
tokio::time::sleep(Duration::from_millis(100 * attempt as u64)).await;
}
}
}
}
Err(last_error.unwrap())
}
fn validate_url_component(component: &str, component_type: &str) -> Result<()> {
if component.contains('/') || component.contains('\\') {
anyhow::bail!(
"Security: {} contains path separators which are not allowed",
component_type
);
}
if component.contains("..") {
anyhow::bail!(
"Security: {} contains '..' which is not allowed",
component_type
);
}
if component.contains('#') || component.contains('?') || component.contains('@') {
anyhow::bail!(
"Security: {} contains URL-unsafe characters",
component_type
);
}
Ok(())
}
async fn fetch_from_pypi(&self, package_name: &str, version: &str) -> Result<PyPiPackageInfo> {
Self::validate_url_component(package_name, "Package name")?;
Self::validate_url_component(version, "Version")?;
let encoded_package = urlencoding::encode(package_name);
let encoded_version = urlencoding::encode(version);
let url = format!(
"https://pypi.org/pypi/{}/{}/json",
encoded_package, encoded_version
);
let response = self.client.get(&url).send().await?;
if !response.status().is_success() {
anyhow::bail!("PyPI API returned status code {}", response.status());
}
let package_info: PyPiPackageInfo = response.json().await?;
Ok(package_info)
}
}
impl PyPiLicenseRepository {
pub async fn verify_package_exists(&self, package_name: &str) -> bool {
let normalized = package_name.to_lowercase().replace('_', "-");
let url = format!("https://pypi.org/pypi/{}/json", normalized);
match self
.client
.head(&url)
.timeout(Duration::from_secs(5))
.send()
.await
{
Ok(response) => response.status().is_success(),
Err(_) => false,
}
}
pub async fn verify_packages(&self, names: &[String]) -> HashSet<String> {
use futures::stream::{self, StreamExt};
const MAX_CONCURRENT: usize = 10;
let results: Vec<(String, bool)> = stream::iter(names.iter().cloned())
.map(|name| async move {
let exists = self.verify_package_exists(&name).await;
(name, exists)
})
.buffer_unordered(MAX_CONCURRENT)
.collect()
.await;
results
.into_iter()
.filter_map(|(name, exists)| if exists { Some(name) } else { None })
.collect()
}
}
#[async_trait]
impl LicenseRepository for PyPiLicenseRepository {
async fn fetch_license_info(&self, package_name: &str, version: &str) -> Result<PyPiMetadata> {
let package_info = self.fetch_with_retry(package_name, version).await?;
let sha256_hash = package_info
.urls
.iter()
.find_map(|url| url.digests.sha256.clone());
Ok((
package_info.info.license,
package_info.info.license_expression,
package_info.info.classifiers,
package_info.info.summary,
sha256_hash,
))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_pypi_client_creation() {
let client = PyPiLicenseRepository::new();
assert!(client.is_ok());
}
#[tokio::test]
async fn test_verify_packages_empty_list() {
let client = PyPiLicenseRepository::new().unwrap();
let result = client.verify_packages(&[]).await;
assert!(result.is_empty());
}
#[test]
fn test_pypi_url_deserialization_with_digests() {
let json = r#"{
"info": {
"license": "MIT",
"summary": "A test package"
},
"urls": [
{
"digests": {
"sha256": "abc123def456"
}
}
]
}"#;
let package_info: PyPiPackageInfo = serde_json::from_str(json).unwrap();
assert_eq!(package_info.urls.len(), 1);
assert_eq!(
package_info.urls[0].digests.sha256,
Some("abc123def456".to_string())
);
}
#[test]
fn test_pypi_url_deserialization_without_urls() {
let json = r#"{
"info": {
"license": "MIT",
"summary": "A test package"
}
}"#;
let package_info: PyPiPackageInfo = serde_json::from_str(json).unwrap();
assert!(package_info.urls.is_empty());
}
#[test]
fn test_pypi_url_deserialization_empty_digests() {
let json = r#"{
"info": {
"license": "MIT",
"summary": "A test package"
},
"urls": [
{
"digests": {}
}
]
}"#;
let package_info: PyPiPackageInfo = serde_json::from_str(json).unwrap();
assert!(package_info.urls[0].digests.sha256.is_none());
}
}