tauri/ipc/
channel.rs

1// Copyright 2019-2024 Tauri Programme within The Commons Conservancy
2// SPDX-License-Identifier: Apache-2.0
3// SPDX-License-Identifier: MIT
4
5use std::{
6  collections::HashMap,
7  str::FromStr,
8  sync::{
9    atomic::{AtomicU32, AtomicUsize, Ordering},
10    Arc, Mutex,
11  },
12};
13
14use serde::{Deserialize, Deserializer, Serialize, Serializer};
15
16use crate::{
17  command,
18  ipc::{CommandArg, CommandItem},
19  plugin::{Builder as PluginBuilder, TauriPlugin},
20  Manager, Runtime, State, Webview,
21};
22
23use super::{
24  format_callback::format_raw_js, CallbackFn, InvokeError, InvokeResponseBody, IpcResponse,
25  Request, Response,
26};
27
28pub const IPC_PAYLOAD_PREFIX: &str = "__CHANNEL__:";
29// TODO: Change this to `channel` in v3
30pub const CHANNEL_PLUGIN_NAME: &str = "__TAURI_CHANNEL__";
31// TODO: Change this to `plugin:channel|fetch` in v3
32pub const FETCH_CHANNEL_DATA_COMMAND: &str = "plugin:__TAURI_CHANNEL__|fetch";
33const CHANNEL_ID_HEADER_NAME: &str = "Tauri-Channel-Id";
34
35/// Maximum size a JSON we should send directly without going through the fetch process
36// 8192 byte JSON payload runs roughly 2x faster through eval than through fetch on WebView2 v135
37const MAX_JSON_DIRECT_EXECUTE_THRESHOLD: usize = 8192;
38// 1024 byte payload runs  roughly 30% faster through eval than through fetch on macOS
39const MAX_RAW_DIRECT_EXECUTE_THRESHOLD: usize = 1024;
40
41static CHANNEL_COUNTER: AtomicU32 = AtomicU32::new(0);
42static CHANNEL_DATA_COUNTER: AtomicU32 = AtomicU32::new(0);
43
44/// Maps a channel id to a pending data that must be send to the JavaScript side via the IPC.
45#[derive(Default, Clone)]
46pub struct ChannelDataIpcQueue(Arc<Mutex<HashMap<u32, InvokeResponseBody>>>);
47
48/// An IPC channel.
49pub struct Channel<TSend = InvokeResponseBody> {
50  inner: Arc<ChannelInner>,
51  phantom: std::marker::PhantomData<TSend>,
52}
53
54#[cfg(feature = "specta")]
55const _: () = {
56  #[derive(specta::Type)]
57  #[specta(remote = super::Channel, rename = "TAURI_CHANNEL")]
58  #[allow(dead_code)]
59  struct Channel<TSend>(std::marker::PhantomData<TSend>);
60};
61
62impl<TSend> Clone for Channel<TSend> {
63  fn clone(&self) -> Self {
64    Self {
65      inner: self.inner.clone(),
66      phantom: self.phantom,
67    }
68  }
69}
70
71type OnDropFn = Option<Box<dyn Fn() + Send + Sync + 'static>>;
72type OnMessageFn = Box<dyn Fn(InvokeResponseBody) -> crate::Result<()> + Send + Sync>;
73
74struct ChannelInner {
75  id: u32,
76  on_message: OnMessageFn,
77  on_drop: OnDropFn,
78}
79
80impl Drop for ChannelInner {
81  fn drop(&mut self) {
82    if let Some(on_drop) = &self.on_drop {
83      on_drop();
84    }
85  }
86}
87
88impl<TSend> Serialize for Channel<TSend> {
89  fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
90  where
91    S: Serializer,
92  {
93    serializer.serialize_str(&format!("{IPC_PAYLOAD_PREFIX}{}", self.inner.id))
94  }
95}
96
97/// The ID of a channel that was defined on the JavaScript layer.
98///
99/// Useful when expecting [`Channel`] as part of a JSON object instead of a top-level command argument.
100///
101/// # Examples
102///
103/// ```rust
104/// use tauri::{ipc::JavaScriptChannelId, Runtime, Webview};
105///
106/// #[derive(serde::Deserialize)]
107/// #[serde(rename_all = "camelCase")]
108/// struct Button {
109///   label: String,
110///   on_click: JavaScriptChannelId,
111/// }
112///
113/// #[tauri::command]
114/// fn add_button<R: Runtime>(webview: Webview<R>, button: Button) {
115///   let channel = button.on_click.channel_on(webview);
116///   channel.send("clicked").unwrap();
117/// }
118/// ```
119pub struct JavaScriptChannelId(CallbackFn);
120
121impl FromStr for JavaScriptChannelId {
122  type Err = &'static str;
123
124  fn from_str(s: &str) -> Result<Self, Self::Err> {
125    s.strip_prefix(IPC_PAYLOAD_PREFIX)
126      .ok_or("invalid channel string")
127      .and_then(|id| id.parse().map_err(|_| "invalid channel ID"))
128      .map(|id| Self(CallbackFn(id)))
129  }
130}
131
132impl JavaScriptChannelId {
133  /// Gets a [`Channel`] for this channel ID on the given [`Webview`].
134  pub fn channel_on<R: Runtime, TSend>(&self, webview: Webview<R>) -> Channel<TSend> {
135    let callback_fn = self.0;
136    let callback_id = callback_fn.0;
137
138    let counter = Arc::new(AtomicUsize::new(0));
139    let counter_clone = counter.clone();
140    let webview_clone = webview.clone();
141
142    Channel::new_with_id(
143      callback_id,
144      Box::new(move |body| {
145        let current_index = counter.fetch_add(1, Ordering::Relaxed);
146
147        if let Some(interceptor) = &webview.manager.channel_interceptor {
148          if interceptor(&webview, callback_fn, current_index, &body) {
149            return Ok(());
150          }
151        }
152
153        match body {
154          // Don't go through the fetch process if the payload is small
155          InvokeResponseBody::Json(json_string)
156            if json_string.len() < MAX_JSON_DIRECT_EXECUTE_THRESHOLD =>
157          {
158            webview.eval(format_raw_js(
159              callback_id,
160              format!("{{ message: {json_string}, index: {current_index} }}"),
161            ))?;
162          }
163          InvokeResponseBody::Raw(bytes) if bytes.len() < MAX_RAW_DIRECT_EXECUTE_THRESHOLD => {
164            let bytes_as_json_array = serde_json::to_string(&bytes)?;
165            webview.eval(format_raw_js(callback_id, format!("{{ message: new Uint8Array({bytes_as_json_array}).buffer, index: {current_index} }}")))?;
166          }
167          // use the fetch API to speed up larger response payloads
168          _ => {
169            let data_id = CHANNEL_DATA_COUNTER.fetch_add(1, Ordering::Relaxed);
170
171            webview
172              .state::<ChannelDataIpcQueue>()
173              .0
174              .lock()
175              .unwrap()
176              .insert(data_id, body);
177
178            webview.eval(format!(
179              "window.__TAURI_INTERNALS__.invoke('{FETCH_CHANNEL_DATA_COMMAND}', null, {{ headers: {{ '{CHANNEL_ID_HEADER_NAME}': '{data_id}' }} }}).then((response) => window.__TAURI_INTERNALS__.runCallback({callback_id}, {{ message: response, index: {current_index} }})).catch(console.error)",
180            ))?;
181          }
182        }
183
184        Ok(())
185      }),
186      Some(Box::new(move || {
187        let current_index = counter_clone.load(Ordering::Relaxed);
188        let _ = webview_clone.eval(format_raw_js(
189          callback_id,
190          format!("{{ end: true, index: {current_index} }}"),
191        ));
192      })),
193    )
194  }
195}
196
197impl<'de> Deserialize<'de> for JavaScriptChannelId {
198  fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
199  where
200    D: Deserializer<'de>,
201  {
202    let value: String = Deserialize::deserialize(deserializer)?;
203    Self::from_str(&value).map_err(|_| {
204      serde::de::Error::custom(format!(
205        "invalid channel value `{value}`, expected a string in the `{IPC_PAYLOAD_PREFIX}ID` format"
206      ))
207    })
208  }
209}
210
211impl<TSend> Channel<TSend> {
212  /// Creates a new channel with the given message handler.
213  pub fn new<F: Fn(InvokeResponseBody) -> crate::Result<()> + Send + Sync + 'static>(
214    on_message: F,
215  ) -> Self {
216    Self::new_with_id(
217      CHANNEL_COUNTER.fetch_add(1, Ordering::Relaxed),
218      Box::new(on_message),
219      None,
220    )
221  }
222
223  fn new_with_id(id: u32, on_message: OnMessageFn, on_drop: OnDropFn) -> Self {
224    #[allow(clippy::let_and_return)]
225    let channel = Self {
226      inner: Arc::new(ChannelInner {
227        id,
228        on_message,
229        on_drop,
230      }),
231      phantom: Default::default(),
232    };
233
234    #[cfg(mobile)]
235    crate::plugin::mobile::register_channel(Channel {
236      inner: channel.inner.clone(),
237      phantom: Default::default(),
238    });
239
240    channel
241  }
242
243  // This is used from the IPC handler
244  pub(crate) fn from_callback_fn<R: Runtime>(webview: Webview<R>, callback: CallbackFn) -> Self {
245    let callback_id = callback.0;
246    Channel::new_with_id(
247      callback_id,
248      Box::new(move |body| {
249        match body {
250          // Don't go through the fetch process if the payload is small
251          InvokeResponseBody::Json(json_string)
252            if json_string.len() < MAX_JSON_DIRECT_EXECUTE_THRESHOLD =>
253          {
254            webview.eval(format_raw_js(callback_id, json_string))?;
255          }
256          InvokeResponseBody::Raw(bytes) if bytes.len() < MAX_RAW_DIRECT_EXECUTE_THRESHOLD => {
257            let bytes_as_json_array = serde_json::to_string(&bytes)?;
258            webview.eval(format_raw_js(
259              callback_id,
260              format!("new Uint8Array({bytes_as_json_array}).buffer"),
261            ))?;
262          }
263          // use the fetch API to speed up larger response payloads
264          _ => {
265            let data_id = CHANNEL_DATA_COUNTER.fetch_add(1, Ordering::Relaxed);
266
267            webview
268              .state::<ChannelDataIpcQueue>()
269              .0
270              .lock()
271              .unwrap()
272              .insert(data_id, body);
273
274            webview.eval(format!(
275              "window.__TAURI_INTERNALS__.invoke('{FETCH_CHANNEL_DATA_COMMAND}', null, {{ headers: {{ '{CHANNEL_ID_HEADER_NAME}': '{data_id}' }} }}).then((response) => window.__TAURI_INTERNALS__.runCallback({callback_id}, response)).catch(console.error)",
276            ))?;
277          }
278        }
279
280        Ok(())
281      }),
282      None,
283    )
284  }
285
286  /// The channel identifier.
287  pub fn id(&self) -> u32 {
288    self.inner.id
289  }
290
291  /// Sends the given data through the channel.
292  pub fn send(&self, data: TSend) -> crate::Result<()>
293  where
294    TSend: IpcResponse,
295  {
296    (self.inner.on_message)(data.body()?)
297  }
298}
299
300impl<'de, R: Runtime, TSend> CommandArg<'de, R> for Channel<TSend> {
301  /// Grabs the [`Webview`] from the [`CommandItem`] and returns the associated [`Channel`].
302  fn from_command(command: CommandItem<'de, R>) -> Result<Self, InvokeError> {
303    let name = command.name;
304    let arg = command.key;
305    let webview = command.message.webview();
306    let value: String =
307      Deserialize::deserialize(command).map_err(|e| crate::Error::InvalidArgs(name, arg, e))?;
308    JavaScriptChannelId::from_str(&value)
309      .map(|id| id.channel_on(webview))
310      .map_err(|_| {
311        InvokeError::from(format!(
312	        "invalid channel value `{value}`, expected a string in the `{IPC_PAYLOAD_PREFIX}ID` format"
313	      ))
314      })
315  }
316}
317
318#[command(root = "crate")]
319fn fetch(
320  request: Request<'_>,
321  cache: State<'_, ChannelDataIpcQueue>,
322) -> Result<Response, &'static str> {
323  if let Some(id) = request
324    .headers()
325    .get(CHANNEL_ID_HEADER_NAME)
326    .and_then(|v| v.to_str().ok())
327    .and_then(|id| id.parse().ok())
328  {
329    if let Some(data) = cache.0.lock().unwrap().remove(&id) {
330      Ok(Response::new(data))
331    } else {
332      Err("data not found")
333    }
334  } else {
335    Err("missing channel id header")
336  }
337}
338
339pub fn plugin<R: Runtime>() -> TauriPlugin<R> {
340  PluginBuilder::new(CHANNEL_PLUGIN_NAME)
341    .invoke_handler(crate::generate_handler![
342      #![plugin(__TAURI_CHANNEL__)]
343      fetch
344    ])
345    .build()
346}