uni_stream/
stream.rs

1//! Provides an abstraction of Stream, as well as specific implementations for TCP and UDP
2
3use std::net::SocketAddr;
4
5use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
6use tokio::net::tcp::{ReadHalf, WriteHalf};
7use tokio::net::{TcpListener, TcpStream};
8
9use crate::udp::UdpListener;
10
11use super::addr::{each_addr, ToSocketAddrs};
12use super::udp::{UdpStream, UdpStreamReadHalf, UdpStreamWriteHalf};
13
14type Result<T, E = std::io::Error> = std::result::Result<T, E>;
15
16/// Used to abstract Stream operations, see [`tokio::net::TcpStream`] for details
17pub trait NetworkStream: AsyncReadExt + AsyncWriteExt + Send + Unpin + 'static {
18    /// The reader association type used to represent the read operation
19    type ReaderRef<'a>: AsyncReadExt + Send + Unpin + Send
20    where
21        Self: 'a;
22    /// The writer association type used to represent the write operation
23    type WriterRef<'a>: AsyncWriteExt + Send + Unpin + Send
24    where
25        Self: 'a;
26
27    /// Used to get internal specific implementations such as [`UdpStream`]
28    type InnerStream: AsyncReadExt + AsyncWriteExt + Unpin + Send;
29
30    /// Splitting the Stream into a read side and a write side is useful in scenarios where you need to use read and write separately.
31    fn split(&mut self) -> (Self::ReaderRef<'_>, Self::WriterRef<'_>);
32
33    /// Get the internal concrete implementation, note that this operation transfers ownership
34    fn into_inner_stream(self) -> Self::InnerStream;
35
36    /// get  local address
37    fn local_addr(&self) -> Result<SocketAddr>;
38
39    /// get  peer address
40    fn peer_addr(&self) -> Result<SocketAddr>;
41}
42
43macro_rules! gen_stream_impl {
44    ($struct_name:ident, $inner_ty:ty,$doc_string:literal) => {
45        #[doc = $doc_string]
46        pub struct $struct_name($inner_ty);
47
48        impl $struct_name {
49            /// create new struct
50            pub fn new(stream: $inner_ty) -> Self {
51                Self(stream)
52            }
53        }
54
55        impl AsyncRead for $struct_name {
56            fn poll_read(
57                mut self: std::pin::Pin<&mut Self>,
58                cx: &mut std::task::Context<'_>,
59                buf: &mut tokio::io::ReadBuf<'_>,
60            ) -> std::task::Poll<std::io::Result<()>> {
61                std::pin::Pin::new(&mut self.0).poll_read(cx, buf)
62            }
63        }
64
65        impl AsyncWrite for $struct_name {
66            fn poll_write(
67                mut self: std::pin::Pin<&mut Self>,
68                cx: &mut std::task::Context<'_>,
69                buf: &[u8],
70            ) -> std::task::Poll<std::prelude::v1::Result<usize, std::io::Error>> {
71                std::pin::Pin::new(&mut self.0).poll_write(cx, buf)
72            }
73
74            fn poll_flush(
75                mut self: std::pin::Pin<&mut Self>,
76                cx: &mut std::task::Context<'_>,
77            ) -> std::task::Poll<std::prelude::v1::Result<(), std::io::Error>> {
78                std::pin::Pin::new(&mut self.0).poll_flush(cx)
79            }
80
81            fn poll_shutdown(
82                mut self: std::pin::Pin<&mut Self>,
83                cx: &mut std::task::Context<'_>,
84            ) -> std::task::Poll<std::prelude::v1::Result<(), std::io::Error>> {
85                std::pin::Pin::new(&mut self.0).poll_shutdown(cx)
86            }
87        }
88    };
89}
90
91gen_stream_impl!(
92    TcpStreamImpl,
93    TcpStream,
94    "Implementing NetworkStream for TcpStream"
95);
96
97gen_stream_impl!(
98    UdpStreamImpl,
99    UdpStream,
100    "Implementing NetworkStream for UdpStream"
101);
102
103impl NetworkStream for TcpStreamImpl {
104    type ReaderRef<'a> = ReadHalf<'a>
105    where
106        Self: 'a;
107
108    type WriterRef<'a> = WriteHalf<'a>
109    where
110        Self: 'a;
111
112    type InnerStream = TcpStream;
113
114    fn split(&mut self) -> (Self::ReaderRef<'_>, Self::WriterRef<'_>) {
115        self.0.split()
116    }
117
118    fn into_inner_stream(self) -> Self::InnerStream {
119        self.0
120    }
121
122    fn local_addr(&self) -> Result<SocketAddr> {
123        self.0.local_addr()
124    }
125
126    fn peer_addr(&self) -> Result<SocketAddr> {
127        self.0.peer_addr()
128    }
129}
130
131impl NetworkStream for UdpStreamImpl {
132    type ReaderRef<'a> = UdpStreamReadHalf<'static>;
133
134    type WriterRef<'a> = UdpStreamWriteHalf<'a>
135    where
136        Self: 'a;
137
138    type InnerStream = UdpStream;
139
140    fn split(&mut self) -> (Self::ReaderRef<'_>, Self::WriterRef<'_>) {
141        self.0.split()
142    }
143
144    fn into_inner_stream(self) -> Self::InnerStream {
145        self.0
146    }
147
148    fn local_addr(&self) -> Result<SocketAddr> {
149        self.0.local_addr()
150    }
151
152    fn peer_addr(&self) -> Result<SocketAddr> {
153        self.0.peer_addr()
154    }
155}
156
157/// Provides an abstraction for connect
158pub trait StreamProvider {
159    /// Stream obtained after connect
160    type Item: NetworkStream;
161
162    /// Getting the Stream through a connection,
163    /// the only difference between this process and tokio::net::TcpStream::connect is that
164    /// it will be resolved through a customized dns service
165    fn connect<A: ToSocketAddrs + Send>(
166        addr: A,
167    ) -> impl std::future::Future<Output = Result<Self::Item>> + Send;
168}
169
170/// The medium used to get the [`TcpStreamImpl`]
171pub struct TcpStreamProvider;
172
173impl StreamProvider for TcpStreamProvider {
174    type Item = TcpStreamImpl;
175
176    async fn connect<A: ToSocketAddrs + Send>(addr: A) -> Result<Self::Item> {
177        Ok(TcpStreamImpl(each_addr(addr, TcpStream::connect).await?))
178    }
179}
180
181/// The medium used to get the [`UdpStreamImpl`]
182pub struct UdpStreamProvider;
183
184impl StreamProvider for UdpStreamProvider {
185    type Item = UdpStreamImpl;
186
187    async fn connect<A: ToSocketAddrs + Send>(addr: A) -> Result<Self::Item> {
188        Ok(UdpStreamImpl(UdpStream::connect(addr).await?))
189    }
190}
191
192/// Provides an abstraction for bind
193pub trait ListenerProvider {
194    /// Listener obtained after bind
195    type Listener: StreamAccept + 'static;
196
197    /// Getting the Listener through a binding,
198    /// the only difference between this process and `tokio::net::TcpListener::bind` is that
199    /// it will be resolved through a customized dns service
200    fn bind<A: ToSocketAddrs + Send>(
201        addr: A,
202    ) -> impl std::future::Future<Output = Result<Self::Listener>> + Send;
203}
204
205/// Abstractions for Listener-provided operations
206pub trait StreamAccept {
207    /// Stream obtained after accept
208    type Item: NetworkStream;
209
210    /// Listener waits to get new Stream
211    fn accept(&self) -> impl std::future::Future<Output = Result<(Self::Item, SocketAddr)>> + Send;
212}
213
214/// The medium used to get the [`TcpListenerImpl`]
215pub struct TcpListenerProvider;
216
217impl ListenerProvider for TcpListenerProvider {
218    type Listener = TcpListenerImpl;
219
220    async fn bind<A: ToSocketAddrs + Send>(addr: A) -> Result<Self::Listener> {
221        Ok(TcpListenerImpl(each_addr(addr, TcpListener::bind).await?))
222    }
223}
224
225/// Implementing [`StreamAccept`] for TcpListener
226pub struct TcpListenerImpl(TcpListener);
227
228impl StreamAccept for TcpListenerImpl {
229    type Item = TcpStreamImpl;
230
231    async fn accept(&self) -> Result<(Self::Item, SocketAddr)> {
232        let (stream, addr) = self.0.accept().await?;
233        Ok((TcpStreamImpl::new(stream), addr))
234    }
235}
236
237/// The medium used to get the [`TcpListenerImpl`]
238pub struct UdpListenerProvider;
239
240impl ListenerProvider for UdpListenerProvider {
241    type Listener = UdpListenerImpl;
242
243    async fn bind<A: ToSocketAddrs + Send>(addr: A) -> Result<Self::Listener> {
244        Ok(UdpListenerImpl(UdpListener::bind(addr).await?))
245    }
246}
247
248/// Implementing [`StreamAccept`] for [`UdpListener`]
249pub struct UdpListenerImpl(UdpListener);
250
251impl StreamAccept for UdpListenerImpl {
252    type Item = UdpStreamImpl;
253
254    async fn accept(&self) -> Result<(Self::Item, SocketAddr)> {
255        let (stream, addr) = self.0.accept().await?;
256        Ok((UdpStreamImpl::new(stream), addr))
257    }
258}