use std::sync::Arc;
use tokio::sync::mpsc::Receiver;
use tokio_util::sync::CancellationToken;
use tracing::warn;
use crate::{
resolver::{pipeline::Outcome, state::ResolverState},
storage::query_log::{QueryLogRecord, QueryLogRepository},
telemetry::QueryEvent,
};
const BATCH_CAPACITY: usize = 256;
const FLUSH_INTERVAL: std::time::Duration = std::time::Duration::from_secs(1);
pub struct QueryLogWriter<R> {
rx: Receiver<QueryEvent>,
repo: R,
state: Arc<ResolverState>,
}
impl<R> QueryLogWriter<R>
where
R: QueryLogRepository,
{
pub fn new(rx: Receiver<QueryEvent>, repo: R, state: Arc<ResolverState>) -> Self {
Self { rx, repo, state }
}
fn record_for(&self, event: &QueryEvent) -> QueryLogRecord {
let mut record = QueryLogRecord::from(event);
if event.outcome == Outcome::BlockedByBlocklist {
record.blocklist_id = self.state.blocklist().primary_source(&event.qname);
}
record
}
pub async fn run(mut self, token: CancellationToken) {
loop {
let mut batch = Vec::with_capacity(BATCH_CAPACITY);
tokio::select! {
biased;
_ = token.cancelled() => break,
received = self.rx.recv_many(&mut batch, BATCH_CAPACITY) => {
if received == 0 {
return;
}
}
}
if batch.len() < BATCH_CAPACITY {
let flush_at = tokio::time::sleep(FLUSH_INTERVAL);
tokio::pin!(flush_at);
loop {
let remaining = BATCH_CAPACITY - batch.len();
if remaining == 0 {
break;
}
tokio::select! {
biased;
_ = token.cancelled() => break,
_ = &mut flush_at => break,
received = self.rx.recv_many(&mut batch, remaining) => {
if received == 0 {
break; }
}
}
}
}
self.flush(batch).await;
if token.is_cancelled() {
break;
}
}
self.final_drain().await;
}
async fn final_drain(&mut self) {
let mut batch = Vec::with_capacity(BATCH_CAPACITY);
while let Ok(event) = self.rx.try_recv() {
batch.push(event);
if batch.len() >= BATCH_CAPACITY {
self.flush(std::mem::take(&mut batch)).await;
batch.reserve(BATCH_CAPACITY);
}
}
self.flush(batch).await;
}
async fn flush(&self, batch: Vec<QueryEvent>) {
if batch.is_empty() {
return;
}
let records: Vec<QueryLogRecord> =
batch.iter().map(|event| self.record_for(event)).collect();
if let Err(e) = self.repo.insert_batch(&records).await {
warn!(
error = %e,
count = records.len(),
"failed to persist query-log batch"
);
}
}
}
#[cfg(test)]
mod tests {
use tempfile::TempDir;
use tokio::sync::mpsc;
use super::*;
use crate::{
codec::{message::Qtype, name::Name},
storage::{Db, query_log::SqliteQueryLogRepo},
};
async fn open_repo() -> (TempDir, SqliteQueryLogRepo, Db, Arc<ResolverState>) {
let (dir, db) = crate::test_support::temp_db().await;
let repo = db.query_log();
let state = ResolverState::hydrate(&db).await.expect("hydrate");
(dir, repo, db, state)
}
fn event(name: &str) -> QueryEvent {
event_with(name, Outcome::Forwarded)
}
fn event_with(name: &str, outcome: Outcome) -> QueryEvent {
QueryEvent::new(
"10.0.0.1:1000".parse().unwrap(),
name.parse::<Name>().unwrap(),
Qtype::A,
outcome,
)
.with_ts(1_000)
}
fn install_blocklist(state: &ResolverState, entries: &[(&str, i64)]) {
let map = entries
.iter()
.map(|(n, id)| (n.parse::<Name>().unwrap(), *id))
.collect();
state.blocklist().store(map);
}
#[tokio::test]
async fn writes_enqueued_events_to_db() {
let (_dir, repo, db, state) = open_repo().await;
let (tx, rx) = mpsc::channel(64);
let writer = QueryLogWriter::new(rx, repo, state);
let token = CancellationToken::new();
let t2 = token.clone();
let handle = tokio::spawn(async move { writer.run(t2).await });
tx.send(event("a.test.")).await.unwrap();
tx.send(event("b.test.")).await.unwrap();
token.cancel();
drop(tx);
handle.await.unwrap();
let rows = db.query_log().page(None, 10).await.unwrap();
assert_eq!(rows.len(), 2, "both enqueued events must be persisted");
}
#[tokio::test]
async fn final_drain_flushes_buffered_events_on_cancel() {
let (_dir, repo, db, state) = open_repo().await;
let (tx, rx) = mpsc::channel(64);
for i in 0..5 {
tx.try_send(event(&format!("d{i}.test."))).unwrap();
}
let token = CancellationToken::new();
token.cancel();
let writer = QueryLogWriter::new(rx, repo, state);
writer.run(token).await;
let rows = db.query_log().page(None, 10).await.unwrap();
assert_eq!(rows.len(), 5, "cancellation must drain the buffer");
}
#[tokio::test]
async fn closed_channel_ends_the_run() {
let (_dir, repo, _db, state) = open_repo().await;
let (tx, rx) = mpsc::channel(64);
let writer = QueryLogWriter::new(rx, repo, state);
drop(tx);
tokio::time::timeout(
std::time::Duration::from_secs(5),
writer.run(CancellationToken::new()),
)
.await
.expect("run must end when the channel closes");
}
#[tokio::test]
async fn blocked_by_blocklist_persists_primary_source() {
let (_dir, repo, db, state) = open_repo().await;
install_blocklist(&state, &[("ads.example.com", 7)]);
let (tx, rx) = mpsc::channel(64);
let token = CancellationToken::new();
let t2 = token.clone();
let writer = QueryLogWriter::new(rx, repo, state);
let handle = tokio::spawn(async move { writer.run(t2).await });
tx.send(event_with("ads.example.com", Outcome::BlockedByBlocklist))
.await
.unwrap();
token.cancel();
drop(tx);
handle.await.unwrap();
let rows = db.query_log().page(None, 10).await.unwrap();
assert_eq!(rows.len(), 1);
assert_eq!(
rows[0].blocklist_id,
Some(7),
"block must be attributed to its primary source"
);
}
#[tokio::test]
async fn non_blocklist_outcomes_persist_null() {
let (_dir, repo, db, state) = open_repo().await;
install_blocklist(&state, &[("ads.example.com", 7)]);
let (tx, rx) = mpsc::channel(64);
let token = CancellationToken::new();
let t2 = token.clone();
let writer = QueryLogWriter::new(rx, repo, state);
let handle = tokio::spawn(async move { writer.run(t2).await });
tx.send(event_with("ads.example.com", Outcome::BlockedByAdmin))
.await
.unwrap();
tx.send(event_with("safe.example.com", Outcome::Forwarded))
.await
.unwrap();
token.cancel();
drop(tx);
handle.await.unwrap();
let rows = db.query_log().page(None, 10).await.unwrap();
assert_eq!(rows.len(), 2);
assert!(
rows.iter().all(|r| r.blocklist_id.is_none()),
"only BlockedByBlocklist rows may carry a blocklist_id"
);
}
#[tokio::test]
async fn blocked_domain_absent_from_snapshot_persists_null() {
let (_dir, repo, db, state) = open_repo().await;
let (tx, rx) = mpsc::channel(64);
let token = CancellationToken::new();
let t2 = token.clone();
let writer = QueryLogWriter::new(rx, repo, state);
let handle = tokio::spawn(async move { writer.run(t2).await });
tx.send(event_with("gone.example.com", Outcome::BlockedByBlocklist))
.await
.unwrap();
token.cancel();
drop(tx);
handle.await.unwrap();
let rows = db.query_log().page(None, 10).await.unwrap();
assert_eq!(rows.len(), 1);
assert_eq!(
rows[0].blocklist_id, None,
"absent attribution must store NULL, not error"
);
}
}