use alloc::boxed::Box;
use alloc::rc::Rc;
use alloc::string::{String, ToString};
use alloc::vec::Vec;
use base64::Engine;
use core::cell::RefCell;
use core::future::poll_fn;
use core::pin::{Pin, pin};
use futures_util::FutureExt;
use std::collections::HashMap;
use std::sync::Arc;
use http::Response;
use crate::batch::{Runtime, in_runtime};
use crate::function_registry::FUNCTION_REGISTRY;
use crate::ipc::{DecodedVariant, IPCMessage, MessageType, OutboundIPCMessage, decode_data};
use crate::runtime::{AppEventVariant, IPCSenders, WryIPC, handle_callbacks};
pub use crate::runtime::WryBindgenEvent;
pub trait ImplWryBindgenResponder {
fn respond(self: Box<Self>, response: Response<Vec<u8>>);
}
pub struct WryBindgenResponder {
respond: Box<dyn ImplWryBindgenResponder>,
}
impl<F> From<F> for WryBindgenResponder
where
F: FnOnce(Response<Vec<u8>>) + 'static,
{
fn from(respond: F) -> Self {
struct FnOnceWrapper<F> {
f: F,
}
impl<F> ImplWryBindgenResponder for FnOnceWrapper<F>
where
F: FnOnce(Response<Vec<u8>>) + 'static,
{
fn respond(self: Box<Self>, response: Response<Vec<u8>>) {
(self.f)(response)
}
}
Self {
respond: Box::new(FnOnceWrapper { f: respond }),
}
}
}
impl WryBindgenResponder {
pub fn new(f: impl ImplWryBindgenResponder + 'static) -> Self {
Self {
respond: Box::new(f),
}
}
fn respond(self, response: Response<Vec<u8>>) {
self.respond.respond(response);
}
fn respond_ipc(self, response: IPCMessage) {
let body = response.into_data();
let engine = base64::engine::general_purpose::STANDARD;
let body_base64 = engine.encode(&body);
self.respond(
http::Response::builder()
.status(200)
.header("Content-Type", "text/plain")
.body(body_base64.into_bytes())
.expect("Failed to build response"),
);
}
}
fn decode_request_data(request: &http::Request<Vec<u8>>) -> Option<IPCMessage> {
if let Some(header_value) = request.headers().get("dioxus-data") {
return decode_data(header_value.as_bytes());
}
None
}
enum WebviewLoadingState {
Pending { queued: Vec<OutboundIPCMessage> },
Loaded,
}
impl Default for WebviewLoadingState {
fn default() -> Self {
WebviewLoadingState::Pending { queued: Vec::new() }
}
}
struct WebviewState {
messages: WebviewMessageLayer,
loading_state: WebviewLoadingState,
evaluate_script: Box<dyn FnMut(&str)>,
}
struct WebviewMessageLayer {
current_xhr: Option<WryBindgenResponder>,
rust_eval_stack: Vec<RustEvalKind>,
sender: IPCSenders,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum RustEvalKind {
TopLevel,
Nested,
}
impl WebviewState {
fn new(sender: IPCSenders, evaluate_script: impl FnMut(&str) + 'static) -> Self {
Self {
messages: WebviewMessageLayer::new(sender),
loading_state: WebviewLoadingState::default(),
evaluate_script: Box::new(evaluate_script),
}
}
fn evaluate_script(&mut self, script: &str) {
(self.evaluate_script)(script);
}
}
impl WebviewMessageLayer {
fn new(sender: IPCSenders) -> Self {
Self {
current_xhr: None,
rust_eval_stack: Vec::new(),
sender,
}
}
fn receive_js_message(&mut self, msg: IPCMessage, responder: WryBindgenResponder) {
let msg_type = msg.ty().unwrap();
if self.current_xhr.is_some() {
responder.respond(error_response());
return;
}
let top_level_responder = match msg_type {
MessageType::Evaluate => {
self.current_xhr = Some(responder);
None
}
MessageType::Respond => match self.rust_eval_stack.pop() {
Some(RustEvalKind::Nested) => {
self.current_xhr = Some(responder);
None
}
Some(RustEvalKind::TopLevel) => Some(responder),
None => {
responder.respond(error_response());
return;
}
},
};
if self.sender.start_send(msg) {
if let Some(responder) = top_level_responder {
responder.respond(blank_response());
}
} else if let Some(responder) = top_level_responder {
responder.respond(error_response());
} else if let Some(responder) = self.current_xhr.take() {
responder.respond(error_response());
}
}
fn receive_rust_message(&mut self, ipc_msg: OutboundIPCMessage) -> Option<IPCMessage> {
let ty = ipc_msg.message.ty().unwrap();
let top_level = ipc_msg.top_level;
let message = ipc_msg.message;
match ty {
MessageType::Respond => {
let responder = self
.current_xhr
.take()
.expect("Rust Respond with no suspended JS XHR to reply to");
responder.respond_ipc(message);
None
}
MessageType::Evaluate if top_level => {
self.rust_eval_stack.push(RustEvalKind::TopLevel);
Some(message)
}
MessageType::Evaluate => {
let responder = self
.current_xhr
.take()
.expect("Nested Rust Evaluate with no suspended JS XHR to reply to");
responder.respond_ipc(message);
self.rust_eval_stack.push(RustEvalKind::Nested);
None
}
}
}
}
fn unique_id() -> u64 {
use core::sync::atomic::{AtomicU64, Ordering};
static COUNTER: AtomicU64 = AtomicU64::new(0);
COUNTER.fetch_add(1, Ordering::Relaxed)
}
pub struct PreparedApp {
id: u64,
future: Box<dyn FnOnce() -> Pin<Box<dyn core::future::Future<Output = ()> + 'static>> + Send>,
}
impl PreparedApp {
pub fn id(&self) -> u64 {
self.id
}
pub fn into_future(self) -> Pin<Box<dyn core::future::Future<Output = ()> + 'static>> {
(self.future)()
}
}
pub struct ProtocolHandler {
id: u64,
webview: Rc<RefCell<HashMap<u64, WebviewState>>>,
}
impl ProtocolHandler {
pub fn handle_request<F, R: Into<WryBindgenResponder>>(
&self,
protocol: &str,
proxy: F,
request: &http::Request<Vec<u8>>,
responder: R,
) -> Option<R>
where
F: Fn(WryBindgenEvent),
{
let webviews = &self.webview;
let webview_id = self.id;
let protocol_prefix = format!("{protocol}://index.html");
let android_prefix = format!("https://{protocol}.index.html");
let windows_prefix = format!("http://{protocol}.index.html");
let uri = request.uri().to_string();
let real_path = uri
.strip_prefix(&protocol_prefix)
.or_else(|| uri.strip_prefix(&windows_prefix))
.or_else(|| uri.strip_prefix(&android_prefix))
.unwrap_or(&uri);
let real_path = real_path.trim_matches('/');
let Some(path_without_wbg) = real_path.strip_prefix("__wbg__/") else {
return Some(responder);
};
if let Some(path_without_snippets) = path_without_wbg.strip_prefix("snippets/") {
let responder = responder.into();
if let Some(content) = FUNCTION_REGISTRY.get_module(path_without_snippets) {
responder.respond(module_response(content));
return None;
}
responder.respond(not_found_response());
return None;
}
if path_without_wbg == "init.js" {
let responder = responder.into();
responder.respond(module_response(&init_script()));
return None;
}
if path_without_wbg == "initialized" {
proxy(WryBindgenEvent::webview_loaded(webview_id));
let responder = responder.into();
responder.respond(blank_response());
return None;
}
if path_without_wbg == "handler" {
let responder = responder.into();
let mut webviews = webviews.borrow_mut();
let Some(webview_state) = webviews.get_mut(&webview_id) else {
responder.respond(error_response());
return None;
};
let Some(msg) = decode_request_data(request) else {
responder.respond(error_response());
return None;
};
webview_state.messages.receive_js_message(msg, responder);
return None;
}
Some(responder)
}
}
fn init_script() -> String {
const INITIALIZATION_SCRIPT: &str = include_str!("./js/main.js");
let collect_functions = FUNCTION_REGISTRY.script();
format!("{INITIALIZATION_SCRIPT}\n{collect_functions}")
}
pub struct WryBindgen {
event_loop_proxy: Arc<dyn Fn(WryBindgenEvent) + Send + Sync>,
webview: Rc<RefCell<HashMap<u64, WebviewState>>>,
}
impl WryBindgen {
pub fn new(event_loop_proxy: impl Fn(WryBindgenEvent) + Send + Sync + 'static) -> Self {
Self {
event_loop_proxy: Arc::new(event_loop_proxy),
webview: Rc::new(RefCell::new(HashMap::new())),
}
}
pub fn app_builder<'a>(&'a self) -> AppBuilder<'a> {
let event_loop_proxy = self.event_loop_proxy.clone();
let webview_id = unique_id();
let (ipc, senders) = WryIPC::new(event_loop_proxy);
self.webview.borrow_mut().insert(
webview_id,
WebviewState::new(senders, |_| {
unreachable!("evaluate_script will only be used after spawning the app")
}),
);
AppBuilder {
webview_id,
bindgen: self,
ipc,
}
}
pub fn handle_user_event(&self, event: WryBindgenEvent) {
let id = event.id();
match event.into_variant() {
AppEventVariant::Ipc(ipc_msg) => self.handle_ipc_message(id, ipc_msg),
AppEventVariant::WebviewLoaded => {
let mut state = self.webview.borrow_mut();
let Some(webview_state) = state.get_mut(&id) else {
return;
};
if let WebviewLoadingState::Pending { queued } = std::mem::replace(
&mut webview_state.loading_state,
WebviewLoadingState::Loaded,
) {
for msg in queued {
self.immediately_handle_ipc_message(webview_state, msg);
}
}
}
}
}
fn handle_ipc_message(&self, id: u64, ipc_msg: OutboundIPCMessage) {
let mut state = self.webview.borrow_mut();
let Some(webview_state) = state.get_mut(&id) else {
return;
};
if let WebviewLoadingState::Pending { queued } = &mut webview_state.loading_state {
queued.push(ipc_msg);
return;
}
self.immediately_handle_ipc_message(webview_state, ipc_msg)
}
fn immediately_handle_ipc_message(
&self,
webview_state: &mut WebviewState,
ipc_msg: OutboundIPCMessage,
) {
let Some(message) = webview_state.messages.receive_rust_message(ipc_msg) else {
return;
};
let decoded = message.decoded().unwrap();
if let DecodedVariant::Evaluate { .. } = decoded {
let engine = base64::engine::general_purpose::STANDARD;
let data_base64 = engine.encode(message.data());
let code = format!("window.evaluate_from_rust_binary(\"{data_base64}\")");
webview_state.evaluate_script(&code);
}
}
}
pub struct AppBuilder<'a> {
webview_id: u64,
bindgen: &'a WryBindgen,
ipc: WryIPC,
}
impl<'a> AppBuilder<'a> {
pub fn protocol_handler(&self) -> ProtocolHandler {
ProtocolHandler {
id: self.webview_id,
webview: self.bindgen.webview.clone(),
}
}
pub fn build<F>(
self,
app: impl FnOnce() -> F + Send + 'static,
evaluate_script: impl FnMut(&str) + 'static,
) -> PreparedApp
where
F: core::future::Future<Output = ()> + 'static,
{
{
let mut webviews = self.bindgen.webview.borrow_mut();
let webview_state = webviews
.get_mut(&self.webview_id)
.expect("The webview state was created in WryBindgen::spawner");
webview_state.evaluate_script = Box::new(evaluate_script);
}
let start_future = move || {
let run_app_in_runtime = async move {
let run_app = app();
let wait_for_events = handle_callbacks();
futures_util::select! {
_ = run_app.fuse() => {},
_ = wait_for_events.fuse() => {},
}
};
let runtime = Runtime::new(self.ipc, self.webview_id);
let mut maybe_runtime = Some(runtime);
let poll_in_runtime = async move {
let mut run_app_in_runtime = pin!(run_app_in_runtime);
poll_fn(move |ctx| {
let (new_runtime, poll_result) =
in_runtime(maybe_runtime.take().unwrap(), || {
run_app_in_runtime.as_mut().poll(ctx)
});
maybe_runtime = Some(new_runtime);
poll_result
})
.await
};
Box::pin(poll_in_runtime) as Pin<Box<dyn Future<Output = ()> + 'static>>
};
PreparedApp {
id: self.webview_id,
future: Box::new(start_future),
}
}
}
pub fn blank_response() -> http::Response<Vec<u8>> {
http::Response::builder()
.status(200)
.body(vec![])
.expect("Failed to build blank response")
}
pub fn error_response() -> http::Response<Vec<u8>> {
http::Response::builder()
.status(400)
.body(vec![])
.expect("Failed to build error response")
}
pub fn module_response(content: &str) -> http::Response<Vec<u8>> {
http::Response::builder()
.status(200)
.header("Content-Type", "application/javascript")
.header("access-control-allow-origin", "*")
.body(content.as_bytes().to_vec())
.expect("Failed to build module response")
}
pub fn not_found_response() -> http::Response<Vec<u8>> {
http::Response::builder()
.status(404)
.body(b"Not Found".to_vec())
.expect("Failed to build not found response")
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ipc::EncodedData;
fn ipc_message(message_type: MessageType) -> IPCMessage {
let mut data = EncodedData::new();
data.push_u8(message_type as u8);
IPCMessage::new(data.to_bytes())
}
fn handler_request(message_type: MessageType) -> http::Request<Vec<u8>> {
let engine = base64::engine::general_purpose::STANDARD;
let body_base64 = engine.encode(ipc_message(message_type).data());
http::Request::builder()
.uri("wry://index.html/__wbg__/handler")
.header("dioxus-data", body_base64)
.body(Vec::new())
.expect("failed to build request")
}
#[test]
fn handler_responds_error_when_evaluate_arrives_after_runtime_drop() {
let bindgen = WryBindgen::new(|_| {});
let app_builder = bindgen.app_builder();
let protocol_handler = app_builder.protocol_handler();
drop(app_builder);
let response = Rc::new(RefCell::new(None));
let captured_response = response.clone();
let request = handler_request(MessageType::Evaluate);
let unhandled = protocol_handler.handle_request(
"wry",
|_| {},
&request,
move |response| *captured_response.borrow_mut() = Some(response),
);
assert!(unhandled.is_none());
let response = response
.borrow_mut()
.take()
.expect("closed runtime should receive an error response");
assert_eq!(response.status(), http::StatusCode::BAD_REQUEST);
}
#[test]
fn handler_responds_error_when_top_level_respond_arrives_after_runtime_drop() {
let bindgen = WryBindgen::new(|_| {});
let app_builder = bindgen.app_builder();
let webview_id = app_builder.webview_id;
let protocol_handler = app_builder.protocol_handler();
let evaluated_scripts = Rc::new(RefCell::new(Vec::new()));
let captured_scripts = evaluated_scripts.clone();
let prepared_app = app_builder.build(
|| async {},
move |script| captured_scripts.borrow_mut().push(script.to_string()),
);
bindgen.handle_user_event(WryBindgenEvent::webview_loaded(webview_id));
bindgen.handle_user_event(WryBindgenEvent::ipc(
webview_id,
OutboundIPCMessage::new(ipc_message(MessageType::Evaluate), true),
));
assert_eq!(evaluated_scripts.borrow().len(), 1);
drop(prepared_app);
let response = Rc::new(RefCell::new(None));
let captured_response = response.clone();
let request = handler_request(MessageType::Respond);
let unhandled = protocol_handler.handle_request(
"wry",
|_| {},
&request,
move |response| *captured_response.borrow_mut() = Some(response),
);
assert!(unhandled.is_none());
let response = response
.borrow_mut()
.take()
.expect("closed runtime should receive an error response");
assert_eq!(response.status(), http::StatusCode::BAD_REQUEST);
}
}