Skip to main content

hickory_net/tcp/
tcp_client_stream.rs

1// Copyright 2015-2016 Benjamin Fry <benjaminfry@me.com>
2//
3// Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or
4// https://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// https://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8use core::net::SocketAddr;
9use core::pin::Pin;
10use core::task::{Context, Poll, ready};
11use core::time::Duration;
12use std::future::Future;
13
14use futures_util::{StreamExt, stream::Stream};
15use tracing::warn;
16
17use crate::error::NetError;
18use crate::proto::op::SerialMessage;
19#[cfg(feature = "tokio")]
20use crate::runtime::TokioTime;
21#[cfg(feature = "tokio")]
22use crate::runtime::iocompat::AsyncIoTokioAsStd;
23use crate::runtime::{DnsTcpStream, RuntimeProvider, Spawn};
24use crate::tcp::TcpStream;
25use crate::xfer::{DnsClientStream, DnsExchange};
26use crate::{BufDnsStreamHandle, DnsMultiplexer};
27
28/// Tcp client stream
29///
30/// Use with `hickory_client::client::DnsMultiplexer` impls
31#[must_use = "futures do nothing unless polled"]
32pub struct TcpClientStream<S>
33where
34    S: DnsTcpStream,
35{
36    tcp_stream: TcpStream<S>,
37}
38
39impl<S: DnsTcpStream> TcpClientStream<S> {
40    /// Create a new [`DnsExchange`] wrapped around a multiplexed [`TcpClientStream`]
41    ///
42    /// # Arguments
43    ///
44    /// * `remote_addr` - Address of the remote nameserver
45    /// * `bind_addr` - Optional local address to bind to
46    /// * `timeout` - Timeout for requests
47    /// * `max_active_requests` - Optional limit on concurrent in-flight requests.
48    ///   If `None`, uses the default (32).
49    /// * `provider` - Runtime provider for spawning background tasks
50    pub async fn exchange<P: RuntimeProvider<Tcp = S>>(
51        remote_addr: SocketAddr,
52        bind_addr: Option<SocketAddr>,
53        timeout: Duration,
54        max_active_requests: Option<usize>,
55        provider: P,
56    ) -> Result<DnsExchange<P>, NetError> {
57        let mut handle = provider.create_handle();
58        let (future, sender) = Self::new(remote_addr, bind_addr, Some(timeout), provider);
59
60        // TODO: need config for Signer...
61        let mut multiplexer = DnsMultiplexer::new(future.await?, sender).with_timeout(timeout);
62        if let Some(max) = max_active_requests {
63            multiplexer = multiplexer.with_max_active_requests(max);
64        }
65        let (exchange, bg) = DnsExchange::from_stream(multiplexer);
66        handle.spawn_bg(bg);
67        Ok(exchange)
68    }
69
70    /// Create a new TcpClientStream
71    pub fn new<P: RuntimeProvider<Tcp = S>>(
72        peer_addr: SocketAddr,
73        bind_addr: Option<SocketAddr>,
74        timeout: Option<Duration>,
75        provider: P,
76    ) -> (
77        impl Future<Output = Result<Self, NetError>> + Send + 'static,
78        BufDnsStreamHandle,
79    ) {
80        let (sender, outbound_messages) = BufDnsStreamHandle::new(peer_addr);
81        (
82            async move {
83                let tcp = provider.connect_tcp(peer_addr, bind_addr, timeout).await?;
84                Ok(Self::from_stream(TcpStream::from_stream_with_receiver(
85                    tcp,
86                    peer_addr,
87                    outbound_messages,
88                )))
89            },
90            sender,
91        )
92    }
93
94    /// Wraps the TcpStream in TcpClientStream
95    pub fn from_stream(tcp_stream: TcpStream<S>) -> Self {
96        Self { tcp_stream }
97    }
98}
99
100impl<S: DnsTcpStream> DnsClientStream for TcpClientStream<S> {
101    type Time = S::Time;
102
103    fn name_server_addr(&self) -> SocketAddr {
104        self.tcp_stream.peer_addr()
105    }
106}
107
108impl<S: DnsTcpStream> Stream for TcpClientStream<S> {
109    type Item = Result<SerialMessage, NetError>;
110
111    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
112        let message = match ready!(self.tcp_stream.poll_next_unpin(cx)) {
113            Some(Ok(t)) => t,
114            Some(Err(e)) => return Poll::Ready(Some(Err(NetError::from(e)))),
115            None => return Poll::Ready(None),
116        };
117
118        // this is busted if the tcp connection doesn't have a peer
119        let peer = self.tcp_stream.peer_addr();
120        if message.addr() != peer {
121            // TODO: this should be an error, right?
122            warn!("{} does not match name_server: {}", message.addr(), peer)
123        }
124
125        Poll::Ready(Some(Ok(message)))
126    }
127}
128
129#[cfg(feature = "tokio")]
130impl<T> DnsTcpStream for AsyncIoTokioAsStd<T>
131where
132    T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + Sync + Sized + 'static,
133{
134    type Time = TokioTime;
135}
136
137#[cfg(test)]
138#[cfg(feature = "tokio")]
139mod tests {
140    use core::net::{IpAddr, Ipv4Addr, Ipv6Addr};
141
142    use test_support::subscribe;
143
144    use crate::runtime::TokioRuntimeProvider;
145    use crate::tcp::tests::tcp_client_stream_test;
146    #[tokio::test]
147    async fn test_tcp_stream_ipv4() {
148        subscribe();
149        tcp_client_stream_test(IpAddr::V4(Ipv4Addr::LOCALHOST), TokioRuntimeProvider::new()).await;
150    }
151
152    #[tokio::test]
153    async fn test_tcp_stream_ipv6() {
154        subscribe();
155        tcp_client_stream_test(
156            IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)),
157            TokioRuntimeProvider::new(),
158        )
159        .await;
160    }
161}