use crate::muxing::{IStreamMuxer, StreamMuxer, StreamMuxerEx};
use crate::secure_io::SecureInfo;
use crate::transport::{ConnectionInfo, IListener, ITransport, ListenerEvent, TransportListener};
use crate::upgrade::multistream::Multistream;
use crate::upgrade::Upgrader;
use crate::{transport::TransportError, Multiaddr, Transport};
use async_trait::async_trait;
use futures::{future::Either, stream::FuturesUnordered, AsyncRead, AsyncWrite, FutureExt, StreamExt};
use std::{
future::Future,
num::NonZeroUsize,
pin::Pin,
task::{Context, Poll},
};
#[derive(Debug, Clone)]
pub struct TransportUpgrade<InnerTrans, TMux, TSec> {
inner: InnerTrans,
mux: Multistream<TMux>,
sec: Multistream<TSec>,
}
impl<InnerTrans, TMux, TSec> TransportUpgrade<InnerTrans, TMux, TSec>
where
InnerTrans: Transport,
InnerTrans::Output: ConnectionInfo + AsyncRead + AsyncWrite + Unpin,
TSec: Upgrader<InnerTrans::Output>,
TSec::Output: SecureInfo + AsyncRead + AsyncWrite + Unpin,
TMux: Upgrader<TSec::Output>,
TMux::Output: StreamMuxer,
{
pub fn new(inner: InnerTrans, mux: TMux, sec: TSec) -> Self {
TransportUpgrade {
inner,
sec: Multistream::new(sec),
mux: Multistream::new(mux),
}
}
}
#[async_trait]
impl<InnerTrans, TMux, TSec> Transport for TransportUpgrade<InnerTrans, TMux, TSec>
where
InnerTrans: Transport + Clone + 'static,
InnerTrans::Output: ConnectionInfo + AsyncRead + AsyncWrite + Unpin + 'static,
TSec: Upgrader<InnerTrans::Output> + 'static,
TSec::Output: SecureInfo + AsyncRead + AsyncWrite + Unpin,
TMux: Upgrader<TSec::Output> + 'static,
TMux::Output: StreamMuxerEx + 'static,
{
type Output = IStreamMuxer;
fn listen_on(&mut self, addr: Multiaddr) -> Result<IListener<Self::Output>, TransportError> {
let inner_listener = self.inner.listen_on(addr)?;
let listener = ListenerUpgrade::new(inner_listener, self.mux.clone(), self.sec.clone());
Ok(Box::new(listener))
}
async fn dial(&mut self, addr: Multiaddr) -> Result<Self::Output, TransportError> {
let socket = self.inner.dial(addr).await?;
let sec = self.sec.clone();
log::debug!("upgrading outbound security towards {}...", socket.remote_multiaddr());
let sec_socket = sec.select_outbound(socket).await?;
let mux = self.mux.clone();
log::debug!("security applied, upgrading outbound stream muxer...");
let o = mux.select_outbound(sec_socket).await?;
Ok(Box::new(o))
}
fn box_clone(&self) -> ITransport<Self::Output> {
Box::new(self.clone())
}
fn protocols(&self) -> Vec<u32> {
self.inner.protocols()
}
}
type UpgradeFuture<Output> = Pin<Box<dyn Future<Output = Result<Output, TransportError>> + Send>>;
pub struct ListenerUpgrade<TOutput, TMux, TSec>
where
TOutput: ConnectionInfo + AsyncRead + AsyncWrite + Unpin + 'static,
TSec: Upgrader<TOutput> + Send + Clone + 'static,
TSec::Output: SecureInfo + AsyncRead + AsyncWrite + Unpin,
TMux: Upgrader<TSec::Output> + 'static,
TMux::Output: StreamMuxerEx + 'static,
{
inner: IListener<TOutput>,
mux: Multistream<TMux>,
sec: Multistream<TSec>,
futures: FuturesUnordered<UpgradeFuture<TMux::Output>>,
limit: Option<NonZeroUsize>,
}
impl<TOutput, TMux, TSec> ListenerUpgrade<TOutput, TMux, TSec>
where
TOutput: ConnectionInfo + AsyncRead + AsyncWrite + Unpin + 'static,
TSec: Upgrader<TOutput> + Send + Clone + 'static,
TSec::Output: SecureInfo + AsyncRead + AsyncWrite + Unpin,
TMux: Upgrader<TSec::Output> + 'static,
TMux::Output: StreamMuxerEx + 'static,
{
pub(crate) fn new(inner: IListener<TOutput>, mux: Multistream<TMux>, sec: Multistream<TSec>) -> Self {
Self {
inner,
mux,
sec,
futures: FuturesUnordered::new(),
limit: NonZeroUsize::new(10),
}
}
pub fn limit(&self) -> Option<NonZeroUsize> {
self.limit
}
pub fn set_limit(&mut self, limit: Option<NonZeroUsize>) {
self.limit = limit;
}
}
#[async_trait]
impl<TOutput, TMux, TSec> TransportListener for ListenerUpgrade<TOutput, TMux, TSec>
where
TOutput: ConnectionInfo + AsyncRead + AsyncWrite + Unpin + 'static,
TSec: Upgrader<TOutput> + Send + Clone + 'static,
TSec::Output: SecureInfo + AsyncRead + AsyncWrite + Unpin,
TMux: Upgrader<TSec::Output> + 'static,
TMux::Output: StreamMuxerEx + 'static,
{
type Output = IStreamMuxer;
async fn accept(&mut self) -> Result<ListenerEvent<Self::Output>, TransportError> {
loop {
let mut next_incoming = if self.limit.map(|limit| limit.get() > self.futures.len()).unwrap_or(true) {
self.inner.accept()
} else {
futures::future::pending().boxed()
};
let mut next_upgraded = self.futures.next();
let next = futures::future::poll_fn(move |cx: &mut Context| {
if let Poll::Ready(ret) = next_incoming.poll_unpin(cx) {
return Poll::Ready(Either::Left(ret));
}
match next_upgraded.poll_unpin(cx) {
Poll::Pending | Poll::Ready(None) => {
Poll::Pending
}
Poll::Ready(Some(ret)) => Poll::Ready(Either::Right(ret)),
}
});
let event_or_upgraded = next.await;
match event_or_upgraded {
Either::Left(ret) => {
match ret? {
ListenerEvent::AddressAdded(a) => {
return Ok(ListenerEvent::AddressAdded(a));
}
ListenerEvent::AddressDeleted(a) => {
return Ok(ListenerEvent::AddressDeleted(a));
}
ListenerEvent::Accepted(socket) => {
let sec = self.sec.clone();
let mux = self.mux.clone();
self.futures.push(
async move {
log::trace!("accept a new connection from {}, upgrading...", socket.remote_multiaddr());
let sec_socket = sec.select_inbound(socket).await?;
mux.select_inbound(sec_socket).await
}
.boxed(),
);
}
}
}
Either::Right(ret) => {
let o = ret?;
return Ok(ListenerEvent::Accepted(Box::new(o)));
}
}
}
}
fn multi_addr(&self) -> Option<&Multiaddr> {
self.inner.multi_addr()
}
}
pub type IListenerEx = IListener<IStreamMuxer>;
pub type ITransportEx = ITransport<IStreamMuxer>;
impl Clone for ITransportEx {
fn clone(&self) -> Self {
self.box_clone()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::transport::memory::MemoryTransport;
use crate::transport::protector::ProtectorTransport;
use crate::upgrade::dummy::DummyUpgrader;
use libp2p_pnet::*;
#[test]
fn test_dialer_and_listener() {
let rand_port = rand::random::<u64>().saturating_add(1);
let t1_addr: Multiaddr = format!("/memory/{}", rand_port).parse().unwrap();
let cloned_t1_addr = t1_addr.clone();
let psk = "/key/swarm/psk/1.0.0/\n/base16/\n6189c5cf0b87fb800c1a9feeda73c6ab5e998db48fb9e6a978575c770ceef683"
.parse::<PreSharedKey>()
.unwrap();
let pnet = PnetConfig::new(psk);
let pro_trans = ProtectorTransport::new(MemoryTransport::default(), pnet);
let mut t1 = TransportUpgrade::new(pro_trans.clone(), DummyUpgrader::new(), DummyUpgrader::new());
let listener = async move {
let mut listener = t1.listen_on(t1_addr.clone()).unwrap();
let mut socket = match listener.accept().await.unwrap() {
ListenerEvent::Accepted(s) => s,
_ => panic!("unreachable"),
};
socket.accept_stream().await.unwrap_err();
};
let mut t2 = TransportUpgrade::new(pro_trans, DummyUpgrader::new(), DummyUpgrader::new());
let dialer = async move {
let mut socket = t2.dial(cloned_t1_addr).await.unwrap();
let r = socket.open_stream().await;
assert!(r.is_err());
};
futures::executor::block_on(futures::future::join(listener, dialer));
}
}