use std::hash::Hash;
use std::sync::{Arc, Mutex};
use crate::{Extracted, FlowExtractor, PacketView, Timestamp};
use super::correlator::Correlator;
use super::parser::{DnsParseResult, parse_message_at};
use super::types::{DnsConfig, DnsHandler};
pub struct DnsUdpObserver<E, H>
where
E: FlowExtractor,
H: DnsHandler,
{
pub inner: E,
pub handler: Arc<H>,
pub udp_port: u16,
correlator: Arc<Mutex<Correlator<E::Key>>>,
}
impl<E, H> DnsUdpObserver<E, H>
where
E: FlowExtractor,
H: DnsHandler,
E::Key: Eq + Hash + Clone,
{
pub fn new(inner: E, handler: H) -> Self {
Self::with_config(inner, handler, DnsConfig::default(), 53)
}
pub fn with_config(inner: E, handler: H, config: DnsConfig, udp_port: u16) -> Self {
Self {
inner,
handler: Arc::new(handler),
udp_port,
correlator: Arc::new(Mutex::new(Correlator::with_config(config))),
}
}
pub fn sweep_unanswered(&self, now: Timestamp) {
let expired = self.correlator.lock().unwrap().sweep(now);
for q in expired {
self.handler.on_unanswered(&q);
}
}
}
impl<E, H> FlowExtractor for DnsUdpObserver<E, H>
where
E: FlowExtractor,
H: DnsHandler,
E::Key: Eq + Hash + Clone,
{
type Key = E::Key;
fn extract(&self, view: PacketView<'_>) -> Option<Extracted<E::Key>> {
let inner_result = self.inner.extract(view);
if let Some(udp) = peek_udp(view.frame)
&& (udp.dst_port == self.udp_port || udp.src_port == self.udp_port)
&& let Ok(parsed) = parse_message_at(udp.payload, view.timestamp)
{
let scope = inner_result.as_ref().map(|e| e.key.clone());
match parsed {
DnsParseResult::Query(q) => {
if let Some(s) = scope {
self.correlator.lock().unwrap().record_query(s, q.clone());
}
self.handler.on_query(&q);
}
DnsParseResult::Response(mut resp) => {
if let Some(s) = scope {
if let Some((_, elapsed)) = self.correlator.lock().unwrap().match_response(
&s,
resp.transaction_id,
view.timestamp,
) {
resp.elapsed = Some(elapsed);
}
}
self.handler.on_response(&resp);
}
}
}
inner_result
}
}
struct UdpInfo<'a> {
src_port: u16,
dst_port: u16,
payload: &'a [u8],
}
fn peek_udp(frame: &[u8]) -> Option<UdpInfo<'_>> {
let mut offset = 14usize; if frame.len() < offset {
return None;
}
let mut ethertype = u16::from_be_bytes([frame[12], frame[13]]);
for _ in 0..2 {
if ethertype != 0x8100 && ethertype != 0x88a8 {
break;
}
if frame.len() < offset + 4 {
return None;
}
ethertype = u16::from_be_bytes([frame[offset + 2], frame[offset + 3]]);
offset += 4;
}
let (proto, l4_offset) = match ethertype {
0x0800 => {
if frame.len() < offset + 20 {
return None;
}
let ihl = (frame[offset] & 0x0f) as usize * 4;
if ihl < 20 || frame.len() < offset + ihl {
return None;
}
let proto = frame[offset + 9];
let frag = u16::from_be_bytes([frame[offset + 6], frame[offset + 7]]);
let frag_off = frag & 0x1FFF;
let mf = (frag & 0x2000) != 0;
if frag_off != 0 || mf {
return None;
}
(proto, offset + ihl)
}
0x86dd => {
if frame.len() < offset + 40 {
return None;
}
(frame[offset + 6], offset + 40)
}
_ => return None,
};
if proto != 17 {
return None;
}
if frame.len() < l4_offset + 8 {
return None;
}
let src_port = u16::from_be_bytes([frame[l4_offset], frame[l4_offset + 1]]);
let dst_port = u16::from_be_bytes([frame[l4_offset + 2], frame[l4_offset + 3]]);
let udp_len = u16::from_be_bytes([frame[l4_offset + 4], frame[l4_offset + 5]]) as usize;
if udp_len < 8 || frame.len() < l4_offset + udp_len {
return None;
}
let payload = &frame[l4_offset + 8..l4_offset + udp_len];
Some(UdpInfo {
src_port,
dst_port,
payload,
})
}