use std::{
collections::VecDeque,
io,
pin::Pin,
task::{Context, Poll},
time::{Duration, Instant},
};
use asynchronous_codec::Framed;
use futures::prelude::*;
use futures::StreamExt;
use iroh_metrics::{bitswap::BitswapMetrics, core::MRecorder, inc};
use libp2p::core::upgrade::{
InboundUpgrade, NegotiationError, OutboundUpgrade, ProtocolError, UpgradeError,
};
use libp2p::swarm::{
ConnectionHandler, ConnectionHandlerEvent, ConnectionHandlerUpgrErr, KeepAlive,
NegotiatedSubstream, SubstreamProtocol,
};
use smallvec::SmallVec;
use tokio::sync::oneshot;
use tracing::{debug, error, trace, warn};
use crate::{
error::Error,
message::BitswapMessage,
network,
protocol::{BitswapCodec, ProtocolConfig, ProtocolId},
};
const INITIAL_KEEP_ALIVE: u64 = 30;
#[derive(thiserror::Error, Debug)]
pub enum BitswapHandlerError {
#[error("max inbound substreams")]
MaxInboundSubstreams,
#[error("max outbound substreams")]
MaxOutboundSubstreams,
#[error("max transmission size")]
MaxTransmissionSize,
#[error("negotiation timeout")]
NegotiationTimeout,
#[error("negotatiation protocol error {0}")]
NegotiationProtocolError(#[from] ProtocolError),
#[error("io {0}")]
Io(#[from] std::io::Error),
#[error("bitswap {0}")]
Bitswap(#[from] Error),
}
#[derive(Debug)]
pub enum HandlerEvent {
Message {
message: BitswapMessage,
protocol: ProtocolId,
},
Connected {
protocol: ProtocolId,
},
ProtocolNotSuppported,
}
type BitswapMessageResponse = oneshot::Sender<Result<(), network::SendError>>;
#[derive(Debug)]
#[allow(clippy::large_enum_variant)]
pub enum BitswapHandlerIn {
Message(BitswapMessage, BitswapMessageResponse),
Protect,
Unprotect,
}
const MAX_SUBSTREAM_CREATION: usize = 5;
type BitswapConnectionHandlerEvent = ConnectionHandlerEvent<
ProtocolConfig,
(BitswapMessage, BitswapMessageResponse),
HandlerEvent,
BitswapHandlerError,
>;
#[derive(Debug)]
pub struct BitswapHandler {
listen_protocol: SubstreamProtocol<ProtocolConfig, ()>,
outbound_substream: Option<OutboundSubstreamState>,
inbound_substream: Option<InboundSubstreamState>,
events: SmallVec<[BitswapConnectionHandlerEvent; 4]>,
send_queue: SmallVec<[(BitswapMessage, BitswapMessageResponse); 16]>,
outbound_substream_establishing: bool,
outbound_substreams_created: usize,
inbound_substreams_created: usize,
protocol_unsupported: bool,
protocol_sent: bool,
protocol: Option<ProtocolId>,
idle_timeout: Duration,
upgrade_errors: VecDeque<ConnectionHandlerUpgrErr<BitswapHandlerError>>,
keep_alive: KeepAlive,
}
#[derive(Debug)]
enum InboundSubstreamState {
WaitingInput(Framed<NegotiatedSubstream, BitswapCodec>),
Closing(Framed<NegotiatedSubstream, BitswapCodec>),
Poisoned,
}
#[allow(clippy::large_enum_variant)]
#[derive(Debug)]
enum OutboundSubstreamState {
WaitingOutput(Framed<NegotiatedSubstream, BitswapCodec>),
PendingSend(
Framed<NegotiatedSubstream, BitswapCodec>,
(BitswapMessage, BitswapMessageResponse),
),
PendingFlush(Framed<NegotiatedSubstream, BitswapCodec>),
_Closing(Framed<NegotiatedSubstream, BitswapCodec>),
Poisoned,
}
impl BitswapHandler {
pub fn new(protocol_config: ProtocolConfig, idle_timeout: Duration) -> Self {
Self {
listen_protocol: SubstreamProtocol::new(protocol_config, ()),
inbound_substream: None,
outbound_substream: None,
outbound_substream_establishing: false,
outbound_substreams_created: 0,
inbound_substreams_created: 0,
send_queue: SmallVec::new(),
protocol_unsupported: false,
protocol: None,
protocol_sent: false,
idle_timeout,
upgrade_errors: VecDeque::new(),
keep_alive: KeepAlive::Until(Instant::now() + Duration::from_secs(INITIAL_KEEP_ALIVE)),
events: Default::default(),
}
}
}
impl ConnectionHandler for BitswapHandler {
type InEvent = BitswapHandlerIn;
type OutEvent = HandlerEvent;
type Error = BitswapHandlerError;
type InboundOpenInfo = ();
type InboundProtocol = ProtocolConfig;
type OutboundOpenInfo = (BitswapMessage, BitswapMessageResponse);
type OutboundProtocol = ProtocolConfig;
fn listen_protocol(&self) -> SubstreamProtocol<Self::InboundProtocol, Self::InboundOpenInfo> {
self.listen_protocol.clone()
}
fn inject_fully_negotiated_inbound(
&mut self,
protocol: <Self::InboundProtocol as InboundUpgrade<NegotiatedSubstream>>::Output,
_info: Self::InboundOpenInfo,
) {
let substream = protocol;
if self.protocol_unsupported {
return;
}
let protocol_id = substream.codec().protocol;
if self.protocol.is_none() {
self.protocol = Some(protocol_id);
}
self.inbound_substreams_created += 1;
trace!("New inbound substream request: {:?}", protocol_id);
self.inbound_substream = Some(InboundSubstreamState::WaitingInput(substream));
}
fn inject_fully_negotiated_outbound(
&mut self,
protocol: <Self::OutboundProtocol as OutboundUpgrade<NegotiatedSubstream>>::Output,
message: Self::OutboundOpenInfo,
) {
let substream = protocol;
if self.protocol_unsupported {
return;
}
let protocol_id = substream.codec().protocol;
if self.protocol.is_none() {
self.protocol = Some(protocol_id);
}
self.outbound_substream_establishing = false;
self.outbound_substreams_created += 1;
if self.outbound_substream.is_some() {
warn!("Established an outbound substream with one already available");
self.send_queue.push(message);
} else {
trace!("New outbound substream: {:?}", protocol_id);
self.outbound_substream = Some(OutboundSubstreamState::PendingSend(substream, message));
}
}
fn inject_event(&mut self, message: BitswapHandlerIn) {
match message {
BitswapHandlerIn::Message(m, response) => {
tracing::debug!("sending message ({})", self.protocol_unsupported);
if self.protocol_unsupported {
inc!(BitswapMetrics::ProtocolUnsupported);
response
.send(Err(network::SendError::ProtocolNotSupported))
.ok();
} else {
self.send_queue.push((m, response));
self.keep_alive = KeepAlive::Until(Instant::now() + self.idle_timeout);
}
}
BitswapHandlerIn::Protect => {
self.keep_alive = KeepAlive::Yes;
}
BitswapHandlerIn::Unprotect => {
self.keep_alive =
KeepAlive::Until(Instant::now() + Duration::from_secs(INITIAL_KEEP_ALIVE));
}
}
}
fn inject_dial_upgrade_error(
&mut self,
_: Self::OutboundOpenInfo,
e: ConnectionHandlerUpgrErr<
<Self::OutboundProtocol as OutboundUpgrade<NegotiatedSubstream>>::Error,
>,
) {
self.outbound_substream_establishing = false;
warn!("Dial upgrade error {:?}", e);
self.upgrade_errors.push_back(e);
}
fn connection_keep_alive(&self) -> KeepAlive {
self.keep_alive
}
fn poll(&mut self, cx: &mut Context<'_>) -> Poll<BitswapConnectionHandlerEvent> {
inc!(BitswapMetrics::HandlerPollCount);
if !self.events.is_empty() {
return Poll::Ready(self.events.remove(0));
}
inc!(BitswapMetrics::HandlerPollEventCount);
if let Some(error) = self.upgrade_errors.pop_front() {
inc!(BitswapMetrics::HandlerConnUpgradeErrors);
let reported_error = match error {
ConnectionHandlerUpgrErr::Timeout | ConnectionHandlerUpgrErr::Timer => {
Some(BitswapHandlerError::NegotiationTimeout)
}
ConnectionHandlerUpgrErr::Upgrade(UpgradeError::Apply(e)) => Some(e),
ConnectionHandlerUpgrErr::Upgrade(UpgradeError::Select(negotiation_error)) => {
match negotiation_error {
NegotiationError::Failed => {
self.protocol_unsupported = true;
if !self.protocol_sent {
self.protocol_sent = true;
self.inbound_substream = None;
self.outbound_substream = None;
self.keep_alive = KeepAlive::No;
return Poll::Ready(ConnectionHandlerEvent::Custom(
HandlerEvent::ProtocolNotSuppported,
));
} else {
None
}
}
NegotiationError::ProtocolError(e) => {
Some(BitswapHandlerError::NegotiationProtocolError(e))
}
}
}
};
if let Some(error) = reported_error {
return Poll::Ready(ConnectionHandlerEvent::Close(error));
}
}
if !self.protocol_sent {
if let Some(protocol) = self.protocol.as_ref() {
self.protocol_sent = true;
return Poll::Ready(ConnectionHandlerEvent::Custom(HandlerEvent::Connected {
protocol: *protocol,
}));
}
}
if self.inbound_substreams_created > MAX_SUBSTREAM_CREATION {
inc!(BitswapMetrics::InboundSubstreamsCreatedLimit);
return Poll::Ready(ConnectionHandlerEvent::Close(
BitswapHandlerError::MaxInboundSubstreams,
));
}
if !self.send_queue.is_empty()
&& self.outbound_substream.is_none()
&& !self.outbound_substream_establishing
{
inc!(BitswapMetrics::OutboundSubstreamsEvent);
if self.outbound_substreams_created >= MAX_SUBSTREAM_CREATION {
inc!(BitswapMetrics::OutboundSubstreamsCreatedLimit);
return Poll::Ready(ConnectionHandlerEvent::Close(
BitswapHandlerError::MaxOutboundSubstreams,
));
}
let message = self.send_queue.remove(0);
self.send_queue.shrink_to_fit();
self.outbound_substream_establishing = true;
return Poll::Ready(ConnectionHandlerEvent::OutboundSubstreamRequest {
protocol: self.listen_protocol.clone().map_info(|()| message),
});
}
loop {
inc!(BitswapMetrics::HandlerInboundLoopCount);
match std::mem::replace(
&mut self.inbound_substream,
Some(InboundSubstreamState::Poisoned),
) {
Some(InboundSubstreamState::WaitingInput(mut substream)) => {
match substream.poll_next_unpin(cx) {
Poll::Ready(Some(Ok(message))) => {
self.keep_alive = KeepAlive::Until(Instant::now() + self.idle_timeout);
self.inbound_substream =
Some(InboundSubstreamState::WaitingInput(substream));
return Poll::Ready(ConnectionHandlerEvent::Custom(message));
}
Poll::Ready(Some(Err(error))) => {
match error {
BitswapHandlerError::MaxTransmissionSize => {
warn!("Message exceeded the maximum transmission size");
self.inbound_substream =
Some(InboundSubstreamState::WaitingInput(substream));
}
_ => {
warn!("Inbound stream error: {}", error);
self.inbound_substream =
Some(InboundSubstreamState::Closing(substream));
}
}
}
Poll::Ready(None) => {
debug!("Peer closed their outbound stream");
self.inbound_substream =
Some(InboundSubstreamState::Closing(substream));
}
Poll::Pending => {
self.inbound_substream =
Some(InboundSubstreamState::WaitingInput(substream));
break;
}
}
}
Some(InboundSubstreamState::Closing(mut substream)) => {
match Sink::poll_close(Pin::new(&mut substream), cx) {
Poll::Ready(res) => {
if let Err(e) = res {
warn!("Inbound substream error while closing: {:?}", e);
}
self.inbound_substream = None;
if self.outbound_substream.is_none() {
self.keep_alive = KeepAlive::No;
}
break;
}
Poll::Pending => {
self.inbound_substream =
Some(InboundSubstreamState::Closing(substream));
break;
}
}
}
None => {
self.inbound_substream = None;
break;
}
Some(InboundSubstreamState::Poisoned) => {
unreachable!("Error occurred during inbound stream processing")
}
}
}
loop {
inc!(BitswapMetrics::HandlerOutboundLoopCount);
match std::mem::replace(
&mut self.outbound_substream,
Some(OutboundSubstreamState::Poisoned),
) {
Some(OutboundSubstreamState::WaitingOutput(substream)) => {
if !self.send_queue.is_empty() {
let message = self.send_queue.remove(0);
self.send_queue.shrink_to_fit();
self.outbound_substream =
Some(OutboundSubstreamState::PendingSend(substream, message));
} else {
self.outbound_substream =
Some(OutboundSubstreamState::WaitingOutput(substream));
break;
}
}
Some(OutboundSubstreamState::PendingSend(mut substream, (message, response))) => {
match Sink::poll_ready(Pin::new(&mut substream), cx) {
Poll::Ready(Ok(())) => {
tracing::debug!("sedning message");
match Sink::start_send(Pin::new(&mut substream), message) {
Ok(()) => {
response.send(Ok(())).ok();
self.outbound_substream =
Some(OutboundSubstreamState::PendingFlush(substream))
}
e @ Err(BitswapHandlerError::MaxTransmissionSize) => {
error!("Message exceeded the maximum transmission size and was not sent.");
response
.send(Err(network::SendError::Other(
e.unwrap_err().to_string(),
)))
.ok();
self.outbound_substream =
Some(OutboundSubstreamState::WaitingOutput(substream));
}
Err(e) => {
error!("Error sending message: {}", e);
response
.send(Err(network::SendError::Other(e.to_string())))
.ok();
return Poll::Ready(ConnectionHandlerEvent::Close(e));
}
}
}
Poll::Ready(Err(e)) => {
error!("Outbound substream error while sending output: {:?}", e);
return Poll::Ready(ConnectionHandlerEvent::Close(e));
}
Poll::Pending => {
self.keep_alive = KeepAlive::Yes;
self.outbound_substream = Some(OutboundSubstreamState::PendingSend(
substream,
(message, response),
));
break;
}
}
}
Some(OutboundSubstreamState::PendingFlush(mut substream)) => {
match Sink::poll_flush(Pin::new(&mut substream), cx) {
Poll::Ready(Ok(())) => {
self.keep_alive = KeepAlive::Until(Instant::now() + self.idle_timeout);
self.outbound_substream =
Some(OutboundSubstreamState::WaitingOutput(substream))
}
Poll::Ready(Err(e)) => {
return Poll::Ready(ConnectionHandlerEvent::Close(e))
}
Poll::Pending => {
self.keep_alive = KeepAlive::Yes;
self.outbound_substream =
Some(OutboundSubstreamState::PendingFlush(substream));
break;
}
}
}
Some(OutboundSubstreamState::_Closing(mut substream)) => {
match Sink::poll_close(Pin::new(&mut substream), cx) {
Poll::Ready(Ok(())) => {
self.outbound_substream = None;
if self.inbound_substream.is_none() {
self.keep_alive = KeepAlive::No;
}
break;
}
Poll::Ready(Err(e)) => {
warn!("Outbound substream error while closing: {:?}", e);
return Poll::Ready(ConnectionHandlerEvent::Close(
io::Error::new(
io::ErrorKind::BrokenPipe,
"Failed to close outbound substream",
)
.into(),
));
}
Poll::Pending => {
self.keep_alive = KeepAlive::No;
self.outbound_substream =
Some(OutboundSubstreamState::_Closing(substream));
break;
}
}
}
None => {
self.outbound_substream = None;
break;
}
Some(OutboundSubstreamState::Poisoned) => {
unreachable!("Error occurred during outbound stream processing")
}
}
}
Poll::Pending
}
}