use std::task::{Context, Poll};
use futures::FutureExt;
use libp2p_core::upgrade::{DeniedUpgrade, ReadyUpgrade};
use libp2p_swarm::{
handler::{
ConnectionEvent, DialUpgradeError, FullyNegotiatedInbound, FullyNegotiatedOutbound,
ListenUpgradeError,
},
ConnectionHandler, ConnectionHandlerEvent, StreamProtocol, SubstreamProtocol,
};
use tracing::error;
use void::Void;
use crate::Run;
#[derive(Debug)]
pub struct Event {
pub stats: Run,
}
pub struct Handler {
inbound: futures_bounded::FuturesSet<Result<Run, std::io::Error>>,
}
impl Handler {
pub fn new() -> Self {
Self {
inbound: futures_bounded::FuturesSet::new(
crate::RUN_TIMEOUT,
crate::MAX_PARALLEL_RUNS_PER_CONNECTION,
),
}
}
}
impl Default for Handler {
fn default() -> Self {
Self::new()
}
}
impl ConnectionHandler for Handler {
type FromBehaviour = Void;
type ToBehaviour = Event;
type InboundProtocol = ReadyUpgrade<StreamProtocol>;
type OutboundProtocol = DeniedUpgrade;
type OutboundOpenInfo = Void;
type InboundOpenInfo = ();
fn listen_protocol(&self) -> SubstreamProtocol<Self::InboundProtocol, Self::InboundOpenInfo> {
SubstreamProtocol::new(ReadyUpgrade::new(crate::PROTOCOL_NAME), ())
}
fn on_behaviour_event(&mut self, v: Self::FromBehaviour) {
void::unreachable(v)
}
fn on_connection_event(
&mut self,
event: ConnectionEvent<
Self::InboundProtocol,
Self::OutboundProtocol,
Self::InboundOpenInfo,
Self::OutboundOpenInfo,
>,
) {
match event {
ConnectionEvent::FullyNegotiatedInbound(FullyNegotiatedInbound {
protocol,
info: _,
}) => {
if self
.inbound
.try_push(crate::protocol::receive_send(protocol).boxed())
.is_err()
{
tracing::warn!("Dropping inbound stream because we are at capacity");
}
}
ConnectionEvent::FullyNegotiatedOutbound(FullyNegotiatedOutbound { info, .. }) => {
void::unreachable(info)
}
ConnectionEvent::DialUpgradeError(DialUpgradeError { info, .. }) => {
void::unreachable(info)
}
ConnectionEvent::AddressChange(_)
| ConnectionEvent::LocalProtocolsChange(_)
| ConnectionEvent::RemoteProtocolsChange(_) => {}
ConnectionEvent::ListenUpgradeError(ListenUpgradeError { info: (), error }) => {
void::unreachable(error)
}
_ => {}
}
}
#[tracing::instrument(level = "trace", name = "ConnectionHandler::poll", skip(self, cx))]
fn poll(
&mut self,
cx: &mut Context<'_>,
) -> Poll<
ConnectionHandlerEvent<Self::OutboundProtocol, Self::OutboundOpenInfo, Self::ToBehaviour>,
> {
loop {
match self.inbound.poll_unpin(cx) {
Poll::Ready(Ok(Ok(stats))) => {
return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(Event { stats }))
}
Poll::Ready(Ok(Err(e))) => {
error!("{e:?}");
continue;
}
Poll::Ready(Err(e @ futures_bounded::Timeout { .. })) => {
error!("inbound perf request timed out: {e}");
continue;
}
Poll::Pending => {}
}
return Poll::Pending;
}
}
}