Skip to main content

authx_plugins/oidc_federation/
service.rs

1//! OIDC Federation — sign in via external IdPs (Okta, Azure AD, Google Workspace).
2
3use std::collections::HashMap;
4use std::sync::Arc;
5use std::time::{Duration, Instant};
6
7use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
8use chrono::Utc;
9use rand::Rng;
10use serde::Deserialize;
11use sha2::{Digest, Sha256};
12use tracing::instrument;
13use uuid::Uuid;
14
15use authx_core::{
16    crypto::{decrypt, encrypt, sha256_hex},
17    error::{AuthError, Result},
18    models::{ClaimMappingRule, CreateSession, CreateUser, Session, UpsertOAuthAccount, User},
19};
20use authx_storage::ports::{
21    OAuthAccountRepository, OidcFederationProviderRepository, OrgRepository, SessionRepository,
22    UserRepository,
23};
24
25/// Response from begin() — redirect the user to authorization_url.
26#[derive(Debug)]
27pub struct OidcFederationBeginResponse {
28    pub authorization_url: String,
29    pub state: String,
30    pub code_verifier: String,
31}
32
33/// Discovered OIDC provider configuration.
34#[derive(Debug, Deserialize)]
35struct OidcDiscovery {
36    authorization_endpoint: String,
37    token_endpoint: String,
38    #[serde(default)]
39    userinfo_endpoint: Option<String>,
40}
41
42/// UserInfo from OIDC IdP.
43#[derive(Debug, Deserialize)]
44pub struct OidcUserInfo {
45    pub sub: String,
46    #[serde(default)]
47    pub email: Option<String>,
48    #[serde(default)]
49    pub email_verified: Option<bool>,
50    #[serde(default)]
51    pub name: Option<String>,
52    #[serde(default)]
53    pub preferred_username: Option<String>,
54    /// All extra claims for claim mapping evaluation.
55    #[serde(flatten)]
56    pub extra: serde_json::Value,
57}
58
59/// Stored federation flow state (code_verifier + redirect_uri).
60struct FederationState {
61    code_verifier: String,
62    redirect_uri: String,
63    expires_at: Instant,
64}
65
66/// OIDC Federation service — sign in via Okta, Azure AD, Google Workspace, etc.
67pub struct OidcFederationService<S> {
68    storage: S,
69    session_ttl_secs: i64,
70    encryption_key: [u8; 32],
71    client: reqwest::Client,
72    /// state -> (code_verifier, redirect_uri) for callback lookup.
73    pending: Arc<std::sync::RwLock<HashMap<String, FederationState>>>,
74}
75
76impl<S> OidcFederationService<S>
77where
78    S: OidcFederationProviderRepository
79        + UserRepository
80        + SessionRepository
81        + OAuthAccountRepository
82        + OrgRepository
83        + Clone
84        + Send
85        + Sync
86        + 'static,
87{
88    pub fn new(storage: S, session_ttl_secs: i64, encryption_key: [u8; 32]) -> Self {
89        Self {
90            storage,
91            session_ttl_secs,
92            encryption_key,
93            client: reqwest::Client::new(),
94            pending: Arc::new(std::sync::RwLock::new(HashMap::new())),
95        }
96    }
97
98    /// Begin OIDC federation flow. Returns URL to redirect user to the IdP.
99    #[instrument(skip(self))]
100    pub async fn begin(
101        &self,
102        provider_name: &str,
103        redirect_uri: &str,
104    ) -> Result<OidcFederationBeginResponse> {
105        let provider = OidcFederationProviderRepository::find_by_name(&self.storage, provider_name)
106            .await?
107            .ok_or_else(|| {
108                AuthError::Internal(format!("unknown federation provider: {provider_name}"))
109            })?;
110
111        if !provider.enabled {
112            return Err(AuthError::Internal("provider is disabled".into()));
113        }
114
115        let discovery_url = format!(
116            "{}/.well-known/openid-configuration",
117            provider.issuer.trim_end_matches('/')
118        );
119        let discovery: OidcDiscovery = self
120            .client
121            .get(&discovery_url)
122            .send()
123            .await
124            .map_err(|e| AuthError::Internal(format!("oidc discovery failed: {e}")))?
125            .json()
126            .await
127            .map_err(|e| AuthError::Internal(format!("oidc discovery parse failed: {e}")))?;
128
129        let verifier_bytes: [u8; 32] = rand::thread_rng().r#gen();
130        let code_verifier = URL_SAFE_NO_PAD.encode(verifier_bytes);
131        let mut hasher = Sha256::new();
132        hasher.update(code_verifier.as_bytes());
133        let code_challenge = URL_SAFE_NO_PAD.encode(hasher.finalize());
134
135        let state_bytes: [u8; 16] = rand::thread_rng().r#gen();
136        let state = hex::encode(state_bytes);
137
138        {
139            let mut pending = self
140                .pending
141                .write()
142                .map_err(|e| AuthError::Internal(format!("lock poisoned: {e}")))?;
143            pending.insert(
144                state.clone(),
145                FederationState {
146                    code_verifier: code_verifier.clone(),
147                    redirect_uri: redirect_uri.to_string(),
148                    expires_at: Instant::now() + Duration::from_secs(600), // 10 min
149                },
150            );
151        }
152
153        let mut auth_url = reqwest::Url::parse(&discovery.authorization_endpoint)
154            .map_err(|e| AuthError::Internal(format!("invalid auth endpoint: {e}")))?;
155        auth_url
156            .query_pairs_mut()
157            .append_pair("response_type", "code")
158            .append_pair("client_id", &provider.client_id)
159            .append_pair("redirect_uri", redirect_uri)
160            .append_pair("scope", &provider.scopes)
161            .append_pair("state", &state)
162            .append_pair("code_challenge", &code_challenge)
163            .append_pair("code_challenge_method", "S256");
164
165        Ok(OidcFederationBeginResponse {
166            authorization_url: auth_url.to_string(),
167            state,
168            code_verifier,
169        })
170    }
171
172    /// Handle callback from IdP. Exchange code, get userinfo, find-or-create user, create session.
173    /// Looks up code_verifier and redirect_uri from the pending state stored during begin().
174    #[instrument(skip(self))]
175    pub async fn callback(
176        &self,
177        provider_name: &str,
178        code: &str,
179        state: &str,
180        ip: &str,
181    ) -> Result<(User, Session, String)> {
182        let (code_verifier, redirect_uri) = {
183            let mut pending = self
184                .pending
185                .write()
186                .map_err(|e| AuthError::Internal(format!("lock poisoned: {e}")))?;
187            let entry = pending.remove(state).ok_or(AuthError::InvalidToken)?;
188            if entry.expires_at < Instant::now() {
189                return Err(AuthError::InvalidToken);
190            }
191            (entry.code_verifier, entry.redirect_uri)
192        };
193
194        let provider = OidcFederationProviderRepository::find_by_name(&self.storage, provider_name)
195            .await?
196            .ok_or_else(|| {
197                AuthError::Internal(format!("unknown federation provider: {provider_name}"))
198            })?;
199
200        let secret_bytes = decrypt(&self.encryption_key, &provider.secret_enc)
201            .map_err(|e| AuthError::Internal(format!("decrypt client secret: {e}")))?;
202        let secret = String::from_utf8(secret_bytes)
203            .map_err(|_| AuthError::Internal("client secret not valid UTF-8".into()))?;
204
205        let discovery_url = format!(
206            "{}/.well-known/openid-configuration",
207            provider.issuer.trim_end_matches('/')
208        );
209        let discovery: OidcDiscovery = self
210            .client
211            .get(&discovery_url)
212            .send()
213            .await
214            .map_err(|e| AuthError::Internal(format!("oidc discovery: {e}")))?
215            .json()
216            .await
217            .map_err(|e| AuthError::Internal(format!("oidc discovery parse: {e}")))?;
218
219        let token_resp = self
220            .client
221            .post(&discovery.token_endpoint)
222            .form(&[
223                ("grant_type", "authorization_code"),
224                ("code", code),
225                ("redirect_uri", &redirect_uri),
226                ("client_id", &provider.client_id),
227                ("client_secret", &secret),
228                ("code_verifier", &code_verifier),
229            ])
230            .send()
231            .await
232            .map_err(|e| AuthError::Internal(format!("token exchange: {e}")))?;
233
234        if !token_resp.status().is_success() {
235            let status = token_resp.status();
236            let body = token_resp.text().await.unwrap_or_default();
237            return Err(AuthError::Internal(format!(
238                "token exchange failed {}: {}",
239                status, body
240            )));
241        }
242
243        #[derive(Deserialize)]
244        struct TokenResponse {
245            access_token: String,
246        }
247        let tokens: TokenResponse = token_resp
248            .json()
249            .await
250            .map_err(|e| AuthError::Internal(format!("token parse: {e}")))?;
251
252        let userinfo_endpoint = discovery
253            .userinfo_endpoint
254            .ok_or_else(|| AuthError::Internal("IdP has no userinfo endpoint".into()))?;
255
256        let userinfo: OidcUserInfo = self
257            .client
258            .get(&userinfo_endpoint)
259            .bearer_auth(&tokens.access_token)
260            .send()
261            .await
262            .map_err(|e| AuthError::Internal(format!("userinfo: {e}")))?
263            .json()
264            .await
265            .map_err(|e| AuthError::Internal(format!("userinfo parse: {e}")))?;
266
267        let username = userinfo
268            .preferred_username
269            .clone()
270            .or_else(|| userinfo.name.clone());
271        let email = userinfo
272            .email
273            .filter(|_| userinfo.email_verified.unwrap_or(true))
274            .or_else(|| userinfo.preferred_username.clone())
275            .unwrap_or_else(|| format!("{}@{}", userinfo.sub, provider_name));
276
277        let user = match UserRepository::find_by_email(&self.storage, &email).await? {
278            Some(u) => u,
279            None => {
280                UserRepository::create(
281                    &self.storage,
282                    CreateUser {
283                        email: email.clone(),
284                        username,
285                        metadata: None,
286                    },
287                )
288                .await?
289            }
290        };
291
292        let access_enc = encrypt(&self.encryption_key, tokens.access_token.as_bytes())
293            .map_err(|e| AuthError::Internal(format!("encrypt: {e}")))?;
294
295        OAuthAccountRepository::upsert(
296            &self.storage,
297            UpsertOAuthAccount {
298                user_id: user.id,
299                provider: provider_name.to_string(),
300                provider_user_id: userinfo.sub,
301                access_token_enc: access_enc,
302                refresh_token_enc: None,
303                expires_at: None,
304            },
305        )
306        .await?;
307
308        // Apply claim mapping rules
309        let session_org_id: Option<Uuid> = self
310            .apply_claim_mapping(user.id, &provider, &userinfo.extra)
311            .await;
312
313        let raw: [u8; 32] = rand::thread_rng().r#gen();
314        let raw_str = hex::encode(raw);
315        let token_hash = sha256_hex(raw_str.as_bytes());
316
317        let session = SessionRepository::create(
318            &self.storage,
319            CreateSession {
320                user_id: user.id,
321                token_hash,
322                device_info: serde_json::json!({ "oidc_federation": provider_name }),
323                ip_address: ip.to_string(),
324                org_id: session_org_id.or(provider.org_id),
325                expires_at: Utc::now() + chrono::Duration::seconds(self.session_ttl_secs),
326            },
327        )
328        .await?;
329
330        tracing::info!(user_id = %user.id, provider = provider_name, "oidc federation sign-in complete");
331        Ok((user.clone(), session, raw_str))
332    }
333
334    /// Evaluate claim mapping rules against external IdP claims.
335    /// Returns an org_id if a rule resolved to "add_to_org".
336    async fn apply_claim_mapping(
337        &self,
338        user_id: Uuid,
339        provider: &authx_core::models::OidcFederationProvider,
340        claims: &serde_json::Value,
341    ) -> Option<Uuid> {
342        let mut resolved_org_id = None;
343
344        for rule in &provider.claim_mapping {
345            if !rule_matches(rule, claims) {
346                continue;
347            }
348
349            match rule.action.as_str() {
350                "add_to_org" => {
351                    if let Ok(Some(org)) =
352                        OrgRepository::find_by_slug(&self.storage, &rule.target).await
353                    {
354                        // Find default role for the org
355                        if let Ok(roles) = OrgRepository::find_roles(&self.storage, org.id).await {
356                            let role_id = roles
357                                .iter()
358                                .find(|r| r.name == "member")
359                                .or(roles.first())
360                                .map(|r| r.id);
361                            if let Some(rid) = role_id {
362                                let _ =
363                                    OrgRepository::add_member(&self.storage, org.id, user_id, rid)
364                                        .await;
365                            }
366                        }
367                        resolved_org_id = Some(org.id);
368                    }
369                }
370                "assign_role" => {
371                    // Assign a specific role within the provider's org
372                    if let Some(org_id) = provider.org_id
373                        && let Ok(roles) = OrgRepository::find_roles(&self.storage, org_id).await
374                        && let Some(role) = roles.iter().find(|r| r.name == rule.target)
375                    {
376                        let _ = OrgRepository::update_member_role(
377                            &self.storage,
378                            org_id,
379                            user_id,
380                            role.id,
381                        )
382                        .await;
383                    }
384                }
385                other => {
386                    tracing::debug!(action = other, "unknown claim mapping action, skipping");
387                }
388            }
389        }
390
391        resolved_org_id
392    }
393}
394
395/// Check if a claim mapping rule matches against the given claims JSON.
396fn rule_matches(rule: &ClaimMappingRule, claims: &serde_json::Value) -> bool {
397    let claim_value = match claims.get(&rule.source_claim) {
398        Some(v) => v,
399        None => return false,
400    };
401
402    match rule.match_type.as_str() {
403        "equals" => match claim_value {
404            serde_json::Value::String(s) => s == &rule.match_value,
405            serde_json::Value::Bool(b) => b.to_string() == rule.match_value,
406            serde_json::Value::Number(n) => n.to_string() == rule.match_value,
407            _ => false,
408        },
409        "contains" => match claim_value {
410            serde_json::Value::String(s) => s.contains(&rule.match_value),
411            serde_json::Value::Array(arr) => arr
412                .iter()
413                .any(|v| v.as_str().map(|s| s == rule.match_value).unwrap_or(false)),
414            _ => false,
415        },
416        "exists" => true, // claim exists, that's enough
417        _ => false,
418    }
419}