Skip to main content

authx_plugins/webauthn/
service.rs

1use std::{
2    collections::HashMap,
3    sync::{Arc, RwLock},
4    time::{Duration, Instant},
5};
6
7use chrono::Utc;
8use rand::Rng;
9use serde::Serialize;
10use tracing::instrument;
11use uuid::Uuid;
12use webauthn_rs::prelude::{
13    CreationChallengeResponse, Passkey, PasskeyAuthentication, PasskeyRegistration,
14    PublicKeyCredential, RegisterPublicKeyCredential, RequestChallengeResponse, Url, Webauthn,
15    WebauthnBuilder,
16};
17
18use authx_core::{
19    crypto::sha256_hex,
20    error::{AuthError, Result, StorageError},
21    models::{CreateCredential, CreateSession, CredentialKind},
22};
23use authx_storage::ports::{CredentialRepository, SessionRepository, UserRepository};
24
25enum PendingCeremony {
26    Registration {
27        user_id: Uuid,
28        state: PasskeyRegistration,
29        expires_at: Instant,
30    },
31    Authentication {
32        user_id: Uuid,
33        state: PasskeyAuthentication,
34        passkey: Box<Passkey>,
35        expires_at: Instant,
36    },
37}
38
39#[derive(Debug, Clone, Serialize)]
40pub struct WebAuthnBeginResponse {
41    /// Opaque server-side ceremony id (previously challenge id in this API).
42    pub challenge: String,
43    pub user_id: Uuid,
44    pub timeout_secs: u64,
45    pub options: serde_json::Value,
46}
47
48#[derive(Debug, Clone)]
49pub struct FinishRegistrationRequest {
50    pub challenge: String,
51    pub credential: RegisterPublicKeyCredential,
52}
53
54#[derive(Debug, Clone, Serialize)]
55pub struct WebAuthnRegistrationResult {
56    pub user_id: Uuid,
57    pub credential_stored: bool,
58}
59
60#[derive(Debug, Clone)]
61pub struct FinishAuthenticationRequest {
62    pub challenge: String,
63    pub credential: PublicKeyCredential,
64    pub ip: String,
65}
66
67#[derive(Debug, Clone, Serialize)]
68pub struct WebAuthnAuthenticationResult {
69    pub user_id: Uuid,
70    pub session_id: Uuid,
71    pub token: String,
72}
73
74/// WebAuthn/passkey service powered by `webauthn-rs`.
75///
76/// Registration/authentication states are held server-side in memory and
77/// expire by TTL. Registered passkeys are persisted in credential metadata.
78pub struct WebAuthnService<S> {
79    storage: S,
80    webauthn: Webauthn,
81    challenge_ttl: Duration,
82    session_ttl_secs: i64,
83    pending: Arc<RwLock<HashMap<String, PendingCeremony>>>,
84}
85
86impl<S> WebAuthnService<S>
87where
88    S: UserRepository + CredentialRepository + SessionRepository + Clone + Send + Sync + 'static,
89{
90    pub fn new(
91        storage: S,
92        rp_id: impl Into<String>,
93        rp_origin: impl Into<String>,
94        challenge_ttl: Duration,
95        session_ttl_secs: i64,
96    ) -> Result<Self> {
97        let rp_id = rp_id.into();
98        let rp_origin = rp_origin.into();
99        let origin = Url::parse(&rp_origin)
100            .map_err(|e| AuthError::Internal(format!("invalid rp origin: {e}")))?;
101        let webauthn = WebauthnBuilder::new(&rp_id, &origin)
102            .map_err(|e| AuthError::Internal(format!("invalid webauthn config: {e}")))?
103            .timeout(challenge_ttl)
104            .build()
105            .map_err(|e| AuthError::Internal(format!("invalid webauthn builder config: {e}")))?;
106
107        Ok(Self {
108            storage,
109            webauthn,
110            challenge_ttl,
111            session_ttl_secs,
112            pending: Arc::new(RwLock::new(HashMap::new())),
113        })
114    }
115
116    #[instrument(skip(self), fields(user_id = %user_id))]
117    pub async fn begin_registration(&self, user_id: Uuid) -> Result<WebAuthnBeginResponse> {
118        let user = UserRepository::find_by_id(&self.storage, user_id)
119            .await?
120            .ok_or(AuthError::UserNotFound)?;
121
122        let exclude_credentials = self
123            .stored_passkey(user_id)
124            .await?
125            .map(|passkey| vec![passkey.cred_id().clone()]);
126
127        let (creation, state) = self
128            .webauthn
129            .start_passkey_registration(user.id, &user.email, &user.email, exclude_credentials)
130            .map_err(|e| AuthError::Internal(format!("start passkey registration failed: {e}")))?;
131
132        let ceremony_id = self.store_registration_state(user_id, state)?;
133        Ok(WebAuthnBeginResponse {
134            challenge: ceremony_id,
135            user_id,
136            timeout_secs: self.challenge_ttl.as_secs(),
137            options: creation_challenge_to_json(creation)?,
138        })
139    }
140
141    #[instrument(skip(self, req))]
142    pub async fn finish_registration(
143        &self,
144        req: FinishRegistrationRequest,
145    ) -> Result<WebAuthnRegistrationResult> {
146        let (user_id, state) = self.take_registration_state(&req.challenge)?;
147        let passkey = self
148            .webauthn
149            .finish_passkey_registration(&req.credential, &state)
150            .map_err(|_| AuthError::InvalidToken)?;
151
152        self.persist_passkey(user_id, passkey).await?;
153
154        tracing::info!(user_id = %user_id, "webauthn credential registered");
155        Ok(WebAuthnRegistrationResult {
156            user_id,
157            credential_stored: true,
158        })
159    }
160
161    #[instrument(skip(self), fields(user_id = %user_id))]
162    pub async fn begin_authentication(&self, user_id: Uuid) -> Result<WebAuthnBeginResponse> {
163        let _ = UserRepository::find_by_id(&self.storage, user_id)
164            .await?
165            .ok_or(AuthError::UserNotFound)?;
166
167        let passkey = self
168            .stored_passkey(user_id)
169            .await?
170            .ok_or(AuthError::InvalidCredentials)?;
171        let (request, state) = self
172            .webauthn
173            .start_passkey_authentication(std::slice::from_ref(&passkey))
174            .map_err(|e| {
175                AuthError::Internal(format!("start passkey authentication failed: {e}"))
176            })?;
177
178        let ceremony_id = self.store_authentication_state(user_id, state, passkey)?;
179        Ok(WebAuthnBeginResponse {
180            challenge: ceremony_id,
181            user_id,
182            timeout_secs: self.challenge_ttl.as_secs(),
183            options: request_challenge_to_json(request)?,
184        })
185    }
186
187    #[instrument(skip(self, req))]
188    pub async fn finish_authentication(
189        &self,
190        req: FinishAuthenticationRequest,
191    ) -> Result<WebAuthnAuthenticationResult> {
192        let (user_id, state, mut passkey) = self.take_authentication_state(&req.challenge)?;
193        let auth_result = self
194            .webauthn
195            .finish_passkey_authentication(&req.credential, &state)
196            .map_err(|_| AuthError::InvalidCredentials)?;
197
198        let _ = passkey.update_credential(&auth_result);
199        self.persist_passkey(user_id, passkey).await?;
200
201        let raw_token = generate_token();
202        let token_hash = sha256_hex(raw_token.as_bytes());
203        let session = SessionRepository::create(
204            &self.storage,
205            CreateSession {
206                user_id,
207                token_hash,
208                device_info: serde_json::json!({ "webauthn": true }),
209                ip_address: req.ip,
210                org_id: None,
211                expires_at: Utc::now() + chrono::Duration::seconds(self.session_ttl_secs),
212            },
213        )
214        .await?;
215
216        tracing::info!(user_id = %user_id, session_id = %session.id, "webauthn sign-in complete");
217        Ok(WebAuthnAuthenticationResult {
218            user_id,
219            session_id: session.id,
220            token: raw_token,
221        })
222    }
223
224    async fn persist_passkey(&self, user_id: Uuid, passkey: Passkey) -> Result<()> {
225        let serialized = serde_json::to_value(passkey)
226            .map_err(|e| AuthError::Internal(format!("serialize passkey failed: {e}")))?;
227        let cred_hash = sha256_hex(
228            serde_json::to_string(&serialized)
229                .map_err(|e| AuthError::Internal(format!("encode passkey failed: {e}")))?
230                .as_bytes(),
231        );
232
233        self.delete_existing_webauthn_credential(user_id).await?;
234        CredentialRepository::create(
235            &self.storage,
236            CreateCredential {
237                user_id,
238                kind: CredentialKind::Webauthn,
239                credential_hash: cred_hash,
240                metadata: Some(serde_json::json!({
241                    "passkey": serialized,
242                    "updated_at": Utc::now(),
243                })),
244            },
245        )
246        .await?;
247        Ok(())
248    }
249
250    async fn stored_passkey(&self, user_id: Uuid) -> Result<Option<Passkey>> {
251        let credential = CredentialRepository::find_by_user_and_kind(
252            &self.storage,
253            user_id,
254            CredentialKind::Webauthn,
255        )
256        .await?;
257        let Some(credential) = credential else {
258            return Ok(None);
259        };
260
261        let passkey_value = credential
262            .metadata
263            .get("passkey")
264            .cloned()
265            .ok_or(AuthError::InvalidCredentials)?;
266        let passkey =
267            serde_json::from_value(passkey_value).map_err(|_| AuthError::InvalidCredentials)?;
268        Ok(Some(passkey))
269    }
270
271    async fn delete_existing_webauthn_credential(&self, user_id: Uuid) -> Result<()> {
272        match CredentialRepository::delete_by_user_and_kind(
273            &self.storage,
274            user_id,
275            CredentialKind::Webauthn,
276        )
277        .await
278        {
279            Ok(()) | Err(AuthError::Storage(StorageError::NotFound)) => Ok(()),
280            Err(e) => Err(e),
281        }
282    }
283
284    fn store_registration_state(
285        &self,
286        user_id: Uuid,
287        state: PasskeyRegistration,
288    ) -> Result<String> {
289        let ceremony_id = generate_token();
290        let mut pending = self
291            .pending
292            .write()
293            .map_err(|e| AuthError::Internal(format!("webauthn ceremony lock poisoned: {e}")))?;
294        purge_expired(&mut pending);
295        pending.insert(
296            ceremony_id.clone(),
297            PendingCeremony::Registration {
298                user_id,
299                state,
300                expires_at: Instant::now() + self.challenge_ttl,
301            },
302        );
303        Ok(ceremony_id)
304    }
305
306    fn store_authentication_state(
307        &self,
308        user_id: Uuid,
309        state: PasskeyAuthentication,
310        passkey: Passkey,
311    ) -> Result<String> {
312        let ceremony_id = generate_token();
313        let mut pending = self
314            .pending
315            .write()
316            .map_err(|e| AuthError::Internal(format!("webauthn ceremony lock poisoned: {e}")))?;
317        purge_expired(&mut pending);
318        pending.insert(
319            ceremony_id.clone(),
320            PendingCeremony::Authentication {
321                user_id,
322                state,
323                passkey: Box::new(passkey),
324                expires_at: Instant::now() + self.challenge_ttl,
325            },
326        );
327        Ok(ceremony_id)
328    }
329
330    fn take_registration_state(&self, ceremony_id: &str) -> Result<(Uuid, PasskeyRegistration)> {
331        let mut pending = self
332            .pending
333            .write()
334            .map_err(|e| AuthError::Internal(format!("webauthn ceremony lock poisoned: {e}")))?;
335        let Some(ceremony) = pending.remove(ceremony_id) else {
336            return Err(AuthError::InvalidToken);
337        };
338
339        match ceremony {
340            PendingCeremony::Registration {
341                user_id,
342                state,
343                expires_at,
344            } if expires_at >= Instant::now() => Ok((user_id, state)),
345            _ => Err(AuthError::InvalidToken),
346        }
347    }
348
349    fn take_authentication_state(
350        &self,
351        ceremony_id: &str,
352    ) -> Result<(Uuid, PasskeyAuthentication, Passkey)> {
353        let mut pending = self
354            .pending
355            .write()
356            .map_err(|e| AuthError::Internal(format!("webauthn ceremony lock poisoned: {e}")))?;
357        let Some(ceremony) = pending.remove(ceremony_id) else {
358            return Err(AuthError::InvalidToken);
359        };
360
361        match ceremony {
362            PendingCeremony::Authentication {
363                user_id,
364                state,
365                passkey,
366                expires_at,
367            } if expires_at >= Instant::now() => Ok((user_id, state, *passkey)),
368            _ => Err(AuthError::InvalidToken),
369        }
370    }
371}
372
373fn creation_challenge_to_json(challenge: CreationChallengeResponse) -> Result<serde_json::Value> {
374    serde_json::to_value(challenge)
375        .map_err(|e| AuthError::Internal(format!("serialize registration challenge failed: {e}")))
376}
377
378fn request_challenge_to_json(challenge: RequestChallengeResponse) -> Result<serde_json::Value> {
379    serde_json::to_value(challenge)
380        .map_err(|e| AuthError::Internal(format!("serialize authentication challenge failed: {e}")))
381}
382
383fn generate_token() -> String {
384    let bytes: [u8; 32] = rand::thread_rng().r#gen();
385    hex::encode(bytes)
386}
387
388fn purge_expired(pending: &mut HashMap<String, PendingCeremony>) {
389    let now = Instant::now();
390    pending.retain(|_, ceremony| match ceremony {
391        PendingCeremony::Registration { expires_at, .. }
392        | PendingCeremony::Authentication { expires_at, .. } => *expires_at >= now,
393    });
394}