mod errors;
pub use errors::GitHubOIDCError;
use jsonwebtoken::{decode, Algorithm, DecodingKey, Validation};
use log::{debug, error, info, warn};
use serde::{Deserialize, Serialize};
#[derive(Debug, Serialize, Deserialize)]
pub struct JWK {
pub kty: String,
pub use_: Option<String>,
pub kid: String,
pub alg: Option<String>,
pub n: String,
pub e: String,
pub x5c: Option<Vec<String>>,
pub x5t: Option<String>,
pub x5t_s256: Option<String>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct GithubJWKS {
pub keys: Vec<JWK>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct GitHubClaims {
pub sub: String,
pub repository: String,
pub repository_owner: String,
pub job_workflow_ref: String,
pub iat: u64,
}
pub const DEFAULT_GITHUB_OIDC_URL: &str = "https://token.actions.githubusercontent.com";
pub async fn fetch_jwks(oidc_url: &str) -> Result<GithubJWKS, GitHubOIDCError> {
info!("Fetching JWKS from {}", oidc_url);
let client = reqwest::Client::new();
let jwks_url = format!("{}/.well-known/jwks", oidc_url);
match client.get(&jwks_url).send().await {
Ok(response) => match response.json::<GithubJWKS>().await {
Ok(jwks) => {
info!("JWKS fetched successfully");
Ok(jwks)
}
Err(e) => {
error!("Failed to parse JWKS response: {:?}", e);
Err(GitHubOIDCError::JWKSParseError(e.to_string()))
}
},
Err(e) => {
error!("Failed to fetch JWKS: {:?}", e);
Err(GitHubOIDCError::JWKSFetchError(e.to_string()))
}
}
}
#[derive(Debug, Clone, Default)]
pub struct GitHubOIDCConfig {
pub audience: Option<String>,
pub repository: Option<String>,
pub repository_owner: Option<String>,
}
impl GithubJWKS {
pub fn validate_github_token(
&self,
token: &str,
config: &GitHubOIDCConfig,
) -> Result<GitHubClaims, GitHubOIDCError> {
debug!("Starting token validation");
if !token.starts_with("eyJ") {
warn!("Invalid token format received");
return Err(GitHubOIDCError::InvalidTokenFormat);
}
debug!("JWKS loaded");
let header = jsonwebtoken::decode_header(token).map_err(|e| {
GitHubOIDCError::HeaderDecodingError(format!(
"Failed to decode header: {}. Make sure you're using a valid JWT, not a PAT.",
e
))
})?;
let decoding_key = if let Some(kid) = header.kid {
let key = self
.keys
.iter()
.find(|k| k.kid == kid)
.ok_or(GitHubOIDCError::KeyNotFound)?;
let modulus = key.n.as_str();
let exponent = key.e.as_str();
DecodingKey::from_rsa_components(modulus, exponent)
.map_err(|e| GitHubOIDCError::DecodingKeyCreationError(e.to_string()))?
} else {
DecodingKey::from_secret("your_secret_key".as_ref())
};
let mut validation = Validation::new(Algorithm::RS256);
if let Some(audience) = &config.audience {
validation.set_audience(&[audience]);
}
let token_data = decode::<GitHubClaims>(token, &decoding_key, &validation)
.map_err(|e| GitHubOIDCError::TokenDecodingError(e.to_string()))?;
let claims = token_data.claims;
if let Some(expected_owner) = &config.repository_owner {
if claims.repository_owner != *expected_owner {
warn!(
"Token organization mismatch. Expected: {}, Found: {}",
expected_owner, claims.repository_owner
);
return Err(GitHubOIDCError::OrganizationMismatch);
}
}
if let Some(expected_repo) = &config.repository {
debug!(
"Comparing repositories - Expected: {}, Found: {}",
expected_repo, claims.repository
);
if claims.repository != *expected_repo {
warn!(
"Token repository mismatch. Expected: {}, Found: {}",
expected_repo, claims.repository
);
return Err(GitHubOIDCError::RepositoryMismatch);
}
}
debug!("Token validation completed successfully");
Ok(claims)
}
}