hickory_client/client/
client.rs

1// Copyright 2015-2023 Benjamin Fry <benjaminfry@me.com>
2//
3// Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or
4// https://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// https://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8use std::{
9    future::Future,
10    pin::Pin,
11    sync::Arc,
12    task::{Context, Poll},
13    time::Duration,
14};
15
16use futures_util::{
17    ready,
18    stream::{Stream, StreamExt},
19};
20use rand;
21use tracing::debug;
22
23use crate::{ClientError, ClientErrorKind};
24use hickory_proto::{
25    ProtoError, ProtoErrorKind,
26    op::{Edns, Message, MessageFinalizer, MessageType, OpCode, Query, update_message},
27    rr::{DNSClass, Name, Record, RecordSet, RecordType, rdata::SOA},
28    runtime::TokioTime,
29    xfer::{
30        BufDnsStreamHandle, DnsClientStream, DnsExchange, DnsExchangeBackground, DnsExchangeSend,
31        DnsHandle, DnsMultiplexer, DnsRequest, DnsRequestOptions, DnsRequestSender, DnsResponse,
32    },
33};
34
35#[doc(hidden)]
36#[deprecated(since = "0.25.0", note = "use `Client` instead")]
37pub type ClientFuture = Client;
38
39/// A DNS Client implemented over futures-rs.
40///
41/// This Client is generic and capable of wrapping UDP, TCP, and other underlying DNS protocol
42///  implementations.
43#[derive(Clone)]
44pub struct Client {
45    exchange: DnsExchange,
46    use_edns: bool,
47}
48
49impl Client {
50    /// Spawns a new Client Stream. This uses a default timeout of 5 seconds for all requests.
51    ///
52    /// # Arguments
53    ///
54    /// * `stream` - A stream of bytes that can be used to send/receive DNS messages
55    ///   (see TcpClientStream or UdpClientStream)
56    /// * `stream_handle` - The handle for the `stream` on which bytes can be sent/received.
57    /// * `signer` - An optional signer for requests, needed for Updates with Sig0, otherwise not needed
58    #[allow(clippy::new_ret_no_self)]
59    pub async fn new<F, S>(
60        stream: F,
61        stream_handle: BufDnsStreamHandle,
62        signer: Option<Arc<dyn MessageFinalizer>>,
63    ) -> Result<(Self, DnsExchangeBackground<DnsMultiplexer<S>, TokioTime>), ProtoError>
64    where
65        F: Future<Output = Result<S, ProtoError>> + Send + Unpin + 'static,
66        S: DnsClientStream + 'static + Unpin,
67    {
68        Self::with_timeout(stream, stream_handle, Duration::from_secs(5), signer).await
69    }
70
71    /// Spawns a new Client Stream.
72    ///
73    /// # Arguments
74    ///
75    /// * `stream` - A stream of bytes that can be used to send/receive DNS messages
76    ///   (see TcpClientStream or UdpClientStream)
77    /// * `stream_handle` - The handle for the `stream` on which bytes can be sent/received.
78    /// * `timeout_duration` - All requests may fail due to lack of response, this is the time to
79    ///   wait for a response before canceling the request.
80    /// * `signer` - An optional signer for requests, needed for Updates with Sig0, otherwise not needed
81    pub async fn with_timeout<F, S>(
82        stream: F,
83        stream_handle: BufDnsStreamHandle,
84        timeout_duration: Duration,
85        signer: Option<Arc<dyn MessageFinalizer>>,
86    ) -> Result<(Self, DnsExchangeBackground<DnsMultiplexer<S>, TokioTime>), ProtoError>
87    where
88        F: Future<Output = Result<S, ProtoError>> + 'static + Send + Unpin,
89        S: DnsClientStream + 'static + Unpin,
90    {
91        let mp = DnsMultiplexer::with_timeout(stream, stream_handle, timeout_duration, signer);
92        Self::connect(mp).await
93    }
94
95    /// Returns a future, which itself wraps a future which is awaiting connection.
96    ///
97    /// The connect_future should be lazy.
98    ///
99    /// # Returns
100    ///
101    /// This returns a tuple of Self a handle to send dns messages and an optional background.
102    ///  The background task must be run on an executor before handle is used, if it is Some.
103    ///  If it is None, then another thread has already run the background.
104    pub async fn connect<F, S>(
105        connect_future: F,
106    ) -> Result<(Self, DnsExchangeBackground<S, TokioTime>), ProtoError>
107    where
108        S: DnsRequestSender,
109        F: Future<Output = Result<S, ProtoError>> + 'static + Send + Unpin,
110    {
111        let result = DnsExchange::connect(connect_future).await;
112        let use_edns = true;
113        result.map(|(exchange, bg)| (Self { exchange, use_edns }, bg))
114    }
115
116    /// (Re-)enable usage of EDNS for outgoing messages
117    pub fn enable_edns(&mut self) {
118        self.use_edns = true;
119    }
120
121    /// Disable usage of EDNS for outgoing messages
122    pub fn disable_edns(&mut self) {
123        self.use_edns = false;
124    }
125}
126
127impl DnsHandle for Client {
128    type Response = DnsExchangeSend;
129
130    fn send<R: Into<DnsRequest> + Unpin + Send + 'static>(&self, request: R) -> Self::Response {
131        self.exchange.send(request)
132    }
133
134    fn is_using_edns(&self) -> bool {
135        self.use_edns
136    }
137}
138
139impl<T> ClientHandle for T where T: DnsHandle {}
140
141/// A trait for implementing high level functions of DNS.
142pub trait ClientHandle: 'static + Clone + DnsHandle + Send {
143    /// A *classic* DNS query
144    ///
145    /// *Note* As of now, this will not recurse on PTR or CNAME record responses, that is up to
146    ///        the caller.
147    ///
148    /// # Arguments
149    ///
150    /// * `name` - the label to lookup
151    /// * `query_class` - most likely this should always be DNSClass::IN
152    /// * `query_type` - record type to lookup
153    fn query(
154        &mut self,
155        name: Name,
156        query_class: DNSClass,
157        query_type: RecordType,
158    ) -> ClientResponse<<Self as DnsHandle>::Response> {
159        let mut query = Query::query(name, query_type);
160        query.set_query_class(query_class);
161        let mut options = DnsRequestOptions::default();
162        options.use_edns = self.is_using_edns();
163        ClientResponse(self.lookup(query, options))
164    }
165
166    /// Sends a NOTIFY message to the remote system
167    ///
168    /// [RFC 1996](https://tools.ietf.org/html/rfc1996), DNS NOTIFY, August 1996
169    ///
170    ///
171    /// ```text
172    /// 1. Rationale and Scope
173    ///
174    ///   1.1. Slow propagation of new and changed data in a DNS zone can be
175    ///   due to a zone's relatively long refresh times.  Longer refresh times
176    ///   are beneficial in that they reduce load on the Primary Zone Servers, but
177    ///   that benefit comes at the cost of long intervals of incoherence among
178    ///   authority servers whenever the zone is updated.
179    ///
180    ///   1.2. The DNS NOTIFY transaction allows Primary Zone Servers to inform Secondary
181    ///   Zone Servers when the zone has changed -- an interrupt as opposed to poll
182    ///   model -- which it is hoped will reduce propagation delay while not
183    ///   unduly increasing the masters' load.  This specification only allows
184    ///   slaves to be notified of SOA RR changes, but the architecture of
185    ///   NOTIFY is intended to be extensible to other RR types.
186    ///
187    ///   1.3. This document intentionally gives more definition to the roles
188    ///   of "Primary", "Secondary" and "Stealth" servers, their enumeration in NS
189    ///   RRs, and the SOA MNAME field.  In that sense, this document can be
190    ///   considered an addendum to [RFC1035].
191    ///
192    /// ```
193    ///
194    /// The below section describes how the Notify message should be constructed. The function
195    ///  implementation accepts a Record, but the actual data of the record should be ignored by the
196    ///  server, i.e. the server should make a request subsequent to receiving this Notification for
197    ///  the authority record, but could be used to decide to request an update or not:
198    ///
199    /// ```text
200    ///   3.7. A NOTIFY request has QDCOUNT>0, ANCOUNT>=0, AUCOUNT>=0,
201    ///   ADCOUNT>=0.  If ANCOUNT>0, then the answer section represents an
202    ///   unsecure hint at the new RRset for this <QNAME,QCLASS,QTYPE>.  A
203    ///   Secondary receiving such a hint is free to treat equivalence of this
204    ///   answer section with its local data as a "no further work needs to be
205    ///   done" indication.  If ANCOUNT=0, or ANCOUNT>0 and the answer section
206    ///   differs from the Secondary's local data, then the Secondary should query its
207    ///   known Primaries to retrieve the new data.
208    /// ```
209    ///
210    /// Client's should be ready to handle, or be aware of, a server response of NOTIMP:
211    ///
212    /// ```text
213    ///   3.12. If a NOTIFY request is received by a Secondary who does not
214    ///   implement the NOTIFY opcode, it will respond with a NOTIMP
215    ///   (unimplemented feature error) message.  A Primary Zone Server who receives
216    ///   such a NOTIMP should consider the NOTIFY transaction complete for
217    ///   that Secondary.
218    /// ```
219    ///
220    /// # Arguments
221    ///
222    /// * `name` - the label which is being notified
223    /// * `query_class` - most likely this should always be DNSClass::IN
224    /// * `query_type` - record type which has been updated
225    /// * `rrset` - the new version of the record(s) being notified
226    fn notify<R>(
227        &mut self,
228        name: Name,
229        query_class: DNSClass,
230        query_type: RecordType,
231        rrset: Option<R>,
232    ) -> ClientResponse<<Self as DnsHandle>::Response>
233    where
234        R: Into<RecordSet>,
235    {
236        debug!("notifying: {} {:?}", name, query_type);
237
238        // build the message
239        let mut message: Message = Message::new();
240        let id: u16 = rand::random();
241        message
242            .set_id(id)
243            // 3.3. NOTIFY is similar to QUERY in that it has a request message with
244            // the header QR flag "clear" and a response message with QR "set".  The
245            // response message contains no useful information, but its reception by
246            // the Primary is an indication that the Secondary has received the NOTIFY
247            // and that the Primary Zone Server can remove the Secondary from any retry queue for
248            // this NOTIFY event.
249            .set_message_type(MessageType::Query)
250            .set_op_code(OpCode::Notify);
251
252        // Extended dns
253        if self.is_using_edns() {
254            message
255                .extensions_mut()
256                .get_or_insert_with(Edns::new)
257                .set_max_payload(update_message::MAX_PAYLOAD_LEN)
258                .set_version(0);
259        }
260
261        // add the query
262        let mut query: Query = Query::new();
263        query
264            .set_name(name)
265            .set_query_class(query_class)
266            .set_query_type(query_type);
267        message.add_query(query);
268
269        // add the notify message, see https://tools.ietf.org/html/rfc1996, section 3.7
270        if let Some(rrset) = rrset {
271            message.add_answers(rrset.into());
272        }
273
274        ClientResponse(self.send(message))
275    }
276
277    /// Sends a record to create on the server, this will fail if the record exists (atomicity
278    ///  depends on the server)
279    ///
280    /// [RFC 2136](https://tools.ietf.org/html/rfc2136), DNS Update, April 1997
281    ///
282    /// ```text
283    ///  2.4.3 - RRset Does Not Exist
284    ///
285    ///   No RRs with a specified NAME and TYPE (in the zone and class denoted
286    ///   by the Zone Section) can exist.
287    ///
288    ///   For this prerequisite, a requestor adds to the section a single RR
289    ///   whose NAME and TYPE are equal to that of the RRset whose nonexistence
290    ///   is required.  The RDLENGTH of this record is zero (0), and RDATA
291    ///   field is therefore empty.  CLASS must be specified as NONE in order
292    ///   to distinguish this condition from a valid RR whose RDLENGTH is
293    ///   naturally zero (0) (for example, the NULL RR).  TTL must be specified
294    ///   as zero (0).
295    ///
296    /// 2.5.1 - Add To An RRset
297    ///
298    ///    RRs are added to the Update Section whose NAME, TYPE, TTL, RDLENGTH
299    ///    and RDATA are those being added, and CLASS is the same as the zone
300    ///    class.  Any duplicate RRs will be silently ignored by the Primary Zone
301    ///    Server.
302    /// ```
303    ///
304    /// # Arguments
305    ///
306    /// * `rrset` - the record(s) to create
307    /// * `zone_origin` - the zone name to update, i.e. SOA name
308    ///
309    /// The update must go to a zone authority (i.e. the server used in the ClientConnection)
310    fn create<R>(
311        &mut self,
312        rrset: R,
313        zone_origin: Name,
314    ) -> ClientResponse<<Self as DnsHandle>::Response>
315    where
316        R: Into<RecordSet>,
317    {
318        let rrset = rrset.into();
319        let message = update_message::create(rrset, zone_origin, self.is_using_edns());
320
321        ClientResponse(self.send(message))
322    }
323
324    /// Appends a record to an existing rrset, optionally require the rrset to exist (atomicity
325    ///  depends on the server)
326    ///
327    /// [RFC 2136](https://tools.ietf.org/html/rfc2136), DNS Update, April 1997
328    ///
329    /// ```text
330    /// 2.4.1 - RRset Exists (Value Independent)
331    ///
332    ///   At least one RR with a specified NAME and TYPE (in the zone and class
333    ///   specified in the Zone Section) must exist.
334    ///
335    ///   For this prerequisite, a requestor adds to the section a single RR
336    ///   whose NAME and TYPE are equal to that of the zone RRset whose
337    ///   existence is required.  RDLENGTH is zero and RDATA is therefore
338    ///   empty.  CLASS must be specified as ANY to differentiate this
339    ///   condition from that of an actual RR whose RDLENGTH is naturally zero
340    ///   (0) (e.g., NULL).  TTL is specified as zero (0).
341    ///
342    /// 2.5.1 - Add To An RRset
343    ///
344    ///    RRs are added to the Update Section whose NAME, TYPE, TTL, RDLENGTH
345    ///    and RDATA are those being added, and CLASS is the same as the zone
346    ///    class.  Any duplicate RRs will be silently ignored by the Primary Zone
347    ///    Server.
348    /// ```
349    ///
350    /// # Arguments
351    ///
352    /// * `rrset` - the record(s) to append to an RRSet
353    /// * `zone_origin` - the zone name to update, i.e. SOA name
354    /// * `must_exist` - if true, the request will fail if the record does not exist
355    ///
356    /// The update must go to a zone authority (i.e. the server used in the ClientConnection). If
357    /// the rrset does not exist and must_exist is false, then the RRSet will be created.
358    fn append<R>(
359        &mut self,
360        rrset: R,
361        zone_origin: Name,
362        must_exist: bool,
363    ) -> ClientResponse<<Self as DnsHandle>::Response>
364    where
365        R: Into<RecordSet>,
366    {
367        let rrset = rrset.into();
368        let message = update_message::append(rrset, zone_origin, must_exist, self.is_using_edns());
369
370        ClientResponse(self.send(message))
371    }
372
373    /// Compares and if it matches, swaps it for the new value (atomicity depends on the server)
374    ///
375    /// ```text
376    ///  2.4.2 - RRset Exists (Value Dependent)
377    ///
378    ///   A set of RRs with a specified NAME and TYPE exists and has the same
379    ///   members with the same RDATAs as the RRset specified here in this
380    ///   section.  While RRset ordering is undefined and therefore not
381    ///   significant to this comparison, the sets be identical in their
382    ///   extent.
383    ///
384    ///   For this prerequisite, a requestor adds to the section an entire
385    ///   RRset whose preexistence is required.  NAME and TYPE are that of the
386    ///   RRset being denoted.  CLASS is that of the zone.  TTL must be
387    ///   specified as zero (0) and is ignored when comparing RRsets for
388    ///   identity.
389    ///
390    ///  2.5.4 - Delete An RR From An RRset
391    ///
392    ///   RRs to be deleted are added to the Update Section.  The NAME, TYPE,
393    ///   RDLENGTH and RDATA must match the RR being deleted.  TTL must be
394    ///   specified as zero (0) and will otherwise be ignored by the Primary
395    ///   Zone Server.  CLASS must be specified as NONE to distinguish this from an
396    ///   RR addition.  If no such RRs exist, then this Update RR will be
397    ///   silently ignored by the Primary Zone Server.
398    ///
399    ///  2.5.1 - Add To An RRset
400    ///
401    ///   RRs are added to the Update Section whose NAME, TYPE, TTL, RDLENGTH
402    ///   and RDATA are those being added, and CLASS is the same as the zone
403    ///   class.  Any duplicate RRs will be silently ignored by the Primary
404    ///   Zone Server.
405    /// ```
406    ///
407    /// # Arguments
408    ///
409    /// * `current` - the current rrset which must exist for the swap to complete
410    /// * `new` - the new rrset with which to replace the current rrset
411    /// * `zone_origin` - the zone name to update, i.e. SOA name
412    ///
413    /// The update must go to a zone authority (i.e. the server used in the ClientConnection).
414    fn compare_and_swap<C, N>(
415        &mut self,
416        current: C,
417        new: N,
418        zone_origin: Name,
419    ) -> ClientResponse<<Self as DnsHandle>::Response>
420    where
421        C: Into<RecordSet>,
422        N: Into<RecordSet>,
423    {
424        let current = current.into();
425        let new = new.into();
426
427        let message =
428            update_message::compare_and_swap(current, new, zone_origin, self.is_using_edns());
429        ClientResponse(self.send(message))
430    }
431
432    /// Deletes a record (by rdata) from an rrset, optionally require the rrset to exist.
433    ///
434    /// [RFC 2136](https://tools.ietf.org/html/rfc2136), DNS Update, April 1997
435    ///
436    /// ```text
437    /// 2.4.1 - RRset Exists (Value Independent)
438    ///
439    ///   At least one RR with a specified NAME and TYPE (in the zone and class
440    ///   specified in the Zone Section) must exist.
441    ///
442    ///   For this prerequisite, a requestor adds to the section a single RR
443    ///   whose NAME and TYPE are equal to that of the zone RRset whose
444    ///   existence is required.  RDLENGTH is zero and RDATA is therefore
445    ///   empty.  CLASS must be specified as ANY to differentiate this
446    ///   condition from that of an actual RR whose RDLENGTH is naturally zero
447    ///   (0) (e.g., NULL).  TTL is specified as zero (0).
448    ///
449    /// 2.5.4 - Delete An RR From An RRset
450    ///
451    ///   RRs to be deleted are added to the Update Section.  The NAME, TYPE,
452    ///   RDLENGTH and RDATA must match the RR being deleted.  TTL must be
453    ///   specified as zero (0) and will otherwise be ignored by the Primary
454    ///   Zone Server.  CLASS must be specified as NONE to distinguish this from an
455    ///   RR addition.  If no such RRs exist, then this Update RR will be
456    ///   silently ignored by the Primary Zone Server.
457    /// ```
458    ///
459    /// # Arguments
460    ///
461    /// * `rrset` - the record(s) to delete from a RRSet, the name, type and rdata must match the
462    ///   record to delete
463    /// * `zone_origin` - the zone name to update, i.e. SOA name
464    /// * `signer` - the signer, with private key, to use to sign the request
465    ///
466    /// The update must go to a zone authority (i.e. the server used in the ClientConnection). If
467    /// the rrset does not exist and must_exist is false, then the RRSet will be deleted.
468    fn delete_by_rdata<R>(
469        &mut self,
470        rrset: R,
471        zone_origin: Name,
472    ) -> ClientResponse<<Self as DnsHandle>::Response>
473    where
474        R: Into<RecordSet>,
475    {
476        let rrset = rrset.into();
477        let message = update_message::delete_by_rdata(rrset, zone_origin, self.is_using_edns());
478
479        ClientResponse(self.send(message))
480    }
481
482    /// Deletes an entire rrset, optionally require the rrset to exist.
483    ///
484    /// [RFC 2136](https://tools.ietf.org/html/rfc2136), DNS Update, April 1997
485    ///
486    /// ```text
487    /// 2.4.1 - RRset Exists (Value Independent)
488    ///
489    ///   At least one RR with a specified NAME and TYPE (in the zone and class
490    ///   specified in the Zone Section) must exist.
491    ///
492    ///   For this prerequisite, a requestor adds to the section a single RR
493    ///   whose NAME and TYPE are equal to that of the zone RRset whose
494    ///   existence is required.  RDLENGTH is zero and RDATA is therefore
495    ///   empty.  CLASS must be specified as ANY to differentiate this
496    ///   condition from that of an actual RR whose RDLENGTH is naturally zero
497    ///   (0) (e.g., NULL).  TTL is specified as zero (0).
498    ///
499    /// 2.5.2 - Delete An RRset
500    ///
501    ///   One RR is added to the Update Section whose NAME and TYPE are those
502    ///   of the RRset to be deleted.  TTL must be specified as zero (0) and is
503    ///   otherwise not used by the Primary Zone Server.  CLASS must be specified as
504    ///   ANY.  RDLENGTH must be zero (0) and RDATA must therefore be empty.
505    ///   If no such RRset exists, then this Update RR will be silently ignored
506    ///   by the Primary Zone Server.
507    /// ```
508    ///
509    /// # Arguments
510    ///
511    /// * `record` - The name, class and record_type will be used to match and delete the RecordSet
512    /// * `zone_origin` - the zone name to update, i.e. SOA name
513    ///
514    /// The update must go to a zone authority (i.e. the server used in the ClientConnection). If
515    /// the rrset does not exist and must_exist is false, then the RRSet will be deleted.
516    fn delete_rrset(
517        &mut self,
518        record: Record,
519        zone_origin: Name,
520    ) -> ClientResponse<<Self as DnsHandle>::Response> {
521        assert!(zone_origin.zone_of(record.name()));
522        let message = update_message::delete_rrset(record, zone_origin, self.is_using_edns());
523
524        ClientResponse(self.send(message))
525    }
526
527    /// Deletes all records at the specified name
528    ///
529    /// [RFC 2136](https://tools.ietf.org/html/rfc2136), DNS Update, April 1997
530    ///
531    /// ```text
532    /// 2.5.3 - Delete All RRsets From A Name
533    ///
534    ///   One RR is added to the Update Section whose NAME is that of the name
535    ///   to be cleansed of RRsets.  TYPE must be specified as ANY.  TTL must
536    ///   be specified as zero (0) and is otherwise not used by the Primary
537    ///   Zone Server.  CLASS must be specified as ANY.  RDLENGTH must be zero (0)
538    ///   and RDATA must therefore be empty.  If no such RRsets exist, then
539    ///   this Update RR will be silently ignored by the Primary Zone Server.
540    /// ```
541    ///
542    /// # Arguments
543    ///
544    /// * `name_of_records` - the name of all the record sets to delete
545    /// * `zone_origin` - the zone name to update, i.e. SOA name
546    /// * `dns_class` - the class of the SOA
547    ///
548    /// The update must go to a zone authority (i.e. the server used in the ClientConnection). This
549    /// operation attempts to delete all resource record sets the specified name regardless of
550    /// the record type.
551    fn delete_all(
552        &mut self,
553        name_of_records: Name,
554        zone_origin: Name,
555        dns_class: DNSClass,
556    ) -> ClientResponse<<Self as DnsHandle>::Response> {
557        assert!(zone_origin.zone_of(&name_of_records));
558        let message = update_message::delete_all(
559            name_of_records,
560            zone_origin,
561            dns_class,
562            self.is_using_edns(),
563        );
564
565        ClientResponse(self.send(message))
566    }
567
568    /// Download all records from a zone, or all records modified since given SOA was observed.
569    /// The request will either be a AXFR Query (ask for full zone transfer) if a SOA was not
570    /// provided, or a IXFR Query (incremental zone transfer) if a SOA was provided.
571    ///
572    /// # Arguments
573    /// * `zone_origin` - the zone name to update, i.e. SOA name
574    /// * `last_soa` - the last SOA known, if any. If provided, name must match `zone_origin`
575    fn zone_transfer(
576        &mut self,
577        zone_origin: Name,
578        last_soa: Option<SOA>,
579    ) -> ClientStreamXfr<<Self as DnsHandle>::Response> {
580        let ixfr = last_soa.is_some();
581        let message = update_message::zone_transfer(zone_origin, last_soa);
582
583        ClientStreamXfr::new(self.send(message), ixfr)
584    }
585}
586
587/// A stream result of a Client Request
588#[must_use = "stream do nothing unless polled"]
589pub struct ClientStreamingResponse<R>(pub(crate) R)
590where
591    R: Stream<Item = Result<DnsResponse, ProtoError>> + Send + Unpin + 'static;
592
593impl<R> Stream for ClientStreamingResponse<R>
594where
595    R: Stream<Item = Result<DnsResponse, ProtoError>> + Send + Unpin + 'static,
596{
597    type Item = Result<DnsResponse, ClientError>;
598
599    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
600        Poll::Ready(ready!(self.0.poll_next_unpin(cx)).map(|r| r.map_err(ClientError::from)))
601    }
602}
603
604/// A future result of a Client Request
605#[must_use = "futures do nothing unless polled"]
606pub struct ClientResponse<R>(pub(crate) R)
607where
608    R: Stream<Item = Result<DnsResponse, ProtoError>> + Send + Unpin + 'static;
609
610impl<R> Future for ClientResponse<R>
611where
612    R: Stream<Item = Result<DnsResponse, ProtoError>> + Send + Unpin + 'static,
613{
614    type Output = Result<DnsResponse, ClientError>;
615
616    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
617        Poll::Ready(
618            match ready!(self.0.poll_next_unpin(cx)) {
619                Some(r) => r,
620                None => Err(ProtoError::from(ProtoErrorKind::Timeout)),
621            }
622            .map_err(ClientError::from),
623        )
624    }
625}
626
627/// A stream result of a zone transfer Client Request
628/// Accept messages until the end of a zone transfer. For AXFR, it search for a starting and an
629/// ending SOA. For IXFR, it do so taking into account there will be other SOA inbetween
630#[must_use = "stream do nothing unless polled"]
631pub struct ClientStreamXfr<R>
632where
633    R: Stream<Item = Result<DnsResponse, ProtoError>> + Send + Unpin + 'static,
634{
635    state: ClientStreamXfrState<R>,
636}
637
638impl<R> ClientStreamXfr<R>
639where
640    R: Stream<Item = Result<DnsResponse, ProtoError>> + Send + Unpin + 'static,
641{
642    fn new(inner: R, maybe_incr: bool) -> Self {
643        Self {
644            state: ClientStreamXfrState::Start { inner, maybe_incr },
645        }
646    }
647}
648
649/// State machine for ClientStreamXfr, implementing almost all logic
650#[derive(Debug)]
651enum ClientStreamXfrState<R> {
652    Start {
653        inner: R,
654        maybe_incr: bool,
655    },
656    Second {
657        inner: R,
658        expected_serial: u32,
659        maybe_incr: bool,
660    },
661    Axfr {
662        inner: R,
663        expected_serial: u32,
664    },
665    Ixfr {
666        inner: R,
667        even: bool,
668        expected_serial: u32,
669    },
670    Ended,
671    Invalid,
672}
673
674impl<R> ClientStreamXfrState<R> {
675    /// Helper to get the stream from the enum
676    fn inner(&mut self) -> &mut R {
677        use ClientStreamXfrState::*;
678        match self {
679            Start { inner, .. } => inner,
680            Second { inner, .. } => inner,
681            Axfr { inner, .. } => inner,
682            Ixfr { inner, .. } => inner,
683            Ended | Invalid => unreachable!(),
684        }
685    }
686
687    /// Helper to ingest answer Records
688    // TODO: this is complex enough it should get its own tests
689    fn process(&mut self, answers: &[Record]) -> Result<(), ClientError> {
690        use ClientStreamXfrState::*;
691        fn get_serial(r: &Record) -> Option<u32> {
692            r.data().as_soa().map(SOA::serial)
693        }
694
695        if answers.is_empty() {
696            return Ok(());
697        }
698        match std::mem::replace(self, Invalid) {
699            Start { inner, maybe_incr } => {
700                if let Some(expected_serial) = get_serial(&answers[0]) {
701                    *self = Second {
702                        inner,
703                        maybe_incr,
704                        expected_serial,
705                    };
706                    self.process(&answers[1..])
707                } else {
708                    *self = Ended;
709                    Ok(())
710                }
711            }
712            Second {
713                inner,
714                maybe_incr,
715                expected_serial,
716            } => {
717                if let Some(serial) = get_serial(&answers[0]) {
718                    // maybe IXFR, or empty AXFR
719                    if serial == expected_serial {
720                        // empty AXFR
721                        *self = Ended;
722                        if answers.len() == 1 {
723                            Ok(())
724                        } else {
725                            // invalid answer : trailing records
726                            Err(ClientErrorKind::Message(
727                                "invalid zone transfer, contains trailing records",
728                            )
729                            .into())
730                        }
731                    } else if maybe_incr {
732                        *self = Ixfr {
733                            inner,
734                            expected_serial,
735                            even: true,
736                        };
737                        self.process(&answers[1..])
738                    } else {
739                        *self = Ended;
740                        Err(ClientErrorKind::Message(
741                            "invalid zone transfer, expected AXFR, got IXFR",
742                        )
743                        .into())
744                    }
745                } else {
746                    // standard AXFR
747                    *self = Axfr {
748                        inner,
749                        expected_serial,
750                    };
751                    self.process(&answers[1..])
752                }
753            }
754            Axfr {
755                inner,
756                expected_serial,
757            } => {
758                let soa_count = answers
759                    .iter()
760                    .filter(|a| a.record_type() == RecordType::SOA)
761                    .count();
762                match soa_count {
763                    0 => {
764                        *self = Axfr {
765                            inner,
766                            expected_serial,
767                        };
768                        Ok(())
769                    }
770                    1 => {
771                        *self = Ended;
772                        match answers.last().map(|r| r.record_type()) {
773                            Some(RecordType::SOA) => Ok(()),
774                            _ => Err(ClientErrorKind::Message(
775                                "invalid zone transfer, contains trailing records",
776                            )
777                            .into()),
778                        }
779                    }
780                    _ => {
781                        *self = Ended;
782                        Err(ClientErrorKind::Message(
783                            "invalid zone transfer, contains trailing records",
784                        )
785                        .into())
786                    }
787                }
788            }
789            Ixfr {
790                inner,
791                even,
792                expected_serial,
793            } => {
794                let even = answers
795                    .iter()
796                    .fold(even, |even, a| even ^ (a.record_type() == RecordType::SOA));
797                if even {
798                    if let Some(serial) = get_serial(answers.last().unwrap()) {
799                        if serial == expected_serial {
800                            *self = Ended;
801                            return Ok(());
802                        }
803                    }
804                }
805                *self = Ixfr {
806                    inner,
807                    even,
808                    expected_serial,
809                };
810                Ok(())
811            }
812            Ended | Invalid => {
813                unreachable!();
814            }
815        }
816    }
817}
818
819impl<R> Stream for ClientStreamXfr<R>
820where
821    R: Stream<Item = Result<DnsResponse, ProtoError>> + Send + Unpin + 'static,
822{
823    type Item = Result<DnsResponse, ClientError>;
824
825    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
826        use ClientStreamXfrState::*;
827
828        if matches!(self.state, Ended) {
829            return Poll::Ready(None);
830        }
831
832        let message = ready!(self.state.inner().poll_next_unpin(cx)).map(|response| {
833            let ok = response?;
834            self.state.process(ok.answers())?;
835            Ok(ok)
836        });
837        Poll::Ready(message)
838    }
839}
840
841#[cfg(test)]
842mod tests {
843    use std::net::SocketAddr;
844
845    use super::*;
846
847    use ClientStreamXfrState::*;
848    use futures_util::stream::iter;
849    use hickory_proto::{
850        rr::{
851            RData,
852            rdata::{A, SOA},
853        },
854        runtime::TokioRuntimeProvider,
855    };
856    use test_support::subscribe;
857
858    fn soa_record(serial: u32) -> Record {
859        let soa = RData::SOA(SOA::new(
860            Name::from_ascii("example.com.").unwrap(),
861            Name::from_ascii("admin.example.com.").unwrap(),
862            serial,
863            60,
864            60,
865            60,
866            60,
867        ));
868        Record::from_rdata(Name::from_ascii("example.com.").unwrap(), 600, soa)
869    }
870
871    fn a_record(ip: u8) -> Record {
872        let a = RData::A(A::new(0, 0, 0, ip));
873        Record::from_rdata(Name::from_ascii("www.example.com.").unwrap(), 600, a)
874    }
875
876    fn get_stream_testcase(
877        records: Vec<Vec<Record>>,
878    ) -> impl Stream<Item = Result<DnsResponse, ProtoError>> + Send + Unpin + 'static {
879        let stream = records.into_iter().map(|r| {
880            Ok({
881                let mut m = Message::new();
882                m.insert_answers(r);
883                DnsResponse::from_message(m).unwrap()
884            })
885        });
886        iter(stream)
887    }
888
889    #[tokio::test]
890    async fn test_stream_xfr_valid_axfr() {
891        subscribe();
892        let stream = get_stream_testcase(vec![vec![
893            soa_record(3),
894            a_record(1),
895            a_record(2),
896            soa_record(3),
897        ]]);
898        let mut stream = ClientStreamXfr::new(stream, false);
899        assert!(matches!(stream.state, Start { .. }));
900
901        let response = stream.next().await.unwrap().unwrap();
902        assert!(matches!(stream.state, Ended));
903        assert_eq!(response.answers().len(), 4);
904
905        assert!(stream.next().await.is_none());
906    }
907
908    #[tokio::test]
909    async fn test_stream_xfr_valid_axfr_multipart() {
910        subscribe();
911        let stream = get_stream_testcase(vec![
912            vec![soa_record(3)],
913            vec![a_record(1)],
914            vec![soa_record(3)],
915            vec![a_record(2)], // will be ignored as connection is dropped before reading this message
916        ]);
917        let mut stream = ClientStreamXfr::new(stream, false);
918        assert!(matches!(stream.state, Start { .. }));
919
920        let response = stream.next().await.unwrap().unwrap();
921        assert!(matches!(stream.state, Second { .. }));
922        assert_eq!(response.answers().len(), 1);
923
924        let response = stream.next().await.unwrap().unwrap();
925        assert!(matches!(stream.state, Axfr { .. }));
926        assert_eq!(response.answers().len(), 1);
927
928        let response = stream.next().await.unwrap().unwrap();
929        assert!(matches!(stream.state, Ended));
930        assert_eq!(response.answers().len(), 1);
931
932        assert!(stream.next().await.is_none());
933    }
934
935    #[tokio::test]
936    async fn test_stream_xfr_empty_axfr() {
937        subscribe();
938        let stream = get_stream_testcase(vec![vec![soa_record(3)], vec![soa_record(3)]]);
939        let mut stream = ClientStreamXfr::new(stream, false);
940        assert!(matches!(stream.state, Start { .. }));
941
942        let response = stream.next().await.unwrap().unwrap();
943        assert!(matches!(stream.state, Second { .. }));
944        assert_eq!(response.answers().len(), 1);
945
946        let response = stream.next().await.unwrap().unwrap();
947        assert!(matches!(stream.state, Ended));
948        assert_eq!(response.answers().len(), 1);
949
950        assert!(stream.next().await.is_none());
951    }
952
953    #[tokio::test]
954    async fn test_stream_xfr_axfr_with_ixfr_reply() {
955        subscribe();
956        let stream = get_stream_testcase(vec![vec![
957            soa_record(3),
958            soa_record(2),
959            a_record(1),
960            soa_record(3),
961            a_record(2),
962            soa_record(3),
963        ]]);
964        let mut stream = ClientStreamXfr::new(stream, false);
965        assert!(matches!(stream.state, Start { .. }));
966
967        stream.next().await.unwrap().unwrap_err();
968        assert!(matches!(stream.state, Ended));
969
970        assert!(stream.next().await.is_none());
971    }
972
973    #[tokio::test]
974    async fn test_stream_xfr_axfr_with_non_xfr_reply() {
975        subscribe();
976        let stream = get_stream_testcase(vec![
977            vec![a_record(1)], // assume this is an error response, not a zone transfer
978            vec![a_record(2)],
979        ]);
980        let mut stream = ClientStreamXfr::new(stream, false);
981        assert!(matches!(stream.state, Start { .. }));
982
983        let response = stream.next().await.unwrap().unwrap();
984        assert!(matches!(stream.state, Ended));
985        assert_eq!(response.answers().len(), 1);
986
987        assert!(stream.next().await.is_none());
988    }
989
990    #[tokio::test]
991    async fn test_stream_xfr_invalid_axfr_multipart() {
992        subscribe();
993        let stream = get_stream_testcase(vec![
994            vec![soa_record(3)],
995            vec![a_record(1)],
996            vec![soa_record(3), a_record(2)],
997            vec![soa_record(3)],
998        ]);
999        let mut stream = ClientStreamXfr::new(stream, false);
1000        assert!(matches!(stream.state, Start { .. }));
1001
1002        let response = stream.next().await.unwrap().unwrap();
1003        assert!(matches!(stream.state, Second { .. }));
1004        assert_eq!(response.answers().len(), 1);
1005
1006        let response = stream.next().await.unwrap().unwrap();
1007        assert!(matches!(stream.state, Axfr { .. }));
1008        assert_eq!(response.answers().len(), 1);
1009
1010        stream.next().await.unwrap().unwrap_err();
1011        assert!(matches!(stream.state, Ended));
1012
1013        assert!(stream.next().await.is_none());
1014    }
1015
1016    #[tokio::test]
1017    async fn test_stream_xfr_valid_ixfr() {
1018        subscribe();
1019        let stream = get_stream_testcase(vec![vec![
1020            soa_record(3),
1021            soa_record(2),
1022            a_record(1),
1023            soa_record(3),
1024            a_record(2),
1025            soa_record(3),
1026        ]]);
1027        let mut stream = ClientStreamXfr::new(stream, true);
1028        assert!(matches!(stream.state, Start { .. }));
1029
1030        let response = stream.next().await.unwrap().unwrap();
1031        assert!(matches!(stream.state, Ended));
1032        assert_eq!(response.answers().len(), 6);
1033
1034        assert!(stream.next().await.is_none());
1035    }
1036
1037    #[tokio::test]
1038    async fn test_stream_xfr_valid_ixfr_multipart() {
1039        subscribe();
1040        let stream = get_stream_testcase(vec![
1041            vec![soa_record(3)],
1042            vec![soa_record(2)],
1043            vec![a_record(1)],
1044            vec![soa_record(3)],
1045            vec![a_record(2)],
1046            vec![soa_record(3)],
1047            vec![a_record(3)], //
1048        ]);
1049        let mut stream = ClientStreamXfr::new(stream, true);
1050        assert!(matches!(stream.state, Start { .. }));
1051
1052        let response = stream.next().await.unwrap().unwrap();
1053        assert!(matches!(stream.state, Second { .. }));
1054        assert_eq!(response.answers().len(), 1);
1055
1056        let response = stream.next().await.unwrap().unwrap();
1057        assert!(matches!(stream.state, Ixfr { even: true, .. }));
1058        assert_eq!(response.answers().len(), 1);
1059
1060        let response = stream.next().await.unwrap().unwrap();
1061        assert!(matches!(stream.state, Ixfr { even: true, .. }));
1062        assert_eq!(response.answers().len(), 1);
1063
1064        let response = stream.next().await.unwrap().unwrap();
1065        assert!(matches!(stream.state, Ixfr { even: false, .. }));
1066        assert_eq!(response.answers().len(), 1);
1067
1068        let response = stream.next().await.unwrap().unwrap();
1069        assert!(matches!(stream.state, Ixfr { even: false, .. }));
1070        assert_eq!(response.answers().len(), 1);
1071
1072        let response = stream.next().await.unwrap().unwrap();
1073        assert!(matches!(stream.state, Ended));
1074        assert_eq!(response.answers().len(), 1);
1075
1076        assert!(stream.next().await.is_none());
1077    }
1078
1079    #[tokio::test]
1080    async fn async_client() {
1081        subscribe();
1082        use crate::client::{Client, ClientHandle};
1083        use hickory_proto::{
1084            rr::{DNSClass, Name, RData, RecordType},
1085            tcp::TcpClientStream,
1086        };
1087        use std::str::FromStr;
1088
1089        // Since we used UDP in the previous examples, let's change things up a bit and use TCP here
1090        let addr = SocketAddr::from(([8, 8, 8, 8], 53));
1091        let (stream, sender) = TcpClientStream::new(addr, None, None, TokioRuntimeProvider::new());
1092
1093        // Create a new client, the bg is a background future which handles
1094        //   the multiplexing of the DNS requests to the server.
1095        //   the client is a handle to an unbounded queue for sending requests via the
1096        //   background. The background must be scheduled to run before the client can
1097        //   send any dns requests
1098        let client = Client::new(stream, sender, None);
1099
1100        // await the connection to be established
1101        let (mut client, bg) = client.await.expect("connection failed");
1102
1103        // make sure to run the background task
1104        tokio::spawn(bg);
1105
1106        // Create a query future
1107        let query = client.query(
1108            Name::from_str("www.example.com.").unwrap(),
1109            DNSClass::IN,
1110            RecordType::A,
1111        );
1112
1113        // wait for its response
1114        let (message_returned, buffer) = query.await.unwrap().into_parts();
1115
1116        // validate it's what we expected
1117        if let RData::A(addr) = message_returned.answers()[0].data() {
1118            assert_eq!(*addr, A::new(93, 184, 215, 14));
1119        }
1120
1121        let message_parsed = Message::from_vec(&buffer)
1122            .expect("buffer was parsed already by Client so we should be able to do it again");
1123
1124        // validate it's what we expected
1125        if let RData::A(addr) = message_parsed.answers()[0].data() {
1126            assert_eq!(*addr, A::new(93, 184, 215, 14));
1127        }
1128    }
1129}