1use crate::client::task::*;
2use crate::client::{
3 ClientCaller, ClientCallerBlocking, ClientConfig, ClientFacts, ClientTransport, ConnPool,
4};
5use crate::proto::RpcAction;
6use crate::{
7 Codec,
8 error::{EncodedErr, RpcIntErr},
9};
10use ahash::AHashMap;
11use arc_swap::ArcSwap;
12use captains_log::filter::LogFilter;
13use crossfire::{AsyncRx, MTx, SendError, mpsc};
14use orb::prelude::AsyncExec;
15use parking_lot::Mutex;
16use std::fmt;
17use std::sync::{
18 Arc, Weak,
19 atomic::{AtomicUsize, Ordering},
20};
21
22pub struct FailoverPool<F, P>
30where
31 F: ClientFacts,
32 P: ClientTransport,
33{
34 inner: Arc<FailoverPoolInner<F, P>>,
35}
36
37struct FailoverFacts<F>
38where
39 F: ClientFacts,
40{
41 retry_tx: MTx<mpsc::List<FailoverTask<F::Task>>>,
42 facts: Arc<F>,
43 logger: Arc<LogFilter>,
44 retry_limit: usize,
45}
46
47struct FailoverPoolInner<F, P>
48where
49 F: ClientFacts,
50 P: ClientTransport,
51{
52 pools: ArcSwap<ClusterConfig<F, P>>,
53 stateless: bool,
54 next_node: AtomicUsize,
58 pool_channel_size: usize,
59 facts: Arc<FailoverFacts<F>>,
60 rt: P::RT,
61 add_pool_mutex: Mutex<()>,
63}
64
65struct ClusterConfig<F, P>
66where
67 F: ClientFacts,
68 P: ClientTransport,
69{
70 pools: Vec<ConnPool<FailoverFacts<F>, P>>,
71 ver: u64,
72}
73
74impl<F, P> FailoverPool<F, P>
75where
76 F: ClientFacts,
77 P: ClientTransport,
78{
79 pub fn new(
83 facts: Arc<F>, rt: &P::RT, addrs: Vec<String>, stateless: bool, retry_limit: usize,
84 pool_channel_size: usize,
85 ) -> Self {
86 let (retry_tx, retry_rx) = mpsc::unbounded_async();
87 let retry_logger = facts.new_logger();
88 let wrapped_facts =
89 Arc::new(FailoverFacts { retry_limit, retry_tx, logger: facts.new_logger(), facts });
90 let mut pools = Vec::with_capacity(addrs.len());
91 for addr in addrs.iter() {
92 let pool = ConnPool::new(wrapped_facts.clone(), rt, addr, pool_channel_size);
93 pools.push(pool);
94 }
95 let inner = Arc::new(FailoverPoolInner::<F, P> {
97 pools: ArcSwap::new(Arc::new(ClusterConfig { ver: 0, pools })),
98 stateless,
99 facts: wrapped_facts,
100 next_node: AtomicUsize::new(0),
101 pool_channel_size,
102 rt: rt.clone(),
103 add_pool_mutex: Mutex::new(()),
104 });
105 let weak_self = Arc::downgrade(&inner);
106 rt.spawn_detach(async move {
107 FailoverPoolInner::retry_worker(weak_self, retry_logger, retry_rx).await;
108 });
109 Self { inner }
110 }
111
112 #[inline]
114 pub fn get_retry_limit(&self) -> usize {
115 self.inner.facts.retry_limit
116 }
117
118 pub async fn resubmit(
123 &self, task: F::Task, addr_or_retry: Result<String, usize>, retry_count: usize,
124 max_retries: Option<usize>,
125 ) where
126 F::Task: ClientTask,
127 {
128 match &addr_or_retry {
130 Ok(addr) => {
131 let (pool, index, conf_ver) = self.get_or_add_addr(addr);
132 self.inner.next_node.store(index, Ordering::SeqCst);
135 let failover_task = FailoverTask {
136 last_index: index,
137 config_ver: conf_ver,
138 inner: task,
139 retry: retry_count,
140 should_retry: false,
141 max_retries: max_retries.unwrap_or(0),
142 };
143 pool.send_req(failover_task).await;
144 return;
145 }
146 Err(last_index) => {
147 let cluster = self.inner.pools.load();
148 if let Some((pool, index)) = cluster.select(self.inner.stateless, Err(*last_index))
150 {
151 let failover_task = FailoverTask {
152 last_index: index,
153 config_ver: cluster.ver,
154 inner: task,
155 retry: 0,
156 should_retry: false,
157 max_retries: 0,
158 };
159 pool.send_req(failover_task).await;
160 return;
161 }
162 let mut task = task;
164 task.set_rpc_error(RpcIntErr::Unreachable);
165 task.done();
166 }
167 }
168 }
169
170 fn get_or_add_addr(&self, addr: &str) -> (ConnPool<FailoverFacts<F>, P>, usize, u64) {
172 let inner = &self.inner;
173 {
175 let cluster = inner.pools.load();
176 if let Some((pool, idx)) = cluster.get_by_addr(addr) {
177 return (pool.clone(), idx, cluster.ver);
178 }
179 }
180 {
181 let _guard = self.inner.add_pool_mutex.lock();
183 let old_cluster = self.inner.pools.load_full();
185 if let Some((pool, idx)) = old_cluster.get_by_addr(addr) {
186 return (pool.clone(), idx, old_cluster.ver);
187 }
188 let mut new_cluster = Vec::with_capacity(old_cluster.pools.len() + 1);
189 let new_pool =
191 ConnPool::new(inner.facts.clone(), &inner.rt, addr, inner.pool_channel_size);
192
193 new_cluster.push(new_pool.clone());
196 new_cluster.extend(old_cluster.pools.iter().cloned());
197 let new_ver = old_cluster.ver.wrapping_add(1);
198 drop(old_cluster);
199 let new_cluster = ClusterConfig { pools: new_cluster, ver: new_ver };
200 inner.pools.store(Arc::new(new_cluster));
201 (new_pool, 0, new_ver)
202 }
203 }
204
205 pub fn update_addrs(&self, addrs: Vec<String>) {
206 let inner = &self.inner;
207 {
208 let _guard = self.inner.add_pool_mutex.lock();
209 let old_cluster = inner.pools.load_full();
210 let mut new_pools: Vec<ConnPool<FailoverFacts<F>, P>> = Vec::with_capacity(addrs.len());
211 let mut old_pools_map = AHashMap::with_capacity(old_cluster.pools.len());
212 for pool in &old_cluster.pools {
213 old_pools_map.insert(pool.get_addr().to_string(), pool);
214 }
215 for addr in addrs {
216 if let Some(reused_pool) = old_pools_map.remove(&addr) {
217 new_pools.push(reused_pool.clone());
218 } else {
219 let new_pool = ConnPool::new(
221 inner.facts.clone(),
222 &inner.rt,
223 &addr,
224 inner.pool_channel_size,
225 );
226 new_pools.push(new_pool);
227 }
228 }
229 let new_ver = old_cluster.ver.wrapping_add(1);
230 drop(old_cluster);
231 let new_cluster = ClusterConfig { pools: new_pools, ver: new_ver };
232 inner.pools.store(Arc::new(new_cluster));
233 }
234 }
235}
236
237impl<F, P> ClusterConfig<F, P>
238where
239 F: ClientFacts,
240 P: ClientTransport,
241{
242 #[inline]
247 fn select(
248 &self, stateless: bool, route: Result<&AtomicUsize, usize>,
249 ) -> Option<(&ConnPool<FailoverFacts<F>, P>, usize)> {
250 let l = self.pools.len();
251 if l == 0 {
252 return None;
253 }
254 let seed = match &route {
257 Err(index) => *index + 1, Ok(next_node) => {
259 if stateless {
261 next_node.fetch_add(1, Ordering::Relaxed)
263 } else {
264 next_node.load(Ordering::SeqCst)
265 }
266 }
267 };
268 for i in seed..seed + l {
269 let pool = &self.pools[i % l];
270 if pool.is_healthy() {
271 return Some((pool, i));
272 }
273 }
274 None
275 }
276
277 fn get_by_addr(&self, addr: &str) -> Option<(&ConnPool<FailoverFacts<F>, P>, usize)> {
279 for (i, pool) in self.pools.iter().enumerate() {
280 if pool.get_addr() == addr {
281 return Some((pool, i));
284 }
285 }
286 None
287 }
288}
289
290impl<F, P> FailoverPoolInner<F, P>
291where
292 F: ClientFacts,
293 P: ClientTransport,
294{
295 async fn retry_worker(
296 weak_self: Weak<Self>, logger: Arc<LogFilter>,
297 retry_rx: AsyncRx<mpsc::List<FailoverTask<F::Task>>>,
298 ) {
299 while let Ok(mut task) = retry_rx.recv().await {
300 if let Some(inner) = weak_self.upgrade() {
301 let cluster = inner.pools.load();
302 let route = if cluster.ver == task.config_ver {
303 Err(task.last_index)
304 } else {
305 task.config_ver = cluster.ver;
308 Ok(&inner.next_node) };
310 if let Some((pool, index)) = cluster.select(inner.stateless, route) {
311 if let Err(last) = &route {
312 logger_trace!(
313 logger,
314 "FailoverPool: task {:?} retry {}->{}",
315 task.inner,
316 last,
317 index
318 );
319 }
320 task.last_index = index;
321 pool.send_req(task).await; continue;
323 }
324 logger_debug!(logger, "FailoverPool: no next hoop for {:?}", task.inner);
326 task.done();
327 } else {
328 logger_trace!(logger, "FailoverPool: skip {:?} due to drop", task.inner);
329 task.done();
330 }
331 }
332 logger_trace!(logger, "FailoverPool retry worker exit");
333 }
334}
335
336impl<F, P> Drop for FailoverPoolInner<F, P>
337where
338 F: ClientFacts,
339 P: ClientTransport,
340{
341 #[inline]
342 fn drop(&mut self) {
343 logger_trace!(self.facts.logger, "FailoverPool dropped");
344 }
345}
346
347impl<F> std::ops::Deref for FailoverFacts<F>
349where
350 F: ClientFacts,
351{
352 type Target = F;
353
354 #[inline]
355 fn deref(&self) -> &Self::Target {
356 self.facts.as_ref()
357 }
358}
359
360impl<F> ClientFacts for FailoverFacts<F>
361where
362 F: ClientFacts,
363{
364 type Codec = F::Codec;
365
366 type Task = FailoverTask<F::Task>;
367
368 #[inline]
369 fn new_logger(&self) -> Arc<LogFilter> {
370 self.facts.new_logger()
371 }
372
373 #[inline]
374 fn get_config(&self) -> &ClientConfig {
375 self.facts.get_config()
376 }
377
378 #[inline]
379 fn error_handle(&self, task: FailoverTask<F::Task>) {
380 let retry_limit = if task.max_retries > 0 { task.max_retries } else { self.retry_limit };
382 if task.should_retry && task.retry <= retry_limit {
383 if let Err(SendError(_task)) = self.retry_tx.send(task) {
384 _task.done();
385 }
386 return;
387 }
388 task.inner.done();
389 }
390}
391
392impl<F, P> Clone for FailoverPool<F, P>
393where
394 F: ClientFacts,
395 P: ClientTransport,
396{
397 #[inline]
398 fn clone(&self) -> Self {
399 Self { inner: self.inner.clone() }
400 }
401}
402
403impl<F, P> ClientCaller for FailoverPool<F, P>
404where
405 F: ClientFacts,
406 P: ClientTransport,
407{
408 type Facts = F;
409
410 async fn send_req(&self, mut task: F::Task) {
411 let cluster = self.inner.pools.load();
412 if let Some((pool, index)) = cluster.select(self.inner.stateless, Ok(&self.inner.next_node))
413 {
414 let failover_task = FailoverTask {
415 last_index: index,
416 config_ver: cluster.ver,
417 inner: task,
418 retry: 0,
419 should_retry: false,
420 max_retries: 0, };
422 pool.send_req(failover_task).await;
423 return;
424 }
425
426 task.set_rpc_error(RpcIntErr::Unreachable);
428 task.done();
429 }
430}
431
432impl<F, P> ClientCallerBlocking for FailoverPool<F, P>
433where
434 F: ClientFacts,
435 P: ClientTransport,
436{
437 type Facts = F;
438 fn send_req_blocking(&self, mut task: F::Task) {
439 let cluster = self.inner.pools.load();
440 if let Some((pool, index)) = cluster.select(self.inner.stateless, Ok(&self.inner.next_node))
441 {
442 let failover_task = FailoverTask {
443 last_index: index,
444 config_ver: cluster.ver,
445 inner: task,
446 retry: 0,
447 should_retry: false,
448 max_retries: 0, };
450 pool.send_req_blocking(failover_task);
451 return;
452 }
453
454 task.set_rpc_error(RpcIntErr::Unreachable);
456 task.done();
457 }
458}
459
460pub struct FailoverTask<T: ClientTask> {
461 last_index: usize,
462 config_ver: u64,
463 inner: T,
464 retry: usize,
465 should_retry: bool,
466 max_retries: usize,
468}
469
470impl<T: ClientTask> ClientTaskEncode for FailoverTask<T> {
471 #[inline(always)]
472 fn encode_req<C: Codec>(&self, codec: &C, buf: &mut Vec<u8>) -> Result<usize, ()> {
473 self.inner.encode_req(codec, buf)
474 }
475
476 #[inline(always)]
477 fn get_req_blob(&self) -> Option<&[u8]> {
478 self.inner.get_req_blob()
479 }
480}
481
482impl<T: ClientTask> ClientTaskDecode for FailoverTask<T> {
483 #[inline(always)]
484 fn decode_resp<C: Codec>(&mut self, codec: &C, buf: &[u8]) -> Result<(), ()> {
485 self.inner.decode_resp(codec, buf)
486 }
487
488 #[inline(always)]
489 fn reserve_resp_blob(&mut self, _size: i32) -> Option<&mut [u8]> {
490 self.inner.reserve_resp_blob(_size)
491 }
492}
493
494impl<T: ClientTask> ClientTaskDone for FailoverTask<T> {
495 #[inline(always)]
496 fn set_custom_error<C: Codec>(
497 &mut self, codec: &C, e: EncodedErr, _last_index: usize, _conf_ver: u64,
498 ) {
499 self.should_retry = false;
500 self.inner.set_custom_error(codec, e, self.last_index, self.config_ver);
501 }
502
503 #[inline(always)]
504 fn set_rpc_error(&mut self, e: RpcIntErr) {
505 if e < RpcIntErr::Method {
506 self.should_retry = true;
507 self.retry += 1;
508 } else {
509 self.should_retry = false;
510 }
511 self.inner.set_rpc_error(e.clone());
512 }
513
514 #[inline(always)]
515 fn set_ok(&mut self) {
516 self.inner.set_ok();
517 }
518
519 #[inline(always)]
520 fn done(self) {
521 self.inner.done();
522 }
523}
524
525impl<T: ClientTask> ClientTaskAction for FailoverTask<T> {
526 #[inline(always)]
527 fn get_action<'a>(&'a self) -> RpcAction<'a> {
528 self.inner.get_action()
529 }
530}
531
532impl<T: ClientTask> std::ops::Deref for FailoverTask<T> {
533 type Target = ClientTaskCommon;
534 fn deref(&self) -> &Self::Target {
535 self.inner.deref()
536 }
537}
538
539impl<T: ClientTask> std::ops::DerefMut for FailoverTask<T> {
540 fn deref_mut(&mut self) -> &mut Self::Target {
541 self.inner.deref_mut()
542 }
543}
544
545impl<T: ClientTask> ClientTask for FailoverTask<T> {}
546
547impl<T: ClientTask> fmt::Debug for FailoverTask<T> {
548 #[inline]
549 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
550 self.inner.fmt(f)
551 }
552}