use std::sync::Arc;
use futures::stream;
use pgwire::api::results::{DataRowEncoder, QueryResponse, Response};
use pgwire::error::PgWireResult;
use crate::control::security::audit::entry::hash_entry;
use crate::control::security::identity::AuthenticatedIdentity;
use crate::control::state::SharedState;
use super::super::super::types::{sqlstate_error, text_field};
use super::helpers::clean_arg;
pub async fn verify_audit_chain(
state: &SharedState,
_identity: &AuthenticatedIdentity,
sql: &str,
) -> PgWireResult<Vec<Response>> {
let args = super::helpers::extract_function_args(sql, "VERIFY_AUDIT_CHAIN")?;
let from_seq: u64 = if !args.is_empty() {
clean_arg(args[0])
.parse()
.map_err(|_| sqlstate_error("22023", "from_seq must be an integer"))?
} else {
1
};
let to_seq: u64 = if args.len() >= 2 {
clean_arg(args[1])
.parse()
.map_err(|_| sqlstate_error("22023", "to_seq must be an integer"))?
} else {
u64::MAX
};
let audit_entries = state
.wal
.recover_audit_entries()
.map_err(|e| sqlstate_error("XX000", &format!("audit WAL recovery failed: {e}")))?;
let mut valid = true;
let mut checked = 0u64;
let mut broken_at: Option<u64> = None;
let mut prev_hash = String::new();
if !audit_entries.is_empty() {
for (_data_lsn, audit_bytes) in &audit_entries {
let entry: crate::control::security::audit::AuditEntry =
match zerompk::from_msgpack(audit_bytes) {
Ok(e) => e,
Err(e) => {
tracing::warn!(error = %e, "skipping malformed audit WAL entry");
continue;
}
};
if entry.seq < from_seq {
prev_hash = hash_entry(&entry);
continue;
}
if entry.seq > to_seq {
break;
}
if entry.prev_hash != prev_hash {
valid = false;
broken_at = Some(entry.seq);
break;
}
prev_hash = hash_entry(&entry);
checked += 1;
}
} else {
let log = match state.audit.lock() {
Ok(l) => l,
Err(poisoned) => {
tracing::warn!("audit log mutex poisoned, recovering");
poisoned.into_inner()
}
};
for entry in log.all() {
if entry.seq < from_seq {
prev_hash = hash_entry(entry);
continue;
}
if entry.seq > to_seq {
break;
}
if entry.prev_hash != prev_hash {
valid = false;
broken_at = Some(entry.seq);
break;
}
prev_hash = hash_entry(entry);
checked += 1;
}
}
let result = serde_json::json!({
"valid": valid,
"entries_checked": checked,
"from_seq": from_seq,
"to_seq": if to_seq == u64::MAX { checked + from_seq - 1 } else { to_seq },
"broken_at_seq": broken_at,
"last_hash": prev_hash,
});
let schema = Arc::new(vec![text_field("result")]);
let mut encoder = DataRowEncoder::new(schema.clone());
encoder
.encode_field(&result.to_string())
.map_err(|e| sqlstate_error("XX000", &e.to_string()))?;
Ok(vec![Response::Query(QueryResponse::new(
schema,
stream::iter(vec![Ok(encoder.take_row())]),
))])
}