oauth2_passkey/passkey/storage/
postgres.rs

1use crate::storage::validate_postgres_table_schema;
2use crate::userdb::DB_TABLE_USERS;
3use chrono::{DateTime, Utc};
4use sqlx::{Pool, Postgres};
5
6use crate::passkey::errors::PasskeyError;
7use crate::passkey::types::{
8    CredentialSearchField, PasskeyCredential, PublicKeyCredentialUserEntity,
9};
10
11use super::config::DB_TABLE_PASSKEY_CREDENTIALS;
12
13// PostgreSQL implementations
14pub(super) async fn create_tables_postgres(pool: &Pool<Postgres>) -> Result<(), PasskeyError> {
15    let passkey_table = DB_TABLE_PASSKEY_CREDENTIALS.as_str();
16    let users_table = DB_TABLE_USERS.as_str();
17
18    sqlx::query(&format!(
19        r#"
20        CREATE TABLE IF NOT EXISTS {passkey_table} (
21            credential_id TEXT PRIMARY KEY NOT NULL,
22            user_id TEXT NOT NULL REFERENCES {users_table}(id),
23            public_key TEXT NOT NULL,
24            counter INTEGER NOT NULL DEFAULT 0,
25            user_handle TEXT NOT NULL,
26            user_name TEXT NOT NULL,
27            user_display_name TEXT NOT NULL,
28            aaguid TEXT NOT NULL,
29            created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
30            updated_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
31            last_used_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
32            FOREIGN KEY (user_id) REFERENCES {users_table}(id)
33        )
34        "#
35    ))
36    .execute(pool)
37    .await
38    .map_err(|e| PasskeyError::Storage(e.to_string()))?;
39
40    sqlx::query(&format!(
41        r#"
42        CREATE INDEX IF NOT EXISTS idx_{}_user_name ON {}(user_name);
43        "#,
44        passkey_table.replace(".", "_"),
45        passkey_table
46    ))
47    .execute(pool)
48    .await
49    .map_err(|e| PasskeyError::Storage(e.to_string()))?;
50
51    sqlx::query(&format!(
52        r#"
53        CREATE INDEX IF NOT EXISTS idx_{}_user_id ON {}(user_id);
54        "#,
55        passkey_table.replace(".", "_"),
56        passkey_table
57    ))
58    .execute(pool)
59    .await
60    .map_err(|e| PasskeyError::Storage(e.to_string()))?;
61
62    Ok(())
63}
64
65/// Validates that the Passkey credential table schema matches what we expect
66pub(super) async fn validate_passkey_tables_postgres(
67    pool: &Pool<Postgres>,
68) -> Result<(), PasskeyError> {
69    let passkey_table = DB_TABLE_PASSKEY_CREDENTIALS.as_str();
70
71    // Define expected schema (column name, data type)
72    let expected_columns = [
73        ("credential_id", "text"),
74        ("user_id", "text"),
75        ("public_key", "text"),
76        ("counter", "integer"),
77        ("user_handle", "text"),
78        ("user_name", "text"),
79        ("user_display_name", "text"),
80        ("aaguid", "text"),
81        ("created_at", "timestamp with time zone"),
82        ("updated_at", "timestamp with time zone"),
83        ("last_used_at", "timestamp with time zone"),
84    ];
85
86    validate_postgres_table_schema(
87        pool,
88        passkey_table,
89        &expected_columns,
90        PasskeyError::Storage,
91    )
92    .await
93}
94
95pub(super) async fn store_credential_postgres(
96    pool: &Pool<Postgres>,
97    credential_id: &str,
98    credential: &PasskeyCredential,
99) -> Result<(), PasskeyError> {
100    let counter_i32 = credential.counter as i32;
101    let public_key = &credential.public_key;
102    let user_id = &credential.user_id;
103    let user_handle = &credential.user.user_handle;
104    let user_name = &credential.user.name;
105    let user_display_name = &credential.user.display_name;
106    let aaguid = &credential.aaguid;
107    let created_at = &credential.created_at;
108    let updated_at = &credential.updated_at;
109    let last_used_at = &credential.last_used_at;
110    let passkey_table = DB_TABLE_PASSKEY_CREDENTIALS.as_str();
111
112    sqlx::query_as::<_, (i32,)>(&format!(
113        r#"
114        INSERT INTO {passkey_table}
115        (credential_id, user_id, public_key, counter, user_handle, user_name, user_display_name, aaguid, created_at, updated_at, last_used_at)
116        VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
117        ON CONFLICT (credential_id) DO UPDATE
118        SET user_id = $2, public_key = $3, counter = $4, user_handle = $5, user_name = $6, user_display_name = $7, aaguid = $8, updated_at = CURRENT_TIMESTAMP, last_used_at = CURRENT_TIMESTAMP
119        RETURNING 1
120        "#
121    ))
122    .bind(credential_id)
123    .bind(user_id)
124    .bind(public_key)
125    .bind(counter_i32)
126    .bind(user_handle)
127    .bind(user_name)
128    .bind(user_display_name)
129    .bind(aaguid)
130    .bind(created_at)
131    .bind(updated_at)
132    .bind(last_used_at)
133    .fetch_optional(pool)
134    .await
135    .map_err(|e| PasskeyError::Storage(e.to_string()))?;
136
137    Ok(())
138}
139
140pub(super) async fn get_credential_postgres(
141    pool: &Pool<Postgres>,
142    credential_id: &str,
143) -> Result<Option<PasskeyCredential>, PasskeyError> {
144    let passkey_table = DB_TABLE_PASSKEY_CREDENTIALS.as_str();
145
146    sqlx::query_as::<_, PasskeyCredential>(&format!(
147        r#"SELECT * FROM {passkey_table} WHERE credential_id = $1"#
148    ))
149    .bind(credential_id)
150    .fetch_optional(pool)
151    .await
152    .map_err(|e| PasskeyError::Storage(e.to_string()))
153}
154
155pub(super) async fn get_credentials_by_field_postgres(
156    pool: &Pool<Postgres>,
157    field: &CredentialSearchField,
158) -> Result<Vec<PasskeyCredential>, PasskeyError> {
159    let passkey_table = DB_TABLE_PASSKEY_CREDENTIALS.as_str();
160    let (query, value) = match field {
161        CredentialSearchField::CredentialId(credential_id) => (
162            &format!(r#"SELECT * FROM {passkey_table} WHERE credential_id = $1"#),
163            credential_id.as_str(),
164        ),
165        CredentialSearchField::UserId(id) => (
166            &format!(r#"SELECT * FROM {passkey_table} WHERE user_id = $1"#),
167            id.as_str(),
168        ),
169        CredentialSearchField::UserHandle(handle) => (
170            &format!(r#"SELECT * FROM {passkey_table} WHERE user_handle = $1"#),
171            handle.as_str(),
172        ),
173        CredentialSearchField::UserName(name) => (
174            &format!(r#"SELECT * FROM {passkey_table} WHERE user_name = $1"#),
175            name.as_str(),
176        ),
177    };
178
179    sqlx::query_as::<_, PasskeyCredential>(query)
180        .bind(value)
181        .fetch_all(pool)
182        .await
183        .map_err(|e| PasskeyError::Storage(e.to_string()))
184}
185
186pub(super) async fn update_credential_counter_postgres(
187    pool: &Pool<Postgres>,
188    credential_id: &str,
189    counter: u32,
190) -> Result<(), PasskeyError> {
191    let counter_i32 = counter as i32;
192    let passkey_table = DB_TABLE_PASSKEY_CREDENTIALS.as_str();
193
194    sqlx::query_as::<_, (i32,)>(&format!(
195        r#"
196        UPDATE {passkey_table}
197        SET counter = $1, updated_at = CURRENT_TIMESTAMP
198        WHERE credential_id = $2
199        RETURNING 1
200        "#
201    ))
202    .bind(counter_i32)
203    .bind(credential_id)
204    .fetch_optional(pool)
205    .await
206    .map_err(|e| PasskeyError::Storage(e.to_string()))?;
207
208    Ok(())
209}
210
211pub(super) async fn delete_credential_by_field_postgres(
212    pool: &Pool<Postgres>,
213    field: &CredentialSearchField,
214) -> Result<(), PasskeyError> {
215    let passkey_table = DB_TABLE_PASSKEY_CREDENTIALS.as_str();
216    let (query, value) = match field {
217        CredentialSearchField::CredentialId(credential_id) => (
218            &format!(r#"DELETE FROM {passkey_table} WHERE credential_id = $1"#),
219            credential_id.as_str(),
220        ),
221        CredentialSearchField::UserId(id) => (
222            &format!(r#"DELETE FROM {passkey_table} WHERE user_id = $1"#),
223            id.as_str(),
224        ),
225        CredentialSearchField::UserHandle(handle) => (
226            &format!(r#"DELETE FROM {passkey_table} WHERE user_handle = $1"#),
227            handle.as_str(),
228        ),
229        CredentialSearchField::UserName(name) => (
230            &format!(r#"DELETE FROM {passkey_table} WHERE user_name = $1"#),
231            name.as_str(),
232        ),
233    };
234
235    sqlx::query(query)
236        .bind(value)
237        .execute(pool)
238        .await
239        .map_err(|e| PasskeyError::Storage(e.to_string()))?;
240
241    Ok(())
242}
243
244pub(super) async fn update_credential_user_details_postgres(
245    pool: &Pool<Postgres>,
246    credential_id: &str,
247    name: &str,
248    display_name: &str,
249) -> Result<(), PasskeyError> {
250    let passkey_table = DB_TABLE_PASSKEY_CREDENTIALS.as_str();
251
252    sqlx::query(&format!(
253        r#"UPDATE {passkey_table} SET user_name = $1, user_display_name = $2 WHERE credential_id = $3"#
254    ))
255    .bind(name)
256    .bind(display_name)
257    .bind(credential_id)
258    .execute(pool)
259    .await
260    .map_err(|e| PasskeyError::Storage(e.to_string()))?;
261
262    Ok(())
263}
264
265use sqlx::{FromRow, Row, postgres::PgRow, sqlite::SqliteRow};
266
267// Implement FromRow for PasskeyCredential to handle the flattened database structure for SQLite
268impl<'r> FromRow<'r, SqliteRow> for PasskeyCredential {
269    fn from_row(row: &'r SqliteRow) -> Result<Self, sqlx::Error> {
270        let credential_id: String = row.try_get("credential_id")?;
271        let user_id: String = row.try_get("user_id")?;
272        let public_key: String = row.try_get("public_key")?;
273        let counter: i64 = row.try_get("counter")?;
274        let user_handle: String = row.try_get("user_handle")?;
275        let user_name: String = row.try_get("user_name")?;
276        let user_display_name: String = row.try_get("user_display_name")?;
277        let aaguid: String = row.try_get("aaguid")?;
278        let created_at: DateTime<Utc> = row.try_get("created_at")?;
279        let updated_at: DateTime<Utc> = row.try_get("updated_at")?;
280        let last_used_at: DateTime<Utc> = row.try_get("last_used_at")?;
281
282        Ok(PasskeyCredential {
283            credential_id,
284            user_id,
285            public_key,
286            counter: counter as u32,
287            user: PublicKeyCredentialUserEntity {
288                user_handle,
289                name: user_name,
290                display_name: user_display_name,
291            },
292            aaguid,
293            created_at,
294            updated_at,
295            last_used_at,
296        })
297    }
298}
299
300// Implement FromRow for PasskeyCredential to handle the flattened database structure for PostgreSQL
301impl<'r> FromRow<'r, PgRow> for PasskeyCredential {
302    fn from_row(row: &'r PgRow) -> Result<Self, sqlx::Error> {
303        let credential_id: String = row.try_get("credential_id")?;
304        let user_id: String = row.try_get("user_id")?;
305        let public_key: String = row.try_get("public_key")?;
306        let counter: i32 = row.try_get("counter")?;
307        let user_handle: String = row.try_get("user_handle")?;
308        let user_name: String = row.try_get("user_name")?;
309        let user_display_name: String = row.try_get("user_display_name")?;
310        let aaguid: String = row.try_get("aaguid")?;
311        let created_at: DateTime<Utc> = row.try_get("created_at")?;
312        let updated_at: DateTime<Utc> = row.try_get("updated_at")?;
313        let last_used_at: DateTime<Utc> = row.try_get("last_used_at")?;
314
315        Ok(PasskeyCredential {
316            credential_id,
317            user_id,
318            public_key,
319            counter: counter as u32,
320            user: PublicKeyCredentialUserEntity {
321                user_handle,
322                name: user_name,
323                display_name: user_display_name,
324            },
325            aaguid,
326            created_at,
327            updated_at,
328            last_used_at,
329        })
330    }
331}
332
333pub(super) async fn update_credential_last_used_at_postgres(
334    pool: &Pool<Postgres>,
335    credential_id: &str,
336    last_used_at: DateTime<Utc>,
337) -> Result<(), PasskeyError> {
338    let passkey_table = DB_TABLE_PASSKEY_CREDENTIALS.as_str();
339
340    sqlx::query(&format!(
341        r#"UPDATE {passkey_table} SET last_used_at = $1 WHERE credential_id = $2"#
342    ))
343    .bind(last_used_at)
344    .bind(credential_id)
345    .execute(pool)
346    .await
347    .map_err(|e| PasskeyError::Storage(e.to_string()))?;
348
349    Ok(())
350}