1use std::{net::SocketAddr, sync::Arc, time::Duration};
2
3use chik_protocol::{
4 Bytes32, ChikProtocolMessage, CoinStateFilters, Message, PuzzleSolutionResponse,
5 RegisterForCoinUpdates, RegisterForPhUpdates, RejectCoinState, RejectPuzzleSolution,
6 RejectPuzzleState, RequestChildren, RequestCoinState, RequestPeers, RequestPuzzleSolution,
7 RequestPuzzleState, RequestRemoveCoinSubscriptions, RequestRemovePuzzleSubscriptions,
8 RequestTransaction, RespondChildren, RespondCoinState, RespondPeers, RespondPuzzleSolution,
9 RespondPuzzleState, RespondRemoveCoinSubscriptions, RespondRemovePuzzleSubscriptions,
10 RespondToCoinUpdates, RespondToPhUpdates, RespondTransaction, SendTransaction, SpendBundle,
11 TransactionAck,
12};
13use chik_traits::Streamable;
14use futures_util::{
15 stream::{SplitSink, SplitStream},
16 SinkExt, StreamExt,
17};
18use tokio::{
19 net::TcpStream,
20 sync::{mpsc, oneshot, Mutex},
21 task::JoinHandle,
22};
23use tokio_tungstenite::{MaybeTlsStream, WebSocketStream};
24use tracing::{debug, warn};
25
26use crate::{request_map::RequestMap, ClientError, RateLimiter, V2_RATE_LIMITS};
27
28#[cfg(any(feature = "native-tls", feature = "rustls"))]
29use tokio_tungstenite::Connector;
30
31type WebSocket = WebSocketStream<MaybeTlsStream<TcpStream>>;
32type Sink = SplitSink<WebSocket, tungstenite::Message>;
33type Stream = SplitStream<WebSocket>;
34type Response<T, E> = std::result::Result<T, E>;
35
36#[derive(Debug, Clone, Copy)]
37pub struct PeerOptions {
38 pub rate_limit_factor: f64,
39}
40
41impl Default for PeerOptions {
42 fn default() -> Self {
43 Self {
44 rate_limit_factor: 0.6,
45 }
46 }
47}
48
49#[derive(Debug, Clone)]
50pub struct Peer(Arc<PeerInner>);
51
52#[derive(Debug)]
53struct PeerInner {
54 sink: Mutex<Sink>,
55 inbound_handle: JoinHandle<()>,
56 requests: Arc<RequestMap>,
57 socket_addr: SocketAddr,
58 outbound_rate_limiter: Mutex<RateLimiter>,
59}
60
61impl Peer {
62 #[cfg(any(feature = "native-tls", feature = "rustls"))]
64 pub async fn connect(
65 socket_addr: SocketAddr,
66 connector: Connector,
67 options: PeerOptions,
68 ) -> Result<(Self, mpsc::Receiver<Message>), ClientError> {
69 Self::connect_full_uri(&format!("wss://{socket_addr}/ws"), connector, options).await
70 }
71
72 #[cfg(any(feature = "native-tls", feature = "rustls"))]
75 pub async fn connect_full_uri(
76 uri: &str,
77 connector: Connector,
78 options: PeerOptions,
79 ) -> Result<(Self, mpsc::Receiver<Message>), ClientError> {
80 let (ws, _) =
81 tokio_tungstenite::connect_async_tls_with_config(uri, None, false, Some(connector))
82 .await?;
83 Self::from_websocket(ws, options)
84 }
85
86 pub fn from_websocket(
89 ws: WebSocket,
90 options: PeerOptions,
91 ) -> Result<(Self, mpsc::Receiver<Message>), ClientError> {
92 let socket_addr = match ws.get_ref() {
93 #[cfg(feature = "native-tls")]
94 MaybeTlsStream::NativeTls(tls) => {
95 let tls_stream = tls.get_ref();
96 let tcp_stream = tls_stream.get_ref().get_ref();
97 tcp_stream.peer_addr()?
98 }
99 #[cfg(feature = "rustls")]
100 MaybeTlsStream::Rustls(tls) => {
101 let (tcp_stream, _) = tls.get_ref();
102 tcp_stream.peer_addr()?
103 }
104 MaybeTlsStream::Plain(plain) => plain.peer_addr()?,
105 _ => return Err(ClientError::UnsupportedTls),
106 };
107
108 let (sink, stream) = ws.split();
109 let (sender, receiver) = mpsc::channel(32);
110
111 let requests = Arc::new(RequestMap::new());
112 let requests_clone = requests.clone();
113
114 let inbound_handle = tokio::spawn(async move {
115 if let Err(error) = handle_inbound_messages(stream, sender, requests_clone).await {
116 debug!("Error handling message: {error}");
117 }
118 });
119
120 let peer = Self(Arc::new(PeerInner {
121 sink: Mutex::new(sink),
122 inbound_handle,
123 requests,
124 socket_addr,
125 outbound_rate_limiter: Mutex::new(RateLimiter::new(
126 false,
127 60,
128 options.rate_limit_factor,
129 V2_RATE_LIMITS.clone(),
130 )),
131 }));
132
133 Ok((peer, receiver))
134 }
135
136 pub fn socket_addr(&self) -> SocketAddr {
138 self.0.socket_addr
139 }
140
141 pub async fn send_transaction(
142 &self,
143 spend_bundle: SpendBundle,
144 ) -> Result<TransactionAck, ClientError> {
145 self.request_infallible(SendTransaction::new(spend_bundle))
146 .await
147 }
148
149 pub async fn request_puzzle_state(
150 &self,
151 puzzle_hashes: Vec<Bytes32>,
152 previous_height: Option<u32>,
153 header_hash: Bytes32,
154 filters: CoinStateFilters,
155 subscribe_when_finished: bool,
156 ) -> Result<Response<RespondPuzzleState, RejectPuzzleState>, ClientError> {
157 self.request_fallible(RequestPuzzleState::new(
158 puzzle_hashes,
159 previous_height,
160 header_hash,
161 filters,
162 subscribe_when_finished,
163 ))
164 .await
165 }
166
167 pub async fn request_coin_state(
168 &self,
169 coin_ids: Vec<Bytes32>,
170 previous_height: Option<u32>,
171 header_hash: Bytes32,
172 subscribe: bool,
173 ) -> Result<Response<RespondCoinState, RejectCoinState>, ClientError> {
174 self.request_fallible(RequestCoinState::new(
175 coin_ids,
176 previous_height,
177 header_hash,
178 subscribe,
179 ))
180 .await
181 }
182
183 pub async fn register_for_ph_updates(
184 &self,
185 puzzle_hashes: Vec<Bytes32>,
186 min_height: u32,
187 ) -> Result<RespondToPhUpdates, ClientError> {
188 self.request_infallible(RegisterForPhUpdates::new(puzzle_hashes, min_height))
189 .await
190 }
191
192 pub async fn register_for_coin_updates(
193 &self,
194 coin_ids: Vec<Bytes32>,
195 min_height: u32,
196 ) -> Result<RespondToCoinUpdates, ClientError> {
197 self.request_infallible(RegisterForCoinUpdates::new(coin_ids, min_height))
198 .await
199 }
200
201 pub async fn remove_puzzle_subscriptions(
202 &self,
203 puzzle_hashes: Option<Vec<Bytes32>>,
204 ) -> Result<RespondRemovePuzzleSubscriptions, ClientError> {
205 self.request_infallible(RequestRemovePuzzleSubscriptions::new(puzzle_hashes))
206 .await
207 }
208
209 pub async fn remove_coin_subscriptions(
210 &self,
211 coin_ids: Option<Vec<Bytes32>>,
212 ) -> Result<RespondRemoveCoinSubscriptions, ClientError> {
213 self.request_infallible(RequestRemoveCoinSubscriptions::new(coin_ids))
214 .await
215 }
216
217 pub async fn request_transaction(
218 &self,
219 transaction_id: Bytes32,
220 ) -> Result<RespondTransaction, ClientError> {
221 self.request_infallible(RequestTransaction::new(transaction_id))
222 .await
223 }
224
225 pub async fn request_puzzle_and_solution(
226 &self,
227 coin_id: Bytes32,
228 height: u32,
229 ) -> Result<Response<PuzzleSolutionResponse, RejectPuzzleSolution>, ClientError> {
230 match self
231 .request_fallible::<RespondPuzzleSolution, _, _>(RequestPuzzleSolution::new(
232 coin_id, height,
233 ))
234 .await?
235 {
236 Ok(response) => Ok(Ok(response.response)),
237 Err(rejection) => Ok(Err(rejection)),
238 }
239 }
240
241 pub async fn request_children(&self, coin_id: Bytes32) -> Result<RespondChildren, ClientError> {
242 self.request_infallible(RequestChildren::new(coin_id)).await
243 }
244
245 pub async fn request_peers(&self) -> Result<RespondPeers, ClientError> {
246 self.request_infallible(RequestPeers::new()).await
247 }
248
249 pub async fn send<T>(&self, body: T) -> Result<(), ClientError>
251 where
252 T: Streamable + ChikProtocolMessage,
253 {
254 self.send_raw(Message {
255 msg_type: T::msg_type(),
256 id: None,
257 data: body.to_bytes()?.into(),
258 })
259 .await?;
260
261 Ok(())
262 }
263
264 pub async fn request_fallible<T, E, B>(&self, body: B) -> Result<Response<T, E>, ClientError>
266 where
267 T: Streamable + ChikProtocolMessage,
268 E: Streamable + ChikProtocolMessage,
269 B: Streamable + ChikProtocolMessage,
270 {
271 let message = self.request_raw(body).await?;
272 if message.msg_type != T::msg_type() && message.msg_type != E::msg_type() {
273 return Err(ClientError::InvalidResponse(
274 vec![T::msg_type(), E::msg_type()],
275 message.msg_type,
276 ));
277 }
278 if message.msg_type == T::msg_type() {
279 Ok(Ok(T::from_bytes(&message.data)?))
280 } else {
281 Ok(Err(E::from_bytes(&message.data)?))
282 }
283 }
284
285 pub async fn request_infallible<T, B>(&self, body: B) -> Result<T, ClientError>
287 where
288 T: Streamable + ChikProtocolMessage,
289 B: Streamable + ChikProtocolMessage,
290 {
291 let message = self.request_raw(body).await?;
292 if message.msg_type != T::msg_type() {
293 return Err(ClientError::InvalidResponse(
294 vec![T::msg_type()],
295 message.msg_type,
296 ));
297 }
298 Ok(T::from_bytes(&message.data)?)
299 }
300
301 pub async fn request_raw<T>(&self, body: T) -> Result<Message, ClientError>
303 where
304 T: Streamable + ChikProtocolMessage,
305 {
306 let (sender, receiver) = oneshot::channel();
307
308 self.send_raw(Message {
309 msg_type: T::msg_type(),
310 id: Some(self.0.requests.insert(sender).await),
311 data: body.to_bytes()?.into(),
312 })
313 .await?;
314
315 Ok(receiver.await?)
316 }
317
318 async fn send_raw(&self, message: Message) -> Result<(), ClientError> {
319 loop {
320 if !self
321 .0
322 .outbound_rate_limiter
323 .lock()
324 .await
325 .handle_message(&message)
326 {
327 tokio::time::sleep(Duration::from_secs(1)).await;
328 continue;
329 }
330
331 self.0
332 .sink
333 .lock()
334 .await
335 .send(message.to_bytes()?.into())
336 .await?;
337
338 return Ok(());
339 }
340 }
341
342 pub async fn close(&self) -> Result<(), ClientError> {
343 self.0.sink.lock().await.close().await?;
344 Ok(())
345 }
346}
347
348impl Drop for PeerInner {
349 fn drop(&mut self) {
350 self.inbound_handle.abort();
351 }
352}
353
354async fn handle_inbound_messages(
355 mut stream: Stream,
356 sender: mpsc::Sender<Message>,
357 requests: Arc<RequestMap>,
358) -> Result<(), ClientError> {
359 use tungstenite::Message::{Binary, Close, Frame, Ping, Pong, Text};
360
361 while let Some(message) = stream.next().await {
362 let message = message?;
363
364 match message {
365 Frame(..) => unreachable!(),
366 Close(..) => break,
367 Ping(..) | Pong(..) => {}
368 Text(text) => {
369 warn!("Received unexpected text message: {text}");
370 }
371 Binary(binary) => {
372 let message = Message::from_bytes(&binary)?;
373
374 let Some(id) = message.id else {
375 sender.send(message).await.ok();
376 continue;
377 };
378
379 let Some(request) = requests.remove(id).await else {
380 warn!(
381 "Received {:?} message with untracked id {id}",
382 message.msg_type
383 );
384 return Err(ClientError::UnexpectedMessage(message.msg_type));
385 };
386
387 request.send(message);
388 }
389 }
390 }
391 Ok(())
392}