Skip to main content

snap_coin/node/
peer.rs

1use std::{collections::HashMap, net::SocketAddr, sync::Arc, time::Duration};
2
3use log::{error};
4use tokio::{
5    net::{
6        TcpStream,
7        tcp::{OwnedReadHalf, OwnedWriteHalf},
8    },
9    sync::{
10        Mutex,
11        mpsc::{self, Receiver},
12        oneshot,
13    },
14    time::{sleep, timeout},
15};
16
17use thiserror::Error;
18
19use crate::{
20    core::blockchain::BlockchainError, light_node::block_meta_store::BlockMetaStoreError, node::{
21        message::{Command, Message, MessageId},peer_behavior::SharedPeerBehavior
22    }
23};
24
25/// Message expecting a response OR not
26pub enum Outgoing {
27    Request(Message, oneshot::Sender<Message>),
28    OneWay(Message),
29}
30
31type Pending = Arc<Mutex<HashMap<MessageId, oneshot::Sender<Message>>>>;
32type KillSignal = String;
33
34/// Peer timeout, in seconds
35pub const PEER_TIMEOUT: Duration = Duration::from_secs(5);
36
37/// Peer ping interval, in seconds
38pub const PEER_PING_INTERVAL: Duration = Duration::from_secs(5);
39
40#[derive(Error, Debug)]
41pub enum PeerError {
42    #[error("IO error: {0}")]
43    Io(String),
44
45    #[error("Timeout waiting for peer response")]
46    Timeout,
47
48    #[error("Failed to send request to peer: {0}")]
49    SendError(String),
50
51    #[error("Failed to receive response from peer: {0}")]
52    ReceiveError(String),
53
54    #[error("Peer killed: {0}")]
55    Killed(String),
56
57    #[error("Message decoding error: {0}")]
58    MessageDecode(String),
59
60    #[error("Message encoding error: {0}")]
61    MessageEncode(String),
62
63    #[error("Peer disconnected unexpectedly")]
64    Disconnected,
65
66    #[error("Unknown error: {0}")]
67    Unknown(String),
68
69    #[error("Blockchain error: {0}")]
70    Blockchain(#[from] BlockchainError),
71
72    #[error("Meta Store error: {0}")]
73    BlockMetaStore(#[from] BlockMetaStoreError),
74
75    #[error("Sync error: {0}")]
76    SyncError(String),
77
78    #[error("Incorrect response received")]
79    IncorrectResponse,
80}
81
82/// Used to reference, request, and kill
83#[derive(Clone, Debug)]
84pub struct PeerHandle {
85    pub address: SocketAddr,
86    pub is_client: bool,
87    send: mpsc::Sender<Outgoing>,
88    kill: Arc<Mutex<Option<oneshot::Sender<KillSignal>>>>,
89}
90
91impl PeerHandle {
92    /// Send a request message, and expect a response message from this peer
93    pub async fn request(&self, request: Message) -> Result<Message, PeerError> {
94        let (callback_tx, callback_rx) = oneshot::channel::<Message>();
95
96        match timeout(
97            PEER_TIMEOUT,
98            self.send.send(Outgoing::Request(request, callback_tx)),
99        )
100        .await
101        {
102            Ok(res) => res.map_err(|e| PeerError::SendError(e.to_string()))?,
103            Err(_) => {
104                self.kill("Peer timed out".to_string()).await?;
105                return Err(PeerError::Timeout);
106            }
107        }
108
109        callback_rx
110            .await
111            .map_err(|e| PeerError::ReceiveError(e.to_string()))
112    }
113
114    /// Send a message without expecting a response
115    pub async fn send(&self, message: Message) -> Result<(), PeerError> {
116        self.send
117            .send(Outgoing::OneWay(message))
118            .await
119            .map_err(|e| PeerError::SendError(e.to_string()))
120    }
121
122    /// Send a kill signal to this peer
123    pub async fn kill(&self, message: String) -> Result<(), PeerError> {
124        if let Some(kill) = self.kill.lock().await.take() {
125            kill.send(message.clone())
126                .map_err(|_| PeerError::Killed(message))?;
127        }
128        Ok(())
129    }
130}
131
132/// Create a new peer, start internal tasks, and return a PeerHandle
133pub fn create_peer(
134    stream: TcpStream,
135    behavior: SharedPeerBehavior,
136    is_client: bool,
137) -> Result<PeerHandle, PeerError> {
138    let address = stream
139        .peer_addr()
140        .map_err(|e| PeerError::Io(format!("IO error: {e}")))?;
141
142    let (outgoing_tx, outgoing_rx) = mpsc::channel::<Outgoing>(64);
143    let (kill, should_kill) = oneshot::channel::<KillSignal>();
144
145    let handle = PeerHandle {
146        send: outgoing_tx,
147        kill: Arc::new(Mutex::new(Some(kill))),
148        is_client,
149        address,
150    };
151    let my_handle = handle.clone();
152
153    tokio::spawn(async move {
154        let behavior_on_kill = behavior.clone();
155        let my_handle_on_kill = my_handle.clone();
156        if let Err(e) = async move {
157            let (reader, writer) = stream.into_split();
158
159            let pending: Pending =
160                Arc::new(Mutex::new(HashMap::<MessageId, oneshot::Sender<Message>>::new()));
161
162            tokio::select! {
163                res = reader_task(reader, pending.clone(), my_handle.clone(), behavior.clone()) => res,
164                res = writer_task(writer, outgoing_rx, pending) => res,
165                res = pinger_task(my_handle, behavior.clone()) => res,
166                res = async move {
167                    let message = should_kill
168                        .await
169                        .map_err(|_| PeerError::Killed("Kill channel closed".to_string()))?;
170                    Err(PeerError::Killed(message))
171                } => res
172            }?;
173
174            Ok::<(), PeerError>(())
175        }
176        .await
177        {
178            tokio::spawn(async move {
179                behavior_on_kill.on_kill(&my_handle_on_kill).await;
180                error!("Peer error (disconnected): {e}");
181            });
182            
183        }
184    });
185
186    Ok(handle)
187}
188
189async fn reader_task(
190    mut stream: OwnedReadHalf,
191    pending: Pending,
192    my_handle: PeerHandle,
193    behavior: SharedPeerBehavior
194) -> Result<(), PeerError> {
195    loop {
196        let message = Message::from_stream(&mut stream)
197            .await
198            .map_err(|e| PeerError::MessageDecode(e.to_string()))?;
199
200        if let Some(requester) = pending.lock().await.remove(&message.id) {
201            let _ = requester.send(message);
202        } else {
203            let response = behavior.on_message(message, &my_handle).await?;
204            my_handle.send(response).await?;
205        }
206    }
207}
208
209async fn writer_task(
210    mut stream: OwnedWriteHalf,
211    mut receiver: Receiver<Outgoing>,
212    pending: Pending,
213) -> Result<(), PeerError> {
214    while let Some(outgoing) = receiver.recv().await {
215        match outgoing {
216            Outgoing::Request(msg, responder) => {
217                pending.lock().await.insert(msg.id, responder);
218                msg.send(&mut stream)
219                    .await
220                    .map_err(|e| PeerError::MessageEncode(e.to_string()))?;
221            }
222            Outgoing::OneWay(msg) => {
223                msg.send(&mut stream)
224                    .await
225                    .map_err(|e| PeerError::MessageEncode(e.to_string()))?;
226            }
227        }
228    }
229    Err(PeerError::Disconnected)
230}
231
232async fn pinger_task(
233    my_handle: PeerHandle,
234    behavior: SharedPeerBehavior
235) -> Result<(), PeerError> {
236    loop {
237        sleep(PEER_PING_INTERVAL).await;
238        my_handle.request(
239            Message::new(Command::Ping {
240                height: behavior.get_height().await,
241            }),
242        )
243        .await?;
244    }
245}