use std::result;
use std::sync::{atomic, Arc};
use tokio::sync::{mpsc, oneshot};
use super::endpoint::{self, new_endpoint, CtrlStatus, Endpoint, EndpointState, OutboundWaker};
use super::event::*;
use tracing::Level;
#[derive(Debug)]
pub enum WorkerError {
Invalid,
EndpointNotExist,
PollFailure(std::io::Error),
Full,
}
impl From<std::io::Error> for WorkerError {
fn from(err: std::io::Error) -> Self {
WorkerError::PollFailure(err)
}
}
#[derive(Debug)]
pub enum DriverError {
Worker(WorkerError),
TokioUnavailable,
InvalidWorkerID,
}
impl From<WorkerError> for DriverError {
fn from(err: WorkerError) -> Self {
DriverError::Worker(err)
}
}
type WorkerResult<T> = result::Result<T, WorkerError>;
pub type Result<T> = result::Result<T, DriverError>;
struct DrivingWorker {
worker: Worker,
nendpoints: atomic::AtomicUsize,
close: EventBlockingSender<()>,
}
#[derive(Clone)]
pub struct Driver(Arc<DriverInner>);
struct DriverInner {
workers: Vec<DrivingWorker>,
spawner: AsyncSpawner,
}
impl Driver {
#[cfg(not(target_arch = "wasm32"))]
pub fn new(nworker: usize) -> Result<Self> {
let handle = tokio::runtime::Handle::try_current().map_err(|_| DriverError::TokioUnavailable)?;
Self::new_with_spawner(nworker, AsyncSpawner(handle))
}
#[cfg(target_arch = "wasm32")]
pub async fn new(nworker: usize) -> Result<Self> {
use wasm_bindgen::JsValue;
let cross_origin_isolated = (|| {
JsValue::try_from(&web_sys::window()?.get("crossOriginIsolated")?)
.ok()?
.as_bool()
})()
.unwrap_or(false);
let spawner = if cross_origin_isolated {
match wasm_futures_executor::ThreadPool::new(nworker).await {
Ok(tp) => AsyncSpawner::WASMExecutor(tp),
Err(_) => {
tracing::event!(target: "Driver", Level::WARN, "web worker is unsupported");
AsyncSpawner::SingleThreaded
}
}
} else {
tracing::event!(
target: "Driver",
Level::WARN,
"crossOriginIsolated is not enabled! Falling back to run without web workers"
);
AsyncSpawner::SingleThreaded
};
Self::new_with_spawner(nworker, spawner)
}
pub fn new_with_spawner(nworker: usize, spawner: AsyncSpawner) -> Result<Self> {
let workers = (0..nworker)
.map(|i| DrivingWorker::new(i, 1024, spawner.clone()))
.collect::<WorkerResult<Vec<_>>>()?;
Ok(Driver(Arc::new(DriverInner { workers, spawner })))
}
pub async fn create_endpoint_with_options(&self, renew_once: bool) -> Result<Endpoint> {
let mut min = usize::MAX;
let mut min_idx = 0;
let workers = &self.0.workers;
for (i, w) in workers.iter().enumerate() {
let size = w.nendpoints.fetch_add(0, atomic::Ordering::Relaxed);
if size < min {
min = size;
min_idx = i
}
}
let w = &workers[min_idx].worker;
workers[min_idx].nendpoints.fetch_add(1, atomic::Ordering::Relaxed);
Ok(w.new_endpoint(w.clone(), renew_once).await?)
}
#[inline]
pub async fn create_endpoint(&self) -> Result<Endpoint> {
self.create_endpoint_with_options(false).await
}
#[inline]
pub fn worker(&self, wid: usize) -> Result<Worker> {
let workers = &self.0.workers;
if wid > workers.len() {
return Err(DriverError::InvalidWorkerID)
}
Ok(workers[wid].worker.clone())
}
#[inline(always)]
pub fn spawner(&self) -> &AsyncSpawner {
&self.0.spawner
}
pub fn nendpoints(&self) -> usize {
let mut total = 0;
for w in self.0.workers.iter() {
total += w.nendpoints.fetch_add(0, atomic::Ordering::Relaxed);
}
total
}
}
impl Drop for DriverInner {
fn drop(&mut self) {
for DrivingWorker { close, .. } in self.workers.iter() {
close.notify(());
}
}
}
pub(super) enum CtrlCmd {
NewEndpoint(Worker, bool, oneshot::Sender<WorkerResult<Endpoint>>),
RemoveEndpoint(usize, oneshot::Sender<WorkerResult<EndpointState>>),
DropEndpoint(usize),
Endpoint(endpoint::CtrlCmd, usize),
AddEndpoint(EndpointState, oneshot::Sender<WorkerResult<usize>>),
Deregister(usize),
IsDisconnected(usize, oneshot::Sender<WorkerResult<bool>>),
}
#[derive(Clone)]
pub struct Worker {
pub(super) ctrl_tx: mpsc::Sender<CtrlCmd>,
pub(super) drop_tx: mpsc::UnboundedSender<usize>,
pub(super) id: usize,
pub(super) outbound: EndpointEventSender,
pub(super) inbound: EndpointEventSender,
}
macro_rules! make_ctrl_handler {
($req:ident, $vis:vis fn $func:ident ($($v:ident: $t:ty),*) -> $T: ty) => {
$vis async fn $func(&self, $($v: $t),*) -> WorkerResult<$T> {
let (tx, rx) = oneshot::channel();
self.ctrl_tx
.send(_make_ctrl_handler_args!(CtrlCmd, $req, tx, [$($v),*]))
.await
.map_err(|_| WorkerError::Invalid)?;
rx.await.map_err(|_| WorkerError::Invalid)?
}
};
($req:ident, $vis:vis fn $func: ident ($($v:ident: $t:ty),*)) => {
$vis async fn $func(&self, $($v: $t),*) -> WorkerResult<()> {
self.ctrl_tx
.send(_make_ctrl_handler_args!(CtrlCmd, $req, [$($v),*]))
.await
.map_err(|_| WorkerError::Invalid)
}
}
}
impl Worker {
make_ctrl_handler!(NewEndpoint, fn new_endpoint(w: Worker, renewable_transport: bool) -> Endpoint);
make_ctrl_handler!(RemoveEndpoint, pub(super) fn remove_endpoint(eid: usize) -> EndpointState);
make_ctrl_handler!(AddEndpoint, pub(super) fn add_endpoint(eps: EndpointState) -> usize);
make_ctrl_handler!(Deregister, pub(super) fn deregister(eid: usize));
make_ctrl_handler!(IsDisconnected, pub(super) fn is_disconnected(eid: usize) -> bool);
}
struct WorkerState {
#[allow(dead_code)]
id: usize, poll: IOPoll,
async_spawner: AsyncSpawner,
waker: IOWaker,
ctrl_ev: EventReceiver<CtrlCmd>,
outbound_ev_rx: EndpointEventReceiver,
inbound_ev_rx: EndpointEventReceiver,
outbound_ev_tx: EndpointEventSender,
inbound_ev_tx: EndpointEventSender,
close_ev: EventBlockingReceiver<()>,
endpoints: Vec<Option<EndpointState>>,
ep_freed: Vec<usize>,
}
impl WorkerState {
const MAX_ENDPOINT_TOKEN: usize = usize::MAX >> 1;
const CTRL_TOKEN: usize = Self::MAX_ENDPOINT_TOKEN + 1;
fn new_eid(&mut self) -> Option<usize> {
match self.ep_freed.pop() {
Some(eid) => Some(eid),
None => {
let eid = self.endpoints.len();
if eid > Self::MAX_ENDPOINT_TOKEN {
return None
}
self.endpoints.push(None);
Some(eid)
}
}
}
fn deregister_endpoint(poll: &mut IOPoll, ep: &mut EndpointState) -> WorkerResult<()> {
if let Some(tp) = ep.transport() {
match tp.source() {
#[cfg(not(target_arch = "wasm32"))]
IOSource::MIO(src) => poll.deregister(src),
IOSource::Generic(src) => {
_ = poll;
src.deregister()
}
IOSource::Empty => Ok(()),
}?
}
Ok(())
}
fn register_endpoint(
poll: &mut IOPoll, ep: &mut EndpointState, eid: usize, inbound_ev: &EndpointEventSender,
outbound_ev: &EndpointEventSender,
) -> WorkerResult<()> {
if let Some(tp) = ep.transport() {
match tp.source() {
#[cfg(not(target_arch = "wasm32"))]
IOSource::MIO(src) => poll.register(src, IOToken(eid), IOInterest::READABLE | IOInterest::WRITABLE),
IOSource::Generic(src) => {
_ = poll;
src.register(IONotifier::new(eid, inbound_ev.clone(), outbound_ev.clone()))
}
IOSource::Empty => Ok(()),
}?;
inbound_ev.blocking_notify(eid);
outbound_ev.blocking_notify(eid);
}
Ok(())
}
fn ctrl_handler(&mut self, ctrl: CtrlCmd) -> WorkerResult<()> {
use CtrlCmd::*;
match ctrl {
NewEndpoint(worker, renewable_transport, resp) => {
if let Some(eid) = self.new_eid() {
let (ep, eps) = new_endpoint(
eid,
65536,
worker,
OutboundWaker {
ev: self.outbound_ev_tx.clone(),
eid,
},
OutboundWaker {
ev: self.inbound_ev_tx.clone(),
eid,
},
renewable_transport,
);
self.endpoints[eid] = Some(eps);
worker_send_response!(resp, Ok(ep));
} else {
worker_send_response!(resp, Err(WorkerError::Full));
}
}
RemoveEndpoint(eid, resp) => {
match self.endpoints[eid].take() {
None => worker_send_response!(resp, Err(WorkerError::EndpointNotExist)),
Some(mut ep) => {
if !ep.disconnected() {
Self::deregister_endpoint(&mut self.poll, &mut ep)?
}
worker_send_response!(resp, Ok(ep));
}
}
self.ep_freed.push(eid)
}
DropEndpoint(eid) => {
match self.endpoints[eid].take() {
None => {
tracing::event!(target: "Worker", Level::WARN, "an endpoint that does not exist is dropped")
}
Some(mut ep) => {
if !ep.disconnected() {
Self::deregister_endpoint(&mut self.poll, &mut ep)?
}
}
}
self.ep_freed.push(eid)
}
AddEndpoint(mut eps, resp) => {
if let Some(eid) = self.new_eid() {
if !eps.disconnected() {
Self::register_endpoint(
&mut self.poll,
&mut eps,
eid,
&self.inbound_ev_tx,
&self.outbound_ev_tx,
)?
}
eps.id = eid;
self.endpoints[eid] = Some(eps);
worker_send_response!(resp, Ok(eid));
} else {
worker_send_response!(resp, Err(WorkerError::Full));
}
}
Endpoint(ep_ctrl, eid) => {
if let Some(eps) = &mut self.endpoints[eid] {
match eps.ctrl(ep_ctrl, &self.async_spawner) {
CtrlStatus::Connected => Self::register_endpoint(
&mut self.poll,
eps,
eid,
&self.inbound_ev_tx,
&self.outbound_ev_tx,
)?,
CtrlStatus::Disconnected => {
eps.disconnect();
Self::check_connection(eps, &mut self.poll, &self.async_spawner)?;
}
_ => (),
}
}
}
Deregister(eid) => {
if let Some(ep) = &mut self.endpoints[eid] {
Self::deregister_endpoint(&mut self.poll, ep)?
}
}
IsDisconnected(eid, resp) => {
if let Some(ep) = &mut self.endpoints[eid] {
if ep.disconnected() {
worker_send_response!(resp, Ok(true));
} else {
worker_send_response!(resp, Ok(false));
}
} else {
worker_send_response!(resp, Err(WorkerError::EndpointNotExist));
}
}
}
Ok(())
}
fn check_connection(ep: &mut EndpointState, poll: &mut IOPoll, async_spawner: &AsyncSpawner) -> WorkerResult<()> {
let (inbound, outbound) = match ep.get_conn_change() {
Some((i, o)) => (i, o),
None => return Ok(()),
};
if inbound || outbound {
#[cfg(not(target_arch = "wasm32"))]
let eid = ep.id;
if let Some(tp) = ep.transport() {
match tp.source() {
#[cfg(not(target_arch = "wasm32"))]
IOSource::MIO(src) => {
let mut interest = Some(IOInterest::READABLE | IOInterest::WRITABLE);
if !inbound {
interest = interest.and_then(|i| i.remove(IOInterest::READABLE));
}
if !outbound {
interest = interest.and_then(|i| i.remove(IOInterest::WRITABLE));
}
poll.reregister(src, IOToken(eid), interest.unwrap())?;
}
_ => (),
}
}
} else {
Self::deregister_endpoint(poll, ep)?;
ep.try_renew(async_spawner)
}
Ok(())
}
fn handle_event(&mut self, event: &IOEvent) -> WorkerResult<bool> {
let id = event.token().0;
if id == Self::CTRL_TOKEN {
if let Some(ctrl) = self.ctrl_ev.try_listen() {
self.ctrl_handler(ctrl)?;
self.waker.wake()?;
return Ok(false)
}
if let Some(_) = self.close_ev.listen() {
return Ok(true)
}
if let Some(eid) = self.outbound_ev_rx.try_listen() {
if let Some(ep) = &mut self.endpoints[eid] {
if !ep.disconnected() {
if let Err(_) = ep.try_outbound() {
ep.disconnect();
}
Self::check_connection(ep, &mut self.poll, &self.async_spawner)?;
}
}
self.waker.wake()?;
return Ok(false)
}
if let Some(eid) = self.inbound_ev_rx.try_listen() {
if let Some(ep) = &mut self.endpoints[eid] {
if !ep.disconnected() {
if let Err(_) = ep.try_inbound() {
ep.disconnect();
}
Self::check_connection(ep, &mut self.poll, &self.async_spawner)?;
}
}
self.waker.wake()?;
return Ok(false)
}
return Ok(false)
}
if let Some(ep) = self.endpoints[id].as_mut() {
if event.is_readable() {
if let Err(_) = ep.try_inbound() {
ep.disconnect()
}
}
if event.is_writable() {
if let Err(_) = ep.try_outbound() {
ep.disconnect()
}
}
Self::check_connection(ep, &mut self.poll, &self.async_spawner)?;
}
Ok(false)
}
#[cfg(not(target_arch = "wasm32"))]
fn event_loop(&mut self, events_capacity: usize) -> WorkerResult<()> {
let mut events = IOEvents::with_capacity(events_capacity);
'outer: loop {
self.poll.poll(&mut events)?;
for event in events.iter() {
if self.handle_event(event)? {
break 'outer
}
}
}
Ok(())
}
#[cfg(target_arch = "wasm32")]
async fn event_loop(&mut self, events_capacity: usize) -> WorkerResult<()> {
let mut events = IOEvents::with_capacity(events_capacity);
'outer: loop {
self.poll.poll(&mut events).await?;
for event in events.iter() {
if self.handle_event(event)? {
break 'outer
}
}
}
Ok(())
}
}
impl DrivingWorker {
fn new(id: usize, ctrl_buffer_size: usize, async_spawner: AsyncSpawner) -> WorkerResult<Self> {
let (ctrl_tx, mut ctrl_rx) = mpsc::channel(ctrl_buffer_size);
let (drop_tx, mut drop_rx) = mpsc::unbounded_channel();
let mut poll = IOPoll::new(
#[cfg(target_arch = "wasm32")]
1024,
)?;
let waker = poll.waker(IOToken(WorkerState::CTRL_TOKEN))?;
let (ctrl_ev_tx, ctrl_ev) = new_poll_event(&waker, 1024);
let (outbound_ev_tx, outbound_ev_rx) = new_poll_event(&waker, 1024);
let (inbound_ev_tx, inbound_ev_rx) = new_poll_event(&waker, 1024);
let (close_ev_tx, close_ev) = new_poll_event_blocking(&waker, 1);
async_spawner.spawn(async move {
while let Some(_) = tokio::select! {
Some(c) = ctrl_rx.recv() => { ctrl_ev_tx.notify(c).await },
Some(eid) = drop_rx.recv() => { ctrl_ev_tx.notify(CtrlCmd::DropEndpoint(eid)).await },
else => None,
} {}
});
#[cfg(target_arch = "wasm32")]
let spawner = async_spawner.clone();
let mut worker_state = WorkerState {
id,
poll,
async_spawner,
waker,
ctrl_ev,
outbound_ev_rx,
inbound_ev_rx,
outbound_ev_tx: outbound_ev_tx.clone(),
inbound_ev_tx: inbound_ev_tx.clone(),
close_ev,
endpoints: Vec::new(),
ep_freed: Vec::new(),
};
tracing::event!(target: "Worker", Level::INFO, id = id, "starting event loop");
#[cfg(target_arch = "wasm32")]
spawner.spawn(async move {
if let Err(e) = worker_state.event_loop(1024).await {
tracing::event!(target: "Worker", Level::ERROR, id = id, "died with {:?}", e);
}
});
#[cfg(not(target_arch = "wasm32"))]
std::thread::spawn(move || {
if let Err(e) = worker_state.event_loop(1024) {
tracing::event!(target: "Worker", Level::ERROR, id = id, "died with {:?}", e);
}
});
Ok(Self {
nendpoints: atomic::AtomicUsize::new(0),
worker: Worker {
ctrl_tx,
drop_tx,
id,
inbound: inbound_ev_tx,
outbound: outbound_ev_tx,
},
close: close_ev_tx,
})
}
}
#[cfg(target_arch = "wasm32")]
pub mod wasm {
use super::*;
use crate::hub::utils::to_js_error;
use wasm_bindgen::prelude::*;
#[wasm_bindgen]
pub struct Endpoint(super::Endpoint);
#[wasm_bindgen]
pub struct Transport(Box<dyn crate::hub::transport::Transport>);
#[wasm_bindgen]
impl Endpoint {
pub async fn set_transport(&self, transport: Transport) -> result::Result<(), JsValue> {
self.0.set_transport(transport.0).await.map_err(to_js_error)?;
Ok(())
}
pub async fn send(&self, data: &JsValue) -> result::Result<(), JsValue> {
let data = (|| {
if let Ok(buff) = data.clone().dyn_into::<js_sys::ArrayBuffer>() {
Some(js_sys::Uint8Array::new(&buff).to_vec().into())
} else if let Ok(msg) = data.clone().dyn_into::<js_sys::JsString>() {
msg.as_string().map(|msg| msg.as_bytes().to_vec().into())
} else {
None
}
})()
.ok_or(JsValue::from("invalid data type"))?;
self.0.outbound().send(data).await.map_err(to_js_error)
}
pub async fn recv(&self, size: Option<usize>) -> result::Result<JsValue, JsValue> {
self.0
.inbound()
.recv(size)
.await
.ok_or(JsValue::from("recv error"))
.map(|bytes| js_sys::Uint8Array::from(bytes.as_ref()).buffer().into())
}
}
#[wasm_bindgen]
pub struct Driver(pub(crate) super::Driver);
#[wasm_bindgen]
impl Driver {
#[wasm_bindgen(constructor)]
pub async fn new(nworker: usize) -> result::Result<Driver, JsValue> {
Ok(Self(super::Driver::new(nworker).await.map_err(to_js_error)?))
}
}
}