bytesio/
bytesio.rs

1use std::net::SocketAddr;
2use std::time::Duration;
3
4use async_trait::async_trait;
5use bytes::BufMut;
6use bytes::Bytes;
7use bytes::BytesMut;
8use futures::SinkExt;
9use futures::StreamExt;
10use tokio::net::TcpStream;
11use tokio::net::UdpSocket;
12use tokio_util::codec::BytesCodec;
13use tokio_util::codec::Framed;
14
15use super::bytesio_errors::{BytesIOError, BytesIOErrorValue};
16
17pub enum NetType {
18    TCP,
19    UDP,
20}
21
22#[async_trait]
23pub trait TNetIO: Send + Sync {
24    async fn write(&mut self, bytes: Bytes) -> Result<(), BytesIOError>;
25    async fn read(&mut self) -> Result<BytesMut, BytesIOError>;
26    async fn read_timeout(&mut self, duration: Duration) -> Result<BytesMut, BytesIOError>;
27    fn get_net_type(&self) -> NetType;
28}
29
30pub struct UdpIO {
31    socket: UdpSocket,
32}
33
34impl UdpIO {
35    pub async fn new(remote_domain: String, remote_port: u16, local_port: u16) -> Option<Self> {
36        let remote_address = if remote_domain == "localhost" {
37            format!("127.0.0.1:{remote_port}")
38        } else {
39            format!("{remote_domain}:{remote_port}")
40        };
41        log::info!("remote address: {}", remote_address);
42        let local_address = format!("0.0.0.0:{local_port}");
43        if let Ok(local_socket) = UdpSocket::bind(local_address).await {
44            if let Ok(remote_socket_addr) = remote_address.parse::<SocketAddr>() {
45                if let Err(err) = local_socket.connect(remote_socket_addr).await {
46                    log::info!("connect to remote udp socket error: {}", err);
47                }
48
49                return Some(Self {
50                    socket: local_socket,
51                });
52            } else {
53                log::error!("remote_address parse error: {:?}", remote_address);
54            }
55        }
56
57        None
58    }
59
60    pub async fn new_with_local_port(local_port: u16) -> Option<Self> {
61        let local_address = format!("0.0.0.0:{local_port}");
62
63        if let Ok(local_socket) = UdpSocket::bind(local_address).await {
64            return Some(Self {
65                socket: local_socket,
66            });
67        }
68        None
69    }
70
71    pub fn get_local_port(&self) -> Option<u16> {
72        if let Ok(local_addr) = self.socket.local_addr() {
73            log::info!("local address: {}", local_addr);
74            return Some(local_addr.port());
75        }
76
77        None
78    }
79}
80
81pub async fn new_udpio_pair() -> Option<(UdpIO, UdpIO)> {
82    let mut next_local_port = 0;
83    let first_local_port;
84
85    // get the first available port
86    if let Some(udpio_0) = UdpIO::new_with_local_port(next_local_port).await {
87        if let Some(local_port_0) = udpio_0.get_local_port() {
88            first_local_port = local_port_0;
89        } else {
90            log::error!("cannot get local port");
91            return None;
92        }
93
94        if first_local_port == 65535 {
95            next_local_port = 1;
96        } else if let Some(udpio_1) = UdpIO::new_with_local_port(first_local_port + 1).await {
97            return Some((udpio_0, udpio_1));
98        } else if first_local_port + 1 == 65535 {
99            next_local_port = 1;
100        } else {
101            next_local_port = first_local_port + 2;
102        }
103    } else {
104        return None;
105    }
106
107    loop {
108        log::trace!("next local port: {next_local_port} and first port: {first_local_port}");
109
110        if next_local_port == 65535 {
111            next_local_port = 1;
112            continue;
113        }
114
115        if next_local_port == first_local_port {
116            return None;
117        }
118
119        if let Some(udpio_0) = UdpIO::new_with_local_port(next_local_port).await {
120            if let Some(udpio_1) = UdpIO::new_with_local_port(next_local_port + 1).await {
121                return Some((udpio_0, udpio_1));
122            } else if next_local_port + 1 == 65535 {
123                next_local_port = 1;
124            } else {
125                next_local_port += 2;
126            }
127        } else {
128            // try next port
129            next_local_port += 1;
130        }
131    }
132    //None
133}
134
135#[async_trait]
136impl TNetIO for UdpIO {
137    fn get_net_type(&self) -> NetType {
138        NetType::UDP
139    }
140
141    async fn write(&mut self, bytes: Bytes) -> Result<(), BytesIOError> {
142        self.socket.send(bytes.as_ref()).await?;
143        Ok(())
144    }
145
146    async fn read_timeout(&mut self, duration: Duration) -> Result<BytesMut, BytesIOError> {
147        match tokio::time::timeout(duration, self.read()).await {
148            Ok(data) => data,
149            Err(err) => Err(BytesIOError {
150                value: BytesIOErrorValue::TimeoutError(err),
151            }),
152        }
153    }
154
155    async fn read(&mut self) -> Result<BytesMut, BytesIOError> {
156        let mut buf = vec![0; 4096];
157        let len = self.socket.recv(&mut buf).await?;
158        let mut rv = BytesMut::new();
159        rv.put(&buf[..len]);
160
161        Ok(rv)
162    }
163}
164
165pub struct TcpIO {
166    stream: Framed<TcpStream, BytesCodec>,
167    //timeout: Duration,
168}
169
170impl TcpIO {
171    pub fn new(stream: TcpStream) -> Self {
172        Self {
173            stream: Framed::new(stream, BytesCodec::new()),
174            // timeout: ms,
175        }
176    }
177}
178
179#[async_trait]
180impl TNetIO for TcpIO {
181    fn get_net_type(&self) -> NetType {
182        NetType::TCP
183    }
184
185    async fn write(&mut self, bytes: Bytes) -> Result<(), BytesIOError> {
186        self.stream.send(bytes).await?;
187
188        Ok(())
189    }
190
191    async fn read_timeout(&mut self, duration: Duration) -> Result<BytesMut, BytesIOError> {
192        match tokio::time::timeout(duration, self.read()).await {
193            Ok(data) => data,
194            Err(err) => Err(BytesIOError {
195                value: BytesIOErrorValue::TimeoutError(err),
196            }),
197        }
198    }
199
200    async fn read(&mut self) -> Result<BytesMut, BytesIOError> {
201        let message = self.stream.next().await;
202
203        match message {
204            Some(data) => match data {
205                Ok(bytes) => Ok(bytes),
206                Err(err) => Err(BytesIOError {
207                    value: BytesIOErrorValue::IOError(err),
208                }),
209            },
210            None => Err(BytesIOError {
211                value: BytesIOErrorValue::NoneReturn,
212            }),
213        }
214    }
215}
216
217#[cfg(test)]
218mod tests {
219
220    use super::new_udpio_pair;
221    use super::UdpIO;
222
223    use tokio;
224
225    #[tokio::test]
226    async fn test_new_udpio_pair() {
227        if let Some((udpio1, udpid2)) = new_udpio_pair().await {
228            println!(
229                "{:?} == {:?}",
230                udpio1.get_local_port(),
231                udpid2.get_local_port()
232            );
233        }
234    }
235
236    #[tokio::test]
237    async fn test_new_udpio_pair2() {
238        println!("test_new_udpio_pair2 begin...");
239        let mut socket: Vec<UdpIO> = Vec::new();
240
241        for i in 1..=65535 {
242            println!("cur port:== {}", i);
243            //if i % 2 == 1 {
244            println!("cur port: {}", i);
245            if let Some(udpio) = UdpIO::new_with_local_port(i).await {
246                socket.push(udpio)
247            } else {
248                println!("new local port fail: {}", i);
249            }
250            //}
251        }
252
253        println!("socket size: {}", socket.len());
254
255        if let Some((udpio1, udpid2)) = new_udpio_pair().await {
256            println!(
257                "{:?} == {:?}",
258                udpio1.get_local_port(),
259                udpid2.get_local_port()
260            );
261        }
262    }
263
264    #[tokio::test]
265    async fn test_new_udpio_pair3() {
266        // get the first available port
267
268        let mut first_local_port = 0;
269        if let Some(udpio_0) = UdpIO::new_with_local_port(0).await {
270            if let Some(local_port_0) = udpio_0.get_local_port() {
271                first_local_port = local_port_0;
272            }
273
274            // std::mem::drop(udpio_0);
275        }
276        //The object udpio_0 is automatically cleared and released when it goes out of scope here.
277        println!("first_local_port: {}", first_local_port);
278
279        if (UdpIO::new_with_local_port(first_local_port).await).is_some() {
280            println!("success")
281        } else {
282            println!("fail")
283        }
284    }
285}