1use crate::client::task::*;
2use crate::client::{
3 ClientCaller, ClientCallerBlocking, ClientConfig, ClientFacts, ClientTransport, ConnPool,
4};
5use crate::proto::RpcAction;
6use crate::{
7 Codec,
8 error::{EncodedErr, RpcIntErr},
9};
10use ahash::AHashMap;
11use arc_swap::ArcSwap;
12use captains_log::filter::LogFilter;
13use crossfire::{AsyncRx, MTx, SendError, mpsc};
14use orb::prelude::{AsyncExec, AsyncRuntime};
15use parking_lot::Mutex;
16use std::fmt;
17use std::sync::{
18 Arc, Weak,
19 atomic::{AtomicUsize, Ordering},
20};
21
22pub struct FailoverPool<F, P>
30where
31 F: ClientFacts,
32 P: ClientTransport,
33{
34 inner: Arc<FailoverPoolInner<F, P>>,
35}
36
37struct FailoverFacts<F>
38where
39 F: ClientFacts,
40{
41 retry_tx: MTx<mpsc::List<FailoverTask<F::Task>>>,
42 facts: Arc<F>,
43 logger: Arc<LogFilter>,
44 retry_limit: usize,
45}
46
47struct FailoverPoolInner<F, P>
48where
49 F: ClientFacts,
50 P: ClientTransport,
51{
52 pools: ArcSwap<ClusterConfig<F, P>>,
53 stateless: bool,
54 next_node: AtomicUsize,
58 pool_channel_size: usize,
59 facts: Arc<FailoverFacts<F>>,
60 add_pool_mutex: Mutex<()>,
62 rt: Option<<P::RT as AsyncRuntime>::Exec>,
63}
64
65struct ClusterConfig<F, P>
66where
67 F: ClientFacts,
68 P: ClientTransport,
69{
70 pools: Vec<ConnPool<FailoverFacts<F>, P>>,
71 ver: u64,
72}
73
74impl<F, P> FailoverPool<F, P>
75where
76 F: ClientFacts,
77 P: ClientTransport,
78{
79 pub fn new(
89 facts: Arc<F>, rt: Option<&<P::RT as AsyncRuntime>::Exec>, addrs: Vec<String>,
90 stateless: bool, retry_limit: usize, pool_channel_size: usize,
91 ) -> Self {
92 let (retry_tx, retry_rx) = mpsc::unbounded_async();
93 let retry_logger = facts.new_logger();
94 let wrapped_facts =
95 Arc::new(FailoverFacts { retry_limit, retry_tx, logger: facts.new_logger(), facts });
96 let mut pools = Vec::with_capacity(addrs.len());
97 for addr in addrs.iter() {
98 let pool = ConnPool::new(wrapped_facts.clone(), rt, addr, pool_channel_size);
99 pools.push(pool);
100 }
101 let inner = Arc::new(FailoverPoolInner::<F, P> {
103 pools: ArcSwap::new(Arc::new(ClusterConfig { ver: 0, pools })),
104 stateless,
105 facts: wrapped_facts,
106 next_node: AtomicUsize::new(0),
107 pool_channel_size,
108 add_pool_mutex: Mutex::new(()),
109 rt: rt.cloned(),
110 });
111 let weak_self = Arc::downgrade(&inner);
112 let f = FailoverPoolInner::retry_worker(weak_self, retry_logger, retry_rx);
113 if let Some(_rt) = rt {
114 _rt.spawn_detach(f);
115 } else {
116 P::RT::spawn_detach(f);
117 }
118 Self { inner }
119 }
120
121 #[inline]
123 pub fn get_retry_limit(&self) -> usize {
124 self.inner.facts.retry_limit
125 }
126
127 pub async fn resubmit(
132 &self, task: F::Task, addr_or_retry: Result<String, usize>, retry_count: usize,
133 max_retries: Option<usize>,
134 ) where
135 F::Task: ClientTask,
136 {
137 match &addr_or_retry {
139 Ok(addr) => {
140 let (pool, index, conf_ver) = self.get_or_add_addr(addr);
141 self.inner.next_node.store(index, Ordering::SeqCst);
144 let failover_task = FailoverTask {
145 last_index: index,
146 config_ver: conf_ver,
147 inner: task,
148 retry: retry_count,
149 should_retry: false,
150 max_retries: max_retries.unwrap_or(0),
151 };
152 pool.send_req(failover_task).await;
153 return;
154 }
155 Err(last_index) => {
156 let cluster = self.inner.pools.load();
157 if let Some((pool, index)) = cluster.select(self.inner.stateless, Err(*last_index))
159 {
160 let failover_task = FailoverTask {
161 last_index: index,
162 config_ver: cluster.ver,
163 inner: task,
164 retry: 0,
165 should_retry: false,
166 max_retries: 0,
167 };
168 pool.send_req(failover_task).await;
169 return;
170 }
171 let mut task = task;
173 task.set_rpc_error(RpcIntErr::Unreachable);
174 task.done();
175 }
176 }
177 }
178
179 fn get_or_add_addr(&self, addr: &str) -> (ConnPool<FailoverFacts<F>, P>, usize, u64) {
181 let inner = &self.inner;
182 {
184 let cluster = inner.pools.load();
185 if let Some((pool, idx)) = cluster.get_by_addr(addr) {
186 return (pool.clone(), idx, cluster.ver);
187 }
188 }
189 {
190 let _guard = self.inner.add_pool_mutex.lock();
192 let old_cluster = self.inner.pools.load_full();
194 if let Some((pool, idx)) = old_cluster.get_by_addr(addr) {
195 return (pool.clone(), idx, old_cluster.ver);
196 }
197 let mut new_cluster = Vec::with_capacity(old_cluster.pools.len() + 1);
198 let new_pool = ConnPool::new(inner.facts.clone(), None, addr, inner.pool_channel_size);
200
201 new_cluster.push(new_pool.clone());
204 new_cluster.extend(old_cluster.pools.iter().cloned());
205 let new_ver = old_cluster.ver.wrapping_add(1);
206 drop(old_cluster);
207 let new_cluster = ClusterConfig { pools: new_cluster, ver: new_ver };
208 inner.pools.store(Arc::new(new_cluster));
209 (new_pool, 0, new_ver)
210 }
211 }
212
213 pub fn update_addrs(&self, addrs: Vec<String>) {
214 let inner = &self.inner;
215 {
216 let _guard = self.inner.add_pool_mutex.lock();
217 let old_cluster = inner.pools.load_full();
218 let mut new_pools: Vec<ConnPool<FailoverFacts<F>, P>> = Vec::with_capacity(addrs.len());
219 let mut old_pools_map = AHashMap::with_capacity(old_cluster.pools.len());
220 for pool in &old_cluster.pools {
221 old_pools_map.insert(pool.get_addr().to_string(), pool);
222 }
223 for addr in addrs {
224 if let Some(reused_pool) = old_pools_map.remove(&addr) {
225 new_pools.push(reused_pool.clone());
226 } else {
227 let new_pool = ConnPool::new(
229 inner.facts.clone(),
230 inner.rt.as_ref(),
231 &addr,
232 inner.pool_channel_size,
233 );
234 new_pools.push(new_pool);
235 }
236 }
237 let new_ver = old_cluster.ver.wrapping_add(1);
238 drop(old_cluster);
239 let new_cluster = ClusterConfig { pools: new_pools, ver: new_ver };
240 inner.pools.store(Arc::new(new_cluster));
241 }
242 }
243}
244
245impl<F, P> ClusterConfig<F, P>
246where
247 F: ClientFacts,
248 P: ClientTransport,
249{
250 #[inline]
255 fn select(
256 &self, stateless: bool, route: Result<&AtomicUsize, usize>,
257 ) -> Option<(&ConnPool<FailoverFacts<F>, P>, usize)> {
258 let l = self.pools.len();
259 if l == 0 {
260 return None;
261 }
262 let seed = match &route {
265 Err(index) => *index + 1, Ok(next_node) => {
267 if stateless {
269 next_node.fetch_add(1, Ordering::Relaxed)
271 } else {
272 next_node.load(Ordering::SeqCst)
273 }
274 }
275 };
276 for i in seed..seed + l {
277 let pool = &self.pools[i % l];
278 if pool.is_healthy() {
279 return Some((pool, i));
280 }
281 }
282 None
283 }
284
285 fn get_by_addr(&self, addr: &str) -> Option<(&ConnPool<FailoverFacts<F>, P>, usize)> {
287 for (i, pool) in self.pools.iter().enumerate() {
288 if pool.get_addr() == addr {
289 return Some((pool, i));
292 }
293 }
294 None
295 }
296}
297
298impl<F, P> FailoverPoolInner<F, P>
299where
300 F: ClientFacts,
301 P: ClientTransport,
302{
303 async fn retry_worker(
304 weak_self: Weak<Self>, logger: Arc<LogFilter>,
305 retry_rx: AsyncRx<mpsc::List<FailoverTask<F::Task>>>,
306 ) {
307 while let Ok(mut task) = retry_rx.recv().await {
308 if let Some(inner) = weak_self.upgrade() {
309 let cluster = inner.pools.load();
310 let route = if cluster.ver == task.config_ver {
311 Err(task.last_index)
312 } else {
313 task.config_ver = cluster.ver;
316 Ok(&inner.next_node) };
318 if let Some((pool, index)) = cluster.select(inner.stateless, route) {
319 if let Err(last) = &route {
320 logger_trace!(
321 logger,
322 "FailoverPool: task {:?} retry {}->{}",
323 task.inner,
324 last,
325 index
326 );
327 }
328 task.last_index = index;
329 pool.send_req(task).await; continue;
331 }
332 logger_debug!(logger, "FailoverPool: no next hoop for {:?}", task.inner);
334 task.done();
335 } else {
336 logger_trace!(logger, "FailoverPool: skip {:?} due to drop", task.inner);
337 task.done();
338 }
339 }
340 logger_trace!(logger, "FailoverPool retry worker exit");
341 }
342}
343
344impl<F, P> Drop for FailoverPoolInner<F, P>
345where
346 F: ClientFacts,
347 P: ClientTransport,
348{
349 #[inline]
350 fn drop(&mut self) {
351 logger_trace!(self.facts.logger, "FailoverPool dropped");
352 }
353}
354
355impl<F> std::ops::Deref for FailoverFacts<F>
357where
358 F: ClientFacts,
359{
360 type Target = F;
361
362 #[inline]
363 fn deref(&self) -> &Self::Target {
364 self.facts.as_ref()
365 }
366}
367
368impl<F> ClientFacts for FailoverFacts<F>
369where
370 F: ClientFacts,
371{
372 type Codec = F::Codec;
373
374 type Task = FailoverTask<F::Task>;
375
376 #[inline]
377 fn new_logger(&self) -> Arc<LogFilter> {
378 self.facts.new_logger()
379 }
380
381 #[inline]
382 fn get_config(&self) -> &ClientConfig {
383 self.facts.get_config()
384 }
385
386 #[inline]
387 fn error_handle(&self, task: FailoverTask<F::Task>) {
388 let retry_limit = if task.max_retries > 0 { task.max_retries } else { self.retry_limit };
390 if task.should_retry && task.retry <= retry_limit {
391 if let Err(SendError(_task)) = self.retry_tx.send(task) {
392 _task.done();
393 }
394 return;
395 }
396 task.inner.done();
397 }
398}
399
400impl<F, P> Clone for FailoverPool<F, P>
401where
402 F: ClientFacts,
403 P: ClientTransport,
404{
405 #[inline]
406 fn clone(&self) -> Self {
407 Self { inner: self.inner.clone() }
408 }
409}
410
411impl<F, P> ClientCaller for FailoverPool<F, P>
412where
413 F: ClientFacts,
414 P: ClientTransport,
415{
416 type Facts = F;
417
418 async fn send_req(&self, mut task: F::Task) {
419 let cluster = self.inner.pools.load();
420 if let Some((pool, index)) = cluster.select(self.inner.stateless, Ok(&self.inner.next_node))
421 {
422 let failover_task = FailoverTask {
423 last_index: index,
424 config_ver: cluster.ver,
425 inner: task,
426 retry: 0,
427 should_retry: false,
428 max_retries: 0, };
430 pool.send_req(failover_task).await;
431 return;
432 }
433
434 task.set_rpc_error(RpcIntErr::Unreachable);
436 task.done();
437 }
438}
439
440impl<F, P> ClientCallerBlocking for FailoverPool<F, P>
441where
442 F: ClientFacts,
443 P: ClientTransport,
444{
445 type Facts = F;
446 fn send_req_blocking(&self, mut task: F::Task) {
447 let cluster = self.inner.pools.load();
448 if let Some((pool, index)) = cluster.select(self.inner.stateless, Ok(&self.inner.next_node))
449 {
450 let failover_task = FailoverTask {
451 last_index: index,
452 config_ver: cluster.ver,
453 inner: task,
454 retry: 0,
455 should_retry: false,
456 max_retries: 0, };
458 pool.send_req_blocking(failover_task);
459 return;
460 }
461
462 task.set_rpc_error(RpcIntErr::Unreachable);
464 task.done();
465 }
466}
467
468pub struct FailoverTask<T: ClientTask> {
469 last_index: usize,
470 config_ver: u64,
471 inner: T,
472 retry: usize,
473 should_retry: bool,
474 max_retries: usize,
476}
477
478impl<T: ClientTask> ClientTaskEncode for FailoverTask<T> {
479 #[inline(always)]
480 fn encode_req<C: Codec>(&self, codec: &C, buf: &mut Vec<u8>) -> Result<usize, ()> {
481 self.inner.encode_req(codec, buf)
482 }
483
484 #[inline(always)]
485 fn get_req_blob(&self) -> Option<&[u8]> {
486 self.inner.get_req_blob()
487 }
488}
489
490impl<T: ClientTask> ClientTaskDecode for FailoverTask<T> {
491 #[inline(always)]
492 fn decode_resp<C: Codec>(&mut self, codec: &C, buf: &[u8]) -> Result<(), ()> {
493 self.inner.decode_resp(codec, buf)
494 }
495
496 #[inline(always)]
497 fn reserve_resp_blob(&mut self, _size: i32) -> Option<&mut [u8]> {
498 self.inner.reserve_resp_blob(_size)
499 }
500}
501
502impl<T: ClientTask> ClientTaskDone for FailoverTask<T> {
503 #[inline(always)]
504 fn set_custom_error<C: Codec>(
505 &mut self, codec: &C, e: EncodedErr, _last_index: usize, _conf_ver: u64,
506 ) {
507 self.should_retry = false;
508 self.inner.set_custom_error(codec, e, self.last_index, self.config_ver);
509 }
510
511 #[inline(always)]
512 fn set_rpc_error(&mut self, e: RpcIntErr) {
513 if e < RpcIntErr::Method {
514 self.should_retry = true;
515 self.retry += 1;
516 } else {
517 self.should_retry = false;
518 }
519 self.inner.set_rpc_error(e.clone());
520 }
521
522 #[inline(always)]
523 fn set_ok(&mut self) {
524 self.inner.set_ok();
525 }
526
527 #[inline(always)]
528 fn done(self) {
529 self.inner.done();
530 }
531}
532
533impl<T: ClientTask> ClientTaskAction for FailoverTask<T> {
534 #[inline(always)]
535 fn get_action<'a>(&'a self) -> RpcAction<'a> {
536 self.inner.get_action()
537 }
538}
539
540impl<T: ClientTask> std::ops::Deref for FailoverTask<T> {
541 type Target = ClientTaskCommon;
542 fn deref(&self) -> &Self::Target {
543 self.inner.deref()
544 }
545}
546
547impl<T: ClientTask> std::ops::DerefMut for FailoverTask<T> {
548 fn deref_mut(&mut self) -> &mut Self::Target {
549 self.inner.deref_mut()
550 }
551}
552
553impl<T: ClientTask> ClientTask for FailoverTask<T> {}
554
555impl<T: ClientTask> fmt::Debug for FailoverTask<T> {
556 #[inline]
557 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
558 self.inner.fmt(f)
559 }
560}