use core::ops::{Deref, DerefMut};
#[cfg(feature = "std")]
use core::time::Duration;
#[cfg(feature = "std")]
use super::{DEFAULT_RETRY_FLOOR, Edns};
use super::{Message, Query, edns::DEFAULT_MAX_PAYLOAD_LEN};
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
#[non_exhaustive]
pub struct DnsRequestOptions {
pub use_edns: bool,
pub edns_payload_len: u16,
pub edns_set_dnssec_ok: bool,
pub max_request_depth: usize,
pub recursion_desired: bool,
#[cfg(feature = "std")]
pub case_randomization: bool,
#[cfg(feature = "std")]
pub retry_interval: Duration,
}
impl Default for DnsRequestOptions {
fn default() -> Self {
Self {
max_request_depth: 26,
use_edns: true,
edns_payload_len: DEFAULT_MAX_PAYLOAD_LEN,
edns_set_dnssec_ok: false,
recursion_desired: true,
#[cfg(feature = "std")]
case_randomization: false,
#[cfg(feature = "std")]
retry_interval: DEFAULT_RETRY_FLOOR,
}
}
}
#[derive(Clone, PartialEq, Eq)]
pub struct DnsRequest {
message: Message,
options: DnsRequestOptions,
original_query: Option<Query>,
}
impl DnsRequest {
#[cfg(feature = "std")]
pub fn from_query(mut query: Query, options: DnsRequestOptions) -> Self {
let mut message = Message::query();
let mut original_query = None;
if options.case_randomization {
original_query = Some(query.clone());
query.name.randomize_label_case();
}
message.queries.push(query);
message.metadata.recursion_desired = options.recursion_desired;
if options.use_edns {
message
.edns
.get_or_insert_with(Edns::new)
.set_max_payload(options.edns_payload_len)
.set_dnssec_ok(options.edns_set_dnssec_ok);
}
Self::new(message, options).with_original_query(original_query)
}
pub fn new(message: Message, options: DnsRequestOptions) -> Self {
Self {
message,
options,
original_query: None,
}
}
pub fn with_original_query(mut self, original_query: Option<Query>) -> Self {
self.original_query = original_query;
self
}
pub fn options(&self) -> &DnsRequestOptions {
&self.options
}
pub fn options_mut(&mut self) -> &mut DnsRequestOptions {
&mut self.options
}
pub fn into_parts(self) -> (Message, DnsRequestOptions) {
(self.message, self.options)
}
pub fn original_query(&self) -> Option<&Query> {
self.original_query.as_ref()
}
}
impl Deref for DnsRequest {
type Target = Message;
fn deref(&self) -> &Self::Target {
&self.message
}
}
impl DerefMut for DnsRequest {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.message
}
}
impl From<Message> for DnsRequest {
fn from(message: Message) -> Self {
Self::new(message, DnsRequestOptions::default())
}
}
#[cfg(all(test, feature = "std"))]
mod tests {
use super::*;
use crate::rr::{Name, RecordType};
#[test]
fn from_query_default_includes_edns() {
let query = Query::query(Name::from_ascii("example.com.").unwrap(), RecordType::A);
let request = DnsRequest::from_query(query, DnsRequestOptions::default());
assert!(request.edns.is_some());
assert_eq!(request.max_payload(), DEFAULT_MAX_PAYLOAD_LEN);
}
#[test]
fn from_query_edns_disabled_no_opt() {
let query = Query::query(Name::from_ascii("example.com.").unwrap(), RecordType::A);
let request = DnsRequest::from_query(
query,
DnsRequestOptions {
use_edns: false,
..DnsRequestOptions::default()
},
);
assert!(request.edns.is_none());
assert_eq!(request.max_payload(), 512);
}
}