sql-audit 0.1.0

A companion library to `sql-audit-cli` implementing all of the same functionality, plus useful runtime operations.
Documentation
//! The companion library to `sql-audit-cli`, though you can just use this instead if you'd like to
//! generate audit programmatically.
//!
//! ## Limitations
//! 1. Currently this lib only knows how to use `sqlx`, so you must use it too.
//!
//! ## Tour
//! - [`generate_audit`] is the programmatic equivalent of running `sql-audit-cli`.
//! - [`query_audit`] lets you query the audit table... currently only supports querying all records
//!     for a table.
//! - [`set_local_app_user`] sets the value to fill into `app_user` in the audit table for the
//!     duration of the current transaction.
//! - [`set_local_request_id`] sets the value to fill into `request_id` in the audit table for the
//!     duration of the current transaction.

#![forbid(unsafe_code)]
#![deny(missing_docs)]
#![deny(clippy::pedantic)]
#![deny(clippy::cargo)]

use color_eyre::eyre::WrapErr;
use color_eyre::Result;
use futures::future::try_join_all;
use sqlx::{PgPool, Row};

const AUDIT_SCHEMA_NAME: &str = "sql_audit";
const TABLE_NAME: &str = "audit";

/// A constant for the setting used to record `app_user` in the audit table.
pub const APP_USER_SETTING: &str = "sql_audit.app_user";
/// A constant for the setting used to record `request_id` in the audit table.
pub const REQUEST_ID_SETTING: &str = "sql_audit.request_id";

/// Generate the audit infrastructure for the database that the passed `pool` connects to. This
/// function will:
/// 1. Create the `sql_audit` schema with `audit` table if the schema doesn't yet exist.
/// 2. Create the `sql_audit_trigger` function used in all generated `TRIGGER`s.
/// 3. Create triggers for all tables in the `public` schema which are not in `exclude_tables`.
///
/// ### Limitations
/// If you try to audit any table which does not have a `pk` column, the triggers will fail.
///
/// ### Errors
/// 1. Connection failure to the provided `pool` at any point will generate an error.
/// 2. Any normal Postgres errors (e.g. user doesn't have the right permissions)
pub async fn generate_audit(pool: &sqlx::PgPool, exclude_tables: Vec<String>) -> Result<()> {
    let table_name = format!("{}.{}", AUDIT_SCHEMA_NAME, TABLE_NAME);
    enforce_audit_table_exists(pool, &table_name).await?;
    create_change_trigger_function(pool, &table_name).await?;
    create_triggers(pool, exclude_tables).await
}

/// Get all the audit records for a specific table.
///
/// ## Errors
/// Fails if the query fails, usually due to connection issues.
pub async fn query_audit(pool: &sqlx::PgPool, table_name: &str) -> Result<Vec<AuditRecord>> {
    let audit_table_name = format!("{}.{}", AUDIT_SCHEMA_NAME, TABLE_NAME);
    sqlx::query_as(&format!(
        "SELECT * FROM {} WHERE table_name = $1",
        audit_table_name
    ))
    .bind(table_name)
    .fetch_all(pool)
    .await
    .wrap_err("Could not fetch audit data, did you generate audit for this database?")
}

/// Set a value to be recorded in the `app_user` column of the audit table. Only valid for the
/// current transaction and does nothing if there is no current transaction.
///
/// ## Errors
/// Fails if the query fails, usually due to connection issues.
pub async fn set_local_app_user(user: &str, pool: &PgPool) -> Result<()> {
    //  reset_val WHERE name = 'configuration_parameter';
    sqlx::query(&format!(
        "SELECT set_config('{}', $1, true)",
        APP_USER_SETTING
    ))
    .bind(user)
    .bind(APP_USER_SETTING)
    .execute(pool)
    .await?;
    Ok(())
}

/// Set a value to be recorded in the `request_id` column of the audit table. Only valid for the
/// current transaction and does nothing if there is no current transaction.
///
/// ## Errors
/// Fails if the query fails, usually due to connection issues.
pub async fn set_local_request_id(request_id: &str, pool: &PgPool) -> Result<()> {
    sqlx::query(&format!(
        "SELECT set_config('{}', $1, true)",
        REQUEST_ID_SETTING
    ))
    .bind(request_id)
    .execute(pool)
    .await?;
    Ok(())
}

