use std::collections::{HashMap, HashSet};
use std::sync::{Arc, Mutex};
use async_trait::async_trait;
use base64::engine::general_purpose::STANDARD as BASE64_STANDARD;
use base64::Engine as _;
use serde_json::Value;
use tauri::{Manager, Runtime, WebviewWindow};
use tokio::sync::oneshot;
use webview2_com::Microsoft::Web::WebView2::Win32::{
ICoreWebView2, ICoreWebView2CapturePreviewCompletedHandler, ICoreWebView2Environment6,
ICoreWebView2ExecuteScriptCompletedHandler, ICoreWebView2PrintToPdfCompletedHandler,
ICoreWebView2ScriptDialogOpeningEventHandler, ICoreWebView2WebMessageReceivedEventHandler,
ICoreWebView2_7, COREWEBVIEW2_CAPTURE_PREVIEW_IMAGE_FORMAT_PNG,
COREWEBVIEW2_PRINT_ORIENTATION_LANDSCAPE, COREWEBVIEW2_PRINT_ORIENTATION_PORTRAIT,
};
use windows::core::{Interface, HSTRING, PCWSTR};
use windows::Win32::Foundation::HGLOBAL;
use windows::Win32::System::Com::StructuredStorage::CreateStreamOnHGlobal;
use windows::Win32::System::Com::{
CoInitializeEx, COINIT_APARTMENTTHREADED, STATFLAG_NONAME, STREAM_SEEK_SET,
};
use windows_core::BOOL;
use crate::platform::alert_state::{AlertState, AlertStateManager, AlertType, PendingAlert};
use crate::platform::{wrap_script_for_frame_context, FrameId, PlatformExecutor, PrintOptions};
use crate::server::response::WebDriverErrorResponse;
use crate::webdriver::Timeouts;
const HANDLER_NAME: &str = "webdriver_async";
#[derive(Default)]
pub struct ScriptExecutionLocks {
locks: std::sync::Mutex<HashMap<String, Arc<tokio::sync::Mutex<()>>>>,
}
impl ScriptExecutionLocks {
pub fn get(&self, label: &str) -> Arc<tokio::sync::Mutex<()>> {
let mut m = self.locks.lock().expect("ScriptExecutionLocks poisoned");
m.entry(label.to_string())
.or_insert_with(|| Arc::new(tokio::sync::Mutex::new(())))
.clone()
}
}
#[derive(Default)]
pub struct AsyncScriptState {
pending: Mutex<HashMap<String, oneshot::Sender<Result<Value, String>>>>,
registered_handlers: Mutex<HashSet<String>>,
}
impl AsyncScriptState {
pub fn register(&self, id: String) -> oneshot::Receiver<Result<Value, String>> {
let (tx, rx) = oneshot::channel();
if let Ok(mut pending) = self.pending.lock() {
pending.insert(id, tx);
}
rx
}
pub fn complete(&self, id: &str, result: Result<Value, String>) {
if let Ok(mut pending) = self.pending.lock() {
if let Some(tx) = pending.remove(id) {
let _ = tx.send(result);
}
}
}
pub fn cancel(&self, id: &str) {
if let Ok(mut pending) = self.pending.lock() {
pending.remove(id);
}
}
pub fn mark_handler_registered(&self, label: &str) -> bool {
if let Ok(mut handlers) = self.registered_handlers.lock() {
!handlers.insert(label.to_string())
} else {
false
}
}
}
struct SendableComPtr(*mut std::ffi::c_void);
unsafe impl Send for SendableComPtr {}
impl SendableComPtr {
fn as_ptr(&self) -> *mut std::ffi::c_void {
self.0
}
}
#[derive(Clone)]
pub struct WindowsExecutor<R: Runtime> {
window: WebviewWindow<R>,
timeouts: Timeouts,
frame_context: Vec<FrameId>,
}
impl<R: Runtime> WindowsExecutor<R> {
pub fn new(window: WebviewWindow<R>, timeouts: Timeouts, frame_context: Vec<FrameId>) -> Self {
Self {
window,
timeouts,
frame_context,
}
}
}
impl<R: Runtime + 'static> WindowsExecutor<R> {
async fn evaluate_js_inner(&self, script: &str) -> Result<Value, WebDriverErrorResponse> {
let (tx, rx) = oneshot::channel();
let script_preview: String = script.chars().take(100).collect();
let script_owned = wrap_script_for_frame_context(script, &self.frame_context);
let tx = Arc::new(std::sync::Mutex::new(Some(tx)));
let result = self.window.with_webview({
let tx = tx.clone();
move |webview| unsafe {
let _ = CoInitializeEx(None, COINIT_APARTMENTTHREADED);
if let Ok(webview2) = webview.controller().CoreWebView2() {
let script_hstring = HSTRING::from(&script_owned);
let handler: ICoreWebView2ExecuteScriptCompletedHandler =
ExecuteScriptHandler::new(tx.clone()).into();
if let Err(e) = webview2.ExecuteScript(PCWSTR(script_hstring.as_ptr()), &handler) {
tracing::error!("ExecuteScript call failed for script '{}...': {e:?}", script_preview);
if let Ok(mut guard) = tx.lock() {
if let Some(tx) = guard.take() {
let _ = tx.send(Err(format!("ExecuteScript failed: {e:?}")));
}
}
}
} else {
tracing::error!("Failed to get CoreWebView2 for script execution");
if let Ok(mut guard) = tx.lock() {
if let Some(tx) = guard.take() {
let _ = tx.send(Err("Failed to get CoreWebView2".to_string()));
}
}
}
}
});
if let Err(e) = result {
tracing::error!("with_webview failed: {e}");
if let Ok(mut guard) = tx.lock() {
if let Some(tx) = guard.take() {
let _ = tx.send(Err(e.to_string()));
}
}
}
let timeout = std::time::Duration::from_millis(self.timeouts.script_ms);
match tokio::time::timeout(timeout, rx).await {
Ok(Ok(Ok(value))) => Ok(serde_json::json!({
"success": true,
"value": value
})),
Ok(Ok(Err(error))) => Err(WebDriverErrorResponse::javascript_error(&error, None)),
Ok(Err(_)) => {
tracing::error!("Channel closed unexpectedly during script execution");
Err(WebDriverErrorResponse::unknown_error("Channel closed"))
},
Err(_) => Err(WebDriverErrorResponse::script_timeout()),
}
}
}
pub fn register_webview_handlers<R: Runtime>(webview: &tauri::Webview<R>) {
let manager = webview.app_handle().state::<AlertStateManager>();
let alert_state = manager.get_or_create(webview.label());
let _ = webview.with_webview(move |webview| unsafe {
let _ = CoInitializeEx(None, COINIT_APARTMENTTHREADED);
if let Ok(webview2) = webview.controller().CoreWebView2() {
if let Ok(settings) = webview2.Settings() {
if let Err(e) = settings.SetAreDefaultScriptDialogsEnabled(false) {
tracing::error!("Failed to disable default script dialogs: {e:?}");
return;
}
} else {
tracing::error!("Failed to get webview settings");
return;
}
let handler: ICoreWebView2ScriptDialogOpeningEventHandler =
ScriptDialogOpeningHandler::new(alert_state).into();
let mut token = std::mem::zeroed();
if let Err(e) = webview2.add_ScriptDialogOpening(&handler, &raw mut token) {
tracing::error!("Failed to register ScriptDialogOpening handler: {e:?}");
} else {
tracing::debug!("Registered script dialog handler for webview");
}
std::mem::forget(handler);
}
});
}
#[async_trait]
impl<R: Runtime + 'static> PlatformExecutor<R> for WindowsExecutor<R> {
fn window(&self) -> &WebviewWindow<R> {
&self.window
}
fn script_timeout_ms(&self) -> u64 {
self.timeouts.script_ms
}
async fn evaluate_js(&self, script: &str) -> Result<Value, WebDriverErrorResponse> {
let locks = self.window.state::<ScriptExecutionLocks>();
let lock = locks.get(self.window.label());
let _guard = lock.lock().await;
self.evaluate_js_inner(script).await
}
async fn take_screenshot(&self) -> Result<String, WebDriverErrorResponse> {
let (tx, rx) = oneshot::channel();
let result = self.window.with_webview(move |webview| {
unsafe {
if let Ok(webview2) = webview.controller().CoreWebView2() {
let stream = match CreateStreamOnHGlobal(HGLOBAL::default(), true) {
Ok(s) => s,
Err(e) => {
let tx = Arc::new(std::sync::Mutex::new(Some(tx)));
if let Ok(mut guard) = tx.lock() {
if let Some(tx) = guard.take() {
let _ = tx.send(Err(format!("Failed to create stream: {e}")));
}
}
return;
}
};
let tx = Arc::new(std::sync::Mutex::new(Some(tx)));
let handler = CapturePreviewHandler::new(tx, stream.clone());
let handler: ICoreWebView2CapturePreviewCompletedHandler = handler.into();
if let Err(e) = webview2.CapturePreview(
COREWEBVIEW2_CAPTURE_PREVIEW_IMAGE_FORMAT_PNG,
&stream,
&handler,
) {
tracing::error!("CapturePreview failed: {e}");
}
}
}
});
if let Err(e) = result {
return Err(WebDriverErrorResponse::unknown_error(&e.to_string()));
}
let timeout = std::time::Duration::from_millis(self.timeouts.script_ms);
match tokio::time::timeout(timeout, rx).await {
Ok(Ok(Ok(base64))) => {
if base64.is_empty() {
Err(WebDriverErrorResponse::unknown_error(
"Screenshot returned empty data",
))
} else {
Ok(base64)
}
}
Ok(Ok(Err(error))) => Err(WebDriverErrorResponse::unknown_error(&error)),
Ok(Err(_)) => Err(WebDriverErrorResponse::unknown_error("Channel closed")),
Err(_) => Err(WebDriverErrorResponse::script_timeout()),
}
}
async fn take_element_screenshot(
&self,
js_var: &str,
) -> Result<String, WebDriverErrorResponse> {
let script = format!(
r"(function() {{
var el = window.{js_var};
if (!el || !el.isConnected) {{
throw new Error('stale element reference');
}}
el.scrollIntoView({{ block: 'center', inline: 'center' }});
return true;
}})()"
);
self.evaluate_js(&script).await?;
self.take_screenshot().await
}
#[allow(clippy::too_many_lines)]
async fn print_page(&self, options: PrintOptions) -> Result<String, WebDriverErrorResponse> {
let (tx, rx) = oneshot::channel();
let tx = Arc::new(std::sync::Mutex::new(Some(tx)));
let temp_dir = tempfile::TempDir::new().map_err(|e| {
WebDriverErrorResponse::unknown_error(&format!("Failed to create temp dir: {e}"))
})?;
let pdf_path = temp_dir.path().join("print.pdf");
let pdf_path_clone = pdf_path.clone();
let orientation = options.orientation.clone();
let scale = options.scale;
let background = options.background;
let page_width = options.page_width;
let page_height = options.page_height;
let margin_top = options.margin_top;
let margin_bottom = options.margin_bottom;
let margin_left = options.margin_left;
let margin_right = options.margin_right;
let result = self.window.with_webview(move |webview| unsafe {
let _ = CoInitializeEx(None, COINIT_APARTMENTTHREADED);
let webview2 = match webview.controller().CoreWebView2() {
Ok(wv) => wv,
Err(e) => {
if let Ok(mut guard) = tx.lock() {
if let Some(tx) = guard.take() {
let _ = tx.send(Err(format!("Failed to get CoreWebView2: {e:?}")));
}
}
return;
}
};
let webview7: ICoreWebView2_7 = match webview2.cast() {
Ok(wv) => wv,
Err(e) => {
if let Ok(mut guard) = tx.lock() {
if let Some(tx) = guard.take() {
let _ =
tx.send(Err(format!("Failed to cast to ICoreWebView2_7: {e:?}")));
}
}
return;
}
};
let environment = match webview7.Environment() {
Ok(env) => env,
Err(e) => {
if let Ok(mut guard) = tx.lock() {
if let Some(tx) = guard.take() {
let _ = tx.send(Err(format!("Failed to get environment: {e:?}")));
}
}
return;
}
};
let env6: ICoreWebView2Environment6 = match environment.cast() {
Ok(env) => env,
Err(e) => {
if let Ok(mut guard) = tx.lock() {
if let Some(tx) = guard.take() {
let _ = tx.send(Err(format!(
"Failed to cast to ICoreWebView2Environment6: {e:?}"
)));
}
}
return;
}
};
let settings = match env6.CreatePrintSettings() {
Ok(s) => s,
Err(e) => {
if let Ok(mut guard) = tx.lock() {
if let Some(tx) = guard.take() {
let _ = tx.send(Err(format!("Failed to create print settings: {e:?}")));
}
}
return;
}
};
if let Some(ref orient) = orientation {
let orientation_val = if orient == "landscape" {
COREWEBVIEW2_PRINT_ORIENTATION_LANDSCAPE
} else {
COREWEBVIEW2_PRINT_ORIENTATION_PORTRAIT
};
let _ = settings.SetOrientation(orientation_val);
}
if let Some(s) = scale {
let _ = settings.SetScaleFactor(s);
}
if let Some(bg) = background {
let _ = settings.SetShouldPrintBackgrounds(bg);
}
if let Some(w) = page_width {
let _ = settings.SetPageWidth(w / 2.54);
}
if let Some(h) = page_height {
let _ = settings.SetPageHeight(h / 2.54);
}
if let Some(m) = margin_top {
let _ = settings.SetMarginTop(m / 2.54);
}
if let Some(m) = margin_bottom {
let _ = settings.SetMarginBottom(m / 2.54);
}
if let Some(m) = margin_left {
let _ = settings.SetMarginLeft(m / 2.54);
}
if let Some(m) = margin_right {
let _ = settings.SetMarginRight(m / 2.54);
}
let handler: ICoreWebView2PrintToPdfCompletedHandler =
handlers::PrintToPdfHandler::new(tx).into();
let path_str = pdf_path_clone.to_string_lossy().to_string();
let path_hstring = HSTRING::from(&path_str);
if let Err(e) = webview7.PrintToPdf(&path_hstring, &settings, &handler) {
tracing::error!("PrintToPdf call failed: {e:?}");
}
});
if let Err(e) = result {
return Err(WebDriverErrorResponse::unknown_error(&e.to_string()));
}
let timeout = std::time::Duration::from_millis(self.timeouts.script_ms);
let print_result = match tokio::time::timeout(timeout, rx).await {
Ok(Ok(Ok(()))) => Ok(()),
Ok(Ok(Err(error))) => Err(WebDriverErrorResponse::unknown_error(&error)),
Ok(Err(_)) => Err(WebDriverErrorResponse::unknown_error("Channel closed")),
Err(_) => Err(WebDriverErrorResponse::script_timeout()),
};
print_result?;
let pdf_data = std::fs::read(&pdf_path).map_err(|e| {
WebDriverErrorResponse::unknown_error(&format!("Failed to read PDF file: {e}"))
})?;
Ok(BASE64_STANDARD.encode(&pdf_data))
}
async fn execute_async_script(
&self,
script: &str,
args: &[Value],
) -> Result<Value, WebDriverErrorResponse> {
let args_json = serde_json::to_string(args)
.map_err(|e| WebDriverErrorResponse::invalid_argument(&e.to_string()))?;
let async_id = uuid::Uuid::new_v4().to_string();
let app = self.window.app_handle().clone();
let async_state = app.state::<AsyncScriptState>();
let label = self.window.label().to_string();
if !async_state.mark_handler_registered(&label) {
let app_clone = app.clone();
let handler_result = self.window.with_webview(move |webview| unsafe {
let _ = CoInitializeEx(None, COINIT_APARTMENTTHREADED);
if let Ok(webview2) = webview.controller().CoreWebView2() {
let state = app_clone.state::<AsyncScriptState>();
register_message_handler(&webview2, state.inner());
}
});
if let Err(e) = handler_result {
return Err(WebDriverErrorResponse::unknown_error(&format!(
"Failed to register message handler: {e}"
)));
}
}
let rx = async_state.register(async_id.clone());
let wrapper = format!(
r"(function() {{
var ELEMENT_KEY = 'element-6066-11e4-a52e-4f735466cecf';
function deserializeArg(arg) {{
if (arg === null || arg === undefined) return arg;
if (Array.isArray(arg)) return arg.map(deserializeArg);
if (typeof arg === 'object') {{
if (arg[ELEMENT_KEY]) {{
var el = window['__wd_el_' + arg[ELEMENT_KEY].replace(/-/g, '')];
if (!el) throw new Error('stale element reference');
return el;
}}
var result = {{}};
for (var key in arg) {{
if (arg.hasOwnProperty(key)) result[key] = deserializeArg(arg[key]);
}}
return result;
}}
return arg;
}}
var __done = function(r) {{
window.chrome.webview.postMessage(JSON.stringify({{
handler: '{HANDLER_NAME}',
id: '{async_id}',
result: r,
error: null
}}));
}};
var __args = {args_json}.map(deserializeArg);
__args.push(__done);
try {{
(function() {{ {script} }}).apply(null, __args);
}} catch (e) {{
window.chrome.webview.postMessage(JSON.stringify({{
handler: '{HANDLER_NAME}',
id: '{async_id}',
result: null,
error: e.message || String(e)
}}));
}}
}})()"
);
let locks = self.window.state::<ScriptExecutionLocks>();
let lock = locks.get(self.window.label());
let _guard = lock.lock().await;
self.evaluate_js_inner(&wrapper).await?;
let timeout_ms = self.timeouts.script_ms;
let timeout = std::time::Duration::from_millis(timeout_ms);
match tokio::time::timeout(timeout, rx).await {
Ok(Ok(Ok(value))) => Ok(value),
Ok(Ok(Err(error))) => Err(WebDriverErrorResponse::javascript_error(&error, None)),
Ok(Err(_)) => Err(WebDriverErrorResponse::unknown_error("Channel closed")),
Err(_) => {
async_state.cancel(&async_id);
Err(WebDriverErrorResponse::script_timeout())
}
}
}
}
type ScriptResultSender = Arc<std::sync::Mutex<Option<oneshot::Sender<Result<Value, String>>>>>;
type CaptureResultSender = Arc<std::sync::Mutex<Option<oneshot::Sender<Result<String, String>>>>>;
type PrintResultSender = Arc<std::sync::Mutex<Option<oneshot::Sender<Result<(), String>>>>>;
mod handlers {
#![allow(clippy::inline_always, clippy::ref_as_ptr)]
use serde_json::Value;
use webview2_com::Microsoft::Web::WebView2::Win32::{
ICoreWebView2, ICoreWebView2CapturePreviewCompletedHandler,
ICoreWebView2CapturePreviewCompletedHandler_Impl, ICoreWebView2Deferral,
ICoreWebView2ExecuteScriptCompletedHandler,
ICoreWebView2ExecuteScriptCompletedHandler_Impl, ICoreWebView2PrintToPdfCompletedHandler,
ICoreWebView2PrintToPdfCompletedHandler_Impl, ICoreWebView2ScriptDialogOpeningEventArgs,
ICoreWebView2ScriptDialogOpeningEventHandler,
ICoreWebView2ScriptDialogOpeningEventHandler_Impl,
ICoreWebView2WebMessageReceivedEventArgs, ICoreWebView2WebMessageReceivedEventHandler,
ICoreWebView2WebMessageReceivedEventHandler_Impl, COREWEBVIEW2_SCRIPT_DIALOG_KIND_ALERT,
COREWEBVIEW2_SCRIPT_DIALOG_KIND_CONFIRM, COREWEBVIEW2_SCRIPT_DIALOG_KIND_PROMPT,
};
use windows::core::{implement, Interface};
use super::{
AlertState, AlertType, AsyncScriptState, CaptureResultSender, PendingAlert,
PrintResultSender, ScriptResultSender, SendableComPtr, HANDLER_NAME,
};
use crate::platform::alert_state::AlertResponse;
use std::sync::Arc;
#[implement(ICoreWebView2ExecuteScriptCompletedHandler)]
pub struct ExecuteScriptHandler {
pub tx: ScriptResultSender,
}
impl ExecuteScriptHandler {
pub fn new(tx: ScriptResultSender) -> Self {
Self { tx }
}
}
impl ICoreWebView2ExecuteScriptCompletedHandler_Impl for ExecuteScriptHandler_Impl {
fn Invoke(
&self,
errorcode: windows::core::HRESULT,
resultobjectasjson: &windows::core::PCWSTR,
) -> windows::core::Result<()> {
let response = if errorcode.is_err() {
Err(format!("Script execution failed: {errorcode:?}"))
} else {
let json_str = unsafe { resultobjectasjson.to_string().unwrap_or_default() };
match serde_json::from_str(&json_str) {
Ok(value) => Ok(value),
Err(_) => Ok(Value::String(json_str)),
}
};
if let Ok(mut guard) = self.tx.lock() {
if let Some(tx) = guard.take() {
let _ = tx.send(response);
}
}
Ok(())
}
}
#[implement(ICoreWebView2CapturePreviewCompletedHandler)]
pub struct CapturePreviewHandler {
pub tx: CaptureResultSender,
pub stream: windows::Win32::System::Com::IStream,
}
impl CapturePreviewHandler {
pub fn new(tx: CaptureResultSender, stream: windows::Win32::System::Com::IStream) -> Self {
Self { tx, stream }
}
}
impl ICoreWebView2CapturePreviewCompletedHandler_Impl for CapturePreviewHandler_Impl {
fn Invoke(&self, errorcode: windows::core::HRESULT) -> windows::core::Result<()> {
let response = if errorcode.is_err() {
Err(format!("Capture preview failed: {errorcode:?}"))
} else {
unsafe {
use super::{STATFLAG_NONAME, STREAM_SEEK_SET};
use base64::engine::general_purpose::STANDARD as BASE64_STANDARD;
use base64::Engine as _;
let mut stat = std::mem::zeroed();
if self.stream.Stat(&raw mut stat, STATFLAG_NONAME).is_err() {
return Ok(());
}
let size = usize::try_from(stat.cbSize).unwrap_or(0);
if size == 0 {
if let Ok(mut guard) = self.tx.lock() {
if let Some(tx) = guard.take() {
let _ = tx.send(Err("Empty stream".to_string()));
}
}
return Ok(());
}
let _ = self.stream.Seek(0, STREAM_SEEK_SET, None);
let mut buffer = vec![0u8; size];
let mut bytes_read = 0u32;
if self
.stream
.Read(
buffer.as_mut_ptr().cast(),
u32::try_from(size).unwrap_or(u32::MAX),
Some(&raw mut bytes_read),
)
.is_err()
{
if let Ok(mut guard) = self.tx.lock() {
if let Some(tx) = guard.take() {
let _ = tx.send(Err("Failed to read stream".to_string()));
}
}
return Ok(());
}
buffer.truncate(bytes_read as usize);
let base64 = BASE64_STANDARD.encode(&buffer);
if let Ok(mut guard) = self.tx.lock() {
if let Some(tx) = guard.take() {
let _ = tx.send(Ok(base64));
}
}
return Ok(());
}
};
if let Ok(mut guard) = self.tx.lock() {
if let Some(tx) = guard.take() {
let _ = tx.send(response);
}
}
Ok(())
}
}
#[implement(ICoreWebView2PrintToPdfCompletedHandler)]
pub struct PrintToPdfHandler {
pub tx: PrintResultSender,
}
impl PrintToPdfHandler {
pub fn new(tx: PrintResultSender) -> Self {
Self { tx }
}
}
impl ICoreWebView2PrintToPdfCompletedHandler_Impl for PrintToPdfHandler_Impl {
fn Invoke(
&self,
errorcode: windows::core::HRESULT,
issuccessful: super::BOOL,
) -> windows::core::Result<()> {
let response = if errorcode.is_err() {
Err(format!("PrintToPdf failed: {errorcode:?}"))
} else if !issuccessful.as_bool() {
Err("PrintToPdf was not successful".to_string())
} else {
Ok(())
};
if let Ok(mut guard) = self.tx.lock() {
if let Some(tx) = guard.take() {
let _ = tx.send(response);
}
}
Ok(())
}
}
#[implement(ICoreWebView2WebMessageReceivedEventHandler)]
pub struct WebMessageReceivedHandler {
state_ptr: *const AsyncScriptState,
}
unsafe impl Send for WebMessageReceivedHandler {}
unsafe impl Sync for WebMessageReceivedHandler {}
impl WebMessageReceivedHandler {
pub fn new(state: &AsyncScriptState) -> Self {
Self {
state_ptr: state as *const AsyncScriptState,
}
}
}
impl ICoreWebView2WebMessageReceivedEventHandler_Impl for WebMessageReceivedHandler_Impl {
fn Invoke(
&self,
_sender: windows::core::Ref<'_, ICoreWebView2>,
args: windows::core::Ref<'_, ICoreWebView2WebMessageReceivedEventArgs>,
) -> windows::core::Result<()> {
unsafe {
let state_ptr = self.state_ptr;
if state_ptr.is_null() {
tracing::error!("AsyncScriptState pointer is null");
return Ok(());
}
let state = &*state_ptr;
let Some(args_owned) = args.clone() else {
return Ok(());
};
let mut msg_ptr = windows::core::PWSTR::null();
if args_owned.WebMessageAsJson(&raw mut msg_ptr).is_err() {
return Ok(()); }
let msg_text = msg_ptr.to_string().unwrap_or_default();
let inner_str: String = match serde_json::from_str(&msg_text) {
Ok(s) => s,
Err(_) => return Ok(()), };
let msg: Value = match serde_json::from_str(&inner_str) {
Ok(v) => v,
Err(_) => return Ok(()), };
let handler = msg.get("handler").and_then(Value::as_str);
if handler != Some(HANDLER_NAME) {
return Ok(()); }
let Some(async_id) = msg.get("id").and_then(Value::as_str) else {
tracing::warn!("Message missing 'id' field");
return Ok(());
};
let async_id = async_id.to_string();
if let Some(error) = msg.get("error").and_then(Value::as_str) {
if !error.is_empty() {
state.complete(&async_id, Err(error.to_string()));
return Ok(());
}
}
let result = msg.get("result").cloned().unwrap_or(Value::Null);
state.complete(&async_id, Ok(result));
}
Ok(())
}
}
#[implement(ICoreWebView2ScriptDialogOpeningEventHandler)]
pub struct ScriptDialogOpeningHandler {
alert_state: Arc<AlertState>,
}
unsafe impl Send for ScriptDialogOpeningHandler {}
unsafe impl Sync for ScriptDialogOpeningHandler {}
impl ScriptDialogOpeningHandler {
pub fn new(alert_state: Arc<AlertState>) -> Self {
Self { alert_state }
}
}
impl ICoreWebView2ScriptDialogOpeningEventHandler_Impl for ScriptDialogOpeningHandler_Impl {
fn Invoke(
&self,
_sender: windows::core::Ref<'_, ICoreWebView2>,
args: windows::core::Ref<'_, ICoreWebView2ScriptDialogOpeningEventArgs>,
) -> windows::core::Result<()> {
let (args_ptr, deferral_ptr, rx) = unsafe {
let Some(args) = args.clone() else {
return Ok(());
};
let mut kind = std::mem::zeroed();
if args.Kind(&raw mut kind).is_err() {
tracing::error!("Failed to get script dialog kind");
return Ok(());
}
let mut message_ptr = windows::core::PWSTR::null();
if args.Message(&raw mut message_ptr).is_err() {
tracing::error!("Failed to get script dialog message");
return Ok(());
}
let message = message_ptr.to_string().unwrap_or_default();
let mut default_text_ptr = windows::core::PWSTR::null();
let default_text = if args.DefaultText(&raw mut default_text_ptr).is_ok() {
let text = default_text_ptr.to_string().unwrap_or_default();
if text.is_empty() {
None
} else {
Some(text)
}
} else {
None
};
let alert_type = if kind == COREWEBVIEW2_SCRIPT_DIALOG_KIND_ALERT {
AlertType::Alert
} else if kind == COREWEBVIEW2_SCRIPT_DIALOG_KIND_CONFIRM {
AlertType::Confirm
} else if kind == COREWEBVIEW2_SCRIPT_DIALOG_KIND_PROMPT {
AlertType::Prompt
} else {
let _ = args.Accept();
return Ok(());
};
tracing::debug!("Intercepted {:?} dialog: {}", alert_type, message);
let deferral = match args.GetDeferral() {
Ok(d) => d,
Err(e) => {
tracing::error!("Failed to get deferral: {e:?}");
let _ = args.Accept();
return Ok(());
}
};
let (tx, rx) = std::sync::mpsc::channel::<AlertResponse>();
self.alert_state.set_pending(PendingAlert {
message: message.clone(),
default_text: default_text.clone(),
alert_type,
responder: tx,
});
let args_ptr = SendableComPtr(args.into_raw());
let deferral_ptr = SendableComPtr(deferral.into_raw());
(args_ptr, deferral_ptr, rx)
};
std::thread::spawn(move || {
let timeout = std::time::Duration::from_secs(30);
let response = rx.recv_timeout(timeout);
unsafe {
let args =
ICoreWebView2ScriptDialogOpeningEventArgs::from_raw(args_ptr.as_ptr());
let deferral = ICoreWebView2Deferral::from_raw(deferral_ptr.as_ptr());
match response {
Ok(AlertResponse {
accepted,
prompt_text,
}) => {
if accepted {
if let Some(text) = prompt_text {
let result = windows::core::HSTRING::from(text.as_str());
let _ =
args.SetResultText(windows::core::PCWSTR(result.as_ptr()));
}
let _ = args.Accept();
}
}
Err(_) => {
let _ = args.Accept();
}
}
let _ = deferral.Complete();
}
});
Ok(())
}
}
}
use handlers::{
CapturePreviewHandler, ExecuteScriptHandler, ScriptDialogOpeningHandler,
WebMessageReceivedHandler,
};
unsafe fn register_message_handler(webview: &ICoreWebView2, state: &AsyncScriptState) {
let handler: ICoreWebView2WebMessageReceivedEventHandler =
WebMessageReceivedHandler::new(state).into();
let mut token = std::mem::zeroed();
if let Err(e) = webview.add_WebMessageReceived(&handler, &raw mut token) {
tracing::error!("Failed to register WebMessageReceived handler: {e:?}");
} else {
tracing::debug!("Registered native message handler for webview");
}
}