1use std::{
11 cmp::min,
12 pin::Pin,
13 slice::Iter,
14 sync::Arc,
15 time::{Duration, Instant},
16};
17
18use futures_util::stream::Stream;
19
20use crate::{
21 dns_lru::MAX_TTL,
22 lookup_ip::LookupIpIter,
23 name_server::{ConnectionProvider, NameServerPool},
24 proto::{
25 DnsHandle, ProtoError, RetryDnsHandle,
26 op::Query,
27 rr::{
28 RData, Record,
29 rdata::{self, A, AAAA, NS, PTR},
30 },
31 xfer::{DnsRequest, DnsResponse},
32 },
33};
34
35#[cfg(feature = "__dnssec")]
36use crate::proto::dnssec::{DnssecDnsHandle, Proven};
37
38#[derive(Clone, Debug, Eq, PartialEq)]
42pub struct Lookup {
43 query: Query,
44 records: Arc<[Record]>,
45 valid_until: Instant,
46}
47
48impl Lookup {
49 pub fn from_rdata(query: Query, rdata: RData) -> Self {
51 let record = Record::from_rdata(query.name().clone(), MAX_TTL, rdata);
52 Self::new_with_max_ttl(query, Arc::from([record]))
53 }
54
55 pub fn new_with_max_ttl(query: Query, records: Arc<[Record]>) -> Self {
57 let valid_until = Instant::now() + Duration::from_secs(u64::from(MAX_TTL));
58 Self {
59 query,
60 records,
61 valid_until,
62 }
63 }
64
65 pub fn new_with_deadline(query: Query, records: Arc<[Record]>, valid_until: Instant) -> Self {
67 Self {
68 query,
69 records,
70 valid_until,
71 }
72 }
73
74 pub fn query(&self) -> &Query {
76 &self.query
77 }
78
79 pub fn iter(&self) -> LookupIter<'_> {
83 LookupIter(self.records.iter())
84 }
85
86 #[cfg(feature = "__dnssec")]
88 pub fn dnssec_iter(&self) -> DnssecIter<'_> {
89 DnssecIter(self.dnssec_record_iter())
90 }
91
92 pub fn record_iter(&self) -> LookupRecordIter<'_> {
96 LookupRecordIter(self.records.iter())
97 }
98
99 #[cfg(feature = "__dnssec")]
101 pub fn dnssec_record_iter(&self) -> DnssecLookupRecordIter<'_> {
102 DnssecLookupRecordIter(self.records.iter())
103 }
104
105 pub fn valid_until(&self) -> Instant {
107 self.valid_until
108 }
109
110 #[doc(hidden)]
111 pub fn is_empty(&self) -> bool {
112 self.records.is_empty()
113 }
114
115 pub(crate) fn len(&self) -> usize {
116 self.records.len()
117 }
118
119 pub fn records(&self) -> &[Record] {
122 self.records.as_ref()
123 }
124
125 pub(crate) fn append(&self, other: Self) -> Self {
127 let mut records = Vec::with_capacity(self.len() + other.len());
128 records.extend_from_slice(&self.records);
129 records.extend_from_slice(&other.records);
130
131 let valid_until = min(self.valid_until(), other.valid_until());
133 Self::new_with_deadline(self.query.clone(), Arc::from(records), valid_until)
134 }
135
136 pub fn extend_records(&mut self, other: Vec<Record>) {
138 let mut records = Vec::with_capacity(self.len() + other.len());
139 records.extend_from_slice(&self.records);
140 records.extend(other);
141 self.records = Arc::from(records);
142 }
143}
144
145pub struct LookupIter<'a>(Iter<'a, Record>);
147
148impl<'a> Iterator for LookupIter<'a> {
149 type Item = &'a RData;
150
151 fn next(&mut self) -> Option<Self::Item> {
152 self.0.next().map(Record::data)
153 }
154}
155
156#[cfg(feature = "__dnssec")]
158pub struct DnssecIter<'a>(DnssecLookupRecordIter<'a>);
159
160#[cfg(feature = "__dnssec")]
161impl<'a> Iterator for DnssecIter<'a> {
162 type Item = Proven<&'a RData>;
163
164 fn next(&mut self) -> Option<Self::Item> {
165 self.0.next().map(|r| r.map(Record::data))
166 }
167}
168
169pub struct LookupRecordIter<'a>(Iter<'a, Record>);
171
172impl<'a> Iterator for LookupRecordIter<'a> {
173 type Item = &'a Record;
174
175 fn next(&mut self) -> Option<Self::Item> {
176 self.0.next()
177 }
178}
179
180#[cfg(feature = "__dnssec")]
182pub struct DnssecLookupRecordIter<'a>(Iter<'a, Record>);
183
184#[cfg(feature = "__dnssec")]
185impl<'a> Iterator for DnssecLookupRecordIter<'a> {
186 type Item = Proven<&'a Record>;
187
188 fn next(&mut self) -> Option<Self::Item> {
189 self.0.next().map(Proven::from)
190 }
191}
192
193impl IntoIterator for Lookup {
195 type Item = RData;
196 type IntoIter = LookupIntoIter;
197
198 fn into_iter(self) -> Self::IntoIter {
200 LookupIntoIter {
201 records: Arc::clone(&self.records),
202 index: 0,
203 }
204 }
205}
206
207pub struct LookupIntoIter {
211 records: Arc<[Record]>,
212 index: usize,
213}
214
215impl Iterator for LookupIntoIter {
216 type Item = RData;
217
218 fn next(&mut self) -> Option<Self::Item> {
219 let rdata = self.records.get(self.index).map(Record::data);
220 self.index += 1;
221 rdata.cloned()
222 }
223}
224
225#[derive(Clone)]
227#[doc(hidden)]
228pub enum LookupEither<P: ConnectionProvider + Send> {
229 Retry(RetryDnsHandle<NameServerPool<P>>),
230 #[cfg(feature = "__dnssec")]
231 Secure(DnssecDnsHandle<RetryDnsHandle<NameServerPool<P>>>),
232}
233
234impl<P: ConnectionProvider> DnsHandle for LookupEither<P> {
235 type Response = Pin<Box<dyn Stream<Item = Result<DnsResponse, ProtoError>> + Send>>;
236
237 fn is_verifying_dnssec(&self) -> bool {
238 match self {
239 Self::Retry(c) => c.is_verifying_dnssec(),
240 #[cfg(feature = "__dnssec")]
241 Self::Secure(c) => c.is_verifying_dnssec(),
242 }
243 }
244
245 fn send<R: Into<DnsRequest> + Unpin + Send + 'static>(&self, request: R) -> Self::Response {
246 match self {
247 Self::Retry(c) => c.send(request),
248 #[cfg(feature = "__dnssec")]
249 Self::Secure(c) => c.send(request),
250 }
251 }
252}
253
254#[derive(Debug, Clone)]
256pub struct SrvLookup(Lookup);
257
258impl SrvLookup {
259 pub fn iter(&self) -> SrvLookupIter<'_> {
261 SrvLookupIter(self.0.iter())
262 }
263
264 pub fn query(&self) -> &Query {
266 self.0.query()
267 }
268
269 pub fn ip_iter(&self) -> LookupIpIter<'_> {
273 LookupIpIter(self.0.iter())
274 }
275
276 pub fn as_lookup(&self) -> &Lookup {
280 &self.0
281 }
282}
283
284impl From<Lookup> for SrvLookup {
285 fn from(lookup: Lookup) -> Self {
286 Self(lookup)
287 }
288}
289
290pub struct SrvLookupIter<'i>(LookupIter<'i>);
292
293impl<'i> Iterator for SrvLookupIter<'i> {
294 type Item = &'i rdata::SRV;
295
296 fn next(&mut self) -> Option<Self::Item> {
297 let iter: &mut _ = &mut self.0;
298 iter.find_map(|rdata| match rdata {
299 RData::SRV(data) => Some(data),
300 _ => None,
301 })
302 }
303}
304
305impl IntoIterator for SrvLookup {
306 type Item = rdata::SRV;
307 type IntoIter = SrvLookupIntoIter;
308
309 fn into_iter(self) -> Self::IntoIter {
311 SrvLookupIntoIter(self.0.into_iter())
312 }
313}
314
315pub struct SrvLookupIntoIter(LookupIntoIter);
317
318impl Iterator for SrvLookupIntoIter {
319 type Item = rdata::SRV;
320
321 fn next(&mut self) -> Option<Self::Item> {
322 let iter: &mut _ = &mut self.0;
323 iter.find_map(|rdata| match rdata {
324 RData::SRV(data) => Some(data),
325 _ => None,
326 })
327 }
328}
329
330macro_rules! lookup_type {
332 ($l:ident, $i:ident, $ii:ident, $r:path, $t:path) => {
333 #[derive(Debug, Clone)]
335 pub struct $l(Lookup);
336
337 impl $l {
338 #[doc = stringify!(Returns an iterator over the records that match $r)]
339 pub fn iter(&self) -> $i<'_> {
340 $i(self.0.iter())
341 }
342
343 pub fn query(&self) -> &Query {
345 self.0.query()
346 }
347
348 pub fn valid_until(&self) -> Instant {
350 self.0.valid_until()
351 }
352
353 pub fn as_lookup(&self) -> &Lookup {
357 &self.0
358 }
359 }
360
361 impl From<Lookup> for $l {
362 fn from(lookup: Lookup) -> Self {
363 $l(lookup)
364 }
365 }
366
367 impl From<$l> for Lookup {
368 fn from(revlookup: $l) -> Self {
369 revlookup.0
370 }
371 }
372
373 pub struct $i<'i>(LookupIter<'i>);
375
376 impl<'i> Iterator for $i<'i> {
377 type Item = &'i $t;
378
379 fn next(&mut self) -> Option<Self::Item> {
380 let iter: &mut _ = &mut self.0;
381 iter.find_map(|rdata| match rdata {
382 $r(data) => Some(data),
383 _ => None,
384 })
385 }
386 }
387
388 impl IntoIterator for $l {
389 type Item = $t;
390 type IntoIter = $ii;
391
392 fn into_iter(self) -> Self::IntoIter {
394 $ii(self.0.into_iter())
395 }
396 }
397
398 pub struct $ii(LookupIntoIter);
400
401 impl Iterator for $ii {
402 type Item = $t;
403
404 fn next(&mut self) -> Option<Self::Item> {
405 let iter: &mut _ = &mut self.0;
406 iter.find_map(|rdata| match rdata {
407 $r(data) => Some(data),
408 _ => None,
409 })
410 }
411 }
412 };
413}
414
415lookup_type!(
417 ReverseLookup,
418 ReverseLookupIter,
419 ReverseLookupIntoIter,
420 RData::PTR,
421 PTR
422);
423lookup_type!(Ipv4Lookup, Ipv4LookupIter, Ipv4LookupIntoIter, RData::A, A);
424lookup_type!(
425 Ipv6Lookup,
426 Ipv6LookupIter,
427 Ipv6LookupIntoIter,
428 RData::AAAA,
429 AAAA
430);
431lookup_type!(
432 MxLookup,
433 MxLookupIter,
434 MxLookupIntoIter,
435 RData::MX,
436 rdata::MX
437);
438lookup_type!(
439 TlsaLookup,
440 TlsaLookupIter,
441 TlsaLookupIntoIter,
442 RData::TLSA,
443 rdata::TLSA
444);
445lookup_type!(
446 TxtLookup,
447 TxtLookupIter,
448 TxtLookupIntoIter,
449 RData::TXT,
450 rdata::TXT
451);
452lookup_type!(
453 CertLookup,
454 CertLookupIter,
455 CertLookupIntoIter,
456 RData::CERT,
457 rdata::CERT
458);
459lookup_type!(
460 SoaLookup,
461 SoaLookupIter,
462 SoaLookupIntoIter,
463 RData::SOA,
464 rdata::SOA
465);
466lookup_type!(NsLookup, NsLookupIter, NsLookupIntoIter, RData::NS, NS);
467
468#[cfg(test)]
469mod tests {
470 use std::str::FromStr;
471 use std::sync::Arc;
472
473 #[cfg(feature = "__dnssec")]
474 use crate::proto::op::Query;
475 use crate::proto::rr::{Name, RData, Record};
476
477 use super::*;
478
479 #[test]
480 fn test_lookup_into_iter_arc() {
481 let mut lookup = LookupIntoIter {
482 records: Arc::from([
483 Record::from_rdata(
484 Name::from_str("www.example.com.").unwrap(),
485 80,
486 RData::A(A::new(127, 0, 0, 1)),
487 ),
488 Record::from_rdata(
489 Name::from_str("www.example.com.").unwrap(),
490 80,
491 RData::A(A::new(127, 0, 0, 2)),
492 ),
493 ]),
494 index: 0,
495 };
496
497 assert_eq!(lookup.next().unwrap(), RData::A(A::new(127, 0, 0, 1)));
498 assert_eq!(lookup.next().unwrap(), RData::A(A::new(127, 0, 0, 2)));
499 assert_eq!(lookup.next(), None);
500 }
501
502 #[test]
503 #[cfg(feature = "__dnssec")]
504 fn test_dnssec_lookup() {
505 use hickory_proto::dnssec::Proof;
506
507 let mut a1 = Record::from_rdata(
508 Name::from_str("www.example.com.").unwrap(),
509 80,
510 RData::A(A::new(127, 0, 0, 1)),
511 );
512 a1.set_proof(Proof::Secure);
513
514 let mut a2 = Record::from_rdata(
515 Name::from_str("www.example.com.").unwrap(),
516 80,
517 RData::A(A::new(127, 0, 0, 2)),
518 );
519 a2.set_proof(Proof::Insecure);
520
521 let lookup = Lookup {
522 query: Query::default(),
523 records: Arc::from([a1.clone(), a2.clone()]),
524 valid_until: Instant::now(),
525 };
526
527 let mut lookup = lookup.dnssec_iter();
528
529 assert_eq!(
530 *lookup.next().unwrap().require(Proof::Secure).unwrap(),
531 *a1.data()
532 );
533 assert_eq!(
534 *lookup.next().unwrap().require(Proof::Insecure).unwrap(),
535 *a2.data()
536 );
537 assert_eq!(lookup.next(), None);
538 }
539}