1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
//! Constructing and sending requests.

#![warn(missing_docs)]
#![warn(clippy::missing_docs_in_private_items)]

use crate::base::iana::Rcode;
use crate::base::message::{CopyRecordsError, ShortMessage};
use crate::base::message_builder::{
    AdditionalBuilder, MessageBuilder, PushError, StaticCompressor,
};
use crate::base::opt::{ComposeOptData, LongOptData, OptRecord};
use crate::base::wire::{Composer, ParseError};
use crate::base::{Header, Message, ParsedName, Rtype};
use crate::rdata::AllRecordData;
use bytes::Bytes;
use octseq::Octets;
use std::boxed::Box;
use std::fmt::Debug;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::vec::Vec;
use std::{error, fmt};
use tracing::trace;

//------------ ComposeRequest ------------------------------------------------

/// A trait that allows composing a request as a series.
pub trait ComposeRequest: Debug + Send + Sync {
    /// Appends the final message to a provided composer.
    fn append_message<Target: Composer>(
        &self,
        target: &mut Target,
    ) -> Result<(), CopyRecordsError>;

    /// Create a message that captures the recorded changes.
    fn to_message(&self) -> Result<Message<Vec<u8>>, Error>;

    /// Create a message that captures the recorded changes and convert to
    /// a Vec.
    fn to_vec(&self) -> Result<Vec<u8>, Error>;

    /// Return a reference to a mutable Header to record changes to the header.
    fn header_mut(&mut self) -> &mut Header;

    /// Set the UDP payload size.
    fn set_udp_payload_size(&mut self, value: u16);

    /// Set the DNSSEC OK flag.
    fn set_dnssec_ok(&mut self, value: bool);

    /// Add an EDNS option.
    fn add_opt(
        &mut self,
        opt: &impl ComposeOptData,
    ) -> Result<(), LongOptData>;

    /// Returns whether a message is an answer to the request.
    fn is_answer(&self, answer: &Message<[u8]>) -> bool;
}

//------------ SendRequest ---------------------------------------------------

/// Trait for starting a DNS request based on a request composer.
///
/// In the future, the return type of request should become an associated type.
/// However, the use of 'dyn Request' in redundant currently prevents that.
pub trait SendRequest<CR> {
    /// Request function that takes a ComposeRequest type.
    fn send_request(
        &self,
        request_msg: CR,
    ) -> Box<dyn GetResponse + Send + Sync>;
}

//------------ GetResponse ---------------------------------------------------

/// Trait for getting the result of a DNS query.
///
/// In the future, the return type of get_response should become an associated
/// type. However, too many uses of 'dyn GetResponse' currently prevent that.
pub trait GetResponse: Debug {
    /// Get the result of a DNS request.
    ///
    /// This function is intended to be cancel safe.
    fn get_response(
        &mut self,
    ) -> Pin<
        Box<
            dyn Future<Output = Result<Message<Bytes>, Error>>
                + Send
                + Sync
                + '_,
        >,
    >;
}

//------------ RequestMessage ------------------------------------------------

/// Object that implements the ComposeRequest trait for a Message object.
#[derive(Clone, Debug)]
pub struct RequestMessage<Octs: AsRef<[u8]>> {
    /// Base message.
    msg: Message<Octs>,

    /// New header.
    header: Header,

    /// The OPT record to add if required.
    opt: Option<OptRecord<Vec<u8>>>,
}

impl<Octs: AsRef<[u8]> + Debug + Octets> RequestMessage<Octs> {
    /// Create a new BMB object.
    pub fn new(msg: impl Into<Message<Octs>>) -> Self {
        let msg = msg.into();
        let header = msg.header();
        Self {
            msg,
            header,
            opt: None,
        }
    }

    /// Returns a mutable reference to the OPT record.
    ///
    /// Adds one if necessary.
    fn opt_mut(&mut self) -> &mut OptRecord<Vec<u8>> {
        self.opt.get_or_insert_with(Default::default)
    }

