oauth2-passkey 0.6.0

OAuth2 and Passkey authentication library for Rust web applications
Documentation
use crate::storage::validate_sqlite_table_schema;
use crate::userdb::DB_TABLE_USERS;
use chrono::{DateTime, Utc};
use sqlx::{Pool, Sqlite};

use crate::passkey::errors::PasskeyError;
use crate::passkey::types::{CredentialId, CredentialSearchField, PasskeyCredential};

use super::config::DB_TABLE_PASSKEY_CREDENTIALS;

// SQLite implementations
pub(super) async fn create_tables_sqlite(pool: &Pool<Sqlite>) -> Result<(), PasskeyError> {
    let passkey_table = DB_TABLE_PASSKEY_CREDENTIALS.as_str();
    let users_table = DB_TABLE_USERS.as_str();

    sqlx::query(&format!(
        r#"
        CREATE TABLE IF NOT EXISTS {passkey_table} (
            sequence_number INTEGER PRIMARY KEY AUTOINCREMENT,
            credential_id TEXT NOT NULL UNIQUE,
            user_id TEXT NOT NULL,
            public_key TEXT NOT NULL,
            counter INTEGER NOT NULL DEFAULT 0,
            user_handle TEXT NOT NULL,
            user_name TEXT NOT NULL,
            user_display_name TEXT NOT NULL,
            aaguid TEXT NOT NULL,
            rp_id TEXT NOT NULL DEFAULT '',
            created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
            updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
            last_used_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
            FOREIGN KEY (user_id) REFERENCES {users_table}(id) ON DELETE CASCADE
        )
        "#
    ))
    .execute(pool)
    .await
    .map_err(|e| PasskeyError::Storage(e.to_string()))?;

    sqlx::query(&format!(
        r#"
        CREATE INDEX IF NOT EXISTS idx_{}_user_name ON {}(user_name);
        CREATE INDEX IF NOT EXISTS idx_{}_user_id ON {}(user_id);
        "#,
        passkey_table.replace(".", "_"),
        passkey_table,
        passkey_table.replace(".", "_"),
        passkey_table
    ))
    .execute(pool)
    .await
    .map_err(|e| PasskeyError::Storage(e.to_string()))?;

    Ok(())
}

/// Validates that the Passkey credential table schema matches what we expect
pub(super) async fn validate_passkey_tables_sqlite(
    pool: &Pool<Sqlite>,
) -> Result<(), PasskeyError> {
    let passkey_table = DB_TABLE_PASSKEY_CREDENTIALS.as_str();

    // Define expected schema (column name, data type)
    let expected_columns = vec![
        ("sequence_number", "INTEGER"),
        ("credential_id", "TEXT"),
        ("user_id", "TEXT"),
        ("public_key", "TEXT"),
        ("counter", "INTEGER"),
        ("user_handle", "TEXT"),
        ("user_name", "TEXT"),
        ("user_display_name", "TEXT"),
        ("aaguid", "TEXT"),
        ("rp_id", "TEXT"),
        ("created_at", "TIMESTAMP"),
        ("updated_at", "TIMESTAMP"),
        ("last_used_at", "TIMESTAMP"),
    ];

    validate_sqlite_table_schema(
        pool,
        passkey_table,
        &expected_columns,
        PasskeyError::Storage,
    )
    .await
}

pub(super) async fn store_credential_sqlite(
    pool: &Pool<Sqlite>,
    credential_id: CredentialId,
    credential: &PasskeyCredential,
) -> Result<(), PasskeyError> {
    let counter_i64 = credential.counter as i64;
    let public_key = &credential.public_key;
    let user_id = &credential.user_id;
    let user_handle = &credential.user.user_handle;
    let user_name = &credential.user.name;
    let user_display_name = &credential.user.display_name;
    let aaguid = &credential.aaguid;
    let rp_id = &credential.rp_id;
    let created_at = &credential.created_at;
    let updated_at = &credential.updated_at;
    let last_used_at = &credential.last_used_at;
    let passkey_table = DB_TABLE_PASSKEY_CREDENTIALS.as_str();

    sqlx::query(&format!(
        r#"
        INSERT OR REPLACE INTO {passkey_table}
        (credential_id, user_id, public_key, counter, user_handle, user_name, user_display_name, aaguid, rp_id, created_at, updated_at, last_used_at)
        VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
        "#
    ))
    .bind(credential_id.as_str())
    .bind(user_id)
    .bind(public_key)
    .bind(counter_i64)
    .bind(user_handle)
    .bind(user_name)
    .bind(user_display_name)
    .bind(aaguid)
    .bind(rp_id)
    .bind(created_at)
    .bind(updated_at)
    .bind(last_used_at)
    .execute(pool)
    .await
    .map_err(|e| PasskeyError::Storage(e.to_string()))?;

    Ok(())
}

