use crate::Error::{ArchiveHashMismatch, RepositoryFailure, VersionNotFound};
use crate::repository::Archive;
use crate::repository::maven::models::Metadata;
use crate::repository::model::Repository;
use crate::{Result, hasher};
use async_trait::async_trait;
use futures_util::StreamExt;
use reqwest::header::HeaderMap;
use reqwest_middleware::{ClientBuilder, ClientWithMiddleware};
use reqwest_retry::RetryTransientMiddleware;
use reqwest_retry::policies::ExponentialBackoff;
use reqwest_tracing::TracingMiddleware;
use semver::{Version, VersionReq};
use std::env;
use std::io::Write;
use std::sync::LazyLock;
use tracing::{debug, instrument, warn};
#[cfg(feature = "indicatif")]
use tracing_indicatif::span_ext::IndicatifSpanExt;
static USER_AGENT: LazyLock<String> = LazyLock::new(|| {
format!(
"{PACKAGE}/{VERSION}",
PACKAGE = env!("CARGO_PKG_NAME"),
VERSION = env!("CARGO_PKG_VERSION")
)
});
#[derive(Debug)]
pub struct Maven {
url: String,
}
impl Maven {
#[expect(clippy::new_ret_no_self)]
pub fn new(url: &str) -> Result<Box<dyn Repository>> {
Ok(Box::new(Self {
url: url.to_string(),
}))
}
#[instrument(level = "debug")]
async fn get_artifact(&self, version_req: &VersionReq) -> Result<(String, Version)> {
debug!("Attempting to locate release for version requirement {version_req}");
let client = reqwest_client();
let url = format!("{}/maven-metadata.xml", self.url);
let request = client.get(&url).headers(Self::headers());
let response = request.send().await?.error_for_status()?;
let text = response.text().await?;
let metadata: Metadata = quick_xml::de::from_str(&text)?;
let artifact = metadata.artifact_id;
let mut result = None;
for version in &metadata.versioning.versions.version {
let version = Version::parse(version)?;
if version_req.matches(&version) {
if let Some(result_version) = result.clone() {
if version > result_version {
result = Some(version);
}
} else {
result = Some(version);
}
}
}
match &result {
Some(version) => {
debug!("Version {version} found for version requirement {version_req}");
Ok((artifact, version.clone()))
}
None => Err(VersionNotFound(version_req.to_string())),
}
}
fn headers() -> HeaderMap {
let mut headers = HeaderMap::new();
headers.append("User-Agent", USER_AGENT.parse().unwrap());
headers
}
}
#[async_trait]
impl Repository for Maven {
#[instrument(level = "debug")]
fn name(&self) -> &str {
"Maven"
}
#[instrument(level = "debug")]
async fn get_version(&self, version_req: &VersionReq) -> Result<Version> {
debug!("Attempting to locate release for version requirement {version_req}");
let (_, version) = self.get_artifact(version_req).await?;
Ok(version)
}
#[instrument]
async fn get_archive(&self, version_req: &VersionReq) -> Result<Archive> {
let (artifact, version) = self.get_artifact(version_req).await?;
let archive_name = format!("{artifact}-{version}.jar");
let archive_url = format!("{url}/{version}/{artifact}-{version}.jar", url = self.url,);
let mut hasher_result = None;
for extension in &["sha512", "sha256", "sha1", "md5"] {
if let Ok(hasher_fn) = hasher::registry::get(&self.url, &(*extension).to_string()) {
hasher_result = Some((extension, hasher_fn));
}
}
let Some((extension, hasher_fn)) = hasher_result else {
return Err(RepositoryFailure(format!(
"no hashers found for {}",
&self.url
)));
};
let archive_hash_url = format!("{archive_url}.{extension}");
let client = reqwest_client();
debug!("Downloading archive hash {archive_hash_url}");
let request = client.get(&archive_hash_url).headers(Self::headers());
let response = request.send().await?.error_for_status()?;
let hash = response.text().await?;
debug!("Archive hash {archive_hash_url} downloaded: {}", hash.len(),);
debug!("Downloading archive {archive_url}");
let request = client.get(&archive_url).headers(Self::headers());
let response = request.send().await?.error_for_status()?;
#[cfg(feature = "indicatif")]
let span = tracing::Span::current();
#[cfg(feature = "indicatif")]
{
let content_length = response.content_length().unwrap_or_default();
span.pb_set_length(content_length);
}
let mut bytes = Vec::new();
let mut source = response.bytes_stream();
while let Some(chunk) = source.next().await {
bytes.write_all(&chunk?)?;
#[cfg(feature = "indicatif")]
span.pb_set_position(bytes.len() as u64);
}
debug!("Archive {archive_url} downloaded: {}", bytes.len(),);
let archive_hash = hasher_fn(&bytes)?;
if archive_hash != hash {
return Err(ArchiveHashMismatch { archive_hash, hash });
}
let archive = Archive::new(archive_name, version, bytes);
Ok(archive)
}
}
fn reqwest_client() -> ClientWithMiddleware {
let retry_policy = ExponentialBackoff::builder().build_with_max_retries(3);
ClientBuilder::new(reqwest::Client::new())
.with(TracingMiddleware::default())
.with(RetryTransientMiddleware::new_with_policy(retry_policy))
.build()
}
#[cfg(test)]
mod tests {
use super::*;
const URL: &str = "https://repo1.maven.org/maven2/io/zonky/test/postgres/embedded-postgres-binaries-linux-amd64";
#[test]
fn test_name() {
let maven = Maven::new(URL).unwrap();
assert_eq!("Maven", maven.name());
}
#[tokio::test]
async fn test_get_version() -> Result<()> {
let maven = Maven::new(URL)?;
let version_req = VersionReq::STAR;
let version = maven.get_version(&version_req).await?;
assert!(version > Version::new(0, 0, 0));
Ok(())
}
#[tokio::test]
async fn test_get_specific_version() -> Result<()> {
let maven = Maven::new(URL)?;
let version_req = VersionReq::parse("=16.2.0")?;
let version = maven.get_version(&version_req).await?;
assert_eq!(Version::new(16, 2, 0), version);
Ok(())
}
#[tokio::test]
async fn test_get_specific_not_found() -> Result<()> {
let maven = Maven::new(URL)?;
let version_req = VersionReq::parse("=0.0.0")?;
let error = maven.get_version(&version_req).await.unwrap_err();
assert_eq!("version not found for '=0.0.0'", error.to_string());
Ok(())
}
#[tokio::test]
async fn test_get_archive() -> Result<()> {
let maven = Maven::new(URL)?;
let version = Version::new(16, 2, 0);
let version_req = VersionReq::parse(format!("={version}").as_str())?;
let archive = maven.get_archive(&version_req).await?;
assert_eq!(
format!("embedded-postgres-binaries-linux-amd64-{version}.jar"),
archive.name()
);
assert_eq!(&version, archive.version());
assert!(!archive.bytes().is_empty());
Ok(())
}
}