1use std::collections::HashMap;
7use std::sync::Arc;
8use std::time::Duration;
9
10use futures_util::{SinkExt, StreamExt};
11use serde_json::Value;
12use tokio::net::TcpStream;
13use tokio::sync::{mpsc, oneshot, Mutex};
14use tokio_tungstenite::tungstenite::Message;
15use tokio_tungstenite::{MaybeTlsStream, WebSocketStream};
16
17const DEFAULT_TIMEOUT_MS: u64 = 35_000;
18
19type _WsStream = WebSocketStream<MaybeTlsStream<TcpStream>>;
20
21struct PendingRequest {
22 tx: oneshot::Sender<Result<Value, String>>,
23}
24
25pub struct ConnectorClient {
27 write_tx: Option<mpsc::UnboundedSender<String>>,
28 pending: Arc<Mutex<HashMap<String, PendingRequest>>>,
29 _reader_handle: Option<tokio::task::JoinHandle<()>>,
30}
31
32impl ConnectorClient {
33 pub fn new() -> Self {
34 Self {
35 write_tx: None,
36 pending: Arc::new(Mutex::new(HashMap::new())),
37 _reader_handle: None,
38 }
39 }
40
41 pub async fn connect(&mut self, host: &str, port: u16) -> Result<(), String> {
43 self.disconnect().await;
44
45 let url = format!("ws://{host}:{port}");
46 let (ws, _) = tokio_tungstenite::connect_async(&url)
47 .await
48 .map_err(|e| format!("WebSocket connection failed: {e}"))?;
49
50 let (ws_write, ws_read) = ws.split();
51
52 let (write_tx, mut write_rx) = mpsc::unbounded_channel::<String>();
54 let writer_handle = tokio::spawn(async move {
55 let mut ws_write = ws_write;
56 while let Some(msg) = write_rx.recv().await {
57 if ws_write.send(Message::Text(msg.into())).await.is_err() {
58 break;
59 }
60 }
61 });
62
63 let pending = self.pending.clone();
65 let reader_handle = tokio::spawn(async move {
66 let mut ws_read = ws_read;
67 while let Some(Ok(msg)) = ws_read.next().await {
68 if let Message::Text(text) = msg {
69 let text: &str = text.as_ref();
70 if let Ok(response) = serde_json::from_str::<Value>(text) {
71 let id = response
72 .get("id")
73 .and_then(|v| v.as_str())
74 .unwrap_or("")
75 .to_string();
76
77 let mut pending = pending.lock().await;
78 if let Some(req) = pending.remove(&id) {
79 let result = if let Some(error) = response.get("error") {
80 Err(error
81 .as_str()
82 .unwrap_or("Unknown error")
83 .to_string())
84 } else {
85 Ok(response
86 .get("result")
87 .cloned()
88 .unwrap_or(Value::Null))
89 };
90 let _ = req.tx.send(result);
91 }
92 }
93 }
94 }
95 let mut pending = pending.lock().await;
97 for (_, req) in pending.drain() {
98 let _ = req.tx.send(Err("Connection closed".to_string()));
99 }
100 drop(writer_handle);
101 });
102
103 self.write_tx = Some(write_tx);
104 self._reader_handle = Some(reader_handle);
105
106 Ok(())
107 }
108
109 pub async fn disconnect(&mut self) {
111 self.write_tx = None;
112 if let Some(handle) = self._reader_handle.take() {
113 handle.abort();
114 }
115 let mut pending = self.pending.lock().await;
116 for (_, req) in pending.drain() {
117 let _ = req.tx.send(Err("Disconnected".to_string()));
118 }
119 }
120
121 pub fn is_connected(&self) -> bool {
123 self.write_tx.is_some()
124 }
125
126 pub async fn send(&self, command: Value) -> Result<Value, String> {
128 self.send_with_timeout(command, DEFAULT_TIMEOUT_MS).await
129 }
130
131 pub async fn send_with_timeout(
133 &self,
134 command: Value,
135 timeout_ms: u64,
136 ) -> Result<Value, String> {
137 let write_tx = self
138 .write_tx
139 .as_ref()
140 .ok_or_else(|| "Not connected".to_string())?;
141
142 let id = uuid::Uuid::new_v4().to_string();
143 let (tx, rx) = oneshot::channel();
144
145 {
146 let mut pending = self.pending.lock().await;
147 pending.insert(id.clone(), PendingRequest { tx });
148 }
149
150 let mut msg = match command {
152 Value::Object(map) => map,
153 _ => return Err("Command must be a JSON object".to_string()),
154 };
155 msg.insert("id".to_string(), Value::String(id.clone()));
156
157 let json = serde_json::to_string(&msg).map_err(|e| e.to_string())?;
158 write_tx
159 .send(json)
160 .map_err(|_| "Send failed: connection closed".to_string())?;
161
162 match tokio::time::timeout(Duration::from_millis(timeout_ms), rx).await {
164 Ok(Ok(result)) => result,
165 Ok(Err(_)) => {
166 self.pending.lock().await.remove(&id);
167 Err("Response channel closed".to_string())
168 }
169 Err(_) => {
170 self.pending.lock().await.remove(&id);
171 Err("Request timeout".to_string())
172 }
173 }
174 }
175}
176
177impl Default for ConnectorClient {
178 fn default() -> Self {
179 Self::new()
180 }
181}