chik_client/
peer.rs

1use std::sync::atomic::{AtomicU16, Ordering};
2use std::{collections::HashMap, sync::Arc};
3
4use chik_protocol::*;
5use chik_traits::Streamable;
6use futures_util::stream::SplitSink;
7use futures_util::{SinkExt, StreamExt};
8use tokio::sync::{broadcast, oneshot, Mutex};
9use tokio::{net::TcpStream, task::JoinHandle};
10use tokio_tungstenite::{MaybeTlsStream, WebSocketStream};
11use tungstenite::Message as WsMessage;
12
13use crate::utils::stream;
14use crate::Error;
15
16type WebSocket = WebSocketStream<MaybeTlsStream<TcpStream>>;
17type Requests = Arc<Mutex<HashMap<u16, oneshot::Sender<Message>>>>;
18
19#[derive(Debug, Clone, PartialEq, Eq)]
20pub enum PeerEvent {
21    CoinStateUpdate(CoinStateUpdate),
22    NewPeakWallet(NewPeakWallet),
23    MempoolItemsAdded(MempoolItemsAdded),
24    MempoolItemsRemoved(MempoolItemsRemoved),
25}
26
27pub struct Peer {
28    sink: Mutex<SplitSink<WebSocket, tungstenite::Message>>,
29    inbound_task: JoinHandle<()>,
30    event_receiver: broadcast::Receiver<PeerEvent>,
31    requests: Requests,
32
33    // TODO: This does not currently prevent multiple requests with the same id at the same time.
34    // If one of them is still running while all other ids are being iterated through.
35    nonce: AtomicU16,
36}
37
38impl Peer {
39    pub fn new(ws: WebSocket) -> Self {
40        let (sink, mut stream) = ws.split();
41        let (event_sender, event_receiver) = broadcast::channel(32);
42
43        let requests = Requests::default();
44        let requests_clone = Arc::clone(&requests);
45
46        let inbound_task = tokio::spawn(async move {
47            while let Some(message) = stream.next().await {
48                if let Ok(message) = message {
49                    Self::handle_inbound(message, &requests_clone, &event_sender)
50                        .await
51                        .ok();
52                }
53            }
54        });
55
56        Self {
57            sink: Mutex::new(sink),
58            inbound_task,
59            event_receiver,
60            requests,
61            nonce: AtomicU16::new(0),
62        }
63    }
64
65    pub async fn send_handshake(
66        &self,
67        network_id: String,
68        node_type: NodeType,
69        mempool_updates: bool,
70    ) -> Result<(), Error<()>> {
71        let mut capabilities = vec![
72            (1, "1".to_string()),
73            (2, "1".to_string()),
74            (3, "1".to_string()),
75        ];
76
77        if mempool_updates {
78            capabilities.push((5, "1".to_string()));
79        }
80
81        let body = Handshake {
82            network_id,
83            protocol_version: "0.0.34".to_string(),
84            software_version: "0.0.0".to_string(),
85            server_port: 0,
86            node_type,
87            capabilities,
88        };
89        self.send(body).await
90    }
91
92    pub async fn request_puzzle_and_solution(
93        &self,
94        coin_id: Bytes32,
95        height: u32,
96    ) -> Result<PuzzleSolutionResponse, Error<RejectPuzzleSolution>> {
97        let body = RequestPuzzleSolution {
98            coin_name: coin_id,
99            height,
100        };
101        let response: RespondPuzzleSolution = self.request_or_reject(body).await?;
102        Ok(response.response)
103    }
104
105    pub async fn send_transaction(
106        &self,
107        spend_bundle: SpendBundle,
108    ) -> Result<TransactionAck, Error<()>> {
109        let body = SendTransaction {
110            transaction: spend_bundle,
111        };
112        self.request(body).await
113    }
114
115    pub async fn request_block_header(
116        &self,
117        height: u32,
118    ) -> Result<HeaderBlock, Error<RejectHeaderRequest>> {
119        let body = RequestBlockHeader { height };
120        let response: RespondBlockHeader = self.request_or_reject(body).await?;
121        Ok(response.header_block)
122    }
123
124    pub async fn request_block_headers(
125        &self,
126        start_height: u32,
127        end_height: u32,
128        return_filter: bool,
129    ) -> Result<Vec<HeaderBlock>, Error<()>> {
130        let body = RequestBlockHeaders {
131            start_height,
132            end_height,
133            return_filter,
134        };
135        let response: RespondBlockHeaders =
136            self.request_or_reject(body)
137                .await
138                .map_err(|error: Error<RejectBlockHeaders>| match error {
139                    Error::Rejection(_rejection) => Error::Rejection(()),
140                    Error::Chik(error) => Error::Chik(error),
141                    Error::WebSocket(error) => Error::WebSocket(error),
142                    Error::InvalidResponse(error) => Error::InvalidResponse(error),
143                    Error::MissingResponse => Error::MissingResponse,
144                })?;
145        Ok(response.header_blocks)
146    }
147
148    pub async fn request_removals(
149        &self,
150        height: u32,
151        header_hash: Bytes32,
152        coin_ids: Option<Vec<Bytes32>>,
153    ) -> Result<RespondRemovals, Error<RejectRemovalsRequest>> {
154        let body = RequestRemovals {
155            height,
156            header_hash,
157            coin_names: coin_ids,
158        };
159        self.request_or_reject(body).await
160    }
161
162    pub async fn request_additions(
163        &self,
164        height: u32,
165        header_hash: Option<Bytes32>,
166        puzzle_hashes: Option<Vec<Bytes32>>,
167    ) -> Result<RespondAdditions, Error<RejectAdditionsRequest>> {
168        let body = RequestAdditions {
169            height,
170            header_hash,
171            puzzle_hashes,
172        };
173        self.request_or_reject(body).await
174    }
175
176    pub async fn register_for_ph_updates(
177        &self,
178        puzzle_hashes: Vec<Bytes32>,
179        min_height: u32,
180    ) -> Result<Vec<CoinState>, Error<()>> {
181        let body = RegisterForPhUpdates {
182            puzzle_hashes,
183            min_height,
184        };
185        let response: RespondToPhUpdates = self.request(body).await?;
186        Ok(response.coin_states)
187    }
188
189    pub async fn register_for_coin_updates(
190        &self,
191        coin_ids: Vec<Bytes32>,
192        min_height: u32,
193    ) -> Result<Vec<CoinState>, Error<()>> {
194        let body = RegisterForCoinUpdates {
195            coin_ids,
196            min_height,
197        };
198        let response: RespondToCoinUpdates = self.request(body).await?;
199        Ok(response.coin_states)
200    }
201
202    pub async fn request_children(&self, coin_id: Bytes32) -> Result<Vec<CoinState>, Error<()>> {
203        let body = RequestChildren { coin_name: coin_id };
204        let response: RespondChildren = self.request(body).await?;
205        Ok(response.coin_states)
206    }
207
208    pub async fn request_ses_info(
209        &self,
210        start_height: u32,
211        end_height: u32,
212    ) -> Result<RespondSesInfo, Error<()>> {
213        let body = RequestSesInfo {
214            start_height,
215            end_height,
216        };
217        self.request(body).await
218    }
219
220    pub async fn request_fee_estimates(
221        &self,
222        time_targets: Vec<u64>,
223    ) -> Result<FeeEstimateGroup, Error<()>> {
224        let body = RequestFeeEstimates { time_targets };
225        let response: RespondFeeEstimates = self.request(body).await?;
226        Ok(response.estimates)
227    }
228
229    pub async fn send<T>(&self, body: T) -> Result<(), Error<()>>
230    where
231        T: Streamable + ChikProtocolMessage,
232    {
233        // Create the message.
234        let message = Message {
235            msg_type: T::msg_type(),
236            id: None,
237            data: stream(&body)?.into(),
238        };
239
240        // Send the message through the websocket.
241        let mut sink = self.sink.lock().await;
242        sink.send(stream(&message)?.into()).await?;
243
244        Ok(())
245    }
246
247    pub async fn request_or_reject<T, R, B>(&self, body: B) -> Result<T, Error<R>>
248    where
249        T: Streamable + ChikProtocolMessage,
250        R: Streamable + ChikProtocolMessage,
251        B: Streamable + ChikProtocolMessage,
252    {
253        let message = self.request_raw(body).await?;
254        let data = message.data.as_ref();
255
256        if message.msg_type == T::msg_type() {
257            T::from_bytes(data).or(Err(Error::InvalidResponse(message)))
258        } else if message.msg_type == R::msg_type() {
259            let rejection = R::from_bytes(data).or(Err(Error::InvalidResponse(message)))?;
260            Err(Error::Rejection(rejection))
261        } else {
262            Err(Error::InvalidResponse(message))
263        }
264    }
265
266    pub async fn request<Response, T>(&self, body: T) -> Result<Response, Error<()>>
267    where
268        Response: Streamable + ChikProtocolMessage,
269        T: Streamable + ChikProtocolMessage,
270    {
271        let message = self.request_raw(body).await?;
272        let data = message.data.as_ref();
273
274        if message.msg_type == Response::msg_type() {
275            Response::from_bytes(data).or(Err(Error::InvalidResponse(message)))
276        } else {
277            Err(Error::InvalidResponse(message))
278        }
279    }
280
281    pub async fn request_raw<T, R>(&self, body: T) -> Result<Message, Error<R>>
282    where
283        T: Streamable + ChikProtocolMessage,
284    {
285        // Get the current nonce and increment.
286        let message_id = self.nonce.fetch_add(1, Ordering::SeqCst);
287
288        // Create the message.
289        let message = Message {
290            msg_type: T::msg_type(),
291            id: Some(message_id),
292            data: stream(&body)?.into(),
293        };
294
295        // Create a saved oneshot channel to receive the response.
296        let (sender, receiver) = oneshot::channel::<Message>();
297        self.requests.lock().await.insert(message_id, sender);
298
299        // Send the message.
300        let bytes = match stream(&message) {
301            Ok(bytes) => bytes.into(),
302            Err(error) => {
303                self.requests.lock().await.remove(&message_id);
304                return Err(error.into());
305            }
306        };
307        let send_result = self.sink.lock().await.send(bytes).await;
308
309        if let Err(error) = send_result {
310            self.requests.lock().await.remove(&message_id);
311            return Err(error.into());
312        }
313
314        // Wait for the response.
315        let response = receiver.await;
316
317        // Remove the one shot channel.
318        self.requests.lock().await.remove(&message_id);
319
320        // Handle the response, if present.
321        response.or(Err(Error::MissingResponse))
322    }
323
324    pub fn receiver(&self) -> &broadcast::Receiver<PeerEvent> {
325        &self.event_receiver
326    }
327
328    pub fn receiver_mut(&mut self) -> &mut broadcast::Receiver<PeerEvent> {
329        &mut self.event_receiver
330    }
331
332    async fn handle_inbound(
333        message: WsMessage,
334        requests: &Requests,
335        event_sender: &broadcast::Sender<PeerEvent>,
336    ) -> Result<(), Error<()>> {
337        // Parse the message.
338        let message = Message::from_bytes(message.into_data().as_ref())?;
339
340        if let Some(id) = message.id {
341            // Send response through oneshot channel if present.
342            if let Some(request) = requests.lock().await.remove(&id) {
343                request.send(message).ok();
344            }
345            return Ok(());
346        }
347
348        macro_rules! events {
349            ( $( $event:ident ),+ $(,)? ) => {
350                match message.msg_type {
351                    $( ProtocolMessageTypes::$event => {
352                        event_sender
353                            .send(PeerEvent::$event($event::from_bytes(message.data.as_ref())?))
354                            .ok();
355                    } )+
356                    _ => {}
357                }
358            };
359        }
360
361        // TODO: Handle unexpected messages.
362        events!(
363            CoinStateUpdate,
364            NewPeakWallet,
365            MempoolItemsAdded,
366            MempoolItemsRemoved
367        );
368
369        Ok(())
370    }
371}
372
373impl Drop for Peer {
374    fn drop(&mut self) {
375        self.inbound_task.abort();
376    }
377}