agnostic_dns/
lib.rs

1#![doc = include_str!("../README.md")]
2#![forbid(unsafe_code)]
3#![deny(warnings, missing_docs)]
4#![cfg_attr(docsrs, feature(doc_cfg))]
5#![cfg_attr(docsrs, allow(unused_attributes))]
6
7use agnostic_net::{Net, UdpSocket, runtime::RuntimeLite};
8use futures_util::future::FutureExt;
9use std::{future::Future, io, marker::PhantomData, net::SocketAddr, pin::Pin, time::Duration};
10
11use hickory_proto::Time;
12pub use hickory_resolver::config::*;
13use hickory_resolver::{
14  AsyncResolver,
15  name_server::{ConnectionProvider, GenericConnector, RuntimeProvider, Spawn},
16};
17
18pub use agnostic_net as net;
19
20#[cfg(test)]
21mod tests;
22
23/// Agnostic aysnc DNS resolver
24pub type Dns<N> = AsyncResolver<AsyncConnectionProvider<N>>;
25
26/// Async spawner
27#[derive(Debug, Default)]
28#[repr(transparent)]
29pub struct AsyncSpawn<N> {
30  _marker: PhantomData<N>,
31}
32
33impl<N> Clone for AsyncSpawn<N> {
34  fn clone(&self) -> Self {
35    *self
36  }
37}
38
39impl<N> Copy for AsyncSpawn<N> {}
40
41impl<N: Net> Spawn for AsyncSpawn<N> {
42  fn spawn_bg<F>(&mut self, future: F)
43  where
44    F: Future<Output = Result<(), hickory_proto::error::ProtoError>> + Send + 'static,
45  {
46    <N::Runtime as RuntimeLite>::spawn_detach(future);
47  }
48}
49
50/// Defines which async runtime that handles IO and timers.
51pub struct AsyncRuntimeProvider<N> {
52  runtime: AsyncSpawn<N>,
53}
54
55impl<N> Default for AsyncRuntimeProvider<N> {
56  fn default() -> Self {
57    Self::new()
58  }
59}
60
61impl<N> AsyncRuntimeProvider<N> {
62  /// Create a new `AsyncRuntimeProvider`.
63  pub fn new() -> Self {
64    Self {
65      runtime: AsyncSpawn {
66        _marker: PhantomData,
67      },
68    }
69  }
70}
71
72impl<N> Clone for AsyncRuntimeProvider<N> {
73  fn clone(&self) -> Self {
74    *self
75  }
76}
77
78impl<N> Copy for AsyncRuntimeProvider<N> {}
79
80/// Timer implementation for the dns.
81pub struct Timer<N>(PhantomData<N>);
82
83#[async_trait::async_trait]
84impl<N: Net> Time for Timer<N> {
85  async fn delay_for(duration: Duration) {
86    let _ = <N::Runtime as RuntimeLite>::sleep(duration).await;
87  }
88
89  async fn timeout<F: 'static + Future + Send>(
90    duration: Duration,
91    future: F,
92  ) -> Result<F::Output, std::io::Error> {
93    <N::Runtime as RuntimeLite>::timeout(duration, future)
94      .await
95      .map_err(Into::into)
96  }
97}
98
99/// DNS time
100#[derive(Clone, Copy, Debug)]
101pub struct AgnosticTime<N>(PhantomData<N>);
102
103#[async_trait::async_trait]
104impl<N> Time for AgnosticTime<N>
105where
106  N: Net,
107{
108  async fn delay_for(duration: Duration) {
109    <N::Runtime as RuntimeLite>::sleep(duration).await;
110  }
111
112  async fn timeout<F: 'static + Future + Send>(
113    duration: Duration,
114    future: F,
115  ) -> Result<F::Output, std::io::Error> {
116    <N::Runtime as RuntimeLite>::timeout(duration, future)
117      .await
118      .map_err(Into::into)
119  }
120}
121
122/// DNS tcp
123#[doc(hidden)]
124pub struct AsyncDnsTcp<N: Net>(N::TcpStream);
125
126impl<N: Net> hickory_proto::tcp::DnsTcpStream for AsyncDnsTcp<N> {
127  type Time = AgnosticTime<N>;
128}
129
130impl<N: Net> AsyncDnsTcp<N> {
131  async fn connect(addr: SocketAddr) -> std::io::Result<Self> {
132    <N::TcpStream as agnostic_net::TcpStream>::connect(addr)
133      .await
134      .map(Self)
135  }
136}
137
138impl<N: Net> futures_util::AsyncRead for AsyncDnsTcp<N> {
139  fn poll_read(
140    mut self: std::pin::Pin<&mut Self>,
141    cx: &mut std::task::Context<'_>,
142    buf: &mut [u8],
143  ) -> std::task::Poll<io::Result<usize>> {
144    futures_util::AsyncRead::poll_read(Pin::new(&mut self.0), cx, buf)
145  }
146}
147
148impl<N: Net> futures_util::AsyncWrite for AsyncDnsTcp<N> {
149  fn poll_write(
150    mut self: Pin<&mut Self>,
151    cx: &mut std::task::Context<'_>,
152    buf: &[u8],
153  ) -> std::task::Poll<io::Result<usize>> {
154    futures_util::AsyncWrite::poll_write(Pin::new(&mut self.0), cx, buf)
155  }
156
157  fn poll_flush(
158    mut self: Pin<&mut Self>,
159    cx: &mut std::task::Context<'_>,
160  ) -> std::task::Poll<io::Result<()>> {
161    futures_util::AsyncWrite::poll_flush(Pin::new(&mut self.0), cx)
162  }
163
164  fn poll_close(
165    mut self: Pin<&mut Self>,
166    cx: &mut std::task::Context<'_>,
167  ) -> std::task::Poll<io::Result<()>> {
168    futures_util::AsyncWrite::poll_close(Pin::new(&mut self.0), cx)
169  }
170}
171
172/// DNS udp
173pub struct AsyncDnsUdp<N: Net>(N::UdpSocket);
174
175impl<N: Net> AsyncDnsUdp<N> {
176  async fn bind(addr: SocketAddr) -> std::io::Result<Self> {
177    <N::UdpSocket as UdpSocket>::bind(addr).await.map(Self)
178  }
179}
180
181impl<N: Net> hickory_proto::udp::DnsUdpSocket for AsyncDnsUdp<N> {
182  type Time = AgnosticTime<N>;
183
184  fn poll_recv_from(
185    &self,
186    cx: &mut std::task::Context<'_>,
187    buf: &mut [u8],
188  ) -> std::task::Poll<io::Result<(usize, SocketAddr)>> {
189    self.0.poll_recv_from(cx, buf)
190  }
191
192  fn poll_send_to(
193    &self,
194    cx: &mut std::task::Context<'_>,
195    buf: &[u8],
196    target: SocketAddr,
197  ) -> std::task::Poll<io::Result<usize>> {
198    self.0.poll_send_to(cx, buf, target)
199  }
200}
201
202#[cfg(any(feature = "dns-over-quic", feature = "dns-over-h3"))]
203impl<N: Net> hickory_proto::udp::QuicLocalAddr for AsyncDnsUdp<N> {
204  fn local_addr(&self) -> std::io::Result<std::net::SocketAddr> {
205    <N::UdpSocket as UdpSocket>::local_addr(&self.0)
206  }
207}
208
209impl<N: Net> RuntimeProvider for AsyncRuntimeProvider<N> {
210  type Handle = AsyncSpawn<N>;
211
212  type Timer = Timer<N>;
213
214  type Udp = AsyncDnsUdp<N>;
215
216  type Tcp = AsyncDnsTcp<N>;
217
218  fn create_handle(&self) -> Self::Handle {
219    self.runtime
220  }
221
222  fn connect_tcp(
223    &self,
224    addr: SocketAddr,
225  ) -> std::pin::Pin<Box<dyn Send + Future<Output = io::Result<Self::Tcp>>>> {
226    AsyncDnsTcp::connect(addr).boxed()
227  }
228
229  fn bind_udp(
230    &self,
231    local_addr: SocketAddr,
232    _server_addr: SocketAddr,
233  ) -> std::pin::Pin<Box<dyn Send + Future<Output = io::Result<Self::Udp>>>> {
234    AsyncDnsUdp::bind(local_addr).boxed()
235  }
236}
237
238/// Create `DnsHandle` with the help of `AsyncRuntimeProvider`.
239pub struct AsyncConnectionProvider<N: Net> {
240  runtime_provider: AsyncRuntimeProvider<N>,
241  connection_provider: GenericConnector<AsyncRuntimeProvider<N>>,
242}
243
244impl<N: Net> Default for AsyncConnectionProvider<N> {
245  fn default() -> Self {
246    Self::new()
247  }
248}
249
250impl<N: Net> AsyncConnectionProvider<N> {
251  /// Create a new `AsyncConnectionProvider`.
252  pub fn new() -> Self {
253    Self {
254      runtime_provider: AsyncRuntimeProvider::new(),
255      connection_provider: GenericConnector::new(AsyncRuntimeProvider::new()),
256    }
257  }
258}
259
260impl<N: Net> Clone for AsyncConnectionProvider<N> {
261  fn clone(&self) -> Self {
262    Self {
263      runtime_provider: self.runtime_provider,
264      connection_provider: self.connection_provider.clone(),
265    }
266  }
267}
268
269impl<N: Net> ConnectionProvider for AsyncConnectionProvider<N> {
270  type Conn = <GenericConnector<AsyncRuntimeProvider<N>> as ConnectionProvider>::Conn;
271  type FutureConn = <GenericConnector<AsyncRuntimeProvider<N>> as ConnectionProvider>::FutureConn;
272  type RuntimeProvider = AsyncRuntimeProvider<N>;
273
274  fn new_connection(
275    &self,
276    config: &hickory_resolver::config::NameServerConfig,
277    options: &hickory_resolver::config::ResolverOpts,
278  ) -> Self::FutureConn {
279    self.connection_provider.new_connection(config, options)
280  }
281}
282
283#[cfg(unix)]
284pub use dns_util::read_resolv_conf;
285
286#[cfg(unix)]
287pub use hickory_resolver::system_conf::parse_resolv_conf;
288pub use hickory_resolver::system_conf::read_system_conf;
289
290#[cfg(unix)]
291mod dns_util {
292  use std::{io, path::Path};
293
294  use hickory_resolver::config::{ResolverConfig, ResolverOpts};
295
296  /// Read the DNS configuration from a file.
297  pub fn read_resolv_conf<P: AsRef<Path>>(path: P) -> io::Result<(ResolverConfig, ResolverOpts)> {
298    std::fs::read_to_string(path).and_then(|conf| {
299      hickory_resolver::system_conf::parse_resolv_conf(conf)
300        .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
301    })
302  }
303}