1use crate::log::Log;
22use byteorder::{LittleEndian, WriteBytesExt};
23use serde::{de::DeserializeOwned, Serialize};
24use std::{
25 io::{self, ErrorKind, Read, Write},
26 net::{SocketAddr, TcpListener, TcpStream, ToSocketAddrs},
27};
28
29pub struct NetListener {
30 listener: TcpListener,
31}
32
33impl NetListener {
34 pub fn bind<A: ToSocketAddrs>(addr: A) -> io::Result<Self> {
35 let listener = TcpListener::bind(addr)?;
36 listener.set_nonblocking(true)?;
37 Ok(Self { listener })
38 }
39
40 pub fn local_address(&self) -> io::Result<SocketAddr> {
41 self.listener.local_addr()
42 }
43
44 pub fn accept_connections(&self) -> Vec<NetStream> {
45 let mut streams = Vec::new();
46 while let Ok(result) = self.listener.accept() {
47 streams.push(NetStream::from_inner(result.0).unwrap())
48 }
49 streams
50 }
51}
52
53pub struct NetStream {
54 stream: TcpStream,
55 rx_buffer: Vec<u8>,
56 tx_buffer: Vec<u8>,
57}
58
59impl NetStream {
60 pub fn from_inner(stream: TcpStream) -> io::Result<Self> {
61 stream.set_nonblocking(true)?;
62 stream.set_nodelay(true)?;
63
64 Ok(Self {
65 stream,
66 rx_buffer: Default::default(),
67 tx_buffer: Default::default(),
68 })
69 }
70
71 pub fn connect<A: ToSocketAddrs>(addr: A) -> io::Result<Self> {
72 Self::from_inner(TcpStream::connect(addr)?)
73 }
74
75 pub fn send_message<T>(&mut self, data: &T) -> io::Result<()>
76 where
77 T: Serialize,
78 {
79 self.tx_buffer.clear();
80 if self.tx_buffer.capacity() < std::mem::size_of::<T>() {
81 self.tx_buffer.reserve(std::mem::size_of::<T>());
82 }
83 bincode::serialize_into(&mut self.tx_buffer, data).map_err(io::Error::other)?;
84 self.stream
85 .write_u32::<LittleEndian>(self.tx_buffer.len() as u32)?;
86 self.stream.write_all(&self.tx_buffer)?;
87 Ok(())
88 }
89
90 pub fn peer_address(&self) -> io::Result<SocketAddr> {
91 self.stream.peer_addr()
92 }
93
94 pub fn string_peer_address(&self) -> String {
95 self.peer_address()
96 .map(|addr| addr.to_string())
97 .unwrap_or_else(|_| "Unknown".into())
98 }
99
100 fn next_message<M>(&mut self) -> Option<M>
101 where
102 M: DeserializeOwned,
103 {
104 if self.rx_buffer.len() < 4 {
105 return None;
106 }
107
108 let length = u32::from_le_bytes([
109 self.rx_buffer[0],
110 self.rx_buffer[1],
111 self.rx_buffer[2],
112 self.rx_buffer[3],
113 ]) as usize;
114
115 let end = 4 + length;
116
117 if let Some(data) = self.rx_buffer.as_slice().get(4..end) {
119 let message = match bincode::deserialize::<M>(data) {
120 Ok(message) => Some(message),
121 Err(err) => {
122 Log::err(format!(
123 "Failed to parse a network message of {length} bytes long. Reason: {err:?}"
124 ));
125
126 None
127 }
128 };
129
130 self.rx_buffer.drain(..end);
131
132 message
133 } else {
134 None
135 }
136 }
137
138 fn receive_bytes(&mut self) {
139 loop {
141 let mut bytes = [0; 8192];
142 match self.stream.read(&mut bytes) {
143 Ok(bytes_count) => {
144 if bytes_count == 0 {
145 break;
146 } else {
147 self.rx_buffer.extend(&bytes[..bytes_count])
148 }
149 }
150 Err(err) => match err.kind() {
151 ErrorKind::WouldBlock => {
152 break;
153 }
154 ErrorKind::Interrupted => {
155 }
157 _ => {
158 Log::err(format!(
159 "An error occurred when reading data from socket: {err}"
160 ));
161
162 self.rx_buffer.clear();
163
164 return;
165 }
166 },
167 }
168 }
169 }
170
171 pub fn process_input<M>(&mut self, mut func: impl FnMut(M))
172 where
173 M: DeserializeOwned,
174 {
175 self.receive_bytes();
176
177 while let Some(message) = self.next_message() {
179 func(message)
180 }
181 }
182
183 pub fn pop_message<M: DeserializeOwned>(&mut self) -> Option<M> {
184 self.receive_bytes();
185 self.next_message()
186 }
187}