1use std::{
6 collections::HashMap,
7 str::FromStr,
8 sync::{
9 atomic::{AtomicU32, AtomicUsize, Ordering},
10 Arc, Mutex,
11 },
12};
13
14use serde::{Deserialize, Deserializer, Serialize, Serializer};
15
16use crate::{
17 command,
18 ipc::{CommandArg, CommandItem},
19 plugin::{Builder as PluginBuilder, TauriPlugin},
20 Manager, Runtime, State, Webview,
21};
22
23use super::{
24 format_callback::format_raw_js, CallbackFn, InvokeError, InvokeResponseBody, IpcResponse,
25 Request, Response,
26};
27
28pub const IPC_PAYLOAD_PREFIX: &str = "__CHANNEL__:";
29pub const CHANNEL_PLUGIN_NAME: &str = "__TAURI_CHANNEL__";
31pub const FETCH_CHANNEL_DATA_COMMAND: &str = "plugin:__TAURI_CHANNEL__|fetch";
33const CHANNEL_ID_HEADER_NAME: &str = "Tauri-Channel-Id";
34
35const MAX_JSON_DIRECT_EXECUTE_THRESHOLD: usize = 8192;
38const MAX_RAW_DIRECT_EXECUTE_THRESHOLD: usize = 1024;
40
41static CHANNEL_COUNTER: AtomicU32 = AtomicU32::new(0);
42static CHANNEL_DATA_COUNTER: AtomicU32 = AtomicU32::new(0);
43
44#[derive(Default, Clone)]
46pub struct ChannelDataIpcQueue(Arc<Mutex<HashMap<u32, InvokeResponseBody>>>);
47
48pub struct Channel<TSend = InvokeResponseBody> {
50 inner: Arc<ChannelInner>,
51 phantom: std::marker::PhantomData<TSend>,
52}
53
54#[cfg(feature = "specta")]
55const _: () = {
56 #[derive(specta::Type)]
57 #[specta(remote = super::Channel, rename = "TAURI_CHANNEL")]
58 #[allow(dead_code)]
59 struct Channel<TSend>(std::marker::PhantomData<TSend>);
60};
61
62impl<TSend> Clone for Channel<TSend> {
63 fn clone(&self) -> Self {
64 Self {
65 inner: self.inner.clone(),
66 phantom: self.phantom,
67 }
68 }
69}
70
71type OnDropFn = Option<Box<dyn Fn() + Send + Sync + 'static>>;
72type OnMessageFn = Box<dyn Fn(InvokeResponseBody) -> crate::Result<()> + Send + Sync>;
73
74struct ChannelInner {
75 id: u32,
76 on_message: OnMessageFn,
77 on_drop: OnDropFn,
78}
79
80impl Drop for ChannelInner {
81 fn drop(&mut self) {
82 if let Some(on_drop) = &self.on_drop {
83 on_drop();
84 }
85 }
86}
87
88impl<TSend> Serialize for Channel<TSend> {
89 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
90 where
91 S: Serializer,
92 {
93 serializer.serialize_str(&format!("{IPC_PAYLOAD_PREFIX}{}", self.inner.id))
94 }
95}
96
97pub struct JavaScriptChannelId(CallbackFn);
120
121impl FromStr for JavaScriptChannelId {
122 type Err = &'static str;
123
124 fn from_str(s: &str) -> Result<Self, Self::Err> {
125 s.strip_prefix(IPC_PAYLOAD_PREFIX)
126 .ok_or("invalid channel string")
127 .and_then(|id| id.parse().map_err(|_| "invalid channel ID"))
128 .map(|id| Self(CallbackFn(id)))
129 }
130}
131
132impl JavaScriptChannelId {
133 pub fn channel_on<R: Runtime, TSend>(&self, webview: Webview<R>) -> Channel<TSend> {
135 let callback_fn = self.0;
136 let callback_id = callback_fn.0;
137
138 let counter = Arc::new(AtomicUsize::new(0));
139 let counter_clone = counter.clone();
140 let webview_clone = webview.clone();
141
142 Channel::new_with_id(
143 callback_id,
144 Box::new(move |body| {
145 let current_index = counter.fetch_add(1, Ordering::Relaxed);
146
147 if let Some(interceptor) = &webview.manager.channel_interceptor {
148 if interceptor(&webview, callback_fn, current_index, &body) {
149 return Ok(());
150 }
151 }
152
153 match body {
154 InvokeResponseBody::Json(json_string)
156 if json_string.len() < MAX_JSON_DIRECT_EXECUTE_THRESHOLD =>
157 {
158 webview.eval(format_raw_js(
159 callback_id,
160 format!("{{ message: {json_string}, index: {current_index} }}"),
161 ))?;
162 }
163 InvokeResponseBody::Raw(bytes) if bytes.len() < MAX_RAW_DIRECT_EXECUTE_THRESHOLD => {
164 let bytes_as_json_array = serde_json::to_string(&bytes)?;
165 webview.eval(format_raw_js(callback_id, format!("{{ message: new Uint8Array({bytes_as_json_array}).buffer, index: {current_index} }}")))?;
166 }
167 _ => {
169 let data_id = CHANNEL_DATA_COUNTER.fetch_add(1, Ordering::Relaxed);
170
171 webview
172 .state::<ChannelDataIpcQueue>()
173 .0
174 .lock()
175 .unwrap()
176 .insert(data_id, body);
177
178 webview.eval(format!(
179 "window.__TAURI_INTERNALS__.invoke('{FETCH_CHANNEL_DATA_COMMAND}', null, {{ headers: {{ '{CHANNEL_ID_HEADER_NAME}': '{data_id}' }} }}).then((response) => window.__TAURI_INTERNALS__.runCallback({callback_id}, {{ message: response, index: {current_index} }})).catch(console.error)",
180 ))?;
181 }
182 }
183
184 Ok(())
185 }),
186 Some(Box::new(move || {
187 let current_index = counter_clone.load(Ordering::Relaxed);
188 let _ = webview_clone.eval(format_raw_js(
189 callback_id,
190 format!("{{ end: true, index: {current_index} }}"),
191 ));
192 })),
193 )
194 }
195}
196
197impl<'de> Deserialize<'de> for JavaScriptChannelId {
198 fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
199 where
200 D: Deserializer<'de>,
201 {
202 let value: String = Deserialize::deserialize(deserializer)?;
203 Self::from_str(&value).map_err(|_| {
204 serde::de::Error::custom(format!(
205 "invalid channel value `{value}`, expected a string in the `{IPC_PAYLOAD_PREFIX}ID` format"
206 ))
207 })
208 }
209}
210
211impl<TSend> Channel<TSend> {
212 pub fn new<F: Fn(InvokeResponseBody) -> crate::Result<()> + Send + Sync + 'static>(
214 on_message: F,
215 ) -> Self {
216 Self::new_with_id(
217 CHANNEL_COUNTER.fetch_add(1, Ordering::Relaxed),
218 Box::new(on_message),
219 None,
220 )
221 }
222
223 fn new_with_id(id: u32, on_message: OnMessageFn, on_drop: OnDropFn) -> Self {
224 #[allow(clippy::let_and_return)]
225 let channel = Self {
226 inner: Arc::new(ChannelInner {
227 id,
228 on_message,
229 on_drop,
230 }),
231 phantom: Default::default(),
232 };
233
234 #[cfg(mobile)]
235 crate::plugin::mobile::register_channel(Channel {
236 inner: channel.inner.clone(),
237 phantom: Default::default(),
238 });
239
240 channel
241 }
242
243 pub(crate) fn from_callback_fn<R: Runtime>(webview: Webview<R>, callback: CallbackFn) -> Self {
245 let callback_id = callback.0;
246 Channel::new_with_id(
247 callback_id,
248 Box::new(move |body| {
249 match body {
250 InvokeResponseBody::Json(json_string)
252 if json_string.len() < MAX_JSON_DIRECT_EXECUTE_THRESHOLD =>
253 {
254 webview.eval(format_raw_js(callback_id, json_string))?;
255 }
256 InvokeResponseBody::Raw(bytes) if bytes.len() < MAX_RAW_DIRECT_EXECUTE_THRESHOLD => {
257 let bytes_as_json_array = serde_json::to_string(&bytes)?;
258 webview.eval(format_raw_js(
259 callback_id,
260 format!("new Uint8Array({bytes_as_json_array}).buffer"),
261 ))?;
262 }
263 _ => {
265 let data_id = CHANNEL_DATA_COUNTER.fetch_add(1, Ordering::Relaxed);
266
267 webview
268 .state::<ChannelDataIpcQueue>()
269 .0
270 .lock()
271 .unwrap()
272 .insert(data_id, body);
273
274 webview.eval(format!(
275 "window.__TAURI_INTERNALS__.invoke('{FETCH_CHANNEL_DATA_COMMAND}', null, {{ headers: {{ '{CHANNEL_ID_HEADER_NAME}': '{data_id}' }} }}).then((response) => window.__TAURI_INTERNALS__.runCallback({callback_id}, response)).catch(console.error)",
276 ))?;
277 }
278 }
279
280 Ok(())
281 }),
282 None,
283 )
284 }
285
286 pub fn id(&self) -> u32 {
288 self.inner.id
289 }
290
291 pub fn send(&self, data: TSend) -> crate::Result<()>
293 where
294 TSend: IpcResponse,
295 {
296 (self.inner.on_message)(data.body()?)
297 }
298}
299
300impl<'de, R: Runtime, TSend> CommandArg<'de, R> for Channel<TSend> {
301 fn from_command(command: CommandItem<'de, R>) -> Result<Self, InvokeError> {
303 let name = command.name;
304 let arg = command.key;
305 let webview = command.message.webview();
306 let value: String =
307 Deserialize::deserialize(command).map_err(|e| crate::Error::InvalidArgs(name, arg, e))?;
308 JavaScriptChannelId::from_str(&value)
309 .map(|id| id.channel_on(webview))
310 .map_err(|_| {
311 InvokeError::from(format!(
312 "invalid channel value `{value}`, expected a string in the `{IPC_PAYLOAD_PREFIX}ID` format"
313 ))
314 })
315 }
316}
317
318#[command(root = "crate")]
319fn fetch(
320 request: Request<'_>,
321 cache: State<'_, ChannelDataIpcQueue>,
322) -> Result<Response, &'static str> {
323 if let Some(id) = request
324 .headers()
325 .get(CHANNEL_ID_HEADER_NAME)
326 .and_then(|v| v.to_str().ok())
327 .and_then(|id| id.parse().ok())
328 {
329 if let Some(data) = cache.0.lock().unwrap().remove(&id) {
330 Ok(Response::new(data))
331 } else {
332 Err("data not found")
333 }
334 } else {
335 Err("missing channel id header")
336 }
337}
338
339pub fn plugin<R: Runtime>() -> TauriPlugin<R> {
340 PluginBuilder::new(CHANNEL_PLUGIN_NAME)
341 .invoke_handler(crate::generate_handler![
342 #![plugin(__TAURI_CHANNEL__)]
343 fetch
344 ])
345 .build()
346}