1use crate::client::task::*;
2use crate::client::{
3 ClientCaller, ClientCallerBlocking, ClientConfig, ClientFacts, ClientPool, ClientTransport,
4};
5use crate::proto::RpcAction;
6use crate::{
7 Codec,
8 error::{EncodedErr, RpcIntErr},
9};
10use arc_swap::ArcSwap;
11use captains_log::filter::LogFilter;
12use crossfire::{AsyncRx, MTx, SendError, mpsc};
13use orb::AsyncRuntime;
14use std::fmt;
15use std::sync::{
16 Arc, Weak,
17 atomic::{AtomicU64, AtomicUsize, Ordering},
18};
19
20pub struct FailoverPool<F, P>
28where
29 F: ClientFacts,
30 P: ClientTransport,
31{
32 inner: Arc<FailoverPoolInner<F, P>>,
33}
34
35struct FailoverFacts<F>
36where
37 F: ClientFacts,
38{
39 retry_tx: MTx<mpsc::List<FailoverTask<F::Task>>>,
40 facts: Arc<F>,
41 logger: Arc<LogFilter>,
42 retry_limit: usize,
43}
44
45struct FailoverPoolInner<F, P>
46where
47 F: ClientFacts,
48 P: ClientTransport,
49{
50 pools: ArcSwap<ClusterConfig<F, P>>,
51 round_robin: bool,
52 ver: AtomicU64,
53 rr_counter: AtomicUsize,
54 pool_channel_size: usize,
55 facts: Arc<FailoverFacts<F>>,
56}
57
58struct ClusterConfig<F, P>
59where
60 F: ClientFacts,
61 P: ClientTransport,
62{
63 pools: Vec<ClientPool<FailoverFacts<F>, P>>,
64 ver: u64,
65}
66
67impl<F, P> FailoverPool<F, P>
68where
69 F: ClientFacts,
70 P: ClientTransport,
71{
72 pub fn new<RT: AsyncRuntime + Clone>(
76 facts: Arc<F>, rt: &RT, addrs: Vec<String>, round_robin: bool, retry_limit: usize,
77 pool_channel_size: usize,
78 ) -> Self {
79 let (retry_tx, retry_rx) = mpsc::unbounded_async();
80 let retry_logger = facts.new_logger();
81 let wrapped_facts =
82 Arc::new(FailoverFacts { retry_limit, retry_tx, logger: facts.new_logger(), facts });
83 let mut pools = Vec::with_capacity(addrs.len());
84 for addr in addrs.iter() {
85 let pool = ClientPool::new::<RT>(wrapped_facts.clone(), rt, addr, pool_channel_size);
86 pools.push(pool);
87 }
88 let inner = Arc::new(FailoverPoolInner::<F, P> {
90 pools: ArcSwap::new(Arc::new(ClusterConfig { ver: 0, pools })),
91 round_robin,
92 facts: wrapped_facts,
93 ver: AtomicU64::new(1),
94 rr_counter: AtomicUsize::new(0),
95 pool_channel_size,
96 });
97 let weak_self = Arc::downgrade(&inner);
98 rt.spawn_detach(async move {
99 FailoverPoolInner::retry_worker(weak_self, retry_logger, retry_rx).await;
100 });
101 Self { inner }
102 }
103
104 pub fn update_addrs<RT: AsyncRuntime + Clone>(&self, addrs: Vec<String>, rt: &RT) {
105 let inner = &self.inner;
106 let old_pools = inner.pools.load_full();
107 let mut new_pools: Vec<ClientPool<FailoverFacts<F>, P>> = Vec::with_capacity(addrs.len());
108
109 let mut old_pools_map = std::collections::HashMap::with_capacity(old_pools.pools.len());
110 for pool in &old_pools.pools {
111 old_pools_map.insert(pool.get_addr().to_string(), pool);
112 }
113
114 for addr in addrs {
115 if let Some(reused_pool) = old_pools_map.remove(&addr) {
116 new_pools.push(reused_pool.clone());
117 } else {
118 let new_pool =
120 ClientPool::new::<RT>(inner.facts.clone(), rt, &addr, inner.pool_channel_size);
121 new_pools.push(new_pool);
122 }
123 }
124 let new_ver = inner.ver.fetch_add(1, Ordering::Relaxed) + 1;
125 let new_cluster = ClusterConfig { pools: new_pools, ver: new_ver };
126 inner.pools.store(Arc::new(new_cluster));
127 }
128}
129
130impl<F, P> ClusterConfig<F, P>
131where
132 F: ClientFacts,
133 P: ClientTransport,
134{
135 #[inline]
136 fn select(
137 &self, round_robin: bool, rr_counter: &AtomicUsize, last_index: Option<usize>,
138 ) -> Option<(usize, &ClientPool<FailoverFacts<F>, P>)> {
139 let l = self.pools.len();
140 if l == 0 {
141 return None;
142 }
143 let seed = if let Some(index) = last_index {
144 index + 1
146 } else if round_robin {
147 rr_counter.fetch_add(1, Ordering::Relaxed)
148 } else {
149 0
150 };
151 for i in seed..seed + l {
152 let pool = &self.pools[i % l];
153 if pool.is_healthy() {
154 return Some((i, pool));
155 }
156 }
157 return None;
158 }
159}
160
161impl<F, P> FailoverPoolInner<F, P>
162where
163 F: ClientFacts,
164 P: ClientTransport,
165{
166 async fn retry_worker(
167 weak_self: Weak<Self>, logger: Arc<LogFilter>,
168 retry_rx: AsyncRx<mpsc::List<FailoverTask<F::Task>>>,
169 ) {
170 while let Ok(mut task) = retry_rx.recv().await {
171 if let Some(inner) = weak_self.upgrade() {
172 let cluster = inner.pools.load();
173 let last_index = if cluster.ver == task.cluster_ver {
175 Some(task.last_index)
176 } else {
177 task.cluster_ver = cluster.ver;
178 None };
180 if let Some((index, pool)) =
181 cluster.select(inner.round_robin, &inner.rr_counter, last_index)
182 {
183 if let Some(last) = last_index {
184 logger_trace!(
185 logger,
186 "FailoverPool: task {:?} retry {}->{}",
187 task.inner,
188 last,
189 index
190 );
191 }
192 task.last_index = index;
193 pool.send_req(task).await; continue;
195 }
196 logger_debug!(logger, "FailoverPool: no next hoop for {:?}", task.inner);
198 task.done();
199 } else {
200 logger_trace!(logger, "FailoverPool: skip {:?} due to drop", task.inner);
201 task.done();
202 }
203 }
204 logger_trace!(logger, "FailoverPool retry worker exit");
205 }
206}
207
208impl<F, P> Drop for FailoverPoolInner<F, P>
209where
210 F: ClientFacts,
211 P: ClientTransport,
212{
213 #[inline]
214 fn drop(&mut self) {
215 logger_trace!(self.facts.logger, "FailoverPool dropped");
216 }
217}
218
219impl<F> std::ops::Deref for FailoverFacts<F>
221where
222 F: ClientFacts,
223{
224 type Target = F;
225
226 #[inline]
227 fn deref(&self) -> &Self::Target {
228 self.facts.as_ref()
229 }
230}
231
232impl<F> ClientFacts for FailoverFacts<F>
233where
234 F: ClientFacts,
235{
236 type Codec = F::Codec;
237
238 type Task = FailoverTask<F::Task>;
239
240 #[inline]
241 fn new_logger(&self) -> Arc<LogFilter> {
242 self.facts.new_logger()
243 }
244
245 #[inline]
246 fn get_config(&self) -> &ClientConfig {
247 self.facts.get_config()
248 }
249
250 #[inline]
251 fn error_handle(&self, task: FailoverTask<F::Task>) {
252 if task.should_retry && task.retry <= self.retry_limit {
253 if let Err(SendError(_task)) = self.retry_tx.send(task) {
254 _task.done();
255 }
256 return;
257 }
258 task.inner.done();
259 }
260}
261
262impl<F, P> Clone for FailoverPool<F, P>
263where
264 F: ClientFacts,
265 P: ClientTransport,
266{
267 #[inline]
268 fn clone(&self) -> Self {
269 Self { inner: self.inner.clone() }
270 }
271}
272
273impl<F, P> ClientCaller for FailoverPool<F, P>
274where
275 F: ClientFacts,
276 P: ClientTransport,
277{
278 type Facts = F;
279
280 async fn send_req(&self, mut task: F::Task) {
281 let cluster = self.inner.pools.load();
282 if let Some((index, pool)) =
283 cluster.select(self.inner.round_robin, &self.inner.rr_counter, None)
284 {
285 let failover_task = FailoverTask {
286 last_index: index,
287 cluster_ver: cluster.ver,
288 inner: task,
289 retry: 0,
290 should_retry: false,
291 };
292 pool.send_req(failover_task).await;
293 return;
294 }
295
296 task.set_rpc_error(RpcIntErr::Unreachable);
298 task.done();
299 }
300}
301
302impl<F, P> ClientCallerBlocking for FailoverPool<F, P>
303where
304 F: ClientFacts,
305 P: ClientTransport,
306{
307 type Facts = F;
308 fn send_req_blocking(&self, mut task: F::Task) {
309 let cluster = self.inner.pools.load();
310 if let Some((index, pool)) =
311 cluster.select(self.inner.round_robin, &self.inner.rr_counter, None)
312 {
313 let failover_task = FailoverTask {
314 last_index: index,
315 cluster_ver: cluster.ver,
316 inner: task,
317 retry: 0,
318 should_retry: false,
319 };
320 pool.send_req_blocking(failover_task);
321 return;
322 }
323
324 task.set_rpc_error(RpcIntErr::Unreachable);
326 task.done();
327 }
328}
329
330pub struct FailoverTask<T: ClientTask> {
331 last_index: usize,
332 cluster_ver: u64,
333 inner: T,
334 retry: usize,
335 should_retry: bool,
336}
337
338impl<T: ClientTask> ClientTaskEncode for FailoverTask<T> {
339 #[inline(always)]
340 fn encode_req<C: Codec>(&self, codec: &C, buf: &mut Vec<u8>) -> Result<usize, ()> {
341 self.inner.encode_req(codec, buf)
342 }
343
344 #[inline(always)]
345 fn get_req_blob(&self) -> Option<&[u8]> {
346 self.inner.get_req_blob()
347 }
348}
349
350impl<T: ClientTask> ClientTaskDecode for FailoverTask<T> {
351 #[inline(always)]
352 fn decode_resp<C: Codec>(&mut self, codec: &C, buf: &[u8]) -> Result<(), ()> {
353 self.inner.decode_resp(codec, buf)
354 }
355
356 #[inline(always)]
357 fn reserve_resp_blob(&mut self, _size: i32) -> Option<&mut [u8]> {
358 self.inner.reserve_resp_blob(_size)
359 }
360}
361
362impl<T: ClientTask> ClientTaskDone for FailoverTask<T> {
363 #[inline(always)]
364 fn set_custom_error<C: Codec>(&mut self, codec: &C, e: EncodedErr) {
365 self.should_retry = false;
366 self.inner.set_custom_error(codec, e);
367 }
368
369 #[inline(always)]
370 fn set_rpc_error(&mut self, e: RpcIntErr) {
371 if e < RpcIntErr::Method {
372 self.should_retry = true;
373 self.retry += 1;
374 } else {
375 self.should_retry = false;
376 }
377 self.inner.set_rpc_error(e.clone());
378 }
379
380 #[inline(always)]
381 fn set_ok(&mut self) {
382 self.inner.set_ok();
383 }
384
385 #[inline(always)]
386 fn done(self) {
387 self.inner.done();
388 }
389}
390
391impl<T: ClientTask> ClientTaskAction for FailoverTask<T> {
392 #[inline(always)]
393 fn get_action<'a>(&'a self) -> RpcAction<'a> {
394 self.inner.get_action()
395 }
396}
397
398impl<T: ClientTask> std::ops::Deref for FailoverTask<T> {
399 type Target = ClientTaskCommon;
400 fn deref(&self) -> &Self::Target {
401 self.inner.deref()
402 }
403}
404
405impl<T: ClientTask> std::ops::DerefMut for FailoverTask<T> {
406 fn deref_mut(&mut self) -> &mut Self::Target {
407 self.inner.deref_mut()
408 }
409}
410
411impl<T: ClientTask> ClientTask for FailoverTask<T> {}
412
413impl<T: ClientTask> fmt::Debug for FailoverTask<T> {
414 #[inline]
415 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
416 self.inner.fmt(f)
417 }
418}