pub(super) async fn get_credential_sqlite(
    pool: &Pool<Sqlite>,
    credential_id: CredentialId,
) -> Result<Option<PasskeyCredential>, PasskeyError> {
    let passkey_table = DB_TABLE_PASSKEY_CREDENTIALS.as_str();

    sqlx::query_as::<_, PasskeyCredential>(&format!(
        r#"SELECT * FROM {passkey_table} WHERE credential_id = ?"#
    ))
    .bind(credential_id.as_str())
    .fetch_optional(pool)
    .await
    .map_err(|e| PasskeyError::Storage(e.to_string()))
}

pub(super) async fn get_credentials_by_field_sqlite(
    pool: &Pool<Sqlite>,
    field: &CredentialSearchField,
) -> Result<Vec<PasskeyCredential>, PasskeyError> {
    let passkey_table = DB_TABLE_PASSKEY_CREDENTIALS.as_str();
    let (query, value) = match field {
        CredentialSearchField::CredentialId(credential_id) => (
            &format!(r#"SELECT * FROM {passkey_table} WHERE credential_id = ?"#),
            credential_id.as_str(),
        ),
        CredentialSearchField::UserId(id) => (
            &format!(r#"SELECT * FROM {passkey_table} WHERE user_id = ?"#),
            id.as_str(),
        ),
        CredentialSearchField::UserHandle(handle) => (
            &format!(r#"SELECT * FROM {passkey_table} WHERE user_handle = ?"#),
            handle.as_str(),
        ),
        CredentialSearchField::UserName(name) => (
            &format!(r#"SELECT * FROM {passkey_table} WHERE user_name = ?"#),
            name.as_str(),
        ),
    };

    sqlx::query_as::<_, PasskeyCredential>(query)
        .bind(value)
        .fetch_all(pool)
        .await
        .map_err(|e| PasskeyError::Storage(e.to_string()))
}

pub(super) async fn atomic_update_credential_counter_sqlite(
    pool: &Pool<Sqlite>,
    credential_id: CredentialId,
    new_counter: u32,
) -> Result<bool, PasskeyError> {
    let counter_i64 = new_counter as i64;
    let passkey_table = DB_TABLE_PASSKEY_CREDENTIALS.as_str();

    let result = sqlx::query(&format!(
        r#"
        UPDATE {passkey_table}
        SET counter = ?, updated_at = CURRENT_TIMESTAMP
        WHERE credential_id = ? AND counter < ?
        "#
    ))
    .bind(counter_i64)
    .bind(credential_id.as_str())
    .bind(counter_i64)
    .execute(pool)
    .await
    .map_err(|e| PasskeyError::Storage(e.to_string()))?;

    Ok(result.rows_affected() > 0)
}

pub(super) async fn delete_credential_by_field_sqlite(
    pool: &Pool<Sqlite>,
    field: &CredentialSearchField,
) -> Result<(), PasskeyError> {
    let passkey_table = DB_TABLE_PASSKEY_CREDENTIALS.as_str();
    let (query, value) = match field {
        CredentialSearchField::CredentialId(credential_id) => (
            &format!(r#"DELETE FROM {passkey_table} WHERE credential_id = ?"#),
            credential_id.as_str(),
        ),
        CredentialSearchField::UserId(id) => (
            &format!(r#"DELETE FROM {passkey_table} WHERE user_id = ?"#),
            id.as_str(),
        ),
        CredentialSearchField::UserHandle(handle) => (
            &format!(r#"DELETE FROM {passkey_table} WHERE user_handle = ?"#),
            handle.as_str(),
        ),
        CredentialSearchField::UserName(name) => (
            &format!(r#"DELETE FROM {passkey_table} WHERE user_name = ?"#),
            name.as_str(),
        ),
    };

    sqlx::query(query)
        .bind(value)
        .execute(pool)
        .await
        .map_err(|e| PasskeyError::Storage(e.to_string()))?;

    Ok(())
}

pub(super) async fn update_credential_user_details_sqlite(
    pool: &Pool<Sqlite>,
    credential_id: CredentialId,
    name: &str,
    display_name: &str,
) -> Result<(), PasskeyError> {
    let passkey_table = DB_TABLE_PASSKEY_CREDENTIALS.as_str();

    sqlx::query(&format!(
        r#"UPDATE {passkey_table} SET user_name = $1, user_display_name = $2 WHERE credential_id = $3"#
    ))
    .bind(name)
    .bind(display_name)
    .bind(credential_id.as_str())
    .execute(pool)
    .await
    .map_err(|e| PasskeyError::Storage(e.to_string()))?;

    Ok(())
}

pub(super) async fn update_credential_last_used_at_sqlite(
    pool: &Pool<Sqlite>,
    credential_id: CredentialId,
    last_used_at: DateTime<Utc>,
) -> Result<(), PasskeyError> {
    let passkey_table = DB_TABLE_PASSKEY_CREDENTIALS.as_str();

    sqlx::query(&format!(
        r#"
        UPDATE {passkey_table}
        SET last_used_at = ?, updated_at = CURRENT_TIMESTAMP
        WHERE credential_id = ?
        "#
    ))
    .bind(last_used_at)
    .bind(credential_id.as_str())
    .execute(pool)
    .await
    .map_err(|e| PasskeyError::Storage(e.to_string()))?;

    Ok(())
}