use anyhow::{Context, Result, anyhow, bail};
use base64::Engine;
use base64::engine::general_purpose::URL_SAFE_NO_PAD;
use openidconnect::JsonWebKey as _;
use openidconnect::core::CoreJwsSigningAlgorithm;
pub(super) fn extract_optional_string_claim(
name: &str,
claims: &serde_json::Value,
) -> Option<String> {
match claims.get(name).and_then(|v| v.as_str()) {
Some(s) if !s.is_empty() => Some(s.to_owned()),
_ => None,
}
}
pub(super) fn extract_string_claim(
name: &str,
claims: &serde_json::Value,
default: &str,
) -> String {
match claims.get(name).and_then(|v| v.as_str()) {
Some(s) if !s.is_empty() => s.to_owned(),
_ => default.to_owned(),
}
}
pub(super) struct ParsedJws {
pub(super) alg: CoreJwsSigningAlgorithm,
pub(super) kid: Option<String>,
pub(super) signed_input: String,
pub(super) signature_bytes: Vec<u8>,
pub(super) payload: serde_json::Value,
}
pub(super) fn jwks_signature_verifies(
jwks: &openidconnect::core::CoreJsonWebKeySet,
parsed: &ParsedJws,
) -> bool {
let signed = parsed.signed_input.as_bytes();
let sig = parsed.signature_bytes.as_slice();
jwks.keys()
.iter()
.filter(|k| match (k.key_id(), parsed.kid.as_deref()) {
(Some(jwk_kid), Some(hdr_kid)) => jwk_kid.as_str() == hdr_kid,
_ => parsed.kid.is_none(),
})
.any(|k| k.verify_signature(&parsed.alg, signed, sig).is_ok())
}
pub(super) fn parse_compact_jws(token: &str) -> Result<ParsedJws> {
let mut parts = token.split('.');
let h = parts.next().ok_or_else(|| anyhow!("malformed JWS"))?;
let p = parts.next().ok_or_else(|| anyhow!("malformed JWS"))?;
let s = parts.next().ok_or_else(|| anyhow!("malformed JWS"))?;
if parts.next().is_some() {
bail!("malformed JWS: too many segments");
}
let header_bytes = URL_SAFE_NO_PAD
.decode(h)
.context("logout_token header base64url decode")?;
let payload_bytes = URL_SAFE_NO_PAD
.decode(p)
.context("logout_token payload base64url decode")?;
let signature_bytes = URL_SAFE_NO_PAD
.decode(s)
.context("logout_token signature base64url decode")?;
let header: serde_json::Value = serde_json::from_slice(&header_bytes)
.context("logout_token header JSON parse")?;
let payload: serde_json::Value = serde_json::from_slice(&payload_bytes)
.context("logout_token payload JSON parse")?;
let alg_str = header
.get("alg")
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow!("logout_token header missing alg"))?;
let alg: CoreJwsSigningAlgorithm =
serde_json::from_value(serde_json::Value::String(alg_str.to_owned()))
.with_context(|| {
format!("logout_token alg '{alg_str}' is not recognised")
})?;
let kid = header
.get("kid")
.and_then(|v| v.as_str())
.map(|s| s.to_owned());
Ok(ParsedJws {
alg,
kid,
signed_input: format!("{h}.{p}"),
signature_bytes,
payload,
})
}
pub(super) fn extract_groups_claim(
name: &str,
claims: &serde_json::Value,
) -> Vec<String> {
extract_groups_claim_from_json(name, claims)
}
pub(super) fn extract_groups_claim_from_json(
name: &str,
json: &serde_json::Value,
) -> Vec<String> {
match json.get(name) {
Some(v) => extract_groups_claim_from_value(v),
None => Vec::new(),
}
}
fn extract_groups_claim_from_value(
v: &serde_json::Value,
) -> Vec<String> {
match v {
serde_json::Value::Array(items) => items
.iter()
.filter_map(|v| v.as_str().map(|s| s.to_owned()))
.filter(|s| !s.is_empty())
.collect(),
serde_json::Value::String(s) => s
.split_whitespace()
.map(|w| w.to_owned())
.collect(),
_ => Vec::new(),
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_jws(header: &[u8], payload: &[u8], sig: &[u8]) -> String {
format!(
"{}.{}.{}",
URL_SAFE_NO_PAD.encode(header),
URL_SAFE_NO_PAD.encode(payload),
URL_SAFE_NO_PAD.encode(sig),
)
}
#[test]
fn parse_compact_jws_success_rs256() {
let token = make_jws(
br#"{"alg":"RS256"}"#,
br#"{"sub":"user"}"#,
b"\x00\x01\x02",
);
let parsed = parse_compact_jws(&token).unwrap();
assert_eq!(parsed.kid, None);
assert_eq!(parsed.payload["sub"], "user");
assert!(parsed.signed_input.contains('.'));
assert_eq!(parsed.signature_bytes, b"\x00\x01\x02");
}
#[test]
fn parse_compact_jws_with_kid() {
let token = make_jws(
br#"{"alg":"RS256","kid":"mykey"}"#,
br#"{"sub":"u"}"#,
b"\xAB\xCD",
);
let parsed = parse_compact_jws(&token).unwrap();
assert_eq!(parsed.kid.as_deref(), Some("mykey"));
}
#[test]
fn parse_compact_jws_too_few_parts() {
assert!(parse_compact_jws("a.b").is_err());
}
#[test]
fn parse_compact_jws_too_many_parts() {
assert!(parse_compact_jws("a.b.c.d").is_err());
}
#[test]
fn parse_compact_jws_bad_base64_header() {
assert!(parse_compact_jws("!!!.e30.AAAA").is_err());
}
#[test]
fn parse_compact_jws_bad_json_header() {
let h = URL_SAFE_NO_PAD.encode(b"not-json");
let token = format!("{h}.e30.AAAA");
assert!(parse_compact_jws(&token).is_err());
}
#[test]
fn parse_compact_jws_missing_alg() {
let token = make_jws(br#"{"kid":"x"}"#, b"{}", b"\x00");
assert!(parse_compact_jws(&token).is_err());
}
#[test]
fn parse_compact_jws_unknown_alg() {
let token = make_jws(br#"{"alg":"BOGUS"}"#, b"{}", b"\x00");
assert!(parse_compact_jws(&token).is_err());
}
#[test]
fn parse_compact_jws_bad_base64_payload() {
let h = URL_SAFE_NO_PAD.encode(br#"{"alg":"RS256"}"#);
let token = format!("{h}.!!!.AAAA");
assert!(parse_compact_jws(&token).is_err());
}
#[test]
fn parse_compact_jws_bad_json_payload() {
let h = URL_SAFE_NO_PAD.encode(br#"{"alg":"RS256"}"#);
let p = URL_SAFE_NO_PAD.encode(b"not-json");
let token = format!("{h}.{p}.AAAA");
assert!(parse_compact_jws(&token).is_err());
}
#[test]
fn groups_from_json_array_filters_empty_strings() {
let json = serde_json::json!({"g": ["a", "", "b", ""]});
assert_eq!(
extract_groups_claim_from_json("g", &json),
["a", "b"],
);
}
#[test]
fn groups_from_json_array_skips_non_string_items() {
let json = serde_json::json!({"g": ["a", 42, true, "b"]});
assert_eq!(
extract_groups_claim_from_json("g", &json),
["a", "b"],
);
}
#[test]
fn groups_from_json_missing_claim_returns_empty() {
assert!(
extract_groups_claim_from_json("g", &serde_json::json!({}))
.is_empty()
);
}
#[test]
fn groups_from_json_non_string_non_array_returns_empty() {
let json = serde_json::json!({"g": 42});
assert!(extract_groups_claim_from_json("g", &json).is_empty());
}
#[test]
fn groups_from_space_delimited_string_splits_on_whitespace() {
let json = serde_json::json!({"g": "a b\tc"});
assert_eq!(
extract_groups_claim_from_json("g", &json),
["a", "b", "c"],
);
}
#[test]
fn groups_from_single_bare_string_yields_one_group() {
let json = serde_json::json!({"g": "solo"});
assert_eq!(
extract_groups_claim_from_json("g", &json),
["solo"],
);
}
#[test]
fn groups_from_empty_string_returns_empty() {
let json = serde_json::json!({"g": " "});
assert!(extract_groups_claim_from_json("g", &json).is_empty());
}
#[test]
fn groups_from_bool_value_returns_empty() {
let json = serde_json::json!({"g": true});
assert!(extract_groups_claim_from_json("g", &json).is_empty());
}
}