1pub mod handler;
2
3pub use handler::Handler;
4
5use std::{
6 collections::{HashMap, HashSet, VecDeque},
7 task::{Context, Poll},
8};
9
10use smallvec::SmallVec;
11use volans_core::{Multiaddr, PeerId};
12use volans_swarm::{
13 BehaviorEvent, ConnectionDenied, ConnectionId, DialOpts, NetworkBehavior,
14 NetworkOutgoingBehavior, THandlerAction, THandlerEvent,
15 behavior::NotifyHandler,
16 error::{ConnectionError, DialError},
17};
18
19use crate::{Codec, Config, OutboundFailure, RequestId, client::handler::OutboundRequest};
20
21pub struct Behavior<TCodec>
22where
23 TCodec: Codec + Clone + Send + 'static,
24{
25 clients: HashMap<PeerId, SmallVec<[ConnectionId; 2]>>,
26 codec: TCodec,
27 config: Config,
28 pending_event: VecDeque<BehaviorEvent<Event<TCodec::Response>, THandlerAction<Self>>>,
29 pending_response: HashSet<RequestId>,
30 pending_requests: HashMap<PeerId, SmallVec<[OutboundRequest<TCodec>; 10]>>,
31 pending_dial: HashSet<PeerId>,
32}
33
34impl<TCodec> Behavior<TCodec>
35where
36 TCodec: Codec + Clone + Send + 'static,
37{
38 pub fn with_codec(codec: TCodec, config: Config) -> Self {
39 Self {
40 clients: HashMap::new(),
41 codec,
42 config,
43 pending_event: VecDeque::new(),
44 pending_response: HashSet::new(),
45 pending_requests: HashMap::new(),
46 pending_dial: HashSet::new(),
47 }
48 }
49
50 pub fn send_request(
51 &mut self,
52 peer_id: PeerId,
53 protocol: TCodec::Protocol,
54 request: TCodec::Request,
55 ) -> RequestId {
56 let request_id = RequestId::next();
57 let request = OutboundRequest {
58 request_id,
59 request,
60 protocol,
61 };
62 if let Some(request) = self.try_send_request(&peer_id, request) {
63 self.pending_dial.insert(peer_id);
64 self.pending_requests
65 .entry(peer_id)
66 .or_default()
67 .push(request);
68 }
69 request_id
70 }
71
72 fn remove_pending_response(&mut self, request_id: RequestId) -> bool {
74 self.pending_response.remove(&request_id)
75 }
76
77 fn try_send_request(
78 &mut self,
79 peer_id: &PeerId,
80 request: OutboundRequest<TCodec>,
81 ) -> Option<OutboundRequest<TCodec>> {
82 if let Some(connections) = self.clients.get_mut(peer_id) {
83 if connections.is_empty() {
84 return Some(request);
85 }
86 let index = request.request_id.0 & connections.len();
87 let connection_id = &mut connections[index];
88 self.pending_response.insert(request.request_id);
89 self.pending_event.push_back(BehaviorEvent::HandlerAction {
90 peer_id: *peer_id,
91 handler: NotifyHandler::One(*connection_id),
92 action: request,
93 });
94 None
95 } else {
96 Some(request)
97 }
98 }
99}
100
101#[derive(Debug)]
102pub enum Event<TResponse> {
103 Response {
104 peer_id: PeerId,
105 connection_id: ConnectionId,
106 request_id: RequestId,
107 response: TResponse,
108 },
109 Failure {
110 peer_id: PeerId,
111 connection_id: ConnectionId,
112 request_id: RequestId,
113 cause: OutboundFailure,
114 },
115}
116
117impl<TCodec> NetworkBehavior for Behavior<TCodec>
118where
119 TCodec: Codec + Clone + Send + 'static,
120{
121 type ConnectionHandler = Handler<TCodec>;
122 type Event = Event<TCodec::Response>;
123 fn on_connection_handler_event(
124 &mut self,
125 id: ConnectionId,
126 peer_id: PeerId,
127 event: THandlerEvent<Self>,
128 ) {
129 match event {
130 handler::Event::Response {
131 request_id,
132 response,
133 } => {
134 let removed = self.remove_pending_response(request_id);
135 debug_assert!(removed, "Response for unknown request: {request_id}");
136 self.pending_event
137 .push_back(BehaviorEvent::Behavior(Event::Response {
138 peer_id,
139 connection_id: id,
140 request_id,
141 response,
142 }));
143 }
144 handler::Event::Unsupported(request_id) => {
145 let removed = self.remove_pending_response(request_id);
146 debug_assert!(removed, "Response for unknown request: {request_id}");
147 self.pending_event
148 .push_back(BehaviorEvent::Behavior(Event::Failure {
149 peer_id,
150 connection_id: id,
151 request_id,
152 cause: OutboundFailure::UnsupportedProtocols,
153 }));
154 }
155 handler::Event::StreamError { request_id, error } => {
156 let removed = self.remove_pending_response(request_id);
157 debug_assert!(removed, "Response for unknown request: {request_id}");
158 self.pending_event
159 .push_back(BehaviorEvent::Behavior(Event::Failure {
160 peer_id,
161 connection_id: id,
162 request_id,
163 cause: error.into(),
164 }));
165 }
166 handler::Event::Timeout(request_id) => {
167 let removed = self.remove_pending_response(request_id);
168 debug_assert!(removed, "Response for unknown request: {request_id}");
169 self.pending_event
170 .push_back(BehaviorEvent::Behavior(Event::Failure {
171 peer_id,
172 connection_id: id,
173 request_id,
174 cause: OutboundFailure::Timeout,
175 }));
176 }
177 }
178 }
179
180 fn poll(
181 &mut self,
182 _cx: &mut Context<'_>,
183 ) -> Poll<BehaviorEvent<Self::Event, THandlerAction<Self>>> {
184 if let Some(event) = self.pending_event.pop_front() {
185 return Poll::Ready(event);
186 }
187 Poll::Pending
188 }
189}
190
191impl<TCodec> NetworkOutgoingBehavior for Behavior<TCodec>
192where
193 TCodec: Codec + Clone + Send + 'static,
194{
195 fn handle_established_connection(
196 &mut self,
197 _id: ConnectionId,
198 _peer_id: PeerId,
199 _addr: &Multiaddr,
200 ) -> Result<Self::ConnectionHandler, ConnectionDenied> {
201 let handler = handler::Handler::new(self.codec.clone(), self.config.request_timeout);
202 Ok(handler)
203 }
204
205 fn on_connection_established(&mut self, id: ConnectionId, peer_id: PeerId, _addr: &Multiaddr) {
206 self.clients.entry(peer_id).or_default().push(id);
207 }
208
209 fn on_connection_closed(
210 &mut self,
211 id: ConnectionId,
212 peer_id: PeerId,
213 _addr: &Multiaddr,
214 _reason: Option<&ConnectionError>,
215 ) {
216 self.clients
217 .entry(peer_id)
218 .or_default()
219 .retain(|x| *x != id);
220 if self
221 .clients
222 .get(&peer_id)
223 .map(|v| v.is_empty())
224 .unwrap_or(false)
225 {
226 self.clients.remove(&peer_id);
227 }
228 }
229
230 fn on_dial_failure(
231 &mut self,
232 id: ConnectionId,
233 peer_id: Option<PeerId>,
234 _addr: Option<&Multiaddr>,
235 _error: &DialError,
236 ) {
237 if let Some(peer) = peer_id {
238 if let Some(pending) = self.pending_requests.remove(&peer) {
239 for request in pending {
240 let event = Event::Failure {
241 peer_id: peer,
242 connection_id: id,
243 request_id: request.request_id,
244 cause: OutboundFailure::DialFailure,
245 };
246 self.pending_event.push_back(BehaviorEvent::Behavior(event));
247 }
248 }
249 }
250 }
251
252 fn poll_dial(&mut self, _cx: &mut Context<'_>) -> Poll<DialOpts> {
253 if let Some(peer_id) = self.pending_dial.iter().next().cloned() {
254 self.pending_dial.remove(&peer_id);
255 Poll::Ready(DialOpts::new(None, Some(peer_id)))
256 } else {
257 Poll::Pending
258 }
259 }
260}