pub mod aws;
pub mod azure;
pub mod github;
use async_trait::async_trait;
pub use aws::AwsWebIdentityProvider;
pub use azure::AzureWorkloadIdentityProvider;
pub use github::GitHubActionsProvider;
#[derive(Debug, thiserror::Error)]
pub enum OidcError {
#[error("No OIDC provider detected in environment")]
NoProviderFound,
#[error("Failed to fetch OIDC token: {0}")]
FetchFailed(#[from] reqwest::Error),
#[error("Invalid OIDC token response")]
InvalidResponse,
#[error("Token file not found: {path}")]
FileNotFound {
path: String,
},
#[error("Environment variable missing: {var}")]
MissingEnvVar {
var: String,
},
#[error("IO error: {0}")]
IoError(#[from] std::io::Error),
}
#[async_trait]
pub trait OidcTokenProvider: Send + Sync {
async fn get_token(&self) -> Result<String, OidcError>;
}
pub fn auto_detect_provider() -> Result<Box<dyn OidcTokenProvider>, OidcError> {
if std::env::var("ACTIONS_ID_TOKEN_REQUEST_TOKEN").is_ok()
&& std::env::var("ACTIONS_ID_TOKEN_REQUEST_URL").is_ok()
{
return Ok(Box::new(GitHubActionsProvider::new()?));
}
if std::env::var("AWS_WEB_IDENTITY_TOKEN_FILE").is_ok() {
return Ok(Box::new(AwsWebIdentityProvider::from_env()?));
}
if std::env::var("IDENTITY_ENDPOINT").is_ok() && std::env::var("IDENTITY_HEADER").is_ok() {
return Ok(Box::new(AzureWorkloadIdentityProvider::new()?));
}
Err(OidcError::NoProviderFound)
}
#[cfg(test)]
mod tests {
use super::*;
use serial_test::serial;
#[test]
fn test_oidc_error_display() {
let err = OidcError::NoProviderFound;
assert_eq!(err.to_string(), "No OIDC provider detected in environment");
let err = OidcError::InvalidResponse;
assert_eq!(err.to_string(), "Invalid OIDC token response");
let err = OidcError::FileNotFound {
path: "/tmp/token".into(),
};
assert_eq!(err.to_string(), "Token file not found: /tmp/token");
let err = OidcError::MissingEnvVar {
var: "MY_VAR".into(),
};
assert_eq!(err.to_string(), "Environment variable missing: MY_VAR");
}
#[test]
#[serial]
fn test_auto_detect_github() {
unsafe {
std::env::set_var("ACTIONS_ID_TOKEN_REQUEST_TOKEN", "token");
std::env::set_var("ACTIONS_ID_TOKEN_REQUEST_URL", "http://localhost");
}
let result = auto_detect_provider();
assert!(result.is_ok());
unsafe {
std::env::remove_var("ACTIONS_ID_TOKEN_REQUEST_TOKEN");
std::env::remove_var("ACTIONS_ID_TOKEN_REQUEST_URL");
}
}
#[test]
#[serial]
fn test_auto_detect_aws() {
unsafe {
std::env::remove_var("ACTIONS_ID_TOKEN_REQUEST_TOKEN");
std::env::set_var("AWS_WEB_IDENTITY_TOKEN_FILE", "/tmp/token");
}
let result = auto_detect_provider();
assert!(result.is_ok());
unsafe {
std::env::remove_var("AWS_WEB_IDENTITY_TOKEN_FILE");
}
}
#[test]
#[serial]
fn test_auto_detect_azure() {
unsafe {
std::env::remove_var("ACTIONS_ID_TOKEN_REQUEST_TOKEN");
std::env::remove_var("AWS_WEB_IDENTITY_TOKEN_FILE");
std::env::set_var("IDENTITY_ENDPOINT", "http://localhost");
std::env::set_var("IDENTITY_HEADER", "value");
}
let result = auto_detect_provider();
assert!(result.is_ok());
unsafe {
std::env::remove_var("IDENTITY_ENDPOINT");
std::env::remove_var("IDENTITY_HEADER");
}
}
#[test]
#[serial]
fn test_auto_detect_none() {
unsafe {
std::env::remove_var("ACTIONS_ID_TOKEN_REQUEST_TOKEN");
std::env::remove_var("AWS_WEB_IDENTITY_TOKEN_FILE");
std::env::remove_var("IDENTITY_ENDPOINT");
}
let result = auto_detect_provider();
assert!(result.is_err());
match result {
Err(OidcError::NoProviderFound) => {}
_ => panic!("Expected NoProviderFound error"),
}
}
}