pub mod either;
mod map_in;
mod map_out;
pub mod multi;
mod one_shot;
mod pending;
mod select;
use core::slice;
use std::{
collections::{HashMap, HashSet},
error, fmt, io,
task::{Context, Poll},
time::Duration,
};
use libp2p_core::Multiaddr;
pub use map_in::MapInEvent;
pub use map_out::MapOutEvent;
pub use one_shot::{OneShotHandler, OneShotHandlerConfig};
pub use pending::PendingConnectionHandler;
pub use select::ConnectionHandlerSelect;
use smallvec::SmallVec;
pub use crate::upgrade::{InboundUpgradeSend, OutboundUpgradeSend, SendWrapper, UpgradeInfoSend};
use crate::{connection::AsStrHashEq, StreamProtocol};
pub trait ConnectionHandler: Send + 'static {
type FromBehaviour: fmt::Debug + Send + 'static;
type ToBehaviour: fmt::Debug + Send + 'static;
type InboundProtocol: InboundUpgradeSend;
type OutboundProtocol: OutboundUpgradeSend;
type InboundOpenInfo: Send + 'static;
type OutboundOpenInfo: Send + 'static;
fn listen_protocol(&self) -> SubstreamProtocol<Self::InboundProtocol, Self::InboundOpenInfo>;
fn connection_keep_alive(&self) -> bool {
false
}
fn poll(
&mut self,
cx: &mut Context<'_>,
) -> Poll<
ConnectionHandlerEvent<Self::OutboundProtocol, Self::OutboundOpenInfo, Self::ToBehaviour>,
>;
fn poll_close(&mut self, _: &mut Context<'_>) -> Poll<Option<Self::ToBehaviour>> {
Poll::Ready(None)
}
fn map_in_event<TNewIn, TMap>(self, map: TMap) -> MapInEvent<Self, TNewIn, TMap>
where
Self: Sized,
TMap: Fn(&TNewIn) -> Option<&Self::FromBehaviour>,
{
MapInEvent::new(self, map)
}
fn map_out_event<TMap, TNewOut>(self, map: TMap) -> MapOutEvent<Self, TMap>
where
Self: Sized,
TMap: FnMut(Self::ToBehaviour) -> TNewOut,
{
MapOutEvent::new(self, map)
}
fn select<TProto2>(self, other: TProto2) -> ConnectionHandlerSelect<Self, TProto2>
where
Self: Sized,
{
ConnectionHandlerSelect::new(self, other)
}
fn on_behaviour_event(&mut self, _event: Self::FromBehaviour);
fn on_connection_event(
&mut self,
event: ConnectionEvent<
Self::InboundProtocol,
Self::OutboundProtocol,
Self::InboundOpenInfo,
Self::OutboundOpenInfo,
>,
);
}
#[non_exhaustive]
pub enum ConnectionEvent<'a, IP: InboundUpgradeSend, OP: OutboundUpgradeSend, IOI = (), OOI = ()> {
FullyNegotiatedInbound(FullyNegotiatedInbound<IP, IOI>),
FullyNegotiatedOutbound(FullyNegotiatedOutbound<OP, OOI>),
AddressChange(AddressChange<'a>),
DialUpgradeError(DialUpgradeError<OOI, OP>),
ListenUpgradeError(ListenUpgradeError<IOI, IP>),
LocalProtocolsChange(ProtocolsChange<'a>),
RemoteProtocolsChange(ProtocolsChange<'a>),
}
impl<IP, OP, IOI, OOI> fmt::Debug for ConnectionEvent<'_, IP, OP, IOI, OOI>
where
IP: InboundUpgradeSend + fmt::Debug,
IP::Output: fmt::Debug,
IP::Error: fmt::Debug,
OP: OutboundUpgradeSend + fmt::Debug,
OP::Output: fmt::Debug,
OP::Error: fmt::Debug,
IOI: fmt::Debug,
OOI: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
ConnectionEvent::FullyNegotiatedInbound(v) => {
f.debug_tuple("FullyNegotiatedInbound").field(v).finish()
}
ConnectionEvent::FullyNegotiatedOutbound(v) => {
f.debug_tuple("FullyNegotiatedOutbound").field(v).finish()
}
ConnectionEvent::AddressChange(v) => f.debug_tuple("AddressChange").field(v).finish(),
ConnectionEvent::DialUpgradeError(v) => {
f.debug_tuple("DialUpgradeError").field(v).finish()
}
ConnectionEvent::ListenUpgradeError(v) => {
f.debug_tuple("ListenUpgradeError").field(v).finish()
}
ConnectionEvent::LocalProtocolsChange(v) => {
f.debug_tuple("LocalProtocolsChange").field(v).finish()
}
ConnectionEvent::RemoteProtocolsChange(v) => {
f.debug_tuple("RemoteProtocolsChange").field(v).finish()
}
}
}
}
impl<IP: InboundUpgradeSend, OP: OutboundUpgradeSend, IOI, OOI>
ConnectionEvent<'_, IP, OP, IOI, OOI>
{
pub fn is_outbound(&self) -> bool {
match self {
ConnectionEvent::DialUpgradeError(_) | ConnectionEvent::FullyNegotiatedOutbound(_) => {
true
}
ConnectionEvent::FullyNegotiatedInbound(_)
| ConnectionEvent::AddressChange(_)
| ConnectionEvent::LocalProtocolsChange(_)
| ConnectionEvent::RemoteProtocolsChange(_)
| ConnectionEvent::ListenUpgradeError(_) => false,
}
}
pub fn is_inbound(&self) -> bool {
match self {
ConnectionEvent::FullyNegotiatedInbound(_) | ConnectionEvent::ListenUpgradeError(_) => {
true
}
ConnectionEvent::FullyNegotiatedOutbound(_)
| ConnectionEvent::AddressChange(_)
| ConnectionEvent::LocalProtocolsChange(_)
| ConnectionEvent::RemoteProtocolsChange(_)
| ConnectionEvent::DialUpgradeError(_) => false,
}
}
}
#[derive(Debug)]
pub struct FullyNegotiatedInbound<IP: InboundUpgradeSend, IOI = ()> {
pub protocol: IP::Output,
pub info: IOI,
}
#[derive(Debug)]
pub struct FullyNegotiatedOutbound<OP: OutboundUpgradeSend, OOI = ()> {
pub protocol: OP::Output,
pub info: OOI,
}
#[derive(Debug)]
pub struct AddressChange<'a> {
pub new_address: &'a Multiaddr,
}
#[derive(Debug, Clone)]
pub enum ProtocolsChange<'a> {
Added(ProtocolsAdded<'a>),
Removed(ProtocolsRemoved<'a>),
}
impl<'a> ProtocolsChange<'a> {
pub(crate) fn from_initial_protocols<'b, T: AsRef<str> + 'b>(
new_protocols: impl IntoIterator<Item = &'b T>,
buffer: &'a mut Vec<StreamProtocol>,
) -> Self {
buffer.clear();
buffer.extend(
new_protocols
.into_iter()
.filter_map(|i| StreamProtocol::try_from_owned(i.as_ref().to_owned()).ok()),
);
ProtocolsChange::Added(ProtocolsAdded {
protocols: buffer.iter(),
})
}
pub(crate) fn add(
existing_protocols: &HashSet<StreamProtocol>,
to_add: HashSet<StreamProtocol>,
buffer: &'a mut Vec<StreamProtocol>,
) -> Option<Self> {
buffer.clear();
buffer.extend(
to_add
.into_iter()
.filter(|i| !existing_protocols.contains(i)),
);
if buffer.is_empty() {
return None;
}
Some(Self::Added(ProtocolsAdded {
protocols: buffer.iter(),
}))
}
pub(crate) fn remove(
existing_protocols: &mut HashSet<StreamProtocol>,
to_remove: HashSet<StreamProtocol>,
buffer: &'a mut Vec<StreamProtocol>,
) -> Option<Self> {
buffer.clear();
buffer.extend(
to_remove
.into_iter()
.filter_map(|i| existing_protocols.take(&i)),
);
if buffer.is_empty() {
return None;
}
Some(Self::Removed(ProtocolsRemoved {
protocols: buffer.iter(),
}))
}
pub(crate) fn from_full_sets<T: AsRef<str>>(
existing_protocols: &mut HashMap<AsStrHashEq<T>, bool>,
new_protocols: impl IntoIterator<Item = T>,
buffer: &'a mut Vec<StreamProtocol>,
) -> SmallVec<[Self; 2]> {
buffer.clear();
for v in existing_protocols.values_mut() {
*v = false;
}
let mut new_protocol_count = 0; for new_protocol in new_protocols {
existing_protocols
.entry(AsStrHashEq(new_protocol))
.and_modify(|v| *v = true) .or_insert_with_key(|k| {
buffer.extend(StreamProtocol::try_from_owned(k.0.as_ref().to_owned()).ok());
true
});
new_protocol_count += 1;
}
if new_protocol_count == existing_protocols.len() && buffer.is_empty() {
return SmallVec::new();
}
let num_new_protocols = buffer.len();
existing_protocols.retain(|p, &mut is_supported| {
if !is_supported {
buffer.extend(StreamProtocol::try_from_owned(p.0.as_ref().to_owned()).ok());
}
is_supported
});
let (added, removed) = buffer.split_at(num_new_protocols);
let mut changes = SmallVec::new();
if !added.is_empty() {
changes.push(ProtocolsChange::Added(ProtocolsAdded {
protocols: added.iter(),
}));
}
if !removed.is_empty() {
changes.push(ProtocolsChange::Removed(ProtocolsRemoved {
protocols: removed.iter(),
}));
}
changes
}
}
#[derive(Debug, Clone)]
pub struct ProtocolsAdded<'a> {
pub(crate) protocols: slice::Iter<'a, StreamProtocol>,
}
#[derive(Debug, Clone)]
pub struct ProtocolsRemoved<'a> {
pub(crate) protocols: slice::Iter<'a, StreamProtocol>,
}
impl<'a> Iterator for ProtocolsAdded<'a> {
type Item = &'a StreamProtocol;
fn next(&mut self) -> Option<Self::Item> {
self.protocols.next()
}
}
impl<'a> Iterator for ProtocolsRemoved<'a> {
type Item = &'a StreamProtocol;
fn next(&mut self) -> Option<Self::Item> {
self.protocols.next()
}
}
#[derive(Debug)]
pub struct DialUpgradeError<OOI, OP: OutboundUpgradeSend> {
pub info: OOI,
pub error: StreamUpgradeError<OP::Error>,
}
#[derive(Debug)]
pub struct ListenUpgradeError<IOI, IP: InboundUpgradeSend> {
pub info: IOI,
pub error: IP::Error,
}
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub struct SubstreamProtocol<TUpgrade, TInfo = ()> {
upgrade: TUpgrade,
info: TInfo,
timeout: Duration,
}
impl<TUpgrade, TInfo> SubstreamProtocol<TUpgrade, TInfo> {
pub fn new(upgrade: TUpgrade, info: TInfo) -> Self {
SubstreamProtocol {
upgrade,
info,
timeout: Duration::from_secs(10),
}
}
pub fn map_upgrade<U, F>(self, f: F) -> SubstreamProtocol<U, TInfo>
where
F: FnOnce(TUpgrade) -> U,
{
SubstreamProtocol {
upgrade: f(self.upgrade),
info: self.info,
timeout: self.timeout,
}
}
pub fn map_info<U, F>(self, f: F) -> SubstreamProtocol<TUpgrade, U>
where
F: FnOnce(TInfo) -> U,
{
SubstreamProtocol {
upgrade: self.upgrade,
info: f(self.info),
timeout: self.timeout,
}
}
pub fn with_timeout(mut self, timeout: Duration) -> Self {
self.timeout = timeout;
self
}
pub fn upgrade(&self) -> &TUpgrade {
&self.upgrade
}
pub fn info(&self) -> &TInfo {
&self.info
}
pub fn timeout(&self) -> &Duration {
&self.timeout
}
pub fn into_upgrade(self) -> (TUpgrade, TInfo) {
(self.upgrade, self.info)
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
#[non_exhaustive]
pub enum ConnectionHandlerEvent<TConnectionUpgrade, TOutboundOpenInfo, TCustom> {
OutboundSubstreamRequest {
protocol: SubstreamProtocol<TConnectionUpgrade, TOutboundOpenInfo>,
},
ReportRemoteProtocols(ProtocolSupport),
NotifyBehaviour(TCustom),
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ProtocolSupport {
Added(HashSet<StreamProtocol>),
Removed(HashSet<StreamProtocol>),
}
impl<TConnectionUpgrade, TOutboundOpenInfo, TCustom>
ConnectionHandlerEvent<TConnectionUpgrade, TOutboundOpenInfo, TCustom>
{
pub fn map_outbound_open_info<F, I>(
self,
map: F,
) -> ConnectionHandlerEvent<TConnectionUpgrade, I, TCustom>
where
F: FnOnce(TOutboundOpenInfo) -> I,
{
match self {
ConnectionHandlerEvent::OutboundSubstreamRequest { protocol } => {
ConnectionHandlerEvent::OutboundSubstreamRequest {
protocol: protocol.map_info(map),
}
}
ConnectionHandlerEvent::NotifyBehaviour(val) => {
ConnectionHandlerEvent::NotifyBehaviour(val)
}
ConnectionHandlerEvent::ReportRemoteProtocols(support) => {
ConnectionHandlerEvent::ReportRemoteProtocols(support)
}
}
}
pub fn map_protocol<F, I>(self, map: F) -> ConnectionHandlerEvent<I, TOutboundOpenInfo, TCustom>
where
F: FnOnce(TConnectionUpgrade) -> I,
{
match self {
ConnectionHandlerEvent::OutboundSubstreamRequest { protocol } => {
ConnectionHandlerEvent::OutboundSubstreamRequest {
protocol: protocol.map_upgrade(map),
}
}
ConnectionHandlerEvent::NotifyBehaviour(val) => {
ConnectionHandlerEvent::NotifyBehaviour(val)
}
ConnectionHandlerEvent::ReportRemoteProtocols(support) => {
ConnectionHandlerEvent::ReportRemoteProtocols(support)
}
}
}
pub fn map_custom<F, I>(
self,
map: F,
) -> ConnectionHandlerEvent<TConnectionUpgrade, TOutboundOpenInfo, I>
where
F: FnOnce(TCustom) -> I,
{
match self {
ConnectionHandlerEvent::OutboundSubstreamRequest { protocol } => {
ConnectionHandlerEvent::OutboundSubstreamRequest { protocol }
}
ConnectionHandlerEvent::NotifyBehaviour(val) => {
ConnectionHandlerEvent::NotifyBehaviour(map(val))
}
ConnectionHandlerEvent::ReportRemoteProtocols(support) => {
ConnectionHandlerEvent::ReportRemoteProtocols(support)
}
}
}
}
#[derive(Debug)]
pub enum StreamUpgradeError<TUpgrErr> {
Timeout,
Apply(TUpgrErr),
NegotiationFailed,
Io(io::Error),
}
impl<TUpgrErr> StreamUpgradeError<TUpgrErr> {
pub fn map_upgrade_err<F, E>(self, f: F) -> StreamUpgradeError<E>
where
F: FnOnce(TUpgrErr) -> E,
{
match self {
StreamUpgradeError::Timeout => StreamUpgradeError::Timeout,
StreamUpgradeError::Apply(e) => StreamUpgradeError::Apply(f(e)),
StreamUpgradeError::NegotiationFailed => StreamUpgradeError::NegotiationFailed,
StreamUpgradeError::Io(e) => StreamUpgradeError::Io(e),
}
}
}
impl<TUpgrErr> fmt::Display for StreamUpgradeError<TUpgrErr>
where
TUpgrErr: error::Error + 'static,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
StreamUpgradeError::Timeout => {
write!(f, "Timeout error while opening a substream")
}
StreamUpgradeError::Apply(err) => {
write!(f, "Apply: ")?;
crate::print_error_chain(f, err)
}
StreamUpgradeError::NegotiationFailed => {
write!(f, "no protocols could be agreed upon")
}
StreamUpgradeError::Io(e) => {
write!(f, "IO error: ")?;
crate::print_error_chain(f, e)
}
}
}
}
impl<TUpgrErr> error::Error for StreamUpgradeError<TUpgrErr>
where
TUpgrErr: error::Error + 'static,
{
fn source(&self) -> Option<&(dyn error::Error + 'static)> {
None
}
}
#[cfg(test)]
mod test {
use super::*;
fn protocol_set_of(s: &'static str) -> HashSet<StreamProtocol> {
s.split_whitespace()
.map(|p| StreamProtocol::try_from_owned(format!("/{p}")).unwrap())
.collect()
}
fn test_remove(
existing: &mut HashSet<StreamProtocol>,
to_remove: HashSet<StreamProtocol>,
) -> HashSet<StreamProtocol> {
ProtocolsChange::remove(existing, to_remove, &mut Vec::new())
.into_iter()
.flat_map(|c| match c {
ProtocolsChange::Added(_) => panic!("unexpected added"),
ProtocolsChange::Removed(r) => r.cloned(),
})
.collect::<HashSet<_>>()
}
#[test]
fn test_protocol_remove_subset() {
let mut existing = protocol_set_of("a b c");
let to_remove = protocol_set_of("a b");
let change = test_remove(&mut existing, to_remove);
assert_eq!(existing, protocol_set_of("c"));
assert_eq!(change, protocol_set_of("a b"));
}
#[test]
fn test_protocol_remove_all() {
let mut existing = protocol_set_of("a b c");
let to_remove = protocol_set_of("a b c");
let change = test_remove(&mut existing, to_remove);
assert_eq!(existing, protocol_set_of(""));
assert_eq!(change, protocol_set_of("a b c"));
}
#[test]
fn test_protocol_remove_superset() {
let mut existing = protocol_set_of("a b c");
let to_remove = protocol_set_of("a b c d");
let change = test_remove(&mut existing, to_remove);
assert_eq!(existing, protocol_set_of(""));
assert_eq!(change, protocol_set_of("a b c"));
}
#[test]
fn test_protocol_remove_none() {
let mut existing = protocol_set_of("a b c");
let to_remove = protocol_set_of("d");
let change = test_remove(&mut existing, to_remove);
assert_eq!(existing, protocol_set_of("a b c"));
assert_eq!(change, protocol_set_of(""));
}
#[test]
fn test_protocol_remove_none_from_empty() {
let mut existing = protocol_set_of("");
let to_remove = protocol_set_of("d");
let change = test_remove(&mut existing, to_remove);
assert_eq!(existing, protocol_set_of(""));
assert_eq!(change, protocol_set_of(""));
}
fn test_from_full_sets(
existing: HashSet<StreamProtocol>,
new: HashSet<StreamProtocol>,
) -> [HashSet<StreamProtocol>; 2] {
let mut buffer = Vec::new();
let mut existing = existing
.iter()
.map(|p| (AsStrHashEq(p.as_ref()), true))
.collect::<HashMap<_, _>>();
let changes = ProtocolsChange::from_full_sets(
&mut existing,
new.iter().map(AsRef::as_ref),
&mut buffer,
);
let mut added_changes = HashSet::new();
let mut removed_changes = HashSet::new();
for change in changes {
match change {
ProtocolsChange::Added(a) => {
added_changes.extend(a.cloned());
}
ProtocolsChange::Removed(r) => {
removed_changes.extend(r.cloned());
}
}
}
[removed_changes, added_changes]
}
#[test]
fn test_from_full_stes_subset() {
let existing = protocol_set_of("a b c");
let new = protocol_set_of("a b");
let [removed_changes, added_changes] = test_from_full_sets(existing, new);
assert_eq!(added_changes, protocol_set_of(""));
assert_eq!(removed_changes, protocol_set_of("c"));
}
#[test]
fn test_from_full_sets_superset() {
let existing = protocol_set_of("a b");
let new = protocol_set_of("a b c");
let [removed_changes, added_changes] = test_from_full_sets(existing, new);
assert_eq!(added_changes, protocol_set_of("c"));
assert_eq!(removed_changes, protocol_set_of(""));
}
#[test]
fn test_from_full_sets_intersection() {
let existing = protocol_set_of("a b c");
let new = protocol_set_of("b c d");
let [removed_changes, added_changes] = test_from_full_sets(existing, new);
assert_eq!(added_changes, protocol_set_of("d"));
assert_eq!(removed_changes, protocol_set_of("a"));
}
#[test]
fn test_from_full_sets_disjoint() {
let existing = protocol_set_of("a b c");
let new = protocol_set_of("d e f");
let [removed_changes, added_changes] = test_from_full_sets(existing, new);
assert_eq!(added_changes, protocol_set_of("d e f"));
assert_eq!(removed_changes, protocol_set_of("a b c"));
}
#[test]
fn test_from_full_sets_empty() {
let existing = protocol_set_of("");
let new = protocol_set_of("");
let [removed_changes, added_changes] = test_from_full_sets(existing, new);
assert_eq!(added_changes, protocol_set_of(""));
assert_eq!(removed_changes, protocol_set_of(""));
}
}