hickory_net/tcp/tcp_stream.rs
1// Copyright 2015-2016 Benjamin Fry <benjaminfry@me.com>
2//
3// Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or
4// https://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// https://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8//! This module contains all the TCP structures for demuxing TCP into streams of DNS packets.
9
10use core::mem;
11use core::net::SocketAddr;
12use core::pin::Pin;
13use core::task::{Context, Poll};
14use core::time::Duration;
15use std::io;
16
17use futures_io::IoSlice;
18use futures_util::stream::Stream;
19use futures_util::{self, future::Future, ready};
20use tracing::{debug, trace};
21
22use crate::proto::op::SerialMessage;
23use crate::runtime::{DnsTcpStream, Time};
24use crate::xfer::{BufDnsStreamHandle, StreamReceiver};
25
26/// Current state while writing to the remote of the TCP connection
27enum WriteTcpState {
28 /// Currently writing the length of bytes to of the buffer.
29 LenBytes {
30 /// Current position in the length buffer being written
31 pos: usize,
32 /// Length of the buffer
33 length: [u8; 2],
34 /// Buffer to write after the length
35 bytes: Vec<u8>,
36 },
37 /// Currently writing the buffer to the remote
38 Bytes {
39 /// Current position in the buffer written
40 pos: usize,
41 /// Buffer to write to the remote
42 bytes: Vec<u8>,
43 },
44 /// Currently flushing the bytes to the remote
45 Flushing,
46}
47
48/// Current state of a TCP stream as it's being read.
49pub(crate) enum ReadTcpState {
50 /// Currently reading the length of the TCP packet
51 LenBytes {
52 /// Current position in the buffer
53 pos: usize,
54 /// Buffer of the length to read
55 bytes: [u8; 2],
56 },
57 /// Currently reading the bytes of the DNS packet
58 Bytes {
59 /// Current position while reading the buffer
60 pos: usize,
61 /// buffer being read into
62 bytes: Vec<u8>,
63 },
64}
65
66/// A Stream used for sending data to and from a remote DNS endpoint (client or server).
67#[must_use = "futures do nothing unless polled"]
68pub struct TcpStream<S: DnsTcpStream> {
69 socket: S,
70 outbound_messages: StreamReceiver,
71 send_state: Option<WriteTcpState>,
72 read_state: ReadTcpState,
73 peer_addr: SocketAddr,
74}
75
76impl<S: DnsTcpStream> TcpStream<S> {
77 /// Returns the address of the peer connection.
78 pub fn peer_addr(&self) -> SocketAddr {
79 self.peer_addr
80 }
81
82 fn pollable_split(
83 &mut self,
84 ) -> (
85 &mut S,
86 &mut StreamReceiver,
87 &mut Option<WriteTcpState>,
88 &mut ReadTcpState,
89 ) {
90 (
91 &mut self.socket,
92 &mut self.outbound_messages,
93 &mut self.send_state,
94 &mut self.read_state,
95 )
96 }
97
98 /// Initializes a TcpStream with an established connection.
99 ///
100 /// Uses the default buffer size (32) for the outbound message queue.
101 ///
102 /// # Arguments
103 ///
104 /// * `stream` - the established IO stream for communication
105 /// * `peer_addr` - address of the remote peer
106 pub fn from_stream(stream: S, peer_addr: SocketAddr) -> (Self, BufDnsStreamHandle) {
107 let (message_sender, outbound_messages) = BufDnsStreamHandle::new(peer_addr);
108 let stream = Self::from_stream_with_receiver(stream, peer_addr, outbound_messages);
109 (stream, message_sender)
110 }
111
112 /// Initializes a TcpStream with an established connection and explicit buffer size.
113 ///
114 /// Use this when you need a larger buffer to handle high message rates without
115 /// dropping messages due to backpressure.
116 ///
117 /// # Arguments
118 ///
119 /// * `stream` - the established IO stream for communication
120 /// * `peer_addr` - address of the remote peer
121 /// * `buffer_size` - maximum number of messages that can be queued for sending
122 pub fn from_stream_with_buffer_size(
123 stream: S,
124 peer_addr: SocketAddr,
125 buffer_size: usize,
126 ) -> (Self, BufDnsStreamHandle) {
127 let (message_sender, outbound_messages) =
128 BufDnsStreamHandle::with_buffer_size(peer_addr, buffer_size);
129 let stream = Self::from_stream_with_receiver(stream, peer_addr, outbound_messages);
130 (stream, message_sender)
131 }
132
133 /// Wraps a stream where a sender and receiver have already been established
134 pub fn from_stream_with_receiver(
135 socket: S,
136 peer_addr: SocketAddr,
137 outbound_messages: StreamReceiver,
138 ) -> Self {
139 Self {
140 socket,
141 outbound_messages,
142 send_state: None,
143 read_state: ReadTcpState::LenBytes {
144 pos: 0,
145 bytes: [0u8; 2],
146 },
147 peer_addr,
148 }
149 }
150
151 /// Creates a new future of the eventually establish a IO stream connection or fail trying
152 ///
153 /// # Arguments
154 ///
155 /// * `future` - underlying stream future which this tcp stream relies on
156 /// * `name_server` - the IP and Port of the DNS server to connect to
157 /// * `timeout` - connection timeout
158 pub fn with_future<F: Future<Output = Result<S, io::Error>> + Send + 'static>(
159 future: F,
160 name_server: SocketAddr,
161 timeout: Duration,
162 ) -> (
163 impl Future<Output = Result<Self, io::Error>> + Send,
164 BufDnsStreamHandle,
165 ) {
166 let (message_sender, outbound_messages) = BufDnsStreamHandle::new(name_server);
167 let stream_fut = Self::connect_with_future(future, name_server, timeout, outbound_messages);
168
169 (stream_fut, message_sender)
170 }
171
172 async fn connect_with_future<F: Future<Output = Result<S, io::Error>> + Send + 'static>(
173 future: F,
174 name_server: SocketAddr,
175 timeout: Duration,
176 outbound_messages: StreamReceiver,
177 ) -> Result<Self, io::Error> {
178 let socket = S::Time::timeout(timeout, future).await??;
179 debug!("TCP connection established to: {}", name_server);
180 Ok(Self {
181 socket,
182 outbound_messages,
183 send_state: None,
184 read_state: ReadTcpState::LenBytes {
185 pos: 0,
186 bytes: [0u8; 2],
187 },
188 peer_addr: name_server,
189 })
190 }
191}
192
193impl<S: DnsTcpStream> Stream for TcpStream<S> {
194 type Item = io::Result<SerialMessage>;
195
196 #[allow(clippy::cognitive_complexity)]
197 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
198 let peer = self.peer_addr;
199 let (socket, outbound_messages, send_state, read_state) = self.pollable_split();
200 let mut socket = Pin::new(socket);
201 let mut outbound_messages = Pin::new(outbound_messages);
202
203 // this will not accept incoming data while there is data to send
204 // makes this self throttling.
205 // TODO: it might be interesting to try and split the sending and receiving futures.
206 loop {
207 // in the case we are sending, send it all?
208 if send_state.is_some() {
209 // sending...
210 match send_state {
211 Some(WriteTcpState::LenBytes { pos, length, bytes }) => {
212 let wrote = ready!(socket.as_mut().poll_write_vectored(
213 cx,
214 &[IoSlice::new(&length[*pos..]), IoSlice::new(bytes)]
215 ))?;
216 *pos += wrote;
217 }
218 Some(WriteTcpState::Bytes { pos, bytes }) => {
219 let wrote = ready!(socket.as_mut().poll_write(cx, &bytes[*pos..]))?;
220 *pos += wrote;
221 }
222 Some(WriteTcpState::Flushing) => {
223 ready!(socket.as_mut().poll_flush(cx))?;
224 }
225 _ => (),
226 }
227
228 // get current state
229 let current_state = send_state.take();
230
231 // switch states
232 match current_state {
233 Some(WriteTcpState::LenBytes { pos, length, bytes }) => {
234 if pos < length.len() {
235 *send_state = Some(WriteTcpState::LenBytes { pos, length, bytes });
236 } else if pos < length.len() + bytes.len() {
237 *send_state = Some(WriteTcpState::Bytes {
238 pos: pos - length.len(),
239 bytes,
240 });
241 } else {
242 *send_state = Some(WriteTcpState::Flushing);
243 }
244 }
245 Some(WriteTcpState::Bytes { pos, bytes }) => {
246 if pos < bytes.len() {
247 *send_state = Some(WriteTcpState::Bytes { pos, bytes });
248 } else {
249 // At this point we successfully delivered the entire message.
250 // flush
251 *send_state = Some(WriteTcpState::Flushing);
252 }
253 }
254 Some(WriteTcpState::Flushing) => {
255 // At this point we successfully delivered the entire message.
256 send_state.take();
257 }
258 None => (),
259 };
260 } else {
261 // then see if there is more to send
262 match outbound_messages.as_mut().poll_next(cx)
263 // .map_err(|()| io::Error::new(io::ErrorKind::Other, "unknown"))?
264 {
265 // already handled above, here to make sure the poll() pops the next message
266 Poll::Ready(Some(message)) => {
267 // if there is no peer, this connection should die...
268 let (buffer, dst) = message.into();
269
270 // This is an error if the destination is not our peer (this is TCP after all)
271 // This will kill the connection...
272 if peer != dst {
273 return Poll::Ready(Some(Err(io::Error::new(
274 io::ErrorKind::InvalidData,
275 format!("mismatched peer: {peer} and dst: {dst}"),
276 ))));
277 }
278
279 // will return if the socket will block
280 // the length is 16 bits
281 let len = u16::to_be_bytes(buffer.len() as u16);
282
283 debug!("sending message len: {} to: {}", buffer.len(), dst);
284 *send_state = Some(WriteTcpState::LenBytes {
285 pos: 0,
286 length: len,
287 bytes: buffer,
288 });
289 }
290 // now we get to drop through to the receives...
291 // TODO: should we also return None if there are no more messages to send?
292 Poll::Pending => break,
293 Poll::Ready(None) => {
294 debug!("no messages to send");
295 break;
296 }
297 }
298 }
299 }
300
301 let mut ret_buf = None;
302
303 // this will loop while there is data to read, or the data has been read, or an IO
304 // event would block
305 while ret_buf.is_none() {
306 // Evaluates the next state. If None is the result, then no state change occurs,
307 // if Some(_) is returned, then that will be used as the next state.
308 let new_state: Option<ReadTcpState> = match read_state {
309 ReadTcpState::LenBytes { pos, bytes } => {
310 // debug!("reading length {}", bytes.len());
311 let read = ready!(socket.as_mut().poll_read(cx, &mut bytes[*pos..]))?;
312 if read == 0 {
313 // the Stream was closed!
314 debug!("zero bytes read, stream closed?");
315 //try!(self.socket.shutdown(Shutdown::Both)); // TODO: add generic shutdown function
316
317 if *pos == 0 {
318 // Since this is the start of the next message, we have a clean end
319 return Poll::Ready(None);
320 } else {
321 return Poll::Ready(Some(Err(io::Error::new(
322 io::ErrorKind::BrokenPipe,
323 "closed while reading length",
324 ))));
325 }
326 }
327 trace!("in ReadTcpState::LenBytes: {}", pos);
328 *pos += read;
329
330 if *pos < bytes.len() {
331 trace!("remain ReadTcpState::LenBytes: {}", pos);
332 None
333 } else {
334 let length = u16::from_be_bytes(*bytes);
335 trace!("got length: {}", length);
336 let mut bytes = vec![0; length as usize];
337 bytes.resize(length as usize, 0);
338
339 trace!("move ReadTcpState::Bytes: {}", bytes.len());
340 Some(ReadTcpState::Bytes { pos: 0, bytes })
341 }
342 }
343 ReadTcpState::Bytes { pos, bytes } => {
344 let read = ready!(socket.as_mut().poll_read(cx, &mut bytes[*pos..]))?;
345 if read == 0 {
346 // the Stream was closed!
347 trace!("zero bytes read for message, stream closed?");
348
349 // Since this is the start of the next message, we have a clean end
350 // try!(self.socket.shutdown(Shutdown::Both)); // TODO: add generic shutdown function
351 return Poll::Ready(Some(Err(io::Error::new(
352 io::ErrorKind::BrokenPipe,
353 "closed while reading message",
354 ))));
355 }
356
357 trace!("in ReadTcpState::Bytes: {}", bytes.len());
358 *pos += read;
359
360 if *pos < bytes.len() {
361 trace!("remain ReadTcpState::Bytes: {}", bytes.len());
362 None
363 } else {
364 trace!("reset ReadTcpState::LenBytes: {}", 0);
365 Some(ReadTcpState::LenBytes {
366 pos: 0,
367 bytes: [0u8; 2],
368 })
369 }
370 }
371 };
372
373 // this will move to the next state,
374 // if it was a completed receipt of bytes, then it will move out the bytes
375 if let Some(state) = new_state {
376 if let ReadTcpState::Bytes { pos, bytes } = mem::replace(read_state, state) {
377 assert_eq!(pos, bytes.len());
378 ret_buf = Some(bytes);
379 }
380 }
381 }
382
383 // if the buffer is ready, return it, if not we're Pending
384 if let Some(buffer) = ret_buf {
385 let src_addr = self.peer_addr;
386 Poll::Ready(Some(Ok(SerialMessage::new(buffer, src_addr))))
387 } else {
388 debug!("bottomed out");
389 // at a minimum the outbound_messages should have been polled,
390 // which will wake this future up later...
391 Poll::Pending
392 }
393 }
394}
395
396#[cfg(test)]
397#[cfg(feature = "tokio")]
398mod tests {
399 use core::net::{IpAddr, Ipv4Addr, Ipv6Addr};
400
401 use test_support::subscribe;
402
403 use crate::runtime::TokioRuntimeProvider;
404 use crate::tcp::tests::tcp_stream_test;
405
406 #[tokio::test]
407 async fn test_tcp_stream_ipv4() {
408 subscribe();
409 tcp_stream_test(IpAddr::V4(Ipv4Addr::LOCALHOST), TokioRuntimeProvider::new()).await;
410 }
411
412 #[tokio::test]
413 async fn test_tcp_stream_ipv6() {
414 subscribe();
415 tcp_stream_test(
416 IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)),
417 TokioRuntimeProvider::new(),
418 )
419 .await;
420 }
421}