#[path = "auth_next/mod.rs"]
mod auth_next;
use crate::error::RegistryResult;
#[derive(Debug, Clone)]
pub enum TokenProvider {
Static(String),
None,
#[cfg(feature = "oidc")]
Oidc(OidcProvider),
}
impl TokenProvider {
pub fn static_token(token: impl Into<String>) -> Self {
auth_next::providers::static_token(token)
}
pub fn from_env() -> Self {
auth_next::providers::from_env()
}
pub async fn get_token(&self) -> RegistryResult<Option<String>> {
auth_next::providers::get_token(self).await
}
pub fn is_authenticated(&self) -> bool {
auth_next::providers::is_authenticated(self)
}
#[cfg(feature = "oidc")]
pub fn github_oidc() -> RegistryResult<Self> {
auth_next::providers::github_oidc()
}
}
impl Default for TokenProvider {
fn default() -> Self {
Self::from_env()
}
}
#[cfg(feature = "oidc")]
#[derive(Debug, Clone)]
pub struct OidcProvider {
token_request_url: String,
request_token: String,
registry_exchange_url: String,
audience: String,
cached_token: std::sync::Arc<tokio::sync::RwLock<Option<CachedToken>>>,
}
#[cfg(feature = "oidc")]
#[derive(Debug, Clone)]
struct CachedToken {
token: String,
expires_at: chrono::DateTime<chrono::Utc>,
}
#[cfg(feature = "oidc")]
impl OidcProvider {
pub fn from_github_actions() -> RegistryResult<Self> {
auth_next::oidc::from_github_actions()
}
pub fn new(
token_request_url: impl Into<String>,
request_token: impl Into<String>,
registry_exchange_url: impl Into<String>,
audience: impl Into<String>,
) -> Self {
auth_next::oidc::new(
token_request_url,
request_token,
registry_exchange_url,
audience,
)
}
pub async fn get_token(&self) -> RegistryResult<Option<String>> {
auth_next::cache::get_token(self).await
}
async fn exchange_token_with_retry(&self) -> RegistryResult<String> {
auth_next::oidc::exchange_token_with_retry(self).await
}
async fn exchange_token(&self) -> RegistryResult<String> {
auth_next::oidc::exchange_token(self).await
}
async fn get_github_oidc_token(&self) -> RegistryResult<String> {
auth_next::oidc::get_github_oidc_token(self).await
}
async fn exchange_for_registry_token(&self, oidc_token: &str) -> RegistryResult<String> {
auth_next::oidc::exchange_for_registry_token(self, oidc_token).await
}
pub async fn clear_cache(&self) {
auth_next::cache::clear_cache(self).await
}
}
#[cfg(test)]
mod tests {
use super::*;
use serial_test::serial;
#[test]
fn test_static_token() {
let provider = TokenProvider::static_token("test-token");
assert!(provider.is_authenticated());
}
#[test]
fn test_no_auth() {
let provider = TokenProvider::None;
assert!(!provider.is_authenticated());
}
#[tokio::test]
async fn test_get_static_token() {
let provider = TokenProvider::static_token("my-token");
let token = provider.get_token().await.unwrap();
assert_eq!(token, Some("my-token".to_string()));
}
#[tokio::test]
async fn test_get_no_token() {
let provider = TokenProvider::None;
let token = provider.get_token().await.unwrap();
assert_eq!(token, None);
}
#[test]
#[serial]
fn test_from_env_static() {
std::env::remove_var("ASSAY_REGISTRY_TOKEN");
std::env::remove_var("ASSAY_REGISTRY_OIDC");
std::env::set_var("ASSAY_REGISTRY_TOKEN", "env-token");
let provider = TokenProvider::from_env();
std::env::remove_var("ASSAY_REGISTRY_TOKEN");
assert!(matches!(provider, TokenProvider::Static(_)));
}
#[test]
#[serial]
fn test_from_env_empty_token() {
std::env::remove_var("ASSAY_REGISTRY_TOKEN");
std::env::remove_var("ASSAY_REGISTRY_OIDC");
std::env::set_var("ASSAY_REGISTRY_TOKEN", "");
let provider = TokenProvider::from_env();
std::env::remove_var("ASSAY_REGISTRY_TOKEN");
assert!(matches!(provider, TokenProvider::None));
}
}
#[cfg(all(test, feature = "oidc"))]
mod oidc_tests {
use super::*;
use wiremock::matchers::{body_json, header, method, query_param};
use wiremock::{Mock, MockServer, ResponseTemplate};
#[tokio::test]
async fn test_oidc_full_flow() {
let github_mock = MockServer::start().await;
let registry_mock = MockServer::start().await;
Mock::given(method("GET"))
.and(query_param("audience", "https://registry.test"))
.and(header("authorization", "Bearer gh-request-token"))
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
"value": "github-oidc-jwt-token"
})))
.mount(&github_mock)
.await;
Mock::given(method("POST"))
.and(body_json(serde_json::json!({
"token": "github-oidc-jwt-token",
"grant_type": "urn:ietf:params:oauth:grant-type:token-exchange",
"subject_token_type": "urn:ietf:params:oauth:token-type:jwt"
})))
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
"access_token": "registry-access-token",
"expires_in": 3600,
"token_type": "Bearer"
})))
.mount(®istry_mock)
.await;
let provider = OidcProvider::new(
format!("{}?foo=bar", github_mock.uri()),
"gh-request-token",
format!("{}/auth/oidc/exchange", registry_mock.uri()),
"https://registry.test",
);
let token = provider.get_token().await.unwrap();
assert_eq!(token, Some("registry-access-token".to_string()));
let token2 = provider.get_token().await.unwrap();
assert_eq!(token2, Some("registry-access-token".to_string()));
}
#[tokio::test]
async fn test_oidc_github_failure() {
let github_mock = MockServer::start().await;
Mock::given(method("GET"))
.respond_with(ResponseTemplate::new(401).set_body_string("Unauthorized"))
.mount(&github_mock)
.await;
let provider = OidcProvider::new(
format!("{}?foo=bar", github_mock.uri()),
"bad-token",
"https://registry.test/auth/oidc/exchange",
"https://registry.test",
);
let result = provider.get_token().await;
assert!(matches!(
result,
Err(crate::error::RegistryError::Unauthorized { .. })
));
}
#[tokio::test]
async fn test_oidc_cache_clear() {
let provider = OidcProvider::new(
"https://github.example/token?foo=bar",
"token",
"https://registry.test/exchange",
"https://registry.test",
);
{
let mut cache = provider.cached_token.write().await;
*cache = Some(CachedToken {
token: "cached-token".to_string(),
expires_at: chrono::Utc::now() + chrono::Duration::hours(1),
});
}
let token = provider.get_token().await.unwrap();
assert_eq!(token, Some("cached-token".to_string()));
provider.clear_cache().await;
let cache = provider.cached_token.read().await;
assert!(cache.is_none());
}
#[tokio::test]
async fn test_token_expiry_triggers_refresh() {
let github_mock = MockServer::start().await;
let registry_mock = MockServer::start().await;
Mock::given(method("GET"))
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
"value": "github-oidc-jwt-token"
})))
.mount(&github_mock)
.await;
Mock::given(method("POST"))
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
"access_token": "registry-access-token",
"expires_in": 60, "token_type": "Bearer"
})))
.expect(2) .mount(®istry_mock)
.await;
let provider = OidcProvider::new(
format!("{}?foo=bar", github_mock.uri()),
"gh-request-token",
format!("{}/auth/oidc/exchange", registry_mock.uri()),
"https://registry.test",
);
let _ = provider.get_token().await.unwrap();
{
let mut cache = provider.cached_token.write().await;
*cache = Some(CachedToken {
token: "old-token".to_string(),
expires_at: chrono::Utc::now() - chrono::Duration::seconds(1),
});
}
let token = provider.get_token().await.unwrap();
assert_eq!(token, Some("registry-access-token".to_string()));
}
#[tokio::test]
async fn test_token_cache_buffer() {
let provider = OidcProvider::new(
"https://github.example/token?foo=bar",
"token",
"https://registry.test/exchange",
"https://registry.test",
);
{
let mut cache = provider.cached_token.write().await;
*cache = Some(CachedToken {
token: "almost-expired".to_string(),
expires_at: chrono::Utc::now() + chrono::Duration::seconds(80),
});
}
let cache = provider.cached_token.read().await;
let cached = cache.as_ref().unwrap();
let buffer = chrono::Duration::seconds(90);
let should_refresh = cached.expires_at <= chrono::Utc::now() + buffer;
assert!(
should_refresh,
"Token expiring in 80s should trigger refresh (90s buffer)"
);
}
#[tokio::test]
async fn test_token_not_in_debug_output() {
let provider = TokenProvider::static_token("secret-token-12345");
let debug_output = format!("{:?}", provider);
assert!(
debug_output.contains("Static"),
"Should show token type in debug"
);
}
#[tokio::test]
async fn test_oidc_retry_backoff_on_failure() {
let github_mock = MockServer::start().await;
let registry_mock = MockServer::start().await;
Mock::given(method("GET"))
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
"value": "github-oidc-jwt-token"
})))
.mount(&github_mock)
.await;
Mock::given(method("POST"))
.respond_with(ResponseTemplate::new(500).set_body_string("Server Error"))
.expect(4) .mount(®istry_mock)
.await;
let provider = OidcProvider::new(
format!("{}?foo=bar", github_mock.uri()),
"gh-request-token",
format!("{}/auth/oidc/exchange", registry_mock.uri()),
"https://registry.test",
);
let start = std::time::Instant::now();
let result = provider.get_token().await;
let elapsed = start.elapsed();
assert!(
matches!(result, Err(crate::error::RegistryError::Network { .. })),
"Should fail with network error after retries: {:?}",
result
);
assert!(
elapsed.as_secs() >= 2,
"Should have exponential backoff, elapsed: {:?}",
elapsed
);
}
}