use crate::core::{BranchName, DependencyName, DependencyRef};
use crate::error::PinnerError;
use crate::resolver::azure::ReqwestAzureProvider;
use crate::resolver::bitbucket::ReqwestBitbucketProvider;
use crate::resolver::circleci::ReqwestCircleCiProvider;
use crate::resolver::forgejo::ReqwestForgejoProvider;
use crate::resolver::github::ReqwestGithubProvider;
use crate::resolver::gitlab::ReqwestGitLabProvider;
use async_trait::async_trait;
use moka::future::Cache;
use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, USER_AGENT};
use reqwest_middleware::{ClientBuilder, ClientWithMiddleware};
use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware};
use std::path::PathBuf;
use std::sync::Arc;
use std::time::Duration;
#[cfg_attr(test, mockall::automock)]
#[async_trait]
pub trait RemoteProvider: Send + Sync {
async fn get_commit_sha(
&self,
action: &DependencyName,
tag: &str,
key: &str,
) -> Result<DependencyRef, PinnerError>;
async fn get_latest_release(
&self,
action: &DependencyName,
key: &str,
) -> Result<String, PinnerError>;
async fn list_tags(
&self,
action: &DependencyName,
key: &str,
) -> Result<Vec<String>, PinnerError>;
async fn get_default_branch(
&self,
action: &DependencyName,
key: &str,
) -> Result<BranchName, PinnerError>;
}
pub struct CachedProvider<T: RemoteProvider> {
inner: T,
sha_cache: Cache<(DependencyName, String), DependencyRef>,
release_cache: Cache<DependencyName, String>,
branch_cache: Cache<DependencyName, BranchName>,
disk_cache_path: Option<PathBuf>,
offline: bool,
}
impl<T: RemoteProvider> CachedProvider<T> {
pub fn new(inner: T, disk_cache_path: Option<PathBuf>, offline: bool) -> Self {
Self {
inner,
sha_cache: Cache::builder()
.max_capacity(1000)
.time_to_live(Duration::from_secs(3600))
.build(),
release_cache: Cache::builder()
.max_capacity(500)
.time_to_live(Duration::from_secs(3600))
.build(),
branch_cache: Cache::builder()
.max_capacity(500)
.time_to_live(Duration::from_secs(3600))
.build(),
disk_cache_path,
offline,
}
}
}
#[async_trait]
impl<T: RemoteProvider> RemoteProvider for CachedProvider<T> {
async fn get_commit_sha(
&self,
action: &DependencyName,
tag: &str,
key: &str,
) -> Result<DependencyRef, PinnerError> {
let mem_key = (action.clone(), tag.to_string());
if let Some(sha) = self.sha_cache.get(&mem_key).await {
return Ok(sha);
}
let disk_key = format!("sha:{}:{}:{}", action, tag, key);
if let Some(path) = &self.disk_cache_path {
if let Ok(data) = cacache::read(path, &disk_key).await {
let sha = DependencyRef::from(String::from_utf8_lossy(&data).to_string());
self.sha_cache.insert(mem_key, sha.clone()).await;
return Ok(sha);
}
}
if self.offline {
return Err(PinnerError::Offline(format!(
"Network request for commit SHA of {}@{} is disabled in offline mode",
action, tag
)));
}
let sha = self.inner.get_commit_sha(action, tag, key).await?;
self.sha_cache.insert(mem_key, sha.clone()).await;
if let Some(path) = &self.disk_cache_path {
let _ = cacache::write(path, &disk_key, sha.to_string().as_bytes()).await;
}
Ok(sha)
}
async fn get_latest_release(
&self,
action: &DependencyName,
key: &str,
) -> Result<String, PinnerError> {
if let Some(tag) = self.release_cache.get(action).await {
return Ok(tag);
}
let disk_key = format!("release:{}:{}", action, key);
if let Some(path) = &self.disk_cache_path {
if let Ok(data) = cacache::read(path, &disk_key).await {
let tag = String::from_utf8_lossy(&data).to_string();
self.release_cache.insert(action.clone(), tag.clone()).await;
return Ok(tag);
}
}
if self.offline {
return Err(PinnerError::Offline(format!(
"Network request for latest release of {} is disabled in offline mode",
action
)));
}
let tag = self.inner.get_latest_release(action, key).await?;
self.release_cache.insert(action.clone(), tag.clone()).await;
if let Some(path) = &self.disk_cache_path {
let _ = cacache::write(path, &disk_key, tag.as_bytes()).await;
}
Ok(tag)
}
async fn list_tags(
&self,
action: &DependencyName,
key: &str,
) -> Result<Vec<String>, PinnerError> {
if self.offline {
return Err(PinnerError::Offline(format!(
"Network request for listing tags of {} is disabled in offline mode",
action
)));
}
self.inner.list_tags(action, key).await
}
async fn get_default_branch(
&self,
action: &DependencyName,
key: &str,
) -> Result<BranchName, PinnerError> {
if let Some(branch) = self.branch_cache.get(action).await {
return Ok(branch);
}
let disk_key = format!("branch:{}:{}", action, key);
if let Some(path) = &self.disk_cache_path {
if let Ok(data) = cacache::read(path, &disk_key).await {
let branch = BranchName(String::from_utf8_lossy(&data).to_string());
self.branch_cache
.insert(action.clone(), branch.clone())
.await;
return Ok(branch);
}
}
if self.offline {
return Err(PinnerError::Offline(format!(
"Network request for default branch of {} is disabled in offline mode",
action
)));
}
let branch = self.inner.get_default_branch(action, key).await?;
self.branch_cache
.insert(action.clone(), branch.clone())
.await;
if let Some(path) = &self.disk_cache_path {
let _ = cacache::write(path, &disk_key, branch.0.as_bytes()).await;
}
Ok(branch)
}
}
pub struct RateLimitMiddleware;
#[async_trait]
impl reqwest_middleware::Middleware for RateLimitMiddleware {
async fn handle(
&self,
req: reqwest::Request,
extensions: &mut http::Extensions,
next: reqwest_middleware::Next<'_>,
) -> reqwest_middleware::Result<reqwest::Response> {
let mut attempts = 0;
loop {
let req_to_send = req.try_clone().ok_or_else(|| {
reqwest_middleware::Error::middleware(std::io::Error::other(
"Request is not cloneable",
))
})?;
let res = next.clone().run(req_to_send, extensions).await;
match &res {
Ok(resp) if resp.status().as_u16() == 403 || resp.status().as_u16() == 429 => {
attempts += 1;
if attempts >= 3 {
return res;
}
let mut delay_secs = 2;
if let Some(retry_after) = resp.headers().get("retry-after") {
if let Ok(retry_str) = retry_after.to_str() {
if let Ok(secs) = retry_str.parse::<u64>() {
delay_secs = secs;
}
}
} else if let Some(reset) = resp.headers().get("x-ratelimit-reset") {
if let Ok(reset_str) = reset.to_str() {
if let Ok(ts) = reset_str.parse::<i64>() {
use chrono::Utc;
let now = Utc::now().timestamp();
if ts > now {
delay_secs = (ts - now) as u64;
}
}
}
}
let delay_secs = std::cmp::min(delay_secs, 15);
tracing::warn!(
"Hit rate limit or abuse detection (HTTP {}). Retrying in {} seconds (attempt {}/3)...",
resp.status(),
delay_secs,
attempts
);
tokio::time::sleep(std::time::Duration::from_secs(delay_secs)).await;
}
_ => return res,
}
}
}
}
pub struct BaseHttpClient {
pub client: ClientWithMiddleware,
pub base_url: String,
}
impl BaseHttpClient {
pub fn new(
base_url: String,
token: Option<String>,
token_prefix: &str,
env_var: &str,
) -> Result<Self, PinnerError> {
let mut h = HeaderMap::new();
h.insert(USER_AGENT, HeaderValue::from_static("pinner"));
let token = token.or_else(|| std::env::var(env_var).ok());
if let Some(t) = token {
let auth_val = if token_prefix.is_empty() {
t
} else {
format!("{} {}", token_prefix, t)
};
if let Ok(auth) = HeaderValue::from_str(&auth_val) {
h.insert(AUTHORIZATION, auth);
}
}
let reqwest_client = reqwest::Client::builder()
.default_headers(h)
.build()
.map_err(|e| PinnerError::Api(format!("Failed to build reqwest client: {}", e)))?;
let retry_policy = ExponentialBackoff::builder().build_with_max_retries(3);
let client = ClientBuilder::new(reqwest_client)
.with(RateLimitMiddleware)
.with(RetryTransientMiddleware::new_with_policy(retry_policy))
.build();
Ok(Self { client, base_url })
}
pub fn handle_error(&self, resp: reqwest::Response, action: &DependencyName) -> PinnerError {
let status = resp.status();
match status.as_u16() {
403 | 429 => {
let mut msg = format!(
"API rate limit exceeded (HTTP {}) at {}. Try providing an API token to increase limits.",
status, self.base_url
);
if let Some(reset) = resp.headers().get("x-ratelimit-reset") {
if let Ok(reset_str) = reset.to_str() {
if let Ok(ts) = reset_str.parse::<i64>() {
use chrono::{TimeZone, Utc};
if let Some(dt) = Utc.timestamp_opt(ts, 0).single() {
msg.push_str(&format!(
" Rate limit resets at {}.",
dt.format("%Y-%m-%d %H:%M:%S UTC")
));
}
}
}
} else if let Some(retry) = resp.headers().get("retry-after") {
if let Ok(retry_str) = retry.to_str() {
msg.push_str(&format!(" Retry after {} seconds.", retry_str));
}
}
PinnerError::RateLimit(msg)
}
_ => PinnerError::Api(format!(
"HTTP {}: Error for action {} at {}",
status, action, self.base_url
)),
}
}
}
#[derive(serde::Deserialize)]
pub struct RefResponse {
pub sha: String,
}
#[derive(serde::Deserialize)]
pub struct ReleaseResponse {
pub tag_name: String,
}
#[derive(serde::Deserialize)]
pub struct RepoResponse {
pub default_branch: String,
}
#[derive(Clone)]
pub struct ProviderRegistry {
pub providers: Vec<(Arc<dyn RemoteProvider>, ProviderTypeInfo)>,
}
#[derive(Clone)]
pub struct ProviderTypeInfo {
pub domains: Vec<String>,
pub keys: Vec<String>,
pub variant: String,
}
impl ProviderRegistry {
pub fn new(config: UnifiedProviderConfig) -> Result<Self, PinnerError> {
let mut registry = Self {
providers: Vec::new(),
};
registry.register(
Arc::new(CachedProvider::new(
ReqwestGithubProvider::new(config.github_url.clone(), config.github_token.clone())?,
config.disk_cache_path.clone(),
config.offline,
)),
ProviderTypeInfo {
domains: vec!["github.com".to_string()],
keys: vec!["uses".to_string(), "image".to_string()],
variant: "GitHub".to_string(),
},
);
registry.register(
Arc::new(CachedProvider::new(
ReqwestAzureProvider::new(ReqwestGithubProvider::new(
config.github_url.clone(),
config.github_token.clone(),
)?),
config.disk_cache_path.clone(),
config.offline,
)),
ProviderTypeInfo {
domains: vec![],
keys: vec!["task".to_string(), "template".to_string()],
variant: "Azure".to_string(),
},
);
registry.register(
Arc::new(CachedProvider::new(
ReqwestBitbucketProvider::new(config.bitbucket_url, config.bitbucket_token)?,
config.disk_cache_path.clone(),
config.offline,
)),
ProviderTypeInfo {
domains: vec!["bitbucket.org".to_string()],
keys: vec!["pipe".to_string(), "image".to_string()],
variant: "Bitbucket".to_string(),
},
);
registry.register(
Arc::new(CachedProvider::new(
ReqwestGitLabProvider::new(config.gitlab_url, config.gitlab_token)?,
config.disk_cache_path.clone(),
config.offline,
)),
ProviderTypeInfo {
domains: vec!["gitlab.com".to_string()],
keys: vec![
"include".to_string(),
"image".to_string(),
"ref".to_string(),
],
variant: "GitLab".to_string(),
},
);
registry.register(
Arc::new(CachedProvider::new(
ReqwestForgejoProvider::new(config.forgejo_url, config.forgejo_token)?,
config.disk_cache_path.clone(),
config.offline,
)),
ProviderTypeInfo {
domains: vec!["codeberg.org".to_string(), "forgejo".to_string()],
keys: vec!["uses".to_string(), "image".to_string()],
variant: "Forgejo".to_string(),
},
);
registry.register(
Arc::new(CachedProvider::new(
ReqwestCircleCiProvider::new(config.circleci_url, config.circleci_token)?,
config.disk_cache_path.clone(),
config.offline,
)),
ProviderTypeInfo {
domains: vec![],
keys: vec!["orbs".to_string()],
variant: "CircleCi".to_string(),
},
);
Ok(registry)
}
pub fn register(&mut self, provider: Arc<dyn RemoteProvider>, info: ProviderTypeInfo) {
self.providers.push((provider, info));
}
pub fn get_provider(&self, key: &str, action: &DependencyName) -> Arc<dyn RemoteProvider> {
let action_str = action.0.as_str();
for (provider, info) in &self.providers {
if info.domains.iter().any(|d| action_str.contains(d)) {
return provider.clone();
}
}
for (provider, info) in &self.providers {
if info.keys.iter().any(|k| k == key) {
return provider.clone();
}
}
self.providers[0].0.clone()
}
}
#[derive(Clone)]
pub enum ProviderType {
GitHub(Arc<CachedProvider<ReqwestGithubProvider>>),
Bitbucket(Arc<CachedProvider<ReqwestBitbucketProvider>>),
GitLab(Arc<CachedProvider<ReqwestGitLabProvider>>),
Forgejo(Arc<CachedProvider<ReqwestForgejoProvider>>),
CircleCi(Arc<CachedProvider<ReqwestCircleCiProvider>>),
}
#[derive(Clone)]
pub struct UnifiedProviderConfig {
pub github_url: String,
pub github_token: Option<String>,
pub bitbucket_url: String,
pub bitbucket_token: Option<String>,
pub gitlab_url: String,
pub gitlab_token: Option<String>,
pub forgejo_url: String,
pub forgejo_token: Option<String>,
pub circleci_url: String,
pub circleci_token: Option<String>,
pub disk_cache_path: Option<PathBuf>,
pub offline: bool,
}
impl Default for UnifiedProviderConfig {
fn default() -> Self {
Self {
github_url: "https://api.github.com".to_string(),
github_token: None,
bitbucket_url: "https://api.bitbucket.org/2.0".to_string(),
bitbucket_token: None,
gitlab_url: "https://gitlab.com/api/v4".to_string(),
gitlab_token: None,
forgejo_url: "https://codeberg.org/api/v1".to_string(),
forgejo_token: None,
circleci_url: "https://circleci.com/api/v2".to_string(),
circleci_token: None,
disk_cache_path: None,
offline: false,
}
}
}
#[derive(Clone)]
pub struct UnifiedProvider {
pub registry: ProviderRegistry,
}
impl UnifiedProvider {
pub fn new(config: UnifiedProviderConfig) -> Result<Self, PinnerError> {
Ok(Self {
registry: ProviderRegistry::new(config)?,
})
}
}
#[async_trait]
impl RemoteProvider for UnifiedProvider {
async fn get_commit_sha(
&self,
action: &DependencyName,
tag: &str,
key: &str,
) -> Result<DependencyRef, PinnerError> {
self.registry
.get_provider(key, action)
.get_commit_sha(action, tag, key)
.await
}
async fn get_latest_release(
&self,
action: &DependencyName,
key: &str,
) -> Result<String, PinnerError> {
self.registry
.get_provider(key, action)
.get_latest_release(action, key)
.await
}
async fn list_tags(
&self,
action: &DependencyName,
key: &str,
) -> Result<Vec<String>, PinnerError> {
self.registry
.get_provider(key, action)
.list_tags(action, key)
.await
}
async fn get_default_branch(
&self,
action: &DependencyName,
key: &str,
) -> Result<BranchName, PinnerError> {
self.registry
.get_provider(key, action)
.get_default_branch(action, key)
.await
}
}
#[cfg(test)]
mod tests {
use super::*;
use reqwest::{Response, StatusCode};
#[test]
fn test_action_name_display_and_from() {
let name = DependencyName::from("actions/checkout");
assert_eq!(format!("{}", name), "actions/checkout");
assert_eq!(
DependencyName::from("a".to_string()),
DependencyName("a".into())
);
assert_eq!(DependencyName::from(""), DependencyName("".into()));
assert_eq!(
DependencyName::from("".to_string()),
DependencyName("".into())
);
}
#[test]
fn test_commit_sha_display_and_from() {
let sha = DependencyRef::from("a1b2c3d".to_string());
assert_eq!(format!("{}", sha), "a1b2c3d");
}
#[test]
fn test_branch_name_display_and_from() {
let branch = BranchName::from("main".to_string());
assert_eq!(format!("{}", branch), "main");
}
#[tokio::test]
async fn test_handle_api_error() {
let provider = ReqwestGithubProvider::new("https://api.github.com".into(), None).unwrap();
let action = DependencyName::from("o/r");
let resp = Response::from(
http::Response::builder()
.status(StatusCode::FORBIDDEN)
.body("")
.unwrap(),
);
let err = provider.base.handle_error(resp, &action);
assert!(matches!(err, PinnerError::RateLimit(_)));
assert!(format!("{}", err).contains("rate limit exceeded"));
assert!(format!("{}", err).contains("https://api.github.com"));
let resp = Response::from(
http::Response::builder()
.status(StatusCode::TOO_MANY_REQUESTS)
.body("")
.unwrap(),
);
let err = provider.base.handle_error(resp, &action);
assert!(matches!(err, PinnerError::RateLimit(_)));
assert!(format!("{}", err).contains("rate limit exceeded"));
assert!(format!("{}", err).contains("https://api.github.com"));
let resp = Response::from(
http::Response::builder()
.status(StatusCode::NOT_FOUND)
.body("")
.unwrap(),
);
let err = provider.base.handle_error(resp, &action);
assert!(matches!(err, PinnerError::Api(_)));
assert!(format!("{}", err).contains("HTTP 404"));
assert!(format!("{}", err).contains("o/r"));
assert!(format!("{}", err).contains("https://api.github.com"));
let resp = Response::from(
http::Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body("")
.unwrap(),
);
let err = provider.base.handle_error(resp, &action);
assert!(matches!(err, PinnerError::Api(_)));
assert!(format!("{}", err).contains("HTTP 500"));
assert!(format!("{}", err).contains("https://api.github.com"));
}
#[tokio::test]
async fn test_handle_rate_limit_headers() {
let provider = ReqwestGithubProvider::new("https://api.github.com".into(), None).unwrap();
let action = DependencyName::from("o/r");
let ts = 1718374400; let resp = Response::from(
http::Response::builder()
.status(StatusCode::TOO_MANY_REQUESTS)
.header("x-ratelimit-reset", ts.to_string())
.body("")
.unwrap(),
);
let err = provider.base.handle_error(resp, &action);
assert!(format!("{}", err).contains("Rate limit resets at 2024-06-14 14:13:20 UTC"));
let resp = Response::from(
http::Response::builder()
.status(StatusCode::TOO_MANY_REQUESTS)
.header("retry-after", "60")
.body("")
.unwrap(),
);
let err = provider.base.handle_error(resp, &action);
assert!(format!("{}", err).contains("Retry after 60 seconds"));
}
#[tokio::test]
async fn test_list_tags_success() {
let mut server = mockito::Server::new_async().await;
let _m = server
.mock("GET", "/repos/o/r/tags")
.with_status(200)
.with_body(r#"[{"name":"v1.0.0"},{"name":"v1.1.0"}]"#)
.create_async()
.await;
let provider = ReqwestGithubProvider::new(server.url(), None).unwrap();
let tags = provider
.list_tags(&DependencyName::from("o/r"), "uses")
.await
.unwrap();
assert_eq!(tags, vec!["v1.0.0".to_string(), "v1.1.0".to_string()]);
}
#[tokio::test]
async fn test_list_tags_error() {
let mut server = mockito::Server::new_async().await;
let _m = server
.mock("GET", "/repos/o/r/tags")
.with_status(500)
.create_async()
.await;
let provider = ReqwestGithubProvider::new(server.url(), None).unwrap();
let res = provider
.list_tags(&DependencyName::from("o/r"), "uses")
.await;
assert!(res.is_err());
}
#[tokio::test]
async fn test_get_latest_release_404_fallback() {
let mut server = mockito::Server::new_async().await;
let _m1 = server
.mock("GET", "/repos/o/r/releases/latest")
.with_status(404)
.create_async()
.await;
let _m2 = server
.mock("GET", "/repos/o/r")
.with_status(200)
.with_body(r#"{"default_branch":"main"}"#)
.create_async()
.await;
let provider = ReqwestGithubProvider::new(server.url(), None).unwrap();
let tag = provider
.get_latest_release(&DependencyName::from("o/r"), "uses")
.await
.unwrap();
assert_eq!(tag, "main");
}
#[tokio::test]
async fn test_get_default_branch_fail_fallback() {
let mut server = mockito::Server::new_async().await;
let _m = server
.mock("GET", "/repos/o/r")
.with_status(500)
.create_async()
.await;
let provider = ReqwestGithubProvider::new(server.url(), None).unwrap();
let branch = provider
.get_default_branch(&DependencyName::from("o/r"), "uses")
.await
.unwrap();
assert_eq!(branch.0, "main");
}
#[test]
fn test_dto_deserialization() {
let ref_json = r#"{"sha":"123"}"#;
let res: RefResponse = serde_json::from_str(ref_json).unwrap();
assert_eq!(res.sha, "123");
let rel_json = r#"{"tag_name":"v1"}"#;
let res: ReleaseResponse = serde_json::from_str(rel_json).unwrap();
assert_eq!(res.tag_name, "v1");
let repo_json = r#"{"default_branch":"develop"}"#;
let res: RepoResponse = serde_json::from_str(repo_json).unwrap();
assert_eq!(res.default_branch, "develop");
#[derive(serde::Deserialize)]
struct BitbucketDCRepoResponse {
#[serde(rename = "defaultBranch")]
default_branch: String,
}
let bb_json = r#"{"defaultBranch":"prod"}"#;
let res: BitbucketDCRepoResponse = serde_json::from_str(bb_json).unwrap();
assert_eq!(res.default_branch, "prod");
}
#[tokio::test]
async fn test_unified_provider_exhaustive() {
let mut server = mockito::Server::new_async().await;
let _m1 = server
.mock("GET", "/repos/o/r/releases/latest")
.with_status(200)
.with_body(r#"{"tag_name":"v1"}"#)
.create_async()
.await;
let _m2 = server
.mock("GET", "/api/v4/projects/o%2Fr/repository/tags")
.with_status(200)
.with_body(r#"[{"name":"v2"}]"#)
.create_async()
.await;
let unified = UnifiedProvider::new(UnifiedProviderConfig {
github_url: server.url(),
bitbucket_url: server.url(),
gitlab_url: server.url(),
forgejo_url: server.url(),
circleci_url: server.url(),
..Default::default()
})
.unwrap();
let rel = unified
.get_latest_release(&DependencyName::from("o/r"), "uses")
.await
.unwrap();
assert_eq!(rel, "v1");
let tags = unified
.list_tags(&DependencyName::from("o/r"), "include")
.await
.unwrap();
assert_eq!(tags, vec!["v2".to_string()]);
let branch = unified
.get_default_branch(&DependencyName::from("o/r"), "none")
.await
.unwrap();
assert_eq!(branch.0, "main");
}
#[tokio::test]
async fn test_unified_provider_error() {
let unified = UnifiedProvider::new(UnifiedProviderConfig {
github_url: "http://invalid".into(),
bitbucket_url: "http://invalid".into(),
gitlab_url: "http://invalid".into(),
forgejo_url: "http://invalid".into(),
circleci_url: "http://invalid".into(),
..Default::default()
})
.unwrap();
let res = unified
.get_commit_sha(&DependencyName::from("o/r"), "v1", "uses")
.await;
assert!(res.is_err());
}
#[tokio::test]
async fn test_provider_errors_exhaustive() {
let mut server = mockito::Server::new_async().await;
let _m = server
.mock("GET", mockito::Matcher::Any)
.with_status(404)
.create_async()
.await;
let gitlab = ReqwestGitLabProvider::new(server.url(), None).unwrap();
assert!(gitlab
.get_commit_sha(&DependencyName::from("o/r"), "v1", "")
.await
.is_err());
assert!(gitlab
.list_tags(&DependencyName::from("o/r"), "")
.await
.unwrap()
.is_empty());
assert_eq!(
gitlab
.get_default_branch(&DependencyName::from("o/r"), "")
.await
.unwrap()
.0,
"main"
);
let forgejo = ReqwestForgejoProvider::new(server.url(), None).unwrap();
assert!(forgejo
.get_commit_sha(&DependencyName::from("o/r"), "v1", "")
.await
.is_err());
assert!(forgejo
.list_tags(&DependencyName::from("o/r"), "")
.await
.unwrap()
.is_empty());
assert_eq!(
forgejo
.get_default_branch(&DependencyName::from("o/r"), "")
.await
.unwrap()
.0,
"main"
);
let bb_cloud = ReqwestBitbucketProvider::with_type(server.url(), None, true).unwrap();
assert!(bb_cloud
.get_commit_sha(&DependencyName::from("o/r"), "v1", "")
.await
.is_err());
assert_eq!(
bb_cloud
.get_default_branch(&DependencyName::from("o/r"), "")
.await
.unwrap()
.0,
"main"
);
let bb_dc = ReqwestBitbucketProvider::with_type(server.url(), None, false).unwrap();
assert!(bb_dc
.get_commit_sha(&DependencyName::from("o/r"), "v1", "")
.await
.is_err());
assert_eq!(
bb_dc
.get_default_branch(&DependencyName::from("o/r"), "")
.await
.unwrap()
.0,
"main"
);
}
#[tokio::test]
async fn test_gitlab_latest_release() {
let mut server = mockito::Server::new_async().await;
let _m = server
.mock("GET", "/api/v4/projects/o%2Fr/releases")
.with_status(200)
.with_body(r#"[{"tag_name":"v10"}]"#)
.create_async()
.await;
let provider = ReqwestGitLabProvider::new(server.url(), None).unwrap();
let rel = provider
.get_latest_release(&DependencyName::from("o/r"), "")
.await
.unwrap();
assert_eq!(rel, "v10");
}
#[tokio::test]
async fn test_forgejo_methods() {
let mut server = mockito::Server::new_async().await;
let _m1 = server
.mock("GET", "/api/v1/repos/o/r/releases")
.with_status(200)
.with_body(r#"[{"tag_name":"f1"}]"#)
.create_async()
.await;
let _m2 = server
.mock("GET", "/api/v1/repos/o/r/tags")
.with_status(200)
.with_body(r#"[{"name":"t1"}]"#)
.create_async()
.await;
let provider = ReqwestForgejoProvider::new(server.url(), None).unwrap();
assert_eq!(
provider
.get_latest_release(&DependencyName::from("o/r"), "")
.await
.unwrap(),
"f1"
);
assert_eq!(
provider
.list_tags(&DependencyName::from("o/r"), "")
.await
.unwrap(),
vec!["t1".to_string()]
);
}
#[tokio::test]
async fn test_gitlab_provider() {
let mut server = mockito::Server::new_async().await;
let _m = server
.mock("GET", "/api/v4/projects/o%2Fr/repository/commits/v1")
.with_status(200)
.with_body(r#"{"id":"gitlabsha"}"#)
.create_async()
.await;
let provider = ReqwestGitLabProvider::new(server.url(), None).unwrap();
let sha = provider
.get_commit_sha(&DependencyName::from("o/r"), "v1", "include")
.await
.unwrap();
assert_eq!(sha.to_string(), "gitlabsha");
}
#[tokio::test]
async fn test_gitlab_provider_error() {
let mut server = mockito::Server::new_async().await;
let _m = server
.mock("GET", "/api/v4/projects/o%2Fr/repository/commits/v1")
.with_status(404)
.create_async()
.await;
let provider = ReqwestGitLabProvider::new(server.url(), None).unwrap();
let res = provider
.get_commit_sha(&DependencyName::from("o/r"), "v1", "include")
.await;
assert!(res.is_err());
assert!(res.unwrap_err().to_string().contains("HTTP 404"));
}
#[tokio::test]
async fn test_forgejo_provider() {
let mut server = mockito::Server::new_async().await;
let _m = server
.mock("GET", "/api/v1/repos/o/r/commits/v1")
.with_status(200)
.with_body(r#"{"sha":"forgejosha"}"#)
.create_async()
.await;
let provider = ReqwestForgejoProvider::new(server.url(), None).unwrap();
let sha = provider
.get_commit_sha(&DependencyName::from("o/r"), "v1", "uses")
.await
.unwrap();
assert_eq!(sha.to_string(), "forgejosha");
}
#[tokio::test]
async fn test_forgejo_provider_error() {
let mut server = mockito::Server::new_async().await;
let _m = server
.mock("GET", "/api/v1/repos/o/r/commits/v1")
.with_status(404)
.create_async()
.await;
let provider = ReqwestForgejoProvider::new(server.url(), None).unwrap();
let res = provider
.get_commit_sha(&DependencyName::from("o/r"), "v1", "uses")
.await;
assert!(res.is_err());
assert!(res.unwrap_err().to_string().contains("HTTP 404"));
}
#[tokio::test]
async fn test_bitbucket_cloud_branch_fallback() {
let mut server = mockito::Server::new_async().await;
let _m1 = server
.mock("GET", "/repositories/o/p/refs/tags/v1")
.with_status(404)
.create_async()
.await;
let _m2 = server
.mock("GET", "/repositories/o/p/refs/branches/v1")
.with_status(200)
.with_body(r#"{"target":{"hash":"branchsha"}}"#)
.create_async()
.await;
let provider = ReqwestBitbucketProvider::with_type(server.url(), None, true).unwrap();
let sha = provider
.get_commit_sha(&DependencyName::from("o/p"), "v1", "pipe")
.await
.unwrap();
assert_eq!(sha.to_string(), "branchsha");
}
#[tokio::test]
async fn test_bitbucket_cloud_annotated_tag() {
let mut server = mockito::Server::new_async().await;
let _m = server
.mock("GET", "/repositories/o/p/refs/tags/v1")
.with_status(200)
.with_body(r#"{"target":{"hash":"tagsha","target":{"hash":"realsha"}}}"#)
.create_async()
.await;
let provider = ReqwestBitbucketProvider::with_type(server.url(), None, true).unwrap();
let sha = provider
.get_commit_sha(&DependencyName::from("o/p"), "v1", "pipe")
.await
.unwrap();
assert_eq!(sha.to_string(), "realsha");
}
#[tokio::test]
async fn test_bitbucket_dc_branch_fallback() {
let mut server = mockito::Server::new_async().await;
let _m1 = server
.mock("GET", "/rest/api/1.0/projects/PROJ/repos/repo/tags/v1")
.with_status(404)
.create_async()
.await;
let _m2 = server
.mock(
"GET",
"/rest/api/1.0/projects/PROJ/repos/repo/branches?filterText=v1",
)
.with_status(200)
.with_body(r#"{"values":[{"latestCommit":"branchsha"}]}"#)
.create_async()
.await;
let provider = ReqwestBitbucketProvider::with_type(server.url(), None, false).unwrap();
let sha = provider
.get_commit_sha(&DependencyName::from("PROJ/repo"), "v1", "pipe")
.await
.unwrap();
assert_eq!(sha.to_string(), "branchsha");
}
#[test]
#[serial_test::serial]
fn test_token_injection() {
std::env::set_var("GITHUB_TOKEN", "env_token");
let _client = BaseHttpClient::new(
"https://api.github.com".to_string(),
None,
"Bearer",
"GITHUB_TOKEN",
)
.unwrap();
std::env::remove_var("GITHUB_TOKEN");
let _client2 = BaseHttpClient::new(
"https://api.github.com".to_string(),
Some("manual_token".into()),
"Bearer",
"GITHUB_TOKEN",
)
.unwrap();
}
#[test]
fn test_provider_registry_routing() {
let config = UnifiedProviderConfig::default();
let registry = ProviderRegistry::new(config).unwrap();
let provider =
registry.get_provider("image", &DependencyName::from("gitlab.com/group/repo"));
assert!(format!("{:?}", Arc::as_ptr(&provider)).is_ascii());
let provider =
registry.get_provider("pipe", &DependencyName::from("sonarsource/sonarcloud-scan"));
assert!(format!("{:?}", Arc::as_ptr(&provider)).is_ascii());
let provider = registry.get_provider("orbs", &DependencyName::from("circleci/node"));
assert!(format!("{:?}", Arc::as_ptr(&provider)).is_ascii());
let provider = registry.get_provider("uses", &DependencyName::from("actions/checkout"));
assert!(format!("{:?}", Arc::as_ptr(&provider)).is_ascii());
}
#[tokio::test]
async fn test_provider_caching() {
let mut s = mockito::Server::new_async().await;
let _m = s
.mock("GET", "/repos/o/r/commits/v1")
.with_status(200)
.with_body(r#"{"sha":"123"}"#)
.expect(1) .create_async()
.await;
let provider = CachedProvider::new(
ReqwestGithubProvider::new(s.url(), None).unwrap(),
None,
false,
);
let action = DependencyName::from("o/r");
let sha1 = provider
.get_commit_sha(&action, "v1", "uses")
.await
.unwrap();
assert_eq!(sha1.to_string(), "123");
let sha2 = provider
.get_commit_sha(&action, "v1", "uses")
.await
.unwrap();
assert_eq!(sha2.to_string(), "123");
let _m2 = s
.mock("GET", "/repos/o/r/releases/latest")
.with_status(200)
.with_body(r#"{"tag_name":"v2"}"#)
.expect(1)
.create_async()
.await;
let r1 = provider.get_latest_release(&action, "uses").await.unwrap();
assert_eq!(r1, "v2");
let r2 = provider.get_latest_release(&action, "uses").await.unwrap();
assert_eq!(r2, "v2");
let _m3 = s
.mock("GET", "/repos/o/r")
.with_status(200)
.with_body(r#"{"default_branch":"main"}"#)
.expect(1)
.create_async()
.await;
let b1 = provider.get_default_branch(&action, "uses").await.unwrap();
assert_eq!(b1.0, "main");
let b2 = provider.get_default_branch(&action, "uses").await.unwrap();
assert_eq!(b2.0, "main");
}
#[tokio::test]
async fn test_bitbucket_cloud_provider() {
let mut server = mockito::Server::new_async().await;
let _m = server
.mock("GET", "/repositories/o/p/refs/tags/v1")
.with_status(200)
.with_body(r#"{"target":{"hash":"cloudsha"}}"#)
.create_async()
.await;
let provider = ReqwestBitbucketProvider::with_type(server.url(), None, true).unwrap();
let sha = provider
.get_commit_sha(&DependencyName::from("o/p"), "v1", "pipe")
.await
.unwrap();
assert_eq!(sha.to_string(), "cloudsha");
}
#[tokio::test]
async fn test_bitbucket_dc_provider() {
let mut server = mockito::Server::new_async().await;
let _m = server
.mock("GET", "/rest/api/1.0/projects/PROJ/repos/repo/tags/v1")
.with_status(200)
.with_body(r#"{"latestCommit":"dcsha"}"#)
.create_async()
.await;
let provider = ReqwestBitbucketProvider::with_type(server.url(), None, false).unwrap();
assert!(!provider.is_cloud);
let sha = provider
.get_commit_sha(&DependencyName::from("PROJ/repo"), "v1", "pipe")
.await
.unwrap();
assert_eq!(sha.to_string(), "dcsha");
}
#[tokio::test]
async fn test_bitbucket_dc_invalid_format() {
let provider =
ReqwestBitbucketProvider::with_type("http://bb.local".into(), None, false).unwrap();
let res = provider
.get_commit_sha(&DependencyName::from("invalid-format"), "v1", "pipe")
.await;
assert!(res.is_err());
assert!(res
.unwrap_err()
.to_string()
.contains("Invalid Bitbucket action format"));
}
#[tokio::test]
async fn test_unified_provider_routing() {
let mut server = mockito::Server::new_async().await;
let _m_github = server
.mock("GET", "/repos/o/r/commits/v1")
.with_status(200)
.with_body(r#"{"sha":"githubsha"}"#)
.create_async()
.await;
let _m_bitbucket = server
.mock("GET", "/rest/api/1.0/projects/o/repos/p/tags/v1")
.with_status(200)
.with_body(r#"{"latestCommit":"bitbucketsha"}"#)
.create_async()
.await;
let unified = UnifiedProvider::new(UnifiedProviderConfig {
github_url: server.url(),
bitbucket_url: server.url(),
gitlab_url: server.url(),
forgejo_url: server.url(),
circleci_url: server.url(),
..Default::default()
})
.unwrap();
let sha1 = unified
.get_commit_sha(&DependencyName::from("o/r"), "v1", "uses")
.await
.unwrap();
assert_eq!(sha1.to_string(), "githubsha");
let sha2 = unified
.get_commit_sha(&DependencyName::from("o/p"), "v1", "pipe")
.await
.unwrap();
assert_eq!(sha2.to_string(), "bitbucketsha");
}
#[test]
fn test_handle_error_rate_limit() {
let base = BaseHttpClient::new(
"https://api.github.com".to_string(),
None,
"Bearer",
"GITHUB_TOKEN",
)
.unwrap();
let action = DependencyName::from("actions/checkout");
let mut builder = http::Response::builder().status(403);
builder = builder.header("x-ratelimit-reset", "1718352000"); let resp = builder.body("").unwrap();
let reqwest_resp = reqwest::Response::from(resp);
let err = base.handle_error(reqwest_resp, &action);
assert!(matches!(err, PinnerError::RateLimit(_)));
assert!(err
.to_string()
.contains("Rate limit resets at 2024-06-14 08:00:00 UTC"));
let mut builder = http::Response::builder().status(429);
builder = builder.header("retry-after", "60");
let resp = builder.body("").unwrap();
let reqwest_resp = reqwest::Response::from(resp);
let err = base.handle_error(reqwest_resp, &action);
assert!(matches!(err, PinnerError::RateLimit(_)));
assert!(err.to_string().contains("Retry after 60 seconds"));
let resp = http::Response::builder().status(500).body("").unwrap();
let reqwest_resp = reqwest::Response::from(resp);
let err = base.handle_error(reqwest_resp, &action);
assert!(matches!(err, PinnerError::Api(_)));
assert!(err.to_string().contains("HTTP 500"));
}
#[tokio::test]
async fn test_unified_provider_fallback() {
let unified = UnifiedProvider::new(UnifiedProviderConfig::default()).unwrap();
let action = DependencyName::from("o/r");
let key = "unknown";
let res_commit = unified.get_commit_sha(&action, "v1", key).await;
assert!(res_commit.is_err());
let res_branch = unified.get_default_branch(&action, key).await;
assert_eq!(res_branch.unwrap().0, "main");
}
#[tokio::test]
async fn test_cached_provider_offline_mode() {
let remote = MockRemoteProvider::new();
let provider = CachedProvider::new(remote, None, true);
let action = DependencyName::from("o/r");
let res = provider.get_commit_sha(&action, "v1", "uses").await;
assert!(res.is_err());
assert!(matches!(res.unwrap_err(), PinnerError::Offline(_)));
let res = provider.get_latest_release(&action, "uses").await;
assert!(res.is_err());
assert!(matches!(res.unwrap_err(), PinnerError::Offline(_)));
let res = provider.list_tags(&action, "uses").await;
assert!(res.is_err());
assert!(matches!(res.unwrap_err(), PinnerError::Offline(_)));
let res = provider.get_default_branch(&action, "uses").await;
assert!(res.is_err());
assert!(matches!(res.unwrap_err(), PinnerError::Offline(_)));
}
}