1use std::sync::atomic::{AtomicU16, Ordering};
2use std::{collections::HashMap, sync::Arc};
3
4use chik_protocol::*;
5use chik_traits::Streamable;
6use futures_util::stream::SplitSink;
7use futures_util::{SinkExt, StreamExt};
8use tokio::sync::{broadcast, oneshot, Mutex};
9use tokio::{net::TcpStream, task::JoinHandle};
10use tokio_tungstenite::{MaybeTlsStream, WebSocketStream};
11use tungstenite::Message as WsMessage;
12
13use crate::utils::stream;
14use crate::Error;
15
16type WebSocket = WebSocketStream<MaybeTlsStream<TcpStream>>;
17type Requests = Arc<Mutex<HashMap<u16, oneshot::Sender<Message>>>>;
18
19#[derive(Debug, Clone, PartialEq, Eq)]
20pub enum PeerEvent {
21 CoinStateUpdate(CoinStateUpdate),
22 NewPeakWallet(NewPeakWallet),
23 MempoolItemsAdded(MempoolItemsAdded),
24 MempoolItemsRemoved(MempoolItemsRemoved),
25}
26
27pub struct Peer {
28 sink: Mutex<SplitSink<WebSocket, tungstenite::Message>>,
29 inbound_task: JoinHandle<()>,
30 event_receiver: broadcast::Receiver<PeerEvent>,
31 requests: Requests,
32
33 nonce: AtomicU16,
36}
37
38impl Peer {
39 pub fn new(ws: WebSocket) -> Self {
40 let (sink, mut stream) = ws.split();
41 let (event_sender, event_receiver) = broadcast::channel(32);
42
43 let requests = Requests::default();
44 let requests_clone = Arc::clone(&requests);
45
46 let inbound_task = tokio::spawn(async move {
47 while let Some(message) = stream.next().await {
48 if let Ok(message) = message {
49 Self::handle_inbound(message, &requests_clone, &event_sender)
50 .await
51 .ok();
52 }
53 }
54 });
55
56 Self {
57 sink: Mutex::new(sink),
58 inbound_task,
59 event_receiver,
60 requests,
61 nonce: AtomicU16::new(0),
62 }
63 }
64
65 pub async fn send_handshake(
66 &self,
67 network_id: String,
68 node_type: NodeType,
69 mempool_updates: bool,
70 ) -> Result<(), Error<()>> {
71 let mut capabilities = vec![
72 (1, "1".to_string()),
73 (2, "1".to_string()),
74 (3, "1".to_string()),
75 ];
76
77 if mempool_updates {
78 capabilities.push((5, "1".to_string()));
79 }
80
81 let body = Handshake {
82 network_id,
83 protocol_version: "0.0.34".to_string(),
84 software_version: "0.0.0".to_string(),
85 server_port: 0,
86 node_type,
87 capabilities,
88 };
89 self.send(body).await
90 }
91
92 pub async fn request_puzzle_and_solution(
93 &self,
94 coin_id: Bytes32,
95 height: u32,
96 ) -> Result<PuzzleSolutionResponse, Error<RejectPuzzleSolution>> {
97 let body = RequestPuzzleSolution {
98 coin_name: coin_id,
99 height,
100 };
101 let response: RespondPuzzleSolution = self.request_or_reject(body).await?;
102 Ok(response.response)
103 }
104
105 pub async fn send_transaction(
106 &self,
107 spend_bundle: SpendBundle,
108 ) -> Result<TransactionAck, Error<()>> {
109 let body = SendTransaction {
110 transaction: spend_bundle,
111 };
112 self.request(body).await
113 }
114
115 pub async fn request_block_header(
116 &self,
117 height: u32,
118 ) -> Result<HeaderBlock, Error<RejectHeaderRequest>> {
119 let body = RequestBlockHeader { height };
120 let response: RespondBlockHeader = self.request_or_reject(body).await?;
121 Ok(response.header_block)
122 }
123
124 pub async fn request_block_headers(
125 &self,
126 start_height: u32,
127 end_height: u32,
128 return_filter: bool,
129 ) -> Result<Vec<HeaderBlock>, Error<()>> {
130 let body = RequestBlockHeaders {
131 start_height,
132 end_height,
133 return_filter,
134 };
135 let response: RespondBlockHeaders =
136 self.request_or_reject(body)
137 .await
138 .map_err(|error: Error<RejectBlockHeaders>| match error {
139 Error::Rejection(_rejection) => Error::Rejection(()),
140 Error::Chik(error) => Error::Chik(error),
141 Error::WebSocket(error) => Error::WebSocket(error),
142 Error::InvalidResponse(error) => Error::InvalidResponse(error),
143 Error::MissingResponse => Error::MissingResponse,
144 })?;
145 Ok(response.header_blocks)
146 }
147
148 pub async fn request_removals(
149 &self,
150 height: u32,
151 header_hash: Bytes32,
152 coin_ids: Option<Vec<Bytes32>>,
153 ) -> Result<RespondRemovals, Error<RejectRemovalsRequest>> {
154 let body = RequestRemovals {
155 height,
156 header_hash,
157 coin_names: coin_ids,
158 };
159 self.request_or_reject(body).await
160 }
161
162 pub async fn request_additions(
163 &self,
164 height: u32,
165 header_hash: Option<Bytes32>,
166 puzzle_hashes: Option<Vec<Bytes32>>,
167 ) -> Result<RespondAdditions, Error<RejectAdditionsRequest>> {
168 let body = RequestAdditions {
169 height,
170 header_hash,
171 puzzle_hashes,
172 };
173 self.request_or_reject(body).await
174 }
175
176 pub async fn register_for_ph_updates(
177 &self,
178 puzzle_hashes: Vec<Bytes32>,
179 min_height: u32,
180 ) -> Result<Vec<CoinState>, Error<()>> {
181 let body = RegisterForPhUpdates {
182 puzzle_hashes,
183 min_height,
184 };
185 let response: RespondToPhUpdates = self.request(body).await?;
186 Ok(response.coin_states)
187 }
188
189 pub async fn register_for_coin_updates(
190 &self,
191 coin_ids: Vec<Bytes32>,
192 min_height: u32,
193 ) -> Result<Vec<CoinState>, Error<()>> {
194 let body = RegisterForCoinUpdates {
195 coin_ids,
196 min_height,
197 };
198 let response: RespondToCoinUpdates = self.request(body).await?;
199 Ok(response.coin_states)
200 }
201
202 pub async fn request_children(&self, coin_id: Bytes32) -> Result<Vec<CoinState>, Error<()>> {
203 let body = RequestChildren { coin_name: coin_id };
204 let response: RespondChildren = self.request(body).await?;
205 Ok(response.coin_states)
206 }
207
208 pub async fn request_ses_info(
209 &self,
210 start_height: u32,
211 end_height: u32,
212 ) -> Result<RespondSesInfo, Error<()>> {
213 let body = RequestSesInfo {
214 start_height,
215 end_height,
216 };
217 self.request(body).await
218 }
219
220 pub async fn request_fee_estimates(
221 &self,
222 time_targets: Vec<u64>,
223 ) -> Result<FeeEstimateGroup, Error<()>> {
224 let body = RequestFeeEstimates { time_targets };
225 let response: RespondFeeEstimates = self.request(body).await?;
226 Ok(response.estimates)
227 }
228
229 pub async fn send<T>(&self, body: T) -> Result<(), Error<()>>
230 where
231 T: Streamable + ChikProtocolMessage,
232 {
233 let message = Message {
235 msg_type: T::msg_type(),
236 id: None,
237 data: stream(&body)?.into(),
238 };
239
240 let mut sink = self.sink.lock().await;
242 sink.send(stream(&message)?.into()).await?;
243
244 Ok(())
245 }
246
247 pub async fn request_or_reject<T, R, B>(&self, body: B) -> Result<T, Error<R>>
248 where
249 T: Streamable + ChikProtocolMessage,
250 R: Streamable + ChikProtocolMessage,
251 B: Streamable + ChikProtocolMessage,
252 {
253 let message = self.request_raw(body).await?;
254 let data = message.data.as_ref();
255
256 if message.msg_type == T::msg_type() {
257 T::from_bytes(data).or(Err(Error::InvalidResponse(message)))
258 } else if message.msg_type == R::msg_type() {
259 let rejection = R::from_bytes(data).or(Err(Error::InvalidResponse(message)))?;
260 Err(Error::Rejection(rejection))
261 } else {
262 Err(Error::InvalidResponse(message))
263 }
264 }
265
266 pub async fn request<Response, T>(&self, body: T) -> Result<Response, Error<()>>
267 where
268 Response: Streamable + ChikProtocolMessage,
269 T: Streamable + ChikProtocolMessage,
270 {
271 let message = self.request_raw(body).await?;
272 let data = message.data.as_ref();
273
274 if message.msg_type == Response::msg_type() {
275 Response::from_bytes(data).or(Err(Error::InvalidResponse(message)))
276 } else {
277 Err(Error::InvalidResponse(message))
278 }
279 }
280
281 pub async fn request_raw<T, R>(&self, body: T) -> Result<Message, Error<R>>
282 where
283 T: Streamable + ChikProtocolMessage,
284 {
285 let message_id = self.nonce.fetch_add(1, Ordering::SeqCst);
287
288 let message = Message {
290 msg_type: T::msg_type(),
291 id: Some(message_id),
292 data: stream(&body)?.into(),
293 };
294
295 let (sender, receiver) = oneshot::channel::<Message>();
297 self.requests.lock().await.insert(message_id, sender);
298
299 let bytes = match stream(&message) {
301 Ok(bytes) => bytes.into(),
302 Err(error) => {
303 self.requests.lock().await.remove(&message_id);
304 return Err(error.into());
305 }
306 };
307 let send_result = self.sink.lock().await.send(bytes).await;
308
309 if let Err(error) = send_result {
310 self.requests.lock().await.remove(&message_id);
311 return Err(error.into());
312 }
313
314 let response = receiver.await;
316
317 self.requests.lock().await.remove(&message_id);
319
320 response.or(Err(Error::MissingResponse))
322 }
323
324 pub fn receiver(&self) -> &broadcast::Receiver<PeerEvent> {
325 &self.event_receiver
326 }
327
328 pub fn receiver_mut(&mut self) -> &mut broadcast::Receiver<PeerEvent> {
329 &mut self.event_receiver
330 }
331
332 async fn handle_inbound(
333 message: WsMessage,
334 requests: &Requests,
335 event_sender: &broadcast::Sender<PeerEvent>,
336 ) -> Result<(), Error<()>> {
337 let message = Message::from_bytes(message.into_data().as_ref())?;
339
340 if let Some(id) = message.id {
341 if let Some(request) = requests.lock().await.remove(&id) {
343 request.send(message).ok();
344 }
345 return Ok(());
346 }
347
348 macro_rules! events {
349 ( $( $event:ident ),+ $(,)? ) => {
350 match message.msg_type {
351 $( ProtocolMessageTypes::$event => {
352 event_sender
353 .send(PeerEvent::$event($event::from_bytes(message.data.as_ref())?))
354 .ok();
355 } )+
356 _ => {}
357 }
358 };
359 }
360
361 events!(
363 CoinStateUpdate,
364 NewPeakWallet,
365 MempoolItemsAdded,
366 MempoolItemsRemoved
367 );
368
369 Ok(())
370 }
371}
372
373impl Drop for Peer {
374 fn drop(&mut self) {
375 self.inbound_task.abort();
376 }
377}