1use crate::log::Log;
22use byteorder::{LittleEndian, WriteBytesExt};
23use serde::{de::DeserializeOwned, Serialize};
24use std::{
25 io::{self, ErrorKind, Read, Write},
26 net::{SocketAddr, TcpListener, TcpStream, ToSocketAddrs},
27};
28
29pub struct NetListener {
30 listener: TcpListener,
31}
32
33impl NetListener {
34 pub fn bind<A: ToSocketAddrs>(addr: A) -> io::Result<Self> {
35 let listener = TcpListener::bind(addr)?;
36 listener.set_nonblocking(true)?;
37 Ok(Self { listener })
38 }
39
40 pub fn local_address(&self) -> io::Result<SocketAddr> {
41 self.listener.local_addr()
42 }
43
44 pub fn accept_connections(&self) -> Vec<NetStream> {
45 let mut streams = Vec::new();
46 while let Ok(result) = self.listener.accept() {
47 streams.push(NetStream::from_inner(result.0).unwrap())
48 }
49 streams
50 }
51}
52
53pub struct NetStream {
54 stream: TcpStream,
55 rx_buffer: Vec<u8>,
56 tx_buffer: Vec<u8>,
57}
58
59impl NetStream {
60 pub fn from_inner(stream: TcpStream) -> io::Result<Self> {
61 stream.set_nonblocking(true)?;
62 stream.set_nodelay(true)?;
63
64 Ok(Self {
65 stream,
66 rx_buffer: Default::default(),
67 tx_buffer: Default::default(),
68 })
69 }
70
71 pub fn connect<A: ToSocketAddrs>(addr: A) -> io::Result<Self> {
72 Self::from_inner(TcpStream::connect(addr)?)
73 }
74
75 pub fn send_message<T>(&mut self, data: &T) -> io::Result<()>
76 where
77 T: Serialize,
78 {
79 self.tx_buffer.clear();
80 if self.tx_buffer.capacity() < std::mem::size_of::<T>() {
81 self.tx_buffer.reserve(std::mem::size_of::<T>());
82 }
83 bincode::serialize_into(&mut self.tx_buffer, data)
84 .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
85 self.stream
86 .write_u32::<LittleEndian>(self.tx_buffer.len() as u32)?;
87 self.stream.write_all(&self.tx_buffer)?;
88 Ok(())
89 }
90
91 pub fn peer_address(&self) -> io::Result<SocketAddr> {
92 self.stream.peer_addr()
93 }
94
95 pub fn string_peer_address(&self) -> String {
96 self.peer_address()
97 .map(|addr| addr.to_string())
98 .unwrap_or_else(|_| "Unknown".into())
99 }
100
101 fn next_message<M>(&mut self) -> Option<M>
102 where
103 M: DeserializeOwned,
104 {
105 if self.rx_buffer.len() < 4 {
106 return None;
107 }
108
109 let length = u32::from_le_bytes([
110 self.rx_buffer[0],
111 self.rx_buffer[1],
112 self.rx_buffer[2],
113 self.rx_buffer[3],
114 ]) as usize;
115
116 let end = 4 + length;
117
118 if let Some(data) = self.rx_buffer.as_slice().get(4..end) {
120 let message = match bincode::deserialize::<M>(data) {
121 Ok(message) => Some(message),
122 Err(err) => {
123 Log::err(format!(
124 "Failed to parse a network message of {length} bytes long. Reason: {err:?}"
125 ));
126
127 None
128 }
129 };
130
131 self.rx_buffer.drain(..end);
132
133 message
134 } else {
135 None
136 }
137 }
138
139 pub fn process_input<M>(&mut self, mut func: impl FnMut(M))
140 where
141 M: DeserializeOwned,
142 {
143 loop {
145 let mut bytes = [0; 8192];
146 match self.stream.read(&mut bytes) {
147 Ok(bytes_count) => {
148 if bytes_count == 0 {
149 break;
150 } else {
151 self.rx_buffer.extend(&bytes[..bytes_count])
152 }
153 }
154 Err(err) => match err.kind() {
155 ErrorKind::WouldBlock => {
156 break;
157 }
158 ErrorKind::Interrupted => {
159 }
161 _ => {
162 Log::err(format!(
163 "An error occurred when reading data from socket: {err}"
164 ));
165
166 self.rx_buffer.clear();
167
168 return;
169 }
170 },
171 }
172 }
173
174 while let Some(message) = self.next_message() {
176 func(message)
177 }
178 }
179}