1use std::sync::Arc;
2
3use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
4use chrono::Utc;
5use rand::Rng;
6use sha2::{Digest, Sha256};
7use tracing::instrument;
8
9use authx_core::{
10 crypto::{encrypt, sha256_hex},
11 error::{AuthError, Result},
12 events::{AuthEvent, EventBus},
13 models::{CreateSession, CreateUser, Session, UpsertOAuthAccount, User},
14};
15use authx_storage::ports::{OAuthAccountRepository, SessionRepository, UserRepository};
16
17use super::providers::OAuthProvider;
18
19#[derive(Debug)]
22pub struct OAuthBeginResponse {
23 pub authorization_url: String,
24 pub state: String,
25 pub code_verifier: String,
26}
27
28pub struct OAuthCallbackRequest<'a> {
30 pub provider_name: &'a str,
31 pub code: &'a str,
32 pub expected_state: &'a str,
34 pub received_state: &'a str,
36 pub code_verifier: &'a str,
37 pub redirect_uri: &'a str,
38 pub ip: &'a str,
39}
40
41pub struct OAuthService<S> {
45 storage: S,
46 events: EventBus,
47 providers: std::collections::HashMap<String, Arc<dyn OAuthProvider>>,
48 session_ttl_secs: i64,
49 encryption_key: [u8; 32],
51}
52
53impl<S> OAuthService<S>
54where
55 S: UserRepository + SessionRepository + OAuthAccountRepository + Clone + Send + Sync + 'static,
56{
57 pub fn new(
58 storage: S,
59 events: EventBus,
60 session_ttl_secs: i64,
61 encryption_key: [u8; 32],
62 ) -> Self {
63 Self {
64 storage,
65 events,
66 providers: Default::default(),
67 session_ttl_secs,
68 encryption_key,
69 }
70 }
71
72 pub fn register(mut self, provider: impl OAuthProvider + 'static) -> Self {
74 self.providers
75 .insert(provider.name().to_owned(), Arc::new(provider));
76 self
77 }
78
79 fn provider(&self, name: &str) -> Result<&dyn OAuthProvider> {
80 self.providers
81 .get(name)
82 .map(|p| p.as_ref())
83 .ok_or_else(|| AuthError::Internal(format!("unknown oauth provider: {name}")))
84 }
85
86 #[instrument(skip(self), fields(provider = %provider_name))]
88 pub fn begin(&self, provider_name: &str, _redirect_uri: &str) -> Result<OAuthBeginResponse> {
89 self.provider(provider_name)?;
90
91 let verifier_bytes: [u8; 32] = rand::thread_rng().r#gen();
93 let code_verifier = URL_SAFE_NO_PAD.encode(verifier_bytes);
94
95 let mut hasher = Sha256::new();
97 hasher.update(code_verifier.as_bytes());
98 let digest = hasher.finalize();
99 let code_challenge = URL_SAFE_NO_PAD.encode(digest);
100
101 let state_bytes: [u8; 16] = rand::thread_rng().r#gen();
103 let state = hex::encode(state_bytes);
104
105 let authorization_url = self
106 .provider(provider_name)?
107 .authorization_url(&state, &code_challenge);
108
109 tracing::info!(provider = %provider_name, "oauth flow started");
110 Ok(OAuthBeginResponse {
111 authorization_url,
112 state,
113 code_verifier,
114 })
115 }
116
117 #[instrument(skip(self, req), fields(provider = %req.provider_name, ip = %req.ip))]
124 pub async fn callback(&self, req: OAuthCallbackRequest<'_>) -> Result<(User, Session, String)> {
125 use subtle::ConstantTimeEq;
126 if req
127 .expected_state
128 .as_bytes()
129 .ct_eq(req.received_state.as_bytes())
130 .unwrap_u8()
131 == 0
132 {
133 tracing::warn!(provider = %req.provider_name, "oauth state mismatch — possible CSRF");
134 return Err(AuthError::InvalidToken);
135 }
136
137 let provider = self.provider(req.provider_name)?;
138 let tokens = provider
139 .exchange_code(req.code, req.code_verifier, req.redirect_uri)
140 .await?;
141 let info = provider.fetch_user_info(&tokens.access_token).await?;
142
143 let access_enc = encrypt(&self.encryption_key, tokens.access_token.as_bytes())
145 .map_err(|e| AuthError::Internal(format!("token encrypt: {e}")))?;
146 let refresh_enc = tokens
147 .refresh_token
148 .as_deref()
149 .map(|r| encrypt(&self.encryption_key, r.as_bytes()))
150 .transpose()
151 .map_err(|e| AuthError::Internal(format!("token encrypt: {e}")))?;
152
153 let expires_at = tokens
154 .expires_in
155 .map(|secs| Utc::now() + chrono::Duration::seconds(secs as i64));
156
157 let user = match UserRepository::find_by_email(&self.storage, &info.email).await? {
159 Some(u) => u,
160 None => {
161 let u = UserRepository::create(
162 &self.storage,
163 CreateUser {
164 email: info.email.clone(),
165 username: None,
166 metadata: None,
167 },
168 )
169 .await?;
170 self.events.emit(AuthEvent::UserCreated { user: u.clone() });
171 u
172 }
173 };
174
175 OAuthAccountRepository::upsert(
177 &self.storage,
178 UpsertOAuthAccount {
179 user_id: user.id,
180 provider: req.provider_name.to_owned(),
181 provider_user_id: info.provider_user_id,
182 access_token_enc: access_enc,
183 refresh_token_enc: refresh_enc,
184 expires_at,
185 },
186 )
187 .await?;
188
189 self.events.emit(AuthEvent::OAuthLinked {
190 user_id: user.id,
191 provider: req.provider_name.to_owned(),
192 });
193
194 let raw: [u8; 32] = rand::thread_rng().r#gen();
196 let raw_str = hex::encode(raw);
197 let token_hash = sha256_hex(raw_str.as_bytes());
198
199 let session = SessionRepository::create(
200 &self.storage,
201 CreateSession {
202 user_id: user.id,
203 token_hash,
204 device_info: serde_json::json!({ "provider": req.provider_name }),
205 ip_address: req.ip.to_owned(),
206 org_id: None,
207 expires_at: Utc::now() + chrono::Duration::seconds(self.session_ttl_secs),
208 },
209 )
210 .await?;
211
212 self.events.emit(AuthEvent::SignIn {
213 user: user.clone(),
214 session: session.clone(),
215 });
216 tracing::info!(user_id = %user.id, provider = %req.provider_name, "oauth sign-in complete");
217 Ok((user, session, raw_str))
218 }
219}