1use std::net::SocketAddr;
2use std::time::Duration;
3
4use async_trait::async_trait;
5use bytes::BufMut;
6use bytes::Bytes;
7use bytes::BytesMut;
8use futures::SinkExt;
9use futures::StreamExt;
10use tokio::net::TcpStream;
11use tokio::net::UdpSocket;
12use tokio_util::codec::BytesCodec;
13use tokio_util::codec::Framed;
14
15use super::bytesio_errors::{BytesIOError, BytesIOErrorValue};
16
17pub enum NetType {
18 TCP,
19 UDP,
20}
21
22#[async_trait]
23pub trait TNetIO: Send + Sync {
24 async fn write(&mut self, bytes: Bytes) -> Result<(), BytesIOError>;
25 async fn read(&mut self) -> Result<BytesMut, BytesIOError>;
26 async fn read_timeout(&mut self, duration: Duration) -> Result<BytesMut, BytesIOError>;
27 fn get_net_type(&self) -> NetType;
28}
29
30pub struct UdpIO {
31 socket: UdpSocket,
32}
33
34impl UdpIO {
35 pub async fn new(remote_domain: String, remote_port: u16, local_port: u16) -> Option<Self> {
36 let remote_address = if remote_domain == "localhost" {
37 format!("127.0.0.1:{remote_port}")
38 } else {
39 format!("{remote_domain}:{remote_port}")
40 };
41 log::info!("remote address: {}", remote_address);
42 let local_address = format!("0.0.0.0:{local_port}");
43 if let Ok(local_socket) = UdpSocket::bind(local_address).await {
44 if let Ok(remote_socket_addr) = remote_address.parse::<SocketAddr>() {
45 if let Err(err) = local_socket.connect(remote_socket_addr).await {
46 log::info!("connect to remote udp socket error: {}", err);
47 }
48
49 return Some(Self {
50 socket: local_socket,
51 });
52 } else {
53 log::error!("remote_address parse error: {:?}", remote_address);
54 }
55 }
56
57 None
58 }
59
60 pub async fn new_with_local_port(local_port: u16) -> Option<Self> {
61 let local_address = format!("0.0.0.0:{local_port}");
62
63 if let Ok(local_socket) = UdpSocket::bind(local_address).await {
64 return Some(Self {
65 socket: local_socket,
66 });
67 }
68 None
69 }
70
71 pub fn get_local_port(&self) -> Option<u16> {
72 if let Ok(local_addr) = self.socket.local_addr() {
73 log::info!("local address: {}", local_addr);
74 return Some(local_addr.port());
75 }
76
77 None
78 }
79}
80
81pub async fn new_udpio_pair() -> Option<(UdpIO, UdpIO)> {
82 let mut next_local_port = 0;
83 let first_local_port;
84
85 if let Some(udpio_0) = UdpIO::new_with_local_port(next_local_port).await {
87 if let Some(local_port_0) = udpio_0.get_local_port() {
88 first_local_port = local_port_0;
89 } else {
90 log::error!("cannot get local port");
91 return None;
92 }
93
94 if first_local_port == 65535 {
95 next_local_port = 1;
96 } else if let Some(udpio_1) = UdpIO::new_with_local_port(first_local_port + 1).await {
97 return Some((udpio_0, udpio_1));
98 } else if first_local_port + 1 == 65535 {
99 next_local_port = 1;
100 } else {
101 next_local_port = first_local_port + 2;
102 }
103 } else {
104 return None;
105 }
106
107 loop {
108 log::trace!("next local port: {next_local_port} and first port: {first_local_port}");
109
110 if next_local_port == 65535 {
111 next_local_port = 1;
112 continue;
113 }
114
115 if next_local_port == first_local_port {
116 return None;
117 }
118
119 if let Some(udpio_0) = UdpIO::new_with_local_port(next_local_port).await {
120 if let Some(udpio_1) = UdpIO::new_with_local_port(next_local_port + 1).await {
121 return Some((udpio_0, udpio_1));
122 } else if next_local_port + 1 == 65535 {
123 next_local_port = 1;
124 } else {
125 next_local_port += 2;
126 }
127 } else {
128 next_local_port += 1;
130 }
131 }
132 }
134
135#[async_trait]
136impl TNetIO for UdpIO {
137 fn get_net_type(&self) -> NetType {
138 NetType::UDP
139 }
140
141 async fn write(&mut self, bytes: Bytes) -> Result<(), BytesIOError> {
142 self.socket.send(bytes.as_ref()).await?;
143 Ok(())
144 }
145
146 async fn read_timeout(&mut self, duration: Duration) -> Result<BytesMut, BytesIOError> {
147 match tokio::time::timeout(duration, self.read()).await {
148 Ok(data) => data,
149 Err(err) => Err(BytesIOError {
150 value: BytesIOErrorValue::TimeoutError(err),
151 }),
152 }
153 }
154
155 async fn read(&mut self) -> Result<BytesMut, BytesIOError> {
156 let mut buf = vec![0; 4096];
157 let len = self.socket.recv(&mut buf).await?;
158 let mut rv = BytesMut::new();
159 rv.put(&buf[..len]);
160
161 Ok(rv)
162 }
163}
164
165pub struct TcpIO {
166 stream: Framed<TcpStream, BytesCodec>,
167 }
169
170impl TcpIO {
171 pub fn new(stream: TcpStream) -> Self {
172 Self {
173 stream: Framed::new(stream, BytesCodec::new()),
174 }
176 }
177}
178
179#[async_trait]
180impl TNetIO for TcpIO {
181 fn get_net_type(&self) -> NetType {
182 NetType::TCP
183 }
184
185 async fn write(&mut self, bytes: Bytes) -> Result<(), BytesIOError> {
186 self.stream.send(bytes).await?;
187
188 Ok(())
189 }
190
191 async fn read_timeout(&mut self, duration: Duration) -> Result<BytesMut, BytesIOError> {
192 match tokio::time::timeout(duration, self.read()).await {
193 Ok(data) => data,
194 Err(err) => Err(BytesIOError {
195 value: BytesIOErrorValue::TimeoutError(err),
196 }),
197 }
198 }
199
200 async fn read(&mut self) -> Result<BytesMut, BytesIOError> {
201 let message = self.stream.next().await;
202
203 match message {
204 Some(data) => match data {
205 Ok(bytes) => Ok(bytes),
206 Err(err) => Err(BytesIOError {
207 value: BytesIOErrorValue::IOError(err),
208 }),
209 },
210 None => Err(BytesIOError {
211 value: BytesIOErrorValue::NoneReturn,
212 }),
213 }
214 }
215}
216
217#[cfg(test)]
218mod tests {
219
220 use super::new_udpio_pair;
221 use super::UdpIO;
222
223 use tokio;
224
225 #[tokio::test]
226 async fn test_new_udpio_pair() {
227 if let Some((udpio1, udpid2)) = new_udpio_pair().await {
228 println!(
229 "{:?} == {:?}",
230 udpio1.get_local_port(),
231 udpid2.get_local_port()
232 );
233 }
234 }
235
236 #[tokio::test]
237 async fn test_new_udpio_pair2() {
238 println!("test_new_udpio_pair2 begin...");
239 let mut socket: Vec<UdpIO> = Vec::new();
240
241 for i in 1..=65535 {
242 println!("cur port:== {}", i);
243 println!("cur port: {}", i);
245 if let Some(udpio) = UdpIO::new_with_local_port(i).await {
246 socket.push(udpio)
247 } else {
248 println!("new local port fail: {}", i);
249 }
250 }
252
253 println!("socket size: {}", socket.len());
254
255 if let Some((udpio1, udpid2)) = new_udpio_pair().await {
256 println!(
257 "{:?} == {:?}",
258 udpio1.get_local_port(),
259 udpid2.get_local_port()
260 );
261 }
262 }
263
264 #[tokio::test]
265 async fn test_new_udpio_pair3() {
266 let mut first_local_port = 0;
269 if let Some(udpio_0) = UdpIO::new_with_local_port(0).await {
270 if let Some(local_port_0) = udpio_0.get_local_port() {
271 first_local_port = local_port_0;
272 }
273
274 }
276 println!("first_local_port: {}", first_local_port);
278
279 if (UdpIO::new_with_local_port(first_local_port).await).is_some() {
280 println!("success")
281 } else {
282 println!("fail")
283 }
284 }
285}