xrpl_mithril_client/
websocket.rs1use std::collections::HashMap;
28use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
29use std::sync::Arc;
30
31use futures::{SinkExt, StreamExt};
32use tokio::sync::{mpsc, oneshot};
33use tokio_tungstenite::tungstenite::Message;
34use xrpl_mithril_models::requests::XrplRequest;
35
36use crate::client::Client;
37use crate::error::ClientError;
38use crate::subscription::SubscriptionStream;
39
40enum WsCommand {
42 Request {
44 payload: serde_json::Value,
45 response_tx: oneshot::Sender<Result<serde_json::Value, ClientError>>,
46 },
47 Subscribe {
49 stream_tx: mpsc::UnboundedSender<serde_json::Value>,
50 },
51}
52
53pub struct WebSocketClient {
59 command_tx: mpsc::UnboundedSender<WsCommand>,
60 next_id: AtomicU64,
61 connected: Arc<AtomicBool>,
62 _task: tokio::task::JoinHandle<()>,
63}
64
65impl std::fmt::Debug for WebSocketClient {
66 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
67 f.debug_struct("WebSocketClient")
68 .field("connected", &self.connected.load(Ordering::Relaxed))
69 .finish()
70 }
71}
72
73impl WebSocketClient {
74 pub async fn connect(url: &str) -> Result<Self, ClientError> {
83 let (ws_stream, _response) = tokio_tungstenite::connect_async(url)
84 .await
85 .map_err(|e| ClientError::WebSocket(e.to_string()))?;
86
87 let (command_tx, command_rx) = mpsc::unbounded_channel();
88 let connected = Arc::new(AtomicBool::new(true));
89 let connected_clone = Arc::clone(&connected);
90
91 let task = tokio::spawn(Self::run_loop(ws_stream, command_rx, connected_clone));
92
93 Ok(Self {
94 command_tx,
95 next_id: AtomicU64::new(1),
96 connected,
97 _task: task,
98 })
99 }
100
101 pub fn subscribe_stream(&self) -> Result<SubscriptionStream, ClientError> {
113 let (stream_tx, stream_rx) = mpsc::unbounded_channel();
114 self.command_tx
115 .send(WsCommand::Subscribe { stream_tx })
116 .map_err(|_| ClientError::ConnectionClosed {
117 reason: "background task ended".into(),
118 })?;
119 Ok(SubscriptionStream::new(stream_rx))
120 }
121
122 #[must_use]
124 pub fn is_connected(&self) -> bool {
125 self.connected.load(Ordering::Relaxed)
126 }
127
128 async fn run_loop(
130 ws_stream: tokio_tungstenite::WebSocketStream<
131 tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
132 >,
133 mut command_rx: mpsc::UnboundedReceiver<WsCommand>,
134 connected: Arc<AtomicBool>,
135 ) {
136 let (mut ws_sink, mut ws_source) = ws_stream.split();
137
138 let mut pending: HashMap<u64, oneshot::Sender<Result<serde_json::Value, ClientError>>> =
140 HashMap::new();
141
142 let mut subscribers: Vec<mpsc::UnboundedSender<serde_json::Value>> = Vec::new();
144
145 loop {
146 tokio::select! {
147 Some(cmd) = command_rx.recv() => {
149 match cmd {
150 WsCommand::Request { payload, response_tx } => {
151 let id = payload.get("id")
152 .and_then(|v| v.as_u64())
153 .unwrap_or(0);
154 pending.insert(id, response_tx);
155
156 let msg = Message::Text(payload.to_string().into());
157 if let Err(e) = ws_sink.send(msg).await {
158 if let Some(tx) = pending.remove(&id) {
159 let _ = tx.send(Err(ClientError::WebSocket(e.to_string())));
160 }
161 }
162 }
163 WsCommand::Subscribe { stream_tx } => {
164 subscribers.push(stream_tx);
165 }
166 }
167 }
168
169 Some(msg_result) = ws_source.next() => {
171 match msg_result {
172 Ok(Message::Text(text)) => {
173 if let Ok(value) = serde_json::from_str::<serde_json::Value>(&text) {
174 if let Some(id) = value.get("id").and_then(|v| v.as_u64()) {
176 if let Some(tx) = pending.remove(&id) {
177 let result = extract_result(&value);
179 let _ = tx.send(result);
180 }
181 } else {
182 subscribers.retain(|tx| {
185 tx.send(value.clone()).is_ok()
186 });
187 }
188 }
189 }
190 Ok(Message::Close(_)) => {
191 connected.store(false, Ordering::Relaxed);
192 break;
193 }
194 Ok(Message::Ping(data)) => {
195 let _ = ws_sink.send(Message::Pong(data)).await;
196 }
197 Err(e) => {
198 tracing::error!(error = %e, "WebSocket error");
199 connected.store(false, Ordering::Relaxed);
200 break;
201 }
202 _ => {}
203 }
204 }
205
206 else => break,
207 }
208 }
209
210 for (_id, tx) in pending {
212 let _ = tx.send(Err(ClientError::ConnectionClosed {
213 reason: "WebSocket connection closed".into(),
214 }));
215 }
216 connected.store(false, Ordering::Relaxed);
217 }
218}
219
220fn extract_result(value: &serde_json::Value) -> Result<serde_json::Value, ClientError> {
222 if let Some(result) = value.get("result") {
224 if let Some(status) = result.get("status").and_then(|v| v.as_str()) {
225 if status == "error" {
226 let error = result.get("error").and_then(|v| v.as_str()).map(String::from);
227 let code = result
228 .get("error_code")
229 .and_then(|v| v.as_i64())
230 .map(|c| c as i32);
231 let message = result
232 .get("error_message")
233 .and_then(|v| v.as_str())
234 .unwrap_or_else(|| error.as_deref().unwrap_or("unknown error"))
235 .to_string();
236 return Err(ClientError::RpcError {
237 code,
238 message,
239 error,
240 });
241 }
242 }
243 return Ok(result.clone());
244 }
245
246 Ok(value.clone())
248}
249
250impl Client for WebSocketClient {
251 async fn request<R: XrplRequest + Send + Sync>(
252 &self,
253 request: R,
254 ) -> Result<R::Response, ClientError> {
255 let id = self.next_id.fetch_add(1, Ordering::Relaxed);
256
257 let mut params = serde_json::to_value(&request)?;
259 if let Some(map) = params.as_object_mut() {
260 map.insert("id".into(), serde_json::Value::Number(id.into()));
261 map.insert(
262 "command".into(),
263 serde_json::Value::String(request.method().into()),
264 );
265 }
266
267 let (response_tx, response_rx) = oneshot::channel();
268
269 self.command_tx
270 .send(WsCommand::Request {
271 payload: params,
272 response_tx,
273 })
274 .map_err(|_| ClientError::ConnectionClosed {
275 reason: "background task ended".into(),
276 })?;
277
278 let result = response_rx.await.map_err(|_| ClientError::ConnectionClosed {
279 reason: "response channel dropped".into(),
280 })??;
281
282 let response: R::Response = serde_json::from_value(result)?;
283 Ok(response)
284 }
285}