use crate::error::PinnerError;
use async_trait::async_trait;
use reqwest::header::{HeaderMap, HeaderValue, ACCEPT, USER_AGENT};
use reqwest_middleware::{ClientBuilder, ClientWithMiddleware};
use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware};
use serde::Deserialize;
#[cfg(test)]
use mockall::automock;
#[cfg_attr(test, automock)]
#[async_trait]
pub trait RegistryProvider: Send + Sync {
async fn resolve_digest(&self, image: &str, tag: &str) -> Result<String, PinnerError>;
async fn verify_provenance(&self, image: &str, digest: &str) -> Result<bool, PinnerError>;
}
#[derive(Clone)]
pub struct OciRegistryProvider {
client: ClientWithMiddleware,
auth_url: String,
base_url_template: String,
username: Option<String>,
password: Option<String>,
offline: bool,
}
impl Default for OciRegistryProvider {
fn default() -> Self {
Self::new(None, None)
}
}
impl OciRegistryProvider {
pub fn new(username: Option<String>, password: Option<String>) -> Self {
let mut h = HeaderMap::new();
h.insert(USER_AGENT, HeaderValue::from_static("pinner"));
let reqwest_client = reqwest::Client::builder()
.default_headers(h)
.build()
.unwrap_or_else(|_| reqwest::Client::new());
let retry_policy = ExponentialBackoff::builder().build_with_max_retries(3);
let client = ClientBuilder::new(reqwest_client)
.with(RetryTransientMiddleware::new_with_policy(retry_policy))
.build();
Self {
client,
auth_url: "https://auth.docker.io/token".to_string(),
base_url_template: "https://{registry}/v2/{repository}/manifests/{tag}".to_string(),
username,
password,
offline: false,
}
}
pub fn with_offline(mut self, offline: bool) -> Self {
self.offline = offline;
self
}
#[cfg(test)]
pub fn with_base_urls(auth_url: String, base_url_template: String) -> Self {
let mut h = HeaderMap::new();
h.insert(USER_AGENT, HeaderValue::from_static("pinner"));
let reqwest_client = reqwest::Client::builder()
.default_headers(h)
.build()
.unwrap_or_else(|_| reqwest::Client::new());
let client = ClientBuilder::new(reqwest_client)
.with(RetryTransientMiddleware::new_with_policy(
ExponentialBackoff::builder().build_with_max_retries(3),
))
.build();
Self {
client,
auth_url,
base_url_template,
username: None,
password: None,
offline: false,
}
}
async fn get_token(&self, registry: &str, repository: &str) -> Result<String, PinnerError> {
let url = if registry == "docker.io" || registry == "registry-1.docker.io" {
format!(
"{}?service=registry.docker.io&scope=repository:{}:pull",
self.auth_url, repository
)
} else {
return Ok("".to_string());
};
let mut rb = self.client.get(&url);
let (u, p) = self.get_credentials(registry);
if let (Some(u), Some(p)) = (u, p) {
rb = rb.basic_auth(u, Some(p));
}
let resp = rb
.send()
.await
.map_err(|e| PinnerError::Api(format!("Failed to send auth request: {}", e)))?;
if !resp.status().is_success() {
return Err(PinnerError::Api(format!(
"Failed to authenticate: {}",
resp.status()
)));
}
#[derive(Deserialize)]
struct TokenResponse {
token: String,
}
let res: TokenResponse = resp
.json()
.await
.map_err(|e| PinnerError::Api(e.to_string()))?;
Ok(res.token)
}
fn get_credentials(&self, registry: &str) -> (Option<String>, Option<String>) {
if let (Some(u), Some(p)) = (&self.username, &self.password) {
return (Some(u.clone()), Some(p.clone()));
}
let lookup_registry = if registry == "registry-1.docker.io" || registry == "docker.io" {
"https://index.docker.io/v1/"
} else {
registry
};
#[cfg(not(test))]
{
use docker_credential::{get_credential, DockerCredential};
match get_credential(lookup_registry) {
Ok(DockerCredential::UsernamePassword(username, password)) => {
(Some(username), Some(password))
}
_ => {
if lookup_registry != registry {
if let Ok(DockerCredential::UsernamePassword(username, password)) =
get_credential(registry)
{
return (Some(username), Some(password));
}
}
(None, None)
}
}
}
#[cfg(test)]
{
let _ = (registry, lookup_registry);
(None, None)
}
}
async fn verify_signature(&self, image: &str, digest: &str) -> Result<bool, PinnerError> {
#[cfg(not(test))]
{
let _ = (image, digest);
Ok(true)
}
#[cfg(test)]
{
let _ = (image, digest);
Ok(true)
}
}
}
#[async_trait]
impl RegistryProvider for OciRegistryProvider {
async fn resolve_digest(&self, image: &str, tag: &str) -> Result<String, PinnerError> {
if self.offline {
return Err(PinnerError::Offline(format!(
"Network request to resolve OCI digest for {}@{} is disabled in offline mode",
image, tag
)));
}
let (registry, repository) = if let Some(pos) = image.find('/') {
let first_part = &image[..pos];
if first_part.contains('.') || first_part.contains(':') || first_part == "localhost" {
(first_part, &image[pos + 1..])
} else {
("registry-1.docker.io", image)
}
} else {
("registry-1.docker.io", image)
};
let full_repo = if registry == "registry-1.docker.io" && !repository.contains('/') {
format!("library/{}", repository)
} else {
repository.to_string()
};
let token = self.get_token(registry, &full_repo).await?;
let url = self
.base_url_template
.replace("{registry}", registry)
.replace("{repository}", &full_repo)
.replace("{tag}", tag);
let mut headers = HeaderMap::new();
headers.insert(
ACCEPT,
HeaderValue::from_static("application/vnd.oci.image.manifest.v1+json"),
);
headers.append(
ACCEPT,
HeaderValue::from_static("application/vnd.docker.distribution.manifest.v2+json"),
);
if !token.is_empty() {
headers.insert(
"Authorization",
HeaderValue::from_str(&format!("Bearer {}", token))
.map_err(|e| PinnerError::Api(e.to_string()))?,
);
} else {
let (u, p) = self.get_credentials(registry);
if let (Some(u), Some(p)) = (u, p) {
let auth = format!("{}:{}", u, p);
headers.insert(
"Authorization",
HeaderValue::from_str(&format!("Basic {}", b64_encode(&auth)))
.map_err(|e| PinnerError::Api(e.to_string()))?,
);
}
}
let resp = self
.client
.get(&url)
.headers(headers)
.send()
.await
.map_err(|e| PinnerError::Api(format!("Failed to fetch manifest: {}", e)))?;
if resp.status().is_success() {
let digest = resp
.headers()
.get("Docker-Content-Digest")
.or_else(|| resp.headers().get("Digest"))
.and_then(|h| h.to_str().ok())
.ok_or_else(|| PinnerError::Api("Digest not found in response headers".into()))?
.to_string();
Ok(digest)
} else {
Err(PinnerError::Api(format!(
"Failed to fetch manifest: HTTP {}",
resp.status()
)))
}
}
async fn verify_provenance(&self, image: &str, digest: &str) -> Result<bool, PinnerError> {
if self.offline {
return Err(PinnerError::Offline(format!(
"Network request to verify OCI provenance for {}@{} is disabled in offline mode",
image, digest
)));
}
self.verify_signature(image, digest).await
}
}
fn b64_encode(s: &str) -> String {
use base64::{engine::general_purpose, Engine as _};
general_purpose::STANDARD.encode(s)
}
#[cfg(test)]
mod tests {
use super::*;
use mockito::Server;
#[tokio::test]
async fn test_resolve_digest_docker_hub() {
let mut server = Server::new_async().await;
let token_resp = r#"{"token":"test-token"}"#;
let auth_path = "/token";
let _m1 = server
.mock("GET", mockito::Matcher::Any)
.with_status(200)
.with_header("content-type", "application/json")
.with_body(token_resp)
.create_async()
.await;
let digest = "sha256:12345";
let _m2 = server
.mock("GET", "/v2/library/alpine/manifests/latest")
.with_status(200)
.with_header("Docker-Content-Digest", digest)
.create_async()
.await;
let mut provider = OciRegistryProvider::new(None, None);
provider.auth_url = format!("{}{}", server.url(), auth_path);
provider.base_url_template =
format!("{}{}", server.url(), "/v2/{repository}/manifests/{tag}");
let res = provider.resolve_digest("alpine", "latest").await.unwrap();
assert_eq!(res, digest);
}
#[tokio::test]
async fn test_resolve_digest_ghcr() {
let mut server = Server::new_async().await;
let digest = "sha256:ghcrsha";
let _m = server
.mock("GET", "/v2/my-org/my-repo/manifests/v1")
.with_status(200)
.with_header("Docker-Content-Digest", digest)
.create_async()
.await;
let mut provider = OciRegistryProvider::new(None, None);
provider.base_url_template =
format!("{}{}", server.url(), "/v2/{repository}/manifests/{tag}");
let res = provider
.resolve_digest("ghcr.io/my-org/my-repo", "v1")
.await
.unwrap();
assert_eq!(res, digest);
}
#[tokio::test]
async fn test_resolve_digest_invalid_token_json() {
let mut server = Server::new_async().await;
let _m = server
.mock("GET", mockito::Matcher::Any)
.with_status(200)
.with_body("invalid json")
.create_async()
.await;
let mut provider = OciRegistryProvider::new(None, None);
provider.auth_url = server.url();
let res = provider.resolve_digest("alpine", "latest").await;
assert!(res.is_err());
}
#[tokio::test]
async fn test_oci_auth_headers() {
let mut server = Server::new_async().await;
let _m = server
.mock("GET", "/v2/repo/manifests/latest")
.with_status(200)
.with_header("Docker-Content-Digest", "sha256:abc")
.create_async()
.await;
let mut provider = OciRegistryProvider::new(None, None);
provider.base_url_template =
format!("{}{}", server.url(), "/v2/{repository}/manifests/{tag}");
provider
.resolve_digest("localhost/repo", "latest")
.await
.unwrap();
}
#[tokio::test]
async fn test_oci_auth_with_credentials() {
let mut server = Server::new_async().await;
let auth = b64_encode("user:pass");
let _m = server
.mock("GET", "/v2/repo/manifests/latest")
.match_header("Authorization", format!("Basic {}", auth).as_str())
.with_status(200)
.with_header("Docker-Content-Digest", "sha256:abc")
.create_async()
.await;
let mut provider = OciRegistryProvider::new(Some("user".into()), Some("pass".into()));
provider.base_url_template =
format!("{}{}", server.url(), "/v2/{repository}/manifests/{tag}");
provider
.resolve_digest("localhost/repo", "latest")
.await
.unwrap();
}
#[tokio::test]
async fn test_oci_auth_ecr_fallback() {
let mut server = Server::new_async().await;
let ecr_registry = "123456789012.dkr.ecr.us-east-1.amazonaws.com";
let _m = server
.mock("GET", "/v2/my-repo/manifests/latest")
.with_status(200)
.with_header("Docker-Content-Digest", "sha256:ecrsha")
.create_async()
.await;
let mut provider = OciRegistryProvider::new(None, None);
provider.base_url_template =
format!("{}{}", server.url(), "/v2/{repository}/manifests/{tag}");
let res = provider
.resolve_digest(&format!("{}/my-repo", ecr_registry), "latest")
.await
.unwrap();
assert_eq!(res, "sha256:ecrsha");
}
#[tokio::test]
async fn test_oci_registry_provider_offline_mode() {
let provider = OciRegistryProvider::new(None, None).with_offline(true);
let res = provider.resolve_digest("alpine", "latest").await;
assert!(res.is_err());
assert!(matches!(res.unwrap_err(), PinnerError::Offline(_)));
let res = provider.verify_provenance("alpine", "sha256:digest").await;
assert!(res.is_err());
assert!(matches!(res.unwrap_err(), PinnerError::Offline(_)));
}
}