Skip to main content

authx_plugins/oauth/
service.rs

1use std::sync::Arc;
2
3use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
4use chrono::Utc;
5use rand::Rng;
6use sha2::{Digest, Sha256};
7use tracing::instrument;
8
9use authx_core::{
10    crypto::{encrypt, sha256_hex},
11    error::{AuthError, Result},
12    events::{AuthEvent, EventBus},
13    models::{CreateSession, CreateUser, Session, UpsertOAuthAccount, User},
14};
15use authx_storage::ports::{OAuthAccountRepository, SessionRepository, UserRepository};
16
17use super::providers::OAuthProvider;
18
19/// Returned from `begin()`. The caller should store `state` and `code_verifier`
20/// (e.g., in a server-side session or signed cookie) to verify the callback.
21#[derive(Debug)]
22pub struct OAuthBeginResponse {
23    pub authorization_url: String,
24    pub state: String,
25    pub code_verifier: String,
26}
27
28/// Parameters for the OAuth callback step.
29pub struct OAuthCallbackRequest<'a> {
30    pub provider_name: &'a str,
31    pub code: &'a str,
32    /// The state value originally returned by `begin()` (stored server-side).
33    pub expected_state: &'a str,
34    /// The state received in the OAuth redirect query parameters.
35    pub received_state: &'a str,
36    pub code_verifier: &'a str,
37    pub redirect_uri: &'a str,
38    pub ip: &'a str,
39}
40
41/// OAuth authentication service supporting multiple providers.
42///
43/// Providers are registered by name via [`OAuthService::register`].
44pub struct OAuthService<S> {
45    storage: S,
46    events: EventBus,
47    providers: std::collections::HashMap<String, Arc<dyn OAuthProvider>>,
48    session_ttl_secs: i64,
49    /// 32-byte key for AES-256-GCM token encryption.
50    encryption_key: [u8; 32],
51}
52
53impl<S> OAuthService<S>
54where
55    S: UserRepository + SessionRepository + OAuthAccountRepository + Clone + Send + Sync + 'static,
56{
57    pub fn new(
58        storage: S,
59        events: EventBus,
60        session_ttl_secs: i64,
61        encryption_key: [u8; 32],
62    ) -> Self {
63        Self {
64            storage,
65            events,
66            providers: Default::default(),
67            session_ttl_secs,
68            encryption_key,
69        }
70    }
71
72    /// Register an OAuth provider.
73    pub fn register(mut self, provider: impl OAuthProvider + 'static) -> Self {
74        self.providers
75            .insert(provider.name().to_owned(), Arc::new(provider));
76        self
77    }
78
79    fn provider(&self, name: &str) -> Result<&dyn OAuthProvider> {
80        self.providers
81            .get(name)
82            .map(|p| p.as_ref())
83            .ok_or_else(|| AuthError::Internal(format!("unknown oauth provider: {name}")))
84    }
85
86    /// Begin an OAuth flow. Generate PKCE verifier+challenge and a random state.
87    #[instrument(skip(self), fields(provider = %provider_name))]
88    pub fn begin(&self, provider_name: &str, _redirect_uri: &str) -> Result<OAuthBeginResponse> {
89        self.provider(provider_name)?;
90
91        // Generate PKCE code_verifier (32 random bytes, base64url-encoded).
92        let verifier_bytes: [u8; 32] = rand::thread_rng().gen();
93        let code_verifier = URL_SAFE_NO_PAD.encode(verifier_bytes);
94
95        // code_challenge = BASE64URL(SHA256(verifier))
96        let mut hasher = Sha256::new();
97        hasher.update(code_verifier.as_bytes());
98        let digest = hasher.finalize();
99        let code_challenge = URL_SAFE_NO_PAD.encode(digest);
100
101        // Random state token.
102        let state_bytes: [u8; 16] = rand::thread_rng().gen();
103        let state = hex::encode(state_bytes);
104
105        let authorization_url = self
106            .provider(provider_name)?
107            .authorization_url(&state, &code_challenge);
108
109        tracing::info!(provider = %provider_name, "oauth flow started");
110        Ok(OAuthBeginResponse {
111            authorization_url,
112            state,
113            code_verifier,
114        })
115    }
116
117    /// Handle the OAuth callback. Exchange the code, fetch user info, upsert the
118    /// OAuth account, find-or-create the user by email, and create a session.
119    ///
120    /// `req.expected_state` must be the value returned by `begin()` (stored
121    /// server-side). `req.received_state` comes from the OAuth redirect query
122    /// param. A mismatch is treated as a CSRF attempt and rejected immediately.
123    #[instrument(skip(self, req), fields(provider = %req.provider_name, ip = %req.ip))]
124    pub async fn callback(&self, req: OAuthCallbackRequest<'_>) -> Result<(User, Session, String)> {
125        use subtle::ConstantTimeEq;
126        if req
127            .expected_state
128            .as_bytes()
129            .ct_eq(req.received_state.as_bytes())
130            .unwrap_u8()
131            == 0
132        {
133            tracing::warn!(provider = %req.provider_name, "oauth state mismatch — possible CSRF");
134            return Err(AuthError::InvalidToken);
135        }
136
137        let provider = self.provider(req.provider_name)?;
138        let tokens = provider
139            .exchange_code(req.code, req.code_verifier, req.redirect_uri)
140            .await?;
141        let info = provider.fetch_user_info(&tokens.access_token).await?;
142
143        // Encrypt tokens before storing.
144        let access_enc = encrypt(&self.encryption_key, tokens.access_token.as_bytes())
145            .map_err(|e| AuthError::Internal(format!("token encrypt: {e}")))?;
146        let refresh_enc = tokens
147            .refresh_token
148            .as_deref()
149            .map(|r| encrypt(&self.encryption_key, r.as_bytes()))
150            .transpose()
151            .map_err(|e| AuthError::Internal(format!("token encrypt: {e}")))?;
152
153        let expires_at = tokens
154            .expires_in
155            .map(|secs| Utc::now() + chrono::Duration::seconds(secs as i64));
156
157        // Find or create user by email.
158        let user = match UserRepository::find_by_email(&self.storage, &info.email).await? {
159            Some(u) => u,
160            None => {
161                let u = UserRepository::create(
162                    &self.storage,
163                    CreateUser {
164                        email: info.email.clone(),
165                        username: None,
166                        metadata: None,
167                    },
168                )
169                .await?;
170                self.events.emit(AuthEvent::UserCreated { user: u.clone() });
171                u
172            }
173        };
174
175        // Upsert OAuth account.
176        OAuthAccountRepository::upsert(
177            &self.storage,
178            UpsertOAuthAccount {
179                user_id: user.id,
180                provider: req.provider_name.to_owned(),
181                provider_user_id: info.provider_user_id,
182                access_token_enc: access_enc,
183                refresh_token_enc: refresh_enc,
184                expires_at,
185            },
186        )
187        .await?;
188
189        self.events.emit(AuthEvent::OAuthLinked {
190            user_id: user.id,
191            provider: req.provider_name.to_owned(),
192        });
193
194        // Create session.
195        let raw: [u8; 32] = rand::thread_rng().gen();
196        let raw_str = hex::encode(raw);
197        let token_hash = sha256_hex(raw_str.as_bytes());
198
199        let session = SessionRepository::create(
200            &self.storage,
201            CreateSession {
202                user_id: user.id,
203                token_hash,
204                device_info: serde_json::json!({ "provider": req.provider_name }),
205                ip_address: req.ip.to_owned(),
206                org_id: None,
207                expires_at: Utc::now() + chrono::Duration::seconds(self.session_ttl_secs),
208            },
209        )
210        .await?;
211
212        self.events.emit(AuthEvent::SignIn {
213            user: user.clone(),
214            session: session.clone(),
215        });
216        tracing::info!(user_id = %user.id, provider = %req.provider_name, "oauth sign-in complete");
217        Ok((user, session, raw_str))
218    }
219}