1use crate::{
2 interface::{
3 async_trait, AsyncRead, AsyncWrite, INet, ITcpStream, Net, TcpListener, TcpStream,
4 UdpChannel, UdpSocket,
5 },
6 Address, Context, Result, NOT_IMPLEMENTED,
7};
8use futures_util::future::try_join;
9use std::{
10 collections::VecDeque,
11 future::Future,
12 io,
13 net::SocketAddr,
14 pin::Pin,
15 task::{self, Poll},
16};
17pub use tokio::io::copy_bidirectional;
18use tokio::io::{AsyncReadExt, ReadBuf};
19pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
20
21pub async fn connect_tcp(
23 t1: impl AsyncRead + AsyncWrite,
24 t2: impl AsyncRead + AsyncWrite,
25) -> io::Result<()> {
26 tokio::pin!(t1);
27 tokio::pin!(t2);
28 copy_bidirectional(&mut t1, &mut t2).await?;
29 Ok(())
30}
31
32pub struct PeekableTcpStream {
33 tcp: TcpStream,
34 buf: VecDeque<u8>,
35}
36
37impl AsyncRead for PeekableTcpStream {
38 fn poll_read(
39 mut self: Pin<&mut Self>,
40 cx: &mut task::Context<'_>,
41 buf: &mut ReadBuf,
42 ) -> Poll<io::Result<()>> {
43 let (first, ..) = &self.buf.as_slices();
44 if first.len() > 0 {
45 let read = first.len().min(buf.remaining());
46 let unfilled = buf.initialize_unfilled_to(read);
47 unfilled[0..read].copy_from_slice(&first[0..read]);
48 buf.advance(read);
49
50 self.buf.drain(0..read);
52
53 Poll::Ready(Ok(()))
54 } else {
55 Pin::new(&mut self.tcp).poll_read(cx, buf)
56 }
57 }
58}
59impl AsyncWrite for PeekableTcpStream {
60 fn poll_write(
61 mut self: Pin<&mut Self>,
62 cx: &mut task::Context<'_>,
63 buf: &[u8],
64 ) -> Poll<io::Result<usize>> {
65 Pin::new(&mut self.tcp).poll_write(cx, buf)
66 }
67
68 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
69 Pin::new(&mut self.tcp).poll_flush(cx)
70 }
71
72 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
73 Pin::new(&mut self.tcp).poll_shutdown(cx)
74 }
75}
76
77#[async_trait]
78impl ITcpStream for PeekableTcpStream {
79 async fn peer_addr(&self) -> crate::Result<SocketAddr> {
80 self.tcp.peer_addr().await
81 }
82
83 async fn local_addr(&self) -> crate::Result<SocketAddr> {
84 self.tcp.local_addr().await
85 }
86}
87
88impl PeekableTcpStream {
89 pub fn new(tcp: TcpStream) -> Self {
90 PeekableTcpStream {
91 tcp,
92 buf: VecDeque::new(),
93 }
94 }
95 async fn fill_buf(&mut self, size: usize) -> crate::Result<()> {
97 if size > self.buf.len() {
98 let to_read = size - self.buf.len();
99 let mut buf = vec![0u8; to_read];
100 self.tcp.read_exact(&mut buf).await?;
101 self.buf.append(&mut buf.into());
102 }
103 Ok(())
104 }
105 pub async fn peek_exact(&mut self, buf: &mut [u8]) -> crate::Result<()> {
106 self.fill_buf(buf.len()).await?;
107 let self_buf = self.buf.make_contiguous();
108 buf.copy_from_slice(&self_buf[0..buf.len()]);
109
110 Ok(())
111 }
112 pub fn into_inner(self) -> (TcpStream, VecDeque<u8>) {
113 (self.tcp, self.buf)
114 }
115}
116
117pub struct NotImplementedNet;
119
120#[async_trait]
121impl INet for NotImplementedNet {
122 async fn tcp_connect(&self, _ctx: &mut Context, _addr: Address) -> Result<TcpStream> {
123 Err(NOT_IMPLEMENTED)
124 }
125
126 async fn tcp_bind(&self, _ctx: &mut Context, _addr: Address) -> Result<TcpListener> {
127 Err(NOT_IMPLEMENTED)
128 }
129
130 async fn udp_bind(&self, _ctx: &mut Context, _addr: Address) -> Result<UdpSocket> {
131 Err(NOT_IMPLEMENTED)
132 }
133}
134
135pub struct CombineNet {
137 pub tcp_connect: Net,
138 pub tcp_bind: Net,
139 pub udp_bind: Net,
140}
141
142impl INet for CombineNet {
143 #[inline(always)]
144 fn tcp_connect<'life0: 'a, 'life1: 'a, 'a>(
145 &'life0 self,
146 ctx: &'life1 mut Context,
147 addr: Address,
148 ) -> BoxFuture<'a, Result<TcpStream>>
149 where
150 Self: 'a,
151 {
152 self.tcp_connect.tcp_connect(ctx, addr)
153 }
154
155 #[inline(always)]
156 fn tcp_bind<'life0: 'a, 'life1: 'a, 'a>(
157 &'life0 self,
158 ctx: &'life1 mut Context,
159 addr: Address,
160 ) -> BoxFuture<'a, Result<TcpListener>>
161 where
162 Self: 'a,
163 {
164 self.tcp_bind.tcp_bind(ctx, addr)
165 }
166
167 #[inline(always)]
168 fn udp_bind<'life0: 'a, 'life1: 'a, 'a>(
169 &'life0 self,
170 ctx: &'life1 mut Context,
171 addr: Address,
172 ) -> BoxFuture<'a, Result<UdpSocket>>
173 where
174 Self: 'a,
175 {
176 self.udp_bind.udp_bind(ctx, addr)
177 }
178}
179
180pub fn get_one_net(mut nets: Vec<Net>) -> Result<Net> {
181 if nets.len() != 1 {
182 return Err(crate::Error::Other("Must have one net".to_string().into()));
183 }
184
185 Ok(nets.remove(0))
186}
187
188pub async fn connect_udp(udp_channel: UdpChannel, udp: UdpSocket) -> crate::Result<()> {
189 let in_side = async {
190 let mut buf = [0u8; crate::constant::UDP_BUFFER_SIZE];
191 while let Ok((size, addr)) = udp_channel.recv_send_to(&mut buf).await {
192 let buf = &buf[..size];
193 udp.send_to(buf, addr).await?;
194 }
195 crate::Result::<()>::Ok(())
196 };
197 let out_side = async {
198 let mut buf = [0u8; crate::constant::UDP_BUFFER_SIZE];
199 while let Ok((size, addr)) = udp.recv_from(&mut buf).await {
200 let buf = &buf[..size];
201 udp_channel.send_recv_from(buf, addr).await?;
202 }
203 crate::Result::<()>::Ok(())
204 };
205 try_join(in_side, out_side).await?;
206 Ok(())
207}