snap_coin/node/
peer.rs

1use std::{collections::HashMap, net::SocketAddr, sync::Arc, time::Duration};
2
3use log::{error};
4use tokio::{
5    net::{
6        TcpStream,
7        tcp::{OwnedReadHalf, OwnedWriteHalf},
8    },
9    sync::{
10        Mutex,
11        mpsc::{self, Receiver},
12        oneshot,
13    },
14    time::{sleep, timeout},
15};
16
17use thiserror::Error;
18
19use crate::{
20    core::blockchain::BlockchainError,
21    node::{
22        message::{Command, Message, MessageId},peer_behavior::SharedPeerBehavior
23    },
24};
25
26/// Message expecting a response OR not
27pub enum Outgoing {
28    Request(Message, oneshot::Sender<Message>),
29    OneWay(Message),
30}
31
32type Pending = Arc<Mutex<HashMap<MessageId, oneshot::Sender<Message>>>>;
33type KillSignal = String;
34
35/// Peer timeout, in seconds
36pub const PEER_TIMEOUT: Duration = Duration::from_secs(5);
37
38/// Peer ping interval, in seconds
39pub const PEER_PING_INTERVAL: Duration = Duration::from_secs(5);
40
41#[derive(Error, Debug)]
42pub enum PeerError {
43    #[error("IO error: {0}")]
44    Io(String),
45
46    #[error("Timeout waiting for peer response")]
47    Timeout,
48
49    #[error("Failed to send request to peer: {0}")]
50    SendError(String),
51
52    #[error("Failed to receive response from peer: {0}")]
53    ReceiveError(String),
54
55    #[error("Peer killed: {0}")]
56    Killed(String),
57
58    #[error("Message decoding error: {0}")]
59    MessageDecode(String),
60
61    #[error("Message encoding error: {0}")]
62    MessageEncode(String),
63
64    #[error("Peer disconnected unexpectedly")]
65    Disconnected,
66
67    #[error("Unknown error: {0}")]
68    Unknown(String),
69
70    #[error("Blockchain error: {0}")]
71    Blockchain(#[from] BlockchainError),
72
73    #[error("Sync error: {0}")]
74    SyncError(String),
75}
76
77/// Used to reference, request, and kill
78#[derive(Clone, Debug)]
79pub struct PeerHandle {
80    pub address: SocketAddr,
81    pub is_client: bool,
82    send: mpsc::Sender<Outgoing>,
83    kill: Arc<Mutex<Option<oneshot::Sender<KillSignal>>>>,
84}
85
86impl PeerHandle {
87    /// Send a request message, and expect a response message from this peer
88    pub async fn request(&self, request: Message) -> Result<Message, PeerError> {
89        let (callback_tx, callback_rx) = oneshot::channel::<Message>();
90
91        match timeout(
92            PEER_TIMEOUT,
93            self.send.send(Outgoing::Request(request, callback_tx)),
94        )
95        .await
96        {
97            Ok(res) => res.map_err(|e| PeerError::SendError(e.to_string()))?,
98            Err(_) => {
99                self.kill("Peer timed out".to_string()).await?;
100                return Err(PeerError::Timeout);
101            }
102        }
103
104        callback_rx
105            .await
106            .map_err(|e| PeerError::ReceiveError(e.to_string()))
107    }
108
109    /// Send a message without expecting a response
110    pub async fn send(&self, message: Message) -> Result<(), PeerError> {
111        self.send
112            .send(Outgoing::OneWay(message))
113            .await
114            .map_err(|e| PeerError::SendError(e.to_string()))
115    }
116
117    /// Send a kill signal to this peer
118    pub async fn kill(&self, message: String) -> Result<(), PeerError> {
119        if let Some(kill) = self.kill.lock().await.take() {
120            kill.send(message.clone())
121                .map_err(|_| PeerError::Killed(message))?;
122        }
123        Ok(())
124    }
125}
126
127/// Create a new peer, start internal tasks, and return a PeerHandle
128pub fn create_peer(
129    stream: TcpStream,
130    behavior: SharedPeerBehavior,
131    is_client: bool,
132) -> Result<PeerHandle, PeerError> {
133    let address = stream
134        .peer_addr()
135        .map_err(|e| PeerError::Io(format!("IO error: {e}")))?;
136
137    let (outgoing_tx, outgoing_rx) = mpsc::channel::<Outgoing>(64);
138    let (kill, should_kill) = oneshot::channel::<KillSignal>();
139
140    let handle = PeerHandle {
141        send: outgoing_tx,
142        kill: Arc::new(Mutex::new(Some(kill))),
143        is_client,
144        address,
145    };
146    let my_handle = handle.clone();
147
148    tokio::spawn(async move {
149        let behavior_on_kill = behavior.clone();
150        let my_handle_on_kill = my_handle.clone();
151        if let Err(e) = async move {
152            let (reader, writer) = stream.into_split();
153
154            let pending: Pending =
155                Arc::new(Mutex::new(HashMap::<MessageId, oneshot::Sender<Message>>::new()));
156
157            tokio::select! {
158                res = reader_task(reader, pending.clone(), my_handle.clone(), behavior.clone()) => res,
159                res = writer_task(writer, outgoing_rx, pending) => res,
160                res = pinger_task(my_handle, behavior.clone()) => res,
161                res = async move {
162                    let message = should_kill
163                        .await
164                        .map_err(|_| PeerError::Killed("Kill channel closed".to_string()))?;
165                    Err(PeerError::Killed(message))
166                } => res
167            }?;
168
169            Ok::<(), PeerError>(())
170        }
171        .await
172        {
173            tokio::spawn(async move {
174                behavior_on_kill.on_kill(&my_handle_on_kill).await;
175                error!("Peer error (disconnected): {e}");
176            });
177            
178        }
179    });
180
181    Ok(handle)
182}
183
184async fn reader_task(
185    mut stream: OwnedReadHalf,
186    pending: Pending,
187    my_handle: PeerHandle,
188    behavior: SharedPeerBehavior
189) -> Result<(), PeerError> {
190    loop {
191        let message = Message::from_stream(&mut stream)
192            .await
193            .map_err(|e| PeerError::MessageDecode(e.to_string()))?;
194
195        if let Some(requester) = pending.lock().await.remove(&message.id) {
196            let _ = requester.send(message);
197        } else {
198            let response = behavior.on_message(message, &my_handle).await?;
199            my_handle.send(response).await?;
200        }
201    }
202}
203
204async fn writer_task(
205    mut stream: OwnedWriteHalf,
206    mut receiver: Receiver<Outgoing>,
207    pending: Pending,
208) -> Result<(), PeerError> {
209    while let Some(outgoing) = receiver.recv().await {
210        match outgoing {
211            Outgoing::Request(msg, responder) => {
212                pending.lock().await.insert(msg.id, responder);
213                msg.send(&mut stream)
214                    .await
215                    .map_err(|e| PeerError::MessageEncode(e.to_string()))?;
216            }
217            Outgoing::OneWay(msg) => {
218                msg.send(&mut stream)
219                    .await
220                    .map_err(|e| PeerError::MessageEncode(e.to_string()))?;
221            }
222        }
223    }
224    Err(PeerError::Disconnected)
225}
226
227async fn pinger_task(
228    my_handle: PeerHandle,
229    behavior: SharedPeerBehavior
230) -> Result<(), PeerError> {
231    loop {
232        sleep(PEER_PING_INTERVAL).await;
233        my_handle.request(
234            Message::new(Command::Ping {
235                height: behavior.get_height().await,
236            }),
237        )
238        .await?;
239    }
240}