1use serde_json::Value;
19
20#[derive(Debug, Clone, PartialEq, Eq)]
22pub struct SsoIdentity {
23 pub issuer: String,
25 pub tenant: String,
28 pub subject: String,
32}
33
34#[derive(Debug, Clone, PartialEq, Eq)]
35pub enum SsoError {
36 MissingClaim(&'static str),
38}
39
40pub trait SsoProvider: Send + Sync {
44 fn id(&self) -> &'static str;
46 fn matches_issuer(&self, iss: &str) -> bool;
48 fn extract(&self, claims: &Value) -> Result<SsoIdentity, SsoError>;
50}
51
52fn claim<'a>(claims: &'a Value, key: &'static str) -> Result<&'a str, SsoError> {
53 claims
54 .get(key)
55 .and_then(Value::as_str)
56 .ok_or(SsoError::MissingClaim(key))
57}
58
59pub struct Google;
62impl SsoProvider for Google {
63 fn id(&self) -> &'static str {
64 "google"
65 }
66 fn matches_issuer(&self, iss: &str) -> bool {
67 iss == "https://accounts.google.com"
68 }
69 fn extract(&self, c: &Value) -> Result<SsoIdentity, SsoError> {
70 Ok(SsoIdentity {
71 issuer: claim(c, "iss")?.into(),
72 tenant: claim(c, "hd")?.into(),
73 subject: claim(c, "sub")?.into(),
74 })
75 }
76}
77
78pub struct AzureAd;
81impl SsoProvider for AzureAd {
82 fn id(&self) -> &'static str {
83 "azure"
84 }
85 fn matches_issuer(&self, iss: &str) -> bool {
86 iss.starts_with("https://login.microsoftonline.com/")
87 }
88 fn extract(&self, c: &Value) -> Result<SsoIdentity, SsoError> {
89 Ok(SsoIdentity {
90 issuer: claim(c, "iss")?.into(),
91 tenant: claim(c, "tid")?.into(),
92 subject: claim(c, "sub")?.into(),
93 })
94 }
95}
96
97pub struct Keycloak;
100impl SsoProvider for Keycloak {
101 fn id(&self) -> &'static str {
102 "keycloak"
103 }
104 fn matches_issuer(&self, iss: &str) -> bool {
105 iss.contains("/realms/")
106 }
107 fn extract(&self, c: &Value) -> Result<SsoIdentity, SsoError> {
108 let iss = claim(c, "iss")?;
109 let realm = iss
110 .rsplit("/realms/")
111 .next()
112 .and_then(|s| s.split('/').next())
113 .filter(|s| !s.is_empty())
114 .ok_or(SsoError::MissingClaim("realm"))?;
115 Ok(SsoIdentity {
116 issuer: iss.into(),
117 tenant: realm.into(),
118 subject: claim(c, "sub")?.into(),
119 })
120 }
121}
122
123pub struct Generic;
127impl SsoProvider for Generic {
128 fn id(&self) -> &'static str {
129 "generic"
130 }
131 fn matches_issuer(&self, _iss: &str) -> bool {
132 true
133 }
134 fn extract(&self, c: &Value) -> Result<SsoIdentity, SsoError> {
135 let iss = claim(c, "iss")?;
136 let host = iss
137 .strip_prefix("https://")
138 .or_else(|| iss.strip_prefix("http://"))
139 .unwrap_or(iss)
140 .split('/')
141 .next()
142 .unwrap_or(iss);
143 Ok(SsoIdentity {
144 issuer: iss.into(),
145 tenant: host.into(),
146 subject: claim(c, "sub")?.into(),
147 })
148 }
149}
150
151pub fn builtins() -> [&'static dyn SsoProvider; 4] {
154 [&Google, &AzureAd, &Keycloak, &Generic]
155}
156
157pub fn provider_for(iss: &str) -> &'static dyn SsoProvider {
160 builtins()
161 .into_iter()
162 .find(|p| p.matches_issuer(iss))
163 .expect("Generic matches all issuers")
164}
165
166pub fn normalize(claims: &Value) -> Result<(SsoIdentity, &'static str), SsoError> {
169 let iss = claim(claims, "iss")?;
170 let p = provider_for(iss);
171 Ok((p.extract(claims)?, p.id()))
172}
173
174#[cfg(test)]
175mod tests {
176 use super::*;
177 use serde_json::json;
178
179 #[test]
180 fn google_uses_hd_as_tenant() {
181 let c = json!({"iss": "https://accounts.google.com", "hd": "slanchaai.com", "sub": "117"});
182 let (id, prov) = normalize(&c).unwrap();
183 assert_eq!(prov, "google");
184 assert_eq!(id.tenant, "slanchaai.com");
185 assert_eq!(id.subject, "117");
186 }
187
188 #[test]
189 fn azure_uses_tid_as_tenant() {
190 let c = json!({"iss": "https://login.microsoftonline.com/abc-123/v2.0", "tid": "abc-123", "sub": "u9"});
191 let (id, prov) = normalize(&c).unwrap();
192 assert_eq!(prov, "azure");
193 assert_eq!(id.tenant, "abc-123");
194 }
195
196 #[test]
197 fn keycloak_extracts_realm_from_issuer() {
198 let c = json!({"iss": "https://id.example.com/realms/acme", "sub": "kc1"});
199 let (id, prov) = normalize(&c).unwrap();
200 assert_eq!(prov, "keycloak");
201 assert_eq!(id.tenant, "acme");
202 }
203
204 #[test]
205 fn generic_falls_back_to_issuer_host() {
206 let c = json!({"iss": "https://idp.unknown.example/", "sub": "g1"});
207 let (id, prov) = normalize(&c).unwrap();
208 assert_eq!(prov, "generic");
209 assert_eq!(id.tenant, "idp.unknown.example");
210 }
211
212 #[test]
213 fn missing_tenant_claim_errors() {
214 let c = json!({"iss": "https://accounts.google.com", "sub": "117"});
216 assert_eq!(normalize(&c), Err(SsoError::MissingClaim("hd")));
217 }
218
219 #[test]
221 fn a_new_provider_is_one_impl() {
222 struct Okta;
223 impl SsoProvider for Okta {
224 fn id(&self) -> &'static str {
225 "okta"
226 }
227 fn matches_issuer(&self, iss: &str) -> bool {
228 iss.ends_with(".okta.com")
229 }
230 fn extract(&self, c: &Value) -> Result<SsoIdentity, SsoError> {
231 let iss = claim(c, "iss")?;
232 let org = iss
233 .strip_prefix("https://")
234 .and_then(|h| h.split('.').next())
235 .ok_or(SsoError::MissingClaim("org"))?;
236 Ok(SsoIdentity {
237 issuer: iss.into(),
238 tenant: org.into(),
239 subject: claim(c, "sub")?.into(),
240 })
241 }
242 }
243 let okta = Okta;
244 assert!(okta.matches_issuer("https://slanchaai.okta.com"));
245 let c = json!({"iss": "https://slanchaai.okta.com", "sub": "ok1"});
246 assert_eq!(okta.extract(&c).unwrap().tenant, "slanchaai");
247 }
248}