use crate::protocol::generic_proto::{
upgrade::{
NotificationsIn, NotificationsOut, NotificationsInSubstream, NotificationsOutSubstream,
NotificationsHandshakeError, RegisteredProtocol, RegisteredProtocolSubstream,
RegisteredProtocolEvent, UpgradeCollec
},
};
use bytes::BytesMut;
use libp2p::core::{either::EitherOutput, ConnectedPoint, PeerId};
use libp2p::core::upgrade::{SelectUpgrade, InboundUpgrade, OutboundUpgrade};
use libp2p::swarm::{
ProtocolsHandler, ProtocolsHandlerEvent,
IntoProtocolsHandler,
KeepAlive,
ProtocolsHandlerUpgrErr,
SubstreamProtocol,
NegotiatedSubstream,
};
use futures::{
channel::mpsc,
lock::{Mutex as FuturesMutex, MutexGuard as FuturesMutexGuard},
prelude::*
};
use log::error;
use parking_lot::{Mutex, RwLock};
use smallvec::SmallVec;
use std::{borrow::Cow, collections::VecDeque, mem, pin::Pin, str, sync::Arc, task::{Context, Poll}, time::Duration};
use wasm_timer::Instant;
const ASYNC_NOTIFICATIONS_BUFFER_SIZE: usize = 8;
const SYNC_NOTIFICATIONS_BUFFER_SIZE: usize = 2048;
const OPEN_TIMEOUT: Duration = Duration::from_secs(10);
const INITIAL_KEEPALIVE_TIME: Duration = Duration::from_secs(5);
pub struct NotifsHandlerProto {
protocols: Vec<(Cow<'static, str>, NotificationsIn, Arc<RwLock<Vec<u8>>>, u64)>,
legacy_protocol: RegisteredProtocol,
}
pub struct NotifsHandler {
protocols: Vec<Protocol>,
when_connection_open: Instant,
endpoint: ConnectedPoint,
peer_id: PeerId,
legacy_protocol: RegisteredProtocol,
legacy_substreams: SmallVec<[RegisteredProtocolSubstream<NegotiatedSubstream>; 4]>,
legacy_shutdown: SmallVec<[RegisteredProtocolSubstream<NegotiatedSubstream>; 4]>,
events_queue: VecDeque<
ProtocolsHandlerEvent<NotificationsOut, usize, NotifsHandlerOut, NotifsHandlerError>
>,
}
struct Protocol {
name: Cow<'static, str>,
in_upgrade: NotificationsIn,
handshake: Arc<RwLock<Vec<u8>>>,
max_notification_size: u64,
state: State,
}
enum State {
Closed {
pending_opening: bool,
},
OpenDesiredByRemote {
in_substream: NotificationsInSubstream<NegotiatedSubstream>,
pending_opening: bool,
},
Opening {
in_substream: Option<NotificationsInSubstream<NegotiatedSubstream>>,
},
Open {
notifications_sink_rx: stream::Select<
stream::Fuse<mpsc::Receiver<NotificationsSinkMessage>>,
stream::Fuse<mpsc::Receiver<NotificationsSinkMessage>>
>,
out_substream: Option<NotificationsOutSubstream<NegotiatedSubstream>>,
in_substream: Option<NotificationsInSubstream<NegotiatedSubstream>>,
},
}
impl IntoProtocolsHandler for NotifsHandlerProto {
type Handler = NotifsHandler;
fn inbound_protocol(&self) -> SelectUpgrade<UpgradeCollec<NotificationsIn>, RegisteredProtocol> {
let protocols = self.protocols.iter()
.map(|(_, p, _, _)| p.clone())
.collect::<UpgradeCollec<_>>();
SelectUpgrade::new(protocols, self.legacy_protocol.clone())
}
fn into_handler(self, peer_id: &PeerId, connected_point: &ConnectedPoint) -> Self::Handler {
NotifsHandler {
protocols: self.protocols.into_iter().map(|(name, in_upgrade, handshake, max_size)| {
Protocol {
name,
in_upgrade,
handshake,
state: State::Closed {
pending_opening: false,
},
max_notification_size: max_size,
}
}).collect(),
peer_id: peer_id.clone(),
endpoint: connected_point.clone(),
when_connection_open: Instant::now(),
legacy_protocol: self.legacy_protocol,
legacy_substreams: SmallVec::new(),
legacy_shutdown: SmallVec::new(),
events_queue: VecDeque::with_capacity(16),
}
}
}
#[derive(Debug, Clone)]
pub enum NotifsHandlerIn {
Open {
protocol_index: usize,
},
Close {
protocol_index: usize,
},
}
#[derive(Debug)]
pub enum NotifsHandlerOut {
OpenResultOk {
protocol_index: usize,
endpoint: ConnectedPoint,
received_handshake: Vec<u8>,
notifications_sink: NotificationsSink,
},
OpenResultErr {
protocol_index: usize,
},
CloseResult {
protocol_index: usize,
},
OpenDesiredByRemote {
protocol_index: usize,
},
CloseDesired {
protocol_index: usize,
},
CustomMessage {
message: BytesMut,
},
Notification {
protocol_index: usize,
message: BytesMut,
},
}
#[derive(Debug, Clone)]
pub struct NotificationsSink {
inner: Arc<NotificationsSinkInner>,
}
#[derive(Debug)]
struct NotificationsSinkInner {
peer_id: PeerId,
async_channel: FuturesMutex<mpsc::Sender<NotificationsSinkMessage>>,
sync_channel: Mutex<mpsc::Sender<NotificationsSinkMessage>>,
}
#[derive(Debug)]
enum NotificationsSinkMessage {
Notification {
message: Vec<u8>,
},
ForceClose,
}
impl NotificationsSink {
pub fn peer_id(&self) -> &PeerId {
&self.inner.peer_id
}
pub fn send_sync_notification<'a>(
&'a self,
message: impl Into<Vec<u8>>
) {
let mut lock = self.inner.sync_channel.lock();
let result = lock.try_send(NotificationsSinkMessage::Notification {
message: message.into()
});
if result.is_err() {
let _result2 = lock.clone().try_send(NotificationsSinkMessage::ForceClose);
debug_assert!(_result2.map(|()| true).unwrap_or_else(|err| err.is_disconnected()));
}
}
pub async fn reserve_notification<'a>(&'a self) -> Result<Ready<'a>, ()> {
let mut lock = self.inner.async_channel.lock().await;
let poll_ready = future::poll_fn(|cx| lock.poll_ready(cx)).await;
if poll_ready.is_ok() {
Ok(Ready { lock })
} else {
Err(())
}
}
}
#[must_use]
#[derive(Debug)]
pub struct Ready<'a> {
lock: FuturesMutexGuard<'a, mpsc::Sender<NotificationsSinkMessage>>,
}
impl<'a> Ready<'a> {
pub fn send(
mut self,
notification: impl Into<Vec<u8>>
) -> Result<(), ()> {
self.lock.start_send(NotificationsSinkMessage::Notification {
message: notification.into(),
}).map_err(|_| ())
}
}
#[derive(Debug, derive_more::Display, derive_more::Error)]
pub enum NotifsHandlerError {
SyncNotificationsClogged,
}
impl NotifsHandlerProto {
pub fn new(
legacy_protocol: RegisteredProtocol,
list: impl Into<Vec<(Cow<'static, str>, Arc<RwLock<Vec<u8>>>, u64)>>,
) -> Self {
let protocols = list
.into()
.into_iter()
.map(|(proto_name, msg, max_notif_size)| {
(proto_name.clone(), NotificationsIn::new(proto_name, max_notif_size), msg, max_notif_size)
})
.collect();
NotifsHandlerProto {
protocols,
legacy_protocol,
}
}
}
impl ProtocolsHandler for NotifsHandler {
type InEvent = NotifsHandlerIn;
type OutEvent = NotifsHandlerOut;
type Error = NotifsHandlerError;
type InboundProtocol = SelectUpgrade<UpgradeCollec<NotificationsIn>, RegisteredProtocol>;
type OutboundProtocol = NotificationsOut;
type OutboundOpenInfo = usize;
type InboundOpenInfo = ();
fn listen_protocol(&self) -> SubstreamProtocol<Self::InboundProtocol, ()> {
let protocols = self.protocols.iter()
.map(|p| p.in_upgrade.clone())
.collect::<UpgradeCollec<_>>();
let with_legacy = SelectUpgrade::new(protocols, self.legacy_protocol.clone());
SubstreamProtocol::new(with_legacy, ())
}
fn inject_fully_negotiated_inbound(
&mut self,
out: <Self::InboundProtocol as InboundUpgrade<NegotiatedSubstream>>::Output,
(): ()
) {
match out {
EitherOutput::First(((_remote_handshake, mut new_substream), protocol_index)) => {
let mut protocol_info = &mut self.protocols[protocol_index];
match protocol_info.state {
State::Closed { pending_opening } => {
self.events_queue.push_back(ProtocolsHandlerEvent::Custom(
NotifsHandlerOut::OpenDesiredByRemote {
protocol_index,
}
));
protocol_info.state = State::OpenDesiredByRemote {
in_substream: new_substream,
pending_opening,
};
},
State::OpenDesiredByRemote { .. } => {
return;
},
State::Opening { ref mut in_substream, .. } |
State::Open { ref mut in_substream, .. } => {
if in_substream.is_some() {
return;
}
let handshake_message = protocol_info.handshake.read().clone();
new_substream.send_handshake(handshake_message);
*in_substream = Some(new_substream);
},
};
}
EitherOutput::Second((substream, _handshake)) => {
if self.legacy_substreams.len() <= 4 {
self.legacy_substreams.push(substream);
}
},
}
}
fn inject_fully_negotiated_outbound(
&mut self,
(handshake, substream): <Self::OutboundProtocol as OutboundUpgrade<NegotiatedSubstream>>::Output,
protocol_index: Self::OutboundOpenInfo
) {
match self.protocols[protocol_index].state {
State::Closed { ref mut pending_opening } |
State::OpenDesiredByRemote { ref mut pending_opening, .. } => {
debug_assert!(*pending_opening);
*pending_opening = false;
}
State::Open { .. } => {
error!(target: "sub-libp2p", "☎️ State mismatch in notifications handler");
debug_assert!(false);
}
State::Opening { ref mut in_substream } => {
let (async_tx, async_rx) = mpsc::channel(ASYNC_NOTIFICATIONS_BUFFER_SIZE);
let (sync_tx, sync_rx) = mpsc::channel(SYNC_NOTIFICATIONS_BUFFER_SIZE);
let notifications_sink = NotificationsSink {
inner: Arc::new(NotificationsSinkInner {
peer_id: self.peer_id.clone(),
async_channel: FuturesMutex::new(async_tx),
sync_channel: Mutex::new(sync_tx),
}),
};
self.protocols[protocol_index].state = State::Open {
notifications_sink_rx: stream::select(async_rx.fuse(), sync_rx.fuse()),
out_substream: Some(substream),
in_substream: in_substream.take(),
};
self.events_queue.push_back(ProtocolsHandlerEvent::Custom(
NotifsHandlerOut::OpenResultOk {
protocol_index,
endpoint: self.endpoint.clone(),
received_handshake: handshake,
notifications_sink
}
));
}
}
}
fn inject_event(&mut self, message: NotifsHandlerIn) {
match message {
NotifsHandlerIn::Open { protocol_index } => {
let protocol_info = &mut self.protocols[protocol_index];
match &mut protocol_info.state {
State::Closed { pending_opening } => {
if !*pending_opening {
let proto = NotificationsOut::new(
protocol_info.name.clone(),
protocol_info.handshake.read().clone(),
protocol_info.max_notification_size
);
self.events_queue.push_back(ProtocolsHandlerEvent::OutboundSubstreamRequest {
protocol: SubstreamProtocol::new(proto, protocol_index)
.with_timeout(OPEN_TIMEOUT),
});
}
protocol_info.state = State::Opening {
in_substream: None,
};
},
State::OpenDesiredByRemote { pending_opening, in_substream } => {
let handshake_message = protocol_info.handshake.read().clone();
if !*pending_opening {
let proto = NotificationsOut::new(
protocol_info.name.clone(),
handshake_message.clone(),
protocol_info.max_notification_size,
);
self.events_queue.push_back(ProtocolsHandlerEvent::OutboundSubstreamRequest {
protocol: SubstreamProtocol::new(proto, protocol_index)
.with_timeout(OPEN_TIMEOUT),
});
}
in_substream.send_handshake(handshake_message);
let in_substream = match
mem::replace(&mut protocol_info.state, State::Opening { in_substream: None })
{
State::OpenDesiredByRemote { in_substream, .. } => in_substream,
_ => unreachable!()
};
protocol_info.state = State::Opening {
in_substream: Some(in_substream),
};
},
State::Opening { .. } |
State::Open { .. } => {
error!(target: "sub-libp2p", "opening already-opened handler");
debug_assert!(false);
},
}
},
NotifsHandlerIn::Close { protocol_index } => {
for mut substream in self.legacy_substreams.drain(..) {
substream.shutdown();
self.legacy_shutdown.push(substream);
}
match self.protocols[protocol_index].state {
State::Open { .. } => {
self.protocols[protocol_index].state = State::Closed {
pending_opening: false,
};
},
State::Opening { .. } => {
self.protocols[protocol_index].state = State::Closed {
pending_opening: true,
};
self.events_queue.push_back(ProtocolsHandlerEvent::Custom(
NotifsHandlerOut::OpenResultErr {
protocol_index,
}
));
},
State::OpenDesiredByRemote { pending_opening, .. } => {
self.protocols[protocol_index].state = State::Closed {
pending_opening,
};
}
State::Closed { .. } => {},
}
self.events_queue.push_back(
ProtocolsHandlerEvent::Custom(NotifsHandlerOut::CloseResult {
protocol_index,
})
);
},
}
}
fn inject_dial_upgrade_error(
&mut self,
num: usize,
_: ProtocolsHandlerUpgrErr<NotificationsHandshakeError>
) {
match self.protocols[num].state {
State::Closed { ref mut pending_opening } |
State::OpenDesiredByRemote { ref mut pending_opening, .. } => {
debug_assert!(*pending_opening);
*pending_opening = false;
}
State::Opening { .. } => {
self.protocols[num].state = State::Closed {
pending_opening: false,
};
self.events_queue.push_back(ProtocolsHandlerEvent::Custom(
NotifsHandlerOut::OpenResultErr {
protocol_index: num,
}
));
}
State::Open { .. } => debug_assert!(false),
}
}
fn connection_keep_alive(&self) -> KeepAlive {
if !self.legacy_substreams.is_empty() {
return KeepAlive::Yes;
}
if self.protocols.iter().any(|p| !matches!(p.state, State::Closed { .. })) {
return KeepAlive::Yes;
}
KeepAlive::Until(self.when_connection_open + INITIAL_KEEPALIVE_TIME)
}
fn poll(
&mut self,
cx: &mut Context,
) -> Poll<
ProtocolsHandlerEvent<Self::OutboundProtocol, Self::OutboundOpenInfo, Self::OutEvent, Self::Error>
> {
if let Some(ev) = self.events_queue.pop_front() {
return Poll::Ready(ev);
}
for protocol_index in 0..self.protocols.len() {
match &mut self.protocols[protocol_index].state {
State::Closed { .. } |
State::Open { in_substream: None, .. } |
State::Opening { in_substream: None } => {}
State::Open { in_substream: in_substream @ Some(_), .. } => {
match Stream::poll_next(Pin::new(in_substream.as_mut().unwrap()), cx) {
Poll::Pending => {},
Poll::Ready(Some(Ok(message))) => {
let event = NotifsHandlerOut::Notification {
protocol_index,
message,
};
return Poll::Ready(ProtocolsHandlerEvent::Custom(event))
},
Poll::Ready(None) | Poll::Ready(Some(Err(_))) =>
*in_substream = None,
}
}
State::OpenDesiredByRemote { in_substream, pending_opening } => {
match NotificationsInSubstream::poll_process(Pin::new(in_substream), cx) {
Poll::Pending => {},
Poll::Ready(Ok(void)) => match void {},
Poll::Ready(Err(_)) => {
self.protocols[protocol_index].state = State::Closed {
pending_opening: *pending_opening,
};
return Poll::Ready(ProtocolsHandlerEvent::Custom(
NotifsHandlerOut::CloseDesired { protocol_index }
))
},
}
}
State::Opening { in_substream: in_substream @ Some(_), .. } => {
match NotificationsInSubstream::poll_process(Pin::new(in_substream.as_mut().unwrap()), cx) {
Poll::Pending => {},
Poll::Ready(Ok(void)) => match void {},
Poll::Ready(Err(_)) => *in_substream = None,
}
}
}
match &mut self.protocols[protocol_index].state {
State::Open { out_substream: out_substream @ Some(_), .. } => {
match Sink::poll_flush(Pin::new(out_substream.as_mut().unwrap()), cx) {
Poll::Pending | Poll::Ready(Ok(())) => {},
Poll::Ready(Err(_)) => {
*out_substream = None;
let event = NotifsHandlerOut::CloseDesired { protocol_index };
return Poll::Ready(ProtocolsHandlerEvent::Custom(event));
}
};
}
State::Closed { .. } |
State::Opening { .. } |
State::Open { out_substream: None, .. } |
State::OpenDesiredByRemote { .. } => {}
}
if let State::Open { notifications_sink_rx, out_substream: Some(out_substream), .. }
= &mut self.protocols[protocol_index].state
{
loop {
match out_substream.poll_ready_unpin(cx) {
Poll::Ready(_) => {},
Poll::Pending => break
}
let message = match notifications_sink_rx.poll_next_unpin(cx) {
Poll::Ready(Some(msg)) => msg,
Poll::Ready(None) | Poll::Pending => break,
};
match message {
NotificationsSinkMessage::Notification { message } => {
let _ = out_substream.start_send_unpin(message);
cx.waker().wake_by_ref();
}
NotificationsSinkMessage::ForceClose => {
return Poll::Ready(
ProtocolsHandlerEvent::Close(NotifsHandlerError::SyncNotificationsClogged)
);
}
}
}
}
if matches!(self.protocols[0].state, State::Open { .. }) {
for n in (0..self.legacy_substreams.len()).rev() {
let mut substream = self.legacy_substreams.swap_remove(n);
let poll_outcome = Pin::new(&mut substream).poll_next(cx);
match poll_outcome {
Poll::Pending => self.legacy_substreams.push(substream),
Poll::Ready(Some(Ok(RegisteredProtocolEvent::Message(message)))) => {
self.legacy_substreams.push(substream);
return Poll::Ready(ProtocolsHandlerEvent::Custom(
NotifsHandlerOut::CustomMessage { message }
))
},
Poll::Ready(Some(Ok(RegisteredProtocolEvent::Clogged))) => {
return Poll::Ready(ProtocolsHandlerEvent::Close(
NotifsHandlerError::SyncNotificationsClogged
))
}
Poll::Ready(None) | Poll::Ready(Some(Err(_))) => {
if matches!(poll_outcome, Poll::Ready(None)) {
self.legacy_shutdown.push(substream);
}
if let State::Open { out_substream, .. } = &mut self.protocols[0].state {
if !out_substream.is_some() {
*out_substream = None;
return Poll::Ready(ProtocolsHandlerEvent::Custom(
NotifsHandlerOut::CloseDesired {
protocol_index: 0,
}
))
}
}
}
}
}
}
}
shutdown_list(&mut self.legacy_shutdown, cx);
Poll::Pending
}
}
fn shutdown_list
(list: &mut SmallVec<impl smallvec::Array<Item = RegisteredProtocolSubstream<NegotiatedSubstream>>>,
cx: &mut Context)
{
'outer: for n in (0..list.len()).rev() {
let mut substream = list.swap_remove(n);
loop {
match substream.poll_next_unpin(cx) {
Poll::Ready(Some(Ok(_))) => {}
Poll::Pending => break,
Poll::Ready(Some(Err(_))) | Poll::Ready(None) => continue 'outer,
}
}
list.push(substream);
}
}