1use crate::log::Log;
22use byteorder::{LittleEndian, WriteBytesExt};
23use serde::{de::DeserializeOwned, Serialize};
24use std::{
25 io::{self, ErrorKind, Read, Write},
26 net::{SocketAddr, TcpListener, TcpStream, ToSocketAddrs},
27};
28
29pub struct NetListener {
30 listener: TcpListener,
31}
32
33impl NetListener {
34 pub fn bind<A: ToSocketAddrs>(addr: A) -> io::Result<Self> {
35 let listener = TcpListener::bind(addr)?;
36 listener.set_nonblocking(true)?;
37 Ok(Self { listener })
38 }
39
40 pub fn local_address(&self) -> io::Result<SocketAddr> {
41 self.listener.local_addr()
42 }
43
44 pub fn accept_connections(&self) -> Vec<NetStream> {
45 let mut streams = Vec::new();
46 while let Ok(result) = self.listener.accept() {
47 streams.push(NetStream::from_inner(result.0).unwrap())
48 }
49 streams
50 }
51}
52
53pub struct NetStream {
54 stream: TcpStream,
55 rx_buffer: Vec<u8>,
56 tx_buffer: Vec<u8>,
57}
58
59impl NetStream {
60 pub fn from_inner(stream: TcpStream) -> io::Result<Self> {
61 stream.set_nonblocking(true)?;
62 stream.set_nodelay(true)?;
63
64 Ok(Self {
65 stream,
66 rx_buffer: Default::default(),
67 tx_buffer: Default::default(),
68 })
69 }
70
71 pub fn connect<A: ToSocketAddrs>(addr: A) -> io::Result<Self> {
72 Self::from_inner(TcpStream::connect(addr)?)
73 }
74
75 pub fn send_message<T>(&mut self, data: &T) -> io::Result<()>
76 where
77 T: Serialize,
78 {
79 self.tx_buffer.clear();
80 if self.tx_buffer.capacity() < std::mem::size_of::<T>() {
81 self.tx_buffer.reserve(std::mem::size_of::<T>());
82 }
83 bincode::serialize_into(&mut self.tx_buffer, data).map_err(io::Error::other)?;
84 self.stream
85 .write_u32::<LittleEndian>(self.tx_buffer.len() as u32)?;
86 self.stream.write_all(&self.tx_buffer)?;
87 Ok(())
88 }
89
90 pub fn peer_address(&self) -> io::Result<SocketAddr> {
91 self.stream.peer_addr()
92 }
93
94 pub fn string_peer_address(&self) -> String {
95 self.peer_address()
96 .map(|addr| addr.to_string())
97 .unwrap_or_else(|_| "Unknown".into())
98 }
99
100 fn next_message<M>(&mut self) -> Option<M>
101 where
102 M: DeserializeOwned,
103 {
104 if self.rx_buffer.len() < 4 {
105 return None;
106 }
107
108 let length = u32::from_le_bytes([
109 self.rx_buffer[0],
110 self.rx_buffer[1],
111 self.rx_buffer[2],
112 self.rx_buffer[3],
113 ]) as usize;
114
115 let end = 4 + length;
116
117 if let Some(data) = self.rx_buffer.as_slice().get(4..end) {
119 let message = match bincode::deserialize::<M>(data) {
120 Ok(message) => Some(message),
121 Err(err) => {
122 Log::err(format!(
123 "Failed to parse a network message of {length} bytes long. Reason: {err:?}"
124 ));
125
126 None
127 }
128 };
129
130 self.rx_buffer.drain(..end);
131
132 message
133 } else {
134 None
135 }
136 }
137
138 pub fn process_input<M>(&mut self, mut func: impl FnMut(M))
139 where
140 M: DeserializeOwned,
141 {
142 loop {
144 let mut bytes = [0; 8192];
145 match self.stream.read(&mut bytes) {
146 Ok(bytes_count) => {
147 if bytes_count == 0 {
148 break;
149 } else {
150 self.rx_buffer.extend(&bytes[..bytes_count])
151 }
152 }
153 Err(err) => match err.kind() {
154 ErrorKind::WouldBlock => {
155 break;
156 }
157 ErrorKind::Interrupted => {
158 }
160 _ => {
161 Log::err(format!(
162 "An error occurred when reading data from socket: {err}"
163 ));
164
165 self.rx_buffer.clear();
166
167 return;
168 }
169 },
170 }
171 }
172
173 while let Some(message) = self.next_message() {
175 func(message)
176 }
177 }
178}