1use crate::client::task::*;
2use crate::client::{
3 ClientCaller, ClientCallerBlocking, ClientConfig, ClientFacts, ClientPool, ClientTransport,
4};
5use crate::proto::RpcAction;
6use crate::{
7 Codec,
8 error::{EncodedErr, RpcIntErr},
9};
10use arc_swap::ArcSwapOption;
11use captains_log::filter::LogFilter;
12use crossfire::*;
13use std::fmt;
14use std::sync::{
15 Arc, Weak,
16 atomic::{AtomicU64, AtomicUsize, Ordering},
17};
18
19pub struct FailoverPool<F, P>(Arc<FailoverPoolInner<F, P>>)
29where
30 F: ClientFacts,
31 P: ClientTransport;
32
33struct FailoverPoolInner<F, P>
34where
35 F: ClientFacts,
36 P: ClientTransport,
37{
38 pools: ArcSwapOption<ClusterConfig<F, P>>,
39 round_robin: bool,
40 facts: Arc<F>,
41 retry_limit: usize,
42 retry_tx: MTx<FailoverTask<F::Task>>,
43 ver: AtomicU64,
44 rr_counter: AtomicUsize,
45 pool_channel_size: usize,
46 logger: Arc<LogFilter>,
47}
48
49struct ClusterConfig<F, P>
50where
51 F: ClientFacts,
52 P: ClientTransport,
53{
54 pools: Vec<ClientPool<FailoverPoolInner<F, P>, P>>,
55 ver: u64,
56}
57
58impl<F, P> FailoverPool<F, P>
59where
60 F: ClientFacts,
61 P: ClientTransport,
62{
63 pub fn new(
64 facts: Arc<F>, addrs: Vec<String>, round_robin: bool, retry_limit: usize,
65 pool_channel_size: usize,
66 ) -> Self {
67 let (retry_tx, retry_rx) = mpsc::unbounded_async();
68 let inner = Arc::new(FailoverPoolInner::<F, P> {
70 pools: ArcSwapOption::new(None),
71 round_robin,
72 facts: facts.clone(),
73 retry_limit,
74 retry_tx: retry_tx.into(),
75 ver: AtomicU64::new(1),
76 rr_counter: AtomicUsize::new(0),
77 pool_channel_size,
78 logger: facts.new_logger(),
79 });
80 let mut pools = Vec::with_capacity(addrs.len());
81 for addr in addrs.iter() {
82 let pool = ClientPool::new(inner.clone(), &addr, pool_channel_size);
83 pools.push(pool);
84 }
85 inner.pools.store(Some(Arc::new(ClusterConfig { ver: 0, pools })));
86
87 let retry_logger = facts.new_logger();
88 let weak_self = Arc::downgrade(&inner);
89 facts.spawn_detach(async move {
90 FailoverPoolInner::retry_worker(weak_self, retry_logger, retry_rx).await;
91 });
92 Self(inner)
93 }
94
95 pub fn update_addrs(&self, addrs: Vec<String>) {
96 let inner = &self.0;
97 let old_cluster_arc = inner.pools.load();
98 let old_pools = old_cluster_arc.as_ref().map(|c| c.pools.clone()).unwrap_or_else(Vec::new);
99
100 let mut new_pools = Vec::with_capacity(addrs.len());
101
102 let mut old_pools_map = std::collections::HashMap::with_capacity(old_pools.len());
103 for pool in old_pools {
104 old_pools_map.insert(pool.get_addr().to_string(), pool);
105 }
106
107 for addr in addrs {
108 if let Some(reused_pool) = old_pools_map.remove(&addr) {
109 new_pools.push(reused_pool);
110 } else {
111 let new_pool = ClientPool::new(inner.clone(), &addr, inner.pool_channel_size);
113 new_pools.push(new_pool);
114 }
115 }
116
117 let new_ver = inner.ver.fetch_add(1, Ordering::Relaxed) + 1;
118 let new_cluster = ClusterConfig { pools: new_pools, ver: new_ver };
119 inner.pools.store(Some(Arc::new(new_cluster)));
120 }
121}
122
123impl<F, P> ClusterConfig<F, P>
124where
125 F: ClientFacts,
126 P: ClientTransport,
127{
128 #[inline]
129 fn select(
130 &self, round_robin: bool, rr_counter: &AtomicUsize, last_index: Option<usize>,
131 ) -> Option<(usize, &ClientPool<FailoverPoolInner<F, P>, P>)> {
132 let l = self.pools.len();
133 if l == 0 {
134 return None;
135 }
136 let seed = if let Some(index) = last_index {
137 index + 1
138 } else if round_robin {
139 rr_counter.fetch_add(1, Ordering::Relaxed)
140 } else {
141 0
142 };
143 for i in seed..seed + l {
144 let pool = &self.pools[i % l];
145 if pool.is_healthy() {
146 return Some((i, pool));
147 }
148 }
149 return None;
150 }
151}
152
153impl<F, P> FailoverPoolInner<F, P>
154where
155 F: ClientFacts,
156 P: ClientTransport,
157{
158 async fn retry_worker(
159 weak_self: Weak<Self>, logger: Arc<LogFilter>, retry_rx: AsyncRx<FailoverTask<F::Task>>,
160 ) {
161 while let Ok(mut task) = retry_rx.recv().await {
162 if let Some(inner) = weak_self.upgrade() {
163 let cluster = inner.pools.load();
164 if let Some(cluster) = cluster.as_ref() {
165 let last_index = if cluster.ver == task.cluster_ver {
167 Some(task.last_index)
168 } else {
169 task.cluster_ver = cluster.ver;
170 None };
172 if let Some((index, pool)) =
173 cluster.select(inner.round_robin, &inner.rr_counter, last_index)
174 {
175 if let Some(last) = last_index {
176 logger_trace!(
177 logger,
178 "FailoverPool: task {:?} retry {}->{}",
179 task.inner,
180 last,
181 index
182 );
183 }
184 task.last_index = index;
185 pool.send_req(task).await; continue;
187 }
188 }
189 logger_debug!(logger, "FailoverPool: no next hoop for {:?}", task.inner);
191 task.done();
192 } else {
193 logger_trace!(logger, "FailoverPool: skip {:?} due to drop", task.inner);
194 task.done();
195 }
196 }
197 logger_trace!(logger, "FailoverPool retry worker exit");
198 }
199}
200
201impl<F, P> Drop for FailoverPool<F, P>
202where
203 F: ClientFacts,
204 P: ClientTransport,
205{
206 fn drop(&mut self) {
207 self.0.pools.store(None);
209 }
210}
211
212impl<F, P> Drop for FailoverPoolInner<F, P>
213where
214 F: ClientFacts,
215 P: ClientTransport,
216{
217 fn drop(&mut self) {
218 logger_trace!(self.logger, "FailoverPool dropped");
219 }
220}
221
222impl<F, P> std::ops::Deref for FailoverPoolInner<F, P>
223where
224 F: ClientFacts,
225 P: ClientTransport,
226{
227 type Target = F;
228
229 fn deref(&self) -> &Self::Target {
230 &self.facts
231 }
232}
233
234impl<F, P> ClientFacts for FailoverPoolInner<F, P>
235where
236 F: ClientFacts,
237 P: ClientTransport,
238{
239 type Codec = F::Codec;
240
241 type Task = FailoverTask<F::Task>;
242
243 #[inline]
244 fn new_logger(&self) -> Arc<LogFilter> {
245 self.facts.new_logger()
246 }
247
248 #[inline]
249 fn get_config(&self) -> &ClientConfig {
250 self.facts.get_config()
251 }
252
253 #[inline]
254 fn error_handle(&self, task: FailoverTask<F::Task>) {
255 if task.should_retry {
256 if task.retry <= self.retry_limit {
257 if let Err(SendError(_task)) = self.retry_tx.send(task) {
258 _task.done();
259 }
260 return;
261 }
262 }
263 task.inner.done();
264 }
265}
266
267impl<F, P> ClientCaller for FailoverPool<F, P>
268where
269 F: ClientFacts,
270 P: ClientTransport,
271{
272 type Facts = F;
273
274 async fn send_req(&self, mut task: F::Task) {
275 let cluster = self.0.pools.load();
276 if let Some(cluster) = cluster.as_ref() {
277 if let Some((index, pool)) =
278 cluster.select(self.0.round_robin, &self.0.rr_counter, None)
279 {
280 let failover_task = FailoverTask {
281 last_index: index,
282 cluster_ver: cluster.ver,
283 inner: task,
284 retry: 0,
285 should_retry: false,
286 };
287 pool.send_req(failover_task).await;
288 return;
289 }
290 }
291
292 task.set_rpc_error(RpcIntErr::Unreachable);
294 task.done();
295 }
296}
297
298impl<F, P> ClientCallerBlocking for FailoverPool<F, P>
299where
300 F: ClientFacts,
301 P: ClientTransport,
302{
303 type Facts = F;
304 fn send_req_blocking(&self, mut task: F::Task) {
305 let cluster = self.0.pools.load();
306 if let Some(cluster) = cluster.as_ref() {
307 if let Some((index, pool)) =
308 cluster.select(self.0.round_robin, &self.0.rr_counter, None)
309 {
310 let failover_task = FailoverTask {
311 last_index: index,
312 cluster_ver: cluster.ver,
313 inner: task,
314 retry: 0,
315 should_retry: false,
316 };
317 pool.send_req_blocking(failover_task);
318 return;
319 }
320 }
321
322 task.set_rpc_error(RpcIntErr::Unreachable);
324 task.done();
325 }
326}
327
328pub struct FailoverTask<T: ClientTask> {
329 last_index: usize,
330 cluster_ver: u64,
331 inner: T,
332 retry: usize,
333 should_retry: bool,
334}
335
336impl<T: ClientTask> ClientTaskEncode for FailoverTask<T> {
337 #[inline(always)]
338 fn encode_req<C: Codec>(&self, codec: &C, buf: &mut Vec<u8>) -> Result<usize, ()> {
339 self.inner.encode_req(codec, buf)
340 }
341
342 #[inline(always)]
343 fn get_req_blob(&self) -> Option<&[u8]> {
344 self.inner.get_req_blob()
345 }
346}
347
348impl<T: ClientTask> ClientTaskDecode for FailoverTask<T> {
349 #[inline(always)]
350 fn decode_resp<C: Codec>(&mut self, codec: &C, buf: &[u8]) -> Result<(), ()> {
351 self.inner.decode_resp(codec, buf)
352 }
353
354 #[inline(always)]
355 fn reserve_resp_blob(&mut self, _size: i32) -> Option<&mut [u8]> {
356 self.inner.reserve_resp_blob(_size)
357 }
358}
359
360impl<T: ClientTask> ClientTaskDone for FailoverTask<T> {
361 #[inline(always)]
362 fn set_custom_error<C: Codec>(&mut self, codec: &C, e: EncodedErr) {
363 self.should_retry = false;
364 self.inner.set_custom_error(codec, e);
365 }
366
367 #[inline(always)]
368 fn set_rpc_error(&mut self, e: RpcIntErr) {
369 if e < RpcIntErr::Method {
370 self.should_retry = true;
371 self.retry += 1;
372 } else {
373 self.should_retry = false;
374 }
375 self.inner.set_rpc_error(e.clone());
376 }
377
378 #[inline(always)]
379 fn set_ok(&mut self) {
380 self.inner.set_ok();
381 }
382
383 #[inline(always)]
384 fn done(self) {
385 self.inner.done();
386 }
387}
388
389impl<T: ClientTask> ClientTaskAction for FailoverTask<T> {
390 #[inline(always)]
391 fn get_action<'a>(&'a self) -> RpcAction<'a> {
392 self.inner.get_action()
393 }
394}
395
396impl<T: ClientTask> std::ops::Deref for FailoverTask<T> {
397 type Target = ClientTaskCommon;
398 fn deref(&self) -> &Self::Target {
399 self.inner.deref()
400 }
401}
402
403impl<T: ClientTask> std::ops::DerefMut for FailoverTask<T> {
404 fn deref_mut(&mut self) -> &mut Self::Target {
405 self.inner.deref_mut()
406 }
407}
408
409impl<T: ClientTask> ClientTask for FailoverTask<T> {}
410
411impl<T: ClientTask> fmt::Debug for FailoverTask<T> {
412 #[inline]
413 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
414 self.inner.fmt(f)
415 }
416}