modkit_auth/plugins/
oidc.rs1use crate::{
2 claims::{Claims, Permission},
3 claims_error::ClaimsError,
4 plugin_traits::ClaimsPlugin,
5 standard_claims::StandardClaim,
6 validation::{extract_audiences, extract_string, parse_timestamp, parse_uuid_from_value},
7};
8use serde_json::Value;
9
10#[derive(Debug, Clone)]
16pub struct GenericOidcPlugin {
17 pub tenant_claim: String,
19
20 pub roles_claim: String,
22}
23
24impl Default for GenericOidcPlugin {
25 fn default() -> Self {
26 Self {
27 tenant_claim: "tenant_id".to_owned(),
28 roles_claim: "roles".to_owned(),
29 }
30 }
31}
32
33impl GenericOidcPlugin {
34 pub fn new(tenant_claim: impl Into<String>, roles_claim: impl Into<String>) -> Self {
36 Self {
37 tenant_claim: tenant_claim.into(),
38 roles_claim: roles_claim.into(),
39 }
40 }
41
42 fn extract_permissions(&self, raw: &Value) -> Vec<Permission> {
44 let roles: Vec<String> = raw
45 .get(&self.roles_claim)
46 .and_then(|v| v.as_array())
47 .map(|arr| {
48 arr.iter()
49 .filter_map(|v| v.as_str())
50 .map(ToString::to_string)
51 .collect()
52 })
53 .unwrap_or_default();
54
55 roles
57 .into_iter()
58 .filter_map(|role| {
59 if let Some(pos) = role.rfind(':') {
61 Permission::builder()
62 .resource_pattern(&role[..pos])
63 .action(&role[pos + 1..])
64 .build()
65 .ok()
66 } else {
67 Permission::builder()
69 .resource_pattern(&role)
70 .action("*")
71 .build()
72 .ok()
73 }
74 })
75 .collect()
76 }
77}
78
79impl ClaimsPlugin for GenericOidcPlugin {
80 fn name(&self) -> &'static str {
81 "generic-oidc"
82 }
83
84 fn normalize(&self, raw: &Value) -> Result<Claims, ClaimsError> {
85 let subject = raw
87 .get(StandardClaim::SUB)
88 .ok_or_else(|| ClaimsError::MissingClaim(StandardClaim::SUB.to_owned()))
89 .and_then(|v| parse_uuid_from_value(v, StandardClaim::SUB))?;
90
91 let issuer = raw
93 .get(StandardClaim::ISS)
94 .ok_or_else(|| ClaimsError::MissingClaim(StandardClaim::ISS.to_owned()))
95 .and_then(|v| extract_string(v, StandardClaim::ISS))?;
96
97 let audiences = raw
99 .get(StandardClaim::AUD)
100 .map(extract_audiences)
101 .unwrap_or_default();
102
103 let expires_at = raw
105 .get(StandardClaim::EXP)
106 .map(|v| parse_timestamp(v, StandardClaim::EXP))
107 .transpose()?;
108
109 let not_before = raw
111 .get(StandardClaim::NBF)
112 .map(|v| parse_timestamp(v, StandardClaim::NBF))
113 .transpose()?;
114
115 let issued_at = raw
117 .get(StandardClaim::IAT)
118 .map(|v| parse_timestamp(v, StandardClaim::IAT))
119 .transpose()?;
120
121 let jwt_id = raw
123 .get(StandardClaim::JTI)
124 .and_then(|v| v.as_str())
125 .map(ToString::to_string);
126
127 let tenant_id = raw
129 .get(&self.tenant_claim)
130 .ok_or_else(|| ClaimsError::MissingClaim(self.tenant_claim.clone()))
131 .and_then(|v| parse_uuid_from_value(v, &self.tenant_claim))?;
132
133 let permissions = self.extract_permissions(raw);
135
136 let mut extras = serde_json::Map::new();
138
139 if let Value::Object(obj) = raw {
140 for (key, value) in obj {
141 let is_standard = StandardClaim::is_registered(key);
142 let is_tenant = key == &self.tenant_claim;
143 let is_roles = key == &self.roles_claim;
144
145 if !is_standard && !is_tenant && !is_roles {
146 extras.insert(key.clone(), value.clone());
147 }
148 }
149 }
150
151 for field in [
153 "email",
154 "name",
155 "preferred_username",
156 "given_name",
157 "family_name",
158 "picture",
159 ] {
160 if let Some(value) = raw.get(field) {
161 extras.insert(field.to_owned(), value.clone());
162 }
163 }
164
165 Ok(Claims {
166 issuer,
167 subject,
168 audiences,
169 expires_at,
170 not_before,
171 issued_at,
172 jwt_id,
173 tenant_id,
174 permissions,
175 extras,
176 })
177 }
178}
179
180#[cfg(test)]
181#[allow(clippy::unreadable_literal)]
182#[cfg_attr(coverage_nightly, coverage(off))]
183mod tests {
184 use super::*;
185 use serde_json::json;
186 use uuid::Uuid;
187
188 #[test]
189 fn test_generic_oidc_normalize() {
190 let plugin = GenericOidcPlugin::default();
191
192 let user_id = Uuid::new_v4();
193 let tenant_id = Uuid::new_v4();
194
195 let claims = json!({
196 "iss": "https://auth.example.com",
197 "sub": user_id.to_string(),
198 "aud": ["api", "ui"],
199 "exp": 9999999999i64,
200 "roles": ["users:read", "admin:write"],
201 "tenant_id": tenant_id.to_string(),
202 "email": "test@example.com",
203 "name": "Test User"
204 });
205
206 let normalized = plugin.normalize(&claims).unwrap();
207
208 assert_eq!(normalized.subject, user_id);
209 assert_eq!(normalized.issuer, "https://auth.example.com");
210 assert_eq!(normalized.audiences, vec!["api", "ui"]);
211 assert_eq!(normalized.tenant_id, tenant_id);
212 assert_eq!(normalized.permissions.len(), 2);
213 assert_eq!(
214 normalized.extras.get("email").unwrap().as_str().unwrap(),
215 "test@example.com"
216 );
217 assert_eq!(
218 normalized.extras.get("name").unwrap().as_str().unwrap(),
219 "Test User"
220 );
221 }
222
223 #[test]
224 fn test_generic_oidc_custom_claims() {
225 let plugin = GenericOidcPlugin::new("organization_id", "permissions");
226
227 let user_id = Uuid::new_v4();
228 let org_id = Uuid::new_v4();
229
230 let claims = json!({
231 "iss": "https://auth.example.com",
232 "sub": user_id.to_string(),
233 "aud": "api",
234 "permissions": ["read:*", "write:*"],
235 "organization_id": org_id.to_string()
236 });
237
238 let normalized = plugin.normalize(&claims).unwrap();
239
240 assert_eq!(normalized.tenant_id, org_id);
241 assert_eq!(normalized.permissions.len(), 2);
242 }
243
244 #[test]
245 fn test_generic_oidc_missing_subject_fails() {
246 let plugin = GenericOidcPlugin::default();
247
248 let claims = json!({
249 "iss": "https://auth.example.com",
250 "aud": "api"
251 });
252
253 let result = plugin.normalize(&claims);
254 assert!(matches!(result, Err(ClaimsError::MissingClaim(_))));
255 }
256
257 #[test]
258 fn test_generic_oidc_handles_string_audience() {
259 let plugin = GenericOidcPlugin::default();
260
261 let user_id = Uuid::new_v4();
262 let tenant_id = Uuid::new_v4();
263
264 let claims = json!({
265 "iss": "https://auth.example.com",
266 "sub": user_id.to_string(),
267 "aud": "api", "exp": 9999999999i64,
269 "tenant_id": tenant_id.to_string()
270 });
271
272 let normalized = plugin.normalize(&claims).unwrap();
273 assert_eq!(normalized.audiences, vec!["api"]);
274 }
275}