use crate::err::as_io_error;
use dnssector::DNS_MAX_COMPRESSED_SIZE;
use std::{
io::{self, ErrorKind},
net::SocketAddr,
time::Duration,
};
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
net::{TcpStream, UdpSocket},
};
pub async fn query_raw_udp(
local: &SocketAddr,
upstream: &SocketAddr,
packet: &[u8],
timeout: Duration,
) -> io::Result<Vec<u8>> {
tokio::time::timeout(timeout, async {
let socket = UdpSocket::bind(local).await?;
socket.connect(upstream).await?;
socket.send(packet).await?;
let mut response = vec![0; DNS_MAX_COMPRESSED_SIZE];
let len = socket.recv(&mut response).await?;
response.truncate(len);
Ok(response)
})
.await
.map_err(as_io_error(ErrorKind::TimedOut))?
}
pub async fn query_raw_tcp(
upstream: &SocketAddr,
packet: &[u8],
timeout: Duration,
) -> io::Result<Vec<u8>> {
tokio::time::timeout(timeout, async {
let mut stream = TcpStream::connect(upstream).await?;
let _ = stream.set_nodelay(true);
stream.write_all(packet).await?;
let mut response_len_bytes = [0u8; 2];
stream.read_exact(&mut response_len_bytes).await?;
let response_len =
((response_len_bytes[0] as usize) << 8) | (response_len_bytes[1] as usize);
let mut response = vec![0; response_len];
stream.read_exact(&mut response).await?;
Ok(response)
})
.await
.map_err(as_io_error(ErrorKind::TimedOut))?
}