use crate::SessionParser;
use super::datagram::DnsMessage;
use super::parser::{DnsParseResult, parse_message};
#[derive(Debug, Default, Clone)]
pub struct DnsTcpParser {
init_buf: Vec<u8>,
resp_buf: Vec<u8>,
}
impl DnsTcpParser {
fn drain(buf: &mut Vec<u8>) -> Vec<DnsMessage> {
let mut out = Vec::new();
loop {
if buf.len() < 2 {
return out;
}
let len = u16::from_be_bytes([buf[0], buf[1]]) as usize;
if buf.len() < 2 + len {
return out;
}
let frame_total = 2 + len;
let body = &buf[2..frame_total];
match parse_message(body) {
Ok(DnsParseResult::Query(q)) => out.push(DnsMessage::Query(q)),
Ok(DnsParseResult::Response(r)) => out.push(DnsMessage::Response(r)),
Err(_) => {
}
}
buf.drain(..frame_total);
}
}
}
impl SessionParser for DnsTcpParser {
type Message = DnsMessage;
fn feed_initiator(&mut self, bytes: &[u8]) -> Vec<DnsMessage> {
if bytes.is_empty() {
return Vec::new();
}
self.init_buf.extend_from_slice(bytes);
Self::drain(&mut self.init_buf)
}
fn feed_responder(&mut self, bytes: &[u8]) -> Vec<DnsMessage> {
if bytes.is_empty() {
return Vec::new();
}
self.resp_buf.extend_from_slice(bytes);
Self::drain(&mut self.resp_buf)
}
fn rst_initiator(&mut self) {
self.init_buf.clear();
}
fn rst_responder(&mut self) {
self.resp_buf.clear();
}
}
#[cfg(test)]
mod tests {
use super::*;
fn build_a_query_tcp(tx_id: u16, qname: &str) -> Vec<u8> {
let mut body = Vec::new();
body.extend_from_slice(&tx_id.to_be_bytes());
body.extend_from_slice(&0x0100u16.to_be_bytes()); body.extend_from_slice(&1u16.to_be_bytes()); body.extend_from_slice(&0u16.to_be_bytes()); body.extend_from_slice(&0u16.to_be_bytes()); body.extend_from_slice(&0u16.to_be_bytes()); for label in qname.split('.') {
body.push(label.len() as u8);
body.extend_from_slice(label.as_bytes());
}
body.push(0);
body.extend_from_slice(&1u16.to_be_bytes()); body.extend_from_slice(&1u16.to_be_bytes());
let mut frame = Vec::new();
frame.extend_from_slice(&(body.len() as u16).to_be_bytes());
frame.extend_from_slice(&body);
frame
}
#[test]
fn parses_one_query() {
let mut p = DnsTcpParser::default();
let bytes = build_a_query_tcp(0x1234, "example.com");
let msgs = p.feed_initiator(&bytes);
assert_eq!(msgs.len(), 1);
match &msgs[0] {
DnsMessage::Query(q) => assert_eq!(q.transaction_id, 0x1234),
_ => panic!("expected Query"),
}
}
#[test]
fn parses_multiple_pipelined_queries() {
let mut p = DnsTcpParser::default();
let mut bytes = Vec::new();
bytes.extend_from_slice(&build_a_query_tcp(1, "a.example"));
bytes.extend_from_slice(&build_a_query_tcp(2, "b.example"));
bytes.extend_from_slice(&build_a_query_tcp(3, "c.example"));
let msgs = p.feed_initiator(&bytes);
assert_eq!(msgs.len(), 3);
}
#[test]
fn split_segments_concatenate() {
let mut p = DnsTcpParser::default();
let bytes = build_a_query_tcp(42, "split.example");
let mut all = Vec::new();
for chunk in bytes.chunks(1) {
all.extend(p.feed_initiator(chunk));
}
assert_eq!(all.len(), 1);
match &all[0] {
DnsMessage::Query(q) => assert_eq!(q.transaction_id, 42),
_ => panic!("expected Query"),
}
}
#[test]
fn split_at_length_prefix() {
let mut p = DnsTcpParser::default();
let bytes = build_a_query_tcp(7, "prefix.split");
assert!(p.feed_initiator(&bytes[..1]).is_empty());
let msgs = p.feed_initiator(&bytes[1..]);
assert_eq!(msgs.len(), 1);
}
#[test]
fn malformed_body_consumes_frame_and_keeps_framing() {
let mut p = DnsTcpParser::default();
let mut bytes = Vec::new();
bytes.extend_from_slice(&12u16.to_be_bytes());
bytes.extend_from_slice(&[0xff; 12]);
bytes.extend_from_slice(&build_a_query_tcp(99, "valid.after"));
let msgs = p.feed_initiator(&bytes);
assert_eq!(msgs.len(), 1);
match &msgs[0] {
DnsMessage::Query(q) => assert_eq!(q.transaction_id, 99),
_ => panic!("expected Query"),
}
}
#[test]
fn rst_clears_buffer() {
let mut p = DnsTcpParser::default();
let bytes = build_a_query_tcp(1, "partial.example");
assert!(p.feed_initiator(&bytes[..bytes.len() / 2]).is_empty());
p.rst_initiator();
assert!(p.init_buf.is_empty());
}
#[test]
fn empty_feed_returns_empty() {
let mut p = DnsTcpParser::default();
assert!(p.feed_initiator(&[]).is_empty());
assert!(p.feed_responder(&[]).is_empty());
}
#[test]
fn auto_factory_via_default_clone() {
use crate::SessionParserFactory;
let mut f = DnsTcpParser::default();
let mut p: DnsTcpParser = SessionParserFactory::<()>::new_parser(&mut f, &());
let bytes = build_a_query_tcp(7, "auto.factory");
let msgs = p.feed_initiator(&bytes);
assert_eq!(msgs.len(), 1);
}
}