use crate::error::WSError;
use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
use serde::{Deserialize, Serialize};
use std::env;
use zeroize::Zeroize;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OidcToken {
pub token: String,
pub identity: String,
pub issuer: String,
}
impl Drop for OidcToken {
fn drop(&mut self) {
self.token.zeroize();
self.identity.zeroize();
self.issuer.zeroize();
}
}
impl OidcToken {
pub fn get_sub_claim(&self) -> Result<String, WSError> {
let parts: Vec<&str> = self.token.split('.').collect();
if parts.len() != 3 {
return Err(WSError::OidcError("Invalid JWT token format".to_string()));
}
let payload = parts[1];
let decoded = URL_SAFE_NO_PAD
.decode(payload)
.or_else(|_| base64::prelude::BASE64_STANDARD.decode(payload))
.map_err(|e| WSError::OidcError(format!("Failed to decode JWT payload: {}", e)))?;
let payload_str = String::from_utf8(decoded)
.map_err(|e| WSError::OidcError(format!("Invalid UTF-8 in JWT payload: {}", e)))?;
let payload_json: serde_json::Value = serde_json::from_str(&payload_str)
.map_err(|e| WSError::OidcError(format!("Failed to parse JWT payload: {}", e)))?;
payload_json
.get("sub")
.and_then(|v| v.as_str())
.map(|s| s.to_string())
.ok_or_else(|| WSError::OidcError("No 'sub' claim found in JWT token".to_string()))
}
}
pub trait OidcProvider: Send + Sync {
fn get_token(&self) -> Result<OidcToken, WSError>;
fn name(&self) -> &str;
}
#[derive(Debug, Clone)]
pub struct GitHubOidcProvider {
request_token: String,
request_url: String,
}
impl GitHubOidcProvider {
pub fn new() -> Result<Self, WSError> {
Self::from_env()
}
pub fn from_env() -> Result<Self, WSError> {
let request_token = env::var("ACTIONS_ID_TOKEN_REQUEST_TOKEN").map_err(|_| {
WSError::OidcError(
"ACTIONS_ID_TOKEN_REQUEST_TOKEN environment variable not found".to_string(),
)
})?;
let request_url = env::var("ACTIONS_ID_TOKEN_REQUEST_URL").map_err(|_| {
WSError::OidcError(
"ACTIONS_ID_TOKEN_REQUEST_URL environment variable not found".to_string(),
)
})?;
Ok(Self {
request_token,
request_url,
})
}
fn parse_identity(token: &str) -> Result<String, WSError> {
let parts: Vec<&str> = token.split('.').collect();
if parts.len() != 3 {
return Err(WSError::OidcError("Invalid JWT token format".to_string()));
}
let payload = parts[1];
let decoded = URL_SAFE_NO_PAD
.decode(payload)
.or_else(|_| base64::prelude::BASE64_STANDARD.decode(payload))
.map_err(|e| WSError::OidcError(format!("Failed to decode JWT payload: {}", e)))?;
let payload_str = String::from_utf8(decoded)
.map_err(|e| WSError::OidcError(format!("Invalid UTF-8 in JWT payload: {}", e)))?;
let payload_json: serde_json::Value = serde_json::from_str(&payload_str)
.map_err(|e| WSError::OidcError(format!("Failed to parse JWT payload: {}", e)))?;
if let Some(email) = payload_json.get("email").and_then(|v| v.as_str()) {
Ok(email.to_string())
} else if let Some(sub) = payload_json.get("sub").and_then(|v| v.as_str()) {
Ok(sub.to_string())
} else if let Some(actor) = payload_json.get("actor").and_then(|v| v.as_str()) {
Ok(actor.to_string())
} else {
Err(WSError::OidcError(
"No identity field found in JWT token".to_string(),
))
}
}
fn parse_issuer(token: &str) -> Result<String, WSError> {
let parts: Vec<&str> = token.split('.').collect();
if parts.len() != 3 {
return Err(WSError::OidcError("Invalid JWT token format".to_string()));
}
let payload = parts[1];
let decoded = URL_SAFE_NO_PAD
.decode(payload)
.or_else(|_| base64::prelude::BASE64_STANDARD.decode(payload))
.map_err(|e| WSError::OidcError(format!("Failed to decode JWT payload: {}", e)))?;
let payload_str = String::from_utf8(decoded)
.map_err(|e| WSError::OidcError(format!("Invalid UTF-8 in JWT payload: {}", e)))?;
let payload_json: serde_json::Value = serde_json::from_str(&payload_str)
.map_err(|e| WSError::OidcError(format!("Failed to parse JWT payload: {}", e)))?;
payload_json
.get("iss")
.and_then(|v| v.as_str())
.map(|s| s.to_string())
.ok_or_else(|| WSError::OidcError("No issuer field found in JWT token".to_string()))
}
}
impl OidcProvider for GitHubOidcProvider {
fn get_token(&self) -> Result<OidcToken, WSError> {
let token = self.get_token_impl()?;
let identity = Self::parse_identity(&token)?;
let issuer = Self::parse_issuer(&token)?;
Ok(OidcToken {
token,
identity,
issuer,
})
}
fn name(&self) -> &str {
"GitHub Actions"
}
}
#[cfg(not(target_os = "wasi"))]
impl GitHubOidcProvider {
fn get_token_impl(&self) -> Result<String, WSError> {
let url = format!("{}&audience=sigstore", self.request_url);
let response = ureq::get(&url)
.header("Authorization", &format!("Bearer {}", self.request_token))
.call()
.map_err(|e| {
WSError::OidcError(format!("Failed to retrieve OIDC token from GitHub: {}", e))
})?;
let body = response
.into_body()
.read_to_string()
.map_err(|e| WSError::OidcError(format!("Failed to read response body: {}", e)))?;
let json: serde_json::Value = serde_json::from_str(&body).map_err(|e| {
WSError::OidcError(format!("Failed to parse GitHub OIDC response: {}", e))
})?;
json.get("value")
.and_then(|v| v.as_str())
.map(|s| s.to_string())
.ok_or_else(|| {
WSError::OidcError("No 'value' field in GitHub OIDC response".to_string())
})
}
}
#[cfg(target_os = "wasi")]
impl GitHubOidcProvider {
fn get_token_impl(&self) -> Result<String, WSError> {
use wasi::http::outgoing_handler;
use wasi::http::types::{Fields, Method, OutgoingRequest, Scheme};
let url_str = format!("{}&audience=sigstore", self.request_url);
let url = url_str
.strip_prefix("https://")
.or_else(|| url_str.strip_prefix("http://"))
.ok_or_else(|| WSError::OidcError("Invalid OIDC request URL scheme".to_string()))?;
let (authority, path) = url
.split_once('/')
.map(|(auth, path)| (auth, format!("/{}", path)))
.unwrap_or((url, "/".to_string()));
let headers = Fields::new();
let auth_value = format!("Bearer {}", self.request_token);
headers
.append(
&"Authorization".to_string(),
&auth_value.as_bytes().to_vec(),
)
.map_err(|_| WSError::OidcError("Failed to set Authorization header".to_string()))?;
let request = OutgoingRequest::new(headers);
request
.set_method(&Method::Get)
.map_err(|_| WSError::OidcError("Failed to set HTTP method".to_string()))?;
request
.set_scheme(Some(&Scheme::Https))
.map_err(|_| WSError::OidcError("Failed to set HTTPS scheme".to_string()))?;
request
.set_authority(Some(authority))
.map_err(|_| WSError::OidcError("Failed to set authority".to_string()))?;
request
.set_path_with_query(Some(&path))
.map_err(|_| WSError::OidcError("Failed to set path".to_string()))?;
let future_response = outgoing_handler::handle(request, None)
.map_err(|_| WSError::OidcError("Failed to send HTTP request".to_string()))?;
let incoming_response = future_response
.get()
.ok_or_else(|| WSError::OidcError("HTTP request not ready".to_string()))?
.map_err(|_| WSError::OidcError("Failed to get HTTP response".to_string()))??;
let body = incoming_response
.consume()
.map_err(|_| WSError::OidcError("Failed to get response body".to_string()))?;
let mut bytes = Vec::new();
let stream = body
.stream()
.map_err(|_| WSError::OidcError("Failed to get body stream".to_string()))?;
loop {
let chunk = stream
.blocking_read(8192)
.map_err(|_| WSError::OidcError("Failed to read from stream".to_string()))?;
if chunk.is_empty() {
break;
}
bytes.extend_from_slice(&chunk);
}
let json: serde_json::Value = serde_json::from_slice(&bytes).map_err(|e| {
WSError::OidcError(format!("Failed to parse GitHub OIDC response: {}", e))
})?;
json.get("value")
.and_then(|v| v.as_str())
.map(|s| s.to_string())
.ok_or_else(|| {
WSError::OidcError("No 'value' field in GitHub OIDC response".to_string())
})
}
}
#[derive(Debug, Clone)]
pub struct GoogleOidcProvider {
_credentials_path: Option<String>,
}
impl GoogleOidcProvider {
pub fn new() -> Result<Self, WSError> {
Self::from_env()
}
pub fn from_env() -> Result<Self, WSError> {
let credentials_path = env::var("GOOGLE_APPLICATION_CREDENTIALS").ok();
Ok(Self {
_credentials_path: credentials_path,
})
}
}
impl OidcProvider for GoogleOidcProvider {
fn get_token(&self) -> Result<OidcToken, WSError> {
Err(WSError::OidcError(
"Google Cloud OIDC provider not yet implemented".to_string(),
))
}
fn name(&self) -> &str {
"Google Cloud"
}
}
#[derive(Debug, Clone)]
pub struct GitLabOidcProvider {
_job_jwt: Option<String>,
}
impl GitLabOidcProvider {
pub fn new() -> Result<Self, WSError> {
Self::from_env()
}
pub fn from_env() -> Result<Self, WSError> {
let job_jwt = env::var("CI_JOB_JWT").ok();
Ok(Self { _job_jwt: job_jwt })
}
}
impl OidcProvider for GitLabOidcProvider {
fn get_token(&self) -> Result<OidcToken, WSError> {
Err(WSError::OidcError(
"GitLab CI OIDC provider not yet implemented".to_string(),
))
}
fn name(&self) -> &str {
"GitLab CI"
}
}
pub fn detect_oidc_provider() -> Result<Box<dyn OidcProvider>, WSError> {
if env::var("GITHUB_ACTIONS").ok().as_deref() == Some("true") {
let provider = GitHubOidcProvider::new()?;
return Ok(Box::new(provider));
}
if env::var("GOOGLE_APPLICATION_CREDENTIALS").is_ok() {
let provider = GoogleOidcProvider::new()?;
return Ok(Box::new(provider));
}
if env::var("GITLAB_CI").ok().as_deref() == Some("true") {
let provider = GitLabOidcProvider::new()?;
return Ok(Box::new(provider));
}
Err(WSError::NoOidcProvider)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_provider_names() {
let google = GoogleOidcProvider {
_credentials_path: None,
};
assert_eq!(google.name(), "Google Cloud");
let gitlab = GitLabOidcProvider { _job_jwt: None };
assert_eq!(gitlab.name(), "GitLab CI");
}
#[test]
fn test_parse_jwt_identity() {
let payload = r#"{"email":"test@example.com","sub":"user123","iss":"https://token.actions.githubusercontent.com"}"#;
let encoded_payload = URL_SAFE_NO_PAD.encode(payload);
let token = format!("header.{}.signature", encoded_payload);
let identity = GitHubOidcProvider::parse_identity(&token).unwrap();
assert_eq!(identity, "test@example.com");
}
#[test]
fn test_parse_jwt_identity_no_email() {
let payload = r#"{"sub":"user123","iss":"https://token.actions.githubusercontent.com"}"#;
let encoded_payload = URL_SAFE_NO_PAD.encode(payload);
let token = format!("header.{}.signature", encoded_payload);
let identity = GitHubOidcProvider::parse_identity(&token).unwrap();
assert_eq!(identity, "user123");
}
#[test]
fn test_parse_jwt_issuer() {
let payload =
r#"{"email":"test@example.com","iss":"https://token.actions.githubusercontent.com"}"#;
let encoded_payload = URL_SAFE_NO_PAD.encode(payload);
let token = format!("header.{}.signature", encoded_payload);
let issuer = GitHubOidcProvider::parse_issuer(&token).unwrap();
assert_eq!(issuer, "https://token.actions.githubusercontent.com");
}
#[test]
fn test_parse_invalid_jwt() {
let result = GitHubOidcProvider::parse_identity("invalid-token");
assert!(matches!(result, Err(WSError::OidcError(_))));
}
#[test]
fn test_google_provider_not_implemented() {
let provider = GoogleOidcProvider::new().unwrap();
let result = provider.get_token();
assert!(matches!(result, Err(WSError::OidcError(_))));
}
#[test]
fn test_gitlab_provider_not_implemented() {
let provider = GitLabOidcProvider::new().unwrap();
let result = provider.get_token();
assert!(matches!(result, Err(WSError::OidcError(_))));
}
#[test]
fn test_oidc_token_serialization() {
let token = OidcToken {
token: "test-token".to_string(),
identity: "user@example.com".to_string(),
issuer: "https://issuer.example.com".to_string(),
};
let json = serde_json::to_string(&token).unwrap();
let deserialized: OidcToken = serde_json::from_str(&json).unwrap();
assert_eq!(token.token, deserialized.token);
assert_eq!(token.identity, deserialized.identity);
assert_eq!(token.issuer, deserialized.issuer);
}
#[test]
fn test_oidc_token_drop_is_called() {
let test_token_value = "sensitive-jwt-token-12345";
let test_identity = "user@example.com";
let test_issuer = "https://issuer.example.com";
{
let token = OidcToken {
token: test_token_value.to_string(),
identity: test_identity.to_string(),
issuer: test_issuer.to_string(),
};
assert_eq!(token.token, test_token_value);
assert_eq!(token.identity, test_identity);
assert_eq!(token.issuer, test_issuer);
}
}
#[test]
fn test_oidc_token_drop_with_error_path() {
fn operation_that_fails(token: OidcToken) -> Result<(), WSError> {
assert!(!token.token.is_empty());
Err(WSError::OidcError("Simulated error".to_string()))
}
let token = OidcToken {
token: "secret-token".to_string(),
identity: "user@test.com".to_string(),
issuer: "https://test.issuer.com".to_string(),
};
let result = operation_that_fails(token);
assert!(result.is_err());
}
#[test]
fn test_oidc_token_clone_and_drop() {
let original = OidcToken {
token: "original-token".to_string(),
identity: "original@example.com".to_string(),
issuer: "https://original.issuer.com".to_string(),
};
{
let cloned = original.clone();
assert_eq!(original.token, cloned.token);
assert_eq!(original.identity, cloned.identity);
assert_eq!(original.issuer, cloned.issuer);
}
assert_eq!(original.token, "original-token");
}
#[test]
fn test_oidc_token_in_result_error_path() {
fn create_token_and_fail() -> Result<OidcToken, WSError> {
let token = OidcToken {
token: "will-be-zeroized".to_string(),
identity: "test@example.com".to_string(),
issuer: "https://test.com".to_string(),
};
Ok(token)
}
let result = create_token_and_fail();
assert!(result.is_ok());
let token = result.unwrap();
assert_eq!(token.token, "will-be-zeroized");
}
#[test]
fn test_oidc_token_move_semantics() {
fn consume_token(token: OidcToken) -> String {
token.identity.clone()
}
let token = OidcToken {
token: "moved-token".to_string(),
identity: "moved@example.com".to_string(),
issuer: "https://moved.com".to_string(),
};
let identity = consume_token(token);
assert_eq!(identity, "moved@example.com");
}
#[test]
fn test_oidc_token_empty_strings() {
let token = OidcToken {
token: String::new(),
identity: String::new(),
issuer: String::new(),
};
assert!(token.token.is_empty());
assert!(token.identity.is_empty());
assert!(token.issuer.is_empty());
}
#[test]
fn test_oidc_token_large_token() {
let large_token = "a".repeat(10_000); let token = OidcToken {
token: large_token.clone(),
identity: "user@example.com".to_string(),
issuer: "https://issuer.com".to_string(),
};
assert_eq!(token.token.len(), 10_000);
}
#[test]
fn test_oidc_token_get_sub_claim_with_drop() {
let payload = r#"{"sub":"test-subject","iss":"https://issuer.com"}"#;
let encoded_payload = URL_SAFE_NO_PAD.encode(payload);
let jwt = format!("header.{}.signature", encoded_payload);
let token = OidcToken {
token: jwt,
identity: "user@example.com".to_string(),
issuer: "https://issuer.com".to_string(),
};
let sub = token.get_sub_claim().unwrap();
assert_eq!(sub, "test-subject");
}
#[test]
fn test_oidc_token_vec_of_tokens() {
let mut tokens = Vec::new();
for i in 0..10 {
tokens.push(OidcToken {
token: format!("token-{}", i),
identity: format!("user{}@example.com", i),
issuer: "https://issuer.com".to_string(),
});
}
assert_eq!(tokens.len(), 10);
}
}