use bytes::Bytes;
use parking_lot::{Mutex, MutexGuard, RwLock};
use std::result;
use std::sync::Weak;
use tokio::sync::{mpsc, oneshot};
use super::driver::{self, Worker, WorkerError};
use super::event::{self, AsyncSpawner};
use super::transport::{Transport, TransportError};
use tracing::Level;
#[derive(Debug)]
pub enum EndpointError {
TransportNotAvailable,
Worker(WorkerError),
IO(std::io::Error),
}
impl From<WorkerError> for EndpointError {
fn from(err: WorkerError) -> Self {
EndpointError::Worker(err)
}
}
impl From<std::io::Error> for EndpointError {
fn from(err: std::io::Error) -> Self {
EndpointError::IO(err)
}
}
#[derive(Debug, Clone)]
pub enum SenderError {
NotReady(Bytes),
Closed(Bytes),
}
#[derive(Debug, Clone)]
pub enum ReceiverError {
NotReady,
Closed,
}
pub type Result<T> = std::result::Result<T, EndpointError>;
#[derive(Clone)]
pub struct Sender(OutboundSender<Bytes>);
struct EndpointReceiverInner {
rx: InboundReceiver<Bytes>,
leftover: Bytes,
}
pub struct Receiver<'a>(MutexGuard<'a, EndpointReceiverInner>);
struct ControllerInner {
worker: Worker,
outbound_waker: Arc<RwLock<OutboundWaker>>,
inbound_waker: Arc<RwLock<OutboundWaker>>,
id: usize,
}
struct Controller<'a>(MutexGuard<'a, ControllerInner>);
impl Sender {
pub async fn send(&self, data: Bytes) -> result::Result<(), SenderError> {
match self.0.send(data).await {
Err(mpsc::error::SendError(data)) => Err(SenderError::Closed(data)),
Ok(_) => Ok(()),
}
}
pub fn try_send(&self, data: Bytes) -> result::Result<(), SenderError> {
self.0.try_send(data).map_err(|e| match e {
mpsc::error::TrySendError::Full(data) => SenderError::NotReady(data),
mpsc::error::TrySendError::Closed(data) => SenderError::Closed(data),
})
}
}
impl<'a> Receiver<'a> {
pub async fn recv(&mut self, size: Option<usize>) -> Option<Bytes> {
let inner = &mut self.0;
match size {
None => match inner.leftover.is_empty() {
true => inner.rx.recv().await,
false => Some(std::mem::replace(&mut inner.leftover, Bytes::new())),
},
Some(s) => {
if inner.leftover.len() >= s {
return Some(inner.leftover.split_to(s))
}
use bytes::BufMut;
let mut buffer = bytes::BytesMut::new();
buffer.put(&inner.leftover[..]);
while buffer.len() < s {
buffer.put(inner.rx.recv().await?)
}
inner.leftover = buffer.freeze();
Some(inner.leftover.split_to(s))
}
}
}
pub fn try_recv(&mut self) -> result::Result<Bytes, ReceiverError> {
let inner = &mut self.0;
if inner.leftover.is_empty() {
self.0.rx.try_recv().map_err(|e| match e {
mpsc::error::TryRecvError::Empty => ReceiverError::NotReady,
mpsc::error::TryRecvError::Disconnected => ReceiverError::Closed,
})
} else {
Ok(std::mem::replace(&mut inner.leftover, Bytes::new()))
}
}
pub fn drain(&mut self) -> Vec<Bytes> {
let mut res = Vec::new();
let inner = &mut self.0;
if !inner.leftover.is_empty() {
res.push(inner.leftover.split_off(0));
}
while let Ok(msg) = inner.rx.rx.try_recv() {
res.push(msg)
}
res
}
}
pub(super) enum CtrlCmd {
SetTransport(Box<dyn Transport>, oneshot::Sender<Result<()>>),
TakeTransport(oneshot::Sender<Result<Box<dyn Transport>>>),
FinishRenew(usize, Box<dyn Transport>, result::Result<Vec<Bytes>, TransportError>),
Flush,
RenewTransport(Option<usize>, oneshot::Sender<Result<()>>),
}
macro_rules! worker_send_response {
($resp: expr, $val: expr) => {
if let Err(_) = $resp.send($val) {
tracing::event!(target: "Driver", Level::WARN, "worker dropped the response due to receiving side's issue");
}
};
}
macro_rules! _make_ctrl_handler_args {
($enum: ty, $req: ident, $tx: ident, [$($v:ident),+]) => {
<$enum>::$req($($v),+, $tx)
};
($enum: ty, $req: ident, $tx: ident, []) => {
<$enum>::$req($tx)
};
($enum: ty, $req: ident, [$($v:ident),+] ) => {
<$enum>::$req($($v),+)
};
($enum: ty, $req: ident, []) => {
<$enum>::$req
}
}
macro_rules! make_ctrl_handler {
($req:ident, $vis:vis fn $func:ident ($($v:ident: $t:ty),*) -> $T: ty) => {
$vis async fn $func(&self, $($v: $t),*) -> Result<$T> {
let (tx, rx) = oneshot::channel();
self.0.worker.ctrl_tx
.send(driver::CtrlCmd::Endpoint(_make_ctrl_handler_args!(CtrlCmd, $req, tx, [$($v),*]), self.0.id))
.await
.map_err(|_| EndpointError::Worker(WorkerError::Invalid))?;
rx.await.map_err(|_| EndpointError::Worker(WorkerError::Invalid))?
}
};
($req:ident, $vis:vis fn $func:ident ($($v:ident: $t:ty),*)) => {
$vis async fn $func(&self, $($v: $t),*) -> Result<()> {
self.0.worker.ctrl_tx
.send(driver::CtrlCmd::Endpoint(_make_ctrl_handler_args!(CtrlCmd, $req, [$($v),*]), self.0.id))
.await
.map_err(|_| EndpointError::Worker(WorkerError::Invalid))
}
}
}
impl<'a> Controller<'a> {
make_ctrl_handler!(SetTransport, fn set_transport(transport: Box<dyn Transport>) -> ());
make_ctrl_handler!(TakeTransport, fn take_transport() -> Box<dyn Transport>);
make_ctrl_handler!(Flush, fn flush());
make_ctrl_handler!(RenewTransport, fn renew_transport(token: Option<usize>) -> ());
pub async fn set_worker(mut self, worker: Worker) -> Result<Controller<'a>> {
if worker.id != self.0.worker.id {
let eps = self.0.worker.remove_endpoint(self.0.id).await?;
self.0.worker = worker;
self.0.id = self.0.worker.add_endpoint(eps).await?;
{
let mut outbound = self.0.outbound_waker.write();
*outbound = OutboundWaker {
ev: self.0.worker.outbound.clone(),
eid: self.0.id,
};
outbound
.wake()
.await
.ok_or(EndpointError::Worker(WorkerError::Invalid))?;
}
{
let mut inbound = self.0.inbound_waker.write();
*inbound = OutboundWaker {
ev: self.0.worker.inbound.clone(),
eid: self.0.id,
};
inbound
.wake()
.await
.ok_or(EndpointError::Worker(WorkerError::Invalid))?;
}
}
Ok(self)
}
fn worker(&self) -> &Worker {
&self.0.worker
}
fn id(&self) -> usize {
self.0.id
}
}
#[derive(Clone)]
pub struct TransportCloser {
ep: Weak<EndpointInner>,
token: usize,
}
impl TransportCloser {
pub async fn renew(&self) -> Result<()> {
if let Some(ep) = self.ep.upgrade() {
return Controller(ep.ctrl.lock()).renew_transport(Some(self.token)).await
}
Ok(())
}
}
#[derive(Clone)]
pub struct Endpoint(Arc<EndpointInner>);
struct EndpointInner {
outbound_tx: OutboundSender<Bytes>,
inbound_rx: Mutex<EndpointReceiverInner>,
ctrl: Mutex<ControllerInner>,
drop_tx: mpsc::UnboundedSender<usize>,
}
impl Drop for EndpointInner {
fn drop(&mut self) {
let id = self.ctrl.get_mut().id;
self.drop_tx.send(id).ok();
}
}
pub(super) enum CtrlStatus {
Connected,
Disconnected,
None,
}
pub(super) struct EndpointState {
transport: Option<Box<dyn Transport>>,
outbound_rx: OutboundReceiver<Bytes>,
outbound_leftover: Option<Bytes>,
inbound_tx: InboundSender<Bytes>,
inbound_leftover: Bytes,
pub id: usize,
worker: Worker,
connecting: Option<tokio::task::JoinHandle<()>>,
transport_token: usize,
inbound_connected: bool,
outbound_connected: bool,
closer: TransportCloser,
renew_once: bool,
conn_changed: bool,
}
pub(super) fn new_endpoint(
id: usize, channel_size: usize, worker: Worker, outbound_waker: OutboundWaker, inbound_waker: OutboundWaker,
renew_once: bool,
) -> (Endpoint, EndpointState) {
let (outbound_tx, outbound_rx) = new_outbound_mpsc(channel_size, outbound_waker);
let (inbound_tx, inbound_rx) = new_inbound_spsc(channel_size, inbound_waker);
let outbound_waker = outbound_tx.get_waker();
let inbound_waker = inbound_tx.get_waker();
let ctrl = Mutex::new(ControllerInner {
worker: worker.clone(),
outbound_waker,
inbound_waker,
id,
});
let ep = Endpoint(Arc::new(EndpointInner {
outbound_tx,
inbound_rx: Mutex::new(EndpointReceiverInner {
rx: inbound_rx,
leftover: Bytes::new(),
}),
ctrl,
drop_tx: worker.drop_tx.clone(),
}));
let eps = EndpointState {
id,
inbound_tx,
outbound_rx,
outbound_leftover: Some(Bytes::new()),
inbound_leftover: Bytes::new(),
worker,
transport: None,
connecting: None,
transport_token: 0,
inbound_connected: false,
outbound_connected: false,
conn_changed: false,
closer: TransportCloser {
ep: Arc::downgrade(&ep.clone().0),
token: 0,
},
renew_once,
};
(ep, eps)
}
impl Endpoint {
pub fn inbound(&self) -> Receiver<'_> {
Receiver(self.0.inbound_rx.lock())
}
pub fn outbound(&self) -> Sender {
Sender(self.0.outbound_tx.clone())
}
fn ctrl(&self) -> Controller<'_> {
Controller(self.0.ctrl.lock())
}
pub async fn set_transport(&self, transport: Box<dyn Transport>) -> Result<&Self> {
self.ctrl().set_transport(transport).await?;
Ok(self)
}
pub async fn take_transport(&self) -> Result<Box<dyn Transport>> {
let ctrl = self.ctrl();
ctrl.worker().deregister(ctrl.id()).await?;
ctrl.take_transport().await
}
pub async fn flush(&self) -> Result<&Self> {
self.ctrl().flush().await?;
Ok(self)
}
pub async fn set_worker(&self, worker: Worker) -> Result<&Self> {
self.ctrl().set_worker(worker).await?;
Ok(self)
}
pub async fn worker(&self) -> Worker {
self.ctrl().worker().clone()
}
pub async fn reset(&self) -> Result<&Self> {
self.ctrl().renew_transport(None).await?;
Ok(self)
}
pub async fn is_disconnected(&self) -> Result<bool> {
let ctrl = self.ctrl();
let res = ctrl.worker().is_disconnected(ctrl.id()).await?;
Ok(res)
}
}
impl EndpointState {
pub fn try_outbound(&mut self) -> result::Result<(), TransportError> {
let tp = match &mut self.transport {
Some(tp) => tp,
None => return Ok(()),
};
macro_rules! try_send {
($data: expr, $done: expr) => {
let data = $data;
let data_clone = data.as_ref().map(|b: &Bytes| b.clone());
let skip = match &data_clone {
Some(d) => d.is_empty(),
None => false,
};
if !skip {
let res = tp.try_send(data_clone);
match res {
Ok(true) => {
self.outbound_leftover = None;
return Ok($done)
}
Ok(false) => (),
Err(e) => {
if data.is_some() {
self.outbound_leftover = data;
}
match e {
TransportError::NotReady => return Ok($done),
TransportError::HalfTerminated => {
self.outbound_connected = false;
return Ok($done)
}
_ => return Err(e),
}
}
}
}
};
}
let was_connected = self.outbound_connected;
try_send!(std::mem::replace(&mut self.outbound_leftover, Some(Bytes::new())), ());
self.outbound_rx.try_recv(|data| {
try_send!(Some(data), true);
Ok(false)
})?;
if was_connected && !self.outbound_connected {
self.outbound_conn_change();
}
Ok(())
}
pub fn try_inbound(&mut self) -> result::Result<(), TransportError> {
let tp = match &mut self.transport {
Some(tp) => tp,
None => return Ok(()),
};
let was_connected = self.inbound_connected;
let res = self.inbound_tx.try_send(|inbound_tx| {
macro_rules! try_recv {
($data: expr) => {{
let data = $data;
match inbound_tx.try_send(data) {
Err(e) => match e {
mpsc::error::TrySendError::Full(data) => {
self.inbound_leftover = data;
return Ok(())
}
mpsc::error::TrySendError::Closed(_) => (), },
_ => (),
}
}};
}
if !self.inbound_leftover.is_empty() {
try_recv!(std::mem::replace(&mut self.inbound_leftover, Bytes::new()))
}
loop {
match tp.try_recv() {
Ok(bytes) => {
if bytes.is_empty() {
tracing::event!(target: "Endpoint", Level::WARN, "got 0 byte from transport");
return Ok(())
}
try_recv!(bytes);
}
Err(e) => match e {
TransportError::NotReady => return Ok(()),
TransportError::HalfTerminated => {
self.inbound_connected = false;
return Ok(())
}
_ => return Err(e),
},
}
}
});
if was_connected && !self.inbound_connected {
self.inbound_conn_change();
}
res
}
pub fn transport(&mut self) -> Option<&mut Box<dyn Transport>> {
self.transport.as_mut()
}
pub(super) fn renew(&mut self, spawner: &AsyncSpawner) {
if !self.disconnected() {
return
}
if let Some(tp) = self.transport.take() {
if let Some(h) = self.connecting.take() {
h.abort()
}
self.transport_token = self.transport_token.wrapping_add(1);
self.closer.token = self.transport_token;
let ctrl_tx = self.worker.ctrl_tx.clone();
let id = self.id;
let token = self.transport_token;
let closer = self.closer.clone();
let fut = async move {
let (tp, res) = tp.renew(closer).await;
ctrl_tx
.send(driver::CtrlCmd::Endpoint(CtrlCmd::FinishRenew(token, tp, res), id))
.await
.ok(); };
#[cfg(target_arch = "wasm32")]
{
_ = spawner;
wasm_bindgen_futures::spawn_local(fut)
}
#[cfg(not(target_arch = "wasm32"))]
{
self.connecting = Some(spawner.spawn(fut))
}
}
}
pub(super) fn try_renew(&mut self, spawner: &AsyncSpawner) {
if !self.renew_once {
self.renew(spawner)
}
}
fn inbound_conn_change(&mut self) {
self.conn_changed = true;
if self.renew_once {
self.inbound_tx.tx = mpsc::channel(1).0;
}
}
fn outbound_conn_change(&mut self) {
self.conn_changed = true;
if self.renew_once {
self.outbound_rx.rx = mpsc::channel(1).1;
}
}
pub fn disconnect(&mut self) {
if self.inbound_connected {
self.inbound_connected = false;
self.inbound_conn_change();
}
if self.outbound_connected {
self.outbound_connected = false;
self.outbound_conn_change();
}
}
pub fn disconnected(&self) -> bool {
!self.inbound_connected && !self.outbound_connected
}
pub fn get_conn_change(&mut self) -> Option<(bool, bool)> {
if self.conn_changed {
self.conn_changed = false;
Some((self.inbound_connected, self.outbound_connected))
} else {
None
}
}
pub fn ctrl(&mut self, c: CtrlCmd, spawner: &AsyncSpawner) -> CtrlStatus {
use CtrlCmd::*;
match c {
SetTransport(tp, resp) => {
self.transport = Some(tp);
self.renew(spawner);
worker_send_response!(resp, Ok(()));
}
TakeTransport(resp) => {
self.try_outbound().ok();
if let Some(tp) = self.transport.take() {
worker_send_response!(resp, Ok(tp));
}
}
FinishRenew(token, tp, res) => {
if self.transport_token != token {
return CtrlStatus::None
}
self.transport = Some(tp);
match res {
Ok(inbound_init) => {
for msg in inbound_init {
self.inbound_tx.tx.try_send(msg).ok();
}
self.inbound_connected = true;
self.outbound_connected = true;
return CtrlStatus::Connected
}
Err(_) => self.try_renew(spawner),
}
}
Flush => {
self.try_outbound().ok();
self.try_inbound().ok();
}
RenewTransport(token, resp) => {
if !self.disconnected() {
if let Some(token) = token {
if token != self.transport_token {
return CtrlStatus::None
}
}
worker_send_response!(
resp,
match self.transport.as_mut() {
Some(tp) => tp.shutdown(std::net::Shutdown::Both).map_err(|e| e.into()),
None => Ok(()),
}
);
return CtrlStatus::Disconnected
} else {
worker_send_response!(resp, Ok(()))
}
}
}
CtrlStatus::None
}
}
impl Drop for EndpointState {
fn drop(&mut self) {
self.try_outbound().ok();
if let Some(tp) = &mut self.transport {
tp.shutdown(std::net::Shutdown::Both).ok();
}
}
}
use std::sync::{atomic, Arc};
pub(super) struct OutboundWaker {
pub ev: event::EndpointEventSender,
pub eid: usize,
}
impl OutboundWaker {
async fn wake(&self) -> Option<()> {
self.ev.notify(self.eid).await
}
fn blocking_wake(&self) -> Option<()> {
self.ev.blocking_notify(self.eid)
}
}
struct OutboundState {
wait_sig: atomic::AtomicBool,
waker: Arc<RwLock<OutboundWaker>>,
}
struct OutboundSender<T> {
tx: mpsc::Sender<T>,
state: Arc<OutboundState>,
}
impl<T> Clone for OutboundSender<T> {
fn clone(&self) -> Self {
Self {
tx: self.tx.clone(),
state: self.state.clone(),
}
}
}
struct OutboundReceiver<T> {
rx: mpsc::Receiver<T>,
state: Arc<OutboundState>,
}
fn new_outbound_mpsc<T>(buffer_size: usize, waker: OutboundWaker) -> (OutboundSender<T>, OutboundReceiver<T>) {
let (tx, rx) = mpsc::channel(buffer_size);
let waker = Arc::new(RwLock::new(waker));
let state = Arc::new(OutboundState {
wait_sig: atomic::AtomicBool::new(true),
waker,
});
(
OutboundSender {
tx,
state: state.clone(),
},
OutboundReceiver { rx, state },
)
}
impl<T> OutboundSender<T> {
async fn send(&self, data: T) -> result::Result<(), mpsc::error::SendError<T>> {
self.tx.send(data).await?;
if self.state.wait_sig.swap(false, atomic::Ordering::AcqRel) {
self.state.waker.read().wake().await;
}
Ok(())
}
fn try_send(&self, data: T) -> result::Result<(), mpsc::error::TrySendError<T>> {
self.tx.try_send(data)?;
if self.state.wait_sig.swap(false, atomic::Ordering::AcqRel) {
self.state.waker.read().blocking_wake();
}
Ok(())
}
fn get_waker(&self) -> Arc<RwLock<OutboundWaker>> {
self.state.waker.clone()
}
}
impl<T> OutboundReceiver<T> {
fn try_recv<E>(&mut self, mut consumer: impl FnMut(T) -> result::Result<bool, E>) -> result::Result<(), E> {
self.state.wait_sig.swap(true, atomic::Ordering::AcqRel);
while let Ok(data) = self.rx.try_recv() {
match consumer(data) {
Ok(stop) => {
if stop {
break
}
}
Err(e) => return Err(e),
}
}
Ok(())
}
}
struct InboundSender<T> {
tx: mpsc::Sender<T>,
state: Arc<OutboundState>,
}
impl<T> Clone for InboundSender<T> {
fn clone(&self) -> Self {
Self {
tx: self.tx.clone(),
state: self.state.clone(),
}
}
}
struct InboundReceiver<T> {
rx: mpsc::Receiver<T>,
state: Arc<OutboundState>,
}
fn new_inbound_spsc<T>(buffer_size: usize, waker: OutboundWaker) -> (InboundSender<T>, InboundReceiver<T>) {
let (tx, rx) = mpsc::channel(buffer_size);
let waker = Arc::new(RwLock::new(waker));
let state = Arc::new(OutboundState {
wait_sig: atomic::AtomicBool::new(true),
waker,
});
(
InboundSender {
tx,
state: state.clone(),
},
InboundReceiver { rx, state },
)
}
impl<T> InboundSender<T> {
fn try_send<E>(
&self, mut producer: impl FnMut(&mpsc::Sender<T>) -> result::Result<(), E>,
) -> result::Result<(), E> {
self.state.wait_sig.swap(true, atomic::Ordering::AcqRel);
producer(&self.tx)
}
fn get_waker(&self) -> Arc<RwLock<OutboundWaker>> {
self.state.waker.clone()
}
}
impl<T> InboundReceiver<T> {
fn try_recv(&mut self) -> result::Result<T, mpsc::error::TryRecvError> {
let data = self.rx.try_recv()?;
if self.state.wait_sig.swap(false, atomic::Ordering::AcqRel) {
self.state.waker.read().blocking_wake();
}
Ok(data)
}
async fn recv(&mut self) -> Option<T> {
let data = self.rx.recv().await?;
if self.state.wait_sig.swap(false, atomic::Ordering::AcqRel) {
let waker = self.state.waker.clone();
waker.read().wake().await;
}
Some(data)
}
}