use async_trait::async_trait;
use bytes::Bytes;
use parking_lot::Mutex;
use rand::Rng;
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use std::collections::HashMap;
use std::hash::Hash;
use std::io;
use std::net::Shutdown;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::{mpsc, oneshot};
use super::{Transport, TransportError};
use crate::hub::endpoint::{Endpoint, TransportCloser};
use crate::hub::event::{self, IOSource, TaskHandle};
use tracing::Level;
pub trait ID: Hash + Eq + Send + Sync + Clone + std::fmt::Debug + 'static {}
impl<T: Hash + Eq + Send + Sync + Clone + std::fmt::Debug + 'static> ID for T {}
#[derive(Debug)]
pub enum PoolError {
RemoteNotExist,
InvalidAddr,
Unknown,
}
#[derive(Serialize, Deserialize, Clone, Copy)]
pub enum HandshakeKind {
Init,
Ack,
}
pub trait Handshake<I: ID>: Sized + Send {
fn new(kind: HandshakeKind, nonce: u64, sender_id: I) -> Self;
fn parse(bytes: &[u8]) -> Option<Self>;
fn seal(&self) -> Option<Bytes>;
fn kind(&self) -> HandshakeKind;
fn nonce(&self) -> u64;
fn sender_id(&self) -> I;
}
pub trait NonBlockingStream: Send {
fn try_recv(&mut self) -> Result<Bytes, TransportError>;
fn try_send(&mut self, data: Option<Bytes>) -> Result<bool, TransportError>;
fn source(&mut self) -> IOSource;
fn shutdown(&mut self, how: Shutdown) -> io::Result<()>;
}
#[async_trait]
pub trait StreamFactory<I: ID>: Clone + Send + Sync + 'static {
async fn create_stream(&self, remote_id: I) -> Option<Box<dyn NonBlockingStream>>;
async fn discover_stream(&self) -> Box<dyn NonBlockingStream>;
}
pub enum ACDecision {
Allow,
Deny,
}
#[async_trait]
pub trait AccessControl<I: ID>: Send + Sync + 'static {
async fn query(&self, id: &I) -> ACDecision;
}
#[derive(Serialize, Deserialize)]
pub struct PlainHandshake<I: ID> {
kind: HandshakeKind,
nonce: u64,
sender_id: I,
}
#[derive(Clone)]
pub struct AcceptAll;
#[async_trait]
impl<I: ID> AccessControl<I> for AcceptAll {
async fn query(&self, _id: &I) -> ACDecision {
ACDecision::Allow
}
}
use rmp_serde::{decode, encode};
impl<I: ID + Serialize + DeserializeOwned> Handshake<I> for PlainHandshake<I> {
fn new(kind: HandshakeKind, nonce: u64, sender_id: I) -> Self {
Self { kind, nonce, sender_id }
}
fn parse(bytes: &[u8]) -> Option<Self> {
let decoded: Result<Self, _> = decode::from_slice(&bytes);
decoded.ok()
}
fn seal(&self) -> Option<Bytes> {
let mut buffer = Vec::new();
encode::write(&mut buffer, &self).ok()?;
Some(buffer.into())
}
fn kind(&self) -> HandshakeKind {
self.kind
}
fn nonce(&self) -> u64 {
self.nonce
}
fn sender_id(&self) -> I {
self.sender_id.clone()
}
}
#[async_trait]
pub trait DelayGenerator: Send {
async fn reset(&mut self);
async fn next_delay(&mut self) -> Duration;
}
pub struct ConstDelay(pub Duration);
#[async_trait]
impl DelayGenerator for ConstDelay {
async fn reset(&mut self) {}
async fn next_delay(&mut self) -> Duration {
self.0
}
}
pub struct RandomDelay {
range: std::ops::Range<Duration>,
rng: rand::rngs::StdRng,
}
impl RandomDelay {
pub fn new(range: std::ops::Range<Duration>) -> Self {
use rand::SeedableRng;
let rng = rand::rngs::StdRng::from_entropy();
Self { range, rng }
}
}
#[async_trait]
impl DelayGenerator for RandomDelay {
async fn reset(&mut self) {}
async fn next_delay(&mut self) -> Duration {
self.rng.gen_range(self.range.clone())
}
}
#[derive(Clone)]
struct Driver(crate::hub::Driver);
impl Driver {
async fn create_endpoint(&self) -> Option<Endpoint> {
self.0.create_endpoint_with_options(true).await.ok()
}
fn spawn(&self, fut: impl std::future::Future<Output = ()> + Send + 'static) -> TaskHandle {
self.0.spawner().spawn_with_task_handle(fut)
}
}
struct RemoteTransport<I: ID + Serialize> {
remote: Remote<I>,
transport: Option<Box<dyn Transport>>,
}
struct TransportToken<I: ID + Serialize> {
id: I,
transports: Arc<Mutex<HashMap<I, RemoteTransport<I>>>>,
}
impl<I: ID + Serialize> TransportToken<I> {
fn new(id: I, transports: Arc<Mutex<HashMap<I, RemoteTransport<I>>>>) -> Self {
Self { id, transports }
}
}
impl<I: ID + Serialize> Drop for TransportToken<I> {
fn drop(&mut self) {
if let Some(r) = self.transports.lock().remove(&self.id) {
r.remote.close();
}
}
}
struct Remotes<I: ID + Serialize + DeserializeOwned, F: StreamFactory<I>, A: AccessControl<I>> {
local_id: I,
rx: mpsc::Receiver<Event<I>>,
tx: mpsc::Sender<Event<I>>,
access_control: A,
transports: Arc<Mutex<HashMap<I, RemoteTransport<I>>>>,
stream_factory: F,
driver: Driver,
config: Config,
}
macro_rules! log {
($lvl: expr, $id: expr, $($arg: tt)+) => {
tracing::event!(target: "Pool Remotes", $lvl, local = ?$id, $($arg)+)
}
}
struct RemoteLogField<'a, I: ID>(&'a I, &'a I);
impl<'a, I: ID> std::fmt::Display for RemoteLogField<'a, I> {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "[{:?}<->{:?}]", self.0, self.1)
}
}
macro_rules! remote_log {
($lvl: expr, $id: expr, $rid: expr, $($arg: tt)+) => {
tracing::event!(target: "Pool Remote", $lvl, remote = %RemoteLogField(&$id, &$rid), $($arg)+)
}
}
impl<I: ID + Serialize + DeserializeOwned, F: StreamFactory<I>, A: AccessControl<I>> Remotes<I, F, A> {
fn create_transport<'a, 'b, H: Handshake<I>>(
&'a self, transports: &'b mut HashMap<I, RemoteTransport<I>>, id: &I,
) -> &'b mut RemoteTransport<I> {
transports.entry(id.clone()).or_insert_with(|| RemoteTransport {
remote: Remote::new::<H, _>(
id.clone(),
self.local_id.clone(),
self.tx.clone(),
self.stream_factory.clone(),
&self.config,
self.driver.clone(),
),
transport: Some(Box::new(InactiveTransport::new(
TransportToken::new(id.clone(), self.transports.clone()),
self.tx.clone(),
))),
})
}
async fn handle_message<H: Handshake<I>>(&mut self, msg: Event<I>) {
match msg {
Event::DiscoveredStream(remote_id, ep, r) => {
let mut transports = self.transports.lock();
let s = if !transports.contains_key(&remote_id) {
match self.access_control.query(&remote_id).await {
ACDecision::Allow => {
log!(Level::DEBUG, self.local_id, "adding unknown remote {:?}", remote_id);
Some(&*self.create_transport::<H>(&mut *transports, &remote_id))
}
ACDecision::Deny => None,
}
} else {
Some(transports.get(&remote_id).unwrap())
};
match s {
Some(s) => {
if let Err(_) = s.remote.discovered_stream(ep, r) {
log!(Level::WARN, self.local_id, "failed to handle a new discovered stream");
}
}
None => log!(Level::DEBUG, self.local_id, "dropping unknown remote {:?}", remote_id),
}
}
Event::Renew(token, stream, closer, resp, dedup_counter) => {
let transports = self.transports.lock();
let s = &transports.get(&token.id).expect("remote should exist").remote;
match stream {
Some((stream, stream_mode)) => s.renew(token, stream, stream_mode, closer, resp, dedup_counter),
None => s.init_renew(token, closer, resp),
}
}
Event::New(remote_id, resp) => {
let mut transports = self.transports.lock();
resp.send(match transports.get_mut(&remote_id) {
Some(t) => t.transport.take(),
None => self
.create_transport::<H>(&mut *transports, &remote_id)
.transport
.take(),
})
.ok();
}
Event::SetRetryDelay(remote_id, delay, resp) => match self.transports.lock().get(&remote_id) {
Some(s) => s.remote.set_retry_delay(delay, resp),
None => {
log!(Level::WARN, self.local_id, "remote {:?} not found", remote_id);
resp.send(None).ok();
}
},
}
}
}
#[derive(Debug, Clone, Builder)]
pub struct Config {
#[builder(default = "1024")]
pub max_pending_pool_ops: usize,
#[builder(default = "16")]
pub max_pending_discovered_stream: usize,
#[builder(default = "1024")]
pub max_unknown_streams: usize,
#[builder(default = "Duration::from_millis(1000)")]
pub server_stream_timeout: Duration,
#[builder(default = "Duration::from_millis(1000)")]
pub client_stream_timeout: Duration,
#[builder(default = "(Duration::from_millis(1000)..Duration::from_millis(2000))")]
pub retry_delay: std::ops::Range<Duration>,
}
pub struct PoolCore<I: ID + Serialize + DeserializeOwned> {
reg_tx: mpsc::Sender<Event<I>>,
_loop: TaskHandle,
}
type RenewSender = oneshot::Sender<(Box<dyn Transport>, Vec<Bytes>)>;
enum Event<I: ID + Serialize> {
New(I, oneshot::Sender<Option<Box<dyn Transport>>>),
DiscoveredStream(I, Endpoint, u64),
Renew(
TransportToken<I>,
Option<(Box<dyn NonBlockingStream>, StreamMode)>,
TransportCloser,
RenewSender,
u64,
),
SetRetryDelay(I, Box<dyn DelayGenerator>, oneshot::Sender<Option<()>>),
}
impl<I: ID + Serialize + DeserializeOwned> PoolCore<I> {
async fn discover_loop<H: Handshake<I>, F: StreamFactory<I>>(
stream_factory: F, event_tx: mpsc::Sender<Event<I>>, driver: Driver, max_unknown_streams: usize,
timeout: Duration, local_id: I,
) {
let pool = event::TaskPool::new(max_unknown_streams, driver.0.spawner());
loop {
let stream = stream_factory.discover_stream().await;
let tp = Box::new(HandshakeTransport::new(stream));
let ep = match driver.create_endpoint().await {
Some(ep) => ep,
None => {
log!(Level::ERROR, local_id, "discover_loop: failed to create endpoint");
continue
}
};
if let Err(e) = ep.set_transport(tp).await {
log!(
Level::ERROR,
local_id,
"discover_loop: failed to set_transport: {:?}",
e
);
continue
}
let event_tx_clone = event_tx.clone();
let local_id_clone = local_id.clone();
let protocol = async move {
let raw = match ep.inbound().recv(None).await {
Some(msg) => msg,
None => return,
};
match H::parse(&raw) {
Some(hmsg) => {
let sender_id = hmsg.sender_id();
let nonce = hmsg.nonce();
log!(Level::DEBUG, local_id_clone, "remote {:?} reaching out", sender_id);
event_tx_clone
.send(Event::DiscoveredStream(sender_id, ep, nonce))
.await
.ok();
}
None => {
log!(Level::WARN, local_id_clone, "invalid message, dropping the endpoint");
}
}
};
let local_id_clone = local_id.clone();
if let Err(_) = pool
.submit(async move {
log!(
Level::DEBUG,
local_id_clone,
"discovered stream timeout in {:?}",
timeout
);
match event::timeout(timeout, protocol).await {
Ok(_) => (),
Err(_) => {
log!(Level::DEBUG, local_id_clone, "handshake discovered stream timeout",);
}
}
})
.await
{
log!(Level::WARN, local_id, "pool error");
}
}
}
pub fn new<H: Handshake<I>, F: StreamFactory<I>, A: AccessControl<I>>(
local_id: I, stream_factory: F, access_control: A, driver: &crate::hub::Driver, config: Config,
) -> Self {
let (tx, rx) = mpsc::channel(config.max_pending_pool_ops);
let reg_tx = tx.clone();
let timeout = config.server_stream_timeout.clone();
let max_unknown_streams = config.max_unknown_streams;
let driver = Driver(driver.clone());
let mut reg = Remotes {
local_id: local_id.clone(),
tx,
rx,
access_control,
transports: Arc::new(Mutex::new(HashMap::new())),
stream_factory: stream_factory.clone(),
driver: driver.clone(),
config,
};
let _loop = driver.clone().spawn(async move {
let sender = reg.tx.clone();
tokio::select! {
_ = Self::discover_loop::<H, F>(stream_factory, sender, driver, max_unknown_streams, timeout, local_id) => {},
_ = async move {
while let Some(msg) = reg.rx.recv().await {
reg.handle_message::<H>(msg).await
}
} => {}
}
});
Self { reg_tx, _loop }
}
pub async fn new_remote(&self, remote_id: I) -> Option<Box<dyn Transport>> {
let (tx, rx) = oneshot::channel();
self.reg_tx.send(Event::New(remote_id, tx)).await.ok();
rx.await.ok()?
}
pub async fn set_retry_delay(&self, remote_id: &I, delay: Box<dyn DelayGenerator>) -> Result<(), PoolError> {
let (tx, rx) = oneshot::channel();
async {
self.reg_tx
.send(Event::SetRetryDelay(remote_id.clone(), delay, tx))
.await
.ok()?;
rx.await.ok()
}
.await
.ok_or(PoolError::Unknown)?
.ok_or(PoolError::RemoteNotExist)
}
}
use parking_lot::RwLock;
#[async_trait]
pub trait AddressedStreamFactory: Clone + Send + Sync + 'static {
async fn create_stream(&self, addr: &str) -> Option<Box<dyn NonBlockingStream>>;
async fn discover_stream(&self) -> Box<dyn NonBlockingStream> {
futures::future::pending().await
}
}
#[derive(Clone)]
struct AddressBook<I: ID>(Arc<RwLock<HashMap<I, (Vec<String>, Arc<tokio::sync::Notify>)>>>);
impl<I: ID> AddressBook<I> {
fn set_addr(&self, id: I, addrs: &[String]) {
let mut book = self.0.write();
let slot = book
.entry(id)
.or_insert_with(|| (Vec::new(), Arc::new(tokio::sync::Notify::new())));
slot.0 = addrs.to_vec();
slot.1.notify_one();
}
fn remove_addr(&self, id: &I) {
self.0.write().remove(id);
}
}
#[derive(Clone)]
struct AddressedFactory<I: ID, F: AddressedStreamFactory> {
address_book: AddressBook<I>,
factory: F,
}
#[async_trait]
impl<I: ID, F: AddressedStreamFactory> StreamFactory<I> for AddressedFactory<I, F> {
async fn create_stream(&self, remote_id: I) -> Option<Box<dyn NonBlockingStream>> {
let mut addrs;
let mut notifier;
loop {
(addrs, notifier) = self.address_book.0.read().get(&remote_id)?.clone();
if !addrs.is_empty() {
break
}
tracing::event!(target: "Pool AddressBook", Level::DEBUG, "waiting until address list is non-empty");
notifier.notified().await
}
for addr in addrs {
tracing::event!(target: "Pool AddressBook", Level::DEBUG, "trying {} for remote={:?}", addr, remote_id);
if let Some(s) = self.factory.create_stream(&addr).await {
return Some(s)
}
}
tracing::event!(target: "Pool AddressBook", Level::DEBUG, "tried all addresses but failed");
None
}
async fn discover_stream(&self) -> Box<dyn NonBlockingStream> {
self.factory.discover_stream().await
}
}
pub struct Pool<I: ID + Serialize + DeserializeOwned> {
core: PoolCore<I>,
address_book: AddressBook<I>,
}
impl<I: ID + Serialize + DeserializeOwned> Pool<I> {
pub fn new<H: Handshake<I>, F: AddressedStreamFactory, A: AccessControl<I>>(
local_id: I, factory: F, access_control: A, driver: &crate::hub::Driver, config: Config,
) -> Self {
let address_book = AddressBook(Arc::new(RwLock::new(HashMap::new())));
let factory = AddressedFactory {
address_book: address_book.clone(),
factory,
};
let core = PoolCore::new::<H, _, A>(local_id, factory, access_control, driver, config);
Self { address_book, core }
}
pub async fn new_remote(&self, remote_id: I) -> Option<Box<dyn Transport>> {
self.core.new_remote(remote_id).await
}
pub async fn set_retry_delay(&self, remote_id: &I, delay: Box<dyn DelayGenerator>) -> Result<(), PoolError> {
self.core.set_retry_delay(remote_id, delay).await
}
pub fn set_addr(&self, remote_id: I, addrs: &[String]) {
self.address_book.set_addr(remote_id, addrs)
}
pub fn remove_addr(&self, remote_id: &I) {
self.address_book.remove_addr(remote_id)
}
}
enum PendingStream {
InClientStreamPreparation(TaskHandle),
InHandshake(TaskHandle),
Ready(Box<dyn NonBlockingStream>),
InTransmit(TransportCloser),
Ended,
None,
}
impl std::fmt::Debug for PendingStream {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
use PendingStream::*;
write!(
f,
"{}",
match self {
InClientStreamPreparation(_) => "InClientStreamPreparation",
InHandshake(_) => "InHandshake",
Ready(_) => "Ready",
InTransmit(_) => "InTransmit",
Ended => "Ended",
None => "None",
}
)
}
}
impl PendingStream {
fn take_for_transmit(&mut self, closer: &TransportCloser) -> Option<Box<dyn NonBlockingStream>> {
if let Self::Ready(_) = self {
let ps = std::mem::replace(self, Self::InTransmit(closer.clone()));
Some(if let Self::Ready(stream) = ps {
stream
} else {
unreachable!()
})
} else {
None
}
}
async fn close(&mut self) -> Option<()> {
match self {
Self::Ready(stream) => stream.shutdown(Shutdown::Both).ok(),
Self::InTransmit(closer) => closer.renew().await.ok(),
_ => Some(()),
}
}
}
enum RenewMode {
Terminating,
Terminated,
}
struct DedupSession {
client_stream: PendingStream,
server_stream: PendingStream,
local_nonce: u64,
remote_nonce: Option<u64>,
}
impl DedupSession {
fn new() -> Self {
Self {
client_stream: PendingStream::None,
server_stream: PendingStream::None,
local_nonce: rand::thread_rng().gen(),
remote_nonce: None,
}
}
async fn close(&mut self) {
self.client_stream.close().await;
self.server_stream.close().await;
}
}
struct RemoteState<I: ID + Serialize, F: StreamFactory<I>> {
remote_id: I,
local_id: I,
registry_tx: mpsc::Sender<Event<I>>,
internal_rx: mpsc::UnboundedReceiver<RemoteEvent<I>>,
internal_tx: mpsc::UnboundedSender<RemoteEvent<I>>,
external_rx: mpsc::Receiver<ExternalEvent>,
dedup: DedupSession,
dedup_counter: u64,
renew: Option<(TransportToken<I>, TransportCloser, RenewSender)>,
renew_mode: RenewMode,
inbound_leftover: Vec<Bytes>,
retry_delay: Box<dyn DelayGenerator>,
client_stream_timeout: Duration,
server_stream_timeout: Duration,
stream_factory: F,
driver: Driver,
}
enum ExternalEvent {
DiscoveredStream(Endpoint, u64),
}
enum StreamMode {
Client,
Server,
}
enum RemoteEvent<I: ID + Serialize> {
Renew(
TransportToken<I>,
Box<dyn NonBlockingStream>,
StreamMode,
TransportCloser,
RenewSender,
u64,
),
InitRenew(TransportToken<I>, TransportCloser, RenewSender),
HandshakeFailed(StreamMode, u64), HandshakeFinished(Box<dyn NonBlockingStream>, Vec<Bytes>, u64, StreamMode, u64),
ClientStreamCreated(Box<dyn NonBlockingStream>, u64),
Close,
SetRetryDelay(Box<dyn DelayGenerator>, oneshot::Sender<Option<()>>),
}
impl<I: ID + Serialize, F: StreamFactory<I>> RemoteState<I, F> {
fn check_stream_ready(&mut self) -> bool {
if let Some((token, closer, resp_tx)) = self.renew.take() {
let stream = if let Some(stream) = self.dedup.client_stream.take_for_transmit(&closer) {
Some((stream, StreamMode::Client))
} else if let Some(stream) = self.dedup.server_stream.take_for_transmit(&closer) {
Some((stream, StreamMode::Server))
} else {
None
};
match stream {
Some((stream, stream_mode)) => {
remote_log!(
Level::DEBUG,
self.local_id,
self.remote_id,
"reached stable transmission client={:?} server={:?}",
self.dedup.client_stream,
self.dedup.server_stream,
);
self.renew_mode = RenewMode::Terminating;
resp_tx
.send((
Box::new(ActiveTransport::new(
token,
self.dedup_counter,
self.registry_tx.clone(),
stream,
stream_mode,
)),
std::mem::replace(&mut self.inbound_leftover, Vec::new()),
))
.ok();
return true
}
None => self.renew = Some((token, closer, resp_tx)),
}
} else {
remote_log!(Level::DEBUG, self.local_id, self.remote_id, "no waiting renew()");
}
false
}
fn wait_stream_ready(&mut self, token: TransportToken<I>, closer: TransportCloser, resp_tx: RenewSender) {
assert!(self.renew.is_none());
self.renew = Some((token, closer, resp_tx));
}
async fn create_server_stream<H: Handshake<I>>(
&mut self, timeout: Duration, ep: Endpoint, nonce: u64,
) -> TaskHandle {
let internal_tx = self.internal_tx.clone();
let local_nonce = self.dedup.local_nonce;
let local_id = self.local_id.clone();
let remote_id = self.remote_id.clone();
let dedup_counter = self.dedup_counter;
let protocol = async move {
let ack = H::new(HandshakeKind::Ack, local_nonce, local_id.clone())
.seal()
.or_else(|| {
remote_log!(
Level::ERROR,
local_id,
remote_id,
"server stream failed to seal handshake ack message"
);
None
})?;
remote_log!(
Level::DEBUG,
local_id,
remote_id,
"handshake ACK ({:08x}<->{:08x})",
local_nonce,
nonce
);
ep.outbound().send(ack).await.ok()?;
let tp = ep.take_transport().await.ok()?; let inbound_leftover = ep.inbound().drain(); internal_tx
.send(RemoteEvent::HandshakeFinished(
tp.take_stream().ok()?,
inbound_leftover,
nonce,
StreamMode::Server,
dedup_counter,
))
.ok();
Some(())
};
let internal_tx = self.internal_tx.clone();
let local_id = self.local_id.clone();
let remote_id = self.remote_id.clone();
self.driver.spawn(async move {
remote_log!(
Level::DEBUG,
local_id,
remote_id,
"server-side handshake timeout in {:?}",
timeout
);
match event::timeout(timeout, protocol).await {
Ok(r) => {
if r.is_none() {
remote_log!(Level::DEBUG, local_id, remote_id, "server-side handshake failed");
internal_tx
.send(RemoteEvent::HandshakeFailed(StreamMode::Server, dedup_counter))
.ok();
}
}
Err(_) => {
remote_log!(Level::DEBUG, local_id, remote_id, "server-side handshake timeout");
internal_tx
.send(RemoteEvent::HandshakeFailed(StreamMode::Server, dedup_counter))
.ok();
}
}
})
}
async fn create_client_stream<H: Handshake<I>>(
&mut self, stream: Box<dyn NonBlockingStream>, timeout: Duration,
) -> TaskHandle {
let internal_tx = self.internal_tx.clone();
let remote_id = self.remote_id.clone();
let local_id = self.local_id.clone();
let driver = self.driver.clone();
let nonce = self.dedup.local_nonce;
let dedup_counter = self.dedup_counter;
let protocol = async move {
let tp = Box::new(HandshakeTransport::new(stream));
let ep = driver.create_endpoint().await.or_else(|| {
remote_log!(
Level::ERROR,
local_id,
remote_id,
"client stream failed to create endpoint",
);
None
})?;
ep.set_transport(tp).await.ok()?;
let init = H::new(HandshakeKind::Init, nonce, local_id.clone())
.seal()
.or_else(|| {
remote_log!(
Level::ERROR,
local_id,
remote_id,
"client stream failed to seal handshake init message"
);
None
})?;
remote_log!(Level::DEBUG, local_id, remote_id, "handshake INIT ({:08x})", nonce);
ep.outbound().send(init).await.ok()?;
let raw = ep.inbound().recv(None).await?;
let tp = ep.take_transport().await.ok()?; let inbound_leftover = ep.inbound().drain(); let mut stream = tp.take_stream().ok()?;
match H::parse(&raw) {
Some(ack) => {
internal_tx
.send(RemoteEvent::HandshakeFinished(
stream,
inbound_leftover,
ack.nonce(),
StreamMode::Client,
dedup_counter,
))
.ok();
}
None => {
stream.shutdown(Shutdown::Both).ok();
}
}
Some(())
};
let internal_tx = self.internal_tx.clone();
let local_id = self.local_id.clone();
let remote_id = self.remote_id.clone();
self.driver.spawn(async move {
remote_log!(
Level::DEBUG,
local_id,
remote_id,
"client-side handshake timeout in {:?}",
timeout
);
match event::timeout(timeout, protocol).await {
Ok(r) => {
if r.is_none() {
remote_log!(Level::DEBUG, local_id, remote_id, "client-side handshake failed");
internal_tx
.send(RemoteEvent::HandshakeFailed(StreamMode::Client, dedup_counter))
.ok();
}
}
Err(_) => {
remote_log!(Level::DEBUG, local_id, remote_id, "client-side handshake timeout");
internal_tx
.send(RemoteEvent::HandshakeFailed(StreamMode::Client, dedup_counter))
.ok();
}
}
})
}
async fn try_client_stream<H: Handshake<I>>(&mut self, is_delayed: bool) {
let internal_tx = self.internal_tx.clone();
let remote_id = self.remote_id.clone();
let local_id = self.local_id.clone();
let delay = if is_delayed {
Some(self.retry_delay.next_delay().await)
} else {
None
};
let factory = self.stream_factory.clone();
let dedup_counter = self.dedup_counter;
self.dedup.client_stream = PendingStream::InClientStreamPreparation(self.driver.spawn(async move {
if let Some(delay) = delay {
remote_log!(
Level::DEBUG,
local_id,
remote_id,
"client-side handshake delays for {:.3}s",
delay.as_secs() as f32 + delay.subsec_millis() as f32 / 1e3,
);
event::sleep(delay).await;
}
let stream = match factory.create_stream(remote_id.clone()).await {
Some(s) => s,
None => {
remote_log!(Level::DEBUG, local_id, remote_id, "failed to create client stream");
internal_tx
.send(RemoteEvent::HandshakeFailed(StreamMode::Client, dedup_counter))
.ok();
return
}
};
internal_tx
.send(RemoteEvent::ClientStreamCreated(stream, dedup_counter))
.ok();
}));
}
async fn handle_external<H: Handshake<I>>(&mut self, msg: ExternalEvent) {
let ExternalEvent::DiscoveredStream(ep, nonce) = msg;
if let Some(known_nonce) = self.dedup.remote_nonce {
if known_nonce != nonce {
remote_log!(
Level::DEBUG,
self.local_id,
self.remote_id,
"server stream got a different remote nonce, reset"
);
self.reset().await;
return
}
}
self.dedup.remote_nonce = Some(nonce);
if nonce == self.dedup.local_nonce {
remote_log!(
Level::DEBUG,
self.local_id,
self.remote_id,
"server stream got identical nonce, dropping"
);
return
}
match self.dedup.server_stream {
PendingStream::None | PendingStream::Ended => {
if self.dedup.local_nonce > nonce ||
matches!(
self.dedup.client_stream,
PendingStream::None | PendingStream::InClientStreamPreparation(_)
)
{
self.dedup.server_stream = PendingStream::InHandshake(
self.create_server_stream::<H>(self.server_stream_timeout, ep, nonce)
.await,
);
} else {
remote_log!(
Level::DEBUG,
self.local_id,
self.remote_id,
"server stream priority is lower, will not handshake, dropping",
);
}
}
_ => {
remote_log!(
Level::DEBUG,
self.local_id,
self.remote_id,
"the other server stream exists, dropping",
)
}
}
}
async fn reset(&mut self) {
self.dedup.close().await;
self.dedup = DedupSession::new();
self.dedup_counter = self.dedup_counter.wrapping_add(1);
self.retry_delay.reset().await;
remote_log!(Level::DEBUG, self.local_id, self.remote_id, "reset dedup session");
}
async fn check_reset<H: Handshake<I>>(&mut self) {
remote_log!(
Level::DEBUG,
self.local_id,
self.remote_id,
"stream states client={:?} server={:?}",
self.dedup.client_stream,
self.dedup.server_stream
);
if matches!(
self.dedup.client_stream,
PendingStream::Ended | PendingStream::InClientStreamPreparation(_)
) && matches!(self.dedup.server_stream, PendingStream::Ended | PendingStream::None)
{
self.reset().await;
self.try_client_stream::<H>(true).await;
} else {
self.check_stream_ready();
}
}
async fn handle_internal<H: Handshake<I>>(&mut self, msg: RemoteEvent<I>) {
match msg {
RemoteEvent::ClientStreamCreated(stream, dedup_counter) => {
if dedup_counter != self.dedup_counter {
return
}
self.dedup.client_stream = PendingStream::InHandshake(
self.create_client_stream::<H>(stream, self.client_stream_timeout).await,
);
}
RemoteEvent::HandshakeFinished(stream, inbound_leftover, nonce, stream_mode, dedup_counter) => {
if dedup_counter != self.dedup_counter {
return
}
if let StreamMode::Client = stream_mode {
if let Some(known_nonce) = self.dedup.remote_nonce {
if known_nonce != nonce {
remote_log!(
Level::DEBUG,
self.local_id,
self.remote_id,
"client stream got a different remote nonce, reset"
);
self.reset().await;
return
}
}
self.dedup.remote_nonce = Some(nonce);
if nonce == self.dedup.local_nonce {
remote_log!(
Level::DEBUG,
self.local_id,
self.remote_id,
"client stream got identical nonce, dropping"
);
return
}
}
let (ready_stream, other_stream, is_client) = match stream_mode {
StreamMode::Client => (&mut self.dedup.client_stream, &mut self.dedup.server_stream, true),
StreamMode::Server => (&mut self.dedup.server_stream, &mut self.dedup.client_stream, false),
};
remote_log!(
Level::DEBUG,
self.local_id,
self.remote_id,
local_nonce = %format!("{:08x}", self.dedup.local_nonce),
remote_nonce = %format!("{:08x}", self.dedup.remote_nonce.unwrap()),
"{} handshake finished",
if is_client { "client-side" } else { "server-side" }
);
assert!(matches!(ready_stream, PendingStream::InHandshake(_)));
*ready_stream = PendingStream::Ready(stream);
match other_stream {
PendingStream::InTransmit(closer) => {
assert!(self.renew.is_none());
if (self.dedup.local_nonce > nonce) ^ is_client {
remote_log!(
Level::DEBUG,
self.local_id,
self.remote_id,
"the other in-transmit stream will be replaced",
);
closer.renew().await.ok();
self.inbound_leftover.extend(inbound_leftover);
} else {
ready_stream.close().await;
*ready_stream = PendingStream::Ended;
remote_log!(
Level::WARN,
self.local_id,
self.remote_id,
"dropping low-after-high priority stream",
);
}
}
_ => {
self.inbound_leftover.extend(inbound_leftover);
remote_log!(
Level::DEBUG,
self.local_id,
self.remote_id,
"the other stream is not in-transmit",
);
self.check_stream_ready();
}
}
}
RemoteEvent::Renew(token, mut stream, stream_mode, closer, resp_tx, dedup_counter) => {
match self.renew_mode {
RenewMode::Terminating => {
self.renew_mode = RenewMode::Terminated;
remote_log!(
Level::DEBUG,
self.local_id,
self.remote_id,
"in-transmit stream is terminating"
);
stream.shutdown(Shutdown::Write).ok();
resp_tx
.send((
Box::new(ActiveTransport::new(
token,
dedup_counter,
self.registry_tx.clone(),
stream,
stream_mode,
)),
std::mem::replace(&mut self.inbound_leftover, Vec::new()),
))
.ok();
}
RenewMode::Terminated => {
remote_log!(
Level::DEBUG,
self.local_id,
self.remote_id,
"in-transmit stream is terminated"
);
self.wait_stream_ready(token, closer, resp_tx);
if dedup_counter == self.dedup_counter {
match stream_mode {
StreamMode::Client => self.dedup.client_stream = PendingStream::Ended,
StreamMode::Server => self.dedup.server_stream = PendingStream::Ended,
}
self.check_reset::<H>().await;
} else {
self.check_stream_ready();
}
}
}
}
RemoteEvent::InitRenew(token, closer, resp_tx) => {
remote_log!(Level::DEBUG, self.local_id, self.remote_id, "initial renew");
self.wait_stream_ready(token, closer, resp_tx);
if !self.check_stream_ready() {
self.try_client_stream::<H>(false).await;
}
}
RemoteEvent::Close => {
self.dedup.close().await;
}
RemoteEvent::HandshakeFailed(stream_mode, dedup_counter) => {
if dedup_counter != self.dedup_counter {
return
}
match stream_mode {
StreamMode::Server => {
assert!(matches!(self.dedup.server_stream, PendingStream::InHandshake(_)));
self.dedup.server_stream = PendingStream::Ended
}
StreamMode::Client => {
assert!(matches!(
self.dedup.client_stream,
PendingStream::InHandshake(_) | PendingStream::InClientStreamPreparation(_)
));
self.dedup.client_stream = PendingStream::Ended
}
}
self.check_reset::<H>().await;
}
RemoteEvent::SetRetryDelay(delay, resp_tx) => {
self.retry_delay = delay;
resp_tx.send(Some(())).ok();
}
}
}
}
struct Remote<I: ID + Serialize> {
external_tx: mpsc::Sender<ExternalEvent>,
internal_tx: mpsc::UnboundedSender<RemoteEvent<I>>,
_loop: TaskHandle,
}
impl<I: ID + Serialize> Remote<I> {
fn new<H: Handshake<I>, F: StreamFactory<I>>(
remote_id: I, local_id: I, registry_tx: mpsc::Sender<Event<I>>, stream_factory: F, config: &Config,
driver: Driver,
) -> Self {
let (internal_tx, internal_rx) = mpsc::unbounded_channel();
let (external_tx, external_rx) = mpsc::channel(config.max_pending_discovered_stream);
let mut state = RemoteState {
remote_id,
local_id,
registry_tx,
dedup: DedupSession::new(),
dedup_counter: 0,
renew: None,
renew_mode: RenewMode::Terminating,
inbound_leftover: Vec::new(),
internal_rx,
internal_tx: internal_tx.clone(),
external_rx,
driver,
stream_factory,
retry_delay: Box::new(RandomDelay::new(config.retry_delay.clone())),
client_stream_timeout: config.client_stream_timeout,
server_stream_timeout: config.server_stream_timeout,
};
let _loop = state.driver.clone().spawn(async move {
loop {
tokio::select! {
msg = state.internal_rx.recv() => {
match msg {
Some(msg) => state.handle_internal::<H>(msg).await,
None => break,
}
},
msg = state.external_rx.recv() => {
match msg {
Some(msg) => state.handle_external::<H>(msg).await,
None => break,
}
}
}
}
});
Self {
external_tx,
internal_tx,
_loop,
}
}
fn discovered_stream(&self, ep: Endpoint, r: u64) -> Result<(), Endpoint> {
let msg = ExternalEvent::DiscoveredStream(ep, r);
self.external_tx.try_send(msg).map_err(|e| {
let ExternalEvent::DiscoveredStream(ep, ..) = match e {
mpsc::error::TrySendError::Full(m) => m,
mpsc::error::TrySendError::Closed(m) => m,
};
ep
})
}
fn renew(
&self, token: TransportToken<I>, stream: Box<dyn NonBlockingStream>, stream_mode: StreamMode,
closer: TransportCloser, resp: RenewSender, counter: u64,
) {
self.internal_tx
.send(RemoteEvent::Renew(token, stream, stream_mode, closer, resp, counter))
.ok();
}
fn init_renew(&self, token: TransportToken<I>, closer: TransportCloser, resp: RenewSender) {
self.internal_tx.send(RemoteEvent::InitRenew(token, closer, resp)).ok();
}
fn close(&self) {
self.internal_tx.send(RemoteEvent::Close).ok();
}
fn set_retry_delay(&self, delay: Box<dyn DelayGenerator>, resp: oneshot::Sender<Option<()>>) {
self.internal_tx.send(RemoteEvent::SetRetryDelay(delay, resp)).ok();
}
}
struct ActiveTransport<I: ID + Serialize> {
token: TransportToken<I>,
dedup_counter: u64,
event_tx: mpsc::Sender<Event<I>>,
stream: Box<dyn NonBlockingStream>,
stream_mode: StreamMode,
disable_send: bool,
disable_recv: bool,
}
impl<I: ID + Serialize> ActiveTransport<I> {
fn new(
token: TransportToken<I>, dedup_counter: u64, event_tx: mpsc::Sender<Event<I>>,
stream: Box<dyn NonBlockingStream>, stream_mode: StreamMode,
) -> Self {
Self {
token,
dedup_counter,
event_tx,
stream,
stream_mode,
disable_send: false,
disable_recv: false,
}
}
}
#[async_trait]
impl<I: ID + Serialize> Transport for ActiveTransport<I> {
fn try_send(&mut self, bytes: Option<Bytes>) -> Result<bool, TransportError> {
if self.disable_send {
Err(TransportError::HalfTerminated)
} else {
self.stream.try_send(bytes)
}
}
fn try_recv(&mut self) -> Result<Bytes, TransportError> {
if self.disable_recv {
Err(TransportError::HalfTerminated)
} else {
match self.stream.try_recv() {
Err(TransportError::HalfTerminated) => {
self.stream.shutdown(Shutdown::Both).ok();
Err(TransportError::BothTerminated)
}
res => res,
}
}
}
fn source(&mut self) -> IOSource {
self.stream.source()
}
async fn renew(
mut self: Box<Self>, closer: TransportCloser,
) -> (Box<dyn Transport>, Result<Vec<Bytes>, TransportError>) {
let (resp_tx, resp_rx) = oneshot::channel();
self.event_tx
.send(Event::Renew(
self.token,
Some((self.stream, self.stream_mode)),
closer,
resp_tx,
self.dedup_counter,
))
.await
.ok();
let response = match resp_rx.await {
Ok(r) => r,
Err(_) => futures::future::pending().await,
};
(response.0, Ok(response.1))
}
fn shutdown(&mut self, how: Shutdown) -> io::Result<()> {
match how {
Shutdown::Read => {
self.disable_recv = true;
Ok(())
}
Shutdown::Write => {
self.disable_send = true;
Ok(())
}
Shutdown::Both => {
self.disable_recv = true;
self.disable_send = true;
Ok(())
}
}
}
fn take_stream(self: Box<ActiveTransport<I>>) -> Result<Box<dyn NonBlockingStream>, TransportError> {
Ok(self.stream)
}
}
struct InactiveTransport<I: ID + Serialize> {
token: TransportToken<I>,
event_tx: mpsc::Sender<Event<I>>,
}
impl<I: ID + Serialize> InactiveTransport<I> {
fn new(token: TransportToken<I>, event_tx: mpsc::Sender<Event<I>>) -> Self {
Self { token, event_tx }
}
}
#[async_trait]
impl<I: ID + Serialize> Transport for InactiveTransport<I> {
fn try_send(&mut self, _data: Option<Bytes>) -> Result<bool, TransportError> {
Err(TransportError::NotReady)
}
fn try_recv(&mut self) -> Result<Bytes, TransportError> {
Err(TransportError::NotReady)
}
fn source(&mut self) -> IOSource {
IOSource::Empty
}
fn shutdown(&mut self, _how: Shutdown) -> io::Result<()> {
unreachable!();
}
async fn renew(
mut self: Box<Self>, closer: TransportCloser,
) -> (Box<dyn Transport>, Result<Vec<Bytes>, TransportError>) {
let (resp_tx, resp_rx) = oneshot::channel();
self.event_tx
.send(Event::Renew(self.token, None, closer, resp_tx, 0))
.await
.ok();
let response = match resp_rx.await {
Ok(r) => r,
Err(_) => futures::future::pending().await,
};
(response.0, Ok(response.1))
}
fn take_stream(self: Box<Self>) -> Result<Box<dyn NonBlockingStream>, TransportError> {
Err(TransportError::NotReady)
}
}
struct HandshakeTransport {
stream: Box<dyn NonBlockingStream>,
}
impl HandshakeTransport {
fn new(stream: Box<dyn NonBlockingStream>) -> Self {
HandshakeTransport { stream }
}
}
#[async_trait]
impl Transport for HandshakeTransport {
fn try_send(&mut self, bytes: Option<Bytes>) -> Result<bool, TransportError> {
self.stream.try_send(bytes)
}
fn try_recv(&mut self) -> Result<Bytes, TransportError> {
self.stream.try_recv()
}
fn source(&mut self) -> IOSource {
self.stream.source()
}
async fn renew(
mut self: Box<HandshakeTransport>, _closer: TransportCloser,
) -> (Box<dyn Transport>, Result<Vec<Bytes>, TransportError>) {
(self, Ok(Vec::new()))
}
fn shutdown(&mut self, how: Shutdown) -> io::Result<()> {
self.stream.shutdown(how)
}
fn take_stream(self: Box<Self>) -> Result<Box<dyn NonBlockingStream>, TransportError> {
Ok(self.stream)
}
}
#[cfg(target_arch = "wasm32")]
pub mod wasm {
use super::*;
use async_trait::async_trait;
use wasm_bindgen::prelude::*;
struct JsDelayGen {
tx: mpsc::Sender<()>,
delay_rx: mpsc::UnboundedReceiver<Duration>,
}
impl JsDelayGen {
fn new(callback: js_sys::Function) -> Self {
let (tx, mut rx) = mpsc::channel(1);
let (delay_tx, delay_rx) = mpsc::unbounded_channel();
wasm_bindgen_futures::spawn_local(async move {
while let Some(_) = rx.recv().await {
let ret = (|| match callback.call0(&JsValue::null()) {
Ok(ret) => {
let m = ret.as_f64()? as i32;
if m < 0 {
None
} else {
Some(m as u64)
}
}
Err(_) => None,
})();
delay_tx.send(Duration::from_millis(ret.unwrap_or(1000))).ok();
}
});
Self { tx, delay_rx }
}
}
#[async_trait]
impl DelayGenerator for JsDelayGen {
async fn reset(&mut self) {}
async fn next_delay(&mut self) -> Duration {
self.tx.send(()).await.ok();
self.delay_rx
.recv()
.await
.expect("thread owned closure dropped too early")
}
}
}