use std::{borrow::Cow, str::FromStr as _};
use pyo3::{
exceptions::PyValueError,
intern,
prelude::*,
pybacked::{PyBackedBytes, PyBackedStr},
types::{PyBytes, PyDict, PyList, PyString, PyType},
};
use pyo3_utils::py_wrapper::{PyWrapper, PyWrapperT0, PyWrapperT2};
use tauri::ipc::{self, CommandArg as _, CommandItem, InvokeBody, InvokeMessage};
use crate::{
ext_mod::{
webview::{Webview, WebviewWindow},
PyAppHandleExt as _, StateManager,
},
tauri_runtime::Runtime,
utils::TauriError,
};
type IpcInvoke = tauri::ipc::Invoke<Runtime>;
type IpcInvokeResolver = tauri::ipc::InvokeResolver<Runtime>;
type TauriWebviewWindow = tauri::webview::WebviewWindow<Runtime>;
type TauriInvokeResponseBody = tauri::ipc::InvokeResponseBody;
#[derive(FromPyObject)]
enum InvokeResponseBody {
Json(PyBackedStr),
Raw(PyBackedBytes),
}
impl From<InvokeResponseBody> for TauriInvokeResponseBody {
fn from(value: InvokeResponseBody) -> Self {
match value {
InvokeResponseBody::Json(json) => TauriInvokeResponseBody::Json(json.to_owned()),
InvokeResponseBody::Raw(raw) => TauriInvokeResponseBody::Raw(raw.to_owned()),
}
}
}
#[pyclass(frozen, generic)]
#[non_exhaustive]
pub struct InvokeResolver {
inner: PyWrapper<PyWrapperT2<IpcInvokeResolver>>,
#[pyo3(get)]
arguments: Py<PyDict>,
}
impl InvokeResolver {
#[inline]
fn new(resolver: IpcInvokeResolver, arguments: Py<PyDict>) -> Self {
Self {
inner: PyWrapper::new2(resolver),
arguments,
}
}
}
#[pymethods]
impl InvokeResolver {
fn resolve(&self, py: Python<'_>, value: InvokeResponseBody) -> PyResult<()> {
py.allow_threads(|| {
let resolver = self.inner.try_take_inner()??;
resolver.resolve(TauriInvokeResponseBody::from(value));
Ok(())
})
}
fn reject(&self, py: Python<'_>, value: Cow<'_, str>) -> PyResult<()> {
py.allow_threads(|| {
let resolver = self.inner.try_take_inner()??;
resolver.reject(value);
Ok(())
})
}
}
#[pyclass(frozen)]
#[non_exhaustive]
pub struct Invoke {
inner: PyWrapper<PyWrapperT2<IpcInvoke>>,
#[pyo3(get)]
command: Py<PyString>,
}
impl Invoke {
#[cfg(feature = "__private")]
pub fn new(py: Python<'_>, invoke: IpcInvoke) -> Option<Self> {
let func_name = match Self::get_func_name_from_message(&invoke.message) {
Ok(name) => name,
Err(e) => {
invoke.resolver.reject(e);
return None;
}
};
let command = PyString::new(py, func_name).unbind();
let slf = Self {
inner: PyWrapper::new2(invoke),
command,
};
Some(slf)
}
const PYFUNC_HEADER_KEY: &str = "pyfunc";
#[inline]
fn get_func_name_from_message(message: &InvokeMessage<Runtime>) -> Result<&str, String> {
let func_name = message
.headers()
.get(Self::PYFUNC_HEADER_KEY)
.ok_or_else(|| format!("There is no {} header", Self::PYFUNC_HEADER_KEY))?
.to_str()
.map_err(|e| format!("{e}"))?;
Ok(func_name)
}
}
#[pymethods]
impl Invoke {
const BODY_KEY: &str = "body";
const APP_HANDLE_KEY: &str = "app_handle";
const WEBVIEW_WINDOW_KEY: &str = "webview_window";
const HEADERS_KEY: &str = "headers";
const STATES_KEY: &str = "states";
fn bind_to(&self, parameters: &Bound<'_, PyDict>) -> PyResult<Option<InvokeResolver>> {
let py = parameters.py();
let invoke = self.inner.try_take_inner()??;
let IpcInvoke {
message,
resolver,
acl,
} = invoke;
let arguments = PyDict::new(py);
let body_key = intern!(py, Invoke::BODY_KEY);
if parameters.contains(body_key)? {
match message.payload() {
InvokeBody::Json(_) => {
resolver.reject(
"Please use `ArrayBuffer` or `Uint8Array` for `InvokeBody::Raw`. \
If you are using `pyInvoke`, please report this as bug to pytauri developers.",
);
return Ok(None);
}
InvokeBody::Raw(body) => arguments.set_item(body_key, PyBytes::new(py, body))?,
}
}
let app_handle_key = intern!(py, Invoke::APP_HANDLE_KEY);
if parameters.contains(app_handle_key)? {
let py_app_handle = message.webview_ref().try_py_app_handle()?;
arguments.set_item(app_handle_key, py_app_handle)?;
}
let webview_window_key = intern!(py, Invoke::WEBVIEW_WINDOW_KEY);
if parameters.contains(webview_window_key)? {
let command_webview_window_item = CommandItem {
plugin: None,
name: "__whatever__pyfunc",
key: "__whatever__webviewWindow",
message: &message,
acl: &acl,
};
let webview_window = match TauriWebviewWindow::from_command(command_webview_window_item)
{
Ok(webview_window) => webview_window,
Err(e) => {
resolver.invoke_error(e);
return Ok(None);
}
};
arguments.set_item(webview_window_key, WebviewWindow::new(webview_window))?;
}
let headers_key = intern!(py, Invoke::HEADERS_KEY);
if parameters.contains(headers_key)? {
let headers: Vec<(&[u8], &[u8])> = message
.headers()
.iter()
.filter(|(key, _)| **key != Self::PYFUNC_HEADER_KEY)
.map(|(key, value)| (key.as_ref(), value.as_bytes()))
.collect();
let py_headers = PyList::new(py, headers)?;
arguments.set_item(headers_key, py_headers)?;
}
let states_key = intern!(py, Invoke::STATES_KEY);
if let Some(states_params) = parameters.get_item(states_key)? {
let states_params = states_params.downcast_into::<PyDict>()?;
let states_args = PyDict::new(py);
let state_manager = StateManager::get_or_init(py, message.webview_ref());
for (key, value) in states_params.into_iter() {
let state_type = value.downcast::<PyType>()?;
if let Some(state) = state_manager.try_state(py, state_type)? {
states_args.set_item(key, state)?;
} else {
resolver.reject(format!(
"state `{state_type}` not managed for field `{key}`. \
You must call `.manage()` before using this command"
));
return Ok(None);
}
}
arguments.set_item(states_key, states_args)?;
}
Ok(Some(InvokeResolver::new(resolver, arguments.unbind())))
}
fn resolve(&self, py: Python<'_>, value: InvokeResponseBody) -> PyResult<()> {
py.allow_threads(|| {
let resolver = self.inner.try_take_inner()??.resolver;
resolver.resolve(TauriInvokeResponseBody::from(value));
Ok(())
})
}
fn reject(&self, py: Python<'_>, value: Cow<'_, str>) -> PyResult<()> {
py.allow_threads(|| {
let resolver = self.inner.try_take_inner()??.resolver;
resolver.reject(value);
Ok(())
})
}
}
#[pyclass(frozen)]
#[non_exhaustive]
pub struct JavaScriptChannelId(PyWrapper<PyWrapperT0<ipc::JavaScriptChannelId>>);
impl JavaScriptChannelId {
fn new(js_channel_id: ipc::JavaScriptChannelId) -> Self {
Self(PyWrapper::new0(js_channel_id))
}
}
#[pymethods]
impl JavaScriptChannelId {
#[staticmethod]
fn from_str(py: Python<'_>, value: &str) -> PyResult<Self> {
let result = ipc::JavaScriptChannelId::from_str(value);
match result {
Ok(js_channel_id) => Ok(Self::new(js_channel_id)),
Err(err) => {
let msg: &'static str = err;
let msg = PyString::intern(py, msg).unbind();
Err(PyValueError::new_err(msg))
}
}
}
fn channel_on(&self, py: Python<'_>, webview: Py<Webview>) -> Channel {
py.allow_threads(|| {
let js_channel_id = self.0.inner_ref();
let webview = webview.get().0.inner_ref().clone();
let channel = js_channel_id.channel_on(webview); Channel::new(channel)
})
}
}
#[pyclass(frozen)]
#[non_exhaustive]
pub struct Channel(PyWrapper<PyWrapperT0<ipc::Channel>>);
impl Channel {
fn new(channel: ipc::Channel) -> Self {
Self(PyWrapper::new0(channel))
}
}
#[pymethods]
impl Channel {
fn id(&self) -> u32 {
self.0.inner_ref().id()
}
fn send(&self, py: Python<'_>, data: InvokeResponseBody) -> PyResult<()> {
py.allow_threads(|| {
self.0
.inner_ref()
.send(TauriInvokeResponseBody::from(data))
.map_err(TauriError::from)?;
Ok(())
})
}
}