hickory_resolver/name_server/
name_server.rs

1// Copyright 2015-2019 Benjamin Fry <benjaminfry@me.com>
2//
3// Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or
4// https://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// https://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8use std::fmt::{self, Debug, Formatter};
9use std::pin::Pin;
10use std::sync::Arc;
11use std::time::Instant;
12
13use futures_util::lock::Mutex;
14use futures_util::stream::{Stream, once};
15use tracing::debug;
16
17use crate::config::{NameServerConfig, ResolverOpts};
18use crate::name_server::connection_provider::{ConnectionProvider, GenericConnector};
19use crate::name_server::{NameServerState, NameServerStats};
20use crate::proto::{
21    ProtoError,
22    xfer::{DnsHandle, DnsRequest, DnsResponse, FirstAnswer},
23};
24
25/// This struct is used to create `DnsHandle` with the help of `P`.
26#[derive(Clone)]
27pub struct NameServer<P: ConnectionProvider> {
28    config: NameServerConfig,
29    options: ResolverOpts,
30    client: Arc<Mutex<Option<P::Conn>>>,
31    state: Arc<NameServerState>,
32    pub(crate) stats: Arc<NameServerStats>,
33    connection_provider: P,
34}
35
36/// Specifies the details of a remote NameServer used for lookups
37pub type GenericNameServer<R> = NameServer<GenericConnector<R>>;
38
39impl<P> Debug for NameServer<P>
40where
41    P: ConnectionProvider + Send,
42{
43    fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), fmt::Error> {
44        write!(f, "config: {:?}, options: {:?}", self.config, self.options)
45    }
46}
47
48impl<P> NameServer<P>
49where
50    P: ConnectionProvider + Send,
51{
52    /// Construct a new Nameserver with the configuration and options. The connection provider will create UDP and TCP sockets
53    pub fn new(config: NameServerConfig, options: ResolverOpts, connection_provider: P) -> Self {
54        Self {
55            config,
56            options,
57            client: Arc::new(Mutex::new(None)),
58            state: Arc::new(NameServerState::init(None)),
59            stats: Arc::new(NameServerStats::default()),
60            connection_provider,
61        }
62    }
63
64    #[doc(hidden)]
65    pub fn from_conn(
66        config: NameServerConfig,
67        options: ResolverOpts,
68        client: P::Conn,
69        connection_provider: P,
70    ) -> Self {
71        Self {
72            config,
73            options,
74            client: Arc::new(Mutex::new(Some(client))),
75            state: Arc::new(NameServerState::init(None)),
76            stats: Arc::new(NameServerStats::default()),
77            connection_provider,
78        }
79    }
80
81    #[cfg(test)]
82    #[allow(dead_code)]
83    pub(crate) fn is_connected(&self) -> bool {
84        !self.state.is_failed()
85            && if let Some(client) = self.client.try_lock() {
86                client.is_some()
87            } else {
88                // assuming that if someone has it locked it will be or is connected
89                true
90            }
91    }
92
93    /// This will return a mutable client to allows for sending messages.
94    ///
95    /// If the connection is in a failed state, then this will establish a new connection
96    async fn connected_mut_client(&mut self) -> Result<P::Conn, ProtoError> {
97        let mut client = self.client.lock().await;
98
99        // if this is in a failure state
100        if self.state.is_failed() || client.is_none() {
101            debug!("reconnecting: {:?}", self.config);
102
103            // TODO: we need the local EDNS options
104            self.state.reinit(None);
105
106            let new_client = Box::pin(
107                self.connection_provider
108                    .new_connection(&self.config, &self.options)?,
109            )
110            .await?;
111
112            // establish a new connection
113            *client = Some(new_client);
114        } else {
115            debug!("existing connection: {:?}", self.config);
116        }
117
118        Ok((*client)
119            .clone()
120            .expect("bad state, client should be connected"))
121    }
122
123    async fn inner_send<R: Into<DnsRequest> + Unpin + Send + 'static>(
124        mut self,
125        request: R,
126    ) -> Result<DnsResponse, ProtoError> {
127        let client = self.connected_mut_client().await?;
128        let now = Instant::now();
129        let response = client.send(request).first_answer().await;
130        let rtt = now.elapsed();
131
132        match response {
133            Ok(response) => {
134                // First evaluate if the message succeeded.
135                let result =
136                    ProtoError::from_response(response, self.config.trust_negative_responses);
137                self.stats.record(rtt, &result);
138                let response = result?;
139
140                // TODO: consider making message::take_edns...
141                let remote_edns = response.extensions().clone();
142
143                // take the remote edns options and store them
144                self.state.establish(remote_edns);
145
146                Ok(response)
147            }
148            Err(error) => {
149                debug!(config = ?self.config, "name_server connection failure: {}", error);
150
151                // this transitions the state to failure
152                self.state.fail(Instant::now());
153
154                // record the failure
155                self.stats.record_connection_failure();
156
157                // These are connection failures, not lookup failures, that is handled in the resolver layer
158                Err(error)
159            }
160        }
161    }
162
163    /// Specifies that this NameServer will treat negative responses as permanent failures and will not retry
164    pub fn trust_nx_responses(&self) -> bool {
165        self.config.trust_negative_responses
166    }
167}
168
169impl<P> DnsHandle for NameServer<P>
170where
171    P: ConnectionProvider + Clone,
172{
173    type Response = Pin<Box<dyn Stream<Item = Result<DnsResponse, ProtoError>> + Send>>;
174
175    fn is_verifying_dnssec(&self) -> bool {
176        self.options.validate
177    }
178
179    // TODO: there needs to be some way of customizing the connection based on EDNS options from the server side...
180    fn send<R: Into<DnsRequest> + Unpin + Send + 'static>(&self, request: R) -> Self::Response {
181        let this = self.clone();
182        // if state is failed, return future::err(), unless retry delay expired..
183        Box::pin(once(this.inner_send(request)))
184    }
185}
186
187#[cfg(test)]
188#[cfg(feature = "tokio")]
189mod tests {
190    use std::net::{IpAddr, Ipv4Addr, SocketAddr};
191    use std::str::FromStr;
192    use std::time::Duration;
193
194    use hickory_proto::op::Message;
195    use hickory_proto::rr::rdata::NULL;
196    use hickory_proto::rr::{RData, Record};
197    use test_support::subscribe;
198    use tokio::net::UdpSocket;
199    use tokio::spawn;
200
201    use crate::proto::op::{Query, ResponseCode};
202    use crate::proto::rr::{Name, RecordType};
203    use crate::proto::xfer::{DnsHandle, DnsRequestOptions, FirstAnswer, Protocol};
204
205    use super::*;
206    use crate::name_server::connection_provider::TokioConnectionProvider;
207
208    #[tokio::test]
209    async fn test_name_server() {
210        subscribe();
211
212        let config = NameServerConfig {
213            socket_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)), 53),
214            protocol: Protocol::Udp,
215            tls_dns_name: None,
216            http_endpoint: None,
217            trust_negative_responses: false,
218            bind_addr: None,
219        };
220        let name_server = GenericNameServer::new(
221            config,
222            ResolverOpts::default(),
223            TokioConnectionProvider::default(),
224        );
225
226        let name = Name::parse("www.example.com.", None).unwrap();
227        let response = name_server
228            .lookup(
229                Query::query(name.clone(), RecordType::A),
230                DnsRequestOptions::default(),
231            )
232            .first_answer()
233            .await
234            .expect("query failed");
235        assert_eq!(response.response_code(), ResponseCode::NoError);
236    }
237
238    #[tokio::test]
239    async fn test_failed_name_server() {
240        subscribe();
241
242        let options = ResolverOpts {
243            timeout: Duration::from_millis(1), // this is going to fail, make it fail fast...
244            ..ResolverOpts::default()
245        };
246        let config = NameServerConfig {
247            socket_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 252)), 252),
248            protocol: Protocol::Udp,
249            tls_dns_name: None,
250            http_endpoint: None,
251            trust_negative_responses: false,
252            bind_addr: None,
253        };
254        let name_server =
255            GenericNameServer::new(config, options, TokioConnectionProvider::default());
256
257        let name = Name::parse("www.example.com.", None).unwrap();
258        assert!(
259            name_server
260                .lookup(
261                    Query::query(name.clone(), RecordType::A),
262                    DnsRequestOptions::default(),
263                )
264                .first_answer()
265                .await
266                .is_err()
267        );
268    }
269
270    #[tokio::test]
271    async fn case_randomization_query_preserved() {
272        subscribe();
273
274        let provider = TokioConnectionProvider::default();
275        let server = UdpSocket::bind((Ipv4Addr::LOCALHOST, 0)).await.unwrap();
276        let server_addr = server.local_addr().unwrap();
277        let name = Name::from_str("dead.beef.").unwrap();
278        let data = b"DEADBEEF";
279
280        spawn({
281            let name = name.clone();
282            async move {
283                let mut buffer = [0_u8; 512];
284                let (len, addr) = server.recv_from(&mut buffer).await.unwrap();
285                let request = Message::from_vec(&buffer[0..len]).unwrap();
286                let mut response = Message::new();
287                response.set_id(request.id());
288                response.add_queries(request.queries().to_vec());
289                response.add_answer(Record::from_rdata(
290                    name,
291                    0,
292                    RData::NULL(NULL::with(data.to_vec())),
293                ));
294                let response_buffer = response.to_vec().unwrap();
295                server.send_to(&response_buffer, addr).await.unwrap();
296            }
297        });
298
299        let config = NameServerConfig::new(server_addr, Protocol::Udp);
300        let resolver_opts = ResolverOpts {
301            case_randomization: true,
302            ..Default::default()
303        };
304        let mut request_options = DnsRequestOptions::default();
305        request_options.case_randomization = true;
306        let ns = NameServer::new(config, resolver_opts, provider);
307
308        let stream = ns.lookup(
309            Query::query(name.clone(), RecordType::NULL),
310            request_options,
311        );
312        let response = stream.first_answer().await.unwrap();
313
314        let response_query_name = response.query().unwrap().name();
315        assert!(response_query_name.eq_case(&name));
316    }
317}