use base64::engine::general_purpose::URL_SAFE_NO_PAD;
use base64::Engine;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum AudCheckOutcome {
Match,
Mismatch {
found: Vec<String>,
expected: String,
},
MissingAud,
OpaqueToken,
}
impl AudCheckOutcome {
pub fn is_mismatch(&self) -> bool {
matches!(self, AudCheckOutcome::Mismatch { .. })
}
}
pub fn validate_token_aud(access_token: &str, expected_resource: &str) -> AudCheckOutcome {
let Some(payload_b64) = jwt_payload_segment(access_token) else {
return AudCheckOutcome::OpaqueToken;
};
let Ok(payload_bytes) = URL_SAFE_NO_PAD.decode(payload_b64) else {
return AudCheckOutcome::OpaqueToken;
};
let Ok(payload) = serde_json::from_slice::<serde_json::Value>(&payload_bytes) else {
return AudCheckOutcome::OpaqueToken;
};
let Some(aud_field) = payload.get("aud") else {
return AudCheckOutcome::MissingAud;
};
let found = collect_aud(aud_field);
if found.is_empty() {
return AudCheckOutcome::MissingAud;
}
if found.iter().any(|v| v == expected_resource) {
AudCheckOutcome::Match
} else {
AudCheckOutcome::Mismatch {
found,
expected: expected_resource.to_string(),
}
}
}
fn jwt_payload_segment(token: &str) -> Option<&str> {
let mut parts = token.split('.');
let _header = parts.next().filter(|s| !s.is_empty())?;
let payload = parts.next().filter(|s| !s.is_empty())?;
let _signature = parts.next().filter(|s| !s.is_empty())?;
if parts.next().is_some() {
return None;
}
Some(payload)
}
fn collect_aud(value: &serde_json::Value) -> Vec<String> {
match value {
serde_json::Value::String(s) => vec![s.clone()],
serde_json::Value::Array(items) => items
.iter()
.filter_map(|v| v.as_str().map(String::from))
.collect(),
_ => Vec::new(),
}
}
#[cfg(test)]
mod tests {
use super::*;
use base64::engine::general_purpose::URL_SAFE_NO_PAD;
pub(crate) fn fake_jwt(payload: serde_json::Value) -> String {
let header_b64 = URL_SAFE_NO_PAD.encode(b"{\"alg\":\"RS256\",\"typ\":\"JWT\"}");
let payload_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_vec(&payload).unwrap());
let sig_b64 = URL_SAFE_NO_PAD.encode(b"signature-bytes");
format!("{header_b64}.{payload_b64}.{sig_b64}")
}
#[test]
fn match_for_string_aud() {
let token = fake_jwt(serde_json::json!({
"aud": "https://server.example.com",
"sub": "user-1",
}));
let outcome = validate_token_aud(&token, "https://server.example.com");
assert_eq!(outcome, AudCheckOutcome::Match);
}
#[test]
fn match_for_array_aud_containing_expected() {
let token = fake_jwt(serde_json::json!({
"aud": ["https://other.example.com", "https://server.example.com"],
}));
let outcome = validate_token_aud(&token, "https://server.example.com");
assert_eq!(outcome, AudCheckOutcome::Match);
}
#[test]
fn mismatch_for_string_aud_pointing_elsewhere() {
let token = fake_jwt(serde_json::json!({
"aud": "https://other.example.com",
}));
let outcome = validate_token_aud(&token, "https://server.example.com");
match outcome {
AudCheckOutcome::Mismatch { found, expected } => {
assert_eq!(found, vec!["https://other.example.com".to_string()]);
assert_eq!(expected, "https://server.example.com");
}
other => panic!("expected Mismatch, got {other:?}"),
}
}
#[test]
fn mismatch_for_array_aud_without_expected() {
let token = fake_jwt(serde_json::json!({
"aud": ["a", "b", "c"],
}));
let outcome = validate_token_aud(&token, "z");
assert!(outcome.is_mismatch());
}
#[test]
fn missing_aud_when_payload_has_no_field() {
let token = fake_jwt(serde_json::json!({
"sub": "user-1",
}));
assert_eq!(
validate_token_aud(&token, "anything"),
AudCheckOutcome::MissingAud
);
}
#[test]
fn missing_aud_when_field_is_unexpected_type() {
let token = fake_jwt(serde_json::json!({ "aud": 42 }));
assert_eq!(
validate_token_aud(&token, "anything"),
AudCheckOutcome::MissingAud
);
let token = fake_jwt(serde_json::json!({ "aud": [123, 456] }));
assert_eq!(
validate_token_aud(&token, "anything"),
AudCheckOutcome::MissingAud
);
}
#[test]
fn opaque_token_when_not_three_segments() {
assert_eq!(
validate_token_aud("abc.def", "anything"),
AudCheckOutcome::OpaqueToken
);
assert_eq!(
validate_token_aud("just-a-random-string", "anything"),
AudCheckOutcome::OpaqueToken
);
assert_eq!(
validate_token_aud("a.b.c.d", "anything"),
AudCheckOutcome::OpaqueToken
);
assert_eq!(
validate_token_aud("a..c", "anything"),
AudCheckOutcome::OpaqueToken
);
}
#[test]
fn opaque_token_when_middle_segment_isnt_json() {
let header = URL_SAFE_NO_PAD.encode(b"{}");
let payload = URL_SAFE_NO_PAD.encode(b"not-actually-json");
let sig = URL_SAFE_NO_PAD.encode(b"sig");
let token = format!("{header}.{payload}.{sig}");
assert_eq!(
validate_token_aud(&token, "anything"),
AudCheckOutcome::OpaqueToken
);
}
#[test]
fn opaque_token_when_middle_segment_isnt_base64url() {
let token = "abc.@@@not-base64@@@.def";
assert_eq!(
validate_token_aud(token, "anything"),
AudCheckOutcome::OpaqueToken
);
}
}