use std::sync::Arc;
use async_trait::async_trait;
use base64::engine::general_purpose::STANDARD as BASE64_STANDARD;
use base64::Engine as _;
use block2::{DynBlock, RcBlock};
use objc2::rc::Retained;
use objc2::runtime::{AnyObject, ProtocolObject};
use objc2::{define_class, msg_send, DefinedClass, MainThreadMarker, MainThreadOnly};
use objc2_app_kit::{NSBitmapImageFileType, NSBitmapImageRep, NSImage};
use objc2_foundation::{NSData, NSDictionary, NSError, NSObject, NSObjectProtocol, NSString};
use objc2_web_kit::{
WKContentWorld, WKFrameInfo, WKPDFConfiguration, WKSnapshotConfiguration, WKUIDelegate,
WKWebView,
};
use serde_json::Value;
use tauri::{Manager, Runtime, WebviewWindow};
use tokio::sync::oneshot;
use crate::platform::alert_state::{AlertState, AlertStateManager, AlertType, PendingAlert};
use crate::platform::{wrap_script_for_frame_context, FrameId, PlatformExecutor, PrintOptions};
use crate::server::response::WebDriverErrorResponse;
use crate::webdriver::Timeouts;
static DELEGATE_KEY: u8 = 0;
#[derive(Clone)]
pub struct MacOSExecutor<R: Runtime> {
window: WebviewWindow<R>,
timeouts: Timeouts,
frame_context: Vec<FrameId>,
}
impl<R: Runtime> MacOSExecutor<R> {
pub fn new(window: WebviewWindow<R>, timeouts: Timeouts, frame_context: Vec<FrameId>) -> Self {
Self {
window,
timeouts,
frame_context,
}
}
}
pub fn register_webview_handlers<R: Runtime>(webview: &tauri::Webview<R>) {
use objc2::ffi::{objc_setAssociatedObject, OBJC_ASSOCIATION_RETAIN_NONATOMIC};
let manager = webview.app_handle().state::<AlertStateManager>();
let alert_state = manager.get_or_create(webview.label());
let _ = webview.with_webview(move |webview| unsafe {
let wk_webview: &WKWebView = &*webview.inner().cast();
let delegate = WebDriverUIDelegate::new(alert_state);
let delegate_protocol: Retained<ProtocolObject<dyn WKUIDelegate>> =
ProtocolObject::from_retained(delegate);
let _: () = msg_send![wk_webview, setUIDelegate: &*delegate_protocol];
objc_setAssociatedObject(
std::ptr::from_ref::<WKWebView>(wk_webview)
.cast_mut()
.cast(),
std::ptr::addr_of!(DELEGATE_KEY).cast(),
Retained::into_raw(delegate_protocol).cast(),
OBJC_ASSOCIATION_RETAIN_NONATOMIC,
);
tracing::debug!("Registered UI delegate for webview");
});
}
#[async_trait]
impl<R: Runtime + 'static> PlatformExecutor<R> for MacOSExecutor<R> {
fn window(&self) -> &WebviewWindow<R> {
&self.window
}
fn script_timeout_ms(&self) -> u64 {
self.timeouts.script_ms
}
async fn evaluate_js(&self, script: &str) -> Result<Value, WebDriverErrorResponse> {
let (tx, rx) = oneshot::channel();
let script_owned = wrap_script_for_frame_context(script, &self.frame_context);
let result = self.window.with_webview(move |webview| unsafe {
let wk_webview: &WKWebView = &*webview.inner().cast();
let ns_script = NSString::from_str(&script_owned);
let tx = Arc::new(std::sync::Mutex::new(Some(tx)));
let block = RcBlock::new(move |result: *mut AnyObject, error: *mut NSError| {
let response = if !error.is_null() {
let error_ref = &*error;
let description = error_ref.localizedDescription();
Err(description.to_string())
} else if result.is_null() {
Ok(Value::Null)
} else {
let obj = &*result;
Ok(ns_object_to_json(obj))
};
if let Ok(mut guard) = tx.lock() {
if let Some(tx) = guard.take() {
let _ = tx.send(response);
}
}
});
wk_webview.evaluateJavaScript_completionHandler(&ns_script, Some(&block));
});
if let Err(e) = result {
return Err(WebDriverErrorResponse::javascript_error(
&e.to_string(),
None,
));
}
let timeout = std::time::Duration::from_millis(self.timeouts.script_ms);
match tokio::time::timeout(timeout, rx).await {
Ok(Ok(Ok(value))) => Ok(serde_json::json!({
"success": true,
"value": value
})),
Ok(Ok(Err(error))) => Err(WebDriverErrorResponse::javascript_error(&error, None)),
Ok(Err(_)) => Err(WebDriverErrorResponse::unknown_error("Channel closed")),
Err(_) => Err(WebDriverErrorResponse::script_timeout()),
}
}
async fn execute_async_script(
&self,
script: &str,
args: &[Value],
) -> Result<Value, WebDriverErrorResponse> {
let args_json = serde_json::to_string(args)
.map_err(|e| WebDriverErrorResponse::invalid_argument(&e.to_string()))?;
let wrapper = format!(
r"return new Promise((resolve, reject) => {{
var ELEMENT_KEY = 'element-6066-11e4-a52e-4f735466cecf';
function deserializeArg(arg) {{
if (arg === null || arg === undefined) return arg;
if (Array.isArray(arg)) return arg.map(deserializeArg);
if (typeof arg === 'object') {{
if (arg[ELEMENT_KEY]) {{
var el = window['__wd_el_' + arg[ELEMENT_KEY].replace(/-/g, '')];
if (!el) throw new Error('stale element reference');
return el;
}}
var result = {{}};
for (var key in arg) {{
if (arg.hasOwnProperty(key)) result[key] = deserializeArg(arg[key]);
}}
return result;
}}
return arg;
}}
var __done = function(result, error) {{
if (error) {{
reject(new Error(typeof error === 'string' ? error : String(error)));
}} else {{
resolve(result);
}}
}};
var __args = {args_json}.map(deserializeArg);
__args.push(__done);
try {{
(function() {{ {script} }}).apply(null, __args);
}} catch (e) {{
reject(e);
}}
}})"
);
let (tx, rx) = oneshot::channel();
let result = self.window.with_webview(move |webview| unsafe {
let wk_webview: &WKWebView = &*webview.inner().cast();
let ns_script = NSString::from_str(&wrapper);
let mtm = MainThreadMarker::new_unchecked();
let empty_dict: Retained<NSDictionary<NSString, AnyObject>> = NSDictionary::new();
let content_world = WKContentWorld::pageWorld(mtm);
let tx = Arc::new(std::sync::Mutex::new(Some(tx)));
let block = RcBlock::new(move |result: *mut AnyObject, error: *mut NSError| {
let response = if !error.is_null() {
let error_ref = &*error;
let description = error_ref.localizedDescription();
Err(description.to_string())
} else if result.is_null() {
Ok(Value::Null)
} else {
let obj = &*result;
Ok(ns_object_to_json(obj))
};
if let Ok(mut guard) = tx.lock() {
if let Some(tx) = guard.take() {
let _ = tx.send(response);
}
}
});
wk_webview.callAsyncJavaScript_arguments_inFrame_inContentWorld_completionHandler(
&ns_script,
Some(&empty_dict),
None,
&content_world,
Some(&block),
);
});
if let Err(e) = result {
return Err(WebDriverErrorResponse::javascript_error(
&e.to_string(),
None,
));
}
let timeout = std::time::Duration::from_millis(self.timeouts.script_ms);
match tokio::time::timeout(timeout, rx).await {
Ok(Ok(Ok(value))) => Ok(value),
Ok(Ok(Err(error))) => Err(WebDriverErrorResponse::javascript_error(&error, None)),
Ok(Err(_)) => Err(WebDriverErrorResponse::unknown_error("Channel closed")),
Err(_) => Err(WebDriverErrorResponse::script_timeout()),
}
}
async fn take_screenshot(&self) -> Result<String, WebDriverErrorResponse> {
let (tx, rx) = oneshot::channel();
let result = self.window.with_webview(move |webview| unsafe {
let wk_webview: &WKWebView = &*webview.inner().cast();
let mtm = MainThreadMarker::new_unchecked();
let config = WKSnapshotConfiguration::new(mtm);
let tx = Arc::new(std::sync::Mutex::new(Some(tx)));
let block = RcBlock::new(move |image: *mut NSImage, error: *mut NSError| {
let response = if !error.is_null() {
let error_ref = &*error;
let description = error_ref.localizedDescription();
Err(description.to_string())
} else if image.is_null() {
Err("No image returned".to_string())
} else {
let image_ref = &*image;
image_to_png_base64(image_ref)
};
if let Ok(mut guard) = tx.lock() {
if let Some(tx) = guard.take() {
let _ = tx.send(response);
}
}
});
wk_webview.takeSnapshotWithConfiguration_completionHandler(Some(&config), &block);
});
if let Err(e) = result {
return Err(WebDriverErrorResponse::unknown_error(&e.to_string()));
}
let timeout = std::time::Duration::from_millis(self.timeouts.script_ms);
match tokio::time::timeout(timeout, rx).await {
Ok(Ok(Ok(base64))) => Ok(base64),
Ok(Ok(Err(error))) => Err(WebDriverErrorResponse::unknown_error(&error)),
Ok(Err(_)) => Err(WebDriverErrorResponse::unknown_error("Channel closed")),
Err(_) => Err(WebDriverErrorResponse::script_timeout()),
}
}
async fn take_element_screenshot(
&self,
js_var: &str,
) -> Result<String, WebDriverErrorResponse> {
let script = format!(
r"(function() {{
var el = window.{js_var};
if (!el || !el.isConnected) {{
throw new Error('stale element reference');
}}
// Use html2canvas-like approach if available, otherwise scroll into view
el.scrollIntoView({{ block: 'center', inline: 'center' }});
// Return element bounds for clipping
var rect = el.getBoundingClientRect();
return {{
x: rect.x,
y: rect.y,
width: rect.width,
height: rect.height
}};
}})()"
);
self.evaluate_js(&script).await?;
self.take_screenshot().await
}
async fn print_page(&self, options: PrintOptions) -> Result<String, WebDriverErrorResponse> {
let page_width = options.page_width.unwrap_or(21.0);
let page_height = options.page_height.unwrap_or(29.7);
let margin_top = options.margin_top.unwrap_or(1.0);
let margin_bottom = options.margin_bottom.unwrap_or(1.0);
let margin_left = options.margin_left.unwrap_or(1.0);
let margin_right = options.margin_right.unwrap_or(1.0);
let orientation = options.orientation.as_deref().unwrap_or("portrait");
let css_script = format!(
r"(function() {{
var style = document.createElement('style');
style.id = '__webdriver_print_style';
style.textContent = `
@page {{
size: {page_width}cm {page_height}cm {orientation};
margin: {margin_top}cm {margin_right}cm {margin_bottom}cm {margin_left}cm;
}}
@media print {{
body {{
-webkit-print-color-adjust: exact;
print-color-adjust: exact;
}}
}}
`;
document.head.appendChild(style);
return true;
}})()"
);
self.evaluate_js(&css_script).await?;
let (tx, rx) = oneshot::channel();
let result = self.window.with_webview(move |webview| unsafe {
let wk_webview: &WKWebView = &*webview.inner().cast();
let mtm = MainThreadMarker::new_unchecked();
let config = WKPDFConfiguration::new(mtm);
let tx = Arc::new(std::sync::Mutex::new(Some(tx)));
let block = RcBlock::new(move |data: *mut NSData, error: *mut NSError| {
let response = if !error.is_null() {
let error_ref = &*error;
let description = error_ref.localizedDescription();
Err(description.to_string())
} else if data.is_null() {
Err("No PDF data returned".to_string())
} else {
let data_ref = &*data;
let bytes = data_ref.to_vec();
Ok(bytes)
};
if let Ok(mut guard) = tx.lock() {
if let Some(tx) = guard.take() {
let _ = tx.send(response);
}
}
});
wk_webview.createPDFWithConfiguration_completionHandler(Some(&config), &block);
});
if let Err(e) = result {
return Err(WebDriverErrorResponse::unknown_error(&e.to_string()));
}
let timeout = std::time::Duration::from_millis(self.timeouts.script_ms);
let pdf_result = match tokio::time::timeout(timeout, rx).await {
Ok(Ok(Ok(bytes))) => Ok(bytes),
Ok(Ok(Err(error))) => Err(WebDriverErrorResponse::unknown_error(&error)),
Ok(Err(_)) => Err(WebDriverErrorResponse::unknown_error("Channel closed")),
Err(_) => Err(WebDriverErrorResponse::script_timeout()),
};
let _ = self
.evaluate_js(
r"(function() {
var style = document.getElementById('__webdriver_print_style');
if (style) style.remove();
return true;
})()",
)
.await;
pdf_result.map(|bytes| BASE64_STANDARD.encode(&bytes))
}
}
unsafe fn image_to_png_base64(image: &NSImage) -> Result<String, String> {
let tiff_data: Option<objc2::rc::Retained<NSData>> = image.TIFFRepresentation();
let tiff_data = tiff_data.ok_or("Failed to get TIFF representation")?;
let bitmap_rep = NSBitmapImageRep::imageRepWithData(&tiff_data)
.ok_or("Failed to create bitmap image rep")?;
let empty_dict: objc2::rc::Retained<NSDictionary<NSString>> = NSDictionary::new();
let png_data: Option<objc2::rc::Retained<NSData>> =
bitmap_rep.representationUsingType_properties(NSBitmapImageFileType::PNG, &empty_dict);
let png_data = png_data.ok_or("Failed to convert to PNG")?;
let bytes = png_data.to_vec();
Ok(BASE64_STANDARD.encode(&bytes))
}
pub(super) unsafe fn ns_object_to_json(obj: &AnyObject) -> Value {
use objc2_foundation::NSString as NSStr;
let class = obj.class();
let class_name = class.name().to_str().unwrap_or("");
if class_name.contains("String") {
let ns_str: &NSStr = &*std::ptr::from_ref::<AnyObject>(obj).cast::<NSStr>();
return Value::String(ns_str.to_string());
}
if class_name.contains("Number") || class_name.contains("Boolean") {
use objc2::msg_send;
use objc2::runtime::Bool;
if class_name.contains("Boolean") {
let bool_val: Bool = msg_send![obj, boolValue];
return Value::Bool(bool_val.as_bool());
}
let double_val: f64 = msg_send![obj, doubleValue];
let int_val: i64 = msg_send![obj, longLongValue];
#[allow(clippy::cast_precision_loss)]
if (int_val as f64 - double_val).abs() < f64::EPSILON {
return Value::Number(serde_json::Number::from(int_val));
} else if let Some(n) = serde_json::Number::from_f64(double_val) {
return Value::Number(n);
}
return Value::Null;
}
if class_name.contains("Array") {
use objc2::msg_send;
let count: usize = msg_send![obj, count];
let mut arr = Vec::new();
for i in 0..count {
let item: *mut AnyObject = msg_send![obj, objectAtIndex: i];
if !item.is_null() {
arr.push(ns_object_to_json(&*item));
}
}
return Value::Array(arr);
}
if class_name.contains("Dictionary") {
use objc2::msg_send;
let keys: *mut AnyObject = msg_send![obj, allKeys];
if keys.is_null() {
return Value::Object(serde_json::Map::new());
}
let count: usize = msg_send![keys, count];
let mut map = serde_json::Map::new();
for i in 0..count {
let key: *mut AnyObject = msg_send![keys, objectAtIndex: i];
if key.is_null() {
continue;
}
let key_class = (&*key).class().name().to_str().unwrap_or("");
if !key_class.contains("String") {
continue;
}
let ns_key: &NSStr = &*key.cast_const().cast::<NSStr>();
let key_str = ns_key.to_string();
let val: *mut AnyObject = msg_send![obj, objectForKey: key];
if !val.is_null() {
map.insert(key_str, ns_object_to_json(&*val));
}
}
return Value::Object(map);
}
if class_name.contains("Null") {
return Value::Null;
}
Value::Null
}
struct WebDriverUIDelegateIvars {
alert_state: Arc<AlertState>,
}
unsafe impl Send for WebDriverUIDelegateIvars {}
unsafe impl Sync for WebDriverUIDelegateIvars {}
define_class!(
#[unsafe(super(NSObject))]
#[thread_kind = MainThreadOnly]
#[name = "WebDriverUIDelegate"]
#[ivars = WebDriverUIDelegateIvars]
struct WebDriverUIDelegate;
unsafe impl NSObjectProtocol for WebDriverUIDelegate {}
#[allow(non_snake_case)]
unsafe impl WKUIDelegate for WebDriverUIDelegate {
#[unsafe(method(webView:runJavaScriptAlertPanelWithMessage:initiatedByFrame:completionHandler:))]
fn webView_runJavaScriptAlertPanelWithMessage_initiatedByFrame_completionHandler(
&self,
_webview: &WKWebView,
message: &NSString,
_frame: &WKFrameInfo,
completion_handler: &DynBlock<dyn Fn()>,
) {
let message_str = message.to_string();
tracing::debug!("Intercepted alert: {message_str}");
let (tx, rx) = std::sync::mpsc::channel();
self.ivars().alert_state.set_pending(PendingAlert {
message: message_str,
default_text: None,
alert_type: AlertType::Alert,
responder: tx,
});
let timeout = std::time::Duration::from_secs(30);
let _ = rx.recv_timeout(timeout);
completion_handler.call(());
}
#[unsafe(method(webView:runJavaScriptConfirmPanelWithMessage:initiatedByFrame:completionHandler:))]
fn webView_runJavaScriptConfirmPanelWithMessage_initiatedByFrame_completionHandler(
&self,
_webview: &WKWebView,
message: &NSString,
_frame: &WKFrameInfo,
completion_handler: &DynBlock<dyn Fn(objc2::runtime::Bool)>,
) {
let message_str = message.to_string();
tracing::debug!("Intercepted confirm: {message_str}");
let (tx, rx) = std::sync::mpsc::channel();
self.ivars().alert_state.set_pending(PendingAlert {
message: message_str,
default_text: None,
alert_type: AlertType::Confirm,
responder: tx,
});
let timeout = std::time::Duration::from_secs(30);
let response = rx.recv_timeout(timeout);
let accepted = response.map(|r| r.accepted).unwrap_or(true);
completion_handler.call((objc2::runtime::Bool::from(accepted),));
}
#[unsafe(method(webView:runJavaScriptTextInputPanelWithPrompt:defaultText:initiatedByFrame:completionHandler:))]
fn webView_runJavaScriptTextInputPanelWithPrompt_defaultText_initiatedByFrame_completionHandler(
&self,
_webview: &WKWebView,
prompt: &NSString,
default_text: Option<&NSString>,
_frame: &WKFrameInfo,
completion_handler: &DynBlock<dyn Fn(*mut NSString)>,
) {
let prompt_str = prompt.to_string();
let default = default_text.map(std::string::ToString::to_string);
tracing::debug!("Intercepted prompt: {prompt_str}");
let (tx, rx) = std::sync::mpsc::channel();
self.ivars().alert_state.set_pending(PendingAlert {
message: prompt_str,
default_text: default.clone(),
alert_type: AlertType::Prompt,
responder: tx,
});
let timeout = std::time::Duration::from_secs(30);
let response = rx.recv_timeout(timeout);
let result: *mut NSString = match response {
Ok(r) if r.accepted => {
let text = r.prompt_text.or(default).unwrap_or_default();
let ns_str = NSString::from_str(&text);
Retained::into_raw(ns_str)
}
_ => std::ptr::null_mut(), };
completion_handler.call((result,));
}
}
);
impl WebDriverUIDelegate {
unsafe fn new(alert_state: Arc<AlertState>) -> Retained<Self> {
let mtm = MainThreadMarker::new_unchecked();
let this = Self::alloc(mtm);
let this = this.set_ivars(WebDriverUIDelegateIvars { alert_state });
msg_send![super(this), init]
}
}