use std::{
cell::RefCell,
marker::PhantomData,
pin::Pin,
rc::Rc,
task::{Context, Poll},
};
use futures_channel::mpsc;
use futures_core::{future::LocalBoxFuture, Future};
use futures_util::{FutureExt, StreamExt};
use gloo_events::EventListener;
use js_sys::{ArrayBuffer, Uint8Array};
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use wasm_bindgen::JsCast;
#[doc(hidden)]
pub use bincode;
#[doc(hidden)]
pub use futures_channel;
#[doc(hidden)]
pub use futures_core;
#[doc(hidden)]
pub use futures_util;
#[doc(hidden)]
pub use gloo_events;
#[doc(hidden)]
pub use js_sys;
#[doc(hidden)]
pub use pin_utils;
#[doc(hidden)]
pub use serde;
#[doc(hidden)]
pub use wasm_bindgen;
pub use web_rpc_macro::service;
pub mod client;
#[doc(hidden)]
pub mod codec;
pub mod interface;
pub mod port;
#[doc(hidden)]
pub mod service;
pub use interface::Interface;
#[doc(hidden)]
#[derive(Serialize, Deserialize)]
pub enum MessageHeader {
Request(usize),
Abort(usize),
Response(usize),
StreamItem(usize),
StreamEnd(usize),
}
pub struct Builder<C, S> {
client: PhantomData<C>,
service: S,
interface: Interface,
}
impl Builder<(), ()> {
pub fn new(interface: Interface) -> Self {
Self {
interface,
client: PhantomData::<()>,
service: (),
}
}
}
impl<C> Builder<C, ()> {
pub fn with_service<S: service::Service>(self, implementation: impl Into<S>) -> Builder<C, S> {
let service = implementation.into();
let Builder {
interface, client, ..
} = self;
Builder {
interface,
client,
service,
}
}
}
impl<S> Builder<(), S> {
pub fn with_client<C: client::Client>(self) -> Builder<C, S> {
let Builder {
interface, service, ..
} = self;
Builder {
interface,
client: PhantomData::<C>,
service,
}
}
}
#[must_use = "Server must be polled in order for RPC requests to be executed"]
pub struct Server {
_listener: Rc<EventListener>,
task: LocalBoxFuture<'static, ()>,
}
impl Future for Server {
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.task.poll_unpin(cx)
}
}
impl<C> Builder<C, ()>
where
C: client::Client + From<client::Configuration<C::Response>> + 'static,
<C as client::Client>::Response: DeserializeOwned,
{
pub fn build(self) -> C {
let Builder {
interface:
Interface {
port,
listener,
mut messages_rx,
},
..
} = self;
let client_callback_map: Rc<RefCell<client::CallbackMap<C::Response>>> = Default::default();
let client_callback_map_cloned = client_callback_map.clone();
let stream_callback_map: Rc<RefCell<client::StreamCallbackMap<C::Response>>> =
Default::default();
let stream_callback_map_cloned = stream_callback_map.clone();
let dispatcher = async move {
while let Some(array) = messages_rx.next().await {
let header_bytes =
Uint8Array::new(&array.shift().dyn_into::<ArrayBuffer>().unwrap()).to_vec();
let header: MessageHeader = bincode::deserialize(&header_bytes).unwrap();
match header {
MessageHeader::Response(seq_id) => {
let payload_bytes =
Uint8Array::new(&array.shift().dyn_into::<ArrayBuffer>().unwrap())
.to_vec();
let response: C::Response = bincode::deserialize(&payload_bytes).unwrap();
if let Some(callback_tx) =
client_callback_map_cloned.borrow_mut().remove(&seq_id)
{
let _ = callback_tx.send((response, array));
}
}
MessageHeader::StreamItem(seq_id) => {
let payload_bytes =
Uint8Array::new(&array.shift().dyn_into::<ArrayBuffer>().unwrap())
.to_vec();
let response: C::Response = bincode::deserialize(&payload_bytes).unwrap();
if let Some(tx) = stream_callback_map_cloned.borrow().get(&seq_id) {
let _ = tx.unbounded_send((response, array));
}
}
MessageHeader::StreamEnd(seq_id) => {
stream_callback_map_cloned.borrow_mut().remove(&seq_id);
}
_ => panic!("client received a server message"),
}
}
}
.boxed_local()
.shared();
let port_cloned = port.clone();
let abort_sender = move |seq_id: usize| {
let header = MessageHeader::Abort(seq_id);
let header_bytes = bincode::serialize(&header).unwrap();
let buffer = js_sys::Uint8Array::from(&header_bytes[..]).buffer();
let post_args = js_sys::Array::of1(&buffer);
let transfer_args = js_sys::Array::of1(&buffer);
port_cloned
.post_message(&post_args, &transfer_args)
.unwrap();
};
C::from((
client_callback_map,
stream_callback_map,
port,
Rc::new(listener),
dispatcher,
Rc::new(abort_sender),
))
}
}
impl<S> Builder<(), S>
where
S: service::Service + 'static,
<S as service::Service>::Response: Serialize,
{
pub fn build(self) -> Server {
let Builder {
service,
interface:
Interface {
port,
listener,
mut messages_rx,
},
..
} = self;
let (server_requests_tx, server_requests_rx) = mpsc::unbounded();
let (abort_requests_tx, abort_requests_rx) = mpsc::unbounded();
let dispatcher = async move {
while let Some(array) = messages_rx.next().await {
let header_bytes =
Uint8Array::new(&array.shift().dyn_into::<ArrayBuffer>().unwrap()).to_vec();
let header: MessageHeader = bincode::deserialize(&header_bytes).unwrap();
match header {
MessageHeader::Request(seq_id) => {
let payload =
Uint8Array::new(&array.shift().dyn_into::<ArrayBuffer>().unwrap())
.to_vec();
server_requests_tx
.unbounded_send((seq_id, payload, array))
.unwrap();
}
MessageHeader::Abort(seq_id) => {
abort_requests_tx.unbounded_send(seq_id).unwrap();
}
_ => panic!("server received a client message"),
}
}
}
.boxed_local()
.shared();
Server {
_listener: Rc::new(listener),
task: service::task::<S>(
service,
port,
dispatcher,
server_requests_rx,
abort_requests_rx,
)
.boxed_local(),
}
}
}
impl<C, S> Builder<C, S>
where
C: client::Client + From<client::Configuration<C::Response>> + 'static,
S: service::Service + 'static,
<S as service::Service>::Response: Serialize,
<C as client::Client>::Response: DeserializeOwned,
{
pub fn build(self) -> (C, Server) {
let Builder {
service: server,
interface:
Interface {
port,
listener,
mut messages_rx,
},
..
} = self;
let client_callback_map: Rc<RefCell<client::CallbackMap<C::Response>>> = Default::default();
let stream_callback_map: Rc<RefCell<client::StreamCallbackMap<C::Response>>> =
Default::default();
let (server_requests_tx, server_requests_rx) = mpsc::unbounded();
let (abort_requests_tx, abort_requests_rx) = mpsc::unbounded();
let client_callback_map_cloned = client_callback_map.clone();
let stream_callback_map_cloned = stream_callback_map.clone();
let dispatcher = async move {
while let Some(array) = messages_rx.next().await {
let header_bytes =
Uint8Array::new(&array.shift().dyn_into::<ArrayBuffer>().unwrap()).to_vec();
let header: MessageHeader = bincode::deserialize(&header_bytes).unwrap();
match header {
MessageHeader::Response(seq_id) => {
let payload_bytes =
Uint8Array::new(&array.shift().dyn_into::<ArrayBuffer>().unwrap())
.to_vec();
let response: C::Response = bincode::deserialize(&payload_bytes).unwrap();
if let Some(callback_tx) =
client_callback_map_cloned.borrow_mut().remove(&seq_id)
{
let _ = callback_tx.send((response, array));
}
}
MessageHeader::StreamItem(seq_id) => {
let payload_bytes =
Uint8Array::new(&array.shift().dyn_into::<ArrayBuffer>().unwrap())
.to_vec();
let response: C::Response = bincode::deserialize(&payload_bytes).unwrap();
if let Some(tx) = stream_callback_map_cloned.borrow().get(&seq_id) {
let _ = tx.unbounded_send((response, array));
}
}
MessageHeader::StreamEnd(seq_id) => {
stream_callback_map_cloned.borrow_mut().remove(&seq_id);
}
MessageHeader::Request(seq_id) => {
let payload =
Uint8Array::new(&array.shift().dyn_into::<ArrayBuffer>().unwrap())
.to_vec();
server_requests_tx
.unbounded_send((seq_id, payload, array))
.unwrap();
}
MessageHeader::Abort(seq_id) => {
abort_requests_tx.unbounded_send(seq_id).unwrap();
}
}
}
}
.boxed_local()
.shared();
let port_cloned = port.clone();
let abort_sender = move |seq_id: usize| {
let header = MessageHeader::Abort(seq_id);
let header_bytes = bincode::serialize(&header).unwrap();
let buffer = js_sys::Uint8Array::from(&header_bytes[..]).buffer();
let post_args = js_sys::Array::of1(&buffer);
let transfer_args = js_sys::Array::of1(&buffer);
port_cloned
.post_message(&post_args, &transfer_args)
.unwrap();
};
let listener = Rc::new(listener);
let client = C::from((
client_callback_map,
stream_callback_map,
port.clone(),
listener.clone(),
dispatcher.clone(),
Rc::new(abort_sender),
));
let server = Server {
_listener: listener,
task: service::task::<S>(
server,
port,
dispatcher,
server_requests_rx,
abort_requests_rx,
)
.boxed_local(),
};
(client, server)
}
}