use crate::{webview::web_context::WebContextData, Error};
use glib::FileError;
use http::{header::CONTENT_TYPE, Request, Response};
use std::{
borrow::Cow,
cell::RefCell,
collections::{HashSet, VecDeque},
path::PathBuf,
rc::Rc,
str::FromStr,
sync::{
atomic::{AtomicBool, Ordering::SeqCst},
Mutex,
},
};
use url::Url;
use webkit2gtk::{
ApplicationInfo, CookiePersistentStorage, LoadEvent, URIRequest, URIRequestExt, WebContext,
WebContextBuilder, WebView, WebViewExt, WebsiteDataManagerBuilder,
};
#[derive(Debug)]
pub struct WebContextImpl {
context: WebContext,
webview_uri_loader: Rc<WebviewUriLoader>,
registered_protocols: HashSet<String>,
automation: bool,
app_info: Option<ApplicationInfo>,
}
impl WebContextImpl {
pub fn new(data: &WebContextData) -> Self {
use webkit2gtk::traits::*;
let mut context_builder = WebContextBuilder::new();
if let Some(data_directory) = data.data_directory() {
let data_manager = WebsiteDataManagerBuilder::new()
.local_storage_directory(&data_directory.join("localstorage").to_string_lossy())
.indexeddb_directory(
&data_directory
.join("databases")
.join("indexeddb")
.to_string_lossy(),
)
.build();
if let Some(cookie_manager) = data_manager.cookie_manager() {
cookie_manager.set_persistent_storage(
&data_directory.join("cookies").to_string_lossy(),
CookiePersistentStorage::Text,
);
}
context_builder = context_builder.website_data_manager(&data_manager);
}
let context = context_builder.build();
let automation = false;
context.set_automation_allowed(automation);
let app_info = ApplicationInfo::new();
app_info.set_name(env!("CARGO_PKG_NAME"));
app_info.set_version(
env!("CARGO_PKG_VERSION_MAJOR")
.parse()
.expect("invalid wry version major"),
env!("CARGO_PKG_VERSION_MINOR")
.parse()
.expect("invalid wry version minor"),
env!("CARGO_PKG_VERSION_PATCH")
.parse()
.expect("invalid wry version patch"),
);
Self {
context,
automation,
registered_protocols: Default::default(),
webview_uri_loader: Rc::default(),
app_info: Some(app_info),
}
}
pub fn set_allows_automation(&mut self, flag: bool) {
use webkit2gtk::traits::*;
self.automation = flag;
self.context.set_automation_allowed(flag);
}
}
pub trait WebContextExt {
fn context(&self) -> &WebContext;
fn register_uri_scheme<F>(&mut self, name: &str, handler: F) -> crate::Result<()>
where
F: Fn(&Request<Vec<u8>>) -> crate::Result<Response<Cow<'static, [u8]>>> + 'static;
fn try_register_uri_scheme<F>(&mut self, name: &str, handler: F) -> crate::Result<()>
where
F: Fn(&Request<Vec<u8>>) -> crate::Result<Response<Cow<'static, [u8]>>> + 'static;
fn queue_load_uri(&self, webview: Rc<WebView>, url: Url, headers: Option<http::HeaderMap>);
fn flush_queue_loader(&self);
fn allows_automation(&self) -> bool;
fn register_automation(&mut self, webview: WebView);
fn register_download_handler(
&mut self,
download_started_callback: Option<Box<dyn FnMut(String, &mut PathBuf) -> bool>>,
download_completed_callback: Option<Rc<dyn Fn(String, Option<PathBuf>, bool) + 'static>>,
);
}
impl WebContextExt for super::WebContext {
fn context(&self) -> &WebContext {
&self.os.context
}
fn register_uri_scheme<F>(&mut self, name: &str, handler: F) -> crate::Result<()>
where
F: Fn(&Request<Vec<u8>>) -> crate::Result<Response<Cow<'static, [u8]>>> + 'static,
{
actually_register_uri_scheme(self, name, handler)?;
if self.os.registered_protocols.insert(name.to_string()) {
Ok(())
} else {
Err(Error::DuplicateCustomProtocol(name.to_string()))
}
}
fn try_register_uri_scheme<F>(&mut self, name: &str, handler: F) -> crate::Result<()>
where
F: Fn(&Request<Vec<u8>>) -> crate::Result<Response<Cow<'static, [u8]>>> + 'static,
{
if self.os.registered_protocols.insert(name.to_string()) {
actually_register_uri_scheme(self, name, handler)
} else {
Err(Error::DuplicateCustomProtocol(name.to_string()))
}
}
fn queue_load_uri(&self, webview: Rc<WebView>, url: Url, headers: Option<http::HeaderMap>) {
self.os.webview_uri_loader.push(webview, url, headers)
}
fn flush_queue_loader(&self) {
Rc::clone(&self.os.webview_uri_loader).flush()
}
fn allows_automation(&self) -> bool {
self.os.automation
}
fn register_automation(&mut self, webview: WebView) {
use webkit2gtk::traits::*;
if let (true, Some(app_info)) = (self.os.automation, self.os.app_info.take()) {
self.os.context.connect_automation_started(move |_, auto| {
let webview = webview.clone();
auto.set_application_info(&app_info);
auto.connect_create_web_view(None, move |_| webview.clone());
});
}
}
fn register_download_handler(
&mut self,
download_started_handler: Option<Box<dyn FnMut(String, &mut PathBuf) -> bool>>,
download_completed_handler: Option<Rc<dyn Fn(String, Option<PathBuf>, bool) + 'static>>,
) {
use webkit2gtk::traits::*;
let context = &self.os.context;
let download_started_handler = RefCell::new(download_started_handler);
let failed = Rc::new(RefCell::new(false));
context.connect_download_started(move |_context, download| {
if let Some(uri) = download.request().and_then(|req| req.uri()) {
let uri = uri.to_string();
let mut download_location = download
.destination()
.and_then(|p| PathBuf::from_str(&p).ok())
.unwrap_or_default();
if let Some(download_started_handler) = download_started_handler.borrow_mut().as_mut() {
if download_started_handler(uri, &mut download_location) {
download.connect_response_notify(move |download| {
download.set_destination(&download_location.to_string_lossy());
});
} else {
download.cancel();
}
}
}
download.connect_failed({
let failed = failed.clone();
move |_, _error| {
*failed.borrow_mut() = true;
}
});
if let Some(download_completed_handler) = download_completed_handler.clone() {
download.connect_finished({
let failed = failed.clone();
move |download| {
if let Some(uri) = download.request().and_then(|req| req.uri()) {
let failed = failed.borrow();
let uri = uri.to_string();
download_completed_handler(
uri,
(!*failed)
.then(|| {
download
.destination()
.map_or_else(|| None, |p| Some(PathBuf::from(p.as_str())))
})
.flatten(),
!*failed,
)
}
}
});
}
});
}
}
fn actually_register_uri_scheme<F>(
context: &mut super::WebContext,
name: &str,
handler: F,
) -> crate::Result<()>
where
F: Fn(&Request<Vec<u8>>) -> crate::Result<Response<Cow<'static, [u8]>>> + 'static,
{
use webkit2gtk::traits::*;
let context = &context.os.context;
context
.security_manager()
.ok_or(Error::MissingManager)?
.register_uri_scheme_as_secure(name);
context.register_uri_scheme(name, move |request| {
#[cfg(feature = "tracing")]
let span =
tracing::info_span!(parent: None, "wry::custom_protocol::handle", uri = tracing::field::Empty).entered();
if let Some(uri) = request.uri() {
let uri = uri.as_str();
#[cfg(feature = "tracing")]
span.record("uri", uri);
#[allow(unused_mut)]
let mut http_request = Request::builder().uri(uri).method("GET");
#[cfg(feature = "linux-headers")]
{
use http::{header::HeaderName, HeaderValue};
if let Some(mut headers) = request.http_headers() {
if let Some(map) = http_request.headers_mut() {
headers.foreach(move |k, v| {
if let Ok(name) = HeaderName::from_bytes(k.as_bytes()) {
if let Ok(value) = HeaderValue::from_bytes(v.as_bytes()) {
map.insert(name, value);
}
}
});
}
}
if let Some(method) = request.http_method() {
http_request = http_request.method(method.as_str());
}
}
let http_request = match http_request.body(Vec::new()) {
Ok(req) => req,
Err(_) => {
request.finish_error(&mut glib::Error::new(
FileError::Exist,
"Could not get uri.",
));
return;
}
};
let res = {
#[cfg(feature = "tracing")]
let _span = tracing::info_span!("wry::custom_protocol::call_handler").entered();
handler(&http_request)
};
match res {
Ok(http_response) => {
let buffer = http_response.body();
let input = gio::MemoryInputStream::from_bytes(&glib::Bytes::from(buffer));
let content_type = http_response
.headers()
.get(CONTENT_TYPE)
.and_then(|h| h.to_str().ok());
#[cfg(feature = "linux-headers")]
{
use soup::{MessageHeaders, MessageHeadersType};
use webkit2gtk::URISchemeResponse;
let response = URISchemeResponse::new(&input, buffer.len() as i64);
response.set_status(http_response.status().as_u16() as u32, None);
if let Some(content_type) = content_type {
response.set_content_type(content_type);
}
let mut headers = MessageHeaders::new(MessageHeadersType::Response);
for (name, value) in http_response.headers().into_iter() {
headers.append(name.as_str(), value.to_str().unwrap_or(""));
}
response.set_http_headers(&mut headers);
request.finish_with_response(&response);
}
#[cfg(not(feature = "linux-headers"))]
request.finish(&input, buffer.len() as i64, content_type)
}
Err(_) => request.finish_error(&mut glib::Error::new(
FileError::Exist,
"Could not get requested file.",
)),
}
} else {
request.finish_error(&mut glib::Error::new(
FileError::Exist,
"Could not get uri.",
));
}
});
Ok(())
}
#[derive(Debug, Default)]
struct WebviewUriLoader {
lock: AtomicBool,
queue: Mutex<VecDeque<(Rc<WebView>, Url, Option<http::HeaderMap>)>>,
}
impl WebviewUriLoader {
fn is_locked(&self) -> bool {
self.lock.swap(true, SeqCst)
}
fn unlock(&self) {
self.lock.store(false, SeqCst)
}
fn push(&self, webview: Rc<WebView>, url: Url, headers: Option<http::HeaderMap>) {
let mut queue = self.queue.lock().expect("poisoned load queue");
queue.push_back((webview, url, headers))
}
fn pop(&self) -> Option<(Rc<WebView>, Url, Option<http::HeaderMap>)> {
let mut queue = self.queue.lock().expect("poisoned load queue");
queue.pop_front()
}
fn flush(self: Rc<Self>) {
if !self.is_locked() {
if let Some((webview, url, headers)) = self.pop() {
webview.connect_load_changed(move |_, event| {
if let LoadEvent::Finished = event {
self.unlock();
Rc::clone(&self).flush();
};
});
if let Some(headers) = headers {
let req = URIRequest::builder().uri(url.as_str()).build();
if let Some(ref mut req_headers) = req.http_headers() {
for (header, value) in headers.iter() {
req_headers.append(
header.to_string().as_str(),
value.to_str().unwrap_or_default(),
);
}
}
webview.load_request(&req);
} else {
webview.load_uri(url.as_str());
}
} else {
self.unlock();
}
}
}
}