1use core::{
11 pin::Pin,
12 task::{Context, Poll},
13 time::Duration,
14};
15use std::collections::{HashMap, hash_map::Entry};
16
17use futures_channel::mpsc;
18use futures_util::{
19 FutureExt,
20 future::BoxFuture,
21 stream::{Stream, StreamExt},
22};
23use rand::RngExt;
24use tracing::debug;
25
26use super::{
27 BufDnsStreamHandle, DnsClientStream, DnsRequestSender, DnsResponseStream, ignore_send,
28};
29use crate::proto::op::{DnsRequest, DnsResponse, SerialMessage};
30#[cfg(feature = "__dnssec")]
31use crate::proto::rr::{TSigVerifier, TSigner};
32use crate::{DnsStreamHandle, error::NetError, runtime::Time};
33
34struct ActiveRequest {
35 completion: mpsc::Sender<Result<DnsResponse, NetError>>,
37 request_id: u16,
38 timeout: BoxFuture<'static, ()>,
39 #[cfg(feature = "__dnssec")]
40 verifier: Option<TSigVerifier>,
41}
42
43impl ActiveRequest {
44 fn new(
45 completion: mpsc::Sender<Result<DnsResponse, NetError>>,
46 request_id: u16,
47 timeout: BoxFuture<'static, ()>,
48 #[cfg(feature = "__dnssec")] verifier: Option<TSigVerifier>,
49 ) -> Self {
50 Self {
51 completion,
52 request_id,
53 timeout,
55 #[cfg(feature = "__dnssec")]
56 verifier,
57 }
58 }
59
60 fn poll_timeout(&mut self, cx: &mut Context<'_>) -> Poll<()> {
62 self.timeout.poll_unpin(cx)
63 }
64
65 fn is_canceled(&self) -> bool {
67 self.completion.is_closed()
68 }
69
70 fn request_id(&self) -> u16 {
72 self.request_id
73 }
74
75 fn complete_with_error(mut self, error: NetError) {
77 ignore_send(self.completion.try_send(Err(error)));
78 }
79}
80
81#[must_use = "futures do nothing unless polled"]
87pub struct DnsMultiplexer<S> {
88 stream: S,
89 timeout_duration: Duration,
90 stream_handle: BufDnsStreamHandle,
91 active_requests: HashMap<u16, ActiveRequest>,
92 max_active_requests: usize,
93 #[cfg(feature = "__dnssec")]
94 signer: Option<TSigner>,
95 is_shutdown: bool,
96}
97
98impl<S: DnsClientStream> DnsMultiplexer<S> {
99 pub fn new(stream: S, stream_handle: BufDnsStreamHandle) -> Self {
111 Self {
112 stream,
113 timeout_duration: Duration::from_secs(5),
114 stream_handle,
115 active_requests: HashMap::default(),
116 max_active_requests: 32,
117 #[cfg(feature = "__dnssec")]
118 signer: None,
119 is_shutdown: false,
120 }
121 }
122
123 pub fn with_timeout(mut self, timeout: Duration) -> Self {
125 self.timeout_duration = timeout;
126 self
127 }
128
129 pub fn with_max_active_requests(mut self, max: usize) -> Self {
135 self.max_active_requests = max;
136 self
137 }
138
139 #[cfg(feature = "__dnssec")]
141 pub fn with_signer(mut self, signer: TSigner) -> Self {
142 self.signer = Some(signer);
143 self
144 }
145
146 fn drop_cancelled(&mut self, cx: &mut Context<'_>) {
149 let mut canceled = HashMap::<u16, NetError>::new();
150 for (&id, active_req) in &mut self.active_requests {
151 if active_req.is_canceled() {
152 canceled.insert(id, NetError::from("requestor canceled"));
153 }
154
155 match active_req.poll_timeout(cx) {
157 Poll::Ready(()) => {
158 debug!("request timed out: {}", id);
159 canceled.insert(id, NetError::Timeout);
160 }
161 Poll::Pending => (),
162 }
163 }
164
165 for (id, error) in canceled {
167 if let Some(active_request) = self.active_requests.remove(&id) {
168 active_request.complete_with_error(error);
170 }
171 }
172 }
173
174 fn next_random_query_id(&self) -> Result<u16, NetError> {
176 let mut rand = rand::rng();
177
178 for _ in 0..100 {
179 let id: u16 = rand.random(); if !self.active_requests.contains_key(&id) {
182 return Ok(id);
183 }
184 }
185
186 Err(NetError::from(
187 "id space exhausted, consider filing an issue",
188 ))
189 }
190
191 fn stream_closed_close_all(&mut self, error: NetError) {
193 debug!(%error, addr = %self.stream.name_server_addr());
194 for (_, active_request) in self.active_requests.drain() {
195 active_request.complete_with_error(error.clone());
197 }
198 }
199}
200
201impl<S: DnsClientStream> DnsRequestSender for DnsMultiplexer<S> {
202 fn send_message(&mut self, request: DnsRequest) -> DnsResponseStream {
203 if self.is_shutdown {
204 panic!("can not send messages after stream is shutdown")
205 }
206
207 if self.active_requests.len() >= self.max_active_requests {
208 return NetError::Busy.into();
209 }
210
211 let query_id = match self.next_random_query_id() {
212 Ok(id) => id,
213 Err(e) => return e.into(),
214 };
215
216 let (mut request, _) = request.into_parts();
217 request.metadata.id = query_id;
218
219 #[cfg(feature = "__dnssec")]
220 let mut verifier = None;
221 #[cfg(feature = "__dnssec")]
222 if let Some(signer) = &self.signer {
223 if signer.should_sign_message(&request) {
224 match request.finalize(signer, S::Time::current_time()) {
225 Ok(answer_verifier) => verifier = answer_verifier,
226 Err(e) => {
227 debug!("could not sign message: {}", e);
228 return NetError::from(e).into();
229 }
230 }
231 }
232 }
233
234 let timeout = S::Time::delay_for(self.timeout_duration);
236
237 let (complete, receiver) = mpsc::channel(QUERY_RESPONSE_BUFFER_SIZE);
238
239 let active_request = ActiveRequest::new(
241 complete,
242 request.id,
243 timeout,
244 #[cfg(feature = "__dnssec")]
245 verifier,
246 );
247
248 match request.to_vec() {
249 Ok(buffer) => {
250 debug!(id = %active_request.request_id(), "sending message");
251 let serial_message = SerialMessage::new(buffer, self.stream.name_server_addr());
252
253 debug!(
254 "final message: {}",
255 serial_message
256 .to_message()
257 .expect("bizarre we just made this message")
258 );
259
260 match self.stream_handle.send(serial_message) {
263 Ok(()) => self
264 .active_requests
265 .insert(active_request.request_id(), active_request),
266 Err(err) => return err.into(),
267 };
268 }
269 Err(error) => {
270 debug!(
271 id = %active_request.request_id(),
272 %error,
273 "error message"
274 );
275 return NetError::from(error).into();
277 }
278 }
279
280 receiver.into()
281 }
282
283 fn shutdown(&mut self) {
284 self.is_shutdown = true;
285 }
286
287 fn is_shutdown(&self) -> bool {
288 self.is_shutdown
289 }
290}
291
292impl<S: DnsClientStream> Stream for DnsMultiplexer<S> {
293 type Item = Result<(), NetError>;
294
295 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
296 self.drop_cancelled(cx);
298
299 if self.is_shutdown && self.active_requests.is_empty() {
300 debug!("stream is done: {}", self.stream.name_server_addr());
301 return Poll::Ready(None);
302 }
303
304 let mut messages_received = 0;
308 for i in 0..QOS_MAX_RECEIVE_MSGS {
309 match self.stream.poll_next_unpin(cx) {
310 Poll::Ready(Some(Ok(buffer))) => {
311 messages_received = i;
312
313 match DnsResponse::from_buffer(buffer.into_parts().0) {
315 Ok(response) => match self.active_requests.entry(response.id) {
316 Entry::Occupied(mut request_entry) => {
317 let active_request = request_entry.get_mut();
319 #[cfg(feature = "__dnssec")]
320 if let Some(verifier) = &mut active_request.verifier {
321 ignore_send(
322 active_request.completion.try_send(
323 verifier
324 .verify(response.as_buffer())
325 .map_err(NetError::from),
326 ),
327 );
328 } else {
329 ignore_send(active_request.completion.try_send(Ok(response)));
330 }
331 #[cfg(not(feature = "__dnssec"))]
332 ignore_send(active_request.completion.try_send(Ok(response)));
333 }
334 Entry::Vacant(..) => debug!("unexpected request_id: {}", response.id),
335 },
336 Err(error) => debug!(%error, "error decoding message"),
338 }
339 }
340 Poll::Ready(err) => {
341 let err = match err {
342 Some(Err(e)) => e,
343 None => NetError::from("stream closed"),
344 _ => unreachable!(),
345 };
346
347 self.stream_closed_close_all(err);
348 self.is_shutdown = true;
349 return Poll::Ready(None);
350 }
351 Poll::Pending => break,
352 }
353 }
354
355 if messages_received == QOS_MAX_RECEIVE_MSGS {
359 cx.waker().wake_by_ref();
361 }
362
363 Poll::Pending
365 }
366}
367
368const QOS_MAX_RECEIVE_MSGS: usize = 100; const QUERY_RESPONSE_BUFFER_SIZE: usize = 8;
375
376#[cfg(test)]
377mod test {
378 use core::net::{Ipv4Addr, Ipv6Addr, SocketAddr};
379
380 use futures_util::{
381 future::{self, BoxFuture},
382 ready,
383 stream::TryStreamExt,
384 };
385 use test_support::subscribe;
386
387 use super::*;
388 use crate::proto::op::{DnsRequestOptions, Message, Query};
389 use crate::proto::rr::rdata::{NS, SOA};
390 use crate::proto::rr::{DNSClass, Name, RData, Record, RecordType};
391 use crate::proto::serialize::binary::BinEncodable;
392 use crate::xfer::{DnsClientStream, StreamReceiver};
393
394 struct MockClientStream {
395 messages: Vec<Message>,
396 addr: SocketAddr,
397 id: Option<u16>,
398 receiver: Option<StreamReceiver>,
399 }
400
401 impl MockClientStream {
402 fn new(
403 mut messages: Vec<Message>,
404 addr: SocketAddr,
405 ) -> BoxFuture<'static, Result<Self, NetError>> {
406 messages.reverse(); Box::pin(future::ok(Self {
408 messages,
409 addr,
410 id: None,
411 receiver: None,
412 }))
413 }
414 }
415
416 impl Stream for MockClientStream {
417 type Item = Result<SerialMessage, NetError>;
418
419 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
420 let id = if let Some(id) = self.id {
421 id
422 } else {
423 let serial = ready!(
424 self.receiver
425 .as_mut()
426 .expect("should only be polled after receiver has been set")
427 .poll_next_unpin(cx)
428 );
429 let message = serial.unwrap().to_message().unwrap();
430 self.id = Some(message.id);
431 message.id
432 };
433
434 if let Some(mut message) = self.messages.pop() {
435 message.metadata.id = id;
436 Poll::Ready(Some(Ok(SerialMessage::new(
437 message.to_bytes().unwrap(),
438 self.addr,
439 ))))
440 } else {
441 Poll::Pending
442 }
443 }
444 }
445
446 impl DnsClientStream for MockClientStream {
447 type Time = crate::runtime::TokioTime;
448
449 fn name_server_addr(&self) -> SocketAddr {
450 self.addr
451 }
452 }
453
454 async fn get_mocked_multiplexer(
455 mock_response: Vec<Message>,
456 ) -> DnsMultiplexer<MockClientStream> {
457 let addr = SocketAddr::from(([127, 0, 0, 1], 1234));
458 let mock_response = MockClientStream::new(mock_response, addr).await.unwrap();
459 let (handler, receiver) = BufDnsStreamHandle::new(addr);
460 let mut multiplexer =
461 DnsMultiplexer::new(mock_response, handler).with_timeout(Duration::from_millis(100));
462
463 multiplexer.stream.receiver = Some(receiver); multiplexer
466 }
467
468 fn a_query_answer() -> (DnsRequest, Vec<Message>) {
469 let name = Name::from_ascii("www.example.com.").unwrap();
470
471 let mut request = Message::query();
472 request.metadata.recursion_desired = true;
473 request.add_query({
474 let mut q = Query::query(name.clone(), RecordType::A);
475 q.set_query_class(DNSClass::IN);
476 q
477 });
478
479 let mut response = request.clone().into_response();
480 response.add_answer(Record::from_rdata(
481 name,
482 86400,
483 RData::A(Ipv4Addr::new(93, 184, 215, 14).into()),
484 ));
485 (
486 DnsRequest::new(request, DnsRequestOptions::default()),
487 vec![response],
488 )
489 }
490
491 fn axfr_query() -> Message {
492 let name = Name::from_ascii("example.com.").unwrap();
493
494 let mut msg = Message::query();
495 msg.metadata.recursion_desired = true;
496 msg.add_query({
497 let mut query = Query::query(name, RecordType::AXFR);
498 query.set_query_class(DNSClass::IN);
499 query
500 });
501 msg
502 }
503
504 fn axfr_response() -> Vec<Record> {
505 let origin = Name::from_ascii("example.com.").unwrap();
506 let soa = Record::from_rdata(
507 origin.clone(),
508 3600,
509 RData::SOA(SOA::new(
510 Name::parse("sns.dns.icann.org.", None).unwrap(),
511 Name::parse("noc.dns.icann.org.", None).unwrap(),
512 2015082403,
513 7200,
514 3600,
515 1209600,
516 3600,
517 )),
518 );
519
520 vec![
521 soa.clone(),
522 Record::from_rdata(
523 origin.clone(),
524 86400,
525 RData::NS(NS(Name::parse("a.iana-servers.net.", None).unwrap())),
526 ),
527 Record::from_rdata(
528 origin.clone(),
529 86400,
530 RData::NS(NS(Name::parse("b.iana-servers.net.", None).unwrap())),
531 ),
532 Record::from_rdata(
533 origin.clone(),
534 86400,
535 RData::A(Ipv4Addr::new(93, 184, 215, 14).into()),
536 ),
537 Record::from_rdata(
538 origin,
539 86400,
540 RData::AAAA(
541 Ipv6Addr::new(
542 0x2606, 0x2800, 0x21f, 0xcb07, 0x6820, 0x80da, 0xaf6b, 0x8b2c,
543 )
544 .into(),
545 ),
546 ),
547 soa,
548 ]
549 }
550
551 fn axfr_query_answer() -> (DnsRequest, Vec<Message>) {
552 let msg = axfr_query();
553
554 let mut response = msg.clone().into_response();
555 response.insert_answers(axfr_response());
556 (
557 DnsRequest::new(msg, DnsRequestOptions::default()),
558 vec![response],
559 )
560 }
561
562 fn axfr_query_answer_multi() -> (DnsRequest, Vec<Message>) {
563 let base = axfr_query();
564
565 let query = base.clone();
566 let mut rr = axfr_response();
567 let rr2 = rr.split_off(3);
568 let mut msg1 = base.clone().into_response();
569 msg1.insert_answers(rr);
570 let mut msg2 = base.into_response();
571 msg2.insert_answers(rr2);
572 (
573 DnsRequest::new(query, DnsRequestOptions::default()),
574 vec![msg1, msg2],
575 )
576 }
577
578 #[tokio::test]
579 async fn test_multiplexer_a() {
580 subscribe();
581 let (query, answer) = a_query_answer();
582 let mut multiplexer = get_mocked_multiplexer(answer).await;
583 let response = multiplexer.send_message(query);
584 let response = tokio::select! {
585 _ = multiplexer.next() => {
586 panic!("should never end")
588 },
589 r = response.try_collect::<Vec<_>>() => r.unwrap(),
590 };
591 assert_eq!(response.len(), 1);
592 }
593
594 #[tokio::test]
595 async fn test_multiplexer_axfr() {
596 subscribe();
597 let (query, answer) = axfr_query_answer();
598 let mut multiplexer = get_mocked_multiplexer(answer).await;
599 let response = multiplexer.send_message(query);
600 let response = tokio::select! {
601 _ = multiplexer.next() => {
602 panic!("should never end")
604 },
605 r = response.try_collect::<Vec<_>>() => r.unwrap(),
606 };
607 assert_eq!(response.len(), 1);
608 assert_eq!(response[0].answers.len(), axfr_response().len());
609 }
610
611 #[tokio::test]
612 async fn test_multiplexer_axfr_multi() {
613 subscribe();
614 let (query, answer) = axfr_query_answer_multi();
615 let mut multiplexer = get_mocked_multiplexer(answer).await;
616 let response = multiplexer.send_message(query);
617 let response = tokio::select! {
618 _ = multiplexer.next() => {
619 panic!("should never end")
621 },
622 r = response.try_collect::<Vec<_>>() => r.unwrap(),
623 };
624 assert_eq!(response.len(), 2);
625 assert_eq!(
626 response.iter().map(|m| m.answers.len()).sum::<usize>(),
627 axfr_response().len()
628 );
629 }
630}