use dnssector::DNS_MAX_COMPRESSED_SIZE;
use smol::{
future::FutureExt,
io::{AsyncReadExt, AsyncWriteExt},
net::{TcpStream, UdpSocket},
Timer,
};
use std::{
io::{self, ErrorKind},
net::SocketAddr,
time::Duration,
};
pub async fn query_raw_udp(
local: &SocketAddr,
upstream: &SocketAddr,
packet: &[u8],
timeout: Duration,
) -> io::Result<Vec<u8>> {
async {
let socket = UdpSocket::bind(local).await?;
socket.connect(upstream).await?;
socket.send(packet).await?;
let mut response = vec![0; DNS_MAX_COMPRESSED_SIZE];
let len = socket.recv(&mut response).await?;
response.truncate(len);
Ok(response)
}
.or(with_timeout(timeout))
.await
}
pub async fn query_raw_tcp(
upstream: &SocketAddr,
packet: &[u8],
timeout: Duration,
) -> io::Result<Vec<u8>> {
async {
let mut stream = TcpStream::connect(upstream).await?;
let _ = stream.set_nodelay(true);
stream.write_all(packet).await?;
let mut response_len_bytes = [0u8; 2];
stream.read_exact(&mut response_len_bytes).await?;
let response_len =
((response_len_bytes[0] as usize) << 8) | (response_len_bytes[1] as usize);
let mut response = vec![0; response_len];
stream.read_exact(&mut response).await?;
Ok(response)
}
.or(with_timeout(timeout))
.await
}
async fn with_timeout(timeout: Duration) -> io::Result<Vec<u8>> {
Timer::after(timeout).await;
Err(ErrorKind::TimedOut.into())
}