snap_coin/node/
peer.rs

1use bincode::error::EncodeError;
2use std::{
3    collections::{HashMap, VecDeque},
4    net::SocketAddr,
5    sync::Arc,
6    time::Duration,
7};
8use thiserror::Error;
9use tokio::{
10    net::TcpStream,
11    sync::{RwLock, oneshot},
12    task::JoinHandle,
13    time::{sleep, timeout},
14};
15
16use crate::{
17    core::{blockchain::BlockchainError, utxo::TransactionError},
18    node::{
19        message::{Command, Message, MessageError},
20        node::Node,
21        sync::sync_to_peer,
22    },
23};
24
25#[derive(Error, Debug)]
26pub enum PeerError {
27    #[error("{0}")]
28    MessageError(#[from] MessageError),
29
30    #[error("Disconnected")]
31    Disconnected,
32
33    #[error("Blockchain error: {0}")]
34    BlockchainError(#[from] BlockchainError),
35
36    #[error("Transaction error: {0}")]
37    TransactionError(#[from] TransactionError),
38
39    #[error("Sync peer returned an invalid response")]
40    SyncResponseInvalid,
41
42    #[error("Could not find fork point with peer")]
43    NoForkPoint,
44
45    #[error("Block has invalid difficulty")]
46    BadBlockDifficulty,
47
48    #[error("Block has invalid block hash")]
49    BadBlockHash,
50
51    #[error("Block has no block hash attached")]
52    NoBlockHash,
53
54    #[error("Encode error: {0}")]
55    EncodeError(#[from] EncodeError),
56}
57
58pub const TIMEOUT: Duration = Duration::from_secs(15);
59
60/// A struct representing one peer (peer connection. Can be both a client peer or a connected peer)
61pub struct Peer {
62    pub address: SocketAddr,
63
64    pub is_client: bool,
65
66    // Outgoing messages waiting to be written to stream
67    send_queue: VecDeque<Message>,
68
69    // Pending requests waiting for a response (id -> oneshot sender)
70    pending: HashMap<u16, oneshot::Sender<Message>>,
71}
72
73impl Peer {
74    /// Create a new peer
75    pub fn new(address: SocketAddr, is_client: bool) -> Self {
76        Self {
77            address,
78            is_client,
79            send_queue: VecDeque::new(),
80            pending: HashMap::new(),
81        }
82    }
83
84    async fn on_fail(peer: Arc<RwLock<Peer>>, node: Arc<RwLock<Node>>) {
85        let peer_address = peer.read().await.address;
86
87        let mut node_peers = node.write().await;
88
89        let mut new_peers = Vec::new();
90        for p in node_peers.peers.drain(..) {
91            let p_address = p.read().await.address;
92            if p_address != peer_address {
93                new_peers.push(p);
94            }
95        }
96
97        node_peers.peers = new_peers;
98    }
99
100    /// Main connection handler
101    pub async fn connect(
102        peer: Arc<RwLock<Peer>>,
103        node: Arc<RwLock<Node>>,
104        stream: TcpStream,
105    ) -> JoinHandle<Result<(), PeerError>> {
106        let (mut read_stream, mut write_stream) = stream.into_split();
107
108        // Spawn peer handler task
109        tokio::spawn(async move {
110            let peer_cloned = peer.clone();
111            let node_cloned = node.clone();
112            // Spawn ping / pong task
113            let pinger = {
114                let peer_outer = peer.clone();
115                let node_outer = node.clone();
116
117                Box::pin(async move {
118                    loop {
119                        sleep(Duration::from_secs(5)).await;
120
121                        let height = node_outer.read().await.blockchain.get_height();
122
123                        let response = Peer::request(
124                            peer_outer.clone(),
125                            Message::new(Command::Ping { height }),
126                        )
127                        .await?;
128
129                        if let Command::Pong { height } = response.command {
130                            let local_height = node_outer.read().await.blockchain.get_height();
131
132                            if local_height < height {
133                                let node_for_task = node_outer.clone();
134                                let peer_for_task = peer_outer.clone();
135
136                                tokio::spawn(async move {
137                                    if node_for_task.read().await.is_syncing {
138                                        return;
139                                    }
140
141                                    node_for_task.write().await.is_syncing = true;
142
143                                    let result = sync_to_peer(
144                                        node_for_task.clone(),
145                                        peer_for_task.clone(),
146                                        height,
147                                    )
148                                    .await;
149
150                                    if let Err(e) = result {
151                                        Node::log(format!(
152                                            "[SYNC] Failed: {}, disconnecting from {}",
153                                            e,
154                                            peer_for_task.read().await.address
155                                        ));
156
157                                        let node_for_task = node_for_task.clone();
158                                        Peer::on_fail(peer_for_task, node_for_task).await;
159                                    } else {
160                                        Node::log("[SYNC] Completed".to_string());
161                                    }
162
163                                    node_for_task.write().await.is_syncing = false;
164                                });
165                            }
166                        }
167                    }
168                    #[allow(unreachable_code)]
169                    Ok::<(), PeerError>(())
170                })
171            };
172
173            // Spawn reader task
174            let reader = {
175                let peer = peer.clone();
176                let node = node.clone();
177                Box::pin(async move {
178                    loop {
179                        let msg = Message::from_stream(&mut read_stream).await?;
180                        match timeout(
181                            TIMEOUT,
182                            Peer::handle_incoming(peer.clone(), node.clone(), msg),
183                        )
184                        .await
185                        {
186                            Ok(()) => {}
187                            Err(..) => return Err(PeerError::Disconnected),
188                        }
189                    }
190                    #[allow(unreachable_code)]
191                    Ok::<(), PeerError>(())
192                })
193            };
194
195            // Spawn writer task
196            let writer = {
197                let peer = peer.clone();
198                Box::pin(async move {
199                    loop {
200                        let maybe_msg = {
201                            let mut p = peer.write().await;
202                            p.send_queue.pop_front()
203                        };
204
205                        if let Some(msg) = maybe_msg {
206                            match timeout(TIMEOUT, msg.send(&mut write_stream)).await {
207                                Ok(e) => e?,
208                                Err(..) => return Err(PeerError::Disconnected),
209                            }
210                        } else {
211                            sleep(Duration::from_millis(10)).await;
212                        }
213                    }
214                    #[allow(unreachable_code)]
215                    Ok::<(), PeerError>(())
216                })
217            };
218
219            // Join all tasks
220            let result = tokio::select! {
221              r = reader => r,
222              r = writer => r,
223              r = pinger => r,
224            };
225
226            if let Err(e) = result {
227                Node::log(format!(
228                    "Disconnected from peer: {}:{}. Error: {:?}",
229                    peer.read().await.address.ip(),
230                    peer.read().await.address.port(),
231                    e
232                ));
233
234                tokio::spawn(async move {
235                    Self::on_fail(peer_cloned, node_cloned).await;
236                });
237            }
238            Ok(())
239        })
240    }
241
242    /// Handle incoming message
243    async fn handle_incoming(peer: Arc<RwLock<Peer>>, node: Arc<RwLock<Node>>, message: Message) {
244        {
245            let mut p = peer.write().await;
246            if let Some(tx) = p.pending.remove(&message.id) {
247                let _ = tx.send(message);
248                return;
249            }
250        }
251
252        Peer::on_message(peer.clone(), node.clone(), message).await;
253    }
254
255    /// Handle incoming message
256    async fn on_message(peer: Arc<RwLock<Peer>>, node: Arc<RwLock<Node>>, message: Message) {
257        if let Err(err) = async {
258            match message.command {
259                Command::Connect => {
260                    Peer::send(peer, message.make_response(Command::AcknowledgeConnection)).await;
261                }
262                Command::AcknowledgeConnection => {
263                    Node::log(format!("Got unhandled AcknowledgeConnection"));
264                }
265                Command::Ping { height: _ } => {
266                    Peer::send(
267                        peer.clone(),
268                        message.make_response(Command::Pong {
269                            height: node.read().await.blockchain.get_height(),
270                        }),
271                    )
272                    .await;
273                }
274                Command::Pong { .. } => {
275                    Node::log(format!("Got unhandled Pong"));
276                }
277                Command::GetPeers => {
278                    let peers: Vec<String> = {
279                        let node_read = node.read().await;
280                        let mut peer_addrs = Vec::new();
281                        for p in &node_read.peers {
282                            if p.read().await.is_client {
283                                continue;
284                            }
285                            let p_addr = p.read().await.address.to_string();
286                            peer_addrs.push(p_addr);
287                        }
288                        peer_addrs
289                    };
290                    let response = message.make_response(Command::SendPeers { peers });
291                    Peer::send(peer, response).await;
292                }
293                Command::SendPeers { .. } => {
294                    Node::log(format!("Got unhandled SendPeers"));
295                }
296                Command::NewBlock { ref block } => {
297                    // Make sure block is not in the blockchain
298                    if Some(node.read().await.last_seen_block) != block.hash {
299                        Node::submit_block(node.clone(), block.clone()).await?;
300                    }
301                }
302                Command::NewTransaction { ref transaction } => {
303                    // Check if transaction was already seen
304                    if !node
305                        .read()
306                        .await
307                        .mempool
308                        .validate_transaction(transaction)
309                        .await
310                    {
311                        return Ok(());
312                    }
313
314                    Node::submit_transaction(node, transaction.clone()).await?;
315                }
316                Command::GetBlock { block_hash } => {
317                    Peer::send(
318                        peer,
319                        message.make_response(Command::GetBlockResponse {
320                            block: node.read().await.blockchain.get_block_by_hash(&block_hash),
321                        }),
322                    )
323                    .await;
324                }
325                Command::GetBlockResponse { .. } => {
326                    Node::log(format!("Got unhandled SendBlock"));
327                }
328                Command::GetBlockHashes { start, end } => {
329                    let mut block_hashes = Vec::new();
330                    for i in start..end {
331                        if let Some(block_hash) =
332                            node.read().await.blockchain.get_block_hash_by_height(i)
333                        {
334                            block_hashes.push(*block_hash);
335                        }
336                    }
337                    Peer::send(
338                        peer,
339                        message.make_response(Command::GetBlockHashesResponse { block_hashes }),
340                    )
341                    .await;
342                }
343                Command::GetBlockHashesResponse { .. } => {
344                    Node::log(format!("Got unhandled SendBlockHashes"));
345                }
346            };
347            Ok::<(), PeerError>(())
348        }
349        .await
350        {
351            Node::log(format!("Error processing incoming message: {err}"));
352        }
353    }
354
355    /// Send a request and wait for the response
356    pub async fn request(peer: Arc<RwLock<Peer>>, message: Message) -> Result<Message, PeerError> {
357        let id = message.id;
358
359        let (tx, rx) = oneshot::channel();
360
361        {
362            let mut p = peer.write().await;
363            p.pending.insert(id, tx);
364            p.send_queue.push_back(message);
365        }
366
367        match timeout(Duration::from_secs(10), rx).await {
368            Ok(Ok(msg)) => Ok(msg),
369            Ok(Err(_)) => Err(PeerError::Disconnected),
370            Err(_) => Err(PeerError::Disconnected),
371        }
372    }
373
374    /// Send a message to this peer, without expecting a response
375    pub async fn send(peer: Arc<RwLock<Peer>>, message: Message) {
376        let mut p = peer.write().await;
377        p.send_queue.push_back(message);
378    }
379
380    /// Send this message to all peers but this one
381    pub async fn send_to_peers(node: Arc<RwLock<Node>>, message: Message) {
382        // clone the peer list while holding the lock, then drop the lock
383        let peers = {
384            let guard = node.read().await;
385            guard.peers.clone()
386        };
387
388        for peer in peers {
389            // now safe to await
390            Peer::send(peer, message.clone()).await;
391        }
392    }
393}