authx_plugins/oidc_federation/
service.rs1use std::collections::HashMap;
4use std::sync::Arc;
5use std::time::{Duration, Instant};
6
7use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
8use chrono::Utc;
9use rand::Rng;
10use serde::Deserialize;
11use sha2::{Digest, Sha256};
12use tracing::instrument;
13use uuid::Uuid;
14
15use authx_core::{
16 crypto::{decrypt, encrypt, sha256_hex},
17 error::{AuthError, Result},
18 models::{ClaimMappingRule, CreateSession, CreateUser, Session, UpsertOAuthAccount, User},
19};
20use authx_storage::ports::{
21 OAuthAccountRepository, OidcFederationProviderRepository, OrgRepository, SessionRepository,
22 UserRepository,
23};
24
25#[derive(Debug)]
27pub struct OidcFederationBeginResponse {
28 pub authorization_url: String,
29 pub state: String,
30 pub code_verifier: String,
31}
32
33#[derive(Debug, Deserialize)]
35struct OidcDiscovery {
36 authorization_endpoint: String,
37 token_endpoint: String,
38 #[serde(default)]
39 userinfo_endpoint: Option<String>,
40}
41
42#[derive(Debug, Deserialize)]
44pub struct OidcUserInfo {
45 pub sub: String,
46 #[serde(default)]
47 pub email: Option<String>,
48 #[serde(default)]
49 pub email_verified: Option<bool>,
50 #[serde(default)]
51 pub name: Option<String>,
52 #[serde(default)]
53 pub preferred_username: Option<String>,
54 #[serde(flatten)]
56 pub extra: serde_json::Value,
57}
58
59struct FederationState {
61 code_verifier: String,
62 redirect_uri: String,
63 expires_at: Instant,
64}
65
66pub struct OidcFederationService<S> {
68 storage: S,
69 session_ttl_secs: i64,
70 encryption_key: [u8; 32],
71 client: reqwest::Client,
72 pending: Arc<std::sync::RwLock<HashMap<String, FederationState>>>,
74}
75
76impl<S> OidcFederationService<S>
77where
78 S: OidcFederationProviderRepository
79 + UserRepository
80 + SessionRepository
81 + OAuthAccountRepository
82 + OrgRepository
83 + Clone
84 + Send
85 + Sync
86 + 'static,
87{
88 pub fn new(storage: S, session_ttl_secs: i64, encryption_key: [u8; 32]) -> Self {
89 Self {
90 storage,
91 session_ttl_secs,
92 encryption_key,
93 client: reqwest::Client::new(),
94 pending: Arc::new(std::sync::RwLock::new(HashMap::new())),
95 }
96 }
97
98 #[instrument(skip(self))]
100 pub async fn begin(
101 &self,
102 provider_name: &str,
103 redirect_uri: &str,
104 ) -> Result<OidcFederationBeginResponse> {
105 let provider = OidcFederationProviderRepository::find_by_name(&self.storage, provider_name)
106 .await?
107 .ok_or_else(|| {
108 AuthError::Internal(format!("unknown federation provider: {provider_name}"))
109 })?;
110
111 if !provider.enabled {
112 return Err(AuthError::Internal("provider is disabled".into()));
113 }
114
115 let discovery_url = format!(
116 "{}/.well-known/openid-configuration",
117 provider.issuer.trim_end_matches('/')
118 );
119 let discovery: OidcDiscovery = self
120 .client
121 .get(&discovery_url)
122 .send()
123 .await
124 .map_err(|e| AuthError::Internal(format!("oidc discovery failed: {e}")))?
125 .json()
126 .await
127 .map_err(|e| AuthError::Internal(format!("oidc discovery parse failed: {e}")))?;
128
129 let verifier_bytes: [u8; 32] = rand::thread_rng().r#gen();
130 let code_verifier = URL_SAFE_NO_PAD.encode(verifier_bytes);
131 let mut hasher = Sha256::new();
132 hasher.update(code_verifier.as_bytes());
133 let code_challenge = URL_SAFE_NO_PAD.encode(hasher.finalize());
134
135 let state_bytes: [u8; 16] = rand::thread_rng().r#gen();
136 let state = hex::encode(state_bytes);
137
138 {
139 let mut pending = self
140 .pending
141 .write()
142 .map_err(|e| AuthError::Internal(format!("lock poisoned: {e}")))?;
143 pending.insert(
144 state.clone(),
145 FederationState {
146 code_verifier: code_verifier.clone(),
147 redirect_uri: redirect_uri.to_string(),
148 expires_at: Instant::now() + Duration::from_secs(600), },
150 );
151 }
152
153 let mut auth_url = reqwest::Url::parse(&discovery.authorization_endpoint)
154 .map_err(|e| AuthError::Internal(format!("invalid auth endpoint: {e}")))?;
155 auth_url
156 .query_pairs_mut()
157 .append_pair("response_type", "code")
158 .append_pair("client_id", &provider.client_id)
159 .append_pair("redirect_uri", redirect_uri)
160 .append_pair("scope", &provider.scopes)
161 .append_pair("state", &state)
162 .append_pair("code_challenge", &code_challenge)
163 .append_pair("code_challenge_method", "S256");
164
165 Ok(OidcFederationBeginResponse {
166 authorization_url: auth_url.to_string(),
167 state,
168 code_verifier,
169 })
170 }
171
172 #[instrument(skip(self))]
175 pub async fn callback(
176 &self,
177 provider_name: &str,
178 code: &str,
179 state: &str,
180 ip: &str,
181 ) -> Result<(User, Session, String)> {
182 let (code_verifier, redirect_uri) = {
183 let mut pending = self
184 .pending
185 .write()
186 .map_err(|e| AuthError::Internal(format!("lock poisoned: {e}")))?;
187 let entry = pending.remove(state).ok_or(AuthError::InvalidToken)?;
188 if entry.expires_at < Instant::now() {
189 return Err(AuthError::InvalidToken);
190 }
191 (entry.code_verifier, entry.redirect_uri)
192 };
193
194 let provider = OidcFederationProviderRepository::find_by_name(&self.storage, provider_name)
195 .await?
196 .ok_or_else(|| {
197 AuthError::Internal(format!("unknown federation provider: {provider_name}"))
198 })?;
199
200 let secret_bytes = decrypt(&self.encryption_key, &provider.secret_enc)
201 .map_err(|e| AuthError::Internal(format!("decrypt client secret: {e}")))?;
202 let secret = String::from_utf8(secret_bytes)
203 .map_err(|_| AuthError::Internal("client secret not valid UTF-8".into()))?;
204
205 let discovery_url = format!(
206 "{}/.well-known/openid-configuration",
207 provider.issuer.trim_end_matches('/')
208 );
209 let discovery: OidcDiscovery = self
210 .client
211 .get(&discovery_url)
212 .send()
213 .await
214 .map_err(|e| AuthError::Internal(format!("oidc discovery: {e}")))?
215 .json()
216 .await
217 .map_err(|e| AuthError::Internal(format!("oidc discovery parse: {e}")))?;
218
219 let token_resp = self
220 .client
221 .post(&discovery.token_endpoint)
222 .form(&[
223 ("grant_type", "authorization_code"),
224 ("code", code),
225 ("redirect_uri", &redirect_uri),
226 ("client_id", &provider.client_id),
227 ("client_secret", &secret),
228 ("code_verifier", &code_verifier),
229 ])
230 .send()
231 .await
232 .map_err(|e| AuthError::Internal(format!("token exchange: {e}")))?;
233
234 if !token_resp.status().is_success() {
235 let status = token_resp.status();
236 let body = token_resp.text().await.unwrap_or_default();
237 return Err(AuthError::Internal(format!(
238 "token exchange failed {}: {}",
239 status, body
240 )));
241 }
242
243 #[derive(Deserialize)]
244 struct TokenResponse {
245 access_token: String,
246 }
247 let tokens: TokenResponse = token_resp
248 .json()
249 .await
250 .map_err(|e| AuthError::Internal(format!("token parse: {e}")))?;
251
252 let userinfo_endpoint = discovery
253 .userinfo_endpoint
254 .ok_or_else(|| AuthError::Internal("IdP has no userinfo endpoint".into()))?;
255
256 let userinfo: OidcUserInfo = self
257 .client
258 .get(&userinfo_endpoint)
259 .bearer_auth(&tokens.access_token)
260 .send()
261 .await
262 .map_err(|e| AuthError::Internal(format!("userinfo: {e}")))?
263 .json()
264 .await
265 .map_err(|e| AuthError::Internal(format!("userinfo parse: {e}")))?;
266
267 let username = userinfo
268 .preferred_username
269 .clone()
270 .or_else(|| userinfo.name.clone());
271 let email = userinfo
272 .email
273 .filter(|_| userinfo.email_verified.unwrap_or(true))
274 .or_else(|| userinfo.preferred_username.clone())
275 .unwrap_or_else(|| format!("{}@{}", userinfo.sub, provider_name));
276
277 let user = match UserRepository::find_by_email(&self.storage, &email).await? {
278 Some(u) => u,
279 None => {
280 UserRepository::create(
281 &self.storage,
282 CreateUser {
283 email: email.clone(),
284 username,
285 metadata: None,
286 },
287 )
288 .await?
289 }
290 };
291
292 let access_enc = encrypt(&self.encryption_key, tokens.access_token.as_bytes())
293 .map_err(|e| AuthError::Internal(format!("encrypt: {e}")))?;
294
295 OAuthAccountRepository::upsert(
296 &self.storage,
297 UpsertOAuthAccount {
298 user_id: user.id,
299 provider: provider_name.to_string(),
300 provider_user_id: userinfo.sub,
301 access_token_enc: access_enc,
302 refresh_token_enc: None,
303 expires_at: None,
304 },
305 )
306 .await?;
307
308 let session_org_id: Option<Uuid> = self
310 .apply_claim_mapping(user.id, &provider, &userinfo.extra)
311 .await;
312
313 let raw: [u8; 32] = rand::thread_rng().r#gen();
314 let raw_str = hex::encode(raw);
315 let token_hash = sha256_hex(raw_str.as_bytes());
316
317 let session = SessionRepository::create(
318 &self.storage,
319 CreateSession {
320 user_id: user.id,
321 token_hash,
322 device_info: serde_json::json!({ "oidc_federation": provider_name }),
323 ip_address: ip.to_string(),
324 org_id: session_org_id.or(provider.org_id),
325 expires_at: Utc::now() + chrono::Duration::seconds(self.session_ttl_secs),
326 },
327 )
328 .await?;
329
330 tracing::info!(user_id = %user.id, provider = provider_name, "oidc federation sign-in complete");
331 Ok((user.clone(), session, raw_str))
332 }
333
334 async fn apply_claim_mapping(
337 &self,
338 user_id: Uuid,
339 provider: &authx_core::models::OidcFederationProvider,
340 claims: &serde_json::Value,
341 ) -> Option<Uuid> {
342 let mut resolved_org_id = None;
343
344 for rule in &provider.claim_mapping {
345 if !rule_matches(rule, claims) {
346 continue;
347 }
348
349 match rule.action.as_str() {
350 "add_to_org" => {
351 if let Ok(Some(org)) =
352 OrgRepository::find_by_slug(&self.storage, &rule.target).await
353 {
354 if let Ok(roles) = OrgRepository::find_roles(&self.storage, org.id).await {
356 let role_id = roles
357 .iter()
358 .find(|r| r.name == "member")
359 .or(roles.first())
360 .map(|r| r.id);
361 if let Some(rid) = role_id {
362 let _ =
363 OrgRepository::add_member(&self.storage, org.id, user_id, rid)
364 .await;
365 }
366 }
367 resolved_org_id = Some(org.id);
368 }
369 }
370 "assign_role" => {
371 if let Some(org_id) = provider.org_id
373 && let Ok(roles) = OrgRepository::find_roles(&self.storage, org_id).await
374 && let Some(role) = roles.iter().find(|r| r.name == rule.target)
375 {
376 let _ = OrgRepository::update_member_role(
377 &self.storage,
378 org_id,
379 user_id,
380 role.id,
381 )
382 .await;
383 }
384 }
385 other => {
386 tracing::debug!(action = other, "unknown claim mapping action, skipping");
387 }
388 }
389 }
390
391 resolved_org_id
392 }
393}
394
395fn rule_matches(rule: &ClaimMappingRule, claims: &serde_json::Value) -> bool {
397 let claim_value = match claims.get(&rule.source_claim) {
398 Some(v) => v,
399 None => return false,
400 };
401
402 match rule.match_type.as_str() {
403 "equals" => match claim_value {
404 serde_json::Value::String(s) => s == &rule.match_value,
405 serde_json::Value::Bool(b) => b.to_string() == rule.match_value,
406 serde_json::Value::Number(n) => n.to_string() == rule.match_value,
407 _ => false,
408 },
409 "contains" => match claim_value {
410 serde_json::Value::String(s) => s.contains(&rule.match_value),
411 serde_json::Value::Array(arr) => arr
412 .iter()
413 .any(|v| v.as_str().map(|s| s == rule.match_value).unwrap_or(false)),
414 _ => false,
415 },
416 "exists" => true, _ => false,
418 }
419}