reown_relay_client/websocket/
stream.rs1#[cfg(not(target_arch = "wasm32"))]
2use tokio_tungstenite::{
3 connect_async,
4 tungstenite::{protocol::CloseFrame, Message},
5 MaybeTlsStream,
6 WebSocketStream,
7};
8#[cfg(target_arch = "wasm32")]
9use tokio_tungstenite_wasm::{connect as connect_async, CloseFrame, Message, WebSocketStream};
10use {
11 super::{
12 inbound::InboundRequest,
13 outbound::{create_request, OutboundRequest, ResponseFuture},
14 CloseReason,
15 TransportError,
16 WebsocketClientError,
17 },
18 crate::{error::ClientError, HttpRequest, MessageIdGenerator},
19 futures_util::{stream::FusedStream, SinkExt, Stream, StreamExt},
20 reown_relay_rpc::{
21 domain::MessageId,
22 rpc::{self, Params, Payload, Response, ServiceRequest, Subscription},
23 },
24 std::{
25 collections::{hash_map::Entry, HashMap},
26 pin::Pin,
27 task::{Context, Poll},
28 },
29 tokio::sync::{
30 mpsc,
31 mpsc::{UnboundedReceiver, UnboundedSender},
32 oneshot,
33 },
34};
35#[cfg(not(target_arch = "wasm32"))]
36pub type SocketStream = WebSocketStream<MaybeTlsStream<TcpStream>>;
37#[cfg(not(target_arch = "wasm32"))]
38use tokio::net::TcpStream;
39#[cfg(target_arch = "wasm32")]
40pub type SocketStream = WebSocketStream;
41
42#[cfg(not(target_arch = "wasm32"))]
45pub async fn create_stream(request: HttpRequest<()>) -> Result<ClientStream, WebsocketClientError> {
46 let (socket, _) = connect_async(request)
47 .await
48 .map_err(WebsocketClientError::ConnectionFailed)?;
49
50 Ok(ClientStream::new(socket))
51}
52
53#[cfg(target_arch = "wasm32")]
54pub async fn create_stream(request: HttpRequest<()>) -> Result<ClientStream, WebsocketClientError> {
55 let url = format!("{}", request.uri());
56 let socket = connect_async(url)
57 .await
58 .map_err(WebsocketClientError::ConnectionFailed)?;
59
60 Ok(ClientStream::new(socket))
61}
62
63#[derive(Debug)]
67pub enum StreamEvent {
68 InboundSubscriptionRequest(InboundRequest<Subscription>),
73
74 InboundError(ClientError),
77
78 OutboundError(ClientError),
81
82 ConnectionClosed(Option<CloseFrame<'static>>),
86}
87
88pub struct ClientStream {
96 socket: SocketStream,
97 outbound_tx: UnboundedSender<Message>,
98 outbound_rx: UnboundedReceiver<Message>,
99 requests: HashMap<MessageId, oneshot::Sender<Result<serde_json::Value, ClientError>>>,
100 id_generator: MessageIdGenerator,
101 close_frame: Option<CloseFrame<'static>>,
102}
103
104impl ClientStream {
105 pub fn new(socket: SocketStream) -> Self {
106 let requests = HashMap::new();
107 let (outbound_tx, outbound_rx) = mpsc::unbounded_channel();
108 let id_generator = MessageIdGenerator::new();
109
110 Self {
111 socket,
112 outbound_tx,
113 outbound_rx,
114 requests,
115 id_generator,
116 close_frame: None,
117 }
118 }
119
120 pub fn send_raw(&mut self, request: OutboundRequest) {
123 let tx = request.tx;
124 let id = self.id_generator.next();
125 let request = Payload::Request(rpc::Request::new(id, request.params));
126 let serialized = serde_json::to_string(&request);
127
128 match serialized {
129 Ok(data) => match self.requests.entry(id) {
130 Entry::Occupied(_) => {
131 tx.send(Err(ClientError::DuplicateRequestId)).ok();
132 }
133
134 Entry::Vacant(entry) => {
135 entry.insert(tx);
136 self.outbound_tx.send(Message::Text(data)).ok();
137 }
138 },
139
140 Err(err) => {
141 tx.send(Err(ClientError::Serialization(err))).ok();
142 }
143 }
144 }
145
146 pub fn send<T>(&mut self, request: T) -> ResponseFuture<T>
149 where
150 T: ServiceRequest,
151 {
152 let (request, response) = create_request(request);
153 self.send_raw(request);
154 response
155 }
156
157 #[cfg(not(target_arch = "wasm32"))]
159 pub async fn close(&mut self, frame: Option<CloseFrame<'static>>) -> Result<(), ClientError> {
160 self.close_frame = frame.clone();
161 self.socket
162 .close(frame)
163 .await
164 .map_err(|err| WebsocketClientError::ClosingFailed(err).into())
165 }
166
167 #[cfg(target_arch = "wasm32")]
168 pub async fn close(&mut self, frame: Option<CloseFrame<'static>>) -> Result<(), ClientError> {
169 self.close_frame = frame.clone();
170 self.socket
171 .close()
172 .await
173 .map_err(|err| WebsocketClientError::ClosingFailed(err).into())
174 }
175
176 fn parse_inbound(&mut self, result: Result<Message, TransportError>) -> Option<StreamEvent> {
177 match result {
178 Ok(message) => match &message {
179 Message::Binary(_) | Message::Text(_) => {
180 let payload: Payload = match serde_json::from_slice(&message.into_data()) {
181 Ok(payload) => payload,
182
183 Err(err) => {
184 return Some(StreamEvent::InboundError(ClientError::Deserialization(
185 err,
186 )))
187 }
188 };
189
190 match payload {
191 Payload::Request(request) => {
192 let id = request.id;
193
194 let event =
195 match request.params {
196 Params::Subscription(data) => {
197 StreamEvent::InboundSubscriptionRequest(
198 InboundRequest::new(id, data, self.outbound_tx.clone()),
199 )
200 }
201
202 _ => StreamEvent::InboundError(ClientError::InvalidRequestType),
203 };
204
205 Some(event)
206 }
207
208 Payload::Response(response) => {
209 let id = response.id();
210
211 if id.is_zero() {
212 return match response {
213 Response::Error(response) => Some(StreamEvent::InboundError(
214 ClientError::from(response.error),
215 )),
216
217 Response::Success(_) => Some(StreamEvent::InboundError(
218 ClientError::InvalidResponseId,
219 )),
220 };
221 }
222
223 if let Some(tx) = self.requests.remove(&id) {
224 let result = match response {
225 Response::Error(response) => {
226 Err(ClientError::from(response.error))
227 }
228
229 Response::Success(response) => Ok(response.result),
230 };
231
232 tx.send(result).ok();
233
234 if self.requests.len() * 3 < self.requests.capacity() {
236 self.requests.shrink_to_fit();
237 }
238
239 None
240 } else {
241 Some(StreamEvent::InboundError(ClientError::InvalidResponseId))
242 }
243 }
244 }
245 }
246
247 Message::Close(frame) => {
248 self.close_frame = frame.clone();
249 Some(StreamEvent::ConnectionClosed(frame.clone()))
250 }
251 #[cfg(not(target_arch = "wasm32"))]
252 _ => None,
253 },
254
255 Err(error) => Some(StreamEvent::InboundError(
256 WebsocketClientError::Transport(error).into(),
257 )),
258 }
259 }
260
261 fn poll_write(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), TransportError>> {
262 let mut should_flush = false;
263
264 loop {
265 match self.socket.poll_ready_unpin(cx) {
268 Poll::Ready(Ok(())) => {
270 if let Poll::Ready(Some(next_message)) = self.outbound_rx.poll_recv(cx) {
271 if let Err(err) = self.socket.start_send_unpin(next_message) {
272 return Poll::Ready(Err(err));
273 }
274
275 should_flush = true;
276 } else if should_flush {
277 return self.socket.poll_flush_unpin(cx);
279 } else {
280 return Poll::Pending;
281 }
282 }
283
284 Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
285
286 Poll::Pending => return Poll::Pending,
288 }
289 }
290 }
291}
292
293impl Stream for ClientStream {
294 type Item = StreamEvent;
295
296 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
297 #[cfg(not(target_arch = "wasm32"))]
298 if self.socket.is_terminated() {
299 return Poll::Ready(None);
300 }
301
302 while let Poll::Ready(data) = self.socket.poll_next_unpin(cx) {
303 match data {
304 Some(result) => {
305 if let Some(event) = self.parse_inbound(result) {
306 return Poll::Ready(Some(event));
307 }
308 }
309
310 None => {
311 return Poll::Ready(Some(StreamEvent::ConnectionClosed(
312 self.close_frame.clone(),
313 )))
314 }
315 }
316 }
317
318 match self.poll_write(cx) {
319 Poll::Ready(Err(error)) => Poll::Ready(Some(StreamEvent::OutboundError(
320 WebsocketClientError::Transport(error).into(),
321 ))),
322
323 _ => Poll::Pending,
324 }
325 }
326}
327
328impl FusedStream for ClientStream {
329 #[cfg(not(target_arch = "wasm32"))]
330 fn is_terminated(&self) -> bool {
331 self.socket.is_terminated()
332 }
333
334 #[cfg(target_arch = "wasm32")]
335 fn is_terminated(&self) -> bool {
336 false
337 }
338}
339
340impl Drop for ClientStream {
341 fn drop(&mut self) {
342 let reason = CloseReason(self.close_frame.take());
343
344 for (_, tx) in self.requests.drain() {
345 tx.send(Err(
346 WebsocketClientError::ConnectionClosed(reason.clone()).into()
347 ))
348 .ok();
349 }
350 }
351}