volans_request/
client.rs

1pub mod handler;
2
3pub use handler::Handler;
4
5use std::{
6    collections::{HashMap, HashSet, VecDeque},
7    task::{Context, Poll},
8};
9
10use smallvec::SmallVec;
11use volans_core::{Multiaddr, PeerId};
12use volans_swarm::{
13    BehaviorEvent, ConnectionDenied, ConnectionId, DialOpts, NetworkBehavior,
14    NetworkOutgoingBehavior, THandlerAction, THandlerEvent,
15    behavior::NotifyHandler,
16    error::{ConnectionError, DialError},
17};
18
19use crate::{Codec, Config, OutboundFailure, RequestId, client::handler::OutboundRequest};
20
21pub struct Behavior<TCodec>
22where
23    TCodec: Codec + Clone + Send + 'static,
24{
25    clients: HashMap<PeerId, SmallVec<[ConnectionId; 2]>>,
26    codec: TCodec,
27    config: Config,
28    pending_event: VecDeque<BehaviorEvent<Event<TCodec::Response>, THandlerAction<Self>>>,
29    pending_response: HashSet<RequestId>,
30    pending_requests: HashMap<PeerId, SmallVec<[OutboundRequest<TCodec>; 10]>>,
31    pending_dial: HashSet<PeerId>,
32}
33
34impl<TCodec> Behavior<TCodec>
35where
36    TCodec: Codec + Clone + Send + 'static,
37{
38    pub fn with_codec(codec: TCodec, config: Config) -> Self {
39        Self {
40            clients: HashMap::new(),
41            codec,
42            config,
43            pending_event: VecDeque::new(),
44            pending_response: HashSet::new(),
45            pending_requests: HashMap::new(),
46            pending_dial: HashSet::new(),
47        }
48    }
49
50    pub fn send_request(
51        &mut self,
52        peer_id: PeerId,
53        protocol: TCodec::Protocol,
54        request: TCodec::Request,
55    ) -> RequestId {
56        let request_id = RequestId::next();
57        let request = OutboundRequest {
58            request_id,
59            request,
60            protocol,
61        };
62        if let Some(request) = self.try_send_request(&peer_id, request) {
63            self.pending_dial.insert(peer_id);
64            self.pending_requests
65                .entry(peer_id)
66                .or_default()
67                .push(request);
68        }
69        request_id
70    }
71
72    // 移除Pending Response
73    fn remove_pending_response(&mut self, request_id: RequestId) -> bool {
74        self.pending_response.remove(&request_id)
75    }
76
77    fn try_send_request(
78        &mut self,
79        peer_id: &PeerId,
80        request: OutboundRequest<TCodec>,
81    ) -> Option<OutboundRequest<TCodec>> {
82        if let Some(connections) = self.clients.get_mut(peer_id) {
83            if connections.is_empty() {
84                return Some(request);
85            }
86            let index = request.request_id.0 & connections.len();
87            let connection_id = &mut connections[index];
88            self.pending_response.insert(request.request_id);
89            self.pending_event.push_back(BehaviorEvent::HandlerAction {
90                peer_id: *peer_id,
91                handler: NotifyHandler::One(*connection_id),
92                action: request,
93            });
94            None
95        } else {
96            Some(request)
97        }
98    }
99}
100
101#[derive(Debug)]
102pub enum Event<TResponse> {
103    Response {
104        peer_id: PeerId,
105        connection_id: ConnectionId,
106        request_id: RequestId,
107        response: TResponse,
108    },
109    Failure {
110        peer_id: PeerId,
111        connection_id: ConnectionId,
112        request_id: RequestId,
113        cause: OutboundFailure,
114    },
115}
116
117impl<TCodec> NetworkBehavior for Behavior<TCodec>
118where
119    TCodec: Codec + Clone + Send + 'static,
120{
121    type ConnectionHandler = Handler<TCodec>;
122    type Event = Event<TCodec::Response>;
123    fn on_connection_handler_event(
124        &mut self,
125        id: ConnectionId,
126        peer_id: PeerId,
127        event: THandlerEvent<Self>,
128    ) {
129        match event {
130            handler::Event::Response {
131                request_id,
132                response,
133            } => {
134                let removed = self.remove_pending_response(request_id);
135                debug_assert!(removed, "Response for unknown request: {request_id}");
136                self.pending_event
137                    .push_back(BehaviorEvent::Behavior(Event::Response {
138                        peer_id,
139                        connection_id: id,
140                        request_id,
141                        response,
142                    }));
143            }
144            handler::Event::Unsupported(request_id) => {
145                let removed = self.remove_pending_response(request_id);
146                debug_assert!(removed, "Response for unknown request: {request_id}");
147                self.pending_event
148                    .push_back(BehaviorEvent::Behavior(Event::Failure {
149                        peer_id,
150                        connection_id: id,
151                        request_id,
152                        cause: OutboundFailure::UnsupportedProtocols,
153                    }));
154            }
155            handler::Event::StreamError { request_id, error } => {
156                let removed = self.remove_pending_response(request_id);
157                debug_assert!(removed, "Response for unknown request: {request_id}");
158                self.pending_event
159                    .push_back(BehaviorEvent::Behavior(Event::Failure {
160                        peer_id,
161                        connection_id: id,
162                        request_id,
163                        cause: error.into(),
164                    }));
165            }
166            handler::Event::Timeout(request_id) => {
167                let removed = self.remove_pending_response(request_id);
168                debug_assert!(removed, "Response for unknown request: {request_id}");
169                self.pending_event
170                    .push_back(BehaviorEvent::Behavior(Event::Failure {
171                        peer_id,
172                        connection_id: id,
173                        request_id,
174                        cause: OutboundFailure::Timeout,
175                    }));
176            }
177        }
178    }
179
180    fn poll(
181        &mut self,
182        _cx: &mut Context<'_>,
183    ) -> Poll<BehaviorEvent<Self::Event, THandlerAction<Self>>> {
184        if let Some(event) = self.pending_event.pop_front() {
185            return Poll::Ready(event);
186        }
187        Poll::Pending
188    }
189}
190
191impl<TCodec> NetworkOutgoingBehavior for Behavior<TCodec>
192where
193    TCodec: Codec + Clone + Send + 'static,
194{
195    fn handle_established_connection(
196        &mut self,
197        _id: ConnectionId,
198        _peer_id: PeerId,
199        _addr: &Multiaddr,
200    ) -> Result<Self::ConnectionHandler, ConnectionDenied> {
201        let handler = handler::Handler::new(self.codec.clone(), self.config.request_timeout);
202        Ok(handler)
203    }
204
205    fn on_connection_established(&mut self, id: ConnectionId, peer_id: PeerId, _addr: &Multiaddr) {
206        self.clients.entry(peer_id).or_default().push(id);
207    }
208
209    fn on_connection_closed(
210        &mut self,
211        id: ConnectionId,
212        peer_id: PeerId,
213        _addr: &Multiaddr,
214        _reason: Option<&ConnectionError>,
215    ) {
216        self.clients
217            .entry(peer_id)
218            .or_default()
219            .retain(|x| *x != id);
220        if self
221            .clients
222            .get(&peer_id)
223            .map(|v| v.is_empty())
224            .unwrap_or(false)
225        {
226            self.clients.remove(&peer_id);
227        }
228    }
229
230    fn on_dial_failure(
231        &mut self,
232        id: ConnectionId,
233        peer_id: Option<PeerId>,
234        _addr: Option<&Multiaddr>,
235        _error: &DialError,
236    ) {
237        if let Some(peer) = peer_id {
238            if let Some(pending) = self.pending_requests.remove(&peer) {
239                for request in pending {
240                    let event = Event::Failure {
241                        peer_id: peer,
242                        connection_id: id,
243                        request_id: request.request_id,
244                        cause: OutboundFailure::DialFailure,
245                    };
246                    self.pending_event.push_back(BehaviorEvent::Behavior(event));
247                }
248            }
249        }
250    }
251
252    fn poll_dial(&mut self, _cx: &mut Context<'_>) -> Poll<DialOpts> {
253        if let Some(peer_id) = self.pending_dial.iter().next().cloned() {
254            self.pending_dial.remove(&peer_id);
255            Poll::Ready(DialOpts::new(None, Some(peer_id)))
256        } else {
257            Poll::Pending
258        }
259    }
260}