Skip to main content

modkit_auth/plugins/
keycloak.rs

1use crate::{
2    claims::{Claims, Permission},
3    claims_error::ClaimsError,
4    plugin_traits::ClaimsPlugin,
5    standard_claims::StandardClaim,
6    validation::{extract_audiences, extract_string, parse_timestamp, parse_uuid_from_value},
7};
8use serde_json::Value;
9
10/// Keycloak-specific claims plugin
11///
12/// Handles Keycloak's specific claim structure:
13/// - Roles from `realm_access.roles` and `resource_access.<client>.roles`
14/// - Optional role prefix
15/// - Tenant claim from configurable field (default: `tenant_id`)
16/// - Handles Keycloak's audience validation via `aud`, `azp`, or `resource_access`
17#[derive(Debug, Clone)]
18pub struct KeycloakClaimsPlugin {
19    /// Name of the tenant claim field (default: `tenant_id`)
20    pub tenant_claim: String,
21
22    /// Optional: client ID to extract roles from `resource_access`
23    pub client_roles: Option<String>,
24
25    /// Optional: prefix to add to all roles
26    pub role_prefix: Option<String>,
27}
28
29impl Default for KeycloakClaimsPlugin {
30    fn default() -> Self {
31        Self {
32            tenant_claim: "tenant_id".to_owned(),
33            client_roles: None,
34            role_prefix: None,
35        }
36    }
37}
38
39impl KeycloakClaimsPlugin {
40    /// Create a new Keycloak plugin with custom configuration
41    pub fn new(
42        tenant_claim: impl Into<String>,
43        client_roles: Option<String>,
44        role_prefix: Option<String>,
45    ) -> Self {
46        Self {
47            tenant_claim: tenant_claim.into(),
48            client_roles,
49            role_prefix,
50        }
51    }
52
53    /// Extract permissions from Keycloak's complex role structure
54    fn extract_permissions(&self, raw: &Value) -> Vec<Permission> {
55        let mut roles = Vec::new();
56
57        // 1. Check for top-level "roles" array (simplified format)
58        if let Some(Value::Array(arr)) = raw.get("roles") {
59            roles.extend(
60                arr.iter()
61                    .filter_map(|v| v.as_str())
62                    .map(ToString::to_string),
63            );
64        }
65
66        // 2. Extract from realm_access.roles
67        if let Some(Value::Object(realm)) = raw.get("realm_access")
68            && let Some(Value::Array(arr)) = realm.get("roles")
69        {
70            roles.extend(
71                arr.iter()
72                    .filter_map(|v| v.as_str())
73                    .map(ToString::to_string),
74            );
75        }
76
77        // 3. Extract from resource_access.<client>.roles
78        if let Some(client_id) = &self.client_roles
79            && let Some(Value::Object(resource_access)) = raw.get("resource_access")
80            && let Some(Value::Object(client)) = resource_access.get(client_id)
81            && let Some(Value::Array(arr)) = client.get("roles")
82        {
83            roles.extend(
84                arr.iter()
85                    .filter_map(|v| v.as_str())
86                    .map(ToString::to_string),
87            );
88        }
89
90        // Apply role prefix if configured
91        if let Some(prefix) = &self.role_prefix {
92            roles = roles.into_iter().map(|r| format!("{prefix}:{r}")).collect();
93        }
94
95        // Deduplicate
96        roles.sort();
97        roles.dedup();
98
99        // Convert roles to permissions (resource_pattern:action format)
100        roles
101            .into_iter()
102            .filter_map(|role| {
103                // Try to parse as "resource:action" format
104                if let Some(pos) = role.rfind(':') {
105                    Permission::builder()
106                        .resource_pattern(&role[..pos])
107                        .action(&role[pos + 1..])
108                        .build()
109                        .ok()
110                } else {
111                    // Treat as resource with wildcard action
112                    Permission::builder()
113                        .resource_pattern(&role)
114                        .action("*")
115                        .build()
116                        .ok()
117                }
118            })
119            .collect()
120    }
121}
122
123impl ClaimsPlugin for KeycloakClaimsPlugin {
124    fn name(&self) -> &'static str {
125        "keycloak"
126    }
127
128    fn normalize(&self, raw: &Value) -> Result<Claims, ClaimsError> {
129        // 1. Extract subject (required, must be UUID)
130        let subject = raw
131            .get(StandardClaim::SUB)
132            .ok_or_else(|| ClaimsError::MissingClaim(StandardClaim::SUB.to_owned()))
133            .and_then(|v| parse_uuid_from_value(v, StandardClaim::SUB))?;
134
135        // 2. Extract issuer (required)
136        let issuer = raw
137            .get(StandardClaim::ISS)
138            .ok_or_else(|| ClaimsError::MissingClaim(StandardClaim::ISS.to_owned()))
139            .and_then(|v| extract_string(v, StandardClaim::ISS))?;
140
141        // 3. Extract audiences (handle string or array)
142        let audiences = raw
143            .get(StandardClaim::AUD)
144            .map(extract_audiences)
145            .unwrap_or_default();
146
147        // 4. Extract expiration time
148        let expires_at = raw
149            .get(StandardClaim::EXP)
150            .map(|v| parse_timestamp(v, StandardClaim::EXP))
151            .transpose()?;
152
153        // 5. Extract not-before time
154        let not_before = raw
155            .get(StandardClaim::NBF)
156            .map(|v| parse_timestamp(v, StandardClaim::NBF))
157            .transpose()?;
158
159        // 6. Extract issued-at time
160        let issued_at = raw
161            .get(StandardClaim::IAT)
162            .map(|v| parse_timestamp(v, StandardClaim::IAT))
163            .transpose()?;
164
165        // 7. Extract JWT ID
166        let jwt_id = raw
167            .get(StandardClaim::JTI)
168            .and_then(|v| v.as_str())
169            .map(ToString::to_string);
170
171        // 8. Extract tenant_id (required, must be UUID)
172        let tenant_id = raw
173            .get(&self.tenant_claim)
174            .ok_or_else(|| ClaimsError::MissingClaim(self.tenant_claim.clone()))
175            .and_then(|v| parse_uuid_from_value(v, &self.tenant_claim))?;
176
177        // 9. Extract permissions using Keycloak-specific logic
178        let permissions = self.extract_permissions(raw);
179
180        // 10. Collect extra claims (excluding standard ones and Keycloak-specific fields)
181        let mut extras = serde_json::Map::new();
182        let keycloak_fields = ["roles", "realm_access", "resource_access"];
183
184        if let Value::Object(obj) = raw {
185            for (key, value) in obj {
186                let is_standard = StandardClaim::is_registered(key);
187                let is_keycloak = keycloak_fields.contains(&key.as_str());
188                let is_tenant = key == &self.tenant_claim;
189
190                if !is_standard && !is_keycloak && !is_tenant {
191                    extras.insert(key.clone(), value.clone());
192                }
193            }
194        }
195
196        // Add email if present
197        if let Some(email) = raw.get("email") {
198            extras.insert("email".to_owned(), email.clone());
199        }
200
201        // Add preferred_username if present
202        if let Some(username) = raw.get("preferred_username") {
203            extras.insert("preferred_username".to_owned(), username.clone());
204        }
205
206        // Add name if present
207        if let Some(name) = raw.get("name") {
208            extras.insert("name".to_owned(), name.clone());
209        }
210
211        Ok(Claims {
212            issuer,
213            subject,
214            audiences,
215            expires_at,
216            not_before,
217            issued_at,
218            jwt_id,
219            tenant_id,
220            permissions,
221            extras,
222        })
223    }
224}
225
226#[cfg(test)]
227#[allow(clippy::unreadable_literal)]
228#[cfg_attr(coverage_nightly, coverage(off))]
229mod tests {
230    use super::*;
231    use serde_json::json;
232    use uuid::Uuid;
233
234    #[test]
235    fn test_keycloak_plugin_normalize() {
236        let plugin = KeycloakClaimsPlugin::default();
237
238        let user_id = Uuid::new_v4();
239        let tenant_id = Uuid::new_v4();
240
241        let claims = json!({
242            "iss": "https://kc.example.com/realms/test",
243            "sub": user_id.to_string(),
244            "aud": "modkit-api",
245            "exp": 9999999999i64,
246            "tenant_id": tenant_id.to_string(),
247            "realm_access": {
248                "roles": ["users:read", "admin:write"]
249            },
250            "email": "test@example.com"
251        });
252
253        let normalized = plugin.normalize(&claims).unwrap();
254
255        assert_eq!(normalized.subject, user_id);
256        assert_eq!(normalized.issuer, "https://kc.example.com/realms/test");
257        assert_eq!(normalized.audiences, vec!["modkit-api"]);
258        assert_eq!(normalized.tenant_id, tenant_id);
259        assert_eq!(normalized.permissions.len(), 2);
260        assert_eq!(
261            normalized.extras.get("email").unwrap().as_str().unwrap(),
262            "test@example.com"
263        );
264    }
265
266    #[test]
267    fn test_keycloak_extract_permissions_with_client() {
268        let plugin = KeycloakClaimsPlugin::new("tenant_id", Some("modkit-api".to_owned()), None);
269
270        let claims = json!({
271            "realm_access": {
272                "roles": ["realm:role"]
273            },
274            "resource_access": {
275                "modkit-api": {
276                    "roles": ["api:role"]
277                }
278            }
279        });
280
281        let permissions = plugin.extract_permissions(&claims);
282        assert_eq!(permissions.len(), 2);
283    }
284
285    #[test]
286    fn test_keycloak_extract_permissions_with_prefix() {
287        let plugin = KeycloakClaimsPlugin::new("tenant_id", None, Some("kc".to_owned()));
288
289        let claims = json!({
290            "realm_access": {
291                "roles": ["admin", "user"]
292            }
293        });
294
295        let permissions = plugin.extract_permissions(&claims);
296        assert_eq!(permissions.len(), 2);
297        // Prefixed roles become "kc:admin" and "kc:user", parsed as resource:action
298        assert!(
299            permissions
300                .iter()
301                .any(|p| p.resource_pattern() == "kc" && p.action() == "admin")
302        );
303        assert!(
304            permissions
305                .iter()
306                .any(|p| p.resource_pattern() == "kc" && p.action() == "user")
307        );
308    }
309
310    #[test]
311    fn test_keycloak_missing_subject_fails() {
312        let plugin = KeycloakClaimsPlugin::default();
313
314        let claims = json!({
315            "iss": "https://kc.example.com/realms/test",
316            "aud": "modkit-api"
317        });
318
319        let result = plugin.normalize(&claims);
320        assert!(matches!(result, Err(ClaimsError::MissingClaim(_))));
321    }
322
323    #[test]
324    fn test_keycloak_invalid_uuid_fails() {
325        let plugin = KeycloakClaimsPlugin::default();
326
327        let claims = json!({
328            "iss": "https://kc.example.com/realms/test",
329            "sub": "not-a-uuid",
330            "aud": "modkit-api"
331        });
332
333        let result = plugin.normalize(&claims);
334        assert!(matches!(
335            result,
336            Err(ClaimsError::InvalidClaimFormat { .. })
337        ));
338    }
339}