1use std::{collections::HashMap, error::Error, fmt, io::ErrorKind, net::SocketAddr, sync::Arc, time::Duration};
2use plain_binary_stream::{BinaryStream, Serializable};
3use tokio::{io::{AsyncReadExt, AsyncWriteExt}, net::{TcpListener, TcpStream, tcp::ReadHalf}, sync::{mpsc, Mutex}, time::timeout};
4use crate::{AM, BantamPacketType, ByePacket, DataPacket, HandshakePacket, HandshakeResponsePacket, PacketHeader, SerializableSocketAddr};
5
6pub const TCP_STREAM_READ_BUFFER_SIZE: usize = 256;
7pub const TCP_STREAM_CONNECTION_TIMEOUT_SECS: u64 = 15;
8pub const TCP_STREAM_READ_TIMEOUT_SECS: u64 = 5;
9type Tx = mpsc::UnboundedSender<Vec<u8>>;
10
11#[derive(Debug)]
12pub enum BantamError {
13 ConnectionTimeout,
14 ReadTimeout,
15 ReceivedCorruptedPacket
16}
17
18impl fmt::Display for BantamError {
19 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
20 write!(f, "{}", self.to_string())
21 }
22}
23
24impl Error for BantamError {
25 fn source(&self) -> Option<&(dyn Error + 'static)> {
26 None
27 }
28
29 fn description(&self) -> &str {
30 "description() is deprecated; use Display"
31 }
32
33 fn cause(&self) -> Option<&dyn Error> {
34 None
35 }
36}
37
38pub struct Peer {
39 pub listener_addr: SocketAddr,
41 sender: Tx
42}
43
44pub trait ExternalSharedState {
45 fn on_connected(&mut self, addr: SocketAddr);
46 fn on_disconnected(&mut self, addr: SocketAddr); fn on_receive_packet(&mut self, bytes: Vec<u8>, sender_addr: SocketAddr);
48}
49
50pub struct PeerSharedState {
51 listener_addr: SocketAddr,
52 peers: HashMap<SocketAddr, Peer>
53}
54
55impl PeerSharedState {
56 pub fn get_peers(&self) -> Vec<(&SocketAddr, &Peer)> {
57 let peers = self.peers.iter().map(|pair| {
58 (pair.0, pair.1)
59 }).collect();
60
61 peers
62 }
63
64 pub fn get_peer_count(&self) -> usize {
65 return self.peers.len()
66 }
67
68 fn add_peer(&mut self, addr: SocketAddr, peer: Peer) -> bool {
69 self.peers.insert(addr, peer).is_none()
70 }
71
72 fn remove_peer(&mut self, addr: SocketAddr) -> bool {
73 self.peers.remove(&addr).is_some()
74 }
75
76 async fn broadcast(&self, bytes: Vec<u8>) -> Result<(), Box<dyn Error + Send + Sync>> {
77 Ok(for peer in self.peers.iter() {
78 peer.1.sender.send(bytes.clone())? })
80 }
81
82 async fn unicast(&mut self, bytes: Vec<u8>, addr: SocketAddr) -> Result<(), Box<dyn Error + Send + Sync>> {
83 if let Some(peer) = self.peers.get(&addr) {
85 peer.sender.send(bytes)?;
86 }
87 else {
88 eprintln!("Peer with address {} is not part of this network.", addr);
89 }
90
91 Ok(())
92 }
93}
94
95pub async fn setup_peer<T: ExternalSharedState + Send + Sync + 'static>(port: u16, addr: SocketAddr,
97 ext_shared_state: AM<T>) -> Result<AM<PeerSharedState>, Box<dyn Error + Send + Sync>> {
98 let shared_state = setup_ingoing_peer(port, ext_shared_state.clone()).await?;
99 let shared_state_ref = shared_state.clone();
100 setup_outgoing_peer(addr, shared_state, true, ext_shared_state.clone()).await?;
101
102 Ok(shared_state_ref)
103}
104
105pub async fn setup_ingoing_peer<T: ExternalSharedState + Send + Sync + 'static>(port: u16,
112 ext_shared_state: AM<T>) -> Result<AM<PeerSharedState>, Box<dyn Error + Send + Sync>> {
113 let listener_addr: SocketAddr = format!("0.0.0.0:{}", port).parse().unwrap();
114 let listener = TcpListener::bind(listener_addr).await?;
115 let shared_state = Arc::new(Mutex::new(PeerSharedState {
116 listener_addr: listener_addr,
117 peers: HashMap::new()
118 }));
119 let shared_state_ref = shared_state.clone();
121
122 tokio::spawn(async move {
123 loop {
124 match listener.accept().await {
125 Ok((conn, addr)) => {
126 let shared_state_conn = shared_state.clone();
127 println!("Accepted connection with {}...", &addr);
128
129 handle_peer(conn, addr, shared_state_conn, ext_shared_state.clone())
130 },
131 Err(e) => {
132 eprintln!("Failed to accept new connection: {}.", e);
133 break;
134 }
135 }
136 }
137 println!("Terminating listener loop. Ingoing peer shut down.")
138 });
139
140 Ok(shared_state_ref)
141}
142
143async fn setup_outgoing_peer<T: ExternalSharedState + Send + Sync + 'static>(addr: SocketAddr,
145 shared_state: AM<PeerSharedState>, request_peers: bool, ext_shared_state: AM<T>) -> Result<(), Box<dyn Error + Send + Sync>> {
146 match timeout(Duration::from_secs(TCP_STREAM_CONNECTION_TIMEOUT_SECS),
148 TcpStream::connect(addr)).await {
149 Ok(connection_result) => {
150 match connection_result {
151 Ok(mut conn) => {
152 conn.write_all(&construct_bantam_packet(HandshakePacket::new(
154 shared_state.lock().await.listener_addr.port(), request_peers))).await?;
155 println!("Connecting with {}...", &addr);
156 handle_peer(conn, addr, shared_state.clone(), ext_shared_state);
157
158 Ok(())
159 },
160 Err(e) => return Err(Box::new(e))
161 }
162 },
163 Err(_) => {
164 return Err(Box::new(BantamError::ConnectionTimeout));
165 }
166 }
167}
168
169pub async fn shutdown(shared_state: AM<PeerSharedState>) -> Result<(), Box<dyn Error + Send + Sync>> {
170 send_bantam_packet(ByePacket::new(0), shared_state).await
171}
172
173fn handle_peer<T: ExternalSharedState + Send + Sync + 'static>(conn: TcpStream, addr: SocketAddr,
174 shared_state: AM<PeerSharedState>, ext_shared_state: AM<T>) {
175 if let Err(e) = conn.set_linger(None) {
177 println!("Failed to set linger duration of connection with {}: {}.", addr, e);
178 }
179 if let Err(e) = conn.set_nodelay(true) {
180 println!("Failed to set no delay of connection with {}: {}.", addr, e);
181 }
182
183 tokio::spawn(async move {
184 if let Err(e) = handle_peer_io_loop(conn, addr, shared_state, ext_shared_state).await {
185 eprintln!("Failed running IO loop for {}: {}", addr, e);
186 }
187 });
188}
189
190async fn handle_peer_io_loop<T: ExternalSharedState + Send + Sync + 'static>(mut conn: TcpStream, addr: SocketAddr,
191 shared_state: AM<PeerSharedState>, ext_shared_state: AM<T>) -> Result<(), Box<dyn Error + Send + Sync>> {
192 let (tx, mut rx) = mpsc::unbounded_channel::<Vec<u8>>();
193 let (mut reader, mut writer) = conn.split();
194
195 loop {
196 let mut buffer = [0u8; TCP_STREAM_READ_BUFFER_SIZE];
197 tokio::select! {
198 Some(msg) = rx.recv() => {
199 writer.write_all(&msg).await?;
200 }
201 read_result = reader.read(&mut buffer) => match read_result {
202 Ok(bytes_received) => {
203 let mut buffer_vec = buffer[0..bytes_received].to_vec();
204 match read_segments(bytes_received, &mut buffer_vec, &mut reader).await {
206 Err(e) => {
207 eprintln!("Error occured while reading packet segments of {}: {}", addr, e);
208 break;
209 },
210 Ok(total_bytes_received) if total_bytes_received == 0 => break,
211 _ => ()
212 }
213
214 match process_packet(buffer_vec, addr, tx.clone(), shared_state.clone(), ext_shared_state.clone()).await {
216 Ok(connected) if !connected => break, Err(e) => {
218 eprintln!("Error occured while processing received packet: {}.", e);
219 break
220 },
221 _ => ()
222 }
223 },
224 Err(e) if e.kind() == ErrorKind::ConnectionReset => break, Err(e) => {
226 eprintln!("Error ({:?}) occured while reading stream of {}: {}", e.kind(), &addr, e);
227 break
228 }
229 }
230 }
231 }
232
233 println!("Peer {} disconnected.", &addr);
234 ext_shared_state.lock().await.on_disconnected(addr);
235 shared_state.lock().await.remove_peer(addr);
236
237 Ok(())
238}
239
240async fn read_segments<'a>(bytes_received: usize, buffer_vec: &mut Vec<u8>, reader: &mut ReadHalf<'a>)
241 -> Result<usize, Box<dyn Error + Send + Sync>> {
242 if bytes_received <= 4 { return Ok(0);
244 }
245
246 let total_packet_size = check_first_packet_segment(buffer_vec);
247 let mut total_packet_bytes_received = bytes_received - 4;
248 let packet_bytes_received_percent_step = (total_packet_size as f32 * 0.25f32) as usize;
249 let mut packet_bytes_received_step = packet_bytes_received_percent_step;
250
251 while total_packet_bytes_received < total_packet_size {
253 let mut buffer = [0u8; TCP_STREAM_READ_BUFFER_SIZE];
254 match timeout(Duration::from_secs(TCP_STREAM_READ_TIMEOUT_SECS), reader.read(&mut buffer)).await {
255 Ok(read_result) => {
256 match read_result {
257 Ok(bytes_received) => {
258 buffer_vec.extend(buffer[0..bytes_received].to_vec());
259 total_packet_bytes_received += bytes_received;
260
261 let progress_percentage = f32::round(total_packet_bytes_received as f32 /
262 total_packet_size as f32 * 100f32);
263 if total_packet_bytes_received > packet_bytes_received_step {
264 println!("Downloaded {}% of packet ({}b of {}b).", progress_percentage,
265 total_packet_bytes_received, total_packet_size);
266 packet_bytes_received_step += packet_bytes_received_percent_step;
267 }
268 },
269 Err(e) => return Err(Box::new(e))
270 }
271 },
272 Err(_) => return Err(Box::new(BantamError::ReadTimeout))
273 }
274 }
275
276 Ok(total_packet_size)
277}
278
279fn check_first_packet_segment(buffer: &mut Vec<u8>) -> usize {
280 let mut size_bytes = [0u8; 4];
281 for i in 0..4 {
284 size_bytes[i] = buffer.remove(0);
285 }
286
287 u32::from_le_bytes(size_bytes) as usize
288}
289
290async fn process_packet<T: ExternalSharedState + Send + Sync + 'static>(bytes: Vec<u8>, addr: SocketAddr,
291 tx: Tx, shared_state: AM<PeerSharedState>, ext_shared_state: AM<T>) -> Result<bool, Box<dyn Error + Send + Sync>> {
292 let (packet_type, mut stream) = deconstruct_bantam_packet(bytes);
293 match packet_type {
294 BantamPacketType::Handshake => {
295 process_handshake_packet(stream, addr, tx.clone(), shared_state.clone()).await?;
296 },
297 BantamPacketType::HandshakeResponse => {
298 process_handshake_response_packet(stream, addr, tx.clone(), shared_state.clone(),
299 ext_shared_state.clone()).await?;
300 },
301 BantamPacketType::Data => {
302 let data_packet = DataPacket::from_stream(&mut stream);
303 ext_shared_state.lock().await.on_receive_packet(data_packet.bytes, addr);
304 },
305 BantamPacketType::Bye => {
306 println!("Peer {} manually disconnected.", &addr);
307 return Ok(false);
308 },
309 }
310
311 Ok(true)
312}
313
314async fn process_handshake_packet(mut stream: BinaryStream, addr: SocketAddr, tx: Tx,
315 shared_state: AM<PeerSharedState>) -> Result<(), Box<dyn Error + Send + Sync>> {
316 let handshake = HandshakePacket::from_stream(&mut stream);
318
319 let mut listener_addr = addr.clone();
323 listener_addr.set_port(handshake.listening_port);
324 shared_state.lock().await.add_peer(addr, Peer {
325 listener_addr, sender: tx.clone()
326 });
327
328 let peer_addresses = match handshake.request_peers {
332 true => {
333 let mut peer_addresses = vec![];
334 for (_, peer) in shared_state.lock().await.peers.iter() {
335 if peer.listener_addr != listener_addr {
337 peer_addresses.push(SerializableSocketAddr::from_sock_addr(peer.listener_addr));
339 }
340 }
341
342 println!("Peer {} connected. Integrating into network...", &addr);
343 peer_addresses
344 },
345 false => {
346 println!("Peer {} connected.", &addr);
347 vec![]
348 }
349 };
350
351 send_bantam_packet_to(HandshakeResponsePacket::new(peer_addresses),
353 addr, shared_state.clone()).await
354}
355
356async fn process_handshake_response_packet<T: ExternalSharedState + Send + Sync + 'static>(
357 mut stream: BinaryStream, addr: SocketAddr, tx: Tx, shared_state: AM<PeerSharedState>,
358 ext_shared_state: AM<T>) -> Result<(), Box<dyn Error + Send + Sync>> {
359 shared_state.lock().await.add_peer(addr, Peer {
363 listener_addr: addr, sender: tx.clone()
364 });
365
366 let handshake_response = HandshakeResponsePacket::from_stream(&mut stream);
368 if handshake_response.peers.len() > 0 {
370 println!("Connecting with remaining peers in the network...");
371 for addr in handshake_response.peers {
372 let sock_addr = addr.to_sock_addr();
373 if !shared_state.lock().await.peers.contains_key(&sock_addr) {
374 if let Err(e) = setup_outgoing_peer(sock_addr, shared_state.clone(), false, ext_shared_state.clone()).await {
375 println!("Connection attempt with {} failed: {}.", sock_addr, e);
376 continue;
377 }
378 }
379 }
380 }
381
382 println!("Established connection with {}.", &addr);
383 ext_shared_state.lock().await.on_connected(addr);
384 Ok(())
385}
386
387async fn send_packet(bytes: Vec<u8>, shared_state: AM<PeerSharedState>)
388 -> Result<(), Box<dyn Error + Send + Sync>> {
389 shared_state.lock().await.broadcast(bytes).await
390}
391
392async fn send_packet_to(bytes: Vec<u8>, addr: SocketAddr, shared_state: AM<PeerSharedState>)
393 -> Result<(), Box<dyn Error + Send + Sync>> {
394 shared_state.lock().await.unicast(bytes, addr).await
395}
396
397fn construct_bantam_packet<T: PacketHeader<BantamPacketType> + Serializable>(
398 packet: T) -> Vec<u8> {
399 let mut stream = BinaryStream::new();
400 stream.write_packet_type(packet.get_type()).unwrap();
401 packet.to_stream(&mut stream);
402
403 let buffer = stream.get_buffer_vec();
404 let mut header = u32::to_le_bytes(buffer.len() as u32).to_vec();
406 header.extend(buffer);
408 header
409}
410
411fn deconstruct_bantam_packet(bytes: Vec<u8>) -> (BantamPacketType, BinaryStream) {
412 let mut stream = BinaryStream::from_bytes(&bytes);
413 (stream.read_packet_type::<BantamPacketType>().unwrap(), stream)
414}
415
416async fn send_bantam_packet<T: PacketHeader<BantamPacketType> + Serializable>(
417 packet: T, shared_state: AM<PeerSharedState>) -> Result<(), Box<dyn Error + Send + Sync>> {
418 send_packet(construct_bantam_packet(packet), shared_state).await
419}
420
421async fn send_bantam_packet_to<T: PacketHeader<BantamPacketType> + Serializable>(packet: T,
422 addr: SocketAddr, shared_state: AM<PeerSharedState>) -> Result<(), Box<dyn Error + Send + Sync>> {
423 send_packet_to(construct_bantam_packet(packet), addr, shared_state).await
424}
425
426pub async fn send_data_packet(bytes: Vec<u8>, shared_state: AM<PeerSharedState>)
427 -> Result<(), Box<dyn Error + Send + Sync>> {
428 send_packet(construct_bantam_packet(DataPacket::new(bytes)), shared_state).await
429}
430
431pub async fn send_data_packet_to(bytes: Vec<u8>, addr: SocketAddr, shared_state: AM<PeerSharedState>)
432 -> Result<(), Box<dyn Error + Send + Sync>> {
433 send_packet_to(construct_bantam_packet(DataPacket::new(bytes)), addr, shared_state).await
434}