use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use thiserror::Error;
#[cfg(feature = "passkey")]
use std::sync::Arc;
#[cfg(feature = "passkey")]
use dashmap::DashMap;
#[cfg(feature = "passkey")]
use url::Url;
#[cfg(feature = "passkey")]
use webauthn_rs::prelude::{
CreationChallengeResponse, Passkey, PasskeyAuthentication, PasskeyRegistration,
PublicKeyCredential, RegisterPublicKeyCredential, RequestChallengeResponse, Uuid,
};
#[cfg(feature = "passkey")]
use webauthn_rs::{Webauthn, WebauthnBuilder};
#[derive(Debug, Error)]
pub enum PasskeyError {
#[error("passkey backend not implemented")]
Unimplemented,
#[error("verification failed: {0}")]
Verification(String),
#[error("backend: {0}")]
Backend(String),
#[error("configuration: {0}")]
Config(String),
#[error("no in-flight ceremony for user: {0}")]
NoCeremony(String),
#[error("response parse: {0}")]
Parse(String),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RegistrationOptions {
#[serde(flatten)]
pub raw: serde_json::Value,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RegistrationResponse {
pub id: String,
#[serde(flatten)]
pub raw: serde_json::Value,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AuthenticationOptions {
#[serde(flatten)]
pub raw: serde_json::Value,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AuthenticationResponse {
pub id: String,
#[serde(flatten)]
pub raw: serde_json::Value,
}
#[async_trait]
pub trait PasskeyBackend: Send + Sync + 'static {
async fn registration_options(
&self,
account_id: &str,
) -> Result<RegistrationOptions, PasskeyError>;
async fn registration_verify(
&self,
account_id: &str,
resp: RegistrationResponse,
) -> Result<(), PasskeyError>;
async fn authentication_options(
&self,
account_id: &str,
) -> Result<AuthenticationOptions, PasskeyError>;
async fn authentication_verify(
&self,
resp: AuthenticationResponse,
) -> Result<String, PasskeyError>;
}
#[doc(hidden)]
pub struct PasskeyTodo;
#[doc(hidden)]
pub type NullPasskeyBackend = PasskeyTodo;
#[async_trait]
impl PasskeyBackend for PasskeyTodo {
async fn registration_options(
&self,
_account_id: &str,
) -> Result<RegistrationOptions, PasskeyError> {
Err(PasskeyError::Unimplemented)
}
async fn registration_verify(
&self,
_account_id: &str,
_resp: RegistrationResponse,
) -> Result<(), PasskeyError> {
Err(PasskeyError::Unimplemented)
}
async fn authentication_options(
&self,
_account_id: &str,
) -> Result<AuthenticationOptions, PasskeyError> {
Err(PasskeyError::Unimplemented)
}
async fn authentication_verify(
&self,
_resp: AuthenticationResponse,
) -> Result<String, PasskeyError> {
Err(PasskeyError::Unimplemented)
}
}
#[cfg(feature = "passkey")]
pub struct WebauthnPasskey {
webauthn: Arc<Webauthn>,
registration_state: DashMap<String, PasskeyRegistration>,
authentication_state: DashMap<String, PasskeyAuthentication>,
credentials: DashMap<String, Vec<Passkey>>,
}
#[cfg(feature = "passkey")]
impl WebauthnPasskey {
pub fn new(rp_id: &str, rp_name: &str, origin: &Url) -> Result<Self, PasskeyError> {
let builder = WebauthnBuilder::new(rp_id, origin)
.map_err(|e| PasskeyError::Config(e.to_string()))?
.rp_name(rp_name);
let webauthn = builder
.build()
.map_err(|e| PasskeyError::Config(e.to_string()))?;
Ok(Self {
webauthn: Arc::new(webauthn),
registration_state: DashMap::new(),
authentication_state: DashMap::new(),
credentials: DashMap::new(),
})
}
fn account_uuid(account_id: &str) -> Uuid {
use sha2::{Digest, Sha256};
let digest = Sha256::digest(account_id.as_bytes());
let mut bytes = [0u8; 16];
bytes.copy_from_slice(&digest[..16]);
bytes[6] = (bytes[6] & 0x0f) | 0x40;
bytes[8] = (bytes[8] & 0x3f) | 0x80;
Uuid::from_bytes(bytes)
}
pub fn credentials_for(&self, account_id: &str) -> Vec<Passkey> {
self.credentials
.get(account_id)
.map(|v| v.clone())
.unwrap_or_default()
}
}
#[cfg(feature = "passkey")]
#[async_trait]
impl PasskeyBackend for WebauthnPasskey {
async fn registration_options(
&self,
account_id: &str,
) -> Result<RegistrationOptions, PasskeyError> {
let uuid = Self::account_uuid(account_id);
let existing: Vec<_> = self
.credentials_for(account_id)
.iter()
.map(|p| p.cred_id().clone())
.collect();
let exclude = if existing.is_empty() {
None
} else {
Some(existing)
};
let (ccr, state): (CreationChallengeResponse, PasskeyRegistration) = self
.webauthn
.start_passkey_registration(uuid, account_id, account_id, exclude)
.map_err(|e| PasskeyError::Verification(e.to_string()))?;
self.registration_state
.insert(account_id.to_string(), state);
let raw = serde_json::to_value(&ccr)
.map_err(|e| PasskeyError::Backend(format!("serialise ccr: {e}")))?;
Ok(RegistrationOptions { raw })
}
async fn registration_verify(
&self,
account_id: &str,
resp: RegistrationResponse,
) -> Result<(), PasskeyError> {
let (_, state) = self
.registration_state
.remove(account_id)
.ok_or_else(|| PasskeyError::NoCeremony(account_id.to_string()))?;
let mut value = resp.raw.clone();
if let serde_json::Value::Object(ref mut map) = value {
map.insert("id".into(), serde_json::Value::String(resp.id.clone()));
}
let reg: RegisterPublicKeyCredential = serde_json::from_value(value)
.map_err(|e| PasskeyError::Parse(e.to_string()))?;
let passkey = self
.webauthn
.finish_passkey_registration(®, &state)
.map_err(|e| PasskeyError::Verification(e.to_string()))?;
self.credentials
.entry(account_id.to_string())
.or_default()
.push(passkey);
Ok(())
}
async fn authentication_options(
&self,
account_id: &str,
) -> Result<AuthenticationOptions, PasskeyError> {
let creds = self.credentials_for(account_id);
if creds.is_empty() {
return Err(PasskeyError::NoCeremony(format!(
"no passkeys registered for {account_id}"
)));
}
let (rcr, state): (RequestChallengeResponse, PasskeyAuthentication) = self
.webauthn
.start_passkey_authentication(&creds)
.map_err(|e| PasskeyError::Verification(e.to_string()))?;
self.authentication_state
.insert(account_id.to_string(), state);
let raw = serde_json::to_value(&rcr)
.map_err(|e| PasskeyError::Backend(format!("serialise rcr: {e}")))?;
Ok(AuthenticationOptions { raw })
}
async fn authentication_verify(
&self,
resp: AuthenticationResponse,
) -> Result<String, PasskeyError> {
let mut matched: Option<String> = None;
for cred_entry in self.credentials.iter() {
if cred_entry
.value()
.iter()
.any(|p| base64url_matches(p.cred_id().as_ref(), &resp.id))
{
matched = Some(cred_entry.key().clone());
break;
}
}
let account_id = matched
.ok_or_else(|| PasskeyError::Verification(format!("unknown credential {}", resp.id)))?;
let (_, state) = self
.authentication_state
.remove(&account_id)
.ok_or_else(|| PasskeyError::NoCeremony(account_id.clone()))?;
let mut value = resp.raw.clone();
if let serde_json::Value::Object(ref mut map) = value {
map.insert("id".into(), serde_json::Value::String(resp.id.clone()));
}
let cred: PublicKeyCredential = serde_json::from_value(value)
.map_err(|e| PasskeyError::Parse(e.to_string()))?;
self.webauthn
.finish_passkey_authentication(&cred, &state)
.map_err(|e| PasskeyError::Verification(e.to_string()))?;
Ok(account_id)
}
}
#[cfg(feature = "passkey")]
fn base64url_matches(bin: &[u8], txt: &str) -> bool {
use base64::engine::general_purpose::URL_SAFE_NO_PAD;
use base64::Engine;
let encoded = URL_SAFE_NO_PAD.encode(bin);
encoded == txt.trim_end_matches('=')
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn passkey_todo_is_callable_and_returns_unimplemented() {
let backend = PasskeyTodo;
let err = backend.registration_options("acct-1").await.unwrap_err();
assert!(matches!(err, PasskeyError::Unimplemented));
let err = backend
.authentication_options("acct-1")
.await
.unwrap_err();
assert!(matches!(err, PasskeyError::Unimplemented));
}
#[cfg(feature = "passkey")]
#[tokio::test]
async fn webauthn_passkey_constructs_with_reasonable_defaults() {
let origin = Url::parse("https://idp.example.com").unwrap();
let _pk = WebauthnPasskey::new("idp.example.com", "Example IdP", &origin)
.expect("WebauthnPasskey::new with defaults");
}
#[cfg(feature = "passkey")]
#[tokio::test]
async fn start_registration_returns_non_empty_challenge() {
let origin = Url::parse("https://idp.example.com").unwrap();
let pk = WebauthnPasskey::new("idp.example.com", "Example IdP", &origin).unwrap();
let opts = pk.registration_options("alice").await.unwrap();
let challenge = opts
.raw
.pointer("/publicKey/challenge")
.and_then(|v| v.as_str())
.expect("challenge string present");
assert!(!challenge.is_empty(), "challenge should be non-empty");
assert!(
pk.registration_state.contains_key("alice"),
"registration state recorded for alice"
);
}
#[cfg(feature = "passkey")]
#[tokio::test]
async fn start_registration_is_isolated_per_user() {
let origin = Url::parse("https://idp.example.com").unwrap();
let pk = WebauthnPasskey::new("idp.example.com", "Example IdP", &origin).unwrap();
pk.registration_options("alice").await.unwrap();
pk.registration_options("bob").await.unwrap();
assert!(pk.registration_state.contains_key("alice"));
assert!(pk.registration_state.contains_key("bob"));
assert_eq!(
pk.registration_state.len(),
2,
"per-user isolation retains both states"
);
}
#[cfg(feature = "passkey")]
#[tokio::test]
async fn registration_verify_without_start_is_rejected() {
let origin = Url::parse("https://idp.example.com").unwrap();
let pk = WebauthnPasskey::new("idp.example.com", "Example IdP", &origin).unwrap();
let err = pk
.registration_verify(
"ghost",
RegistrationResponse {
id: "abc".into(),
raw: serde_json::json!({}),
},
)
.await
.unwrap_err();
assert!(matches!(err, PasskeyError::NoCeremony(_)));
}
#[cfg(feature = "passkey")]
#[tokio::test]
async fn authentication_options_rejects_user_with_no_credentials() {
let origin = Url::parse("https://idp.example.com").unwrap();
let pk = WebauthnPasskey::new("idp.example.com", "Example IdP", &origin).unwrap();
let err = pk
.authentication_options("never-registered")
.await
.unwrap_err();
assert!(matches!(err, PasskeyError::NoCeremony(_)));
}
#[cfg(feature = "passkey")]
#[tokio::test]
async fn account_uuid_is_deterministic_and_v4() {
let a = WebauthnPasskey::account_uuid("alice");
let a2 = WebauthnPasskey::account_uuid("alice");
let b = WebauthnPasskey::account_uuid("bob");
assert_eq!(a, a2, "deterministic");
assert_ne!(a, b, "per-user unique");
assert_eq!(a.get_version_num(), 4, "RFC 4122 v4");
}
}