use async_trait::async_trait;
use chrono::{DateTime, SecondsFormat, Utc};
use crate::action::Action;
use crate::audit::{Audit, NewAudit};
use crate::backend::{AuditQuery, Backend, Order};
use crate::changes::AuditedChanges;
use crate::error::{AuditError, Result};
use crate::id::AuditId;
const COLUMNS: &str = "id, auditable_type, auditable_id, associated_type, associated_id, \
user_type, user_id, username, action, audited_changes, version, comment, remote_address, \
request_uuid, created_at";
#[derive(Clone, Copy, PartialEq, Eq)]
#[allow(dead_code)] enum Dialect {
Sqlite,
Postgres,
}
enum Bind {
Str(String),
OptStr(Option<String>),
Int(i64),
}
struct Qb {
dialect: Dialect,
sql: String,
binds: Vec<Bind>,
idx: usize,
}
impl Qb {
fn new(dialect: Dialect, head: impl Into<String>) -> Self {
Qb {
dialect,
sql: head.into(),
binds: Vec::new(),
idx: 0,
}
}
fn raw(&mut self, s: &str) -> &mut Self {
self.sql.push_str(s);
self
}
fn placeholder(&mut self) -> String {
self.idx += 1;
match self.dialect {
Dialect::Sqlite => "?".to_string(),
Dialect::Postgres => format!("${}", self.idx),
}
}
fn bind_str(&mut self, value: impl Into<String>) -> &mut Self {
let ph = self.placeholder();
self.sql.push_str(&ph);
self.binds.push(Bind::Str(value.into()));
self
}
fn bind_opt(&mut self, value: Option<String>) -> &mut Self {
let ph = self.placeholder();
self.sql.push_str(&ph);
self.binds.push(Bind::OptStr(value));
self
}
fn bind_int(&mut self, value: i64) -> &mut Self {
let ph = self.placeholder();
self.sql.push_str(&ph);
self.binds.push(Bind::Int(value));
self
}
}
fn order_clause(order: Order) -> &'static str {
match order {
Order::VersionAsc => " ORDER BY version ASC, id ASC",
Order::VersionDesc => " ORDER BY version DESC, id DESC",
Order::CreatedAtAsc => " ORDER BY created_at ASC, id ASC",
Order::CreatedAtDesc => " ORDER BY created_at DESC, id DESC",
}
}
fn rfc3339(time: DateTime<Utc>) -> String {
time.to_rfc3339_opts(SecondsFormat::Micros, true)
}
#[derive(sqlx::FromRow)]
struct RawAudit {
id: i64,
auditable_type: String,
auditable_id: String,
associated_type: Option<String>,
associated_id: Option<String>,
user_type: Option<String>,
user_id: Option<String>,
username: Option<String>,
action: String,
audited_changes: String,
version: i64,
comment: Option<String>,
remote_address: Option<String>,
request_uuid: Option<String>,
created_at: String,
}
impl RawAudit {
fn into_audit(self) -> Result<Audit> {
let action = self.action.parse::<Action>().map_err(AuditError::backend)?;
let audited_changes: AuditedChanges =
serde_json::from_str(&self.audited_changes).map_err(AuditError::Serialization)?;
let created_at = DateTime::parse_from_rfc3339(&self.created_at)
.map_err(AuditError::backend)?
.with_timezone(&Utc);
Ok(Audit {
id: self.id,
auditable_type: self.auditable_type,
auditable_id: AuditId::new(self.auditable_id),
associated_type: self.associated_type,
associated_id: self.associated_id.map(AuditId::new),
user_type: self.user_type,
user_id: self.user_id.map(AuditId::new),
username: self.username,
action,
audited_changes,
version: self.version as i32,
comment: self.comment,
remote_address: self.remote_address,
request_uuid: self.request_uuid,
created_at,
})
}
}
macro_rules! fetch_all_raw {
($exec:expr, $sql:expr, $binds:expr) => {{
let mut q = sqlx::query_as::<_, RawAudit>($sql);
for b in $binds {
q = match b {
Bind::Str(s) => q.bind(s),
Bind::OptStr(s) => q.bind(s),
Bind::Int(i) => q.bind(i),
};
}
q.fetch_all($exec).await.map_err(AuditError::backend)?
}};
}
macro_rules! fetch_opt_raw {
($exec:expr, $sql:expr, $binds:expr) => {{
let mut q = sqlx::query_as::<_, RawAudit>($sql);
for b in $binds {
q = match b {
Bind::Str(s) => q.bind(s),
Bind::OptStr(s) => q.bind(s),
Bind::Int(i) => q.bind(i),
};
}
q.fetch_optional($exec).await.map_err(AuditError::backend)?
}};
}
macro_rules! fetch_scalar_i64 {
($exec:expr, $sql:expr, $binds:expr) => {{
let mut q = sqlx::query_scalar::<_, i64>($sql);
for b in $binds {
q = match b {
Bind::Str(s) => q.bind(s),
Bind::OptStr(s) => q.bind(s),
Bind::Int(i) => q.bind(i),
};
}
q.fetch_one($exec).await.map_err(AuditError::backend)?
}};
}
#[cfg(feature = "sqlite")]
macro_rules! exec_sql {
($exec:expr, $sql:expr, $binds:expr) => {{
let mut q = sqlx::query($sql);
for b in $binds {
q = match b {
Bind::Str(s) => q.bind(s),
Bind::OptStr(s) => q.bind(s),
Bind::Int(i) => q.bind(i),
};
}
q.execute($exec).await.map_err(AuditError::backend)?
}};
}
#[cfg(feature = "postgres")]
macro_rules! exec_raw {
($exec:expr, $sql:expr, $binds:expr) => {{
let mut q = sqlx::query($sql);
for b in $binds {
q = match b {
Bind::Str(s) => q.bind(s),
Bind::OptStr(s) => q.bind(s),
Bind::Int(i) => q.bind(i),
};
}
q.execute($exec).await?
}};
}
#[cfg(feature = "postgres")]
fn is_deadlock(err: &sqlx::Error) -> bool {
if let sqlx::Error::Database(db) = err {
matches!(db.code().as_deref(), Some("40P01") | Some("40001"))
} else {
false
}
}
macro_rules! do_insert {
($pool:expr, $dialect:expr, $audit:expr) => {{
let mut tx = $pool.begin().await.map_err(AuditError::backend)?;
let version: i32 = if $audit.action == Action::Create {
1
} else {
let (s, b) = build_max_version($dialect, &$audit.auditable_type, &$audit.auditable_id);
fetch_scalar_i64!(&mut *tx, &s, b) as i32
};
let (s, b) = build_insert($dialect, &$audit, version);
let id = fetch_scalar_i64!(&mut *tx, &s, b);
tx.commit().await.map_err(AuditError::backend)?;
(id, version)
}};
}
fn build_max_version(
dialect: Dialect,
auditable_type: &str,
auditable_id: &AuditId,
) -> (String, Vec<Bind>) {
let mut qb = Qb::new(
dialect,
"SELECT COALESCE(MAX(version), 0) + 1 FROM audits WHERE auditable_type = ",
);
qb.bind_str(auditable_type)
.raw(" AND auditable_id = ")
.bind_str(auditable_id.as_str());
(qb.sql, qb.binds)
}
fn build_insert(dialect: Dialect, audit: &NewAudit, version: i32) -> (String, Vec<Bind>) {
let changes_json =
serde_json::to_string(&audit.audited_changes).unwrap_or_else(|_| "{}".into());
let mut qb = Qb::new(
dialect,
"INSERT INTO audits (auditable_type, auditable_id, associated_type, associated_id, \
user_type, user_id, username, action, audited_changes, version, comment, \
remote_address, request_uuid, created_at) VALUES (",
);
qb.bind_str(audit.auditable_type.clone())
.raw(", ")
.bind_str(audit.auditable_id.as_str())
.raw(", ")
.bind_opt(audit.associated_type.clone())
.raw(", ")
.bind_opt(audit.associated_id.as_ref().map(|i| i.as_str().to_string()))
.raw(", ")
.bind_opt(audit.user_type.clone())
.raw(", ")
.bind_opt(audit.user_id.as_ref().map(|i| i.as_str().to_string()))
.raw(", ")
.bind_opt(audit.username.clone())
.raw(", ")
.bind_str(audit.action.as_str())
.raw(", ")
.bind_str(changes_json)
.raw(", ")
.bind_int(version as i64)
.raw(", ")
.bind_opt(audit.comment.clone())
.raw(", ")
.bind_opt(audit.remote_address.clone())
.raw(", ")
.bind_opt(audit.request_uuid.clone())
.raw(", ")
.bind_str(rfc3339(audit.created_at))
.raw(") RETURNING id");
(qb.sql, qb.binds)
}
fn build_select_auditable(
dialect: Dialect,
auditable_type: &str,
auditable_id: &AuditId,
query: &AuditQuery,
) -> (String, Vec<Bind>) {
let mut qb = Qb::new(
dialect,
format!("SELECT {COLUMNS} FROM audits WHERE auditable_type = "),
);
qb.bind_str(auditable_type)
.raw(" AND auditable_id = ")
.bind_str(auditable_id.as_str());
apply_filters(&mut qb, query);
(qb.sql, qb.binds)
}
fn build_select_associated(
dialect: Dialect,
associated_type: &str,
associated_id: &AuditId,
query: &AuditQuery,
) -> (String, Vec<Bind>) {
let mut qb = Qb::new(
dialect,
format!("SELECT {COLUMNS} FROM audits WHERE associated_type = "),
);
qb.bind_str(associated_type)
.raw(" AND associated_id = ")
.bind_str(associated_id.as_str());
apply_filters(&mut qb, query);
(qb.sql, qb.binds)
}
fn apply_filters(qb: &mut Qb, query: &AuditQuery) {
if let Some(action) = query.action {
qb.raw(" AND action = ").bind_str(action.as_str());
}
if let Some(v) = query.from_version {
qb.raw(" AND version >= ").bind_int(v as i64);
}
if let Some(v) = query.to_version {
qb.raw(" AND version <= ").bind_int(v as i64);
}
if let Some(t) = query.up_until {
qb.raw(" AND created_at <= ").bind_str(rfc3339(t));
}
qb.raw(order_clause(query.order));
if let Some(l) = query.limit {
qb.raw(" LIMIT ").bind_int(l);
}
if let Some(o) = query.offset {
qb.raw(" OFFSET ").bind_int(o);
}
}
#[cfg(feature = "sqlite")]
const SQLITE_SCHEMA: &[&str] = &[
"CREATE TABLE IF NOT EXISTS audits (\
id INTEGER PRIMARY KEY AUTOINCREMENT, \
auditable_type TEXT NOT NULL, \
auditable_id TEXT NOT NULL, \
associated_type TEXT, \
associated_id TEXT, \
user_type TEXT, \
user_id TEXT, \
username TEXT, \
action TEXT NOT NULL, \
audited_changes TEXT NOT NULL, \
version INTEGER NOT NULL DEFAULT 0, \
comment TEXT, \
remote_address TEXT, \
request_uuid TEXT, \
created_at TEXT NOT NULL)",
"CREATE INDEX IF NOT EXISTS auditable_index ON audits (auditable_type, auditable_id, version)",
"CREATE INDEX IF NOT EXISTS associated_index ON audits (associated_type, associated_id)",
"CREATE INDEX IF NOT EXISTS user_index ON audits (user_id, user_type)",
"CREATE INDEX IF NOT EXISTS index_audits_on_request_uuid ON audits (request_uuid)",
"CREATE INDEX IF NOT EXISTS index_audits_on_created_at ON audits (created_at)",
"CREATE UNIQUE INDEX IF NOT EXISTS unique_auditable_version ON audits (auditable_type, auditable_id, version)",
];
#[cfg(feature = "postgres")]
const POSTGRES_SCHEMA: &[&str] = &[
"CREATE TABLE IF NOT EXISTS audits (\
id BIGSERIAL PRIMARY KEY, \
auditable_type TEXT NOT NULL, \
auditable_id TEXT NOT NULL, \
associated_type TEXT, \
associated_id TEXT, \
user_type TEXT, \
user_id TEXT, \
username TEXT, \
action TEXT NOT NULL, \
audited_changes TEXT NOT NULL, \
version BIGINT NOT NULL DEFAULT 0, \
comment TEXT, \
remote_address TEXT, \
request_uuid TEXT, \
created_at TEXT NOT NULL)",
"CREATE INDEX IF NOT EXISTS auditable_index ON audits (auditable_type, auditable_id, version)",
"CREATE INDEX IF NOT EXISTS associated_index ON audits (associated_type, associated_id)",
"CREATE INDEX IF NOT EXISTS user_index ON audits (user_id, user_type)",
"CREATE INDEX IF NOT EXISTS index_audits_on_request_uuid ON audits (request_uuid)",
"CREATE INDEX IF NOT EXISTS index_audits_on_created_at ON audits (created_at)",
"CREATE UNIQUE INDEX IF NOT EXISTS unique_auditable_version ON audits (auditable_type, auditable_id, version)",
];
pub enum SqlxBackend {
#[cfg(feature = "sqlite")]
Sqlite(sqlx::SqlitePool),
#[cfg(feature = "postgres")]
Postgres(sqlx::PgPool),
}
impl SqlxBackend {
#[cfg(feature = "sqlite")]
pub fn sqlite(pool: sqlx::SqlitePool) -> Self {
SqlxBackend::Sqlite(pool)
}
#[cfg(feature = "postgres")]
pub fn postgres(pool: sqlx::PgPool) -> Self {
SqlxBackend::Postgres(pool)
}
#[cfg(feature = "sqlite")]
pub async fn connect_sqlite(url: &str) -> Result<Self> {
let pool = sqlx::sqlite::SqlitePoolOptions::new()
.max_connections(1)
.connect(url)
.await
.map_err(AuditError::backend)?;
Ok(SqlxBackend::Sqlite(pool))
}
#[cfg(feature = "postgres")]
pub async fn connect_postgres(url: &str) -> Result<Self> {
let pool = sqlx::postgres::PgPoolOptions::new()
.connect(url)
.await
.map_err(AuditError::backend)?;
Ok(SqlxBackend::Postgres(pool))
}
fn dialect(&self) -> Dialect {
match self {
#[cfg(feature = "sqlite")]
SqlxBackend::Sqlite(_) => Dialect::Sqlite,
#[cfg(feature = "postgres")]
SqlxBackend::Postgres(_) => Dialect::Postgres,
}
}
pub async fn migrate(&self) -> Result<()> {
match self {
#[cfg(feature = "sqlite")]
SqlxBackend::Sqlite(pool) => {
for stmt in SQLITE_SCHEMA {
sqlx::query(stmt)
.execute(pool)
.await
.map_err(AuditError::backend)?;
}
}
#[cfg(feature = "postgres")]
SqlxBackend::Postgres(pool) => {
for stmt in POSTGRES_SCHEMA {
sqlx::query(stmt)
.execute(pool)
.await
.map_err(AuditError::backend)?;
}
}
}
Ok(())
}
}
#[async_trait]
impl Backend for SqlxBackend {
async fn insert(&self, audit: NewAudit) -> Result<Audit> {
let dialect = self.dialect();
let (id, version) = match self {
#[cfg(feature = "sqlite")]
SqlxBackend::Sqlite(pool) => do_insert!(pool, dialect, audit),
#[cfg(feature = "postgres")]
SqlxBackend::Postgres(pool) => do_insert!(pool, dialect, audit),
};
Ok(Audit {
id,
auditable_type: audit.auditable_type,
auditable_id: audit.auditable_id,
associated_type: audit.associated_type,
associated_id: audit.associated_id,
user_type: audit.user_type,
user_id: audit.user_id,
username: audit.username,
action: audit.action,
audited_changes: audit.audited_changes,
version,
comment: audit.comment,
remote_address: audit.remote_address,
request_uuid: audit.request_uuid,
created_at: audit.created_at,
})
}
async fn audits_for_auditable(
&self,
auditable_type: &str,
auditable_id: &AuditId,
query: &AuditQuery,
) -> Result<Vec<Audit>> {
let (sql, binds) =
build_select_auditable(self.dialect(), auditable_type, auditable_id, query);
let raws: Vec<RawAudit> = match self {
#[cfg(feature = "sqlite")]
SqlxBackend::Sqlite(pool) => fetch_all_raw!(pool, &sql, binds),
#[cfg(feature = "postgres")]
SqlxBackend::Postgres(pool) => fetch_all_raw!(pool, &sql, binds),
};
raws.into_iter().map(RawAudit::into_audit).collect()
}
async fn audits_for_associated(
&self,
associated_type: &str,
associated_id: &AuditId,
query: &AuditQuery,
) -> Result<Vec<Audit>> {
let (sql, binds) =
build_select_associated(self.dialect(), associated_type, associated_id, query);
let raws: Vec<RawAudit> = match self {
#[cfg(feature = "sqlite")]
SqlxBackend::Sqlite(pool) => fetch_all_raw!(pool, &sql, binds),
#[cfg(feature = "postgres")]
SqlxBackend::Postgres(pool) => fetch_all_raw!(pool, &sql, binds),
};
raws.into_iter().map(RawAudit::into_audit).collect()
}
async fn own_and_associated_audits(
&self,
auditable_type: &str,
auditable_id: &AuditId,
) -> Result<Vec<Audit>> {
let dialect = self.dialect();
let mut qb = Qb::new(
dialect,
format!("SELECT {COLUMNS} FROM audits WHERE (auditable_type = "),
);
qb.bind_str(auditable_type)
.raw(" AND auditable_id = ")
.bind_str(auditable_id.as_str())
.raw(") OR (associated_type = ")
.bind_str(auditable_type)
.raw(" AND associated_id = ")
.bind_str(auditable_id.as_str())
.raw(")")
.raw(order_clause(Order::CreatedAtDesc));
let (sql, binds) = (qb.sql, qb.binds);
let raws: Vec<RawAudit> = match self {
#[cfg(feature = "sqlite")]
SqlxBackend::Sqlite(pool) => fetch_all_raw!(pool, &sql, binds),
#[cfg(feature = "postgres")]
SqlxBackend::Postgres(pool) => fetch_all_raw!(pool, &sql, binds),
};
raws.into_iter().map(RawAudit::into_audit).collect()
}
async fn count_for_auditable(
&self,
auditable_type: &str,
auditable_id: &AuditId,
) -> Result<i64> {
let dialect = self.dialect();
let mut qb = Qb::new(
dialect,
"SELECT COUNT(*) FROM audits WHERE auditable_type = ",
);
qb.bind_str(auditable_type)
.raw(" AND auditable_id = ")
.bind_str(auditable_id.as_str());
let (sql, binds) = (qb.sql, qb.binds);
let count = match self {
#[cfg(feature = "sqlite")]
SqlxBackend::Sqlite(pool) => fetch_scalar_i64!(pool, &sql, binds),
#[cfg(feature = "postgres")]
SqlxBackend::Postgres(pool) => fetch_scalar_i64!(pool, &sql, binds),
};
Ok(count)
}
async fn find(&self, id: i64) -> Result<Option<Audit>> {
let dialect = self.dialect();
let mut qb = Qb::new(dialect, format!("SELECT {COLUMNS} FROM audits WHERE id = "));
qb.bind_int(id);
let (sql, binds) = (qb.sql, qb.binds);
let raw: Option<RawAudit> = match self {
#[cfg(feature = "sqlite")]
SqlxBackend::Sqlite(pool) => fetch_opt_raw!(pool, &sql, binds),
#[cfg(feature = "postgres")]
SqlxBackend::Postgres(pool) => fetch_opt_raw!(pool, &sql, binds),
};
raw.map(RawAudit::into_audit).transpose()
}
async fn combine(
&self,
target_id: i64,
merged_changes: &AuditedChanges,
comment: Option<&str>,
older_ids: &[i64],
) -> Result<()> {
let dialect = self.dialect();
let changes_json =
serde_json::to_string(merged_changes).map_err(AuditError::Serialization)?;
let mut update = Qb::new(dialect, "UPDATE audits SET audited_changes = ");
update
.bind_str(changes_json)
.raw(", comment = ")
.bind_opt(comment.map(|c| c.to_string()))
.raw(" WHERE id = ")
.bind_int(target_id);
let (update_sql, update_binds) = (update.sql, update.binds);
let delete: Option<(String, Vec<Bind>)> = if older_ids.is_empty() {
None
} else {
let mut del = Qb::new(dialect, "DELETE FROM audits WHERE id IN (");
for (i, id) in older_ids.iter().enumerate() {
if i > 0 {
del.raw(", ");
}
del.bind_int(*id);
}
del.raw(")");
Some((del.sql, del.binds))
};
match self {
#[cfg(feature = "sqlite")]
SqlxBackend::Sqlite(pool) => {
let mut tx = pool.begin().await.map_err(AuditError::backend)?;
exec_sql!(&mut *tx, &update_sql, update_binds);
if let Some((dsql, dbinds)) = delete {
exec_sql!(&mut *tx, &dsql, dbinds);
}
tx.commit().await.map_err(AuditError::backend)?;
}
#[cfg(feature = "postgres")]
SqlxBackend::Postgres(pool) => {
let outcome: std::result::Result<(), sqlx::Error> = async {
let mut tx = pool.begin().await?;
exec_raw!(&mut *tx, &update_sql, update_binds);
if let Some((dsql, dbinds)) = delete {
exec_raw!(&mut *tx, &dsql, dbinds);
}
tx.commit().await?;
Ok(())
}
.await;
match outcome {
Ok(()) => {}
Err(e) if is_deadlock(&e) => {}
Err(e) => return Err(AuditError::backend(e)),
}
}
}
Ok(())
}
}