use crate::config::ProtocolId;
use bytes::BytesMut;
use futures::prelude::*;
use asynchronous_codec::Framed;
use libp2p::core::{UpgradeInfo, InboundUpgrade, OutboundUpgrade, upgrade::ProtocolName};
use parking_lot::RwLock;
use std::{collections::VecDeque, io, pin::Pin, sync::Arc, vec::IntoIter as VecIntoIter};
use std::task::{Context, Poll};
use unsigned_varint::codec::UviBytes;
pub struct RegisteredProtocol {
id: ProtocolId,
base_name: Vec<u8>,
supported_versions: Vec<u8>,
handshake_message: Arc<RwLock<Vec<u8>>>,
}
impl RegisteredProtocol {
pub fn new(protocol: impl Into<ProtocolId>, versions: &[u8], handshake_message: Arc<RwLock<Vec<u8>>>)
-> Self {
let protocol = protocol.into();
let mut base_name = b"/substrate/".to_vec();
base_name.extend_from_slice(protocol.as_ref().as_bytes());
base_name.extend_from_slice(b"/");
RegisteredProtocol {
base_name,
id: protocol,
supported_versions: {
let mut tmp = versions.to_vec();
tmp.sort_by(|a, b| b.cmp(&a));
tmp
},
handshake_message,
}
}
pub fn handshake_message(&self) -> &Arc<RwLock<Vec<u8>>> {
&self.handshake_message
}
}
impl Clone for RegisteredProtocol {
fn clone(&self) -> Self {
RegisteredProtocol {
id: self.id.clone(),
base_name: self.base_name.clone(),
supported_versions: self.supported_versions.clone(),
handshake_message: self.handshake_message.clone(),
}
}
}
pub struct RegisteredProtocolSubstream<TSubstream> {
is_closing: bool,
send_queue: VecDeque<BytesMut>,
requires_poll_flush: bool,
inner: stream::Fuse<Framed<TSubstream, UviBytes<BytesMut>>>,
clogged_fuse: bool,
}
impl<TSubstream> RegisteredProtocolSubstream<TSubstream> {
pub fn shutdown(&mut self) {
self.is_closing = true;
self.send_queue.clear();
}
}
#[derive(Debug, Clone)]
pub enum RegisteredProtocolEvent {
Message(BytesMut),
Clogged,
}
impl<TSubstream> Stream for RegisteredProtocolSubstream<TSubstream>
where TSubstream: AsyncRead + AsyncWrite + Unpin {
type Item = Result<RegisteredProtocolEvent, io::Error>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
while !self.send_queue.is_empty() {
match Pin::new(&mut self.inner).poll_ready(cx) {
Poll::Ready(Ok(())) => {},
Poll::Ready(Err(err)) => return Poll::Ready(Some(Err(err))),
Poll::Pending => break,
}
if let Some(packet) = self.send_queue.pop_front() {
Pin::new(&mut self.inner).start_send(packet)?;
self.requires_poll_flush = true;
}
}
if self.is_closing {
return match Pin::new(&mut self.inner).poll_close(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(Ok(_)) => Poll::Ready(None),
Poll::Ready(Err(err)) => Poll::Ready(Some(Err(err))),
}
}
if self.send_queue.len() >= 1536 {
if !self.clogged_fuse {
self.clogged_fuse = true;
return Poll::Ready(Some(Ok(RegisteredProtocolEvent::Clogged)))
}
} else {
self.clogged_fuse = false;
}
if self.requires_poll_flush {
if let Poll::Ready(()) = Pin::new(&mut self.inner).poll_flush(cx)? {
self.requires_poll_flush = false;
}
}
match Pin::new(&mut self.inner).poll_next(cx)? {
Poll::Ready(Some(data)) => {
Poll::Ready(Some(Ok(RegisteredProtocolEvent::Message(data))))
}
Poll::Ready(None) =>
if !self.requires_poll_flush && self.send_queue.is_empty() {
Poll::Ready(None)
} else {
Poll::Pending
}
Poll::Pending => Poll::Pending,
}
}
}
impl UpgradeInfo for RegisteredProtocol {
type Info = RegisteredProtocolName;
type InfoIter = VecIntoIter<Self::Info>;
#[inline]
fn protocol_info(&self) -> Self::InfoIter {
self.supported_versions.iter().map(|&version| {
let num = version.to_string();
let mut name = self.base_name.clone();
name.extend_from_slice(num.as_bytes());
RegisteredProtocolName {
name,
version,
}
}).collect::<Vec<_>>().into_iter()
}
}
#[derive(Debug, Clone)]
pub struct RegisteredProtocolName {
name: Vec<u8>,
version: u8,
}
impl ProtocolName for RegisteredProtocolName {
fn protocol_name(&self) -> &[u8] {
&self.name
}
}
impl<TSubstream> InboundUpgrade<TSubstream> for RegisteredProtocol
where TSubstream: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
type Output = (RegisteredProtocolSubstream<TSubstream>, Vec<u8>);
type Future = Pin<Box<dyn Future<Output = Result<Self::Output, io::Error>> + Send>>;
type Error = io::Error;
fn upgrade_inbound(
self,
socket: TSubstream,
_: Self::Info,
) -> Self::Future {
Box::pin(async move {
let mut framed = {
let mut codec = UviBytes::default();
codec.set_max_len(16 * 1024 * 1024); Framed::new(socket, codec)
};
let handshake = BytesMut::from(&self.handshake_message.read()[..]);
framed.send(handshake).await?;
let received_handshake = framed.next().await
.ok_or_else(|| io::ErrorKind::UnexpectedEof)??;
Ok((RegisteredProtocolSubstream {
is_closing: false,
send_queue: VecDeque::new(),
requires_poll_flush: false,
inner: framed.fuse(),
clogged_fuse: false,
}, received_handshake.to_vec()))
})
}
}
impl<TSubstream> OutboundUpgrade<TSubstream> for RegisteredProtocol
where TSubstream: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
type Output = <Self as InboundUpgrade<TSubstream>>::Output;
type Future = <Self as InboundUpgrade<TSubstream>>::Future;
type Error = <Self as InboundUpgrade<TSubstream>>::Error;
fn upgrade_outbound(
self,
socket: TSubstream,
_: Self::Info,
) -> Self::Future {
Box::pin(async move {
let mut framed = {
let mut codec = UviBytes::default();
codec.set_max_len(16 * 1024 * 1024); Framed::new(socket, codec)
};
let handshake = BytesMut::from(&self.handshake_message.read()[..]);
framed.send(handshake).await?;
let received_handshake = framed.next().await
.ok_or_else(|| {
io::Error::new(io::ErrorKind::UnexpectedEof, "Failed to receive handshake")
})??;
Ok((RegisteredProtocolSubstream {
is_closing: false,
send_queue: VecDeque::new(),
requires_poll_flush: false,
inner: framed.fuse(),
clogged_fuse: false,
}, received_handshake.to_vec()))
})
}
}