libp2p_swarm/protocols_handler/
one_shot.rs1use crate::upgrade::{InboundUpgradeSend, OutboundUpgradeSend};
22use crate::protocols_handler::{
23 KeepAlive,
24 ProtocolsHandler,
25 ProtocolsHandlerEvent,
26 ProtocolsHandlerUpgrErr,
27 SubstreamProtocol
28};
29
30use smallvec::SmallVec;
31use std::{error, task::Context, task::Poll, time::Duration};
32use wasm_timer::Instant;
33
34pub struct OneShotHandler<TInbound, TOutbound, TEvent>
37where
38 TOutbound: OutboundUpgradeSend,
39{
40 listen_protocol: SubstreamProtocol<TInbound, ()>,
42 pending_error: Option<ProtocolsHandlerUpgrErr<<TOutbound as OutboundUpgradeSend>::Error>>,
44 events_out: SmallVec<[TEvent; 4]>,
46 dial_queue: SmallVec<[TOutbound; 4]>,
48 dial_negotiated: u32,
50 keep_alive: KeepAlive,
52 config: OneShotHandlerConfig,
54}
55
56impl<TInbound, TOutbound, TEvent>
57 OneShotHandler<TInbound, TOutbound, TEvent>
58where
59 TOutbound: OutboundUpgradeSend,
60{
61 pub fn new(
63 listen_protocol: SubstreamProtocol<TInbound, ()>,
64 config: OneShotHandlerConfig,
65 ) -> Self {
66 OneShotHandler {
67 listen_protocol,
68 pending_error: None,
69 events_out: SmallVec::new(),
70 dial_queue: SmallVec::new(),
71 dial_negotiated: 0,
72 keep_alive: KeepAlive::Yes,
73 config,
74 }
75 }
76
77 pub fn pending_requests(&self) -> u32 {
79 self.dial_negotiated + self.dial_queue.len() as u32
80 }
81
82 pub fn listen_protocol_ref(&self) -> &SubstreamProtocol<TInbound, ()> {
87 &self.listen_protocol
88 }
89
90 pub fn listen_protocol_mut(&mut self) -> &mut SubstreamProtocol<TInbound, ()> {
95 &mut self.listen_protocol
96 }
97
98 pub fn send_request(&mut self, upgrade: TOutbound) {
100 self.keep_alive = KeepAlive::Yes;
101 self.dial_queue.push(upgrade);
102 }
103}
104
105impl<TInbound, TOutbound, TEvent> Default
106 for OneShotHandler<TInbound, TOutbound, TEvent>
107where
108 TOutbound: OutboundUpgradeSend,
109 TInbound: InboundUpgradeSend + Default,
110{
111 fn default() -> Self {
112 OneShotHandler::new(
113 SubstreamProtocol::new(Default::default(), ()),
114 OneShotHandlerConfig::default()
115 )
116 }
117}
118
119impl<TInbound, TOutbound, TEvent> ProtocolsHandler for OneShotHandler<TInbound, TOutbound, TEvent>
120where
121 TInbound: InboundUpgradeSend + Send + 'static,
122 TOutbound: OutboundUpgradeSend,
123 TInbound::Output: Into<TEvent>,
124 TOutbound::Output: Into<TEvent>,
125 TOutbound::Error: error::Error + Send + 'static,
126 SubstreamProtocol<TInbound, ()>: Clone,
127 TEvent: Send + 'static,
128{
129 type InEvent = TOutbound;
130 type OutEvent = TEvent;
131 type Error = ProtocolsHandlerUpgrErr<
132 <Self::OutboundProtocol as OutboundUpgradeSend>::Error,
133 >;
134 type InboundProtocol = TInbound;
135 type OutboundProtocol = TOutbound;
136 type OutboundOpenInfo = ();
137 type InboundOpenInfo = ();
138
139 fn listen_protocol(&self) -> SubstreamProtocol<Self::InboundProtocol, Self::InboundOpenInfo> {
140 self.listen_protocol.clone()
141 }
142
143 fn inject_fully_negotiated_inbound(
144 &mut self,
145 out: <Self::InboundProtocol as InboundUpgradeSend>::Output,
146 (): Self::InboundOpenInfo
147 ) {
148 if !self.keep_alive.is_yes() {
150 self.keep_alive = KeepAlive::Until(Instant::now() + self.config.keep_alive_timeout);
151 }
152
153 self.events_out.push(out.into());
154 }
155
156 fn inject_fully_negotiated_outbound(
157 &mut self,
158 out: <Self::OutboundProtocol as OutboundUpgradeSend>::Output,
159 _: Self::OutboundOpenInfo,
160 ) {
161 self.dial_negotiated -= 1;
162 self.events_out.push(out.into());
163 }
164
165 fn inject_event(&mut self, event: Self::InEvent) {
166 self.send_request(event);
167 }
168
169 fn inject_dial_upgrade_error(
170 &mut self,
171 _info: Self::OutboundOpenInfo,
172 error: ProtocolsHandlerUpgrErr<
173 <Self::OutboundProtocol as OutboundUpgradeSend>::Error,
174 >,
175 ) {
176 if self.pending_error.is_none() {
177 self.pending_error = Some(error);
178 }
179 }
180
181 fn connection_keep_alive(&self) -> KeepAlive {
182 self.keep_alive
183 }
184
185 fn poll(
186 &mut self,
187 _: &mut Context<'_>,
188 ) -> Poll<
189 ProtocolsHandlerEvent<Self::OutboundProtocol, Self::OutboundOpenInfo, Self::OutEvent, Self::Error>,
190 > {
191 if let Some(err) = self.pending_error.take() {
192 return Poll::Ready(ProtocolsHandlerEvent::Close(err))
193 }
194
195 if !self.events_out.is_empty() {
196 return Poll::Ready(ProtocolsHandlerEvent::Custom(
197 self.events_out.remove(0)
198 ));
199 } else {
200 self.events_out.shrink_to_fit();
201 }
202
203 if !self.dial_queue.is_empty() {
204 if self.dial_negotiated < self.config.max_dial_negotiated {
205 self.dial_negotiated += 1;
206 let upgrade = self.dial_queue.remove(0);
207 return Poll::Ready(
208 ProtocolsHandlerEvent::OutboundSubstreamRequest {
209 protocol: SubstreamProtocol::new(upgrade, ())
210 .with_timeout(self.config.outbound_substream_timeout)
211 },
212 );
213 }
214 } else {
215 self.dial_queue.shrink_to_fit();
216
217 if self.dial_negotiated == 0 && self.keep_alive.is_yes() {
218 self.keep_alive = KeepAlive::Until(Instant::now() + self.config.keep_alive_timeout);
219 }
220 }
221
222 Poll::Pending
223 }
224}
225
226#[derive(Debug)]
228pub struct OneShotHandlerConfig {
229 pub keep_alive_timeout: Duration,
231 pub outbound_substream_timeout: Duration,
233 pub max_dial_negotiated: u32,
235}
236
237impl Default for OneShotHandlerConfig {
238 fn default() -> Self {
239 OneShotHandlerConfig {
240 keep_alive_timeout: Duration::from_secs(10),
241 outbound_substream_timeout: Duration::from_secs(10),
242 max_dial_negotiated: 8,
243 }
244 }
245}
246
247#[cfg(test)]
248mod tests {
249 use super::*;
250
251 use futures::executor::block_on;
252 use futures::future::poll_fn;
253 use libp2p_core::upgrade::DeniedUpgrade;
254 use void::Void;
255
256 #[test]
257 fn do_not_keep_idle_connection_alive() {
258 let mut handler: OneShotHandler<_, DeniedUpgrade, Void> = OneShotHandler::new(
259 SubstreamProtocol::new(DeniedUpgrade{}, ()),
260 Default::default(),
261 );
262
263 block_on(poll_fn(|cx| {
264 loop {
265 if let Poll::Pending = handler.poll(cx) {
266 return Poll::Ready(())
267 }
268 }
269 }));
270
271 assert!(matches!(handler.connection_keep_alive(), KeepAlive::Until(_)));
272 }
273}