    /// Appends the message to a composer.
    fn append_message_impl<Target: Composer>(
        &self,
        mut target: MessageBuilder<Target>,
    ) -> Result<AdditionalBuilder<Target>, CopyRecordsError> {
        let source = &self.msg;

        *target.header_mut() = self.header;

        let source = source.question();
        let mut target = target.question();
        for rr in source {
            target.push(rr?)?;
        }
        let mut source = source.answer()?;
        let mut target = target.answer();
        for rr in &mut source {
            let rr = rr?
                .into_record::<AllRecordData<_, ParsedName<_>>>()?
                .expect("record expected");
            target.push(rr)?;
        }

        let mut source =
            source.next_section()?.expect("section should be present");
        let mut target = target.authority();
        for rr in &mut source {
            let rr = rr?
                .into_record::<AllRecordData<_, ParsedName<_>>>()?
                .expect("record expected");
            target.push(rr)?;
        }

        let source =
            source.next_section()?.expect("section should be present");
        let mut target = target.additional();
        for rr in source {
            let rr = rr?;
            if rr.rtype() != Rtype::OPT {
                let rr = rr
                    .into_record::<AllRecordData<_, ParsedName<_>>>()?
                    .expect("record expected");
                target.push(rr)?;
            }
        }

        if let Some(opt) = self.opt.as_ref() {
            target.push(opt.as_record())?;
        }

        Ok(target)
    }

    /// Create new message based on the changes to the base message.
    fn to_message_impl(&self) -> Result<Message<Vec<u8>>, Error> {
        let target =
            MessageBuilder::from_target(StaticCompressor::new(Vec::new()))
                .expect("Vec is expected to have enough space");

        let target = self.append_message_impl(target)?;

        // It would be nice to use .builder() here. But that one deletes all
        // section. We have to resort to .as_builder() which gives a
        // reference and then .clone()
        let result = target.as_builder().clone();
        let msg = Message::from_octets(result.finish().into_target()).expect(
            "Message should be able to parse output from MessageBuilder",
        );
        Ok(msg)
    }
}

impl<Octs: AsRef<[u8]> + Clone + Debug + Octets + Send + Sync + 'static>
    ComposeRequest for RequestMessage<Octs>
{
    fn append_message<Target: Composer>(
        &self,
        target: &mut Target,
    ) -> Result<(), CopyRecordsError> {
        let target = MessageBuilder::from_target(target)
            .map_err(|_| CopyRecordsError::Push(PushError::ShortBuf))?;
        self.append_message_impl(target)?;
        Ok(())
    }

    fn to_vec(&self) -> Result<Vec<u8>, Error> {
        let msg = self.to_message()?;
        Ok(msg.as_octets().clone())
    }

    fn to_message(&self) -> Result<Message<Vec<u8>>, Error> {
        self.to_message_impl()
    }

    fn header_mut(&mut self) -> &mut Header {
        &mut self.header
    }

    fn set_udp_payload_size(&mut self, value: u16) {
        self.opt_mut().set_udp_payload_size(value);
    }

    fn set_dnssec_ok(&mut self, value: bool) {
        self.opt_mut().set_dnssec_ok(value);
    }

    fn add_opt(
        &mut self,
        opt: &impl ComposeOptData,
    ) -> Result<(), LongOptData> {
        self.opt_mut().push(opt).map_err(|e| e.unlimited_buf())
    }

    fn is_answer(&self, answer: &Message<[u8]>) -> bool {
        let answer_header = answer.header();
        let answer_hcounts = answer.header_counts();

        // First check qr is set and IDs match.
        if !answer_header.qr() || answer_header.id() != self.header.id() {
            trace!(
                "Wrong QR or ID: QR={}, answer ID={}, self ID={}",
                answer_header.qr(),
                answer_header.id(),
                self.header.id()
            );
            return false;
        }

        // If the result is an error, then the question section can be empty.
        // In that case we require all other sections to be empty as well.
        if answer_header.rcode() != Rcode::NOERROR
            && answer_hcounts.qdcount() == 0
            && answer_hcounts.ancount() == 0
            && answer_hcounts.nscount() == 0
            && answer_hcounts.arcount() == 0
        {
            // We can accept this as a valid reply.
            return true;
        }

        // Now the question section in the reply has to be the same as in the
        // query.
        if answer_hcounts.qdcount() != self.msg.header_counts().qdcount() {
            trace!("Wrong QD count");
            false
        } else {
            let res = answer.question() == self.msg.for_slice().question();
            if !res {
                trace!("Wrong question");
            }
            res
        }
    }
}

//------------ Error ---------------------------------------------------------

/// Error type for client transports.
#[derive(Clone, Debug)]
pub enum Error {
    /// Connection was already closed.
    ConnectionClosed,

    /// The OPT record has become too long.
    OptTooLong,

