use std::sync::Arc;
use std::time::Duration;
use aa_core::storage::{AuditEntry, AuditSink, Result};
use aa_storage_sqlite_buffer::EventBuffer;
use tokio::task::JoinHandle;
use super::{METRIC_BUFFERED, METRIC_FLUSHED, METRIC_PUBLISHED, METRIC_PUBLISH_ERRORS};
pub struct AuditPublisher {
sink: Arc<dyn AuditSink>,
buffer: Arc<EventBuffer>,
}
impl AuditPublisher {
#[must_use]
pub fn new(sink: Arc<dyn AuditSink>, buffer: Arc<EventBuffer>) -> Self {
Self { sink, buffer }
}
pub async fn publish(&self, entry: AuditEntry) {
match self.sink.emit(entry.clone()).await {
Ok(()) => {
metrics::counter!(METRIC_PUBLISHED).increment(1);
}
Err(_) => {
metrics::counter!(METRIC_PUBLISH_ERRORS).increment(1);
if self.buffer.enqueue(&entry).is_ok() {
metrics::counter!(METRIC_BUFFERED).increment(1);
}
}
}
}
pub fn buffered_len(&self) -> Result<usize> {
self.buffer.len()
}
pub async fn flush_pending(&self) -> Result<usize> {
if self.buffer.is_empty()? {
return Ok(0);
}
let flushed = self.buffer.drain_and_send(&*self.sink).await?;
if flushed > 0 {
metrics::counter!(METRIC_FLUSHED).increment(flushed as u64);
}
Ok(flushed)
}
#[must_use]
pub fn spawn_reconnect_flush_loop(self: Arc<Self>, interval: Duration) -> JoinHandle<()> {
tokio::spawn(async move {
let mut ticker = tokio::time::interval(interval);
loop {
ticker.tick().await;
if let Err(err) = self.flush_pending().await {
tracing::warn!(error = %err, "audit buffer flush failed");
}
}
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Mutex;
use aa_core::audit::{AuditEntry, AuditEventType};
use aa_core::storage::StorageError;
use aa_core::{AgentId, SessionId};
use async_trait::async_trait;
use tempfile::TempDir;
struct FakeSink {
up: AtomicBool,
captured: Mutex<Vec<AuditEntry>>,
}
impl FakeSink {
fn new(up: bool) -> Self {
Self {
up: AtomicBool::new(up),
captured: Mutex::new(Vec::new()),
}
}
fn set_up(&self, up: bool) {
self.up.store(up, Ordering::SeqCst);
}
fn captured_seqs(&self) -> Vec<u64> {
self.captured.lock().unwrap().iter().map(AuditEntry::seq).collect()
}
}
#[async_trait]
impl AuditSink for FakeSink {
async fn emit(&self, event: AuditEntry) -> Result<()> {
if !self.up.load(Ordering::SeqCst) {
return Err(StorageError::Backend("fake sink down".to_string()));
}
self.captured.lock().unwrap().push(event);
Ok(())
}
}
fn entry(seq: u64) -> AuditEntry {
AuditEntry::new(
seq,
seq,
AuditEventType::ToolCallIntercepted,
AgentId::from_bytes([0u8; 16]),
SessionId::from_bytes([0u8; 16]),
format!("{{\"seq\":{seq}}}"),
[0u8; 32],
)
}
fn new_buffer() -> (TempDir, Arc<EventBuffer>) {
let dir = TempDir::new().expect("temp dir");
let buffer = EventBuffer::new(dir.path().join("buffer.db"), 10_000).expect("open buffer");
(dir, Arc::new(buffer))
}
#[tokio::test]
async fn publish_succeeds_when_sink_up() {
let (_dir, buffer) = new_buffer();
let sink = Arc::new(FakeSink::new(true));
let publisher = AuditPublisher::new(sink.clone(), buffer.clone());
publisher.publish(entry(1)).await;
assert_eq!(sink.captured_seqs(), vec![1]);
assert_eq!(publisher.buffered_len().unwrap(), 0, "nothing should be buffered");
}
#[tokio::test]
async fn buffers_when_sink_down() {
let (_dir, buffer) = new_buffer();
let sink = Arc::new(FakeSink::new(false));
let publisher = AuditPublisher::new(sink.clone(), buffer.clone());
publisher.publish(entry(1)).await;
publisher.publish(entry(2)).await;
assert!(sink.captured_seqs().is_empty(), "sink is down, nothing delivered");
assert_eq!(publisher.buffered_len().unwrap(), 2, "both events buffered");
}
#[tokio::test]
async fn reconnect_drains_buffer_in_fifo_order() {
let (_dir, buffer) = new_buffer();
let sink = Arc::new(FakeSink::new(false));
let publisher = AuditPublisher::new(sink.clone(), buffer.clone());
for seq in 1..=3 {
publisher.publish(entry(seq)).await;
}
assert_eq!(publisher.buffered_len().unwrap(), 3);
sink.set_up(true);
let flushed = publisher.flush_pending().await.unwrap();
assert_eq!(flushed, 3);
assert_eq!(publisher.buffered_len().unwrap(), 0, "buffer fully drained");
assert_eq!(sink.captured_seqs(), vec![1, 2, 3], "replayed in FIFO order");
}
#[test]
fn records_all_four_audit_metrics() {
use metrics_util::debugging::DebuggingRecorder;
let recorder = DebuggingRecorder::new();
let snapshotter = recorder.snapshotter();
metrics::with_local_recorder(&recorder, || {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
rt.block_on(async {
let (_dir, buffer) = new_buffer();
let sink = Arc::new(FakeSink::new(true));
let publisher = AuditPublisher::new(sink.clone(), buffer.clone());
publisher.publish(entry(1)).await; sink.set_up(false);
publisher.publish(entry(2)).await; sink.set_up(true);
publisher.flush_pending().await.unwrap(); });
});
let names: Vec<String> = snapshotter
.snapshot()
.into_vec()
.into_iter()
.map(|(key, _, _, _)| key.key().name().to_string())
.collect();
for expected in [METRIC_PUBLISHED, METRIC_PUBLISH_ERRORS, METRIC_BUFFERED, METRIC_FLUSHED] {
assert!(
names.iter().any(|n| n == expected),
"missing metric {expected}; got {names:?}"
);
}
}
}