use futures_lite::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use std::io::{self, ErrorKind};
pub(super) async fn length_prefixed_exchange<T: AsyncRead + AsyncWrite + Unpin>(
transport: &mut T,
query: &[u8],
finish_send: bool,
) -> io::Result<Vec<u8>> {
let len = u16::try_from(query.len())
.map_err(|_| io::Error::new(ErrorKind::InvalidInput, "DNS query exceeds 65535 bytes"))?;
let mut framed = Vec::with_capacity(query.len() + 2);
framed.extend_from_slice(&len.to_be_bytes());
framed.extend_from_slice(query);
transport.write_all(&framed).await?;
if finish_send {
transport.close().await?;
} else {
transport.flush().await?;
}
log::trace!("length-prefixed exchange: wrote {len}-byte query, awaiting response length");
let mut len_buf = [0u8; 2];
transport.read_exact(&mut len_buf).await?;
let response_len = usize::from(u16::from_be_bytes(len_buf));
log::trace!("length-prefixed exchange: reading {response_len}-byte response");
let mut response = vec![0u8; response_len];
transport.read_exact(&mut response).await?;
Ok(response)
}
pub(super) fn frame(message: &[u8]) -> io::Result<Vec<u8>> {
let len = u16::try_from(message.len())
.map_err(|_| io::Error::new(ErrorKind::InvalidInput, "DNS message exceeds 65535 bytes"))?;
let mut framed = Vec::with_capacity(message.len() + 2);
framed.extend_from_slice(&len.to_be_bytes());
framed.extend_from_slice(message);
Ok(framed)
}
pub(super) fn take_frame(buf: &mut Vec<u8>) -> Option<Vec<u8>> {
let len = usize::from(u16::from_be_bytes([*buf.first()?, *buf.get(1)?]));
if buf.len() < 2 + len {
return None;
}
let message = buf[2..2 + len].to_vec();
buf.drain(..2 + len);
Some(message)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dns::codec::{build_query, parse_response};
use futures_lite::future;
use hickory_proto::{
op::{Message, MessageType, OpCode},
rr::{Name, RData, Record, RecordType, rdata::A},
};
use std::net::{IpAddr, Ipv4Addr};
use trillium_testing::{TestTransport, harness, test};
fn a_response(ip: Ipv4Addr) -> Vec<u8> {
let mut message = Message::new(0, MessageType::Response, OpCode::Query);
message.add_answer(Record::from_rdata(
Name::from_utf8("example.com.").unwrap(),
60,
RData::A(A(ip)),
));
message.to_vec().unwrap()
}
#[test(harness)]
async fn length_prefixed_exchange_round_trips() {
let (mut client, mut server) = TestTransport::new();
let query = build_query("example.com", 443, RecordType::A).unwrap();
let ip = Ipv4Addr::new(192, 0, 2, 1);
let response = a_response(ip);
let responder = {
let query = query.clone();
async move {
let mut len_buf = [0u8; 2];
server.read_exact(&mut len_buf).await.unwrap();
let mut received = vec![0u8; usize::from(u16::from_be_bytes(len_buf))];
server.read_exact(&mut received).await.unwrap();
assert_eq!(received, query);
let mut framed = u16::try_from(response.len())
.unwrap()
.to_be_bytes()
.to_vec();
framed.extend_from_slice(&response);
server.write_all(&framed);
}
};
let (_, result) = future::zip(
responder,
length_prefixed_exchange(&mut client, &query, false),
)
.await;
let (resolved, _) = parse_response(&result.unwrap()).unwrap();
assert_eq!(resolved.addrs, vec![IpAddr::V4(ip)]);
}
#[test]
fn frame_and_take_frame_round_trip() {
let mut buf = frame(b"hello").unwrap();
buf.extend(frame(b"world").unwrap());
buf.extend_from_slice(&[0, 3, b'a']);
assert_eq!(take_frame(&mut buf).unwrap(), b"hello");
assert_eq!(take_frame(&mut buf).unwrap(), b"world");
assert!(take_frame(&mut buf).is_none());
buf.extend_from_slice(b"bc");
assert_eq!(take_frame(&mut buf).unwrap(), b"abc");
assert!(buf.is_empty());
assert!(take_frame(&mut buf).is_none());
}
}