1use crate::client::task::*;
2use crate::client::{
3 ClientCaller, ClientCallerBlocking, ClientConfig, ClientFacts, ClientPool, ClientTransport,
4};
5use crate::proto::RpcAction;
6use crate::{
7 Codec,
8 error::{EncodedErr, RpcIntErr},
9};
10use arc_swap::ArcSwapOption;
11use captains_log::filter::LogFilter;
12use crossfire::*;
13use std::fmt;
14use std::sync::{
15 Arc, Weak,
16 atomic::{AtomicU64, AtomicUsize, Ordering},
17};
18
19pub struct FailoverPool<F, P>(Arc<FailoverPoolInner<F, P>>)
29where
30 F: ClientFacts,
31 P: ClientTransport;
32
33struct FailoverPoolInner<F, P>
34where
35 F: ClientFacts,
36 P: ClientTransport,
37{
38 pools: ArcSwapOption<ClusterConfig<F, P>>,
39 round_robin: bool,
40 facts: Arc<F>,
41 retry_limit: usize,
42 retry_tx: MTx<mpsc::List<FailoverTask<F::Task>>>,
43 ver: AtomicU64,
44 rr_counter: AtomicUsize,
45 pool_channel_size: usize,
46 logger: Arc<LogFilter>,
47}
48
49struct ClusterConfig<F, P>
50where
51 F: ClientFacts,
52 P: ClientTransport,
53{
54 pools: Vec<ClientPool<FailoverPoolInner<F, P>, P>>,
55 ver: u64,
56}
57
58impl<F, P> FailoverPool<F, P>
59where
60 F: ClientFacts,
61 P: ClientTransport,
62{
63 pub fn new(
64 facts: Arc<F>, addrs: Vec<String>, round_robin: bool, retry_limit: usize,
65 pool_channel_size: usize,
66 ) -> Self {
67 let (retry_tx, retry_rx) = mpsc::unbounded_async();
68 let inner = Arc::new(FailoverPoolInner::<F, P> {
70 pools: ArcSwapOption::new(None),
71 round_robin,
72 facts: facts.clone(),
73 retry_limit,
74 retry_tx: retry_tx.into(),
75 ver: AtomicU64::new(1),
76 rr_counter: AtomicUsize::new(0),
77 pool_channel_size,
78 logger: facts.new_logger(),
79 });
80 let mut pools = Vec::with_capacity(addrs.len());
81 for addr in addrs.iter() {
82 let pool = ClientPool::new(inner.clone(), &addr, pool_channel_size);
83 pools.push(pool);
84 }
85 inner.pools.store(Some(Arc::new(ClusterConfig { ver: 0, pools })));
86
87 let retry_logger = facts.new_logger();
88 let weak_self = Arc::downgrade(&inner);
89 facts.spawn_detach(async move {
90 FailoverPoolInner::retry_worker(weak_self, retry_logger, retry_rx).await;
91 });
92 Self(inner)
93 }
94
95 pub fn update_addrs(&self, addrs: Vec<String>) {
96 let inner = &self.0;
97 let old_cluster_arc = inner.pools.load();
98 let old_pools = old_cluster_arc.as_ref().map(|c| c.pools.clone()).unwrap_or_else(Vec::new);
99
100 let mut new_pools = Vec::with_capacity(addrs.len());
101
102 let mut old_pools_map = std::collections::HashMap::with_capacity(old_pools.len());
103 for pool in old_pools {
104 old_pools_map.insert(pool.get_addr().to_string(), pool);
105 }
106
107 for addr in addrs {
108 if let Some(reused_pool) = old_pools_map.remove(&addr) {
109 new_pools.push(reused_pool);
110 } else {
111 let new_pool = ClientPool::new(inner.clone(), &addr, inner.pool_channel_size);
113 new_pools.push(new_pool);
114 }
115 }
116
117 let new_ver = inner.ver.fetch_add(1, Ordering::Relaxed) + 1;
118 let new_cluster = ClusterConfig { pools: new_pools, ver: new_ver };
119 inner.pools.store(Some(Arc::new(new_cluster)));
120 }
121}
122
123impl<F, P> ClusterConfig<F, P>
124where
125 F: ClientFacts,
126 P: ClientTransport,
127{
128 #[inline]
129 fn select(
130 &self, round_robin: bool, rr_counter: &AtomicUsize, last_index: Option<usize>,
131 ) -> Option<(usize, &ClientPool<FailoverPoolInner<F, P>, P>)> {
132 let l = self.pools.len();
133 if l == 0 {
134 return None;
135 }
136 let seed = if let Some(index) = last_index {
137 index + 1
138 } else if round_robin {
139 rr_counter.fetch_add(1, Ordering::Relaxed)
140 } else {
141 0
142 };
143 for i in seed..seed + l {
144 let pool = &self.pools[i % l];
145 if pool.is_healthy() {
146 return Some((i, pool));
147 }
148 }
149 return None;
150 }
151}
152
153impl<F, P> FailoverPoolInner<F, P>
154where
155 F: ClientFacts,
156 P: ClientTransport,
157{
158 async fn retry_worker(
159 weak_self: Weak<Self>, logger: Arc<LogFilter>,
160 retry_rx: AsyncRx<mpsc::List<FailoverTask<F::Task>>>,
161 ) {
162 while let Ok(mut task) = retry_rx.recv().await {
163 if let Some(inner) = weak_self.upgrade() {
164 let cluster = inner.pools.load();
165 if let Some(cluster) = cluster.as_ref() {
166 let last_index = if cluster.ver == task.cluster_ver {
168 Some(task.last_index)
169 } else {
170 task.cluster_ver = cluster.ver;
171 None };
173 if let Some((index, pool)) =
174 cluster.select(inner.round_robin, &inner.rr_counter, last_index)
175 {
176 if let Some(last) = last_index {
177 logger_trace!(
178 logger,
179 "FailoverPool: task {:?} retry {}->{}",
180 task.inner,
181 last,
182 index
183 );
184 }
185 task.last_index = index;
186 pool.send_req(task).await; continue;
188 }
189 }
190 logger_debug!(logger, "FailoverPool: no next hoop for {:?}", task.inner);
192 task.done();
193 } else {
194 logger_trace!(logger, "FailoverPool: skip {:?} due to drop", task.inner);
195 task.done();
196 }
197 }
198 logger_trace!(logger, "FailoverPool retry worker exit");
199 }
200}
201
202impl<F, P> Drop for FailoverPool<F, P>
203where
204 F: ClientFacts,
205 P: ClientTransport,
206{
207 fn drop(&mut self) {
208 self.0.pools.store(None);
210 }
211}
212
213impl<F, P> Drop for FailoverPoolInner<F, P>
214where
215 F: ClientFacts,
216 P: ClientTransport,
217{
218 fn drop(&mut self) {
219 logger_trace!(self.logger, "FailoverPool dropped");
220 }
221}
222
223impl<F, P> std::ops::Deref for FailoverPoolInner<F, P>
224where
225 F: ClientFacts,
226 P: ClientTransport,
227{
228 type Target = F;
229
230 fn deref(&self) -> &Self::Target {
231 &self.facts
232 }
233}
234
235impl<F, P> ClientFacts for FailoverPoolInner<F, P>
236where
237 F: ClientFacts,
238 P: ClientTransport,
239{
240 type Codec = F::Codec;
241
242 type Task = FailoverTask<F::Task>;
243
244 #[inline]
245 fn new_logger(&self) -> Arc<LogFilter> {
246 self.facts.new_logger()
247 }
248
249 #[inline]
250 fn get_config(&self) -> &ClientConfig {
251 self.facts.get_config()
252 }
253
254 #[inline]
255 fn error_handle(&self, task: FailoverTask<F::Task>) {
256 if task.should_retry {
257 if task.retry <= self.retry_limit {
258 if let Err(SendError(_task)) = self.retry_tx.send(task) {
259 _task.done();
260 }
261 return;
262 }
263 }
264 task.inner.done();
265 }
266}
267
268impl<F, P> ClientCaller for FailoverPool<F, P>
269where
270 F: ClientFacts,
271 P: ClientTransport,
272{
273 type Facts = F;
274
275 async fn send_req(&self, mut task: F::Task) {
276 let cluster = self.0.pools.load();
277 if let Some(cluster) = cluster.as_ref() {
278 if let Some((index, pool)) =
279 cluster.select(self.0.round_robin, &self.0.rr_counter, None)
280 {
281 let failover_task = FailoverTask {
282 last_index: index,
283 cluster_ver: cluster.ver,
284 inner: task,
285 retry: 0,
286 should_retry: false,
287 };
288 pool.send_req(failover_task).await;
289 return;
290 }
291 }
292
293 task.set_rpc_error(RpcIntErr::Unreachable);
295 task.done();
296 }
297}
298
299impl<F, P> ClientCallerBlocking for FailoverPool<F, P>
300where
301 F: ClientFacts,
302 P: ClientTransport,
303{
304 type Facts = F;
305 fn send_req_blocking(&self, mut task: F::Task) {
306 let cluster = self.0.pools.load();
307 if let Some(cluster) = cluster.as_ref() {
308 if let Some((index, pool)) =
309 cluster.select(self.0.round_robin, &self.0.rr_counter, None)
310 {
311 let failover_task = FailoverTask {
312 last_index: index,
313 cluster_ver: cluster.ver,
314 inner: task,
315 retry: 0,
316 should_retry: false,
317 };
318 pool.send_req_blocking(failover_task);
319 return;
320 }
321 }
322
323 task.set_rpc_error(RpcIntErr::Unreachable);
325 task.done();
326 }
327}
328
329pub struct FailoverTask<T: ClientTask> {
330 last_index: usize,
331 cluster_ver: u64,
332 inner: T,
333 retry: usize,
334 should_retry: bool,
335}
336
337impl<T: ClientTask> ClientTaskEncode for FailoverTask<T> {
338 #[inline(always)]
339 fn encode_req<C: Codec>(&self, codec: &C, buf: &mut Vec<u8>) -> Result<usize, ()> {
340 self.inner.encode_req(codec, buf)
341 }
342
343 #[inline(always)]
344 fn get_req_blob(&self) -> Option<&[u8]> {
345 self.inner.get_req_blob()
346 }
347}
348
349impl<T: ClientTask> ClientTaskDecode for FailoverTask<T> {
350 #[inline(always)]
351 fn decode_resp<C: Codec>(&mut self, codec: &C, buf: &[u8]) -> Result<(), ()> {
352 self.inner.decode_resp(codec, buf)
353 }
354
355 #[inline(always)]
356 fn reserve_resp_blob(&mut self, _size: i32) -> Option<&mut [u8]> {
357 self.inner.reserve_resp_blob(_size)
358 }
359}
360
361impl<T: ClientTask> ClientTaskDone for FailoverTask<T> {
362 #[inline(always)]
363 fn set_custom_error<C: Codec>(&mut self, codec: &C, e: EncodedErr) {
364 self.should_retry = false;
365 self.inner.set_custom_error(codec, e);
366 }
367
368 #[inline(always)]
369 fn set_rpc_error(&mut self, e: RpcIntErr) {
370 if e < RpcIntErr::Method {
371 self.should_retry = true;
372 self.retry += 1;
373 } else {
374 self.should_retry = false;
375 }
376 self.inner.set_rpc_error(e.clone());
377 }
378
379 #[inline(always)]
380 fn set_ok(&mut self) {
381 self.inner.set_ok();
382 }
383
384 #[inline(always)]
385 fn done(self) {
386 self.inner.done();
387 }
388}
389
390impl<T: ClientTask> ClientTaskAction for FailoverTask<T> {
391 #[inline(always)]
392 fn get_action<'a>(&'a self) -> RpcAction<'a> {
393 self.inner.get_action()
394 }
395}
396
397impl<T: ClientTask> std::ops::Deref for FailoverTask<T> {
398 type Target = ClientTaskCommon;
399 fn deref(&self) -> &Self::Target {
400 self.inner.deref()
401 }
402}
403
404impl<T: ClientTask> std::ops::DerefMut for FailoverTask<T> {
405 fn deref_mut(&mut self) -> &mut Self::Target {
406 self.inner.deref_mut()
407 }
408}
409
410impl<T: ClientTask> ClientTask for FailoverTask<T> {}
411
412impl<T: ClientTask> fmt::Debug for FailoverTask<T> {
413 #[inline]
414 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
415 self.inner.fmt(f)
416 }
417}