use super::EnqueueError;
use crate::dns::Message;
use crate::plugin::PluginHandler;
use crate::server::{Protocol, RequestContext, RequestHandler};
use dashmap::DashSet;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::mpsc;
use tokio::task::JoinHandle;
use tracing::{debug, trace, warn};
use super::stats::RefreshStats;
#[derive(Clone)]
pub struct RefreshTask {
pub key: String,
pub message: Message,
pub handler: Arc<PluginHandler>,
pub entry_name: String,
pub created_at: Instant,
}
pub struct RefreshCoordinator {
tx: mpsc::Sender<RefreshTask>,
processing: Arc<DashSet<String>>,
stats: Arc<RefreshStats>,
worker_handles: Arc<tokio::sync::Mutex<Vec<JoinHandle<()>>>>,
}
impl RefreshCoordinator {
pub fn new(worker_count: usize, queue_capacity: usize) -> Self {
let (tx, rx) = mpsc::channel(queue_capacity);
let rx = Arc::new(tokio::sync::Mutex::new(rx));
let processing = Arc::new(DashSet::new());
let stats = Arc::new(RefreshStats::new());
let mut handles = Vec::with_capacity(worker_count);
debug!(
worker_count = worker_count,
queue_capacity = queue_capacity,
"Starting refresh coordinator worker pool"
);
for worker_id in 0..worker_count {
let rx_clone = Arc::clone(&rx);
let processing_clone = Arc::clone(&processing);
let stats_clone = Arc::clone(&stats);
let handle = tokio::spawn(async move {
Self::worker_loop(worker_id, rx_clone, processing_clone, stats_clone).await;
});
handles.push(handle);
}
Self {
tx,
processing,
stats,
worker_handles: Arc::new(tokio::sync::Mutex::new(handles)),
}
}
pub async fn enqueue(&self, task: RefreshTask) -> crate::Result<()> {
if !self.processing.insert(task.key.clone()) {
trace!(key = %task.key, "Refresh already in progress, skipping duplicate");
self.stats.record_dedup_skipped();
return Err(EnqueueError::AlreadyProcessing.into());
}
let key_for_cleanup = task.key.clone();
match self.tx.try_send(task) {
Ok(_) => {
self.stats.record_enqueued();
Ok(())
}
Err(mpsc::error::TrySendError::Full(_)) => {
self.processing.remove(&key_for_cleanup);
self.stats.record_rejected();
warn!(
key = %key_for_cleanup,
queue_depth = self.stats.queue_depth(),
"Refresh queue full, rejecting task"
);
Err(EnqueueError::QueueFull.into())
}
Err(mpsc::error::TrySendError::Closed(_)) => {
self.processing.remove(&key_for_cleanup);
warn!("Refresh coordinator channel closed");
Err(EnqueueError::Closed.into())
}
}
}
pub fn stats(&self) -> Arc<RefreshStats> {
Arc::clone(&self.stats)
}
pub async fn shutdown(self) -> crate::Result<()> {
debug!("Shutting down refresh coordinator");
drop(self.tx);
let mut handles_guard = self.worker_handles.lock().await;
let handles = std::mem::take(&mut *handles_guard);
drop(handles_guard);
for handle in handles {
match handle.await {
Ok(_) => {
trace!("Worker task completed successfully");
}
Err(e) => {
warn!("Worker task panicked: {}", e);
}
}
}
debug!("Refresh coordinator shutdown complete");
Ok(())
}
async fn worker_loop(
worker_id: usize,
rx: Arc<tokio::sync::Mutex<mpsc::Receiver<RefreshTask>>>,
processing: Arc<DashSet<String>>,
stats: Arc<RefreshStats>,
) {
trace!(worker_id = worker_id, "Refresh worker started");
loop {
let task = {
let mut rx_guard = rx.lock().await;
rx_guard.recv().await
};
match task {
Some(task) => {
let start = Instant::now();
let key = task.key.clone();
let queued_duration = start.duration_since(task.created_at);
trace!(
worker_id = worker_id,
key = %key,
queued_ms = queued_duration.as_millis(),
"Processing refresh task"
);
const REFRESH_TIMEOUT: Duration = Duration::from_secs(10);
let result =
tokio::time::timeout(REFRESH_TIMEOUT, Self::execute_task(&task)).await;
let duration = start.elapsed();
stats.record_processed();
match result {
Ok(Ok(_)) => {
stats.record_success();
debug!(
worker_id = worker_id,
key = %key,
duration_ms = duration.as_millis(),
"Refresh succeeded"
);
}
Ok(Err(e)) => {
stats.record_failed();
debug!(
worker_id = worker_id,
key = %key,
duration_ms = duration.as_millis(),
error = %e,
"Refresh failed"
);
}
Err(_) => {
stats.record_timeout();
warn!(
worker_id = worker_id,
key = %key,
timeout_secs = REFRESH_TIMEOUT.as_secs(),
"Refresh timeout"
);
}
}
processing.remove(&key);
}
None => {
debug!(
worker_id = worker_id,
"Refresh worker stopping (channel closed)"
);
break;
}
}
}
trace!(worker_id = worker_id, "Refresh worker stopped");
}
async fn execute_task(task: &RefreshTask) -> crate::Result<()> {
trace!(key = %task.key, "Executing refresh query");
let ctx = RequestContext::new(task.message.clone(), Protocol::Udp);
match task.handler.handle(ctx).await {
Ok(response) => {
if response.response_code() == crate::dns::ResponseCode::NoError {
trace!(key = %task.key, "Refresh query returned NoError");
Ok(())
} else {
trace!(
key = %task.key,
rcode = ?response.response_code(),
"Refresh query returned error response"
);
Err(crate::Error::Plugin(format!(
"Response code: {:?}",
response.response_code()
)))
}
}
Err(e) => {
trace!(key = %task.key, error = %e, "Refresh query failed");
Err(e)
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_refresh_coordinator_shutdown() {
let coordinator = RefreshCoordinator::new(2, 10);
{
let handles = coordinator.worker_handles.lock().await;
assert_eq!(handles.len(), 2, "Should have 2 worker handles");
}
let result = coordinator.shutdown().await;
assert!(result.is_ok(), "Shutdown should succeed");
debug!("Test passed: coordinator shutdown successful");
}
#[tokio::test]
async fn test_refresh_coordinator_created() {
let coordinator = RefreshCoordinator::new(1, 10);
let stats = coordinator.stats();
assert!(stats.enqueued.load(std::sync::atomic::Ordering::Relaxed) == 0);
let _ = coordinator.shutdown().await;
debug!("Test passed: coordinator created and shutdown successfully");
}
}