ant_libp2p_swarm/handler/
one_shot.rs1use std::{
22 error,
23 fmt::Debug,
24 task::{Context, Poll},
25 time::Duration,
26};
27
28use smallvec::SmallVec;
29
30use crate::{
31 handler::{
32 ConnectionEvent, ConnectionHandler, ConnectionHandlerEvent, DialUpgradeError,
33 FullyNegotiatedInbound, FullyNegotiatedOutbound, SubstreamProtocol,
34 },
35 upgrade::{InboundUpgradeSend, OutboundUpgradeSend},
36 StreamUpgradeError,
37};
38
39pub struct OneShotHandler<TInbound, TOutbound, TEvent>
42where
43 TOutbound: OutboundUpgradeSend,
44{
45 listen_protocol: SubstreamProtocol<TInbound, ()>,
47 events_out: SmallVec<[Result<TEvent, StreamUpgradeError<TOutbound::Error>>; 4]>,
49 dial_queue: SmallVec<[TOutbound; 4]>,
51 dial_negotiated: u32,
53 config: OneShotHandlerConfig,
55}
56
57impl<TInbound, TOutbound, TEvent> OneShotHandler<TInbound, TOutbound, TEvent>
58where
59 TOutbound: OutboundUpgradeSend,
60{
61 pub fn new(
63 listen_protocol: SubstreamProtocol<TInbound, ()>,
64 config: OneShotHandlerConfig,
65 ) -> Self {
66 OneShotHandler {
67 listen_protocol,
68 events_out: SmallVec::new(),
69 dial_queue: SmallVec::new(),
70 dial_negotiated: 0,
71 config,
72 }
73 }
74
75 pub fn pending_requests(&self) -> u32 {
77 self.dial_negotiated + self.dial_queue.len() as u32
78 }
79
80 pub fn listen_protocol_ref(&self) -> &SubstreamProtocol<TInbound, ()> {
85 &self.listen_protocol
86 }
87
88 pub fn listen_protocol_mut(&mut self) -> &mut SubstreamProtocol<TInbound, ()> {
93 &mut self.listen_protocol
94 }
95
96 pub fn send_request(&mut self, upgrade: TOutbound) {
98 self.dial_queue.push(upgrade);
99 }
100}
101
102impl<TInbound, TOutbound, TEvent> Default for OneShotHandler<TInbound, TOutbound, TEvent>
103where
104 TOutbound: OutboundUpgradeSend,
105 TInbound: InboundUpgradeSend + Default,
106{
107 fn default() -> Self {
108 OneShotHandler::new(
109 SubstreamProtocol::new(Default::default(), ()),
110 OneShotHandlerConfig::default(),
111 )
112 }
113}
114
115impl<TInbound, TOutbound, TEvent> ConnectionHandler for OneShotHandler<TInbound, TOutbound, TEvent>
116where
117 TInbound: InboundUpgradeSend + Send + 'static,
118 TOutbound: Debug + OutboundUpgradeSend,
119 TInbound::Output: Into<TEvent>,
120 TOutbound::Output: Into<TEvent>,
121 TOutbound::Error: error::Error + Send + 'static,
122 SubstreamProtocol<TInbound, ()>: Clone,
123 TEvent: Debug + Send + 'static,
124{
125 type FromBehaviour = TOutbound;
126 type ToBehaviour = Result<TEvent, StreamUpgradeError<TOutbound::Error>>;
127 type InboundProtocol = TInbound;
128 type OutboundProtocol = TOutbound;
129 type OutboundOpenInfo = ();
130 type InboundOpenInfo = ();
131
132 fn listen_protocol(&self) -> SubstreamProtocol<Self::InboundProtocol, Self::InboundOpenInfo> {
133 self.listen_protocol.clone()
134 }
135
136 fn on_behaviour_event(&mut self, event: Self::FromBehaviour) {
137 self.send_request(event);
138 }
139
140 fn poll(
141 &mut self,
142 _: &mut Context<'_>,
143 ) -> Poll<
144 ConnectionHandlerEvent<Self::OutboundProtocol, Self::OutboundOpenInfo, Self::ToBehaviour>,
145 > {
146 if !self.events_out.is_empty() {
147 return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(
148 self.events_out.remove(0),
149 ));
150 } else {
151 self.events_out.shrink_to_fit();
152 }
153
154 if !self.dial_queue.is_empty() {
155 if self.dial_negotiated < self.config.max_dial_negotiated {
156 self.dial_negotiated += 1;
157 let upgrade = self.dial_queue.remove(0);
158 return Poll::Ready(ConnectionHandlerEvent::OutboundSubstreamRequest {
159 protocol: SubstreamProtocol::new(upgrade, ())
160 .with_timeout(self.config.outbound_substream_timeout),
161 });
162 }
163 } else {
164 self.dial_queue.shrink_to_fit();
165 }
166
167 Poll::Pending
168 }
169
170 fn on_connection_event(
171 &mut self,
172 event: ConnectionEvent<
173 Self::InboundProtocol,
174 Self::OutboundProtocol,
175 Self::InboundOpenInfo,
176 Self::OutboundOpenInfo,
177 >,
178 ) {
179 match event {
180 ConnectionEvent::FullyNegotiatedInbound(FullyNegotiatedInbound {
181 protocol: out,
182 ..
183 }) => {
184 self.events_out.push(Ok(out.into()));
185 }
186 ConnectionEvent::FullyNegotiatedOutbound(FullyNegotiatedOutbound {
187 protocol: out,
188 ..
189 }) => {
190 self.dial_negotiated -= 1;
191 self.events_out.push(Ok(out.into()));
192 }
193 ConnectionEvent::DialUpgradeError(DialUpgradeError { error, .. }) => {
194 self.events_out.push(Err(error));
195 }
196 ConnectionEvent::AddressChange(_)
197 | ConnectionEvent::ListenUpgradeError(_)
198 | ConnectionEvent::LocalProtocolsChange(_)
199 | ConnectionEvent::RemoteProtocolsChange(_) => {}
200 }
201 }
202}
203
204#[derive(Debug)]
206pub struct OneShotHandlerConfig {
207 pub outbound_substream_timeout: Duration,
209 pub max_dial_negotiated: u32,
211}
212
213impl Default for OneShotHandlerConfig {
214 fn default() -> Self {
215 OneShotHandlerConfig {
216 outbound_substream_timeout: Duration::from_secs(10),
217 max_dial_negotiated: 8,
218 }
219 }
220}
221
222#[cfg(test)]
223mod tests {
224 use std::convert::Infallible;
225
226 use futures::{executor::block_on, future::poll_fn};
227 use ant_libp2p_core::upgrade::DeniedUpgrade;
228
229 use super::*;
230
231 #[test]
232 fn do_not_keep_idle_connection_alive() {
233 let mut handler: OneShotHandler<_, DeniedUpgrade, Infallible> = OneShotHandler::new(
234 SubstreamProtocol::new(DeniedUpgrade {}, ()),
235 Default::default(),
236 );
237
238 block_on(poll_fn(|cx| loop {
239 if handler.poll(cx).is_pending() {
240 return Poll::Ready(());
241 }
242 }));
243
244 assert!(!handler.connection_keep_alive());
245 }
246}