    /// PushError from MessageBuilder.
    MessageBuilderPushError,

    /// ParseError from Message.
    MessageParseError,

    /// Underlying transport not found in redundant connection
    RedundantTransportNotFound,

    /// Octet sequence too short to be a valid DNS message.
    ShortMessage,

    /// Message too long for stream transport.
    StreamLongMessage,

    /// Stream transport closed because it was idle (for too long).
    StreamIdleTimeout,

    /// Error receiving a reply.
    //
    StreamReceiveError,

    /// Reading from stream gave an error.
    StreamReadError(Arc<std::io::Error>),

    /// Reading from stream took too long.
    StreamReadTimeout,

    /// Too many outstand queries on a single stream transport.
    StreamTooManyOutstandingQueries,

    /// Writing to a stream gave an error.
    StreamWriteError(Arc<std::io::Error>),

    /// Reading for a stream ended unexpectedly.
    StreamUnexpectedEndOfData,

    /// Reply does not match the query.
    WrongReplyForQuery,

    /// No transport available to transmit request.
    NoTransportAvailable,

    /// An error happened in the datagram transport.
    Dgram(Arc<super::dgram::QueryError>),
}

impl From<LongOptData> for Error {
    fn from(_: LongOptData) -> Self {
        Self::OptTooLong
    }
}

impl From<ParseError> for Error {
    fn from(_: ParseError) -> Self {
        Self::MessageParseError
    }
}

impl From<ShortMessage> for Error {
    fn from(_: ShortMessage) -> Self {
        Self::ShortMessage
    }
}

impl From<super::dgram::QueryError> for Error {
    fn from(err: super::dgram::QueryError) -> Self {
        Self::Dgram(err.into())
    }
}

impl fmt::Display for Error {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        match self {
            Error::ConnectionClosed => write!(f, "connection closed"),
            Error::OptTooLong => write!(f, "OPT record is too long"),
            Error::MessageBuilderPushError => {
                write!(f, "PushError from MessageBuilder")
            }
            Error::MessageParseError => write!(f, "ParseError from Message"),
            Error::RedundantTransportNotFound => write!(
                f,
                "Underlying transport not found in redundant connection"
            ),
            Error::ShortMessage => {
                write!(f, "octet sequence to short to be a valid message")
            }
            Error::StreamLongMessage => {
                write!(f, "message too long for stream transport")
            }
            Error::StreamIdleTimeout => {
                write!(f, "stream was idle for too long")
            }
            Error::StreamReceiveError => write!(f, "error receiving a reply"),
            Error::StreamReadError(_) => {
                write!(f, "error reading from stream")
            }
            Error::StreamReadTimeout => {
                write!(f, "timeout reading from stream")
            }
            Error::StreamTooManyOutstandingQueries => {
                write!(f, "too many outstanding queries on stream")
            }
            Error::StreamWriteError(_) => {
                write!(f, "error writing to stream")
            }
            Error::StreamUnexpectedEndOfData => {
                write!(f, "unexpected end of data")
            }
            Error::WrongReplyForQuery => {
                write!(f, "reply does not match query")
            }
            Error::NoTransportAvailable => {
                write!(f, "no transport available")
            }
            Error::Dgram(err) => fmt::Display::fmt(err, f),
        }
    }
}

impl From<CopyRecordsError> for Error {
    fn from(err: CopyRecordsError) -> Self {
        match err {
            CopyRecordsError::Parse(_) => Self::MessageParseError,
            CopyRecordsError::Push(_) => Self::MessageBuilderPushError,
        }
    }
}

impl error::Error for Error {
    fn source(&self) -> Option<&(dyn error::Error + 'static)> {
        match self {
            Error::ConnectionClosed => None,
            Error::OptTooLong => None,
            Error::MessageBuilderPushError => None,
            Error::MessageParseError => None,
            Error::RedundantTransportNotFound => None,
            Error::ShortMessage => None,
            Error::StreamLongMessage => None,
            Error::StreamIdleTimeout => None,
            Error::StreamReceiveError => None,
            Error::StreamReadError(e) => Some(e),
            Error::StreamReadTimeout => None,
            Error::StreamTooManyOutstandingQueries => None,
            Error::StreamWriteError(e) => Some(e),
            Error::StreamUnexpectedEndOfData => None,
            Error::WrongReplyForQuery => None,
            Error::NoTransportAvailable => None,
            Error::Dgram(err) => Some(err),
        }
    }
}