Skip to main content

connector_client/
lib.rs

1//! Shared WebSocket client for connecting to tauri-plugin-connector.
2//!
3//! Both the MCP server and CLI use this crate to communicate with the
4//! running Tauri app's connector plugin over WebSocket.
5
6use std::collections::HashMap;
7use std::sync::Arc;
8use std::time::Duration;
9
10use futures_util::{SinkExt, StreamExt};
11use serde_json::Value;
12use tokio::net::TcpStream;
13use tokio::sync::{mpsc, oneshot, Mutex};
14use tokio_tungstenite::tungstenite::Message;
15use tokio_tungstenite::{MaybeTlsStream, WebSocketStream};
16
17const DEFAULT_TIMEOUT_MS: u64 = 35_000;
18
19type _WsStream = WebSocketStream<MaybeTlsStream<TcpStream>>;
20
21struct PendingRequest {
22    tx: oneshot::Sender<Result<Value, String>>,
23}
24
25/// WebSocket client that communicates with tauri-plugin-connector.
26pub struct ConnectorClient {
27    write_tx: Option<mpsc::UnboundedSender<String>>,
28    pending: Arc<Mutex<HashMap<String, PendingRequest>>>,
29    _reader_handle: Option<tokio::task::JoinHandle<()>>,
30}
31
32impl ConnectorClient {
33    pub fn new() -> Self {
34        Self {
35            write_tx: None,
36            pending: Arc::new(Mutex::new(HashMap::new())),
37            _reader_handle: None,
38        }
39    }
40
41    /// Connect to the plugin's WebSocket server.
42    pub async fn connect(&mut self, host: &str, port: u16) -> Result<(), String> {
43        self.disconnect().await;
44
45        let url = format!("ws://{host}:{port}");
46        let (ws, _) = tokio_tungstenite::connect_async(&url)
47            .await
48            .map_err(|e| format!("WebSocket connection failed: {e}"))?;
49
50        let (ws_write, ws_read) = ws.split();
51
52        // Writer task: forwards messages from channel to WebSocket
53        let (write_tx, mut write_rx) = mpsc::unbounded_channel::<String>();
54        let writer_handle = tokio::spawn(async move {
55            let mut ws_write = ws_write;
56            while let Some(msg) = write_rx.recv().await {
57                if ws_write.send(Message::Text(msg.into())).await.is_err() {
58                    break;
59                }
60            }
61        });
62
63        // Reader task: receives messages from WebSocket and resolves pending requests
64        let pending = self.pending.clone();
65        let reader_handle = tokio::spawn(async move {
66            let mut ws_read = ws_read;
67            while let Some(Ok(msg)) = ws_read.next().await {
68                if let Message::Text(text) = msg {
69                    let text: &str = text.as_ref();
70                    if let Ok(response) = serde_json::from_str::<Value>(text) {
71                        let id = response
72                            .get("id")
73                            .and_then(|v| v.as_str())
74                            .unwrap_or("")
75                            .to_string();
76
77                        let mut pending = pending.lock().await;
78                        if let Some(req) = pending.remove(&id) {
79                            let result = if let Some(error) = response.get("error") {
80                                Err(error
81                                    .as_str()
82                                    .unwrap_or("Unknown error")
83                                    .to_string())
84                            } else {
85                                Ok(response
86                                    .get("result")
87                                    .cloned()
88                                    .unwrap_or(Value::Null))
89                            };
90                            let _ = req.tx.send(result);
91                        }
92                    }
93                }
94            }
95            // Connection closed — reject all pending
96            let mut pending = pending.lock().await;
97            for (_, req) in pending.drain() {
98                let _ = req.tx.send(Err("Connection closed".to_string()));
99            }
100            drop(writer_handle);
101        });
102
103        self.write_tx = Some(write_tx);
104        self._reader_handle = Some(reader_handle);
105
106        Ok(())
107    }
108
109    /// Disconnect from the WebSocket server.
110    pub async fn disconnect(&mut self) {
111        self.write_tx = None;
112        if let Some(handle) = self._reader_handle.take() {
113            handle.abort();
114        }
115        let mut pending = self.pending.lock().await;
116        for (_, req) in pending.drain() {
117            let _ = req.tx.send(Err("Disconnected".to_string()));
118        }
119    }
120
121    /// Check if connected.
122    pub fn is_connected(&self) -> bool {
123        self.write_tx.is_some()
124    }
125
126    /// Send a command and wait for a response.
127    pub async fn send(&self, command: Value) -> Result<Value, String> {
128        self.send_with_timeout(command, DEFAULT_TIMEOUT_MS).await
129    }
130
131    /// Send a command with a custom timeout.
132    pub async fn send_with_timeout(
133        &self,
134        command: Value,
135        timeout_ms: u64,
136    ) -> Result<Value, String> {
137        let write_tx = self
138            .write_tx
139            .as_ref()
140            .ok_or_else(|| "Not connected".to_string())?;
141
142        let id = uuid::Uuid::new_v4().to_string();
143        let (tx, rx) = oneshot::channel();
144
145        {
146            let mut pending = self.pending.lock().await;
147            pending.insert(id.clone(), PendingRequest { tx });
148        }
149
150        // Build the message with the id
151        let mut msg = match command {
152            Value::Object(map) => map,
153            _ => return Err("Command must be a JSON object".to_string()),
154        };
155        msg.insert("id".to_string(), Value::String(id.clone()));
156
157        let json = serde_json::to_string(&msg).map_err(|e| e.to_string())?;
158        write_tx
159            .send(json)
160            .map_err(|_| "Send failed: connection closed".to_string())?;
161
162        // Wait for response with timeout
163        match tokio::time::timeout(Duration::from_millis(timeout_ms), rx).await {
164            Ok(Ok(result)) => result,
165            Ok(Err(_)) => {
166                self.pending.lock().await.remove(&id);
167                Err("Response channel closed".to_string())
168            }
169            Err(_) => {
170                self.pending.lock().await.remove(&id);
171                Err("Request timeout".to_string())
172            }
173        }
174    }
175}
176
177impl Default for ConnectorClient {
178    fn default() -> Self {
179        Self::new()
180    }
181}