Skip to main content

agnostic_dns/
lib.rs

1#![doc = include_str!("../README.md")]
2#![forbid(unsafe_code)]
3#![deny(warnings, missing_docs)]
4#![cfg_attr(docsrs, feature(doc_cfg))]
5#![cfg_attr(docsrs, allow(unused_attributes))]
6
7use agnostic_net::{Net, UdpSocket, runtime::RuntimeLite};
8use futures_util::future::FutureExt;
9use std::{future::Future, io, marker::PhantomData, net::SocketAddr, pin::Pin, time::Duration};
10
11use hickory_proto::runtime::{RuntimeProvider, Spawn, Time};
12pub use hickory_proto::xfer::Protocol;
13pub use hickory_resolver::config::*;
14use hickory_resolver::{
15  Resolver,
16  name_server::{ConnectionProvider, GenericConnector},
17};
18
19pub use agnostic_net as net;
20
21#[cfg(test)]
22mod tests;
23
24/// Agnostic async DNS resolver
25pub type Dns<N> = Resolver<AsyncConnectionProvider<N>>;
26
27/// Async spawner
28#[derive(Debug, Default)]
29#[repr(transparent)]
30pub struct AsyncSpawn<N> {
31  _marker: PhantomData<N>,
32}
33
34impl<N> Clone for AsyncSpawn<N> {
35  fn clone(&self) -> Self {
36    *self
37  }
38}
39
40impl<N> Copy for AsyncSpawn<N> {}
41
42impl<N: Net> Spawn for AsyncSpawn<N> {
43  fn spawn_bg<F>(&mut self, future: F)
44  where
45    F: Future<Output = Result<(), hickory_proto::ProtoError>> + Send + 'static,
46  {
47    <N::Runtime as RuntimeLite>::spawn_detach(future);
48  }
49}
50
51/// Defines which async runtime that handles IO and timers.
52pub struct AsyncRuntimeProvider<N> {
53  runtime: AsyncSpawn<N>,
54}
55
56impl<N> Default for AsyncRuntimeProvider<N> {
57  fn default() -> Self {
58    Self::new()
59  }
60}
61
62impl<N> AsyncRuntimeProvider<N> {
63  /// Create a new `AsyncRuntimeProvider`.
64  pub fn new() -> Self {
65    Self {
66      runtime: AsyncSpawn {
67        _marker: PhantomData,
68      },
69    }
70  }
71}
72
73impl<N> Clone for AsyncRuntimeProvider<N> {
74  fn clone(&self) -> Self {
75    *self
76  }
77}
78
79impl<N> Copy for AsyncRuntimeProvider<N> {}
80
81/// Timer implementation for the dns.
82pub struct Timer<N>(PhantomData<N>);
83
84#[async_trait::async_trait]
85impl<N: Net> Time for Timer<N> {
86  async fn delay_for(duration: Duration) {
87    let _ = <N::Runtime as RuntimeLite>::sleep(duration).await;
88  }
89
90  async fn timeout<F: 'static + Future + Send>(
91    duration: Duration,
92    future: F,
93  ) -> Result<F::Output, std::io::Error> {
94    <N::Runtime as RuntimeLite>::timeout(duration, future)
95      .await
96      .map_err(Into::into)
97  }
98}
99
100/// DNS time
101#[derive(Clone, Copy, Debug)]
102pub struct AgnosticTime<N>(PhantomData<N>);
103
104#[async_trait::async_trait]
105impl<N> Time for AgnosticTime<N>
106where
107  N: Net,
108{
109  async fn delay_for(duration: Duration) {
110    <N::Runtime as RuntimeLite>::sleep(duration).await;
111  }
112
113  async fn timeout<F: 'static + Future + Send>(
114    duration: Duration,
115    future: F,
116  ) -> Result<F::Output, std::io::Error> {
117    <N::Runtime as RuntimeLite>::timeout(duration, future)
118      .await
119      .map_err(Into::into)
120  }
121}
122
123/// DNS tcp
124#[doc(hidden)]
125pub struct AsyncDnsTcp<N: Net>(N::TcpStream);
126
127impl<N: Net> hickory_proto::tcp::DnsTcpStream for AsyncDnsTcp<N> {
128  type Time = AgnosticTime<N>;
129}
130
131impl<N: Net> AsyncDnsTcp<N> {
132  async fn connect(addr: SocketAddr) -> std::io::Result<Self> {
133    <N::TcpStream as agnostic_net::TcpStream>::connect(addr)
134      .await
135      .map(Self)
136  }
137}
138
139impl<N: Net> futures_util::AsyncRead for AsyncDnsTcp<N> {
140  fn poll_read(
141    mut self: std::pin::Pin<&mut Self>,
142    cx: &mut std::task::Context<'_>,
143    buf: &mut [u8],
144  ) -> std::task::Poll<io::Result<usize>> {
145    futures_util::AsyncRead::poll_read(Pin::new(&mut self.0), cx, buf)
146  }
147}
148
149impl<N: Net> futures_util::AsyncWrite for AsyncDnsTcp<N> {
150  fn poll_write(
151    mut self: Pin<&mut Self>,
152    cx: &mut std::task::Context<'_>,
153    buf: &[u8],
154  ) -> std::task::Poll<io::Result<usize>> {
155    futures_util::AsyncWrite::poll_write(Pin::new(&mut self.0), cx, buf)
156  }
157
158  fn poll_flush(
159    mut self: Pin<&mut Self>,
160    cx: &mut std::task::Context<'_>,
161  ) -> std::task::Poll<io::Result<()>> {
162    futures_util::AsyncWrite::poll_flush(Pin::new(&mut self.0), cx)
163  }
164
165  fn poll_close(
166    mut self: Pin<&mut Self>,
167    cx: &mut std::task::Context<'_>,
168  ) -> std::task::Poll<io::Result<()>> {
169    futures_util::AsyncWrite::poll_close(Pin::new(&mut self.0), cx)
170  }
171}
172
173/// DNS udp
174pub struct AsyncDnsUdp<N: Net>(N::UdpSocket);
175
176impl<N: Net> AsyncDnsUdp<N> {
177  async fn bind(addr: SocketAddr) -> std::io::Result<Self> {
178    <N::UdpSocket as UdpSocket>::bind(addr).await.map(Self)
179  }
180}
181
182impl<N: Net> hickory_proto::udp::DnsUdpSocket for AsyncDnsUdp<N> {
183  type Time = AgnosticTime<N>;
184
185  fn poll_recv_from(
186    &self,
187    cx: &mut std::task::Context<'_>,
188    buf: &mut [u8],
189  ) -> std::task::Poll<io::Result<(usize, SocketAddr)>> {
190    self.0.poll_recv_from(cx, buf)
191  }
192
193  fn poll_send_to(
194    &self,
195    cx: &mut std::task::Context<'_>,
196    buf: &[u8],
197    target: SocketAddr,
198  ) -> std::task::Poll<io::Result<usize>> {
199    self.0.poll_send_to(cx, buf, target)
200  }
201}
202
203impl<N: Net> RuntimeProvider for AsyncRuntimeProvider<N> {
204  type Handle = AsyncSpawn<N>;
205
206  type Timer = Timer<N>;
207
208  type Udp = AsyncDnsUdp<N>;
209
210  type Tcp = AsyncDnsTcp<N>;
211
212  fn create_handle(&self) -> Self::Handle {
213    self.runtime
214  }
215
216  fn connect_tcp(
217    &self,
218    server_addr: SocketAddr,
219    _bind_addr: Option<SocketAddr>,
220    timeout: Option<Duration>,
221  ) -> std::pin::Pin<Box<dyn Send + Future<Output = io::Result<Self::Tcp>>>> {
222    async move {
223      let connect_future = AsyncDnsTcp::connect(server_addr);
224      match timeout {
225        Some(duration) => <N::Runtime as RuntimeLite>::timeout(duration, connect_future)
226          .await
227          .map_err(|e| io::Error::new(io::ErrorKind::TimedOut, e))?,
228        None => connect_future.await,
229      }
230    }
231    .boxed()
232  }
233
234  fn bind_udp(
235    &self,
236    local_addr: SocketAddr,
237    _server_addr: SocketAddr,
238  ) -> std::pin::Pin<Box<dyn Send + Future<Output = io::Result<Self::Udp>>>> {
239    AsyncDnsUdp::bind(local_addr).boxed()
240  }
241}
242
243/// Create `DnsHandle` with the help of `AsyncRuntimeProvider`.
244pub struct AsyncConnectionProvider<N: Net> {
245  runtime_provider: AsyncRuntimeProvider<N>,
246  connection_provider: GenericConnector<AsyncRuntimeProvider<N>>,
247}
248
249impl<N: Net> Default for AsyncConnectionProvider<N> {
250  fn default() -> Self {
251    Self::new()
252  }
253}
254
255impl<N: Net> AsyncConnectionProvider<N> {
256  /// Create a new `AsyncConnectionProvider`.
257  pub fn new() -> Self {
258    Self {
259      runtime_provider: AsyncRuntimeProvider::new(),
260      connection_provider: GenericConnector::new(AsyncRuntimeProvider::new()),
261    }
262  }
263}
264
265impl<N: Net> Clone for AsyncConnectionProvider<N> {
266  fn clone(&self) -> Self {
267    Self {
268      runtime_provider: self.runtime_provider,
269      connection_provider: self.connection_provider.clone(),
270    }
271  }
272}
273
274impl<N: Net> ConnectionProvider for AsyncConnectionProvider<N> {
275  type Conn = <GenericConnector<AsyncRuntimeProvider<N>> as ConnectionProvider>::Conn;
276  type FutureConn = <GenericConnector<AsyncRuntimeProvider<N>> as ConnectionProvider>::FutureConn;
277  type RuntimeProvider = AsyncRuntimeProvider<N>;
278
279  fn new_connection(
280    &self,
281    config: &hickory_resolver::config::NameServerConfig,
282    options: &hickory_resolver::config::ResolverOpts,
283  ) -> Result<Self::FutureConn, io::Error> {
284    self.connection_provider.new_connection(config, options)
285  }
286}
287
288#[cfg(unix)]
289pub use dns_util::read_resolv_conf;
290
291#[cfg(unix)]
292pub use hickory_resolver::system_conf::parse_resolv_conf;
293pub use hickory_resolver::system_conf::read_system_conf;
294
295#[cfg(unix)]
296mod dns_util {
297  use std::{io, path::Path};
298
299  use hickory_resolver::config::{ResolverConfig, ResolverOpts};
300
301  /// Read the DNS configuration from a file.
302  pub fn read_resolv_conf<P: AsRef<Path>>(path: P) -> io::Result<(ResolverConfig, ResolverOpts)> {
303    std::fs::read_to_string(path).and_then(|conf| {
304      hickory_resolver::system_conf::parse_resolv_conf(conf)
305        .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
306    })
307  }
308}