reown_relay_client/websocket/
stream.rs

1#[cfg(not(target_arch = "wasm32"))]
2use tokio_tungstenite::{
3    connect_async,
4    tungstenite::{protocol::CloseFrame, Message},
5    MaybeTlsStream,
6    WebSocketStream,
7};
8#[cfg(target_arch = "wasm32")]
9use tokio_tungstenite_wasm::{connect as connect_async, CloseFrame, Message, WebSocketStream};
10use {
11    super::{
12        inbound::InboundRequest,
13        outbound::{create_request, OutboundRequest, ResponseFuture},
14        CloseReason,
15        TransportError,
16        WebsocketClientError,
17    },
18    crate::{error::ClientError, HttpRequest, MessageIdGenerator},
19    futures_util::{stream::FusedStream, SinkExt, Stream, StreamExt},
20    reown_relay_rpc::{
21        domain::MessageId,
22        rpc::{self, Params, Payload, Response, ServiceRequest, Subscription},
23    },
24    std::{
25        collections::{hash_map::Entry, HashMap},
26        pin::Pin,
27        task::{Context, Poll},
28    },
29    tokio::sync::{
30        mpsc,
31        mpsc::{UnboundedReceiver, UnboundedSender},
32        oneshot,
33    },
34};
35#[cfg(not(target_arch = "wasm32"))]
36pub type SocketStream = WebSocketStream<MaybeTlsStream<TcpStream>>;
37#[cfg(not(target_arch = "wasm32"))]
38use tokio::net::TcpStream;
39#[cfg(target_arch = "wasm32")]
40pub type SocketStream = WebSocketStream;
41
42/// Opens a connection to the Relay and returns [`ClientStream`] for the
43/// connection.
44#[cfg(not(target_arch = "wasm32"))]
45pub async fn create_stream(request: HttpRequest<()>) -> Result<ClientStream, WebsocketClientError> {
46    let (socket, _) = connect_async(request)
47        .await
48        .map_err(WebsocketClientError::ConnectionFailed)?;
49
50    Ok(ClientStream::new(socket))
51}
52
53#[cfg(target_arch = "wasm32")]
54pub async fn create_stream(request: HttpRequest<()>) -> Result<ClientStream, WebsocketClientError> {
55    let url = format!("{}", request.uri());
56    let socket = connect_async(url)
57        .await
58        .map_err(WebsocketClientError::ConnectionFailed)?;
59
60    Ok(ClientStream::new(socket))
61}
62
63/// Possible events produced by the [`ClientStream`].
64///
65/// The events are produced by polling [`ClientStream`] in a loop.
66#[derive(Debug)]
67pub enum StreamEvent {
68    /// Inbound request for receiving a subscription message.
69    ///
70    /// Currently, [`Subscription`] is the only request that the Relay sends to
71    /// the clients.
72    InboundSubscriptionRequest(InboundRequest<Subscription>),
73
74    /// Error generated when failed to parse an inbound message, invalid request
75    /// type or message ID.
76    InboundError(ClientError),
77
78    /// Error generated when failed to write data to the underlying websocket
79    /// stream.
80    OutboundError(ClientError),
81
82    /// The websocket connection was closed.
83    ///
84    /// This is the last event that can be produced by the stream.
85    ConnectionClosed(Option<CloseFrame<'static>>),
86}
87
88/// Lower-level [`FusedStream`] interface for the client connection.
89///
90/// The stream produces [`StreamEvent`] when polled, and can be used to send RPC
91/// requests (see [`ClientStream::send()`] and [`ClientStream::send_raw()`]).
92///
93/// For a higher-level interface see [`Client`](crate::client::Client). For an
94/// example usage of the stream see `client::connection` module.
95pub struct ClientStream {
96    socket: SocketStream,
97    outbound_tx: UnboundedSender<Message>,
98    outbound_rx: UnboundedReceiver<Message>,
99    requests: HashMap<MessageId, oneshot::Sender<Result<serde_json::Value, ClientError>>>,
100    id_generator: MessageIdGenerator,
101    close_frame: Option<CloseFrame<'static>>,
102}
103
104impl ClientStream {
105    pub fn new(socket: SocketStream) -> Self {
106        let requests = HashMap::new();
107        let (outbound_tx, outbound_rx) = mpsc::unbounded_channel();
108        let id_generator = MessageIdGenerator::new();
109
110        Self {
111            socket,
112            outbound_tx,
113            outbound_rx,
114            requests,
115            id_generator,
116            close_frame: None,
117        }
118    }
119
120    /// Sends an already serialized [`OutboundRequest`][OutboundRequest] (see
121    /// [`create_request()`]).
122    pub fn send_raw(&mut self, request: OutboundRequest) {
123        let tx = request.tx;
124        let id = self.id_generator.next();
125        let request = Payload::Request(rpc::Request::new(id, request.params));
126        let serialized = serde_json::to_string(&request);
127
128        match serialized {
129            Ok(data) => match self.requests.entry(id) {
130                Entry::Occupied(_) => {
131                    tx.send(Err(ClientError::DuplicateRequestId)).ok();
132                }
133
134                Entry::Vacant(entry) => {
135                    entry.insert(tx);
136                    self.outbound_tx.send(Message::Text(data)).ok();
137                }
138            },
139
140            Err(err) => {
141                tx.send(Err(ClientError::Serialization(err))).ok();
142            }
143        }
144    }
145
146    /// Serialize the request into a generic [`OutboundRequest`] and sends it,
147    /// returning a future that resolves with the response.
148    pub fn send<T>(&mut self, request: T) -> ResponseFuture<T>
149    where
150        T: ServiceRequest,
151    {
152        let (request, response) = create_request(request);
153        self.send_raw(request);
154        response
155    }
156
157    /// Closes the connection.
158    #[cfg(not(target_arch = "wasm32"))]
159    pub async fn close(&mut self, frame: Option<CloseFrame<'static>>) -> Result<(), ClientError> {
160        self.close_frame = frame.clone();
161        self.socket
162            .close(frame)
163            .await
164            .map_err(|err| WebsocketClientError::ClosingFailed(err).into())
165    }
166
167    #[cfg(target_arch = "wasm32")]
168    pub async fn close(&mut self, frame: Option<CloseFrame<'static>>) -> Result<(), ClientError> {
169        self.close_frame = frame.clone();
170        self.socket
171            .close()
172            .await
173            .map_err(|err| WebsocketClientError::ClosingFailed(err).into())
174    }
175
176    fn parse_inbound(&mut self, result: Result<Message, TransportError>) -> Option<StreamEvent> {
177        match result {
178            Ok(message) => match &message {
179                Message::Binary(_) | Message::Text(_) => {
180                    let payload: Payload = match serde_json::from_slice(&message.into_data()) {
181                        Ok(payload) => payload,
182
183                        Err(err) => {
184                            return Some(StreamEvent::InboundError(ClientError::Deserialization(
185                                err,
186                            )))
187                        }
188                    };
189
190                    match payload {
191                        Payload::Request(request) => {
192                            let id = request.id;
193
194                            let event =
195                                match request.params {
196                                    Params::Subscription(data) => {
197                                        StreamEvent::InboundSubscriptionRequest(
198                                            InboundRequest::new(id, data, self.outbound_tx.clone()),
199                                        )
200                                    }
201
202                                    _ => StreamEvent::InboundError(ClientError::InvalidRequestType),
203                                };
204
205                            Some(event)
206                        }
207
208                        Payload::Response(response) => {
209                            let id = response.id();
210
211                            if id.is_zero() {
212                                return match response {
213                                    Response::Error(response) => Some(StreamEvent::InboundError(
214                                        ClientError::from(response.error),
215                                    )),
216
217                                    Response::Success(_) => Some(StreamEvent::InboundError(
218                                        ClientError::InvalidResponseId,
219                                    )),
220                                };
221                            }
222
223                            if let Some(tx) = self.requests.remove(&id) {
224                                let result = match response {
225                                    Response::Error(response) => {
226                                        Err(ClientError::from(response.error))
227                                    }
228
229                                    Response::Success(response) => Ok(response.result),
230                                };
231
232                                tx.send(result).ok();
233
234                                // Perform compaction if required.
235                                if self.requests.len() * 3 < self.requests.capacity() {
236                                    self.requests.shrink_to_fit();
237                                }
238
239                                None
240                            } else {
241                                Some(StreamEvent::InboundError(ClientError::InvalidResponseId))
242                            }
243                        }
244                    }
245                }
246
247                Message::Close(frame) => {
248                    self.close_frame = frame.clone();
249                    Some(StreamEvent::ConnectionClosed(frame.clone()))
250                }
251                #[cfg(not(target_arch = "wasm32"))]
252                _ => None,
253            },
254
255            Err(error) => Some(StreamEvent::InboundError(
256                WebsocketClientError::Transport(error).into(),
257            )),
258        }
259    }
260
261    fn poll_write(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), TransportError>> {
262        let mut should_flush = false;
263
264        loop {
265            // `poll_ready() needs to be called before each `start_send()` to make sure the
266            // sink is ready to accept more data.
267            match self.socket.poll_ready_unpin(cx) {
268                // The sink is ready to accept more data.
269                Poll::Ready(Ok(())) => {
270                    if let Poll::Ready(Some(next_message)) = self.outbound_rx.poll_recv(cx) {
271                        if let Err(err) = self.socket.start_send_unpin(next_message) {
272                            return Poll::Ready(Err(err));
273                        }
274
275                        should_flush = true;
276                    } else if should_flush {
277                        // We've sent out some messages, now we need to flush.
278                        return self.socket.poll_flush_unpin(cx);
279                    } else {
280                        return Poll::Pending;
281                    }
282                }
283
284                Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
285
286                // The sink is not ready.
287                Poll::Pending => return Poll::Pending,
288            }
289        }
290    }
291}
292
293impl Stream for ClientStream {
294    type Item = StreamEvent;
295
296    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
297        #[cfg(not(target_arch = "wasm32"))]
298        if self.socket.is_terminated() {
299            return Poll::Ready(None);
300        }
301
302        while let Poll::Ready(data) = self.socket.poll_next_unpin(cx) {
303            match data {
304                Some(result) => {
305                    if let Some(event) = self.parse_inbound(result) {
306                        return Poll::Ready(Some(event));
307                    }
308                }
309
310                None => {
311                    return Poll::Ready(Some(StreamEvent::ConnectionClosed(
312                        self.close_frame.clone(),
313                    )))
314                }
315            }
316        }
317
318        match self.poll_write(cx) {
319            Poll::Ready(Err(error)) => Poll::Ready(Some(StreamEvent::OutboundError(
320                WebsocketClientError::Transport(error).into(),
321            ))),
322
323            _ => Poll::Pending,
324        }
325    }
326}
327
328impl FusedStream for ClientStream {
329    #[cfg(not(target_arch = "wasm32"))]
330    fn is_terminated(&self) -> bool {
331        self.socket.is_terminated()
332    }
333
334    #[cfg(target_arch = "wasm32")]
335    fn is_terminated(&self) -> bool {
336        false
337    }
338}
339
340impl Drop for ClientStream {
341    fn drop(&mut self) {
342        let reason = CloseReason(self.close_frame.take());
343
344        for (_, tx) in self.requests.drain() {
345            tx.send(Err(
346                WebsocketClientError::ConnectionClosed(reason.clone()).into()
347            ))
348            .ok();
349        }
350    }
351}