1use std::net::SocketAddr;
2
3use adns_proto::{Header, Packet, PacketParseError, Question};
4use rand::{thread_rng, Rng};
5use thiserror::Error;
6use tokio::{
7 io::{AsyncReadExt, AsyncWriteExt},
8 net::{TcpStream, ToSocketAddrs, UdpSocket},
9};
10
11pub struct DnsClient {
12 udp: UdpSocket,
13}
14
15#[derive(Error, Debug)]
16pub enum DnsQueryError {
17 #[error("packet ID mismatch")]
18 IDMismatch,
19 #[error("packet too large >64KB")]
20 PacketTooLarge,
21 #[error("{0}")]
22 IoError(#[from] std::io::Error),
23 #[error("dns parse error {0}")]
24 PacketParseError(#[from] PacketParseError),
25}
26
27impl DnsClient {
28 pub async fn new() -> Result<Self, DnsQueryError> {
29 Ok(Self {
30 udp: UdpSocket::bind("[::]:0".parse::<SocketAddr>().unwrap()).await?,
31 })
32 }
33
34 pub async fn query(
35 &mut self,
36 servers: impl ToSocketAddrs,
37 questions: Vec<Question>,
38 ) -> Result<Packet, DnsQueryError> {
39 let id: u16 = thread_rng().gen();
40 let packet = Packet {
41 header: Header {
42 id,
43 recursion_desired: true,
44 recursion_available: true,
45 ..Default::default()
46 },
47 questions,
48 ..Default::default()
49 };
50 let serialized = packet.serialize(usize::MAX);
51 if serialized.len() > 512 {
52 self.query_tcp(&servers, id, &serialized).await
53 } else {
54 self.udp.send_to(&serialized, &servers).await?;
55 let mut response = [0u8; 512];
56 let mut size;
57 loop {
58 size = self.udp.recv(&mut response).await?;
59 if size < 2 || u16::from_be_bytes(response[..2].try_into().unwrap()) != id {
60 continue;
61 }
62 break;
63 }
64 match Packet::parse(&response[..size]) {
65 Ok(packet) => Ok(packet.0),
66 Err(PacketParseError::Truncated) => self.query_tcp(&servers, id, &serialized).await,
67 Err(e) => Err(e.into()),
68 }
69 }
70 }
71
72 async fn query_tcp(
73 &mut self,
74 servers: impl ToSocketAddrs,
75 id: u16,
76 packet: &[u8],
77 ) -> Result<Packet, DnsQueryError> {
78 let mut client = TcpStream::connect(servers).await?;
79 client
80 .write_u16(
81 packet
82 .len()
83 .try_into()
84 .map_err(|_| DnsQueryError::PacketTooLarge)?,
85 )
86 .await?;
87 client.write_all(packet).await?;
88 let len = client.read_u16().await?;
89 let mut response = vec![0u8; len as usize];
90 client.read_exact(&mut response).await?;
91
92 let packet = Packet::parse(&response)?.0;
93 if packet.header.id != id {
94 return Err(DnsQueryError::IDMismatch);
95 }
96 Ok(packet)
97 }
98}
99
100#[cfg(test)]
101mod tests {
102 use adns_proto::Type;
103
104 use super::*;
105
106 #[tokio::test]
107 async fn test_query() {
108 let mut client = DnsClient::new().await.unwrap();
109 let response = client
110 .query(
111 "8.8.8.8:53",
112 vec![Question::new(Type::A, "google.com").unwrap()],
113 )
114 .await
115 .unwrap();
116 for answer in &response.answers {
117 println!("{answer}");
118 }
119 }
120}