pub mod jwks;
pub mod request;
pub mod responses;
pub mod route_handler;
pub mod sealed_token;
pub mod sts;
use base64::Engine;
pub use jwks::JwksCache;
use multistore::error::ProxyError;
use multistore::registry::CredentialRegistry;
use multistore::types::TemporaryCredentials;
pub use request::try_parse_sts_request;
use request::StsRequest;
pub use responses::{build_sts_error_response, build_sts_response};
pub use sealed_token::TokenKey;
pub async fn try_handle_sts<C: CredentialRegistry>(
query: Option<&str>,
config: &C,
jwks_cache: &JwksCache,
token_key: Option<&TokenKey>,
) -> Option<(u16, String)> {
let sts_result = try_parse_sts_request(query)?;
let (status, xml) = match sts_result {
Ok(sts_request) => {
let Some(key) = token_key else {
tracing::error!("STS request received but SESSION_TOKEN_KEY is not configured");
return Some(build_sts_error_response(&ProxyError::ConfigError(
"STS requires SESSION_TOKEN_KEY to be configured".into(),
)));
};
match assume_role_with_web_identity(config, &sts_request, "STSPRXY", jwks_cache, key)
.await
{
Ok(creds) => build_sts_response(&creds),
Err(e) => {
tracing::warn!(error = %e, "STS request failed");
build_sts_error_response(&e)
}
}
}
Err(e) => build_sts_error_response(&e),
};
Some((status, xml))
}
fn jwt_decode_unverified(
token: &str,
) -> Result<(serde_json::Value, serde_json::Value), ProxyError> {
let mut parts = token.splitn(3, '.');
let header_b64 = parts
.next()
.ok_or_else(|| ProxyError::InvalidOidcToken("malformed JWT".into()))?;
let payload_b64 = parts
.next()
.ok_or_else(|| ProxyError::InvalidOidcToken("malformed JWT".into()))?;
let decode = |s: &str| -> Result<serde_json::Value, ProxyError> {
let bytes = base64::engine::general_purpose::URL_SAFE_NO_PAD
.decode(s)
.map_err(|e| ProxyError::InvalidOidcToken(format!("base64url decode error: {}", e)))?;
serde_json::from_slice(&bytes)
.map_err(|e| ProxyError::InvalidOidcToken(format!("invalid JWT JSON: {}", e)))
};
Ok((decode(header_b64)?, decode(payload_b64)?))
}
pub async fn assume_role_with_web_identity<C: CredentialRegistry>(
config: &C,
sts_request: &StsRequest,
key_prefix: &str,
jwks_cache: &JwksCache,
token_key: &TokenKey,
) -> Result<TemporaryCredentials, ProxyError> {
let role = config
.get_role(&sts_request.role_arn)
.await?
.ok_or_else(|| ProxyError::RoleNotFound(sts_request.role_arn.to_string()))?;
let (header, insecure_claims) = jwt_decode_unverified(&sts_request.web_identity_token)?;
let issuer = insecure_claims
.get("iss")
.and_then(|v| v.as_str())
.ok_or_else(|| ProxyError::InvalidOidcToken("missing iss claim".into()))?;
if !role.trusted_oidc_issuers.iter().any(|i| i == issuer) {
return Err(ProxyError::InvalidOidcToken(format!(
"untrusted issuer: {}",
issuer
)));
}
let alg = header.get("alg").and_then(|v| v.as_str()).unwrap_or("");
if alg != "RS256" {
return Err(ProxyError::InvalidOidcToken(format!(
"unsupported JWT algorithm: {}",
alg
)));
}
let jwks = jwks_cache.get_or_fetch(issuer).await?;
let kid = header
.get("kid")
.and_then(|v| v.as_str())
.ok_or_else(|| ProxyError::InvalidOidcToken("JWT missing kid".into()))?;
let key = jwks::find_key(&jwks, kid)?;
let claims = jwks::verify_token(&sts_request.web_identity_token, key, issuer, &role)?;
let subject = claims.get("sub").and_then(|v| v.as_str()).unwrap_or("");
if !role.subject_conditions.is_empty() {
let matches = role
.subject_conditions
.iter()
.any(|pattern| subject_matches(subject, pattern));
if !matches {
return Err(ProxyError::InvalidOidcToken(format!(
"subject '{}' does not match any conditions",
subject
)));
}
}
const MIN_SESSION_DURATION_SECS: u64 = 900;
let duration = sts_request
.duration_seconds
.unwrap_or(3600)
.clamp(MIN_SESSION_DURATION_SECS, role.max_session_duration_secs);
let mut creds = sts::mint_temporary_credentials(&role, subject, duration, key_prefix, &claims);
creds.session_token = token_key.seal(&creds)?;
Ok(creds)
}
fn subject_matches(subject: &str, pattern: &str) -> bool {
if pattern == "*" {
return true;
}
let parts: Vec<&str> = pattern.split('*').collect();
if parts.len() == 1 {
return subject == pattern;
}
let mut remaining = subject;
if !parts[0].is_empty() {
if !remaining.starts_with(parts[0]) {
return false;
}
remaining = &remaining[parts[0].len()..];
}
for part in &parts[1..parts.len() - 1] {
if part.is_empty() {
continue;
}
match remaining.find(part) {
Some(idx) => remaining = &remaining[idx + part.len()..],
None => return false,
}
}
let last = parts.last().unwrap();
if !last.is_empty() {
return remaining.ends_with(last);
}
true
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_subject_matching() {
assert!(subject_matches(
"repo:org/repo:ref:refs/heads/main",
"repo:org/repo:*"
));
assert!(subject_matches("repo:org/repo:ref:refs/heads/main", "*"));
assert!(subject_matches(
"repo:org/repo:ref:refs/heads/main",
"repo:org/repo:ref:refs/heads/main"
));
assert!(!subject_matches(
"repo:org/repo:ref:refs/heads/main",
"repo:other/*"
));
assert!(subject_matches(
"repo:org/repo:ref:refs/heads/main",
"repo:org/*:ref:refs/heads/*"
));
}
#[test]
fn test_subject_matching_exact() {
assert!(subject_matches("abc", "abc"));
assert!(!subject_matches("abc", "abcd"));
assert!(!subject_matches("abcd", "abc"));
assert!(!subject_matches("", "abc"));
assert!(subject_matches("", ""));
}
#[test]
fn test_subject_matching_leading_wildcard() {
assert!(subject_matches("anything", "*"));
assert!(subject_matches("", "*"));
assert!(subject_matches("foo", "*foo"));
assert!(subject_matches("xfoo", "*foo"));
assert!(!subject_matches("foox", "*foo"));
}
#[test]
fn test_subject_matching_trailing_wildcard() {
assert!(subject_matches("foo", "foo*"));
assert!(subject_matches("foobar", "foo*"));
assert!(!subject_matches("xfoo", "foo*"));
}
#[test]
fn test_subject_matching_middle_wildcard() {
assert!(subject_matches("foobar", "foo*bar"));
assert!(subject_matches("fooXbar", "foo*bar"));
assert!(subject_matches("fooXYZbar", "foo*bar"));
assert!(!subject_matches("fooXbaz", "foo*bar"));
assert!(!subject_matches("xfoobar", "foo*bar"));
}
#[test]
fn test_subject_matching_multiple_wildcards() {
assert!(subject_matches("axbb", "a*b*b"));
assert!(!subject_matches("axb", "a*b*b"));
assert!(!subject_matches("abc", "a*bc*c"));
assert!(subject_matches("abcc", "a*bc*c"));
assert!(subject_matches("aab", "*a*ab"));
assert!(!subject_matches("xab", "*a*ab"));
assert!(subject_matches("xababab", "*ab*ab"));
assert!(!subject_matches("xab", "*ab*ab"));
}
#[test]
fn test_subject_matching_double_wildcard() {
assert!(subject_matches("anything", "**"));
assert!(subject_matches("", "**"));
}
#[test]
fn test_subject_matching_empty_subject() {
assert!(subject_matches("", "*"));
assert!(!subject_matches("", "a"));
assert!(subject_matches("", ""));
}
}