Skip to main content

rustauth_passkey/
store.rs

1use rustauth_core::context::AuthContext;
2use rustauth_core::crypto::random::generate_random_string;
3use rustauth_core::db::{
4    DbAdapter, DbRecord, DbSchema, DbValue, Delete, FindMany, FindOne, SchemaTable, Update,
5};
6use rustauth_core::error::RustAuthError;
7use serde::{Deserialize, Serialize};
8use serde_json::Value;
9use time::OffsetDateTime;
10
11const PASSKEY_MODEL: &str = "passkey";
12
13#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
14#[serde(rename_all = "camelCase")]
15pub struct Passkey {
16    pub id: String,
17    #[serde(skip_serializing_if = "Option::is_none")]
18    pub name: Option<String>,
19    pub public_key: String,
20    pub user_id: String,
21    #[serde(rename = "credentialID")]
22    pub credential_id: String,
23    pub counter: i64,
24    pub device_type: String,
25    pub backed_up: bool,
26    #[serde(skip_serializing_if = "Option::is_none")]
27    pub transports: Option<String>,
28    #[serde(skip_serializing_if = "Option::is_none")]
29    pub created_at: Option<OffsetDateTime>,
30    #[serde(skip_serializing_if = "Option::is_none")]
31    pub aaguid: Option<String>,
32    #[serde(skip)]
33    pub webauthn_credential: Value,
34}
35
36#[derive(Clone)]
37pub struct PasskeyStore<'a> {
38    adapter: &'a dyn DbAdapter,
39    schema: DbSchema,
40}
41
42impl Passkey {
43    /// Value for `excludeCredentials` during registration (full credential or legacy id).
44    pub(crate) fn registration_exclude_value(&self) -> Value {
45        if !self.webauthn_credential.is_null() {
46            self.webauthn_credential.clone()
47        } else {
48            Value::String(self.credential_id.clone())
49        }
50    }
51
52    /// Stored WebAuthn credential state for authentication ceremonies.
53    ///
54    /// Legacy rows without `webauthn_credential` JSON are rebuilt from the
55    /// stored COSE public key and passkey metadata.
56    pub(crate) fn authentication_credential_value(&self) -> Result<Option<Value>, RustAuthError> {
57        if !self.webauthn_credential.is_null() {
58            return Ok(Some(self.webauthn_credential.clone()));
59        }
60        crate::webauthn::legacy_passkey_credential_value(
61            &self.credential_id,
62            &self.public_key,
63            self.counter,
64            &self.device_type,
65            self.backed_up,
66            self.transports.as_deref(),
67        )
68        .map(Some)
69    }
70}
71
72impl<'a> PasskeyStore<'a> {
73    pub fn with_schema(adapter: &'a dyn DbAdapter, schema: DbSchema) -> Self {
74        Self { adapter, schema }
75    }
76
77    pub fn from_context(context: &'a AuthContext) -> Result<Self, RustAuthError> {
78        Ok(Self::with_schema(
79            context.adapter_ref()?,
80            context.db_schema.clone(),
81        ))
82    }
83
84    /// Convenience alias for [`Self::from_context`].
85    pub fn new(context: &'a AuthContext) -> Result<Self, RustAuthError> {
86        Self::from_context(context)
87    }
88
89    fn passkeys(&self) -> Result<SchemaTable<'_>, RustAuthError> {
90        SchemaTable::new(&self.schema, PASSKEY_MODEL)
91    }
92
93    fn parse_passkey(&self, record: DbRecord) -> Result<Passkey, RustAuthError> {
94        passkey_from_record(self.passkeys()?.map_record(record)?)
95    }
96
97    pub async fn list_by_user(&self, user_id: &str) -> Result<Vec<Passkey>, RustAuthError> {
98        let passkeys = self.passkeys()?;
99        self.adapter
100            .find_many(
101                FindMany::new(passkeys.model()).where_clause(
102                    passkeys.where_eq("user_id", DbValue::String(user_id.to_owned()))?,
103                ),
104            )
105            .await?
106            .into_iter()
107            .map(|record| self.parse_passkey(record))
108            .collect()
109    }
110
111    pub async fn find_by_id(&self, id: &str) -> Result<Option<Passkey>, RustAuthError> {
112        let passkeys = self.passkeys()?;
113        self.adapter
114            .find_one(
115                FindOne::new(passkeys.model())
116                    .where_clause(passkeys.where_eq("id", DbValue::String(id.to_owned()))?),
117            )
118            .await?
119            .map(|record| self.parse_passkey(record))
120            .transpose()
121    }
122
123    pub async fn find_by_credential_id(
124        &self,
125        credential_id: &str,
126    ) -> Result<Option<Passkey>, RustAuthError> {
127        let passkeys = self.passkeys()?;
128        self.adapter
129            .find_one(FindOne::new(passkeys.model()).where_clause(
130                passkeys.where_eq("credential_id", DbValue::String(credential_id.to_owned()))?,
131            ))
132            .await?
133            .map(|record| self.parse_passkey(record))
134            .transpose()
135    }
136
137    pub async fn create(
138        &self,
139        user_id: &str,
140        name: Option<String>,
141        credential: crate::webauthn::VerifiedPasskeyCredential,
142    ) -> Result<Passkey, RustAuthError> {
143        let passkeys = self.passkeys()?;
144        let now = OffsetDateTime::now_utc();
145        let record = self
146            .adapter
147            .create(
148                passkeys
149                    .create()
150                    .data("id", DbValue::String(generate_random_string(32)))
151                    .data("name", optional_string(name))
152                    .data("public_key", DbValue::String(credential.public_key))
153                    .data("user_id", DbValue::String(user_id.to_owned()))
154                    .data("credential_id", DbValue::String(credential.credential_id))
155                    .data("counter", DbValue::Number(i64::from(credential.counter)))
156                    .data("device_type", DbValue::String(credential.device_type))
157                    .data("backed_up", DbValue::Boolean(credential.backed_up))
158                    .data("transports", optional_string(credential.transports))
159                    .data("created_at", DbValue::Timestamp(now))
160                    .data("aaguid", optional_string(credential.aaguid))
161                    .data("webauthn_credential", DbValue::Json(credential.credential))
162                    .force_allow_id(),
163            )
164            .await?;
165        self.parse_passkey(record)
166    }
167
168    pub async fn update_name_for_user(
169        &self,
170        id: &str,
171        user_id: &str,
172        name: String,
173    ) -> Result<Option<Passkey>, RustAuthError> {
174        let passkeys = self.passkeys()?;
175        self.adapter
176            .update(
177                Update::new(passkeys.model())
178                    .where_clause(passkeys.where_eq("id", DbValue::String(id.to_owned()))?)
179                    .where_clause(
180                        passkeys.where_eq("user_id", DbValue::String(user_id.to_owned()))?,
181                    )
182                    .data("name", DbValue::String(name)),
183            )
184            .await?
185            .map(|record| self.parse_passkey(record))
186            .transpose()
187    }
188
189    pub async fn update_after_authentication(
190        &self,
191        id: &str,
192        expected_counter: i64,
193        verification: crate::webauthn::VerifiedAuthentication,
194    ) -> Result<Option<Passkey>, RustAuthError> {
195        let passkeys = self.passkeys()?;
196        let mut update = Update::new(passkeys.model())
197            .where_clause(passkeys.where_eq("id", DbValue::String(id.to_owned()))?)
198            .where_clause(passkeys.where_eq("counter", DbValue::Number(expected_counter))?)
199            .data(
200                "counter",
201                DbValue::Number(i64::from(verification.new_counter)),
202            );
203        if let Some(credential) = verification.credential {
204            update = update.data("webauthn_credential", DbValue::Json(credential));
205        }
206        self.adapter
207            .update(update)
208            .await?
209            .map(|record| self.parse_passkey(record))
210            .transpose()
211    }
212
213    pub async fn delete_for_user(&self, id: &str, user_id: &str) -> Result<bool, RustAuthError> {
214        let passkeys = self.passkeys()?;
215        let Some(passkey) = self.find_by_id(id).await? else {
216            return Ok(false);
217        };
218        if passkey.user_id != user_id {
219            return Ok(false);
220        }
221        self.adapter
222            .delete(
223                Delete::new(passkeys.model())
224                    .where_clause(passkeys.where_eq("id", DbValue::String(id.to_owned()))?),
225            )
226            .await?;
227        Ok(true)
228    }
229}
230
231fn optional_string(value: Option<String>) -> DbValue {
232    value.map(DbValue::String).unwrap_or(DbValue::Null)
233}
234
235fn passkey_from_record(record: DbRecord) -> Result<Passkey, RustAuthError> {
236    Ok(Passkey {
237        id: required_string(&record, "id")?.to_owned(),
238        name: optional_string_field(&record, "name")?,
239        public_key: required_string(&record, "public_key")?.to_owned(),
240        user_id: required_string(&record, "user_id")?.to_owned(),
241        credential_id: required_string(&record, "credential_id")?.to_owned(),
242        counter: required_number(&record, "counter")?,
243        device_type: required_string(&record, "device_type")?.to_owned(),
244        backed_up: required_bool(&record, "backed_up")?,
245        transports: optional_string_field(&record, "transports")?,
246        created_at: optional_timestamp(&record, "created_at")?,
247        aaguid: optional_string_field(&record, "aaguid")?,
248        webauthn_credential: match record.get("webauthn_credential") {
249            Some(DbValue::Json(value)) => value.clone(),
250            Some(DbValue::Null) | None => Value::Null,
251            Some(_) => return Err(invalid_field("webauthn_credential", "json")),
252        },
253    })
254}
255
256fn required_string<'a>(record: &'a DbRecord, field: &str) -> Result<&'a str, RustAuthError> {
257    match record.get(field) {
258        Some(DbValue::String(value)) => Ok(value),
259        Some(_) => Err(invalid_field(field, "string")),
260        None => Err(missing_field(field)),
261    }
262}
263
264fn optional_string_field(record: &DbRecord, field: &str) -> Result<Option<String>, RustAuthError> {
265    match record.get(field) {
266        Some(DbValue::String(value)) => Ok(Some(value.to_owned())),
267        Some(DbValue::Null) | None => Ok(None),
268        Some(_) => Err(invalid_field(field, "string or null")),
269    }
270}
271
272fn required_number(record: &DbRecord, field: &str) -> Result<i64, RustAuthError> {
273    match record.get(field) {
274        Some(DbValue::Number(value)) => Ok(*value),
275        Some(_) => Err(invalid_field(field, "number")),
276        None => Err(missing_field(field)),
277    }
278}
279
280fn required_bool(record: &DbRecord, field: &str) -> Result<bool, RustAuthError> {
281    match record.get(field) {
282        Some(DbValue::Boolean(value)) => Ok(*value),
283        Some(_) => Err(invalid_field(field, "boolean")),
284        None => Err(missing_field(field)),
285    }
286}
287
288fn optional_timestamp(
289    record: &DbRecord,
290    field: &str,
291) -> Result<Option<OffsetDateTime>, RustAuthError> {
292    match record.get(field) {
293        Some(DbValue::Timestamp(value)) => Ok(Some(*value)),
294        Some(DbValue::Null) | None => Ok(None),
295        Some(_) => Err(invalid_field(field, "timestamp or null")),
296    }
297}
298
299fn missing_field(field: &str) -> RustAuthError {
300    RustAuthError::Adapter(format!("passkey record is missing `{field}`"))
301}
302
303fn invalid_field(field: &str, expected: &str) -> RustAuthError {
304    RustAuthError::Adapter(format!("passkey record field `{field}` must be {expected}"))
305}