hickory_net/tcp/
tcp_client_stream.rs1use core::net::SocketAddr;
9use core::pin::Pin;
10use core::task::{Context, Poll, ready};
11use core::time::Duration;
12use std::future::Future;
13
14use futures_util::{StreamExt, stream::Stream};
15use tracing::warn;
16
17use crate::error::NetError;
18use crate::proto::op::SerialMessage;
19#[cfg(feature = "tokio")]
20use crate::runtime::TokioTime;
21#[cfg(feature = "tokio")]
22use crate::runtime::iocompat::AsyncIoTokioAsStd;
23use crate::runtime::{DnsTcpStream, RuntimeProvider, Spawn};
24use crate::tcp::TcpStream;
25use crate::xfer::{DnsClientStream, DnsExchange};
26use crate::{BufDnsStreamHandle, DnsMultiplexer};
27
28#[must_use = "futures do nothing unless polled"]
32pub struct TcpClientStream<S>
33where
34 S: DnsTcpStream,
35{
36 tcp_stream: TcpStream<S>,
37}
38
39impl<S: DnsTcpStream> TcpClientStream<S> {
40 pub async fn exchange<P: RuntimeProvider<Tcp = S>>(
51 remote_addr: SocketAddr,
52 bind_addr: Option<SocketAddr>,
53 timeout: Duration,
54 max_active_requests: Option<usize>,
55 provider: P,
56 ) -> Result<DnsExchange<P>, NetError> {
57 let mut handle = provider.create_handle();
58 let (future, sender) = Self::new(remote_addr, bind_addr, Some(timeout), provider);
59
60 let mut multiplexer = DnsMultiplexer::new(future.await?, sender).with_timeout(timeout);
62 if let Some(max) = max_active_requests {
63 multiplexer = multiplexer.with_max_active_requests(max);
64 }
65 let (exchange, bg) = DnsExchange::from_stream(multiplexer);
66 handle.spawn_bg(bg);
67 Ok(exchange)
68 }
69
70 pub fn new<P: RuntimeProvider<Tcp = S>>(
72 peer_addr: SocketAddr,
73 bind_addr: Option<SocketAddr>,
74 timeout: Option<Duration>,
75 provider: P,
76 ) -> (
77 impl Future<Output = Result<Self, NetError>> + Send + 'static,
78 BufDnsStreamHandle,
79 ) {
80 let (sender, outbound_messages) = BufDnsStreamHandle::new(peer_addr);
81 (
82 async move {
83 let tcp = provider.connect_tcp(peer_addr, bind_addr, timeout).await?;
84 Ok(Self::from_stream(TcpStream::from_stream_with_receiver(
85 tcp,
86 peer_addr,
87 outbound_messages,
88 )))
89 },
90 sender,
91 )
92 }
93
94 pub fn from_stream(tcp_stream: TcpStream<S>) -> Self {
96 Self { tcp_stream }
97 }
98}
99
100impl<S: DnsTcpStream> DnsClientStream for TcpClientStream<S> {
101 type Time = S::Time;
102
103 fn name_server_addr(&self) -> SocketAddr {
104 self.tcp_stream.peer_addr()
105 }
106}
107
108impl<S: DnsTcpStream> Stream for TcpClientStream<S> {
109 type Item = Result<SerialMessage, NetError>;
110
111 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
112 let message = match ready!(self.tcp_stream.poll_next_unpin(cx)) {
113 Some(Ok(t)) => t,
114 Some(Err(e)) => return Poll::Ready(Some(Err(NetError::from(e)))),
115 None => return Poll::Ready(None),
116 };
117
118 let peer = self.tcp_stream.peer_addr();
120 if message.addr() != peer {
121 warn!("{} does not match name_server: {}", message.addr(), peer)
123 }
124
125 Poll::Ready(Some(Ok(message)))
126 }
127}
128
129#[cfg(feature = "tokio")]
130impl<T> DnsTcpStream for AsyncIoTokioAsStd<T>
131where
132 T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + Sync + Sized + 'static,
133{
134 type Time = TokioTime;
135}
136
137#[cfg(test)]
138#[cfg(feature = "tokio")]
139mod tests {
140 use core::net::{IpAddr, Ipv4Addr, Ipv6Addr};
141
142 use test_support::subscribe;
143
144 use crate::runtime::TokioRuntimeProvider;
145 use crate::tcp::tests::tcp_client_stream_test;
146 #[tokio::test]
147 async fn test_tcp_stream_ipv4() {
148 subscribe();
149 tcp_client_stream_test(IpAddr::V4(Ipv4Addr::LOCALHOST), TokioRuntimeProvider::new()).await;
150 }
151
152 #[tokio::test]
153 async fn test_tcp_stream_ipv6() {
154 subscribe();
155 tcp_client_stream_test(
156 IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)),
157 TokioRuntimeProvider::new(),
158 )
159 .await;
160 }
161}