tauri_plugin_websocket/
lib.rs

1// Copyright 2019-2023 Tauri Programme within The Commons Conservancy
2// SPDX-License-Identifier: Apache-2.0
3// SPDX-License-Identifier: MIT
4
5//! Open a WebSocket connection using a Rust client in JS.
6
7#![doc(
8    html_logo_url = "https://github.com/tauri-apps/tauri/raw/dev/app-icon.png",
9    html_favicon_url = "https://github.com/tauri-apps/tauri/raw/dev/app-icon.png"
10)]
11
12use futures_util::{stream::SplitSink, SinkExt, StreamExt};
13use http::header::{HeaderName, HeaderValue};
14use serde::{ser::Serializer, Deserialize, Serialize};
15use tauri::{
16    ipc::Channel,
17    plugin::{Builder as PluginBuilder, TauriPlugin},
18    Manager, Runtime, State, Window,
19};
20use tokio::{net::TcpStream, sync::Mutex};
21#[cfg(any(
22    feature = "rustls-tls",
23    feature = "rustls-tls-native-roots",
24    feature = "native-tls"
25))]
26use tokio_tungstenite::connect_async_tls_with_config;
27#[cfg(not(any(
28    feature = "rustls-tls",
29    feature = "rustls-tls-native-roots",
30    feature = "native-tls"
31)))]
32use tokio_tungstenite::connect_async_with_config;
33use tokio_tungstenite::{
34    tungstenite::{
35        client::IntoClientRequest,
36        protocol::{CloseFrame as ProtocolCloseFrame, WebSocketConfig},
37        Message,
38    },
39    Connector, MaybeTlsStream, WebSocketStream,
40};
41
42use std::collections::HashMap;
43use std::str::FromStr;
44
45type Id = u32;
46type WebSocket = WebSocketStream<MaybeTlsStream<TcpStream>>;
47type WebSocketWriter = SplitSink<WebSocket, Message>;
48type Result<T> = std::result::Result<T, Error>;
49
50#[derive(Debug, thiserror::Error)]
51enum Error {
52    #[error(transparent)]
53    Websocket(#[from] tokio_tungstenite::tungstenite::Error),
54    #[error("connection not found for the given id: {0}")]
55    ConnectionNotFound(Id),
56    #[error(transparent)]
57    InvalidHeaderValue(#[from] tokio_tungstenite::tungstenite::http::header::InvalidHeaderValue),
58    #[error(transparent)]
59    InvalidHeaderName(#[from] tokio_tungstenite::tungstenite::http::header::InvalidHeaderName),
60}
61
62impl Serialize for Error {
63    fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
64    where
65        S: Serializer,
66    {
67        serializer.serialize_str(self.to_string().as_str())
68    }
69}
70
71#[derive(Default)]
72struct ConnectionManager(Mutex<HashMap<Id, WebSocketWriter>>);
73
74#[cfg(any(
75    feature = "rustls-tls",
76    feature = "rustls-tls-native-roots",
77    feature = "native-tls"
78))]
79struct TlsConnector(Mutex<Option<Connector>>);
80
81#[derive(Deserialize)]
82#[serde(untagged, rename_all = "camelCase")]
83enum Max {
84    None,
85    Number(usize),
86}
87
88#[derive(Deserialize)]
89#[serde(rename_all = "camelCase")]
90pub(crate) struct ConnectionConfig {
91    pub read_buffer_size: Option<usize>,
92    pub write_buffer_size: Option<usize>,
93    pub max_write_buffer_size: Option<usize>,
94    pub max_message_size: Option<Max>,
95    pub max_frame_size: Option<Max>,
96    #[serde(default)]
97    pub accept_unmasked_frames: bool,
98    pub headers: Option<Vec<(String, String)>>,
99}
100
101impl From<ConnectionConfig> for WebSocketConfig {
102    fn from(config: ConnectionConfig) -> Self {
103        let mut builder =
104            WebSocketConfig::default().accept_unmasked_frames(config.accept_unmasked_frames);
105
106        if let Some(read_buffer_size) = config.read_buffer_size {
107            builder = builder.read_buffer_size(read_buffer_size)
108        }
109
110        if let Some(write_buffer_size) = config.write_buffer_size {
111            builder = builder.write_buffer_size(write_buffer_size)
112        }
113
114        if let Some(max_write_buffer_size) = config.max_write_buffer_size {
115            builder = builder.max_write_buffer_size(max_write_buffer_size)
116        }
117
118        if let Some(max_message_size) = config.max_message_size {
119            let max_size = match max_message_size {
120                Max::None => Option::None,
121                Max::Number(n) => Some(n),
122            };
123            builder = builder.max_message_size(max_size);
124        }
125
126        if let Some(max_frame_size) = config.max_frame_size {
127            let max_size = match max_frame_size {
128                Max::None => Option::None,
129                Max::Number(n) => Some(n),
130            };
131            builder = builder.max_frame_size(max_size);
132        }
133
134        builder
135    }
136}
137
138#[derive(Deserialize, Serialize)]
139struct CloseFrame {
140    pub code: u16,
141    pub reason: String,
142}
143
144#[derive(Deserialize, Serialize)]
145#[serde(tag = "type", content = "data")]
146enum WebSocketMessage {
147    Text(String),
148    Binary(Vec<u8>),
149    Ping(Vec<u8>),
150    Pong(Vec<u8>),
151    Close(Option<CloseFrame>),
152}
153
154#[tauri::command]
155async fn connect<R: Runtime>(
156    window: Window<R>,
157    url: String,
158    on_message: Channel<serde_json::Value>,
159    config: Option<ConnectionConfig>,
160) -> Result<Id> {
161    let id = rand::random();
162    let mut request = url.into_client_request()?;
163
164    if let Some(headers) = config.as_ref().and_then(|c| c.headers.as_ref()) {
165        for (k, v) in headers {
166            let header_name = HeaderName::from_str(k.as_str())?;
167            let header_value = HeaderValue::from_str(v.as_str())?;
168            request.headers_mut().insert(header_name, header_value);
169        }
170    }
171
172    #[cfg(any(
173        feature = "rustls-tls",
174        feature = "rustls-tls-native-roots",
175        feature = "native-tls"
176    ))]
177    let tls_connector = match window.try_state::<TlsConnector>() {
178        Some(tls_connector) => tls_connector.0.lock().await.clone(),
179        None => None,
180    };
181
182    #[cfg(any(
183        feature = "rustls-tls",
184        feature = "rustls-tls-native-roots",
185        feature = "native-tls"
186    ))]
187    let (ws_stream, _) =
188        connect_async_tls_with_config(request, config.map(Into::into), false, tls_connector)
189            .await?;
190    #[cfg(not(any(
191        feature = "rustls-tls",
192        feature = "rustls-tls-native-roots",
193        feature = "native-tls"
194    )))]
195    let (ws_stream, _) = connect_async_with_config(request, config.map(Into::into), false).await?;
196
197    tauri::async_runtime::spawn(async move {
198        let (write, read) = ws_stream.split();
199        let manager = window.state::<ConnectionManager>();
200        manager.0.lock().await.insert(id, write);
201        read.for_each(move |message| {
202            let window_ = window.clone();
203            let on_message_ = on_message.clone();
204            async move {
205                if let Ok(Message::Close(_)) = message {
206                    let manager = window_.state::<ConnectionManager>();
207                    manager.0.lock().await.remove(&id);
208                }
209
210                let response = match message {
211                    Ok(Message::Text(t)) => {
212                        serde_json::to_value(WebSocketMessage::Text(t.to_string())).unwrap()
213                    }
214                    Ok(Message::Binary(t)) => {
215                        serde_json::to_value(WebSocketMessage::Binary(t.to_vec())).unwrap()
216                    }
217                    Ok(Message::Ping(t)) => {
218                        serde_json::to_value(WebSocketMessage::Ping(t.to_vec())).unwrap()
219                    }
220                    Ok(Message::Pong(t)) => {
221                        serde_json::to_value(WebSocketMessage::Pong(t.to_vec())).unwrap()
222                    }
223                    Ok(Message::Close(t)) => {
224                        serde_json::to_value(WebSocketMessage::Close(t.map(|v| CloseFrame {
225                            code: v.code.into(),
226                            reason: v.reason.to_string(),
227                        })))
228                        .unwrap()
229                    }
230                    Ok(Message::Frame(_)) => serde_json::Value::Null, // This value can't be recieved.
231                    Err(e) => serde_json::to_value(Error::from(e)).unwrap(),
232                };
233
234                let _ = on_message_.send(response);
235            }
236        })
237        .await;
238    });
239
240    Ok(id)
241}
242
243#[tauri::command]
244async fn send(
245    manager: State<'_, ConnectionManager>,
246    id: Id,
247    message: WebSocketMessage,
248) -> Result<()> {
249    if let Some(write) = manager.0.lock().await.get_mut(&id) {
250        write
251            .send(match message {
252                WebSocketMessage::Text(t) => Message::Text(t.into()),
253                WebSocketMessage::Binary(t) => Message::Binary(t.into()),
254                WebSocketMessage::Ping(t) => Message::Ping(t.into()),
255                WebSocketMessage::Pong(t) => Message::Pong(t.into()),
256                WebSocketMessage::Close(t) => Message::Close(t.map(|v| ProtocolCloseFrame {
257                    code: v.code.into(),
258                    reason: v.reason.into(),
259                })),
260            })
261            .await?;
262        Ok(())
263    } else {
264        Err(Error::ConnectionNotFound(id))
265    }
266}
267
268pub fn init<R: Runtime>() -> TauriPlugin<R> {
269    Builder::default().build()
270}
271
272#[derive(Default)]
273pub struct Builder {
274    tls_connector: Option<Connector>,
275}
276
277impl Builder {
278    pub fn new() -> Self {
279        Self {
280            tls_connector: None,
281        }
282    }
283
284    pub fn tls_connector(mut self, connector: Connector) -> Self {
285        self.tls_connector.replace(connector);
286        self
287    }
288
289    pub fn build<R: Runtime>(self) -> TauriPlugin<R> {
290        PluginBuilder::new("websocket")
291            .invoke_handler(tauri::generate_handler![connect, send])
292            .setup(|app, _api| {
293                #[cfg(any(feature = "rustls-tls", feature = "rustls-tls-native-roots"))]
294                if (self.tls_connector.is_none()
295                    || matches!(self.tls_connector, Some(Connector::Plain)))
296                    && rustls::crypto::CryptoProvider::get_default().is_none()
297                {
298                    // This can only fail if there is already a default provider which we checked for already.
299                    let _ = rustls::crypto::ring::default_provider().install_default();
300                }
301
302                app.manage(ConnectionManager::default());
303                #[cfg(any(
304                    feature = "rustls-tls",
305                    feature = "rustls-tls-native-roots",
306                    feature = "native-tls"
307                ))]
308                app.manage(TlsConnector(Mutex::new(self.tls_connector)));
309                Ok(())
310            })
311            .build()
312    }
313}