Skip to main content

authx_plugins/oidc_federation/
service.rs

1//! OIDC Federation — sign in via external IdPs (Okta, Azure AD, Google Workspace).
2
3use std::collections::HashMap;
4use std::sync::Arc;
5use std::time::{Duration, Instant};
6
7use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
8use chrono::Utc;
9use rand::Rng;
10use serde::Deserialize;
11use sha2::{Digest, Sha256};
12use tracing::instrument;
13use uuid::Uuid;
14
15use authx_core::{
16    crypto::{decrypt, encrypt, sha256_hex},
17    error::{AuthError, Result},
18    models::{ClaimMappingRule, CreateSession, CreateUser, Session, UpsertOAuthAccount, User},
19};
20use authx_storage::ports::{
21    OAuthAccountRepository, OidcFederationProviderRepository, OrgRepository, SessionRepository,
22    UserRepository,
23};
24
25/// Response from begin() — redirect the user to authorization_url.
26#[derive(Debug)]
27pub struct OidcFederationBeginResponse {
28    pub authorization_url: String,
29    pub state: String,
30    pub code_verifier: String,
31}
32
33/// Discovered OIDC provider configuration.
34#[derive(Debug, Deserialize)]
35struct OidcDiscovery {
36    authorization_endpoint: String,
37    token_endpoint: String,
38    #[serde(default)]
39    userinfo_endpoint: Option<String>,
40}
41
42/// UserInfo from OIDC IdP.
43#[derive(Debug, Deserialize)]
44pub struct OidcUserInfo {
45    pub sub: String,
46    #[serde(default)]
47    pub email: Option<String>,
48    #[serde(default)]
49    pub email_verified: Option<bool>,
50    #[serde(default)]
51    pub name: Option<String>,
52    #[serde(default)]
53    pub preferred_username: Option<String>,
54    /// All extra claims for claim mapping evaluation.
55    #[serde(flatten)]
56    pub extra: serde_json::Value,
57}
58
59/// Stored federation flow state (code_verifier + redirect_uri).
60struct FederationState {
61    code_verifier: String,
62    redirect_uri: String,
63    expires_at: Instant,
64}
65
66/// OIDC Federation service — sign in via Okta, Azure AD, Google Workspace, etc.
67pub struct OidcFederationService<S> {
68    storage: S,
69    session_ttl_secs: i64,
70    encryption_key: [u8; 32],
71    client: reqwest::Client,
72    /// state -> (code_verifier, redirect_uri) for callback lookup.
73    pending: Arc<std::sync::RwLock<HashMap<String, FederationState>>>,
74}
75
76impl<S> OidcFederationService<S>
77where
78    S: OidcFederationProviderRepository
79        + UserRepository
80        + SessionRepository
81        + OAuthAccountRepository
82        + OrgRepository
83        + Clone
84        + Send
85        + Sync
86        + 'static,
87{
88    pub fn new(storage: S, session_ttl_secs: i64, encryption_key: [u8; 32]) -> Self {
89        Self {
90            storage,
91            session_ttl_secs,
92            encryption_key,
93            client: reqwest::Client::new(),
94            pending: Arc::new(std::sync::RwLock::new(HashMap::new())),
95        }
96    }
97
98    /// Begin OIDC federation flow. Returns URL to redirect user to the IdP.
99    #[instrument(skip(self))]
100    pub async fn begin(
101        &self,
102        provider_name: &str,
103        redirect_uri: &str,
104    ) -> Result<OidcFederationBeginResponse> {
105        let provider = OidcFederationProviderRepository::find_by_name(&self.storage, provider_name)
106            .await?
107            .ok_or_else(|| {
108                AuthError::Internal(format!("unknown federation provider: {provider_name}"))
109            })?;
110
111        if !provider.enabled {
112            return Err(AuthError::Internal("provider is disabled".into()));
113        }
114
115        let discovery_url = format!(
116            "{}/.well-known/openid-configuration",
117            provider.issuer.trim_end_matches('/')
118        );
119        let discovery: OidcDiscovery = self
120            .client
121            .get(&discovery_url)
122            .send()
123            .await
124            .map_err(|e| AuthError::Internal(format!("oidc discovery failed: {e}")))?
125            .json()
126            .await
127            .map_err(|e| AuthError::Internal(format!("oidc discovery parse failed: {e}")))?;
128
129        let verifier_bytes: [u8; 32] = rand::thread_rng().gen();
130        let code_verifier = URL_SAFE_NO_PAD.encode(verifier_bytes);
131        let mut hasher = Sha256::new();
132        hasher.update(code_verifier.as_bytes());
133        let code_challenge = URL_SAFE_NO_PAD.encode(hasher.finalize());
134
135        let state_bytes: [u8; 16] = rand::thread_rng().gen();
136        let state = hex::encode(state_bytes);
137
138        {
139            let mut pending = self
140                .pending
141                .write()
142                .map_err(|e| AuthError::Internal(format!("lock poisoned: {e}")))?;
143            pending.insert(
144                state.clone(),
145                FederationState {
146                    code_verifier: code_verifier.clone(),
147                    redirect_uri: redirect_uri.to_string(),
148                    expires_at: Instant::now() + Duration::from_secs(600), // 10 min
149                },
150            );
151        }
152
153        let mut auth_url = reqwest::Url::parse(&discovery.authorization_endpoint)
154            .map_err(|e| AuthError::Internal(format!("invalid auth endpoint: {e}")))?;
155        auth_url
156            .query_pairs_mut()
157            .append_pair("response_type", "code")
158            .append_pair("client_id", &provider.client_id)
159            .append_pair("redirect_uri", redirect_uri)
160            .append_pair("scope", &provider.scopes)
161            .append_pair("state", &state)
162            .append_pair("code_challenge", &code_challenge)
163            .append_pair("code_challenge_method", "S256");
164
165        Ok(OidcFederationBeginResponse {
166            authorization_url: auth_url.to_string(),
167            state,
168            code_verifier,
169        })
170    }
171
172    /// Handle callback from IdP. Exchange code, get userinfo, find-or-create user, create session.
173    /// Looks up code_verifier and redirect_uri from the pending state stored during begin().
174    #[instrument(skip(self))]
175    pub async fn callback(
176        &self,
177        provider_name: &str,
178        code: &str,
179        state: &str,
180        ip: &str,
181    ) -> Result<(User, Session, String)> {
182        let (code_verifier, redirect_uri) = {
183            let mut pending = self
184                .pending
185                .write()
186                .map_err(|e| AuthError::Internal(format!("lock poisoned: {e}")))?;
187            let entry = pending.remove(state).ok_or(AuthError::InvalidToken)?;
188            if entry.expires_at < Instant::now() {
189                return Err(AuthError::InvalidToken);
190            }
191            (entry.code_verifier, entry.redirect_uri)
192        };
193
194        let provider = OidcFederationProviderRepository::find_by_name(&self.storage, provider_name)
195            .await?
196            .ok_or_else(|| {
197                AuthError::Internal(format!("unknown federation provider: {provider_name}"))
198            })?;
199
200        let secret_bytes = decrypt(&self.encryption_key, &provider.secret_enc)
201            .map_err(|e| AuthError::Internal(format!("decrypt client secret: {e}")))?;
202        let secret = String::from_utf8(secret_bytes)
203            .map_err(|_| AuthError::Internal("client secret not valid UTF-8".into()))?;
204
205        let discovery_url = format!(
206            "{}/.well-known/openid-configuration",
207            provider.issuer.trim_end_matches('/')
208        );
209        let discovery: OidcDiscovery = self
210            .client
211            .get(&discovery_url)
212            .send()
213            .await
214            .map_err(|e| AuthError::Internal(format!("oidc discovery: {e}")))?
215            .json()
216            .await
217            .map_err(|e| AuthError::Internal(format!("oidc discovery parse: {e}")))?;
218
219        let token_resp = self
220            .client
221            .post(&discovery.token_endpoint)
222            .form(&[
223                ("grant_type", "authorization_code"),
224                ("code", code),
225                ("redirect_uri", &redirect_uri),
226                ("client_id", &provider.client_id),
227                ("client_secret", &secret),
228                ("code_verifier", &code_verifier),
229            ])
230            .send()
231            .await
232            .map_err(|e| AuthError::Internal(format!("token exchange: {e}")))?;
233
234        if !token_resp.status().is_success() {
235            let status = token_resp.status();
236            let body = token_resp.text().await.unwrap_or_default();
237            return Err(AuthError::Internal(format!(
238                "token exchange failed {}: {}",
239                status, body
240            )));
241        }
242
243        #[derive(Deserialize)]
244        struct TokenResponse {
245            access_token: String,
246        }
247        let tokens: TokenResponse = token_resp
248            .json()
249            .await
250            .map_err(|e| AuthError::Internal(format!("token parse: {e}")))?;
251
252        let userinfo_endpoint = discovery
253            .userinfo_endpoint
254            .ok_or_else(|| AuthError::Internal("IdP has no userinfo endpoint".into()))?;
255
256        let userinfo: OidcUserInfo = self
257            .client
258            .get(&userinfo_endpoint)
259            .bearer_auth(&tokens.access_token)
260            .send()
261            .await
262            .map_err(|e| AuthError::Internal(format!("userinfo: {e}")))?
263            .json()
264            .await
265            .map_err(|e| AuthError::Internal(format!("userinfo parse: {e}")))?;
266
267        let username = userinfo
268            .preferred_username
269            .clone()
270            .or_else(|| userinfo.name.clone());
271        let email = userinfo
272            .email
273            .filter(|_| userinfo.email_verified.unwrap_or(true))
274            .or_else(|| userinfo.preferred_username.clone())
275            .unwrap_or_else(|| format!("{}@{}", userinfo.sub, provider_name));
276
277        let user = match UserRepository::find_by_email(&self.storage, &email).await? {
278            Some(u) => u,
279            None => {
280                let u = UserRepository::create(
281                    &self.storage,
282                    CreateUser {
283                        email: email.clone(),
284                        username,
285                        metadata: None,
286                    },
287                )
288                .await?;
289                u
290            }
291        };
292
293        let access_enc = encrypt(&self.encryption_key, tokens.access_token.as_bytes())
294            .map_err(|e| AuthError::Internal(format!("encrypt: {e}")))?;
295
296        OAuthAccountRepository::upsert(
297            &self.storage,
298            UpsertOAuthAccount {
299                user_id: user.id,
300                provider: provider_name.to_string(),
301                provider_user_id: userinfo.sub,
302                access_token_enc: access_enc,
303                refresh_token_enc: None,
304                expires_at: None,
305            },
306        )
307        .await?;
308
309        // Apply claim mapping rules
310        let session_org_id: Option<Uuid> = self
311            .apply_claim_mapping(user.id, &provider, &userinfo.extra)
312            .await;
313
314        let raw: [u8; 32] = rand::thread_rng().gen();
315        let raw_str = hex::encode(raw);
316        let token_hash = sha256_hex(raw_str.as_bytes());
317
318        let session = SessionRepository::create(
319            &self.storage,
320            CreateSession {
321                user_id: user.id,
322                token_hash,
323                device_info: serde_json::json!({ "oidc_federation": provider_name }),
324                ip_address: ip.to_string(),
325                org_id: session_org_id.or(provider.org_id),
326                expires_at: Utc::now() + chrono::Duration::seconds(self.session_ttl_secs),
327            },
328        )
329        .await?;
330
331        tracing::info!(user_id = %user.id, provider = provider_name, "oidc federation sign-in complete");
332        Ok((user.clone(), session, raw_str))
333    }
334
335    /// Evaluate claim mapping rules against external IdP claims.
336    /// Returns an org_id if a rule resolved to "add_to_org".
337    async fn apply_claim_mapping(
338        &self,
339        user_id: Uuid,
340        provider: &authx_core::models::OidcFederationProvider,
341        claims: &serde_json::Value,
342    ) -> Option<Uuid> {
343        let mut resolved_org_id = None;
344
345        for rule in &provider.claim_mapping {
346            if !rule_matches(rule, claims) {
347                continue;
348            }
349
350            match rule.action.as_str() {
351                "add_to_org" => {
352                    if let Ok(Some(org)) =
353                        OrgRepository::find_by_slug(&self.storage, &rule.target).await
354                    {
355                        // Find default role for the org
356                        if let Ok(roles) = OrgRepository::find_roles(&self.storage, org.id).await {
357                            let role_id = roles
358                                .iter()
359                                .find(|r| r.name == "member")
360                                .or(roles.first())
361                                .map(|r| r.id);
362                            if let Some(rid) = role_id {
363                                let _ =
364                                    OrgRepository::add_member(&self.storage, org.id, user_id, rid)
365                                        .await;
366                            }
367                        }
368                        resolved_org_id = Some(org.id);
369                    }
370                }
371                "assign_role" => {
372                    // Assign a specific role within the provider's org
373                    if let Some(org_id) = provider.org_id {
374                        if let Ok(roles) = OrgRepository::find_roles(&self.storage, org_id).await {
375                            if let Some(role) = roles.iter().find(|r| r.name == rule.target) {
376                                let _ = OrgRepository::update_member_role(
377                                    &self.storage,
378                                    org_id,
379                                    user_id,
380                                    role.id,
381                                )
382                                .await;
383                            }
384                        }
385                    }
386                }
387                other => {
388                    tracing::debug!(action = other, "unknown claim mapping action, skipping");
389                }
390            }
391        }
392
393        resolved_org_id
394    }
395}
396
397/// Check if a claim mapping rule matches against the given claims JSON.
398fn rule_matches(rule: &ClaimMappingRule, claims: &serde_json::Value) -> bool {
399    let claim_value = match claims.get(&rule.source_claim) {
400        Some(v) => v,
401        None => return false,
402    };
403
404    match rule.match_type.as_str() {
405        "equals" => match claim_value {
406            serde_json::Value::String(s) => s == &rule.match_value,
407            serde_json::Value::Bool(b) => b.to_string() == rule.match_value,
408            serde_json::Value::Number(n) => n.to_string() == rule.match_value,
409            _ => false,
410        },
411        "contains" => match claim_value {
412            serde_json::Value::String(s) => s.contains(&rule.match_value),
413            serde_json::Value::Array(arr) => arr
414                .iter()
415                .any(|v| v.as_str().map(|s| s == rule.match_value).unwrap_or(false)),
416            _ => false,
417        },
418        "exists" => true, // claim exists, that's enough
419        _ => false,
420    }
421}