use async_trait::async_trait;
use chrono::{DateTime, Utc};
use serde_json::Value;
use sqlx::{PgPool, Postgres, QueryBuilder};
use uuid::Uuid;
use crate::errors::AppError;
use crate::repositories::pagination::{cap_limit, cap_offset};
#[cfg(test)]
use crate::repositories::pagination::{DEFAULT_MAX_OFFSET, DEFAULT_MAX_PAGE_SIZE};
use crate::repositories::{AuditEventType, AuditLogEntry, AuditLogQuery, AuditLogRepository};
pub struct PostgresAuditLogRepository {
pool: PgPool,
}
impl PostgresAuditLogRepository {
pub fn new(pool: PgPool) -> Self {
Self { pool }
}
}
#[derive(sqlx::FromRow)]
struct AuditLogRow {
id: Uuid,
event_type: String,
actor_user_id: Option<Uuid>,
session_id: Option<Uuid>,
org_id: Option<Uuid>,
target_type: Option<String>,
target_id: Option<Uuid>,
ip_address: Option<String>,
user_agent: Option<String>,
metadata: Value,
created_at: DateTime<Utc>,
}
impl TryFrom<AuditLogRow> for AuditLogEntry {
type Error = AppError;
fn try_from(row: AuditLogRow) -> Result<Self, Self::Error> {
let event_type = AuditEventType::from_str(&row.event_type)
.ok_or_else(|| AppError::Database(format!("Invalid event type: {}", row.event_type)))?;
Ok(Self {
id: row.id,
event_type,
actor_user_id: row.actor_user_id,
session_id: row.session_id,
org_id: row.org_id,
target_type: row.target_type,
target_id: row.target_id,
ip_address: row.ip_address,
user_agent: row.user_agent,
metadata: row.metadata,
created_at: row.created_at,
})
}
}
fn cap_pagination_values(query: &AuditLogQuery) -> (i64, i64) {
let limit = cap_limit(query.limit.unwrap_or(100)) as i64;
let offset = cap_offset(query.offset.unwrap_or(0)) as i64;
(limit, offset)
}
fn append_where_conditions<'a>(builder: &mut QueryBuilder<'a, Postgres>, query: &'a AuditLogQuery) {
let mut has_condition = false;
if let Some(ref org_id) = query.org_id {
builder.push(if has_condition { " AND " } else { " WHERE " });
builder.push("org_id = ").push_bind(*org_id);
has_condition = true;
}
if let Some(ref actor_id) = query.actor_user_id {
builder.push(if has_condition { " AND " } else { " WHERE " });
builder.push("actor_user_id = ").push_bind(*actor_id);
has_condition = true;
}
if let Some(ref event_type) = query.event_type {
builder.push(if has_condition { " AND " } else { " WHERE " });
builder.push("event_type = ").push_bind(event_type.as_str());
has_condition = true;
}
if let Some(ref target_type) = query.target_type {
builder.push(if has_condition { " AND " } else { " WHERE " });
builder
.push("target_type = ")
.push_bind(target_type.clone());
has_condition = true;
}
if let Some(ref target_id) = query.target_id {
builder.push(if has_condition { " AND " } else { " WHERE " });
builder.push("target_id = ").push_bind(*target_id);
has_condition = true;
}
if let Some(ref session_id) = query.session_id {
builder.push(if has_condition { " AND " } else { " WHERE " });
builder.push("session_id = ").push_bind(*session_id);
#[allow(unused_assignments)]
{
has_condition = true;
}
}
}
#[async_trait]
impl AuditLogRepository for PostgresAuditLogRepository {
async fn create(&self, entry: AuditLogEntry) -> Result<AuditLogEntry, AppError> {
let row: AuditLogRow = sqlx::query_as(
r#"
INSERT INTO audit_logs (id, event_type, actor_user_id, session_id, org_id, target_type, target_id, ip_address, user_agent, metadata)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
RETURNING id, event_type, actor_user_id, session_id, org_id, target_type, target_id, ip_address, user_agent, metadata, created_at
"#,
)
.bind(entry.id)
.bind(entry.event_type.as_str())
.bind(entry.actor_user_id)
.bind(entry.session_id)
.bind(entry.org_id)
.bind(&entry.target_type)
.bind(entry.target_id)
.bind(&entry.ip_address)
.bind(&entry.user_agent)
.bind(&entry.metadata)
.fetch_one(&self.pool)
.await
.map_err(|e| AppError::Database(e.to_string()))?;
row.try_into()
}
async fn find_by_id(&self, id: Uuid) -> Result<Option<AuditLogEntry>, AppError> {
let row: Option<AuditLogRow> = sqlx::query_as(
r#"
SELECT id, event_type, actor_user_id, session_id, org_id, target_type, target_id, ip_address, user_agent, metadata, created_at
FROM audit_logs WHERE id = $1
"#,
)
.bind(id)
.fetch_optional(&self.pool)
.await
.map_err(|e| AppError::Database(e.to_string()))?;
match row {
Some(r) => Ok(Some(r.try_into()?)),
None => Ok(None),
}
}
async fn query(&self, query: AuditLogQuery) -> Result<Vec<AuditLogEntry>, AppError> {
let (limit, offset) = cap_pagination_values(&query);
let mut builder: QueryBuilder<Postgres> = QueryBuilder::new(
"SELECT id, event_type, actor_user_id, session_id, org_id, target_type, target_id, ip_address, user_agent, metadata, created_at FROM audit_logs",
);
append_where_conditions(&mut builder, &query);
builder.push(" ORDER BY created_at DESC LIMIT ");
builder.push_bind(limit);
builder.push(" OFFSET ");
builder.push_bind(offset);
let rows: Vec<AuditLogRow> = builder
.build_query_as()
.fetch_all(&self.pool)
.await
.map_err(|e| AppError::Database(e.to_string()))?;
rows.into_iter().map(TryInto::try_into).collect()
}
async fn count(&self, query: AuditLogQuery) -> Result<u64, AppError> {
let mut builder: QueryBuilder<Postgres> =
QueryBuilder::new("SELECT COUNT(*) FROM audit_logs");
append_where_conditions(&mut builder, &query);
let count: (i64,) = builder
.build_query_as()
.fetch_one(&self.pool)
.await
.map_err(|e| AppError::Database(e.to_string()))?;
Ok(count.0.max(0) as u64)
}
async fn query_with_count(
&self,
query: AuditLogQuery,
) -> Result<(Vec<AuditLogEntry>, u64), AppError> {
let (limit, offset) = cap_pagination_values(&query);
let mut builder: QueryBuilder<Postgres> = QueryBuilder::new(
r#"WITH filtered AS (
SELECT id, event_type, actor_user_id, session_id, org_id,
target_type, target_id, ip_address, user_agent, metadata, created_at,
COUNT(*) OVER() as total_count
FROM audit_logs"#,
);
append_where_conditions(&mut builder, &query);
builder.push(
r#"
)
SELECT id, event_type, actor_user_id, session_id, org_id,
target_type, target_id, ip_address, user_agent, metadata, created_at,
total_count
FROM filtered
ORDER BY created_at DESC
LIMIT "#,
);
builder.push_bind(limit);
builder.push(" OFFSET ");
builder.push_bind(offset);
#[derive(sqlx::FromRow)]
struct AuditLogRowWithCount {
id: Uuid,
event_type: String,
actor_user_id: Option<Uuid>,
session_id: Option<Uuid>,
org_id: Option<Uuid>,
target_type: Option<String>,
target_id: Option<Uuid>,
ip_address: Option<String>,
user_agent: Option<String>,
metadata: Value,
created_at: DateTime<Utc>,
total_count: i64,
}
let rows: Vec<AuditLogRowWithCount> = builder
.build_query_as()
.fetch_all(&self.pool)
.await
.map_err(|e| AppError::Database(e.to_string()))?;
let total = rows
.first()
.map(|r| r.total_count.max(0) as u64)
.unwrap_or(0);
let entries: Result<Vec<AuditLogEntry>, AppError> = rows
.into_iter()
.map(|row| {
let event_type = AuditEventType::from_str(&row.event_type).ok_or_else(|| {
AppError::Database(format!("Invalid event type: {}", row.event_type))
})?;
Ok(AuditLogEntry {
id: row.id,
event_type,
actor_user_id: row.actor_user_id,
session_id: row.session_id,
org_id: row.org_id,
target_type: row.target_type,
target_id: row.target_id,
ip_address: row.ip_address,
user_agent: row.user_agent,
metadata: row.metadata,
created_at: row.created_at,
})
})
.collect();
Ok((entries?, total))
}
async fn delete_older_than(&self, date: DateTime<Utc>) -> Result<u64, AppError> {
let result = sqlx::query("DELETE FROM audit_logs WHERE created_at < $1")
.bind(date)
.execute(&self.pool)
.await
.map_err(|e| AppError::Database(e.to_string()))?;
Ok(result.rows_affected())
}
async fn delete_by_org(&self, org_id: Uuid) -> Result<u64, AppError> {
let result = sqlx::query("DELETE FROM audit_logs WHERE org_id = $1")
.bind(org_id)
.execute(&self.pool)
.await
.map_err(|e| AppError::Database(e.to_string()))?;
Ok(result.rows_affected())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cap_pagination_values_defaults() {
let q = AuditLogQuery::default();
let (limit, offset) = cap_pagination_values(&q);
assert_eq!(limit, DEFAULT_MAX_PAGE_SIZE as i64);
assert_eq!(offset, 0);
}
#[test]
fn test_cap_pagination_values_caps() {
let q = AuditLogQuery {
limit: Some(DEFAULT_MAX_PAGE_SIZE + 1),
offset: Some(DEFAULT_MAX_OFFSET + 1),
..Default::default()
};
let (limit, offset) = cap_pagination_values(&q);
assert_eq!(limit, DEFAULT_MAX_PAGE_SIZE as i64);
assert_eq!(offset, DEFAULT_MAX_OFFSET as i64);
}
}