rd_interface/
util.rs

1use crate::{
2    interface::{
3        async_trait, AsyncRead, AsyncWrite, INet, ITcpStream, Net, TcpListener, TcpStream,
4        UdpChannel, UdpSocket,
5    },
6    Address, Context, Result, NOT_IMPLEMENTED,
7};
8use futures_util::future::try_join;
9use std::{
10    collections::VecDeque,
11    future::Future,
12    io,
13    net::SocketAddr,
14    pin::Pin,
15    task::{self, Poll},
16};
17pub use tokio::io::copy_bidirectional;
18use tokio::io::{AsyncReadExt, ReadBuf};
19pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
20
21/// Connect two `TcpStream`
22pub async fn connect_tcp(
23    t1: impl AsyncRead + AsyncWrite,
24    t2: impl AsyncRead + AsyncWrite,
25) -> io::Result<()> {
26    tokio::pin!(t1);
27    tokio::pin!(t2);
28    copy_bidirectional(&mut t1, &mut t2).await?;
29    Ok(())
30}
31
32pub struct PeekableTcpStream {
33    tcp: TcpStream,
34    buf: VecDeque<u8>,
35}
36
37impl AsyncRead for PeekableTcpStream {
38    fn poll_read(
39        mut self: Pin<&mut Self>,
40        cx: &mut task::Context<'_>,
41        buf: &mut ReadBuf,
42    ) -> Poll<io::Result<()>> {
43        let (first, ..) = &self.buf.as_slices();
44        if first.len() > 0 {
45            let read = first.len().min(buf.remaining());
46            let unfilled = buf.initialize_unfilled_to(read);
47            unfilled[0..read].copy_from_slice(&first[0..read]);
48            buf.advance(read);
49
50            // remove 0..read
51            self.buf.drain(0..read);
52
53            Poll::Ready(Ok(()))
54        } else {
55            Pin::new(&mut self.tcp).poll_read(cx, buf)
56        }
57    }
58}
59impl AsyncWrite for PeekableTcpStream {
60    fn poll_write(
61        mut self: Pin<&mut Self>,
62        cx: &mut task::Context<'_>,
63        buf: &[u8],
64    ) -> Poll<io::Result<usize>> {
65        Pin::new(&mut self.tcp).poll_write(cx, buf)
66    }
67
68    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
69        Pin::new(&mut self.tcp).poll_flush(cx)
70    }
71
72    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
73        Pin::new(&mut self.tcp).poll_shutdown(cx)
74    }
75}
76
77#[async_trait]
78impl ITcpStream for PeekableTcpStream {
79    async fn peer_addr(&self) -> crate::Result<SocketAddr> {
80        self.tcp.peer_addr().await
81    }
82
83    async fn local_addr(&self) -> crate::Result<SocketAddr> {
84        self.tcp.local_addr().await
85    }
86}
87
88impl PeekableTcpStream {
89    pub fn new(tcp: TcpStream) -> Self {
90        PeekableTcpStream {
91            tcp,
92            buf: VecDeque::new(),
93        }
94    }
95    // Fill self.buf to size using self.tcp.read_exact
96    async fn fill_buf(&mut self, size: usize) -> crate::Result<()> {
97        if size > self.buf.len() {
98            let to_read = size - self.buf.len();
99            let mut buf = vec![0u8; to_read];
100            self.tcp.read_exact(&mut buf).await?;
101            self.buf.append(&mut buf.into());
102        }
103        Ok(())
104    }
105    pub async fn peek_exact(&mut self, buf: &mut [u8]) -> crate::Result<()> {
106        self.fill_buf(buf.len()).await?;
107        let self_buf = self.buf.make_contiguous();
108        buf.copy_from_slice(&self_buf[0..buf.len()]);
109
110        Ok(())
111    }
112    pub fn into_inner(self) -> (TcpStream, VecDeque<u8>) {
113        (self.tcp, self.buf)
114    }
115}
116
117/// A no-op Net returns [`Error::NotImplemented`](crate::Error::NotImplemented) for every method.
118pub struct NotImplementedNet;
119
120#[async_trait]
121impl INet for NotImplementedNet {
122    async fn tcp_connect(&self, _ctx: &mut Context, _addr: Address) -> Result<TcpStream> {
123        Err(NOT_IMPLEMENTED)
124    }
125
126    async fn tcp_bind(&self, _ctx: &mut Context, _addr: Address) -> Result<TcpListener> {
127        Err(NOT_IMPLEMENTED)
128    }
129
130    async fn udp_bind(&self, _ctx: &mut Context, _addr: Address) -> Result<UdpSocket> {
131        Err(NOT_IMPLEMENTED)
132    }
133}
134
135/// A new Net calls [`tcp_connect()`](crate::INet::tcp_connect()), [`tcp_bind()`](crate::INet::tcp_bind()), [`udp_bind()`](crate::INet::udp_bind()) from different Net.
136pub struct CombineNet {
137    pub tcp_connect: Net,
138    pub tcp_bind: Net,
139    pub udp_bind: Net,
140}
141
142impl INet for CombineNet {
143    #[inline(always)]
144    fn tcp_connect<'life0: 'a, 'life1: 'a, 'a>(
145        &'life0 self,
146        ctx: &'life1 mut Context,
147        addr: Address,
148    ) -> BoxFuture<'a, Result<TcpStream>>
149    where
150        Self: 'a,
151    {
152        self.tcp_connect.tcp_connect(ctx, addr)
153    }
154
155    #[inline(always)]
156    fn tcp_bind<'life0: 'a, 'life1: 'a, 'a>(
157        &'life0 self,
158        ctx: &'life1 mut Context,
159        addr: Address,
160    ) -> BoxFuture<'a, Result<TcpListener>>
161    where
162        Self: 'a,
163    {
164        self.tcp_bind.tcp_bind(ctx, addr)
165    }
166
167    #[inline(always)]
168    fn udp_bind<'life0: 'a, 'life1: 'a, 'a>(
169        &'life0 self,
170        ctx: &'life1 mut Context,
171        addr: Address,
172    ) -> BoxFuture<'a, Result<UdpSocket>>
173    where
174        Self: 'a,
175    {
176        self.udp_bind.udp_bind(ctx, addr)
177    }
178}
179
180pub fn get_one_net(mut nets: Vec<Net>) -> Result<Net> {
181    if nets.len() != 1 {
182        return Err(crate::Error::Other("Must have one net".to_string().into()));
183    }
184
185    Ok(nets.remove(0))
186}
187
188pub async fn connect_udp(udp_channel: UdpChannel, udp: UdpSocket) -> crate::Result<()> {
189    let in_side = async {
190        let mut buf = [0u8; crate::constant::UDP_BUFFER_SIZE];
191        while let Ok((size, addr)) = udp_channel.recv_send_to(&mut buf).await {
192            let buf = &buf[..size];
193            udp.send_to(buf, addr).await?;
194        }
195        crate::Result::<()>::Ok(())
196    };
197    let out_side = async {
198        let mut buf = [0u8; crate::constant::UDP_BUFFER_SIZE];
199        while let Ok((size, addr)) = udp.recv_from(&mut buf).await {
200            let buf = &buf[..size];
201            udp_channel.send_recv_from(buf, addr).await?;
202        }
203        crate::Result::<()>::Ok(())
204    };
205    try_join(in_side, out_side).await?;
206    Ok(())
207}