1use std::{
2 collections::HashMap,
3 sync::{Arc, RwLock},
4 time::{Duration, Instant},
5};
6
7use chrono::Utc;
8use rand::Rng;
9use serde::Serialize;
10use tracing::instrument;
11use uuid::Uuid;
12use webauthn_rs::prelude::{
13 CreationChallengeResponse, Passkey, PasskeyAuthentication, PasskeyRegistration,
14 PublicKeyCredential, RegisterPublicKeyCredential, RequestChallengeResponse, Url, Webauthn,
15 WebauthnBuilder,
16};
17
18use authx_core::{
19 crypto::sha256_hex,
20 error::{AuthError, Result, StorageError},
21 models::{CreateCredential, CreateSession, CredentialKind},
22};
23use authx_storage::ports::{CredentialRepository, SessionRepository, UserRepository};
24
25enum PendingCeremony {
26 Registration {
27 user_id: Uuid,
28 state: PasskeyRegistration,
29 expires_at: Instant,
30 },
31 Authentication {
32 user_id: Uuid,
33 state: PasskeyAuthentication,
34 passkey: Box<Passkey>,
35 expires_at: Instant,
36 },
37}
38
39#[derive(Debug, Clone, Serialize)]
40pub struct WebAuthnBeginResponse {
41 pub challenge: String,
43 pub user_id: Uuid,
44 pub timeout_secs: u64,
45 pub options: serde_json::Value,
46}
47
48#[derive(Debug, Clone)]
49pub struct FinishRegistrationRequest {
50 pub challenge: String,
51 pub credential: RegisterPublicKeyCredential,
52}
53
54#[derive(Debug, Clone, Serialize)]
55pub struct WebAuthnRegistrationResult {
56 pub user_id: Uuid,
57 pub credential_stored: bool,
58}
59
60#[derive(Debug, Clone)]
61pub struct FinishAuthenticationRequest {
62 pub challenge: String,
63 pub credential: PublicKeyCredential,
64 pub ip: String,
65}
66
67#[derive(Debug, Clone, Serialize)]
68pub struct WebAuthnAuthenticationResult {
69 pub user_id: Uuid,
70 pub session_id: Uuid,
71 pub token: String,
72}
73
74pub struct WebAuthnService<S> {
79 storage: S,
80 webauthn: Webauthn,
81 challenge_ttl: Duration,
82 session_ttl_secs: i64,
83 pending: Arc<RwLock<HashMap<String, PendingCeremony>>>,
84}
85
86impl<S> WebAuthnService<S>
87where
88 S: UserRepository + CredentialRepository + SessionRepository + Clone + Send + Sync + 'static,
89{
90 pub fn new(
91 storage: S,
92 rp_id: impl Into<String>,
93 rp_origin: impl Into<String>,
94 challenge_ttl: Duration,
95 session_ttl_secs: i64,
96 ) -> Result<Self> {
97 let rp_id = rp_id.into();
98 let rp_origin = rp_origin.into();
99 let origin = Url::parse(&rp_origin)
100 .map_err(|e| AuthError::Internal(format!("invalid rp origin: {e}")))?;
101 let webauthn = WebauthnBuilder::new(&rp_id, &origin)
102 .map_err(|e| AuthError::Internal(format!("invalid webauthn config: {e}")))?
103 .timeout(challenge_ttl)
104 .build()
105 .map_err(|e| AuthError::Internal(format!("invalid webauthn builder config: {e}")))?;
106
107 Ok(Self {
108 storage,
109 webauthn,
110 challenge_ttl,
111 session_ttl_secs,
112 pending: Arc::new(RwLock::new(HashMap::new())),
113 })
114 }
115
116 #[instrument(skip(self), fields(user_id = %user_id))]
117 pub async fn begin_registration(&self, user_id: Uuid) -> Result<WebAuthnBeginResponse> {
118 let user = UserRepository::find_by_id(&self.storage, user_id)
119 .await?
120 .ok_or(AuthError::UserNotFound)?;
121
122 let exclude_credentials = self
123 .stored_passkey(user_id)
124 .await?
125 .map(|passkey| vec![passkey.cred_id().clone()]);
126
127 let (creation, state) = self
128 .webauthn
129 .start_passkey_registration(user.id, &user.email, &user.email, exclude_credentials)
130 .map_err(|e| AuthError::Internal(format!("start passkey registration failed: {e}")))?;
131
132 let ceremony_id = self.store_registration_state(user_id, state)?;
133 Ok(WebAuthnBeginResponse {
134 challenge: ceremony_id,
135 user_id,
136 timeout_secs: self.challenge_ttl.as_secs(),
137 options: creation_challenge_to_json(creation)?,
138 })
139 }
140
141 #[instrument(skip(self, req))]
142 pub async fn finish_registration(
143 &self,
144 req: FinishRegistrationRequest,
145 ) -> Result<WebAuthnRegistrationResult> {
146 let (user_id, state) = self.take_registration_state(&req.challenge)?;
147 let passkey = self
148 .webauthn
149 .finish_passkey_registration(&req.credential, &state)
150 .map_err(|_| AuthError::InvalidToken)?;
151
152 self.persist_passkey(user_id, passkey).await?;
153
154 tracing::info!(user_id = %user_id, "webauthn credential registered");
155 Ok(WebAuthnRegistrationResult {
156 user_id,
157 credential_stored: true,
158 })
159 }
160
161 #[instrument(skip(self), fields(user_id = %user_id))]
162 pub async fn begin_authentication(&self, user_id: Uuid) -> Result<WebAuthnBeginResponse> {
163 let _ = UserRepository::find_by_id(&self.storage, user_id)
164 .await?
165 .ok_or(AuthError::UserNotFound)?;
166
167 let passkey = self
168 .stored_passkey(user_id)
169 .await?
170 .ok_or(AuthError::InvalidCredentials)?;
171 let (request, state) = self
172 .webauthn
173 .start_passkey_authentication(std::slice::from_ref(&passkey))
174 .map_err(|e| {
175 AuthError::Internal(format!("start passkey authentication failed: {e}"))
176 })?;
177
178 let ceremony_id = self.store_authentication_state(user_id, state, passkey)?;
179 Ok(WebAuthnBeginResponse {
180 challenge: ceremony_id,
181 user_id,
182 timeout_secs: self.challenge_ttl.as_secs(),
183 options: request_challenge_to_json(request)?,
184 })
185 }
186
187 #[instrument(skip(self, req))]
188 pub async fn finish_authentication(
189 &self,
190 req: FinishAuthenticationRequest,
191 ) -> Result<WebAuthnAuthenticationResult> {
192 let (user_id, state, mut passkey) = self.take_authentication_state(&req.challenge)?;
193 let auth_result = self
194 .webauthn
195 .finish_passkey_authentication(&req.credential, &state)
196 .map_err(|_| AuthError::InvalidCredentials)?;
197
198 let _ = passkey.update_credential(&auth_result);
199 self.persist_passkey(user_id, passkey).await?;
200
201 let raw_token = generate_token();
202 let token_hash = sha256_hex(raw_token.as_bytes());
203 let session = SessionRepository::create(
204 &self.storage,
205 CreateSession {
206 user_id,
207 token_hash,
208 device_info: serde_json::json!({ "webauthn": true }),
209 ip_address: req.ip,
210 org_id: None,
211 expires_at: Utc::now() + chrono::Duration::seconds(self.session_ttl_secs),
212 },
213 )
214 .await?;
215
216 tracing::info!(user_id = %user_id, session_id = %session.id, "webauthn sign-in complete");
217 Ok(WebAuthnAuthenticationResult {
218 user_id,
219 session_id: session.id,
220 token: raw_token,
221 })
222 }
223
224 async fn persist_passkey(&self, user_id: Uuid, passkey: Passkey) -> Result<()> {
225 let serialized = serde_json::to_value(passkey)
226 .map_err(|e| AuthError::Internal(format!("serialize passkey failed: {e}")))?;
227 let cred_hash = sha256_hex(
228 serde_json::to_string(&serialized)
229 .map_err(|e| AuthError::Internal(format!("encode passkey failed: {e}")))?
230 .as_bytes(),
231 );
232
233 self.delete_existing_webauthn_credential(user_id).await?;
234 CredentialRepository::create(
235 &self.storage,
236 CreateCredential {
237 user_id,
238 kind: CredentialKind::Webauthn,
239 credential_hash: cred_hash,
240 metadata: Some(serde_json::json!({
241 "passkey": serialized,
242 "updated_at": Utc::now(),
243 })),
244 },
245 )
246 .await?;
247 Ok(())
248 }
249
250 async fn stored_passkey(&self, user_id: Uuid) -> Result<Option<Passkey>> {
251 let credential = CredentialRepository::find_by_user_and_kind(
252 &self.storage,
253 user_id,
254 CredentialKind::Webauthn,
255 )
256 .await?;
257 let Some(credential) = credential else {
258 return Ok(None);
259 };
260
261 let passkey_value = credential
262 .metadata
263 .get("passkey")
264 .cloned()
265 .ok_or(AuthError::InvalidCredentials)?;
266 let passkey =
267 serde_json::from_value(passkey_value).map_err(|_| AuthError::InvalidCredentials)?;
268 Ok(Some(passkey))
269 }
270
271 async fn delete_existing_webauthn_credential(&self, user_id: Uuid) -> Result<()> {
272 match CredentialRepository::delete_by_user_and_kind(
273 &self.storage,
274 user_id,
275 CredentialKind::Webauthn,
276 )
277 .await
278 {
279 Ok(()) | Err(AuthError::Storage(StorageError::NotFound)) => Ok(()),
280 Err(e) => Err(e),
281 }
282 }
283
284 fn store_registration_state(
285 &self,
286 user_id: Uuid,
287 state: PasskeyRegistration,
288 ) -> Result<String> {
289 let ceremony_id = generate_token();
290 let mut pending = self
291 .pending
292 .write()
293 .map_err(|e| AuthError::Internal(format!("webauthn ceremony lock poisoned: {e}")))?;
294 purge_expired(&mut pending);
295 pending.insert(
296 ceremony_id.clone(),
297 PendingCeremony::Registration {
298 user_id,
299 state,
300 expires_at: Instant::now() + self.challenge_ttl,
301 },
302 );
303 Ok(ceremony_id)
304 }
305
306 fn store_authentication_state(
307 &self,
308 user_id: Uuid,
309 state: PasskeyAuthentication,
310 passkey: Passkey,
311 ) -> Result<String> {
312 let ceremony_id = generate_token();
313 let mut pending = self
314 .pending
315 .write()
316 .map_err(|e| AuthError::Internal(format!("webauthn ceremony lock poisoned: {e}")))?;
317 purge_expired(&mut pending);
318 pending.insert(
319 ceremony_id.clone(),
320 PendingCeremony::Authentication {
321 user_id,
322 state,
323 passkey: Box::new(passkey),
324 expires_at: Instant::now() + self.challenge_ttl,
325 },
326 );
327 Ok(ceremony_id)
328 }
329
330 fn take_registration_state(&self, ceremony_id: &str) -> Result<(Uuid, PasskeyRegistration)> {
331 let mut pending = self
332 .pending
333 .write()
334 .map_err(|e| AuthError::Internal(format!("webauthn ceremony lock poisoned: {e}")))?;
335 let Some(ceremony) = pending.remove(ceremony_id) else {
336 return Err(AuthError::InvalidToken);
337 };
338
339 match ceremony {
340 PendingCeremony::Registration {
341 user_id,
342 state,
343 expires_at,
344 } if expires_at >= Instant::now() => Ok((user_id, state)),
345 _ => Err(AuthError::InvalidToken),
346 }
347 }
348
349 fn take_authentication_state(
350 &self,
351 ceremony_id: &str,
352 ) -> Result<(Uuid, PasskeyAuthentication, Passkey)> {
353 let mut pending = self
354 .pending
355 .write()
356 .map_err(|e| AuthError::Internal(format!("webauthn ceremony lock poisoned: {e}")))?;
357 let Some(ceremony) = pending.remove(ceremony_id) else {
358 return Err(AuthError::InvalidToken);
359 };
360
361 match ceremony {
362 PendingCeremony::Authentication {
363 user_id,
364 state,
365 passkey,
366 expires_at,
367 } if expires_at >= Instant::now() => Ok((user_id, state, *passkey)),
368 _ => Err(AuthError::InvalidToken),
369 }
370 }
371}
372
373fn creation_challenge_to_json(challenge: CreationChallengeResponse) -> Result<serde_json::Value> {
374 serde_json::to_value(challenge)
375 .map_err(|e| AuthError::Internal(format!("serialize registration challenge failed: {e}")))
376}
377
378fn request_challenge_to_json(challenge: RequestChallengeResponse) -> Result<serde_json::Value> {
379 serde_json::to_value(challenge)
380 .map_err(|e| AuthError::Internal(format!("serialize authentication challenge failed: {e}")))
381}
382
383fn generate_token() -> String {
384 let bytes: [u8; 32] = rand::thread_rng().gen();
385 hex::encode(bytes)
386}
387
388fn purge_expired(pending: &mut HashMap<String, PendingCeremony>) {
389 let now = Instant::now();
390 pending.retain(|_, ceremony| match ceremony {
391 PendingCeremony::Registration { expires_at, .. }
392 | PendingCeremony::Authentication { expires_at, .. } => *expires_at >= now,
393 });
394}