use std::sync::Arc;
use futures::stream;
use pgwire::api::results::{DataRowEncoder, QueryResponse, Response};
use pgwire::error::PgWireResult;
use crate::control::security::identity::{AuthenticatedIdentity, Role};
use crate::control::server::pgwire::types::{int8_field, text_field};
use crate::control::state::SharedState;
const DEFAULT_LIMIT: usize = 10_000;
pub fn audit_log(
state: &SharedState,
identity: &AuthenticatedIdentity,
upper: &str,
) -> PgWireResult<Vec<Response>> {
if !identity.is_superuser && !identity.has_role(&Role::Monitor) {
return Err(pgwire::error::PgWireError::UserError(Box::new(
pgwire::error::ErrorInfo::new(
"ERROR".to_string(),
"42501".to_string(),
"permission denied: audit_log:read requires superuser or monitor role".to_string(),
),
)));
}
let schema = Arc::new(vec![
int8_field("seq"),
int8_field("timestamp_us"),
text_field("event"),
int8_field("tenant_id"),
text_field("source"),
text_field("detail"),
text_field("prev_hash"),
]);
let Some(catalog) = state.credentials.catalog() else {
return audit_log_from_memory(state, schema);
};
let (from_seq, to_seq, limit) = extract_seq_range_and_limit(upper);
let entries = catalog
.load_audit_entries_ranged(from_seq, to_seq, 0, u64::MAX, limit)
.map_err(|e| pgwire::error::PgWireError::ApiError(Box::new(e)))?;
let mut rows = Vec::with_capacity(entries.len());
let mut encoder = DataRowEncoder::new(schema.clone());
for entry in &entries {
encoder.encode_field(&(entry.seq as i64))?;
encoder.encode_field(&(entry.timestamp_us as i64))?;
encoder.encode_field(&entry.event.as_str())?;
encoder.encode_field(&(entry.tenant_id.unwrap_or(0) as i64))?;
encoder.encode_field(&entry.source.as_str())?;
encoder.encode_field(&entry.detail.as_str())?;
encoder.encode_field(&entry.prev_hash.as_str())?;
rows.push(Ok(encoder.take_row()));
}
Ok(vec![Response::Query(QueryResponse::new(
schema,
stream::iter(rows),
))])
}
fn audit_log_from_memory(
state: &SharedState,
schema: Arc<Vec<pgwire::api::results::FieldInfo>>,
) -> PgWireResult<Vec<Response>> {
let log = match state.audit.lock() {
Ok(l) => l,
Err(p) => p.into_inner(),
};
let all = log.all();
let limit = DEFAULT_LIMIT.min(all.len());
let skip = all.len().saturating_sub(limit);
let mut rows = Vec::new();
let mut encoder = DataRowEncoder::new(schema.clone());
for entry in all.iter().skip(skip) {
encoder.encode_field(&(entry.seq as i64))?;
encoder.encode_field(&(entry.timestamp_us as i64))?;
encoder.encode_field(&format!("{:?}", entry.event))?;
encoder.encode_field(&(entry.tenant_id.map_or(0i64, |t| t.as_u64() as i64)))?;
encoder.encode_field(&entry.source.as_str())?;
encoder.encode_field(&entry.detail.as_str())?;
encoder.encode_field(&entry.prev_hash.as_str())?;
rows.push(Ok(encoder.take_row()));
}
Ok(vec![Response::Query(QueryResponse::new(
schema,
stream::iter(rows),
))])
}
fn extract_seq_range_and_limit(upper: &str) -> (u64, u64, usize) {
let mut from_seq: u64 = 1;
let mut to_seq: u64 = u64::MAX;
let mut limit: usize = DEFAULT_LIMIT;
if let Some(pos) = upper.find("LIMIT") {
let after = upper[pos + 5..].trim_start();
let end = after
.find(|c: char| !c.is_ascii_digit())
.unwrap_or(after.len());
if let Ok(n) = after[..end].parse::<usize>() {
limit = n.min(DEFAULT_LIMIT);
}
}
if let Some(from) = parse_seq_bound(upper, "SEQ >=").or_else(|| parse_seq_bound(upper, "SEQ>="))
{
from_seq = from;
} else if let Some(from) =
parse_seq_bound(upper, "SEQ >").or_else(|| parse_seq_bound(upper, "SEQ>"))
{
from_seq = from.saturating_add(1);
} else if let Some(eq) =
parse_seq_bound(upper, "SEQ =").or_else(|| parse_seq_bound(upper, "SEQ="))
{
from_seq = eq;
to_seq = eq;
}
if let Some(to) = parse_seq_bound(upper, "SEQ <=").or_else(|| parse_seq_bound(upper, "SEQ<=")) {
to_seq = to;
} else if let Some(to) =
parse_seq_bound(upper, "SEQ <").or_else(|| parse_seq_bound(upper, "SEQ<"))
{
to_seq = to.saturating_sub(1);
}
(from_seq, to_seq, limit)
}
fn parse_seq_bound(upper: &str, pattern: &str) -> Option<u64> {
let pos = upper.find(pattern)?;
let after = upper[pos + pattern.len()..].trim_start();
let end = after
.find(|c: char| !c.is_ascii_digit())
.unwrap_or(after.len());
after[..end].parse::<u64>().ok()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_no_range_returns_defaults() {
let (from, to, limit) = extract_seq_range_and_limit("SELECT * FROM _SYSTEM.AUDIT_LOG");
assert_eq!(from, 1);
assert_eq!(to, u64::MAX);
assert_eq!(limit, DEFAULT_LIMIT);
}
#[test]
fn parse_seq_ge_bound() {
let (from, to, limit) =
extract_seq_range_and_limit("SELECT * FROM _SYSTEM.AUDIT_LOG WHERE SEQ >= 42");
assert_eq!(from, 42);
assert_eq!(to, u64::MAX);
assert_eq!(limit, DEFAULT_LIMIT);
}
#[test]
fn parse_seq_le_bound() {
let (from, to, _) =
extract_seq_range_and_limit("SELECT * FROM _SYSTEM.AUDIT_LOG WHERE SEQ <= 100");
assert_eq!(from, 1);
assert_eq!(to, 100);
}
#[test]
fn parse_limit() {
let (_, _, limit) =
extract_seq_range_and_limit("SELECT * FROM _SYSTEM.AUDIT_LOG LIMIT 500");
assert_eq!(limit, 500);
}
#[test]
fn parse_limit_capped_at_default() {
let (_, _, limit) =
extract_seq_range_and_limit("SELECT * FROM _SYSTEM.AUDIT_LOG LIMIT 99999");
assert_eq!(limit, DEFAULT_LIMIT);
}
#[test]
fn parse_seq_equality() {
let (from, to, _) =
extract_seq_range_and_limit("SELECT * FROM _SYSTEM.AUDIT_LOG WHERE SEQ = 77");
assert_eq!(from, 77);
assert_eq!(to, 77);
}
#[test]
fn parse_seq_gt_strict() {
let (from, to, _) =
extract_seq_range_and_limit("SELECT * FROM _SYSTEM.AUDIT_LOG WHERE SEQ > 10");
assert_eq!(from, 11);
assert_eq!(to, u64::MAX);
}
}