mod protocol;
use crate::codec::RequestResponseCodec;
use crate::{RequestId, EMPTY_QUEUE_SHRINK_THRESHOLD};
pub use protocol::{ProtocolSupport, RequestProtocol, ResponseProtocol};
use futures::{channel::oneshot, future::BoxFuture, prelude::*, stream::FuturesUnordered};
use instant::Instant;
use libp2p_core::upgrade::{NegotiationError, UpgradeError};
use libp2p_swarm::{
protocols_handler::{
KeepAlive, ProtocolsHandler, ProtocolsHandlerEvent, ProtocolsHandlerUpgrErr,
},
SubstreamProtocol,
};
use smallvec::SmallVec;
use std::{
collections::VecDeque,
fmt, io,
sync::{
atomic::{AtomicU64, Ordering},
Arc,
},
task::{Context, Poll},
time::Duration,
};
#[doc(hidden)]
pub struct RequestResponseHandler<TCodec>
where
TCodec: RequestResponseCodec,
{
inbound_protocols: SmallVec<[TCodec::Protocol; 2]>,
codec: TCodec,
keep_alive_timeout: Duration,
substream_timeout: Duration,
keep_alive: KeepAlive,
pending_error: Option<ProtocolsHandlerUpgrErr<io::Error>>,
pending_events: VecDeque<RequestResponseHandlerEvent<TCodec>>,
outbound: VecDeque<RequestProtocol<TCodec>>,
inbound: FuturesUnordered<
BoxFuture<
'static,
Result<
(
(RequestId, TCodec::Request),
oneshot::Sender<TCodec::Response>,
),
oneshot::Canceled,
>,
>,
>,
inbound_request_id: Arc<AtomicU64>,
}
impl<TCodec> RequestResponseHandler<TCodec>
where
TCodec: RequestResponseCodec,
{
pub(super) fn new(
inbound_protocols: SmallVec<[TCodec::Protocol; 2]>,
codec: TCodec,
keep_alive_timeout: Duration,
substream_timeout: Duration,
inbound_request_id: Arc<AtomicU64>,
) -> Self {
Self {
inbound_protocols,
codec,
keep_alive: KeepAlive::Yes,
keep_alive_timeout,
substream_timeout,
outbound: VecDeque::new(),
inbound: FuturesUnordered::new(),
pending_events: VecDeque::new(),
pending_error: None,
inbound_request_id,
}
}
}
#[doc(hidden)]
pub enum RequestResponseHandlerEvent<TCodec>
where
TCodec: RequestResponseCodec,
{
Request {
request_id: RequestId,
request: TCodec::Request,
sender: oneshot::Sender<TCodec::Response>,
},
Response {
request_id: RequestId,
response: TCodec::Response,
},
ResponseSent(RequestId),
ResponseOmission(RequestId),
OutboundTimeout(RequestId),
OutboundUnsupportedProtocols(RequestId),
InboundTimeout(RequestId),
InboundUnsupportedProtocols(RequestId),
}
impl<TCodec: RequestResponseCodec> fmt::Debug for RequestResponseHandlerEvent<TCodec> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
RequestResponseHandlerEvent::Request {
request_id,
request: _,
sender: _,
} => f
.debug_struct("RequestResponseHandlerEvent::Request")
.field("request_id", request_id)
.finish(),
RequestResponseHandlerEvent::Response {
request_id,
response: _,
} => f
.debug_struct("RequestResponseHandlerEvent::Response")
.field("request_id", request_id)
.finish(),
RequestResponseHandlerEvent::ResponseSent(request_id) => f
.debug_tuple("RequestResponseHandlerEvent::ResponseSent")
.field(request_id)
.finish(),
RequestResponseHandlerEvent::ResponseOmission(request_id) => f
.debug_tuple("RequestResponseHandlerEvent::ResponseOmission")
.field(request_id)
.finish(),
RequestResponseHandlerEvent::OutboundTimeout(request_id) => f
.debug_tuple("RequestResponseHandlerEvent::OutboundTimeout")
.field(request_id)
.finish(),
RequestResponseHandlerEvent::OutboundUnsupportedProtocols(request_id) => f
.debug_tuple("RequestResponseHandlerEvent::OutboundUnsupportedProtocols")
.field(request_id)
.finish(),
RequestResponseHandlerEvent::InboundTimeout(request_id) => f
.debug_tuple("RequestResponseHandlerEvent::InboundTimeout")
.field(request_id)
.finish(),
RequestResponseHandlerEvent::InboundUnsupportedProtocols(request_id) => f
.debug_tuple("RequestResponseHandlerEvent::InboundUnsupportedProtocols")
.field(request_id)
.finish(),
}
}
}
impl<TCodec> ProtocolsHandler for RequestResponseHandler<TCodec>
where
TCodec: RequestResponseCodec + Send + Clone + 'static,
{
type InEvent = RequestProtocol<TCodec>;
type OutEvent = RequestResponseHandlerEvent<TCodec>;
type Error = ProtocolsHandlerUpgrErr<io::Error>;
type InboundProtocol = ResponseProtocol<TCodec>;
type OutboundProtocol = RequestProtocol<TCodec>;
type OutboundOpenInfo = RequestId;
type InboundOpenInfo = RequestId;
fn listen_protocol(&self) -> SubstreamProtocol<Self::InboundProtocol, Self::InboundOpenInfo> {
let (rq_send, rq_recv) = oneshot::channel();
let (rs_send, rs_recv) = oneshot::channel();
let request_id = RequestId(self.inbound_request_id.fetch_add(1, Ordering::Relaxed));
let proto = ResponseProtocol {
protocols: self.inbound_protocols.clone(),
codec: self.codec.clone(),
request_sender: rq_send,
response_receiver: rs_recv,
request_id,
};
self.inbound
.push(rq_recv.map_ok(move |rq| (rq, rs_send)).boxed());
SubstreamProtocol::new(proto, request_id).with_timeout(self.substream_timeout)
}
fn inject_fully_negotiated_inbound(&mut self, sent: bool, request_id: RequestId) {
if sent {
self.pending_events
.push_back(RequestResponseHandlerEvent::ResponseSent(request_id))
} else {
self.pending_events
.push_back(RequestResponseHandlerEvent::ResponseOmission(request_id))
}
}
fn inject_fully_negotiated_outbound(
&mut self,
response: TCodec::Response,
request_id: RequestId,
) {
self.pending_events
.push_back(RequestResponseHandlerEvent::Response {
request_id,
response,
});
}
fn inject_event(&mut self, request: Self::InEvent) {
self.keep_alive = KeepAlive::Yes;
self.outbound.push_back(request);
}
fn inject_dial_upgrade_error(
&mut self,
info: RequestId,
error: ProtocolsHandlerUpgrErr<io::Error>,
) {
match error {
ProtocolsHandlerUpgrErr::Timeout => {
self.pending_events
.push_back(RequestResponseHandlerEvent::OutboundTimeout(info));
}
ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Select(NegotiationError::Failed)) => {
self.pending_events.push_back(
RequestResponseHandlerEvent::OutboundUnsupportedProtocols(info),
);
}
_ => {
self.pending_error = Some(error);
}
}
}
fn inject_listen_upgrade_error(
&mut self,
info: RequestId,
error: ProtocolsHandlerUpgrErr<io::Error>,
) {
match error {
ProtocolsHandlerUpgrErr::Timeout => self
.pending_events
.push_back(RequestResponseHandlerEvent::InboundTimeout(info)),
ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Select(NegotiationError::Failed)) => {
self.pending_events.push_back(
RequestResponseHandlerEvent::InboundUnsupportedProtocols(info),
);
}
_ => {
self.pending_error = Some(error);
}
}
}
fn connection_keep_alive(&self) -> KeepAlive {
self.keep_alive
}
fn poll(
&mut self,
cx: &mut Context<'_>,
) -> Poll<ProtocolsHandlerEvent<RequestProtocol<TCodec>, RequestId, Self::OutEvent, Self::Error>>
{
if let Some(err) = self.pending_error.take() {
return Poll::Ready(ProtocolsHandlerEvent::Close(err));
}
if let Some(event) = self.pending_events.pop_front() {
return Poll::Ready(ProtocolsHandlerEvent::Custom(event));
} else if self.pending_events.capacity() > EMPTY_QUEUE_SHRINK_THRESHOLD {
self.pending_events.shrink_to_fit();
}
while let Poll::Ready(Some(result)) = self.inbound.poll_next_unpin(cx) {
match result {
Ok(((id, rq), rs_sender)) => {
self.keep_alive = KeepAlive::Yes;
return Poll::Ready(ProtocolsHandlerEvent::Custom(
RequestResponseHandlerEvent::Request {
request_id: id,
request: rq,
sender: rs_sender,
},
));
}
Err(oneshot::Canceled) => {
}
}
}
if let Some(request) = self.outbound.pop_front() {
let info = request.request_id;
return Poll::Ready(ProtocolsHandlerEvent::OutboundSubstreamRequest {
protocol: SubstreamProtocol::new(request, info)
.with_timeout(self.substream_timeout),
});
}
debug_assert!(self.outbound.is_empty());
if self.outbound.capacity() > EMPTY_QUEUE_SHRINK_THRESHOLD {
self.outbound.shrink_to_fit();
}
if self.inbound.is_empty() && self.keep_alive.is_yes() {
let until = Instant::now() + self.substream_timeout + self.keep_alive_timeout;
self.keep_alive = KeepAlive::Until(until);
}
Poll::Pending
}
}