mod protocol;
use crate::{EMPTY_QUEUE_SHRINK_THRESHOLD, RequestId};
use crate::codec::RequestResponseCodec;
pub use protocol::{RequestProtocol, ResponseProtocol, ProtocolSupport};
use futures::{
channel::oneshot,
future::BoxFuture,
prelude::*,
stream::FuturesUnordered
};
use tetsy_libp2p_core::{
upgrade::{UpgradeError, NegotiationError},
};
use tetsy_libp2p_swarm::{
SubstreamProtocol,
protocols_handler::{
KeepAlive,
ProtocolsHandler,
ProtocolsHandlerEvent,
ProtocolsHandlerUpgrErr,
}
};
use smallvec::SmallVec;
use std::{
collections::VecDeque,
io,
sync::{atomic::{AtomicU64, Ordering}, Arc},
time::Duration,
task::{Context, Poll}
};
use wasm_timer::Instant;
#[doc(hidden)]
pub struct RequestResponseHandler<TCodec>
where
TCodec: RequestResponseCodec,
{
inbound_protocols: SmallVec<[TCodec::Protocol; 2]>,
codec: TCodec,
keep_alive_timeout: Duration,
substream_timeout: Duration,
keep_alive: KeepAlive,
pending_error: Option<ProtocolsHandlerUpgrErr<io::Error>>,
pending_events: VecDeque<RequestResponseHandlerEvent<TCodec>>,
outbound: VecDeque<RequestProtocol<TCodec>>,
inbound: FuturesUnordered<BoxFuture<'static,
Result<
((RequestId, TCodec::Request), oneshot::Sender<TCodec::Response>),
oneshot::Canceled
>>>,
inbound_request_id: Arc<AtomicU64>
}
impl<TCodec> RequestResponseHandler<TCodec>
where
TCodec: RequestResponseCodec,
{
pub(super) fn new(
inbound_protocols: SmallVec<[TCodec::Protocol; 2]>,
codec: TCodec,
keep_alive_timeout: Duration,
substream_timeout: Duration,
inbound_request_id: Arc<AtomicU64>
) -> Self {
Self {
inbound_protocols,
codec,
keep_alive: KeepAlive::Yes,
keep_alive_timeout,
substream_timeout,
outbound: VecDeque::new(),
inbound: FuturesUnordered::new(),
pending_events: VecDeque::new(),
pending_error: None,
inbound_request_id
}
}
}
#[doc(hidden)]
#[derive(Debug)]
pub enum RequestResponseHandlerEvent<TCodec>
where
TCodec: RequestResponseCodec
{
Request {
request_id: RequestId,
request: TCodec::Request,
sender: oneshot::Sender<TCodec::Response>
},
Response {
request_id: RequestId,
response: TCodec::Response
},
ResponseSent(RequestId),
ResponseOmission(RequestId),
OutboundTimeout(RequestId),
OutboundUnsupportedProtocols(RequestId),
InboundTimeout(RequestId),
InboundUnsupportedProtocols(RequestId),
}
impl<TCodec> ProtocolsHandler for RequestResponseHandler<TCodec>
where
TCodec: RequestResponseCodec + Send + Clone + 'static,
{
type InEvent = RequestProtocol<TCodec>;
type OutEvent = RequestResponseHandlerEvent<TCodec>;
type Error = ProtocolsHandlerUpgrErr<io::Error>;
type InboundProtocol = ResponseProtocol<TCodec>;
type OutboundProtocol = RequestProtocol<TCodec>;
type OutboundOpenInfo = RequestId;
type InboundOpenInfo = RequestId;
fn listen_protocol(&self) -> SubstreamProtocol<Self::InboundProtocol, Self::InboundOpenInfo> {
let (rq_send, rq_recv) = oneshot::channel();
let (rs_send, rs_recv) = oneshot::channel();
let request_id = RequestId(self.inbound_request_id.fetch_add(1, Ordering::Relaxed));
let proto = ResponseProtocol {
protocols: self.inbound_protocols.clone(),
codec: self.codec.clone(),
request_sender: rq_send,
response_receiver: rs_recv,
request_id
};
self.inbound.push(rq_recv.map_ok(move |rq| (rq, rs_send)).boxed());
SubstreamProtocol::new(proto, request_id).with_timeout(self.substream_timeout)
}
fn inject_fully_negotiated_inbound(
&mut self,
sent: bool,
request_id: RequestId
) {
if sent {
self.pending_events.push_back(
RequestResponseHandlerEvent::ResponseSent(request_id))
} else {
self.pending_events.push_back(
RequestResponseHandlerEvent::ResponseOmission(request_id))
}
}
fn inject_fully_negotiated_outbound(
&mut self,
response: TCodec::Response,
request_id: RequestId,
) {
self.pending_events.push_back(
RequestResponseHandlerEvent::Response {
request_id, response
});
}
fn inject_event(&mut self, request: Self::InEvent) {
self.keep_alive = KeepAlive::Yes;
self.outbound.push_back(request);
}
fn inject_dial_upgrade_error(
&mut self,
info: RequestId,
error: ProtocolsHandlerUpgrErr<io::Error>,
) {
match error {
ProtocolsHandlerUpgrErr::Timeout => {
self.pending_events.push_back(
RequestResponseHandlerEvent::OutboundTimeout(info));
}
ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Select(NegotiationError::Failed)) => {
self.pending_events.push_back(
RequestResponseHandlerEvent::OutboundUnsupportedProtocols(info));
}
_ => {
self.pending_error = Some(error);
}
}
}
fn inject_listen_upgrade_error(
&mut self,
info: RequestId,
error: ProtocolsHandlerUpgrErr<io::Error>
) {
match error {
ProtocolsHandlerUpgrErr::Timeout => {
self.pending_events.push_back(RequestResponseHandlerEvent::InboundTimeout(info))
}
ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Select(NegotiationError::Failed)) => {
self.pending_events.push_back(
RequestResponseHandlerEvent::InboundUnsupportedProtocols(info));
}
_ => {
self.pending_error = Some(error);
}
}
}
fn connection_keep_alive(&self) -> KeepAlive {
self.keep_alive
}
fn poll(
&mut self,
cx: &mut Context<'_>,
) -> Poll<
ProtocolsHandlerEvent<RequestProtocol<TCodec>, RequestId, Self::OutEvent, Self::Error>,
> {
if let Some(err) = self.pending_error.take() {
return Poll::Ready(ProtocolsHandlerEvent::Close(err))
}
if let Some(event) = self.pending_events.pop_front() {
return Poll::Ready(ProtocolsHandlerEvent::Custom(event))
} else if self.pending_events.capacity() > EMPTY_QUEUE_SHRINK_THRESHOLD {
self.pending_events.shrink_to_fit();
}
while let Poll::Ready(Some(result)) = self.inbound.poll_next_unpin(cx) {
match result {
Ok(((id, rq), rs_sender)) => {
self.keep_alive = KeepAlive::Yes;
return Poll::Ready(ProtocolsHandlerEvent::Custom(
RequestResponseHandlerEvent::Request {
request_id: id, request: rq, sender: rs_sender
}))
}
Err(oneshot::Canceled) => {
}
}
}
if let Some(request) = self.outbound.pop_front() {
let info = request.request_id;
return Poll::Ready(
ProtocolsHandlerEvent::OutboundSubstreamRequest {
protocol: SubstreamProtocol::new(request, info)
.with_timeout(self.substream_timeout)
},
)
}
debug_assert!(self.outbound.is_empty());
if self.outbound.capacity() > EMPTY_QUEUE_SHRINK_THRESHOLD {
self.outbound.shrink_to_fit();
}
if self.inbound.is_empty() && self.keep_alive.is_yes() {
let until = Instant::now() + self.substream_timeout + self.keep_alive_timeout;
self.keep_alive = KeepAlive::Until(until);
}
Poll::Pending
}
}