1mod protocol;
22
23use crate::{EMPTY_QUEUE_SHRINK_THRESHOLD, RequestId};
24use crate::codec::RequestResponseCodec;
25
26pub use protocol::{RequestProtocol, ResponseProtocol, ProtocolSupport};
27
28use futures::{
29 channel::oneshot,
30 future::BoxFuture,
31 prelude::*,
32 stream::FuturesUnordered
33};
34use tetsy_libp2p_core::{
35 upgrade::{UpgradeError, NegotiationError},
36};
37use tetsy_libp2p_swarm::{
38 SubstreamProtocol,
39 protocols_handler::{
40 KeepAlive,
41 ProtocolsHandler,
42 ProtocolsHandlerEvent,
43 ProtocolsHandlerUpgrErr,
44 }
45};
46use smallvec::SmallVec;
47use std::{
48 collections::VecDeque,
49 io,
50 sync::{atomic::{AtomicU64, Ordering}, Arc},
51 time::Duration,
52 task::{Context, Poll}
53};
54use wasm_timer::Instant;
55
56#[doc(hidden)]
58pub struct RequestResponseHandler<TCodec>
59where
60 TCodec: RequestResponseCodec,
61{
62 inbound_protocols: SmallVec<[TCodec::Protocol; 2]>,
64 codec: TCodec,
66 keep_alive_timeout: Duration,
69 substream_timeout: Duration,
72 keep_alive: KeepAlive,
74 pending_error: Option<ProtocolsHandlerUpgrErr<io::Error>>,
76 pending_events: VecDeque<RequestResponseHandlerEvent<TCodec>>,
78 outbound: VecDeque<RequestProtocol<TCodec>>,
80 inbound: FuturesUnordered<BoxFuture<'static,
82 Result<
83 ((RequestId, TCodec::Request), oneshot::Sender<TCodec::Response>),
84 oneshot::Canceled
85 >>>,
86 inbound_request_id: Arc<AtomicU64>
87}
88
89impl<TCodec> RequestResponseHandler<TCodec>
90where
91 TCodec: RequestResponseCodec,
92{
93 pub(super) fn new(
94 inbound_protocols: SmallVec<[TCodec::Protocol; 2]>,
95 codec: TCodec,
96 keep_alive_timeout: Duration,
97 substream_timeout: Duration,
98 inbound_request_id: Arc<AtomicU64>
99 ) -> Self {
100 Self {
101 inbound_protocols,
102 codec,
103 keep_alive: KeepAlive::Yes,
104 keep_alive_timeout,
105 substream_timeout,
106 outbound: VecDeque::new(),
107 inbound: FuturesUnordered::new(),
108 pending_events: VecDeque::new(),
109 pending_error: None,
110 inbound_request_id
111 }
112 }
113}
114
115#[doc(hidden)]
117#[derive(Debug)]
118pub enum RequestResponseHandlerEvent<TCodec>
119where
120 TCodec: RequestResponseCodec
121{
122 Request {
124 request_id: RequestId,
125 request: TCodec::Request,
126 sender: oneshot::Sender<TCodec::Response>
127 },
128 Response {
130 request_id: RequestId,
131 response: TCodec::Response
132 },
133 ResponseSent(RequestId),
135 ResponseOmission(RequestId),
138 OutboundTimeout(RequestId),
141 OutboundUnsupportedProtocols(RequestId),
143 InboundTimeout(RequestId),
146 InboundUnsupportedProtocols(RequestId),
148}
149
150impl<TCodec> ProtocolsHandler for RequestResponseHandler<TCodec>
151where
152 TCodec: RequestResponseCodec + Send + Clone + 'static,
153{
154 type InEvent = RequestProtocol<TCodec>;
155 type OutEvent = RequestResponseHandlerEvent<TCodec>;
156 type Error = ProtocolsHandlerUpgrErr<io::Error>;
157 type InboundProtocol = ResponseProtocol<TCodec>;
158 type OutboundProtocol = RequestProtocol<TCodec>;
159 type OutboundOpenInfo = RequestId;
160 type InboundOpenInfo = RequestId;
161
162 fn listen_protocol(&self) -> SubstreamProtocol<Self::InboundProtocol, Self::InboundOpenInfo> {
163 let (rq_send, rq_recv) = oneshot::channel();
166
167 let (rs_send, rs_recv) = oneshot::channel();
170
171 let request_id = RequestId(self.inbound_request_id.fetch_add(1, Ordering::Relaxed));
172
173 let proto = ResponseProtocol {
180 protocols: self.inbound_protocols.clone(),
181 codec: self.codec.clone(),
182 request_sender: rq_send,
183 response_receiver: rs_recv,
184 request_id
185 };
186
187 self.inbound.push(rq_recv.map_ok(move |rq| (rq, rs_send)).boxed());
191
192 SubstreamProtocol::new(proto, request_id).with_timeout(self.substream_timeout)
193 }
194
195 fn inject_fully_negotiated_inbound(
196 &mut self,
197 sent: bool,
198 request_id: RequestId
199 ) {
200 if sent {
201 self.pending_events.push_back(
202 RequestResponseHandlerEvent::ResponseSent(request_id))
203 } else {
204 self.pending_events.push_back(
205 RequestResponseHandlerEvent::ResponseOmission(request_id))
206 }
207 }
208
209 fn inject_fully_negotiated_outbound(
210 &mut self,
211 response: TCodec::Response,
212 request_id: RequestId,
213 ) {
214 self.pending_events.push_back(
215 RequestResponseHandlerEvent::Response {
216 request_id, response
217 });
218 }
219
220 fn inject_event(&mut self, request: Self::InEvent) {
221 self.keep_alive = KeepAlive::Yes;
222 self.outbound.push_back(request);
223 }
224
225 fn inject_dial_upgrade_error(
226 &mut self,
227 info: RequestId,
228 error: ProtocolsHandlerUpgrErr<io::Error>,
229 ) {
230 match error {
231 ProtocolsHandlerUpgrErr::Timeout => {
232 self.pending_events.push_back(
233 RequestResponseHandlerEvent::OutboundTimeout(info));
234 }
235 ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Select(NegotiationError::Failed)) => {
236 self.pending_events.push_back(
242 RequestResponseHandlerEvent::OutboundUnsupportedProtocols(info));
243 }
244 _ => {
245 self.pending_error = Some(error);
248 }
249 }
250 }
251
252 fn inject_listen_upgrade_error(
253 &mut self,
254 info: RequestId,
255 error: ProtocolsHandlerUpgrErr<io::Error>
256 ) {
257 match error {
258 ProtocolsHandlerUpgrErr::Timeout => {
259 self.pending_events.push_back(RequestResponseHandlerEvent::InboundTimeout(info))
260 }
261 ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Select(NegotiationError::Failed)) => {
262 self.pending_events.push_back(
268 RequestResponseHandlerEvent::InboundUnsupportedProtocols(info));
269 }
270 _ => {
271 self.pending_error = Some(error);
274 }
275 }
276 }
277
278 fn connection_keep_alive(&self) -> KeepAlive {
279 self.keep_alive
280 }
281
282 fn poll(
283 &mut self,
284 cx: &mut Context<'_>,
285 ) -> Poll<
286 ProtocolsHandlerEvent<RequestProtocol<TCodec>, RequestId, Self::OutEvent, Self::Error>,
287 > {
288 if let Some(err) = self.pending_error.take() {
290 return Poll::Ready(ProtocolsHandlerEvent::Close(err))
292 }
293
294 if let Some(event) = self.pending_events.pop_front() {
296 return Poll::Ready(ProtocolsHandlerEvent::Custom(event))
297 } else if self.pending_events.capacity() > EMPTY_QUEUE_SHRINK_THRESHOLD {
298 self.pending_events.shrink_to_fit();
299 }
300
301 while let Poll::Ready(Some(result)) = self.inbound.poll_next_unpin(cx) {
303 match result {
304 Ok(((id, rq), rs_sender)) => {
305 self.keep_alive = KeepAlive::Yes;
307 return Poll::Ready(ProtocolsHandlerEvent::Custom(
308 RequestResponseHandlerEvent::Request {
309 request_id: id, request: rq, sender: rs_sender
310 }))
311 }
312 Err(oneshot::Canceled) => {
313 }
317 }
318 }
319
320 if let Some(request) = self.outbound.pop_front() {
322 let info = request.request_id;
323 return Poll::Ready(
324 ProtocolsHandlerEvent::OutboundSubstreamRequest {
325 protocol: SubstreamProtocol::new(request, info)
326 .with_timeout(self.substream_timeout)
327 },
328 )
329 }
330
331 debug_assert!(self.outbound.is_empty());
332
333 if self.outbound.capacity() > EMPTY_QUEUE_SHRINK_THRESHOLD {
334 self.outbound.shrink_to_fit();
335 }
336
337 if self.inbound.is_empty() && self.keep_alive.is_yes() {
338 let until = Instant::now() + self.substream_timeout + self.keep_alive_timeout;
342 self.keep_alive = KeepAlive::Until(until);
343 }
344
345 Poll::Pending
346 }
347}
348