1use crate::client::task::*;
2use crate::client::{
3 ClientCaller, ClientCallerBlocking, ClientConfig, ClientFacts, ClientPool, ClientTransport,
4};
5use crate::proto::RpcAction;
6use crate::{
7 Codec,
8 error::{EncodedErr, RpcIntErr},
9};
10use arc_swap::ArcSwap;
11use captains_log::filter::LogFilter;
12use crossfire::{AsyncRx, MTx, SendError, mpsc};
13use std::fmt;
14use std::sync::{
15 Arc, Weak,
16 atomic::{AtomicU64, AtomicUsize, Ordering},
17};
18
19pub struct FailoverPool<F, P>
27where
28 F: ClientFacts,
29 P: ClientTransport,
30{
31 inner: Arc<FailoverPoolInner<F, P>>,
32}
33
34struct FailoverFacts<F>
35where
36 F: ClientFacts,
37{
38 retry_tx: MTx<mpsc::List<FailoverTask<F::Task>>>,
39 facts: Arc<F>,
40 logger: Arc<LogFilter>,
41 retry_limit: usize,
42}
43
44struct FailoverPoolInner<F, P>
45where
46 F: ClientFacts,
47 P: ClientTransport,
48{
49 pools: ArcSwap<ClusterConfig<F, P>>,
50 round_robin: bool,
51 ver: AtomicU64,
52 rr_counter: AtomicUsize,
53 pool_channel_size: usize,
54 facts: Arc<FailoverFacts<F>>,
55}
56
57struct ClusterConfig<F, P>
58where
59 F: ClientFacts,
60 P: ClientTransport,
61{
62 pools: Vec<ClientPool<FailoverFacts<F>, P>>,
63 ver: u64,
64}
65
66impl<F, P> FailoverPool<F, P>
67where
68 F: ClientFacts,
69 P: ClientTransport,
70{
71 pub fn new(
75 facts: Arc<F>, addrs: Vec<String>, round_robin: bool, retry_limit: usize,
76 pool_channel_size: usize,
77 ) -> Self {
78 let (retry_tx, retry_rx) = mpsc::unbounded_async();
79 let retry_logger = facts.new_logger();
80 let wrapped_facts =
81 Arc::new(FailoverFacts { retry_limit, retry_tx, logger: facts.new_logger(), facts });
82 let mut pools = Vec::with_capacity(addrs.len());
83 for addr in addrs.iter() {
84 let pool = ClientPool::new(wrapped_facts.clone(), addr, pool_channel_size);
85 pools.push(pool);
86 }
87 let inner = Arc::new(FailoverPoolInner::<F, P> {
89 pools: ArcSwap::new(Arc::new(ClusterConfig { ver: 0, pools })),
90 round_robin,
91 facts: wrapped_facts,
92 ver: AtomicU64::new(1),
93 rr_counter: AtomicUsize::new(0),
94 pool_channel_size,
95 });
96 let weak_self = Arc::downgrade(&inner);
97 inner.facts.spawn_detach(async move {
98 FailoverPoolInner::retry_worker(weak_self, retry_logger, retry_rx).await;
99 });
100 Self { inner }
101 }
102
103 pub fn update_addrs(&self, addrs: Vec<String>) {
104 let inner = &self.inner;
105 let old_pools = inner.pools.load_full();
106 let mut new_pools: Vec<ClientPool<FailoverFacts<F>, P>> = Vec::with_capacity(addrs.len());
107
108 let mut old_pools_map = std::collections::HashMap::with_capacity(old_pools.pools.len());
109 for pool in &old_pools.pools {
110 old_pools_map.insert(pool.get_addr().to_string(), pool);
111 }
112
113 for addr in addrs {
114 if let Some(reused_pool) = old_pools_map.remove(&addr) {
115 new_pools.push(reused_pool.clone());
116 } else {
117 let new_pool = ClientPool::new(inner.facts.clone(), &addr, inner.pool_channel_size);
119 new_pools.push(new_pool);
120 }
121 }
122 let new_ver = inner.ver.fetch_add(1, Ordering::Relaxed) + 1;
123 let new_cluster = ClusterConfig { pools: new_pools, ver: new_ver };
124 inner.pools.store(Arc::new(new_cluster));
125 }
126}
127
128impl<F, P> ClusterConfig<F, P>
129where
130 F: ClientFacts,
131 P: ClientTransport,
132{
133 #[inline]
134 fn select(
135 &self, round_robin: bool, rr_counter: &AtomicUsize, last_index: Option<usize>,
136 ) -> Option<(usize, &ClientPool<FailoverFacts<F>, P>)> {
137 let l = self.pools.len();
138 if l == 0 {
139 return None;
140 }
141 let seed = if let Some(index) = last_index {
142 index + 1
144 } else if round_robin {
145 rr_counter.fetch_add(1, Ordering::Relaxed)
146 } else {
147 0
148 };
149 for i in seed..seed + l {
150 let pool = &self.pools[i % l];
151 if pool.is_healthy() {
152 return Some((i, pool));
153 }
154 }
155 return None;
156 }
157}
158
159impl<F, P> FailoverPoolInner<F, P>
160where
161 F: ClientFacts,
162 P: ClientTransport,
163{
164 async fn retry_worker(
165 weak_self: Weak<Self>, logger: Arc<LogFilter>,
166 retry_rx: AsyncRx<mpsc::List<FailoverTask<F::Task>>>,
167 ) {
168 while let Ok(mut task) = retry_rx.recv().await {
169 if let Some(inner) = weak_self.upgrade() {
170 let cluster = inner.pools.load();
171 let last_index = if cluster.ver == task.cluster_ver {
173 Some(task.last_index)
174 } else {
175 task.cluster_ver = cluster.ver;
176 None };
178 if let Some((index, pool)) =
179 cluster.select(inner.round_robin, &inner.rr_counter, last_index)
180 {
181 if let Some(last) = last_index {
182 logger_trace!(
183 logger,
184 "FailoverPool: task {:?} retry {}->{}",
185 task.inner,
186 last,
187 index
188 );
189 }
190 task.last_index = index;
191 pool.send_req(task).await; continue;
193 }
194 logger_debug!(logger, "FailoverPool: no next hoop for {:?}", task.inner);
196 task.done();
197 } else {
198 logger_trace!(logger, "FailoverPool: skip {:?} due to drop", task.inner);
199 task.done();
200 }
201 }
202 logger_trace!(logger, "FailoverPool retry worker exit");
203 }
204}
205
206impl<F, P> Drop for FailoverPoolInner<F, P>
207where
208 F: ClientFacts,
209 P: ClientTransport,
210{
211 #[inline]
212 fn drop(&mut self) {
213 logger_trace!(self.facts.logger, "FailoverPool dropped");
214 }
215}
216
217impl<F> std::ops::Deref for FailoverFacts<F>
219where
220 F: ClientFacts,
221{
222 type Target = F;
223
224 #[inline]
225 fn deref(&self) -> &Self::Target {
226 self.facts.as_ref()
227 }
228}
229
230impl<F> ClientFacts for FailoverFacts<F>
231where
232 F: ClientFacts,
233{
234 type Codec = F::Codec;
235
236 type Task = FailoverTask<F::Task>;
237
238 #[inline]
239 fn new_logger(&self) -> Arc<LogFilter> {
240 self.facts.new_logger()
241 }
242
243 #[inline]
244 fn get_config(&self) -> &ClientConfig {
245 self.facts.get_config()
246 }
247
248 #[inline]
249 fn error_handle(&self, task: FailoverTask<F::Task>) {
250 if task.should_retry && task.retry <= self.retry_limit {
251 if let Err(SendError(_task)) = self.retry_tx.send(task) {
252 _task.done();
253 }
254 return;
255 }
256 task.inner.done();
257 }
258}
259
260impl<F, P> Clone for FailoverPool<F, P>
261where
262 F: ClientFacts,
263 P: ClientTransport,
264{
265 #[inline]
266 fn clone(&self) -> Self {
267 Self { inner: self.inner.clone() }
268 }
269}
270
271impl<F, P> ClientCaller for FailoverPool<F, P>
272where
273 F: ClientFacts,
274 P: ClientTransport,
275{
276 type Facts = F;
277
278 async fn send_req(&self, mut task: F::Task) {
279 let cluster = self.inner.pools.load();
280 if let Some((index, pool)) =
281 cluster.select(self.inner.round_robin, &self.inner.rr_counter, None)
282 {
283 let failover_task = FailoverTask {
284 last_index: index,
285 cluster_ver: cluster.ver,
286 inner: task,
287 retry: 0,
288 should_retry: false,
289 };
290 pool.send_req(failover_task).await;
291 return;
292 }
293
294 task.set_rpc_error(RpcIntErr::Unreachable);
296 task.done();
297 }
298}
299
300impl<F, P> ClientCallerBlocking for FailoverPool<F, P>
301where
302 F: ClientFacts,
303 P: ClientTransport,
304{
305 type Facts = F;
306 fn send_req_blocking(&self, mut task: F::Task) {
307 let cluster = self.inner.pools.load();
308 if let Some((index, pool)) =
309 cluster.select(self.inner.round_robin, &self.inner.rr_counter, None)
310 {
311 let failover_task = FailoverTask {
312 last_index: index,
313 cluster_ver: cluster.ver,
314 inner: task,
315 retry: 0,
316 should_retry: false,
317 };
318 pool.send_req_blocking(failover_task);
319 return;
320 }
321
322 task.set_rpc_error(RpcIntErr::Unreachable);
324 task.done();
325 }
326}
327
328pub struct FailoverTask<T: ClientTask> {
329 last_index: usize,
330 cluster_ver: u64,
331 inner: T,
332 retry: usize,
333 should_retry: bool,
334}
335
336impl<T: ClientTask> ClientTaskEncode for FailoverTask<T> {
337 #[inline(always)]
338 fn encode_req<C: Codec>(&self, codec: &C, buf: &mut Vec<u8>) -> Result<usize, ()> {
339 self.inner.encode_req(codec, buf)
340 }
341
342 #[inline(always)]
343 fn get_req_blob(&self) -> Option<&[u8]> {
344 self.inner.get_req_blob()
345 }
346}
347
348impl<T: ClientTask> ClientTaskDecode for FailoverTask<T> {
349 #[inline(always)]
350 fn decode_resp<C: Codec>(&mut self, codec: &C, buf: &[u8]) -> Result<(), ()> {
351 self.inner.decode_resp(codec, buf)
352 }
353
354 #[inline(always)]
355 fn reserve_resp_blob(&mut self, _size: i32) -> Option<&mut [u8]> {
356 self.inner.reserve_resp_blob(_size)
357 }
358}
359
360impl<T: ClientTask> ClientTaskDone for FailoverTask<T> {
361 #[inline(always)]
362 fn set_custom_error<C: Codec>(&mut self, codec: &C, e: EncodedErr) {
363 self.should_retry = false;
364 self.inner.set_custom_error(codec, e);
365 }
366
367 #[inline(always)]
368 fn set_rpc_error(&mut self, e: RpcIntErr) {
369 if e < RpcIntErr::Method {
370 self.should_retry = true;
371 self.retry += 1;
372 } else {
373 self.should_retry = false;
374 }
375 self.inner.set_rpc_error(e.clone());
376 }
377
378 #[inline(always)]
379 fn set_ok(&mut self) {
380 self.inner.set_ok();
381 }
382
383 #[inline(always)]
384 fn done(self) {
385 self.inner.done();
386 }
387}
388
389impl<T: ClientTask> ClientTaskAction for FailoverTask<T> {
390 #[inline(always)]
391 fn get_action<'a>(&'a self) -> RpcAction<'a> {
392 self.inner.get_action()
393 }
394}
395
396impl<T: ClientTask> std::ops::Deref for FailoverTask<T> {
397 type Target = ClientTaskCommon;
398 fn deref(&self) -> &Self::Target {
399 self.inner.deref()
400 }
401}
402
403impl<T: ClientTask> std::ops::DerefMut for FailoverTask<T> {
404 fn deref_mut(&mut self) -> &mut Self::Target {
405 self.inner.deref_mut()
406 }
407}
408
409impl<T: ClientTask> ClientTask for FailoverTask<T> {}
410
411impl<T: ClientTask> fmt::Debug for FailoverTask<T> {
412 #[inline]
413 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
414 self.inner.fmt(f)
415 }
416}