1pub mod handler;
2
3pub use handler::Handler;
4
5use std::{
6 collections::{HashMap, HashSet, VecDeque},
7 task::{Context, Poll},
8};
9
10use smallvec::SmallVec;
11use volans_core::{PeerId, Url};
12use volans_swarm::{
13 BehaviorEvent, ConnectionDenied, ConnectionId, DialOpts, NetworkBehavior,
14 NetworkOutgoingBehavior, THandlerAction, THandlerEvent,
15 behavior::NotifyHandler,
16 error::{ConnectionError, DialError},
17};
18
19use crate::{Codec, Config, OutboundFailure, RequestId, client::handler::OutboundRequest};
20
21pub struct Behavior<TCodec>
22where
23 TCodec: Codec + Clone + Send + 'static,
24{
25 clients: HashMap<PeerId, SmallVec<[ConnectionId; 2]>>,
26 codec: TCodec,
27 config: Config,
28 pending_event: VecDeque<BehaviorEvent<Event<TCodec::Response>, THandlerAction<Self>>>,
29 pending_response: HashSet<RequestId>,
30 pending_requests: HashMap<PeerId, SmallVec<[OutboundRequest<TCodec>; 10]>>,
31 pending_dial: HashSet<PeerId>,
32}
33
34impl<TCodec> Behavior<TCodec>
35where
36 TCodec: Codec + Clone + Send + 'static,
37{
38 pub fn with_codec(codec: TCodec, config: Config) -> Self {
39 Self {
40 clients: HashMap::new(),
41 codec,
42 config,
43 pending_event: VecDeque::new(),
44 pending_response: HashSet::new(),
45 pending_requests: HashMap::new(),
46 pending_dial: HashSet::new(),
47 }
48 }
49
50 pub fn send_request(
51 &mut self,
52 peer_id: PeerId,
53 protocol: TCodec::Protocol,
54 request: TCodec::Request,
55 ) -> RequestId {
56 let request_id = RequestId::next();
57 let request = OutboundRequest {
58 request_id,
59 request,
60 protocol,
61 };
62 if let Some(request) = self.try_send_request(&peer_id, request) {
63 self.pending_dial.insert(peer_id);
64 self.pending_requests
65 .entry(peer_id)
66 .or_default()
67 .push(request);
68 }
69 request_id
70 }
71
72 fn remove_pending_response(&mut self, request_id: RequestId) -> bool {
74 self.pending_response.remove(&request_id)
75 }
76
77 fn try_send_request(
78 &mut self,
79 peer_id: &PeerId,
80 request: OutboundRequest<TCodec>,
81 ) -> Option<OutboundRequest<TCodec>> {
82 if let Some(connections) = self.clients.get_mut(peer_id) {
83 if connections.is_empty() {
84 return Some(request);
85 }
86 let index = request.request_id.0 & connections.len();
87 let connection_id = &mut connections[index];
88 self.pending_response.insert(request.request_id);
89 self.pending_event.push_back(BehaviorEvent::HandlerAction {
90 peer_id: *peer_id,
91 handler: NotifyHandler::One(*connection_id),
92 action: request,
93 });
94 None
95 } else {
96 Some(request)
97 }
98 }
99}
100
101#[derive(Debug)]
102pub enum Event<TResponse> {
103 Response {
104 peer_id: PeerId,
105 connection_id: ConnectionId,
106 request_id: RequestId,
107 response: TResponse,
108 },
109 Failure {
110 peer_id: PeerId,
111 connection_id: ConnectionId,
112 request_id: RequestId,
113 cause: OutboundFailure,
114 },
115 ResponseSent {
116 peer_id: PeerId,
117 connection_id: ConnectionId,
118 request_id: RequestId,
119 },
120}
121
122impl<TCodec> NetworkBehavior for Behavior<TCodec>
123where
124 TCodec: Codec + Clone + Send + 'static,
125{
126 type ConnectionHandler = Handler<TCodec>;
127 type Event = Event<TCodec::Response>;
128 fn on_connection_handler_event(
129 &mut self,
130 id: ConnectionId,
131 peer_id: PeerId,
132 event: THandlerEvent<Self>,
133 ) {
134 match event {
135 handler::Event::Response {
136 request_id,
137 response,
138 } => {
139 let removed = self.remove_pending_response(request_id);
140 debug_assert!(removed, "Response for unknown request: {request_id}");
141 self.pending_event
142 .push_back(BehaviorEvent::Behavior(Event::Response {
143 peer_id,
144 connection_id: id,
145 request_id,
146 response,
147 }));
148 }
149 handler::Event::Unsupported(request_id) => {
150 let removed = self.remove_pending_response(request_id);
151 debug_assert!(removed, "Response for unknown request: {request_id}");
152 self.pending_event
153 .push_back(BehaviorEvent::Behavior(Event::Failure {
154 peer_id,
155 connection_id: id,
156 request_id,
157 cause: OutboundFailure::UnsupportedProtocols,
158 }));
159 }
160 handler::Event::StreamError { request_id, error } => {
161 let removed = self.remove_pending_response(request_id);
162 debug_assert!(removed, "Response for unknown request: {request_id}");
163 self.pending_event
164 .push_back(BehaviorEvent::Behavior(Event::Failure {
165 peer_id,
166 connection_id: id,
167 request_id,
168 cause: error.into(),
169 }));
170 }
171 handler::Event::Timeout(request_id) => {
172 let removed = self.remove_pending_response(request_id);
173 debug_assert!(removed, "Response for unknown request: {request_id}");
174 self.pending_event
175 .push_back(BehaviorEvent::Behavior(Event::Failure {
176 peer_id,
177 connection_id: id,
178 request_id,
179 cause: OutboundFailure::Timeout,
180 }));
181 }
182 }
183 }
184
185 fn poll(
186 &mut self,
187 _cx: &mut Context<'_>,
188 ) -> Poll<BehaviorEvent<Self::Event, THandlerAction<Self>>> {
189 if let Some(event) = self.pending_event.pop_front() {
190 return Poll::Ready(event);
191 }
192 Poll::Pending
193 }
194}
195
196impl<TCodec> NetworkOutgoingBehavior for Behavior<TCodec>
197where
198 TCodec: Codec + Clone + Send + 'static,
199{
200 fn handle_established_connection(
201 &mut self,
202 _id: ConnectionId,
203 _peer_id: PeerId,
204 _addr: &Url,
205 ) -> Result<Self::ConnectionHandler, ConnectionDenied> {
206 let handler = handler::Handler::new(self.codec.clone(), self.config.request_timeout);
207 Ok(handler)
208 }
209
210 fn on_connection_established(&mut self, id: ConnectionId, peer_id: PeerId, _addr: &Url) {
211 self.clients.entry(peer_id).or_default().push(id);
212 }
213
214 fn on_connection_closed(
215 &mut self,
216 id: ConnectionId,
217 peer_id: PeerId,
218 _addr: &Url,
219 _reason: Option<&ConnectionError>,
220 ) {
221 self.clients
222 .entry(peer_id)
223 .or_default()
224 .retain(|x| *x != id);
225 if self
226 .clients
227 .get(&peer_id)
228 .map(|v| v.is_empty())
229 .unwrap_or(false)
230 {
231 self.clients.remove(&peer_id);
232 }
233 }
234
235 fn on_dial_failure(
236 &mut self,
237 id: ConnectionId,
238 peer_id: Option<PeerId>,
239 _addr: Option<&Url>,
240 _error: &DialError,
241 ) {
242 if let Some(peer) = peer_id {
243 if let Some(pending) = self.pending_requests.remove(&peer) {
244 for request in pending {
245 let event = Event::Failure {
246 peer_id: peer,
247 connection_id: id,
248 request_id: request.request_id,
249 cause: OutboundFailure::DialFailure,
250 };
251 self.pending_event.push_back(BehaviorEvent::Behavior(event));
252 }
253 }
254 }
255 }
256
257 fn poll_dial(&mut self, _cx: &mut Context<'_>) -> Poll<DialOpts> {
258 if let Some(peer_id) = self.pending_dial.iter().next().cloned() {
259 self.pending_dial.remove(&peer_id);
260 Poll::Ready(DialOpts::new(None, Some(peer_id)))
261 } else {
262 Poll::Pending
263 }
264 }
265}