use std::collections::HashMap;
use std::hash::Hash;
use std::time::Duration;
use crate::Timestamp;
use ahash::RandomState;
use super::types::{DnsConfig, DnsQuery};
pub struct Correlator<S: Eq + Hash + Clone> {
pending: HashMap<(S, u16), DnsQuery, RandomState>,
config: DnsConfig,
}
impl<S: Eq + Hash + Clone> Correlator<S> {
pub fn new() -> Self {
Self::with_config(DnsConfig::default())
}
pub fn with_config(config: DnsConfig) -> Self {
Self {
pending: HashMap::with_hasher(RandomState::new()),
config,
}
}
pub fn record_query(&mut self, scope: S, q: DnsQuery) {
if self.pending.len() >= self.config.max_pending {
if let Some(oldest_key) = self
.pending
.iter()
.min_by_key(|(_, q)| q.timestamp)
.map(|(k, _)| k.clone())
{
self.pending.remove(&oldest_key);
}
}
let key = (scope, q.transaction_id);
self.pending.insert(key, q);
}
pub fn match_response(
&mut self,
scope: &S,
tx_id: u16,
response_time: Timestamp,
) -> Option<(DnsQuery, Duration)> {
let key = (scope.clone(), tx_id);
let q = self.pending.remove(&key)?;
let elapsed = response_time
.to_duration()
.saturating_sub(q.timestamp.to_duration());
Some((q, elapsed))
}
pub fn sweep(&mut self, now: Timestamp) -> Vec<DnsQuery> {
let now_d = now.to_duration();
let timeout = self.config.query_timeout;
let expired: Vec<(S, u16)> = self
.pending
.iter()
.filter_map(|(k, q)| {
let age = now_d.saturating_sub(q.timestamp.to_duration());
if age >= timeout {
Some(k.clone())
} else {
None
}
})
.collect();
let mut out = Vec::with_capacity(expired.len());
for k in expired {
if let Some(q) = self.pending.remove(&k) {
out.push(q);
}
}
out
}
pub fn pending_len(&self) -> usize {
self.pending.len()
}
}
impl<S: Eq + Hash + Clone> Default for Correlator<S> {
fn default() -> Self {
Self::new()
}
}