Skip to main content

hickory_net/tcp/
tcp_stream.rs

1// Copyright 2015-2016 Benjamin Fry <benjaminfry@me.com>
2//
3// Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or
4// https://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// https://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8//! This module contains all the TCP structures for demuxing TCP into streams of DNS packets.
9
10use core::mem;
11use core::net::SocketAddr;
12use core::pin::Pin;
13use core::task::{Context, Poll};
14use core::time::Duration;
15use std::io;
16
17use futures_io::IoSlice;
18use futures_util::stream::Stream;
19use futures_util::{self, future::Future, ready};
20use tracing::{debug, trace};
21
22use crate::proto::op::SerialMessage;
23use crate::runtime::{DnsTcpStream, Time};
24use crate::xfer::{BufDnsStreamHandle, StreamReceiver};
25
26/// Current state while writing to the remote of the TCP connection
27enum WriteTcpState {
28    /// Currently writing the length of bytes to of the buffer.
29    LenBytes {
30        /// Current position in the length buffer being written
31        pos: usize,
32        /// Length of the buffer
33        length: [u8; 2],
34        /// Buffer to write after the length
35        bytes: Vec<u8>,
36    },
37    /// Currently writing the buffer to the remote
38    Bytes {
39        /// Current position in the buffer written
40        pos: usize,
41        /// Buffer to write to the remote
42        bytes: Vec<u8>,
43    },
44    /// Currently flushing the bytes to the remote
45    Flushing,
46}
47
48/// Current state of a TCP stream as it's being read.
49pub(crate) enum ReadTcpState {
50    /// Currently reading the length of the TCP packet
51    LenBytes {
52        /// Current position in the buffer
53        pos: usize,
54        /// Buffer of the length to read
55        bytes: [u8; 2],
56    },
57    /// Currently reading the bytes of the DNS packet
58    Bytes {
59        /// Current position while reading the buffer
60        pos: usize,
61        /// buffer being read into
62        bytes: Vec<u8>,
63    },
64}
65
66/// A Stream used for sending data to and from a remote DNS endpoint (client or server).
67#[must_use = "futures do nothing unless polled"]
68pub struct TcpStream<S: DnsTcpStream> {
69    socket: S,
70    outbound_messages: StreamReceiver,
71    send_state: Option<WriteTcpState>,
72    read_state: ReadTcpState,
73    peer_addr: SocketAddr,
74}
75
76impl<S: DnsTcpStream> TcpStream<S> {
77    /// Returns the address of the peer connection.
78    pub fn peer_addr(&self) -> SocketAddr {
79        self.peer_addr
80    }
81
82    fn pollable_split(
83        &mut self,
84    ) -> (
85        &mut S,
86        &mut StreamReceiver,
87        &mut Option<WriteTcpState>,
88        &mut ReadTcpState,
89    ) {
90        (
91            &mut self.socket,
92            &mut self.outbound_messages,
93            &mut self.send_state,
94            &mut self.read_state,
95        )
96    }
97
98    /// Initializes a TcpStream with an established connection.
99    ///
100    /// Uses the default buffer size (32) for the outbound message queue.
101    ///
102    /// # Arguments
103    ///
104    /// * `stream` - the established IO stream for communication
105    /// * `peer_addr` - address of the remote peer
106    pub fn from_stream(stream: S, peer_addr: SocketAddr) -> (Self, BufDnsStreamHandle) {
107        let (message_sender, outbound_messages) = BufDnsStreamHandle::new(peer_addr);
108        let stream = Self::from_stream_with_receiver(stream, peer_addr, outbound_messages);
109        (stream, message_sender)
110    }
111
112    /// Initializes a TcpStream with an established connection and explicit buffer size.
113    ///
114    /// Use this when you need a larger buffer to handle high message rates without
115    /// dropping messages due to backpressure.
116    ///
117    /// # Arguments
118    ///
119    /// * `stream` - the established IO stream for communication
120    /// * `peer_addr` - address of the remote peer
121    /// * `buffer_size` - maximum number of messages that can be queued for sending
122    pub fn from_stream_with_buffer_size(
123        stream: S,
124        peer_addr: SocketAddr,
125        buffer_size: usize,
126    ) -> (Self, BufDnsStreamHandle) {
127        let (message_sender, outbound_messages) =
128            BufDnsStreamHandle::with_buffer_size(peer_addr, buffer_size);
129        let stream = Self::from_stream_with_receiver(stream, peer_addr, outbound_messages);
130        (stream, message_sender)
131    }
132
133    /// Wraps a stream where a sender and receiver have already been established
134    pub fn from_stream_with_receiver(
135        socket: S,
136        peer_addr: SocketAddr,
137        outbound_messages: StreamReceiver,
138    ) -> Self {
139        Self {
140            socket,
141            outbound_messages,
142            send_state: None,
143            read_state: ReadTcpState::LenBytes {
144                pos: 0,
145                bytes: [0u8; 2],
146            },
147            peer_addr,
148        }
149    }
150
151    /// Creates a new future of the eventually establish a IO stream connection or fail trying
152    ///
153    /// # Arguments
154    ///
155    /// * `future` - underlying stream future which this tcp stream relies on
156    /// * `name_server` - the IP and Port of the DNS server to connect to
157    /// * `timeout` - connection timeout
158    pub fn with_future<F: Future<Output = Result<S, io::Error>> + Send + 'static>(
159        future: F,
160        name_server: SocketAddr,
161        timeout: Duration,
162    ) -> (
163        impl Future<Output = Result<Self, io::Error>> + Send,
164        BufDnsStreamHandle,
165    ) {
166        let (message_sender, outbound_messages) = BufDnsStreamHandle::new(name_server);
167        let stream_fut = Self::connect_with_future(future, name_server, timeout, outbound_messages);
168
169        (stream_fut, message_sender)
170    }
171
172    async fn connect_with_future<F: Future<Output = Result<S, io::Error>> + Send + 'static>(
173        future: F,
174        name_server: SocketAddr,
175        timeout: Duration,
176        outbound_messages: StreamReceiver,
177    ) -> Result<Self, io::Error> {
178        let socket = S::Time::timeout(timeout, future).await??;
179        debug!("TCP connection established to: {}", name_server);
180        Ok(Self {
181            socket,
182            outbound_messages,
183            send_state: None,
184            read_state: ReadTcpState::LenBytes {
185                pos: 0,
186                bytes: [0u8; 2],
187            },
188            peer_addr: name_server,
189        })
190    }
191}
192
193impl<S: DnsTcpStream> Stream for TcpStream<S> {
194    type Item = io::Result<SerialMessage>;
195
196    #[allow(clippy::cognitive_complexity)]
197    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
198        let peer = self.peer_addr;
199        let (socket, outbound_messages, send_state, read_state) = self.pollable_split();
200        let mut socket = Pin::new(socket);
201        let mut outbound_messages = Pin::new(outbound_messages);
202
203        // this will not accept incoming data while there is data to send
204        //  makes this self throttling.
205        // TODO: it might be interesting to try and split the sending and receiving futures.
206        loop {
207            // in the case we are sending, send it all?
208            if send_state.is_some() {
209                // sending...
210                match send_state {
211                    Some(WriteTcpState::LenBytes { pos, length, bytes }) => {
212                        let wrote = ready!(socket.as_mut().poll_write_vectored(
213                            cx,
214                            &[IoSlice::new(&length[*pos..]), IoSlice::new(bytes)]
215                        ))?;
216                        *pos += wrote;
217                    }
218                    Some(WriteTcpState::Bytes { pos, bytes }) => {
219                        let wrote = ready!(socket.as_mut().poll_write(cx, &bytes[*pos..]))?;
220                        *pos += wrote;
221                    }
222                    Some(WriteTcpState::Flushing) => {
223                        ready!(socket.as_mut().poll_flush(cx))?;
224                    }
225                    _ => (),
226                }
227
228                // get current state
229                let current_state = send_state.take();
230
231                // switch states
232                match current_state {
233                    Some(WriteTcpState::LenBytes { pos, length, bytes }) => {
234                        if pos < length.len() {
235                            *send_state = Some(WriteTcpState::LenBytes { pos, length, bytes });
236                        } else if pos < length.len() + bytes.len() {
237                            *send_state = Some(WriteTcpState::Bytes {
238                                pos: pos - length.len(),
239                                bytes,
240                            });
241                        } else {
242                            *send_state = Some(WriteTcpState::Flushing);
243                        }
244                    }
245                    Some(WriteTcpState::Bytes { pos, bytes }) => {
246                        if pos < bytes.len() {
247                            *send_state = Some(WriteTcpState::Bytes { pos, bytes });
248                        } else {
249                            // At this point we successfully delivered the entire message.
250                            //  flush
251                            *send_state = Some(WriteTcpState::Flushing);
252                        }
253                    }
254                    Some(WriteTcpState::Flushing) => {
255                        // At this point we successfully delivered the entire message.
256                        send_state.take();
257                    }
258                    None => (),
259                };
260            } else {
261                // then see if there is more to send
262                match outbound_messages.as_mut().poll_next(cx)
263                    // .map_err(|()| io::Error::new(io::ErrorKind::Other, "unknown"))?
264                {
265                    // already handled above, here to make sure the poll() pops the next message
266                    Poll::Ready(Some(message)) => {
267                        // if there is no peer, this connection should die...
268                        let (buffer, dst) = message.into();
269
270                        // This is an error if the destination is not our peer (this is TCP after all)
271                        //  This will kill the connection...
272                        if peer != dst {
273                            return Poll::Ready(Some(Err(io::Error::new(
274                                io::ErrorKind::InvalidData,
275                                format!("mismatched peer: {peer} and dst: {dst}"),
276                            ))));
277                        }
278
279                        // will return if the socket will block
280                        // the length is 16 bits
281                        let len = u16::to_be_bytes(buffer.len() as u16);
282
283                        debug!("sending message len: {} to: {}", buffer.len(), dst);
284                        *send_state = Some(WriteTcpState::LenBytes {
285                            pos: 0,
286                            length: len,
287                            bytes: buffer,
288                        });
289                    }
290                    // now we get to drop through to the receives...
291                    // TODO: should we also return None if there are no more messages to send?
292                    Poll::Pending => break,
293                    Poll::Ready(None) => {
294                        debug!("no messages to send");
295                        break;
296                    }
297                }
298            }
299        }
300
301        let mut ret_buf = None;
302
303        // this will loop while there is data to read, or the data has been read, or an IO
304        //  event would block
305        while ret_buf.is_none() {
306            // Evaluates the next state. If None is the result, then no state change occurs,
307            //  if Some(_) is returned, then that will be used as the next state.
308            let new_state: Option<ReadTcpState> = match read_state {
309                ReadTcpState::LenBytes { pos, bytes } => {
310                    // debug!("reading length {}", bytes.len());
311                    let read = ready!(socket.as_mut().poll_read(cx, &mut bytes[*pos..]))?;
312                    if read == 0 {
313                        // the Stream was closed!
314                        debug!("zero bytes read, stream closed?");
315                        //try!(self.socket.shutdown(Shutdown::Both)); // TODO: add generic shutdown function
316
317                        if *pos == 0 {
318                            // Since this is the start of the next message, we have a clean end
319                            return Poll::Ready(None);
320                        } else {
321                            return Poll::Ready(Some(Err(io::Error::new(
322                                io::ErrorKind::BrokenPipe,
323                                "closed while reading length",
324                            ))));
325                        }
326                    }
327                    trace!("in ReadTcpState::LenBytes: {}", pos);
328                    *pos += read;
329
330                    if *pos < bytes.len() {
331                        trace!("remain ReadTcpState::LenBytes: {}", pos);
332                        None
333                    } else {
334                        let length = u16::from_be_bytes(*bytes);
335                        trace!("got length: {}", length);
336                        let mut bytes = vec![0; length as usize];
337                        bytes.resize(length as usize, 0);
338
339                        trace!("move ReadTcpState::Bytes: {}", bytes.len());
340                        Some(ReadTcpState::Bytes { pos: 0, bytes })
341                    }
342                }
343                ReadTcpState::Bytes { pos, bytes } => {
344                    let read = ready!(socket.as_mut().poll_read(cx, &mut bytes[*pos..]))?;
345                    if read == 0 {
346                        // the Stream was closed!
347                        trace!("zero bytes read for message, stream closed?");
348
349                        // Since this is the start of the next message, we have a clean end
350                        // try!(self.socket.shutdown(Shutdown::Both));  // TODO: add generic shutdown function
351                        return Poll::Ready(Some(Err(io::Error::new(
352                            io::ErrorKind::BrokenPipe,
353                            "closed while reading message",
354                        ))));
355                    }
356
357                    trace!("in ReadTcpState::Bytes: {}", bytes.len());
358                    *pos += read;
359
360                    if *pos < bytes.len() {
361                        trace!("remain ReadTcpState::Bytes: {}", bytes.len());
362                        None
363                    } else {
364                        trace!("reset ReadTcpState::LenBytes: {}", 0);
365                        Some(ReadTcpState::LenBytes {
366                            pos: 0,
367                            bytes: [0u8; 2],
368                        })
369                    }
370                }
371            };
372
373            // this will move to the next state,
374            //  if it was a completed receipt of bytes, then it will move out the bytes
375            if let Some(state) = new_state {
376                if let ReadTcpState::Bytes { pos, bytes } = mem::replace(read_state, state) {
377                    assert_eq!(pos, bytes.len());
378                    ret_buf = Some(bytes);
379                }
380            }
381        }
382
383        // if the buffer is ready, return it, if not we're Pending
384        if let Some(buffer) = ret_buf {
385            let src_addr = self.peer_addr;
386            Poll::Ready(Some(Ok(SerialMessage::new(buffer, src_addr))))
387        } else {
388            debug!("bottomed out");
389            // at a minimum the outbound_messages should have been polled,
390            //  which will wake this future up later...
391            Poll::Pending
392        }
393    }
394}
395
396#[cfg(test)]
397#[cfg(feature = "tokio")]
398mod tests {
399    use core::net::{IpAddr, Ipv4Addr, Ipv6Addr};
400
401    use test_support::subscribe;
402
403    use crate::runtime::TokioRuntimeProvider;
404    use crate::tcp::tests::tcp_stream_test;
405
406    #[tokio::test]
407    async fn test_tcp_stream_ipv4() {
408        subscribe();
409        tcp_stream_test(IpAddr::V4(Ipv4Addr::LOCALHOST), TokioRuntimeProvider::new()).await;
410    }
411
412    #[tokio::test]
413    async fn test_tcp_stream_ipv6() {
414        subscribe();
415        tcp_stream_test(
416            IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)),
417            TokioRuntimeProvider::new(),
418        )
419        .await;
420    }
421}