use serde_json::Value;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct SsoIdentity {
pub issuer: String,
pub tenant: String,
pub subject: String,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum SsoError {
MissingClaim(&'static str),
}
pub trait SsoProvider: Send + Sync {
fn id(&self) -> &'static str;
fn matches_issuer(&self, iss: &str) -> bool;
fn extract(&self, claims: &Value) -> Result<SsoIdentity, SsoError>;
}
fn claim<'a>(claims: &'a Value, key: &'static str) -> Result<&'a str, SsoError> {
claims
.get(key)
.and_then(Value::as_str)
.ok_or(SsoError::MissingClaim(key))
}
pub struct Google;
impl SsoProvider for Google {
fn id(&self) -> &'static str {
"google"
}
fn matches_issuer(&self, iss: &str) -> bool {
iss == "https://accounts.google.com"
}
fn extract(&self, c: &Value) -> Result<SsoIdentity, SsoError> {
Ok(SsoIdentity {
issuer: claim(c, "iss")?.into(),
tenant: claim(c, "hd")?.into(),
subject: claim(c, "sub")?.into(),
})
}
}
pub struct AzureAd;
impl SsoProvider for AzureAd {
fn id(&self) -> &'static str {
"azure"
}
fn matches_issuer(&self, iss: &str) -> bool {
iss.starts_with("https://login.microsoftonline.com/")
}
fn extract(&self, c: &Value) -> Result<SsoIdentity, SsoError> {
Ok(SsoIdentity {
issuer: claim(c, "iss")?.into(),
tenant: claim(c, "tid")?.into(),
subject: claim(c, "sub")?.into(),
})
}
}
pub struct Keycloak;
impl SsoProvider for Keycloak {
fn id(&self) -> &'static str {
"keycloak"
}
fn matches_issuer(&self, iss: &str) -> bool {
iss.contains("/realms/")
}
fn extract(&self, c: &Value) -> Result<SsoIdentity, SsoError> {
let iss = claim(c, "iss")?;
let realm = iss
.rsplit("/realms/")
.next()
.and_then(|s| s.split('/').next())
.filter(|s| !s.is_empty())
.ok_or(SsoError::MissingClaim("realm"))?;
Ok(SsoIdentity {
issuer: iss.into(),
tenant: realm.into(),
subject: claim(c, "sub")?.into(),
})
}
}
pub struct Generic;
impl SsoProvider for Generic {
fn id(&self) -> &'static str {
"generic"
}
fn matches_issuer(&self, _iss: &str) -> bool {
true
}
fn extract(&self, c: &Value) -> Result<SsoIdentity, SsoError> {
let iss = claim(c, "iss")?;
let host = iss
.strip_prefix("https://")
.or_else(|| iss.strip_prefix("http://"))
.unwrap_or(iss)
.split('/')
.next()
.unwrap_or(iss);
Ok(SsoIdentity {
issuer: iss.into(),
tenant: host.into(),
subject: claim(c, "sub")?.into(),
})
}
}
pub fn builtins() -> [&'static dyn SsoProvider; 4] {
[&Google, &AzureAd, &Keycloak, &Generic]
}
pub fn provider_for(iss: &str) -> &'static dyn SsoProvider {
builtins()
.into_iter()
.find(|p| p.matches_issuer(iss))
.expect("Generic matches all issuers")
}
pub fn normalize(claims: &Value) -> Result<(SsoIdentity, &'static str), SsoError> {
let iss = claim(claims, "iss")?;
let p = provider_for(iss);
Ok((p.extract(claims)?, p.id()))
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn google_uses_hd_as_tenant() {
let c = json!({"iss": "https://accounts.google.com", "hd": "slanchaai.com", "sub": "117"});
let (id, prov) = normalize(&c).unwrap();
assert_eq!(prov, "google");
assert_eq!(id.tenant, "slanchaai.com");
assert_eq!(id.subject, "117");
}
#[test]
fn azure_uses_tid_as_tenant() {
let c = json!({"iss": "https://login.microsoftonline.com/abc-123/v2.0", "tid": "abc-123", "sub": "u9"});
let (id, prov) = normalize(&c).unwrap();
assert_eq!(prov, "azure");
assert_eq!(id.tenant, "abc-123");
}
#[test]
fn keycloak_extracts_realm_from_issuer() {
let c = json!({"iss": "https://id.example.com/realms/acme", "sub": "kc1"});
let (id, prov) = normalize(&c).unwrap();
assert_eq!(prov, "keycloak");
assert_eq!(id.tenant, "acme");
}
#[test]
fn generic_falls_back_to_issuer_host() {
let c = json!({"iss": "https://idp.unknown.example/", "sub": "g1"});
let (id, prov) = normalize(&c).unwrap();
assert_eq!(prov, "generic");
assert_eq!(id.tenant, "idp.unknown.example");
}
#[test]
fn missing_tenant_claim_errors() {
let c = json!({"iss": "https://accounts.google.com", "sub": "117"});
assert_eq!(normalize(&c), Err(SsoError::MissingClaim("hd")));
}
#[test]
fn a_new_provider_is_one_impl() {
struct Okta;
impl SsoProvider for Okta {
fn id(&self) -> &'static str {
"okta"
}
fn matches_issuer(&self, iss: &str) -> bool {
iss.ends_with(".okta.com")
}
fn extract(&self, c: &Value) -> Result<SsoIdentity, SsoError> {
let iss = claim(c, "iss")?;
let org = iss
.strip_prefix("https://")
.and_then(|h| h.split('.').next())
.ok_or(SsoError::MissingClaim("org"))?;
Ok(SsoIdentity {
issuer: iss.into(),
tenant: org.into(),
subject: claim(c, "sub")?.into(),
})
}
}
let okta = Okta;
assert!(okta.matches_issuer("https://slanchaai.okta.com"));
let c = json!({"iss": "https://slanchaai.okta.com", "sub": "ok1"});
assert_eq!(okta.extract(&c).unwrap().tenant, "slanchaai");
}
}