adns_client/
lib.rs

1use std::net::SocketAddr;
2
3use adns_proto::{Header, Packet, PacketParseError, Question};
4use rand::{thread_rng, Rng};
5use thiserror::Error;
6use tokio::{
7    io::{AsyncReadExt, AsyncWriteExt},
8    net::{TcpStream, ToSocketAddrs, UdpSocket},
9};
10
11pub struct DnsClient {
12    udp: UdpSocket,
13}
14
15#[derive(Error, Debug)]
16pub enum DnsQueryError {
17    #[error("packet ID mismatch")]
18    IDMismatch,
19    #[error("packet too large >64KB")]
20    PacketTooLarge,
21    #[error("{0}")]
22    IoError(#[from] std::io::Error),
23    #[error("dns parse error {0}")]
24    PacketParseError(#[from] PacketParseError),
25}
26
27impl DnsClient {
28    pub async fn new() -> Result<Self, DnsQueryError> {
29        Ok(Self {
30            udp: UdpSocket::bind("[::]:0".parse::<SocketAddr>().unwrap()).await?,
31        })
32    }
33
34    pub async fn query(
35        &mut self,
36        servers: impl ToSocketAddrs,
37        questions: Vec<Question>,
38    ) -> Result<Packet, DnsQueryError> {
39        let id: u16 = thread_rng().gen();
40        let packet = Packet {
41            header: Header {
42                id,
43                recursion_desired: true,
44                recursion_available: true,
45                ..Default::default()
46            },
47            questions,
48            ..Default::default()
49        };
50        let serialized = packet.serialize(usize::MAX);
51        if serialized.len() > 512 {
52            self.query_tcp(&servers, id, &serialized).await
53        } else {
54            self.udp.send_to(&serialized, &servers).await?;
55            let mut response = [0u8; 512];
56            let mut size;
57            loop {
58                size = self.udp.recv(&mut response).await?;
59                if size < 2 || u16::from_be_bytes(response[..2].try_into().unwrap()) != id {
60                    continue;
61                }
62                break;
63            }
64            match Packet::parse(&response[..size]) {
65                Ok(packet) => Ok(packet.0),
66                Err(PacketParseError::Truncated) => self.query_tcp(&servers, id, &serialized).await,
67                Err(e) => Err(e.into()),
68            }
69        }
70    }
71
72    async fn query_tcp(
73        &mut self,
74        servers: impl ToSocketAddrs,
75        id: u16,
76        packet: &[u8],
77    ) -> Result<Packet, DnsQueryError> {
78        let mut client = TcpStream::connect(servers).await?;
79        client
80            .write_u16(
81                packet
82                    .len()
83                    .try_into()
84                    .map_err(|_| DnsQueryError::PacketTooLarge)?,
85            )
86            .await?;
87        client.write_all(packet).await?;
88        let len = client.read_u16().await?;
89        let mut response = vec![0u8; len as usize];
90        client.read_exact(&mut response).await?;
91
92        let packet = Packet::parse(&response)?.0;
93        if packet.header.id != id {
94            return Err(DnsQueryError::IDMismatch);
95        }
96        Ok(packet)
97    }
98}
99
100#[cfg(test)]
101mod tests {
102    use adns_proto::Type;
103
104    use super::*;
105
106    #[tokio::test]
107    async fn test_query() {
108        let mut client = DnsClient::new().await.unwrap();
109        let response = client
110            .query(
111                "8.8.8.8:53",
112                vec![Question::new(Type::A, "google.com").unwrap()],
113            )
114            .await
115            .unwrap();
116        for answer in &response.answers {
117            println!("{answer}");
118        }
119    }
120}