/// A structure representing a single row of the `audit` table.
#[derive(sqlx::FromRow, Debug, Eq, PartialEq)]
pub struct AuditRecord {
    /// Surrogate key
    pub id: i32,
    /// Name of table that this row is from
    pub table_name: String,
    /// The pk column of the row (matches OLD for UPDATE/DELETE and NEW for INSERT)
    pub pk: i32,
    /// One of "INSERT", "UPDATE", "DELETE"
    pub operation: String,
    /// The user in the database that made the change
    pub db_user: String,
    /// Semantically the user of the application that made the change. Realistically whatever was
    /// set using [`set_local_app_user`] during the transaction.
    pub app_user: Option<String>,
    /// Semantically the unique identifier of the request that made the change. Realistically
    /// whatever was set using [`set_local_request_id`] during the transaction.
    pub request_id: Option<String>,
    /// The JSON representation of the row before the operation took place. Only populated for
    /// UPDATE and DELETE.
    pub old_val: Option<serde_json::Value>,
    /// The JSON representation of the row after the operation took place. Only populated for
    /// INSERT and UPDATE.
    pub new_val: Option<serde_json::Value>,
}

async fn enforce_audit_table_exists(pool: &sqlx::PgPool, table_name: &str) -> Result<()> {
    let audit_exists =
        sqlx::query("SELECT schema_name FROM information_schema.schemata WHERE schema_name = $1;")
            .bind(AUDIT_SCHEMA_NAME)
            .fetch_optional(pool)
            .await?;

    if audit_exists.is_none() {
        sqlx::query(&format!("CREATE SCHEMA {};", AUDIT_SCHEMA_NAME))
            .execute(pool)
            .await?;
        sqlx::query(&format!(
            r##"
            CREATE TABLE {} (
                id serial,
                ts timestamp DEFAULT now(),
                table_name text,
                pk integer,
                operation text,
                db_user text DEFAULT current_user,
                app_user text,
                request_id text,
                new_val jsonb,
                old_val jsonb
            );
        "##,
            &table_name
        ))
        .execute(pool)
        .await?;
    }
    Ok(())
}

async fn create_change_trigger_function(pool: &sqlx::PgPool, audit_table: &str) -> Result<()> {
    sqlx::query(&format!(
        r##"
    CREATE OR REPLACE FUNCTION sql_audit_trigger() RETURNS trigger AS $$
    BEGIN
        IF TG_OP = 'INSERT'
        THEN
            INSERT INTO {audit_table} (table_name, pk, operation, app_user, request_id, new_val)
            VALUES (
                TG_RELNAME, NEW.pk, TG_OP, current_setting('{app_user}', 't'),
                current_setting('{request_id}', 't'), row_to_json(NEW)
            );
            RETURN NEW;
        ELSIF   TG_OP = 'UPDATE'
        THEN
            INSERT INTO {audit_table} (table_name, pk, operation, app_user, request_id, new_val, old_val)
            VALUES (
                TG_RELNAME, OLD.pk, TG_OP, current_setting('{app_user}', 't'),
                current_setting('sql_audit.request_id', 't'), row_to_json(NEW), row_to_json(OLD)
            );
            RETURN NEW;
        ELSIF   TG_OP = 'DELETE'
        THEN
            INSERT INTO {audit_table} (table_name, pk, operation, app_user, request_id, old_val)
            VALUES (
                TG_RELNAME, OLD.pk, TG_OP, current_setting('{app_user}', 't'),
                current_setting('sql_audit.request_id', 't'), row_to_json(OLD)
            );
            RETURN OLD;
        END IF;
    END;
    $$ LANGUAGE 'plpgsql' SECURITY DEFINER;
    "##,
        audit_table = audit_table,
        app_user = APP_USER_SETTING,
        request_id = REQUEST_ID_SETTING,
    ))
    .execute(pool)
    .await?;
    Ok(())
}

#[allow(clippy::filter_map)]
async fn create_triggers(pool: &sqlx::PgPool, exclude_tables: Vec<String>) -> Result<()> {
    let futures = sqlx::query(
        r##"
        SELECT table_name
        FROM information_schema.tables
        WHERE table_schema = 'public';"##,
    )
    .fetch_all(pool)
    .await?
    .into_iter()
    .filter_map(|row| row.try_get("table_name").ok())
    .filter(|table_name| !exclude_tables.contains(table_name))
    .map(|table| create_trigger(pool, table));

    try_join_all(futures).await?;
    Ok(())
}

async fn create_trigger(pool: &sqlx::PgPool, table: String) -> Result<()> {
    sqlx::query(&format!("DROP TRIGGER IF EXISTS sql_audit on {};", table))
        .execute(pool)
        .await?;
    let create_query = format!(
        r##"
            CREATE TRIGGER sql_audit BEFORE INSERT OR UPDATE OR DELETE ON {}
            FOR EACH ROW EXECUTE PROCEDURE sql_audit_trigger();
            "##,
        table
    );
    sqlx::query(&create_query).execute(pool).await?;
    Ok(())
}

#[cfg(test)]
mod tests {
    #[test]
    fn it_works() {
        assert_eq!(2 + 2, 4);
    }
}