bantamweight/
peer.rs

1use std::{collections::HashMap, error::Error, fmt, io::ErrorKind, net::SocketAddr, sync::Arc, time::Duration};
2use plain_binary_stream::{BinaryStream, Serializable};
3use tokio::{io::{AsyncReadExt, AsyncWriteExt}, net::{TcpListener, TcpStream, tcp::ReadHalf}, sync::{mpsc, Mutex}, time::timeout};
4use crate::{AM, BantamPacketType, ByePacket, DataPacket, HandshakePacket, HandshakeResponsePacket, PacketHeader, SerializableSocketAddr};
5
6pub const TCP_STREAM_READ_BUFFER_SIZE: usize = 256;
7pub const TCP_STREAM_CONNECTION_TIMEOUT_SECS: u64 = 15;
8pub const TCP_STREAM_READ_TIMEOUT_SECS: u64 = 5;
9type Tx = mpsc::UnboundedSender<Vec<u8>>;
10
11#[derive(Debug)]
12pub enum BantamError {
13    ConnectionTimeout,
14    ReadTimeout,
15    ReceivedCorruptedPacket
16}
17
18impl fmt::Display for BantamError {
19    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
20        write!(f, "{}", self.to_string())
21    }
22}
23
24impl Error for BantamError {
25    fn source(&self) -> Option<&(dyn Error + 'static)> {
26        None
27    }
28
29    fn description(&self) -> &str {
30        "description() is deprecated; use Display"
31    }
32
33    fn cause(&self) -> Option<&dyn Error> {
34        None
35    }
36}
37
38pub struct Peer {
39    // Add ingoing/outgoing peer tag?
40    pub listener_addr: SocketAddr,
41    sender: Tx
42}
43
44pub trait ExternalSharedState {
45    fn on_connected(&mut self, addr: SocketAddr);
46    fn on_disconnected(&mut self, addr: SocketAddr); // E.g., so that miner can remove blockchain requests from the list
47    fn on_receive_packet(&mut self, bytes: Vec<u8>, sender_addr: SocketAddr);
48}
49
50pub struct PeerSharedState {
51    listener_addr: SocketAddr,
52    peers: HashMap<SocketAddr, Peer>
53}
54
55impl PeerSharedState {
56    pub fn get_peers(&self) -> Vec<(&SocketAddr, &Peer)> {
57        let peers = self.peers.iter().map(|pair| {
58            (pair.0, pair.1)
59        }).collect();
60
61        peers
62    }
63
64    pub fn get_peer_count(&self) -> usize {
65        return self.peers.len()
66    }
67
68    fn add_peer(&mut self, addr: SocketAddr, peer: Peer) -> bool {
69        self.peers.insert(addr, peer).is_none()
70    }
71
72    fn remove_peer(&mut self, addr: SocketAddr) -> bool {
73        self.peers.remove(&addr).is_some()
74    }
75
76    async fn broadcast(&self, bytes: Vec<u8>) -> Result<(), Box<dyn Error + Send + Sync>> {
77        Ok(for peer in self.peers.iter() {
78            peer.1.sender.send(bytes.clone())? // Get rid of clone
79        })
80    }
81
82    async fn unicast(&mut self, bytes: Vec<u8>, addr: SocketAddr) -> Result<(), Box<dyn Error + Send + Sync>> {
83        // Improve error handling
84        if let Some(peer) = self.peers.get(&addr) {
85            peer.sender.send(bytes)?;
86        }
87        else {
88            eprintln!("Peer with address {} is not part of this network.", addr);
89        }
90        
91        Ok(())
92    }
93}
94
95// --- Called by a peer that wants to connect ot a P2P network. ---
96pub async fn setup_peer<T: ExternalSharedState + Send + Sync + 'static>(port: u16, addr: SocketAddr,
97    ext_shared_state: AM<T>) -> Result<AM<PeerSharedState>, Box<dyn Error + Send + Sync>> {
98    let shared_state = setup_ingoing_peer(port, ext_shared_state.clone()).await?;
99    let shared_state_ref = shared_state.clone();
100    setup_outgoing_peer(addr, shared_state, true, ext_shared_state.clone()).await?;
101
102    Ok(shared_state_ref)
103}
104
105// --- Called by the first peer in a P2P network. ---
106// First of all, create a listener as a node in the P2P network. Second, connect to one of the nodes
107// already integrated in the network and wait for the list with all peers. Connect to each individually
108// in order to join.
109// Everybody joining after us will be accepted via the listener, while everyone joining before we
110// enter the network will be connected to manually.
111pub async fn setup_ingoing_peer<T: ExternalSharedState + Send + Sync + 'static>(port: u16,
112    ext_shared_state: AM<T>) -> Result<AM<PeerSharedState>, Box<dyn Error + Send + Sync>> {
113    let listener_addr: SocketAddr = format!("0.0.0.0:{}", port).parse().unwrap();
114    let listener = TcpListener::bind(listener_addr).await?;
115    let shared_state = Arc::new(Mutex::new(PeerSharedState {
116        listener_addr: listener_addr,
117        peers: HashMap::new()
118    }));
119    // Separate reference for the afterworld
120    let shared_state_ref = shared_state.clone();
121
122    tokio::spawn(async move {
123        loop {
124            match listener.accept().await {
125                Ok((conn, addr)) => {
126                    let shared_state_conn = shared_state.clone();
127                    println!("Accepted connection with {}...", &addr);
128
129                    handle_peer(conn, addr, shared_state_conn, ext_shared_state.clone())
130                },
131                Err(e) => {
132                    eprintln!("Failed to accept new connection: {}.", e);
133                    break;
134                }
135            }
136        }
137        println!("Terminating listener loop. Ingoing peer shut down.")
138    });
139
140    Ok(shared_state_ref)
141}
142
143// Called after setup_ingoing_peer(), in order to enter a P2P network
144async fn setup_outgoing_peer<T: ExternalSharedState + Send + Sync + 'static>(addr: SocketAddr,
145    shared_state: AM<PeerSharedState>, request_peers: bool, ext_shared_state: AM<T>) -> Result<(), Box<dyn Error + Send + Sync>> {
146    // Connect to TCP stream until timeout interrupts the operation.
147    match timeout(Duration::from_secs(TCP_STREAM_CONNECTION_TIMEOUT_SECS),
148        TcpStream::connect(addr)).await {
149        Ok(connection_result) => {
150            match connection_result {
151                Ok(mut conn) => {
152                    // Send the handshake to pass port for ingoing connections and to fetch list of all peers
153                    conn.write_all(&construct_bantam_packet(HandshakePacket::new(
154                    shared_state.lock().await.listener_addr.port(), request_peers))).await?;
155                    println!("Connecting with {}...", &addr);
156                    handle_peer(conn, addr, shared_state.clone(), ext_shared_state);
157
158                    Ok(())
159                },
160                Err(e) => return Err(Box::new(e))
161            }
162        },
163        Err(_) => {
164            return Err(Box::new(BantamError::ConnectionTimeout));
165        }
166    }
167}
168
169pub async fn shutdown(shared_state: AM<PeerSharedState>) -> Result<(), Box<dyn Error + Send + Sync>> {
170    send_bantam_packet(ByePacket::new(0), shared_state).await
171}
172
173fn handle_peer<T: ExternalSharedState + Send + Sync + 'static>(conn: TcpStream, addr: SocketAddr,
174    shared_state: AM<PeerSharedState>, ext_shared_state: AM<T>) {
175    // Configure the connection
176    if let Err(e) = conn.set_linger(None) {
177        println!("Failed to set linger duration of connection with {}: {}.", addr, e);
178    }
179    if let Err(e) = conn.set_nodelay(true) {
180        println!("Failed to set no delay of connection with {}: {}.", addr, e);
181    }
182
183    tokio::spawn(async move {
184        if let Err(e) = handle_peer_io_loop(conn, addr, shared_state, ext_shared_state).await {
185            eprintln!("Failed running IO loop for {}: {}", addr, e);
186        }
187    });
188}
189
190async fn handle_peer_io_loop<T: ExternalSharedState + Send + Sync + 'static>(mut conn: TcpStream, addr: SocketAddr,
191    shared_state: AM<PeerSharedState>, ext_shared_state: AM<T>) -> Result<(), Box<dyn Error + Send + Sync>> {
192    let (tx, mut rx) = mpsc::unbounded_channel::<Vec<u8>>();
193    let (mut reader, mut writer) = conn.split();
194
195    loop {
196        let mut buffer = [0u8; TCP_STREAM_READ_BUFFER_SIZE];
197        tokio::select! {
198            Some(msg) = rx.recv() => {
199                writer.write_all(&msg).await?;
200            }
201            read_result = reader.read(&mut buffer) => match read_result {
202                Ok(bytes_received) => {
203                    let mut buffer_vec = buffer[0..bytes_received].to_vec();
204                    // Check the segment read first and potentially read remaining segments of the whole packet.
205                    match read_segments(bytes_received, &mut buffer_vec, &mut reader).await {
206                        Err(e) => {
207                            eprintln!("Error occured while reading packet segments of {}: {}", addr, e);
208                            break;
209                        },
210                        Ok(total_bytes_received) if total_bytes_received == 0 => break,
211                        _ => ()
212                    }
213
214                    // Process packet depending on BantamPacketType
215                    match process_packet(buffer_vec, addr, tx.clone(), shared_state.clone(), ext_shared_state.clone()).await {
216                        Ok(connected) if !connected => break, // If the function returns Ok(false), peer sent a ByePacket
217                        Err(e) => {
218                            eprintln!("Error occured while processing received packet: {}.", e);
219                            break
220                        },
221                        _ => ()
222                    }
223                },
224                Err(e) if e.kind() == ErrorKind::ConnectionReset => break, // Peer disconnected
225                Err(e) => {
226                    eprintln!("Error ({:?}) occured while reading stream of {}: {}", e.kind(), &addr, e);
227                    break
228                }
229            }
230        }
231    }
232
233    println!("Peer {} disconnected.", &addr);
234    ext_shared_state.lock().await.on_disconnected(addr);
235    shared_state.lock().await.remove_peer(addr);
236
237    Ok(())
238}
239
240async fn read_segments<'a>(bytes_received: usize, buffer_vec: &mut Vec<u8>, reader: &mut ReadHalf<'a>)
241    -> Result<usize, Box<dyn Error + Send + Sync>> {
242    if bytes_received <= 4 { // As the base bantam packet header (size of packet) is 4 bytes big, a valid message must have > 4 bytes.
243        return Ok(0);
244    }
245
246    let total_packet_size = check_first_packet_segment(buffer_vec);
247    let mut total_packet_bytes_received = bytes_received - 4;
248    let packet_bytes_received_percent_step = (total_packet_size as f32 * 0.25f32) as usize;
249    let mut packet_bytes_received_step = packet_bytes_received_percent_step;
250
251    // As long as the amount of bytes we read is still less than the size announced by the packet, continue reading...
252    while total_packet_bytes_received < total_packet_size {
253        let mut buffer = [0u8; TCP_STREAM_READ_BUFFER_SIZE];
254        match timeout(Duration::from_secs(TCP_STREAM_READ_TIMEOUT_SECS), reader.read(&mut buffer)).await {
255            Ok(read_result) => {
256                match read_result {
257                    Ok(bytes_received) => {
258                        buffer_vec.extend(buffer[0..bytes_received].to_vec());
259                        total_packet_bytes_received += bytes_received;
260                        
261                        let progress_percentage = f32::round(total_packet_bytes_received as f32 /
262                            total_packet_size as f32 * 100f32);
263                        if total_packet_bytes_received > packet_bytes_received_step {
264                            println!("Downloaded {}% of packet ({}b of {}b).", progress_percentage,
265                                total_packet_bytes_received, total_packet_size);
266                            packet_bytes_received_step += packet_bytes_received_percent_step;
267                        }
268                    },
269                    Err(e) => return Err(Box::new(e))
270                }   
271            },
272            Err(_) => return Err(Box::new(BantamError::ReadTimeout))
273        }
274    }
275
276    Ok(total_packet_size)
277}
278
279fn check_first_packet_segment(buffer: &mut Vec<u8>) -> usize {
280    let mut size_bytes = [0u8; 4];
281    // Remove the first four elements of the buffer, i.e., the u32 at the beginning of the first packet segment,
282    // which indicates the total packet size.
283    for i in 0..4 {
284        size_bytes[i] = buffer.remove(0);
285    }
286
287    u32::from_le_bytes(size_bytes) as usize
288}
289
290async fn process_packet<T: ExternalSharedState + Send + Sync + 'static>(bytes: Vec<u8>, addr: SocketAddr,
291    tx: Tx, shared_state: AM<PeerSharedState>, ext_shared_state: AM<T>) -> Result<bool, Box<dyn Error + Send + Sync>> {
292    let (packet_type, mut stream) = deconstruct_bantam_packet(bytes);
293    match packet_type {
294        BantamPacketType::Handshake => {
295            process_handshake_packet(stream, addr, tx.clone(), shared_state.clone()).await?;
296        },
297        BantamPacketType::HandshakeResponse => {
298            process_handshake_response_packet(stream, addr, tx.clone(), shared_state.clone(),
299                ext_shared_state.clone()).await?;
300        },
301        BantamPacketType::Data => {
302            let data_packet = DataPacket::from_stream(&mut stream);
303            ext_shared_state.lock().await.on_receive_packet(data_packet.bytes, addr);
304        },
305        BantamPacketType::Bye => {
306            println!("Peer {} manually disconnected.", &addr);
307            return Ok(false);
308        },
309    }
310
311    Ok(true)
312}
313
314async fn process_handshake_packet(mut stream: BinaryStream, addr: SocketAddr, tx: Tx,
315    shared_state: AM<PeerSharedState>) -> Result<(), Box<dyn Error + Send + Sync>> {
316    // This will only occur on the ingoing (=listening) peer side
317    let handshake = HandshakePacket::from_stream(&mut stream);
318
319    // After hands have been shook, add the peer to the network
320    // We construct the listener address of the peer from the port they send us.
321    // This way, we can send new connections this peer's actual ingoing listener address.
322    let mut listener_addr = addr.clone();
323    listener_addr.set_port(handshake.listening_port);
324    shared_state.lock().await.add_peer(addr, Peer {
325        listener_addr, sender: tx.clone()
326    });
327
328    // If the ingoing peer requested a list of all P2P network members, fetch the list.
329    // If the peer didn't request such list, they already have a connection to an entry
330    // peer and used their peer address list to connect to us.
331    let peer_addresses = match handshake.request_peers {
332        true => {
333            let mut peer_addresses = vec![];
334            for (_, peer) in shared_state.lock().await.peers.iter() {
335                // If it's not the peer we're currently sending the peer list to...
336                if peer.listener_addr != listener_addr {
337                    // ... then include the peer in the list.
338                    peer_addresses.push(SerializableSocketAddr::from_sock_addr(peer.listener_addr));
339                }
340            }
341
342            println!("Peer {} connected. Integrating into network...", &addr);
343            peer_addresses
344        },
345        false => {
346            println!("Peer {} connected.", &addr);
347            vec![]
348        }
349    };
350
351    // Ingoing peer is obligated to send handshake response.
352    send_bantam_packet_to(HandshakeResponsePacket::new(peer_addresses),
353        addr, shared_state.clone()).await
354}
355
356async fn process_handshake_response_packet<T: ExternalSharedState + Send + Sync + 'static>(
357    mut stream: BinaryStream, addr: SocketAddr, tx: Tx, shared_state: AM<PeerSharedState>,
358    ext_shared_state: AM<T>) -> Result<(), Box<dyn Error + Send + Sync>> {
359    // This will only occur on the outgoing (=connecting) peer side
360
361    // As we now successfully connected, we can safely add the outgoing peer
362    shared_state.lock().await.add_peer(addr, Peer {
363        listener_addr: addr, sender: tx.clone()
364    });
365
366    // Outgoing peer is obligated to send handshake
367    let handshake_response = HandshakeResponsePacket::from_stream(&mut stream);
368    // Connect to all peers that are in the P2P network, according to first-hand connection
369    if handshake_response.peers.len() > 0 {
370        println!("Connecting with remaining peers in the network...");
371        for addr in handshake_response.peers {
372            let sock_addr = addr.to_sock_addr();
373            if !shared_state.lock().await.peers.contains_key(&sock_addr) {
374                if let Err(e) = setup_outgoing_peer(sock_addr, shared_state.clone(), false, ext_shared_state.clone()).await {
375                    println!("Connection attempt with {} failed: {}.", sock_addr, e);
376                    continue;
377                }
378            }
379        }
380    }
381
382    println!("Established connection with {}.", &addr);
383    ext_shared_state.lock().await.on_connected(addr);
384    Ok(())
385}
386
387async fn send_packet(bytes: Vec<u8>, shared_state: AM<PeerSharedState>)
388    -> Result<(), Box<dyn Error + Send + Sync>> {
389    shared_state.lock().await.broadcast(bytes).await
390}
391
392async fn send_packet_to(bytes: Vec<u8>, addr: SocketAddr, shared_state: AM<PeerSharedState>)
393    -> Result<(), Box<dyn Error + Send + Sync>> {
394    shared_state.lock().await.unicast(bytes, addr).await
395}
396
397fn construct_bantam_packet<T: PacketHeader<BantamPacketType> + Serializable>(
398    packet: T) -> Vec<u8> {
399    let mut stream = BinaryStream::new();
400    stream.write_packet_type(packet.get_type()).unwrap();
401    packet.to_stream(&mut stream);
402    
403    let buffer = stream.get_buffer_vec();
404    // Write the total packet size for the receiver, so they know how much to read.
405    let mut header = u32::to_le_bytes(buffer.len() as u32).to_vec();
406    // We need the packet size indicator to be at the very beginning of the packet, so it can be read first.
407    header.extend(buffer);
408    header
409}
410
411fn deconstruct_bantam_packet(bytes: Vec<u8>) -> (BantamPacketType, BinaryStream) {
412    let mut stream = BinaryStream::from_bytes(&bytes);
413    (stream.read_packet_type::<BantamPacketType>().unwrap(), stream)
414}
415
416async fn send_bantam_packet<T: PacketHeader<BantamPacketType> + Serializable>(
417    packet: T, shared_state: AM<PeerSharedState>) -> Result<(), Box<dyn Error + Send + Sync>> {
418    send_packet(construct_bantam_packet(packet), shared_state).await
419}
420
421async fn send_bantam_packet_to<T: PacketHeader<BantamPacketType> + Serializable>(packet: T,
422    addr: SocketAddr, shared_state: AM<PeerSharedState>) -> Result<(), Box<dyn Error + Send + Sync>> {
423    send_packet_to(construct_bantam_packet(packet), addr, shared_state).await
424}
425
426pub async fn send_data_packet(bytes: Vec<u8>, shared_state: AM<PeerSharedState>)
427    -> Result<(), Box<dyn Error + Send + Sync>> {
428    send_packet(construct_bantam_packet(DataPacket::new(bytes)), shared_state).await
429}
430
431pub async fn send_data_packet_to(bytes: Vec<u8>, addr: SocketAddr, shared_state: AM<PeerSharedState>)
432    -> Result<(), Box<dyn Error + Send + Sync>>  {
433    send_packet_to(construct_bantam_packet(DataPacket::new(bytes)), addr, shared_state).await
434}