chik_sdk_client/
peer.rs

1use std::{net::SocketAddr, sync::Arc, time::Duration};
2
3use chik_protocol::{
4    Bytes32, ChikProtocolMessage, CoinStateFilters, Message, PuzzleSolutionResponse,
5    RegisterForCoinUpdates, RegisterForPhUpdates, RejectCoinState, RejectPuzzleSolution,
6    RejectPuzzleState, RequestChildren, RequestCoinState, RequestPeers, RequestPuzzleSolution,
7    RequestPuzzleState, RequestRemoveCoinSubscriptions, RequestRemovePuzzleSubscriptions,
8    RequestTransaction, RespondChildren, RespondCoinState, RespondPeers, RespondPuzzleSolution,
9    RespondPuzzleState, RespondRemoveCoinSubscriptions, RespondRemovePuzzleSubscriptions,
10    RespondToCoinUpdates, RespondToPhUpdates, RespondTransaction, SendTransaction, SpendBundle,
11    TransactionAck,
12};
13use chik_traits::Streamable;
14use futures_util::{
15    stream::{SplitSink, SplitStream},
16    SinkExt, StreamExt,
17};
18use tokio::{
19    net::TcpStream,
20    sync::{mpsc, oneshot, Mutex},
21    task::JoinHandle,
22};
23use tokio_tungstenite::{MaybeTlsStream, WebSocketStream};
24use tracing::{debug, warn};
25
26use crate::{request_map::RequestMap, ClientError, RateLimiter, V2_RATE_LIMITS};
27
28#[cfg(any(feature = "native-tls", feature = "rustls"))]
29use tokio_tungstenite::Connector;
30
31type WebSocket = WebSocketStream<MaybeTlsStream<TcpStream>>;
32type Sink = SplitSink<WebSocket, tungstenite::Message>;
33type Stream = SplitStream<WebSocket>;
34type Response<T, E> = std::result::Result<T, E>;
35
36#[derive(Debug, Clone, Copy)]
37pub struct PeerOptions {
38    pub rate_limit_factor: f64,
39}
40
41impl Default for PeerOptions {
42    fn default() -> Self {
43        Self {
44            rate_limit_factor: 0.6,
45        }
46    }
47}
48
49#[derive(Debug, Clone)]
50pub struct Peer(Arc<PeerInner>);
51
52#[derive(Debug)]
53struct PeerInner {
54    sink: Mutex<Sink>,
55    inbound_handle: JoinHandle<()>,
56    requests: Arc<RequestMap>,
57    socket_addr: SocketAddr,
58    outbound_rate_limiter: Mutex<RateLimiter>,
59}
60
61impl Peer {
62    /// Connects to a peer using its IP address and port.
63    #[cfg(any(feature = "native-tls", feature = "rustls"))]
64    pub async fn connect(
65        socket_addr: SocketAddr,
66        connector: Connector,
67        options: PeerOptions,
68    ) -> Result<(Self, mpsc::Receiver<Message>), ClientError> {
69        Self::connect_full_uri(&format!("wss://{socket_addr}/ws"), connector, options).await
70    }
71
72    /// Connects to a peer using its full websocket URI.
73    /// For example, `wss://127.0.0.1:9678/ws`.
74    #[cfg(any(feature = "native-tls", feature = "rustls"))]
75    pub async fn connect_full_uri(
76        uri: &str,
77        connector: Connector,
78        options: PeerOptions,
79    ) -> Result<(Self, mpsc::Receiver<Message>), ClientError> {
80        let (ws, _) =
81            tokio_tungstenite::connect_async_tls_with_config(uri, None, false, Some(connector))
82                .await?;
83        Self::from_websocket(ws, options)
84    }
85
86    /// Creates a peer from an existing websocket connection.
87    /// The connection must be secured with TLS, so that the certificate can be hashed in a peer id.
88    pub fn from_websocket(
89        ws: WebSocket,
90        options: PeerOptions,
91    ) -> Result<(Self, mpsc::Receiver<Message>), ClientError> {
92        let socket_addr = match ws.get_ref() {
93            #[cfg(feature = "native-tls")]
94            MaybeTlsStream::NativeTls(tls) => {
95                let tls_stream = tls.get_ref();
96                let tcp_stream = tls_stream.get_ref().get_ref();
97                tcp_stream.peer_addr()?
98            }
99            #[cfg(feature = "rustls")]
100            MaybeTlsStream::Rustls(tls) => {
101                let (tcp_stream, _) = tls.get_ref();
102                tcp_stream.peer_addr()?
103            }
104            MaybeTlsStream::Plain(plain) => plain.peer_addr()?,
105            _ => return Err(ClientError::UnsupportedTls),
106        };
107
108        let (sink, stream) = ws.split();
109        let (sender, receiver) = mpsc::channel(32);
110
111        let requests = Arc::new(RequestMap::new());
112        let requests_clone = requests.clone();
113
114        let inbound_handle = tokio::spawn(async move {
115            if let Err(error) = handle_inbound_messages(stream, sender, requests_clone).await {
116                debug!("Error handling message: {error}");
117            }
118        });
119
120        let peer = Self(Arc::new(PeerInner {
121            sink: Mutex::new(sink),
122            inbound_handle,
123            requests,
124            socket_addr,
125            outbound_rate_limiter: Mutex::new(RateLimiter::new(
126                false,
127                60,
128                options.rate_limit_factor,
129                V2_RATE_LIMITS.clone(),
130            )),
131        }));
132
133        Ok((peer, receiver))
134    }
135
136    /// The IP address and port of the peer connection.
137    pub fn socket_addr(&self) -> SocketAddr {
138        self.0.socket_addr
139    }
140
141    pub async fn send_transaction(
142        &self,
143        spend_bundle: SpendBundle,
144    ) -> Result<TransactionAck, ClientError> {
145        self.request_infallible(SendTransaction::new(spend_bundle))
146            .await
147    }
148
149    pub async fn request_puzzle_state(
150        &self,
151        puzzle_hashes: Vec<Bytes32>,
152        previous_height: Option<u32>,
153        header_hash: Bytes32,
154        filters: CoinStateFilters,
155        subscribe_when_finished: bool,
156    ) -> Result<Response<RespondPuzzleState, RejectPuzzleState>, ClientError> {
157        self.request_fallible(RequestPuzzleState::new(
158            puzzle_hashes,
159            previous_height,
160            header_hash,
161            filters,
162            subscribe_when_finished,
163        ))
164        .await
165    }
166
167    pub async fn request_coin_state(
168        &self,
169        coin_ids: Vec<Bytes32>,
170        previous_height: Option<u32>,
171        header_hash: Bytes32,
172        subscribe: bool,
173    ) -> Result<Response<RespondCoinState, RejectCoinState>, ClientError> {
174        self.request_fallible(RequestCoinState::new(
175            coin_ids,
176            previous_height,
177            header_hash,
178            subscribe,
179        ))
180        .await
181    }
182
183    pub async fn register_for_ph_updates(
184        &self,
185        puzzle_hashes: Vec<Bytes32>,
186        min_height: u32,
187    ) -> Result<RespondToPhUpdates, ClientError> {
188        self.request_infallible(RegisterForPhUpdates::new(puzzle_hashes, min_height))
189            .await
190    }
191
192    pub async fn register_for_coin_updates(
193        &self,
194        coin_ids: Vec<Bytes32>,
195        min_height: u32,
196    ) -> Result<RespondToCoinUpdates, ClientError> {
197        self.request_infallible(RegisterForCoinUpdates::new(coin_ids, min_height))
198            .await
199    }
200
201    pub async fn remove_puzzle_subscriptions(
202        &self,
203        puzzle_hashes: Option<Vec<Bytes32>>,
204    ) -> Result<RespondRemovePuzzleSubscriptions, ClientError> {
205        self.request_infallible(RequestRemovePuzzleSubscriptions::new(puzzle_hashes))
206            .await
207    }
208
209    pub async fn remove_coin_subscriptions(
210        &self,
211        coin_ids: Option<Vec<Bytes32>>,
212    ) -> Result<RespondRemoveCoinSubscriptions, ClientError> {
213        self.request_infallible(RequestRemoveCoinSubscriptions::new(coin_ids))
214            .await
215    }
216
217    pub async fn request_transaction(
218        &self,
219        transaction_id: Bytes32,
220    ) -> Result<RespondTransaction, ClientError> {
221        self.request_infallible(RequestTransaction::new(transaction_id))
222            .await
223    }
224
225    pub async fn request_puzzle_and_solution(
226        &self,
227        coin_id: Bytes32,
228        height: u32,
229    ) -> Result<Response<PuzzleSolutionResponse, RejectPuzzleSolution>, ClientError> {
230        match self
231            .request_fallible::<RespondPuzzleSolution, _, _>(RequestPuzzleSolution::new(
232                coin_id, height,
233            ))
234            .await?
235        {
236            Ok(response) => Ok(Ok(response.response)),
237            Err(rejection) => Ok(Err(rejection)),
238        }
239    }
240
241    pub async fn request_children(&self, coin_id: Bytes32) -> Result<RespondChildren, ClientError> {
242        self.request_infallible(RequestChildren::new(coin_id)).await
243    }
244
245    pub async fn request_peers(&self) -> Result<RespondPeers, ClientError> {
246        self.request_infallible(RequestPeers::new()).await
247    }
248
249    /// Sends a message to the peer, but does not expect any response.
250    pub async fn send<T>(&self, body: T) -> Result<(), ClientError>
251    where
252        T: Streamable + ChikProtocolMessage,
253    {
254        self.send_raw(Message {
255            msg_type: T::msg_type(),
256            id: None,
257            data: body.to_bytes()?.into(),
258        })
259        .await?;
260
261        Ok(())
262    }
263
264    /// Sends a message to the peer and expects a message that's either a response or a rejection.
265    pub async fn request_fallible<T, E, B>(&self, body: B) -> Result<Response<T, E>, ClientError>
266    where
267        T: Streamable + ChikProtocolMessage,
268        E: Streamable + ChikProtocolMessage,
269        B: Streamable + ChikProtocolMessage,
270    {
271        let message = self.request_raw(body).await?;
272        if message.msg_type != T::msg_type() && message.msg_type != E::msg_type() {
273            return Err(ClientError::InvalidResponse(
274                vec![T::msg_type(), E::msg_type()],
275                message.msg_type,
276            ));
277        }
278        if message.msg_type == T::msg_type() {
279            Ok(Ok(T::from_bytes(&message.data)?))
280        } else {
281            Ok(Err(E::from_bytes(&message.data)?))
282        }
283    }
284
285    /// Sends a message to the peer and expects a specific response message.
286    pub async fn request_infallible<T, B>(&self, body: B) -> Result<T, ClientError>
287    where
288        T: Streamable + ChikProtocolMessage,
289        B: Streamable + ChikProtocolMessage,
290    {
291        let message = self.request_raw(body).await?;
292        if message.msg_type != T::msg_type() {
293            return Err(ClientError::InvalidResponse(
294                vec![T::msg_type()],
295                message.msg_type,
296            ));
297        }
298        Ok(T::from_bytes(&message.data)?)
299    }
300
301    /// Sends a message to the peer and expects any arbitrary protocol message without parsing it.
302    pub async fn request_raw<T>(&self, body: T) -> Result<Message, ClientError>
303    where
304        T: Streamable + ChikProtocolMessage,
305    {
306        let (sender, receiver) = oneshot::channel();
307
308        self.send_raw(Message {
309            msg_type: T::msg_type(),
310            id: Some(self.0.requests.insert(sender).await),
311            data: body.to_bytes()?.into(),
312        })
313        .await?;
314
315        Ok(receiver.await?)
316    }
317
318    async fn send_raw(&self, message: Message) -> Result<(), ClientError> {
319        loop {
320            if !self
321                .0
322                .outbound_rate_limiter
323                .lock()
324                .await
325                .handle_message(&message)
326            {
327                tokio::time::sleep(Duration::from_secs(1)).await;
328                continue;
329            }
330
331            self.0
332                .sink
333                .lock()
334                .await
335                .send(message.to_bytes()?.into())
336                .await?;
337
338            return Ok(());
339        }
340    }
341
342    pub async fn close(&self) -> Result<(), ClientError> {
343        self.0.sink.lock().await.close().await?;
344        Ok(())
345    }
346}
347
348impl Drop for PeerInner {
349    fn drop(&mut self) {
350        self.inbound_handle.abort();
351    }
352}
353
354async fn handle_inbound_messages(
355    mut stream: Stream,
356    sender: mpsc::Sender<Message>,
357    requests: Arc<RequestMap>,
358) -> Result<(), ClientError> {
359    use tungstenite::Message::{Binary, Close, Frame, Ping, Pong, Text};
360
361    while let Some(message) = stream.next().await {
362        let message = message?;
363
364        match message {
365            Frame(..) => unreachable!(),
366            Close(..) => break,
367            Ping(..) | Pong(..) => {}
368            Text(text) => {
369                warn!("Received unexpected text message: {text}");
370            }
371            Binary(binary) => {
372                let message = Message::from_bytes(&binary)?;
373
374                let Some(id) = message.id else {
375                    sender.send(message).await.ok();
376                    continue;
377                };
378
379                let Some(request) = requests.remove(id).await else {
380                    warn!(
381                        "Received {:?} message with untracked id {id}",
382                        message.msg_type
383                    );
384                    return Err(ClientError::UnexpectedMessage(message.msg_type));
385                };
386
387                request.send(message);
388            }
389        }
390    }
391    Ok(())
392}