Skip to main content

modkit_auth/plugins/
oidc.rs

1use crate::{
2    claims::{Claims, Permission},
3    claims_error::ClaimsError,
4    plugin_traits::ClaimsPlugin,
5    standard_claims::StandardClaim,
6    validation::{extract_audiences, extract_string, parse_timestamp, parse_uuid_from_value},
7};
8use serde_json::Value;
9
10/// Generic OIDC claims plugin
11///
12/// Handles standard OIDC claims with configurable field names.
13/// This plugin serves as a fallback for any OIDC-compliant provider
14/// that doesn't need special handling.
15#[derive(Debug, Clone)]
16pub struct GenericOidcPlugin {
17    /// Name of the tenant claim field (default: `tenant_id`)
18    pub tenant_claim: String,
19
20    /// Name of the roles claim field (default: `roles`)
21    pub roles_claim: String,
22}
23
24impl Default for GenericOidcPlugin {
25    fn default() -> Self {
26        Self {
27            tenant_claim: "tenant_id".to_owned(),
28            roles_claim: "roles".to_owned(),
29        }
30    }
31}
32
33impl GenericOidcPlugin {
34    /// Create a new generic OIDC plugin with custom configuration
35    pub fn new(tenant_claim: impl Into<String>, roles_claim: impl Into<String>) -> Self {
36        Self {
37            tenant_claim: tenant_claim.into(),
38            roles_claim: roles_claim.into(),
39        }
40    }
41
42    /// Extract permissions from the configured roles claim
43    fn extract_permissions(&self, raw: &Value) -> Vec<Permission> {
44        let roles: Vec<String> = raw
45            .get(&self.roles_claim)
46            .and_then(|v| v.as_array())
47            .map(|arr| {
48                arr.iter()
49                    .filter_map(|v| v.as_str())
50                    .map(ToString::to_string)
51                    .collect()
52            })
53            .unwrap_or_default();
54
55        // Convert roles to permissions (resource_pattern:action format)
56        roles
57            .into_iter()
58            .filter_map(|role| {
59                // Try to parse as "resource:action" format
60                if let Some(pos) = role.rfind(':') {
61                    Permission::builder()
62                        .resource_pattern(&role[..pos])
63                        .action(&role[pos + 1..])
64                        .build()
65                        .ok()
66                } else {
67                    // Treat as resource with wildcard action
68                    Permission::builder()
69                        .resource_pattern(&role)
70                        .action("*")
71                        .build()
72                        .ok()
73                }
74            })
75            .collect()
76    }
77}
78
79impl ClaimsPlugin for GenericOidcPlugin {
80    fn name(&self) -> &'static str {
81        "generic-oidc"
82    }
83
84    fn normalize(&self, raw: &Value) -> Result<Claims, ClaimsError> {
85        // 1. Extract subject (required, must be UUID)
86        let subject = raw
87            .get(StandardClaim::SUB)
88            .ok_or_else(|| ClaimsError::MissingClaim(StandardClaim::SUB.to_owned()))
89            .and_then(|v| parse_uuid_from_value(v, StandardClaim::SUB))?;
90
91        // 2. Extract issuer (required)
92        let issuer = raw
93            .get(StandardClaim::ISS)
94            .ok_or_else(|| ClaimsError::MissingClaim(StandardClaim::ISS.to_owned()))
95            .and_then(|v| extract_string(v, StandardClaim::ISS))?;
96
97        // 3. Extract audiences (handle string or array)
98        let audiences = raw
99            .get(StandardClaim::AUD)
100            .map(extract_audiences)
101            .unwrap_or_default();
102
103        // 4. Extract expiration time
104        let expires_at = raw
105            .get(StandardClaim::EXP)
106            .map(|v| parse_timestamp(v, StandardClaim::EXP))
107            .transpose()?;
108
109        // 5. Extract not-before time
110        let not_before = raw
111            .get(StandardClaim::NBF)
112            .map(|v| parse_timestamp(v, StandardClaim::NBF))
113            .transpose()?;
114
115        // 6. Extract issued-at time
116        let issued_at = raw
117            .get(StandardClaim::IAT)
118            .map(|v| parse_timestamp(v, StandardClaim::IAT))
119            .transpose()?;
120
121        // 7. Extract JWT ID
122        let jwt_id = raw
123            .get(StandardClaim::JTI)
124            .and_then(|v| v.as_str())
125            .map(ToString::to_string);
126
127        // 8. Extract tenant_id (required, must be UUID)
128        let tenant_id = raw
129            .get(&self.tenant_claim)
130            .ok_or_else(|| ClaimsError::MissingClaim(self.tenant_claim.clone()))
131            .and_then(|v| parse_uuid_from_value(v, &self.tenant_claim))?;
132
133        // 9. Extract permissions from configured field
134        let permissions = self.extract_permissions(raw);
135
136        // 10. Collect extra claims (excluding standard ones)
137        let mut extras = serde_json::Map::new();
138
139        if let Value::Object(obj) = raw {
140            for (key, value) in obj {
141                let is_standard = StandardClaim::is_registered(key);
142                let is_tenant = key == &self.tenant_claim;
143                let is_roles = key == &self.roles_claim;
144
145                if !is_standard && !is_tenant && !is_roles {
146                    extras.insert(key.clone(), value.clone());
147                }
148            }
149        }
150
151        // Explicitly add common OIDC profile claims to extras
152        for field in [
153            "email",
154            "name",
155            "preferred_username",
156            "given_name",
157            "family_name",
158            "picture",
159        ] {
160            if let Some(value) = raw.get(field) {
161                extras.insert(field.to_owned(), value.clone());
162            }
163        }
164
165        Ok(Claims {
166            issuer,
167            subject,
168            audiences,
169            expires_at,
170            not_before,
171            issued_at,
172            jwt_id,
173            tenant_id,
174            permissions,
175            extras,
176        })
177    }
178}
179
180#[cfg(test)]
181#[allow(clippy::unreadable_literal)]
182#[cfg_attr(coverage_nightly, coverage(off))]
183mod tests {
184    use super::*;
185    use serde_json::json;
186    use uuid::Uuid;
187
188    #[test]
189    fn test_generic_oidc_normalize() {
190        let plugin = GenericOidcPlugin::default();
191
192        let user_id = Uuid::new_v4();
193        let tenant_id = Uuid::new_v4();
194
195        let claims = json!({
196            "iss": "https://auth.example.com",
197            "sub": user_id.to_string(),
198            "aud": ["api", "ui"],
199            "exp": 9999999999i64,
200            "roles": ["users:read", "admin:write"],
201            "tenant_id": tenant_id.to_string(),
202            "email": "test@example.com",
203            "name": "Test User"
204        });
205
206        let normalized = plugin.normalize(&claims).unwrap();
207
208        assert_eq!(normalized.subject, user_id);
209        assert_eq!(normalized.issuer, "https://auth.example.com");
210        assert_eq!(normalized.audiences, vec!["api", "ui"]);
211        assert_eq!(normalized.tenant_id, tenant_id);
212        assert_eq!(normalized.permissions.len(), 2);
213        assert_eq!(
214            normalized.extras.get("email").unwrap().as_str().unwrap(),
215            "test@example.com"
216        );
217        assert_eq!(
218            normalized.extras.get("name").unwrap().as_str().unwrap(),
219            "Test User"
220        );
221    }
222
223    #[test]
224    fn test_generic_oidc_custom_claims() {
225        let plugin = GenericOidcPlugin::new("organization_id", "permissions");
226
227        let user_id = Uuid::new_v4();
228        let org_id = Uuid::new_v4();
229
230        let claims = json!({
231            "iss": "https://auth.example.com",
232            "sub": user_id.to_string(),
233            "aud": "api",
234            "permissions": ["read:*", "write:*"],
235            "organization_id": org_id.to_string()
236        });
237
238        let normalized = plugin.normalize(&claims).unwrap();
239
240        assert_eq!(normalized.tenant_id, org_id);
241        assert_eq!(normalized.permissions.len(), 2);
242    }
243
244    #[test]
245    fn test_generic_oidc_missing_subject_fails() {
246        let plugin = GenericOidcPlugin::default();
247
248        let claims = json!({
249            "iss": "https://auth.example.com",
250            "aud": "api"
251        });
252
253        let result = plugin.normalize(&claims);
254        assert!(matches!(result, Err(ClaimsError::MissingClaim(_))));
255    }
256
257    #[test]
258    fn test_generic_oidc_handles_string_audience() {
259        let plugin = GenericOidcPlugin::default();
260
261        let user_id = Uuid::new_v4();
262        let tenant_id = Uuid::new_v4();
263
264        let claims = json!({
265            "iss": "https://auth.example.com",
266            "sub": user_id.to_string(),
267            "aud": "api",  // String instead of array
268            "exp": 9999999999i64,
269            "tenant_id": tenant_id.to_string()
270        });
271
272        let normalized = plugin.normalize(&claims).unwrap();
273        assert_eq!(normalized.audiences, vec!["api"]);
274    }
275}