use std::{
borrow::Cow,
ffi::{c_char, c_void, CStr},
panic::AssertUnwindSafe,
ptr::NonNull,
};
use http::{
header::{CONTENT_LENGTH, CONTENT_TYPE},
Request, Response as HttpResponse, StatusCode, Version,
};
use objc2::{
rc::Retained,
runtime::{AnyClass, AnyObject, ClassBuilder, ProtocolObject},
AllocAnyThread, ClassType, Message,
};
use objc2_foundation::{
NSData, NSHTTPURLResponse, NSMutableDictionary, NSObject, NSObjectProtocol, NSString, NSURL,
NSUUID,
};
use objc2_web_kit::{WKURLSchemeHandler, WKURLSchemeTask};
use crate::{wkwebview::WEBVIEW_STATE, RequestAsyncResponder, WryWebView};
pub fn create(name: &str) -> &AnyClass {
unsafe {
let unique_id = std::ptr::addr_of!(WEBVIEW_STATE) as usize;
let scheme_name = format!("{name}URLSchemeHandler_{unique_id:x}\0");
let scheme_name = CStr::from_bytes_with_nul(scheme_name.as_bytes()).unwrap();
let cls = ClassBuilder::new(scheme_name, NSObject::class());
match cls {
Some(mut cls) => {
cls.add_ivar::<*mut c_char>(c"webview_id");
cls.add_ivar::<usize>(c"protocol_index");
cls.add_method(
objc2::sel!(webView:startURLSchemeTask:),
start_task as extern "C" fn(_, _, _, _),
);
cls.add_method(
objc2::sel!(webView:stopURLSchemeTask:),
stop_task as extern "C" fn(_, _, _, _),
);
cls.register()
}
None => AnyClass::get(scheme_name).expect("Failed to get the class definition"),
}
}
}
extern "C" fn start_task(
this: &AnyObject,
_sel: objc2::runtime::Sel,
webview: &WryWebView,
task: &ProtocolObject<dyn WKURLSchemeTask>,
) {
unsafe {
#[cfg(feature = "tracing")]
let span = tracing::info_span!(parent: None, "wry::custom_protocol::handle", uri = tracing::field::Empty)
.entered();
let task_key = task.hash(); let task_uuid = webview.add_custom_task_key(task_key);
let ivar = this.class().instance_variable(c"webview_id").unwrap();
let webview_id_ptr: *mut c_char = *ivar.load(this);
let webview_id = CStr::from_ptr(webview_id_ptr)
.to_str()
.ok()
.unwrap_or_default();
let ivar = this.class().instance_variable(c"protocol_index").unwrap();
let protocol_index: usize = *ivar.load(this);
let function = WEBVIEW_STATE
.read()
.unwrap()
.get(webview_id)
.and_then(|v| v.protocol_ptrs.get(protocol_index))
.cloned();
if let Some(function) = function {
let request = task.request();
let url = request.URL().unwrap();
let uri = url.absoluteString().unwrap().to_string();
#[cfg(feature = "tracing")]
span.record("uri", uri.clone());
let method = request.HTTPMethod().unwrap().to_string();
let mut http_request = Request::builder().uri(uri).method(method.as_str());
let mut sent_form_body = Vec::new();
let body = request.HTTPBody();
let body_stream = request.HTTPBodyStream();
if let Some(body) = body {
sent_form_body = body.to_vec();
} else if let Some(body_stream) = body_stream {
body_stream.open();
while body_stream.hasBytesAvailable() {
sent_form_body.reserve(128);
let p = sent_form_body.as_mut_ptr().add(sent_form_body.len());
let read_length = sent_form_body.capacity() - sent_form_body.len();
let count = body_stream.read_maxLength(NonNull::new(p).unwrap(), read_length);
sent_form_body.set_len(sent_form_body.len() + count as usize);
}
body_stream.close();
}
let all_headers = request.allHTTPHeaderFields();
if let Some(all_headers) = all_headers {
for current_header in all_headers.allKeys().iter() {
let header_value = all_headers.valueForKey(¤t_header).unwrap();
http_request = http_request.header(current_header.to_string(), header_value.to_string());
}
}
let respond_with_404 = || {
let urlresponse = NSHTTPURLResponse::alloc();
let response = NSHTTPURLResponse::initWithURL_statusCode_HTTPVersion_headerFields(
urlresponse,
&url,
StatusCode::NOT_FOUND.as_u16().try_into().unwrap(),
Some(&NSString::from_str(
format!("{:#?}", Version::HTTP_11).as_str(),
)),
None,
)
.unwrap();
task.didReceiveResponse(&response);
task.didFinish();
};
fn check_webview_id_valid(webview_id: &str) -> crate::Result<()> {
if !WEBVIEW_STATE.read().unwrap().contains_key(webview_id) {
return Err(crate::Error::CustomProtocolTaskInvalid);
}
Ok(())
}
fn check_task_is_valid(
webview: &WryWebView,
task_key: usize,
current_uuid: Retained<NSUUID>,
) -> crate::Result<()> {
let latest_task_uuid = webview.get_custom_task_uuid(task_key);
let Some(latest_uuid) = latest_task_uuid else {
return Err(crate::Error::CustomProtocolTaskInvalid);
};
if latest_uuid != current_uuid {
return Err(crate::Error::CustomProtocolTaskInvalid);
}
Ok(())
}
match http_request.body(sent_form_body) {
Ok(final_request) => {
let webview = webview.retain();
let task = task.retain();
let responder: Box<dyn FnOnce(HttpResponse<Cow<'static, [u8]>>)> =
Box::new(move |sent_response| {
let validate = || -> crate::Result<()> {
check_webview_id_valid(webview_id)?;
check_task_is_valid(&webview, task_key, task_uuid.clone())?;
Ok(())
};
if let Err(_e) = validate() {
#[cfg(feature = "tracing")]
tracing::warn!("Task invalid before sending response: {:?}", _e);
return; }
unsafe fn response(
task: Retained<ProtocolObject<dyn WKURLSchemeTask>>,
webview: Retained<WryWebView>,
task_key: usize,
task_uuid: Retained<NSUUID>,
webview_id: &str,
url: Retained<NSURL>,
sent_response: HttpResponse<Cow<'_, [u8]>>,
) -> crate::Result<()> {
check_webview_id_valid(webview_id)?;
check_task_is_valid(&webview, task_key, task_uuid.clone())?;
let content = sent_response.body();
let wanted_mime = sent_response.headers().get(CONTENT_TYPE);
let wanted_status_code = sent_response.status().as_u16() as i32;
let wanted_version = format!("{:#?}", sent_response.version());
let headers = NSMutableDictionary::new();
if let Some(mime) = wanted_mime {
headers.insert(
&*NSString::from_str(CONTENT_TYPE.as_str()),
&*NSString::from_str(mime.to_str().unwrap()),
);
}
headers.insert(
&*NSString::from_str(CONTENT_LENGTH.as_str()),
&*NSString::from_str(&content.len().to_string()),
);
for (name, value) in sent_response.headers().iter() {
if let Ok(value) = value.to_str() {
headers.insert(
&*NSString::from_str(name.as_str()),
&*NSString::from_str(value),
);
}
}
let urlresponse = NSHTTPURLResponse::alloc();
let response = NSHTTPURLResponse::initWithURL_statusCode_HTTPVersion_headerFields(
urlresponse,
&url,
wanted_status_code.try_into().unwrap(),
Some(&NSString::from_str(&wanted_version)),
Some(&headers),
)
.unwrap();
check_webview_id_valid(webview_id)?;
check_task_is_valid(&webview, task_key, task_uuid.clone())?;
objc2::exception::catch(AssertUnwindSafe(|| {
task.didReceiveResponse(&response);
}))
.map_err(|_e| crate::Error::CustomProtocolTaskInvalid)?;
let data = NSData::alloc();
let data = NSData::initWithBytes_length(
data,
content.as_ptr() as *mut c_void,
content.len(),
);
check_webview_id_valid(webview_id)?;
check_task_is_valid(&webview, task_key, task_uuid.clone())?;
objc2::exception::catch(AssertUnwindSafe(|| {
task.didReceiveData(&data);
}))
.map_err(|_e| crate::Error::CustomProtocolTaskInvalid)?;
check_webview_id_valid(webview_id)?;
check_task_is_valid(&webview, task_key, task_uuid)?;
objc2::exception::catch(AssertUnwindSafe(|| {
task.didFinish();
}))
.map_err(|_e| crate::Error::CustomProtocolTaskInvalid)?;
if WEBVIEW_STATE.read().unwrap().contains_key(webview_id) {
webview.remove_custom_task_key(task_key);
Ok(())
} else {
Err(crate::Error::CustomProtocolTaskInvalid)
}
}
#[cfg(feature = "tracing")]
let _span = tracing::info_span!("wry::custom_protocol::call_handler").entered();
if let Err(_e) = response(
task,
webview,
task_key,
task_uuid,
webview_id,
url,
sent_response,
) {
#[cfg(feature = "tracing")]
tracing::error!("Error responding to task: {:?}", _e);
}
});
#[cfg(feature = "tracing")]
let _span = tracing::info_span!("wry::custom_protocol::call_handler").entered();
function(
webview_id,
final_request,
RequestAsyncResponder { responder },
);
}
Err(_) => respond_with_404(),
};
} else {
#[cfg(feature = "tracing")]
tracing::warn!(
"Either WebView or WebContext instance is dropped! This handler shouldn't be called."
);
};
}
}
extern "C" fn stop_task(
_this: &ProtocolObject<dyn WKURLSchemeHandler>,
_sel: objc2::runtime::Sel,
webview: &WryWebView,
task: &ProtocolObject<dyn WKURLSchemeTask>,
) {
webview.remove_custom_task_key(task.hash());
}