use dioxus_interpreter_js::MutationState;
use futures_channel::oneshot;
use futures_util::FutureExt;
use rand::{RngCore, SeedableRng};
use std::cell::RefCell;
use std::collections::{HashMap, VecDeque};
use std::future::Future;
use std::net::{TcpListener, TcpStream};
use std::pin::Pin;
use std::rc::Rc;
use std::sync::atomic::AtomicU32;
use std::sync::Mutex;
use std::{
net::IpAddr,
sync::{Arc, RwLock},
};
use tokio::sync::Notify;
#[derive(Clone)]
pub(crate) struct WryQueue {
inner: Rc<RefCell<WryQueueInner>>,
}
impl WryQueue {
pub(crate) fn with_mutation_state_mut<O: 'static>(
&self,
callback: impl FnOnce(&mut MutationState) -> O,
) -> O {
let mut inner = self.inner.borrow_mut();
callback(&mut inner.mutation_state)
}
pub(crate) fn send_edits(&self) {
let mut myself = self.inner.borrow_mut();
let webview_id = myself.location.webview_id;
let serialized_edits = myself.mutation_state.export_memory();
let receiver = myself.websocket.send_edits(webview_id, serialized_edits);
myself.edits_in_progress = Some(receiver);
}
pub(crate) fn poll_edits_flushed(
&self,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<()> {
let mut self_mut = self.inner.borrow_mut();
if let Some(receiver) = self_mut.edits_in_progress.as_mut() {
receiver.poll_unpin(cx).map(|_| ())
} else {
std::task::Poll::Ready(())
}
}
pub(crate) fn poll_new_edits_location(
&self,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<()> {
let mut self_mut = self.inner.borrow_mut();
let poll = self_mut
.server_location_changed_future
.as_mut()
.poll_unpin(cx);
if poll.is_ready() {
self_mut.server_location_changed_future =
owned_notify_future(self_mut.server_location_changed.clone());
}
poll
}
pub(crate) fn edits_path(&self) -> String {
let WebviewWebsocketLocation {
webview_id, server, ..
} = &self.inner.borrow().location;
let server = server.lock().unwrap();
let port = server.port;
let key = &server.client_key;
let key_hex = encode_key_string(key);
format!("ws://127.0.0.1:{port}/{webview_id}/{key_hex}")
}
pub(crate) fn required_server_key(&self) -> String {
let server = &self.inner.borrow().location.server;
let server = server.lock().unwrap();
encode_key_string(&server.server_key)
}
}
pub(crate) struct WryQueueInner {
location: WebviewWebsocketLocation,
websocket: EditWebsocket,
edits_in_progress: Option<oneshot::Receiver<()>>,
server_location_changed: Arc<Notify>,
server_location_changed_future: Pin<Box<dyn Future<Output = ()>>>,
mutation_state: MutationState,
}
#[derive(Clone)]
pub(crate) struct WebviewWebsocketLocation {
webview_id: u32,
server: Arc<Mutex<ServerLocation>>,
}
#[derive(Clone, Copy, PartialEq, Eq, Hash)]
pub(crate) struct ServerLocation {
port: u16,
client_key: [u8; KEY_SIZE],
server_key: [u8; KEY_SIZE],
}
pub(crate) fn start_server() -> (ServerLocation, TcpListener) {
let client_key = create_secure_key();
let server_key = create_secure_key();
let server = TcpListener::bind((IpAddr::from([127, 0, 0, 1]), 0))
.expect("Failed to bind local TCP listener for edit socket");
let port = server.local_addr().unwrap().port();
let location = ServerLocation {
port,
client_key,
server_key,
};
(location, server)
}
#[derive(Clone)]
pub(crate) struct EditWebsocket {
current_location: Arc<Mutex<ServerLocation>>,
max_webview_id: Arc<AtomicU32>,
connections: Arc<RwLock<HashMap<u32, WebviewConnectionState>>>,
server_location: Arc<Notify>,
}
impl EditWebsocket {
pub(crate) fn start() -> Self {
let connections = Arc::new(RwLock::new(HashMap::new()));
let notify = Arc::new(Notify::new());
let (location, server) = start_server();
let current_location = Arc::new(Mutex::new(location));
let connections_ = connections.clone();
let current_location_ = current_location.clone();
let notify_ = notify.clone();
std::thread::spawn(move || {
Self::accept_loop(notify_, server, current_location_, connections_)
});
Self {
connections,
max_webview_id: Default::default(),
current_location,
server_location: notify,
}
}
fn accept_loop(
notify: Arc<Notify>,
mut server: TcpListener,
current_location: Arc<Mutex<ServerLocation>>,
connections: Arc<RwLock<HashMap<u32, WebviewConnectionState>>>,
) {
loop {
while let Ok((stream, _)) = server.accept() {
Self::handle_connection(stream, current_location.clone(), connections.clone());
}
let (location, new_server) = start_server();
notify.notify_waiters();
*current_location.lock().unwrap() = location;
server = new_server;
}
}
fn handle_connection(
stream: TcpStream,
server_location: Arc<Mutex<ServerLocation>>,
connections: Arc<RwLock<HashMap<u32, WebviewConnectionState>>>,
) {
use tungstenite::handshake::server::{Request, Response};
let current_server_location = { *server_location.lock().unwrap() };
let hex_encoded_client_key = encode_key_string(¤t_server_location.client_key);
let hex_encoded_server_key = encode_key_string(¤t_server_location.server_key);
let mut location = None;
let on_request = |req: &Request, res| {
let path = req.uri().path();
let mut segments = path.trim_matches('/').split('/');
let webview_id = segments
.next()
.and_then(|s| s.parse::<u32>().ok())
.ok_or_else(|| {
Response::builder()
.status(400)
.body(Some("Bad Request: Invalid webview ID".to_string()))
.unwrap()
})?;
let key = segments.next().ok_or_else(|| {
Response::builder()
.status(400)
.body(Some("Bad Request: Missing key".to_string()))
.unwrap()
})?;
let key_matches: bool =
subtle::ConstantTimeEq::ct_eq(hex_encoded_client_key.as_ref(), key.as_bytes())
.into();
if !key_matches {
return Err(Response::builder()
.status(403)
.body(Some("Forbidden: Invalid key".to_string()))
.unwrap());
}
location = Some(WebviewWebsocketLocation {
webview_id,
server: server_location,
});
Ok(res)
};
let mut websocket = match tungstenite::accept_hdr(stream, on_request) {
Ok(ws) => ws,
Err(e) => {
tracing::error!("Error accepting websocket connection: {}", e);
return;
}
};
websocket
.send(tungstenite::Message::Text(hex_encoded_server_key.into()))
.unwrap();
let location = match location {
Some(loc) => loc,
None => {
tracing::error!("WebSocket connection without a valid webview ID");
return;
}
};
let (edits_outgoing, edits_incoming_rx) = std::sync::mpsc::channel::<MsgPair>();
let connections_ = connections.clone();
std::thread::spawn(move || {
let mut queued_message = None;
'connection: while let Ok(msg) = edits_incoming_rx.recv() {
let data = msg.edits.clone();
queued_message = Some(msg);
if let Err(e) = websocket.send(tungstenite::Message::Binary(data.into())) {
tracing::error!("Error sending edits to webview: {}", e);
break 'connection;
}
while let Ok(ws_msg) = websocket.read() {
match ws_msg {
tungstenite::Message::Binary(_) => break,
tungstenite::Message::Close(_) => {
break 'connection;
}
_ => {}
}
}
let msg = queued_message.take().expect("Message should be set here");
if msg.response.send(()).is_err() {
tracing::error!("Error sending edits applied notification");
}
}
tracing::trace!("Webview {} closed the connection", location.webview_id);
let mut connection = WebviewConnectionState::default();
if let Some(msg) = queued_message {
connection.add_message_pair(msg);
}
connections_
.write()
.unwrap()
.insert(location.webview_id, connection);
});
let mut connections = connections.write().unwrap();
match connections.remove(&location.webview_id) {
Some(WebviewConnectionState::Pending { mut pending }) => {
while let Some(pair) = pending.pop_front() {
_ = edits_outgoing.send(pair);
}
}
Some(WebviewConnectionState::Connected { .. }) => {
tracing::error!(
"Webview {} was already connected. Rejecting new connection.",
location.webview_id
);
return;
}
None => {}
}
connections.insert(
location.webview_id,
WebviewConnectionState::Connected { edits_outgoing },
);
}
pub(crate) fn create_queue(&self) -> WryQueue {
let webview_id = self
.max_webview_id
.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
let server = self.current_location.clone();
let server_location = self.server_location.clone();
WryQueue {
inner: Rc::new(RefCell::new(WryQueueInner {
server_location_changed: server_location.clone(),
server_location_changed_future: owned_notify_future(server_location),
location: WebviewWebsocketLocation { webview_id, server },
websocket: self.clone(),
edits_in_progress: None,
mutation_state: MutationState::default(),
})),
}
}
fn send_edits(&mut self, webview: u32, edits: Vec<u8>) -> oneshot::Receiver<()> {
let mut connections_mut = self.connections.write().unwrap();
let connection = connections_mut.entry(webview).or_default();
connection.add_message(edits)
}
}
enum WebviewConnectionState {
Pending {
pending: VecDeque<MsgPair>,
},
Connected {
edits_outgoing: std::sync::mpsc::Sender<MsgPair>,
},
}
impl Default for WebviewConnectionState {
fn default() -> Self {
WebviewConnectionState::Pending {
pending: VecDeque::new(),
}
}
}
impl WebviewConnectionState {
fn add_message(&mut self, edits: Vec<u8>) -> oneshot::Receiver<()> {
let (response_sender, response_receiver) = oneshot::channel();
let pair = MsgPair {
edits,
response: response_sender,
};
self.add_message_pair(pair);
response_receiver
}
fn add_message_pair(&mut self, pair: MsgPair) {
match self {
WebviewConnectionState::Pending { pending: queue } => {
queue.push_back(pair);
}
WebviewConnectionState::Connected { edits_outgoing } => {
_ = edits_outgoing.send(pair);
}
}
}
}
struct MsgPair {
edits: Vec<u8>,
response: oneshot::Sender<()>,
}
const KEY_SIZE: usize = 256;
type EncodedKey = [u8; KEY_SIZE];
fn encode_key_string(key: &EncodedKey) -> String {
base64::Engine::encode(&base64::engine::general_purpose::URL_SAFE, key)
}
fn create_secure_key() -> EncodedKey {
fn assert_crypto_random<R: rand::CryptoRng>(val: R) -> R {
val
}
let mut secure_rng = assert_crypto_random(rand::rngs::StdRng::from_os_rng());
let mut expected_key: EncodedKey = [0u8; KEY_SIZE];
secure_rng.fill_bytes(&mut expected_key);
expected_key
}
#[test]
fn test_key_encoding_length() {
let mut rand = rand::rngs::StdRng::from_os_rng();
for _ in 0..100 {
let mut key: EncodedKey = [0u8; KEY_SIZE];
rand.fill_bytes(&mut key);
let encoded = encode_key_string(&key);
assert_eq!(encoded.len(), 344);
}
}
fn owned_notify_future(notify: Arc<Notify>) -> Pin<Box<dyn Future<Output = ()>>> {
let mut notify_owned = Box::pin(async move {
let notified = notify.notified();
tokio::task::yield_now().await;
notified.await;
});
_ = (&mut notify_owned).now_or_never();
notify_owned
}