razor_stream/client/
failover.rs

1use crate::client::task::*;
2use crate::client::{
3    ClientCaller, ClientCallerBlocking, ClientConfig, ClientFacts, ClientPool, ClientTransport,
4};
5use crate::proto::RpcAction;
6use crate::{
7    Codec,
8    error::{EncodedErr, RpcIntErr},
9};
10use arc_swap::ArcSwapOption;
11use captains_log::filter::LogFilter;
12use crossfire::*;
13use std::fmt;
14use std::sync::{
15    Arc, Weak,
16    atomic::{AtomicU64, AtomicUsize, Ordering},
17};
18
19/// A pool supports failover to multiple address with optional round_robin strategy
20///
21/// Supports async and blocking context.
22///
23/// Only retry RpcIntErr that less than RpcIntErr::Method,
24/// currently ignore custom error due to complexity of generic.
25///
26/// NOTE: there's cycle reference inside FailoverPoolInner and it's ClientPool,
27/// don't clone FailoverPool as it has custom drop. FailoverPool should be put in Arc for usage.
28pub struct FailoverPool<F, P>(Arc<FailoverPoolInner<F, P>>)
29where
30    F: ClientFacts,
31    P: ClientTransport;
32
33struct FailoverPoolInner<F, P>
34where
35    F: ClientFacts,
36    P: ClientTransport,
37{
38    pools: ArcSwapOption<ClusterConfig<F, P>>,
39    round_robin: bool,
40    facts: Arc<F>,
41    retry_limit: usize,
42    retry_tx: MTx<FailoverTask<F::Task>>,
43    ver: AtomicU64,
44    rr_counter: AtomicUsize,
45    pool_channel_size: usize,
46    logger: Arc<LogFilter>,
47}
48
49struct ClusterConfig<F, P>
50where
51    F: ClientFacts,
52    P: ClientTransport,
53{
54    pools: Vec<ClientPool<FailoverPoolInner<F, P>, P>>,
55    ver: u64,
56}
57
58impl<F, P> FailoverPool<F, P>
59where
60    F: ClientFacts,
61    P: ClientTransport,
62{
63    pub fn new(
64        facts: Arc<F>, addrs: Vec<String>, round_robin: bool, retry_limit: usize,
65        pool_channel_size: usize,
66    ) -> Self {
67        let (retry_tx, retry_rx) = mpsc::unbounded_async();
68        // NOTE: the ClientPool has cycle reference with FailoverPoolInner
69        let inner = Arc::new(FailoverPoolInner::<F, P> {
70            pools: ArcSwapOption::new(None),
71            round_robin,
72            facts: facts.clone(),
73            retry_limit,
74            retry_tx: retry_tx.into(),
75            ver: AtomicU64::new(1),
76            rr_counter: AtomicUsize::new(0),
77            pool_channel_size,
78            logger: facts.new_logger(),
79        });
80        let mut pools = Vec::with_capacity(addrs.len());
81        for addr in addrs.iter() {
82            let pool = ClientPool::new(inner.clone(), &addr, pool_channel_size);
83            pools.push(pool);
84        }
85        inner.pools.store(Some(Arc::new(ClusterConfig { ver: 0, pools })));
86
87        let retry_logger = facts.new_logger();
88        let weak_self = Arc::downgrade(&inner);
89        facts.spawn_detach(async move {
90            FailoverPoolInner::retry_worker(weak_self, retry_logger, retry_rx).await;
91        });
92        Self(inner)
93    }
94
95    pub fn update_addrs(&self, addrs: Vec<String>) {
96        let inner = &self.0;
97        let old_cluster_arc = inner.pools.load();
98        let old_pools = old_cluster_arc.as_ref().map(|c| c.pools.clone()).unwrap_or_else(Vec::new);
99
100        let mut new_pools = Vec::with_capacity(addrs.len());
101
102        let mut old_pools_map = std::collections::HashMap::with_capacity(old_pools.len());
103        for pool in old_pools {
104            old_pools_map.insert(pool.get_addr().to_string(), pool);
105        }
106
107        for addr in addrs {
108            if let Some(reused_pool) = old_pools_map.remove(&addr) {
109                new_pools.push(reused_pool);
110            } else {
111                // Create a new pool for the new address
112                let new_pool = ClientPool::new(inner.clone(), &addr, inner.pool_channel_size);
113                new_pools.push(new_pool);
114            }
115        }
116
117        let new_ver = inner.ver.fetch_add(1, Ordering::Relaxed) + 1;
118        let new_cluster = ClusterConfig { pools: new_pools, ver: new_ver };
119        inner.pools.store(Some(Arc::new(new_cluster)));
120    }
121}
122
123impl<F, P> ClusterConfig<F, P>
124where
125    F: ClientFacts,
126    P: ClientTransport,
127{
128    #[inline]
129    fn select(
130        &self, round_robin: bool, rr_counter: &AtomicUsize, last_index: Option<usize>,
131    ) -> Option<(usize, &ClientPool<FailoverPoolInner<F, P>, P>)> {
132        let l = self.pools.len();
133        if l == 0 {
134            return None;
135        }
136        let seed = if let Some(index) = last_index {
137            index + 1
138        } else if round_robin {
139            rr_counter.fetch_add(1, Ordering::Relaxed)
140        } else {
141            0
142        };
143        for i in seed..seed + l {
144            let pool = &self.pools[i % l];
145            if pool.is_healthy() {
146                return Some((i, pool));
147            }
148        }
149        return None;
150    }
151}
152
153impl<F, P> FailoverPoolInner<F, P>
154where
155    F: ClientFacts,
156    P: ClientTransport,
157{
158    async fn retry_worker(
159        weak_self: Weak<Self>, logger: Arc<LogFilter>, retry_rx: AsyncRx<FailoverTask<F::Task>>,
160    ) {
161        while let Ok(mut task) = retry_rx.recv().await {
162            if let Some(inner) = weak_self.upgrade() {
163                let cluster = inner.pools.load();
164                if let Some(cluster) = cluster.as_ref() {
165                    // if cluster config changed, restart selection
166                    let last_index = if cluster.ver == task.cluster_ver {
167                        Some(task.last_index)
168                    } else {
169                        task.cluster_ver = cluster.ver;
170                        None // restart selection
171                    };
172                    if let Some((index, pool)) =
173                        cluster.select(inner.round_robin, &inner.rr_counter, last_index)
174                    {
175                        if let Some(last) = last_index {
176                            logger_trace!(
177                                logger,
178                                "FailoverPool: task {:?} retry {}->{}",
179                                task.inner,
180                                last,
181                                index
182                            );
183                        }
184                        task.last_index = index;
185                        pool.send_req(task).await; // retry is async
186                        continue;
187                    }
188                }
189                // if we are here, something is wrong, no pool available or selection failed
190                logger_debug!(logger, "FailoverPool: no next hoop for {:?}", task.inner);
191                task.done();
192            } else {
193                logger_trace!(logger, "FailoverPool: skip {:?} due to drop", task.inner);
194                task.done();
195            }
196        }
197        logger_trace!(logger, "FailoverPool retry worker exit");
198    }
199}
200
201impl<F, P> Drop for FailoverPool<F, P>
202where
203    F: ClientFacts,
204    P: ClientTransport,
205{
206    fn drop(&mut self) {
207        // Remove cycle reference before drop
208        self.0.pools.store(None);
209    }
210}
211
212impl<F, P> Drop for FailoverPoolInner<F, P>
213where
214    F: ClientFacts,
215    P: ClientTransport,
216{
217    fn drop(&mut self) {
218        logger_trace!(self.logger, "FailoverPool dropped");
219    }
220}
221
222impl<F, P> std::ops::Deref for FailoverPoolInner<F, P>
223where
224    F: ClientFacts,
225    P: ClientTransport,
226{
227    type Target = F;
228
229    fn deref(&self) -> &Self::Target {
230        &self.facts
231    }
232}
233
234impl<F, P> ClientFacts for FailoverPoolInner<F, P>
235where
236    F: ClientFacts,
237    P: ClientTransport,
238{
239    type Codec = F::Codec;
240
241    type Task = FailoverTask<F::Task>;
242
243    #[inline]
244    fn new_logger(&self) -> Arc<LogFilter> {
245        self.facts.new_logger()
246    }
247
248    #[inline]
249    fn get_config(&self) -> &ClientConfig {
250        self.facts.get_config()
251    }
252
253    #[inline]
254    fn error_handle(&self, task: FailoverTask<F::Task>) {
255        if task.should_retry {
256            if task.retry <= self.retry_limit {
257                if let Err(SendError(_task)) = self.retry_tx.send(task) {
258                    _task.done();
259                }
260                return;
261            }
262        }
263        task.inner.done();
264    }
265}
266
267impl<F, P> ClientCaller for FailoverPool<F, P>
268where
269    F: ClientFacts,
270    P: ClientTransport,
271{
272    type Facts = F;
273
274    async fn send_req(&self, mut task: F::Task) {
275        let cluster = self.0.pools.load();
276        if let Some(cluster) = cluster.as_ref() {
277            if let Some((index, pool)) =
278                cluster.select(self.0.round_robin, &self.0.rr_counter, None)
279            {
280                let failover_task = FailoverTask {
281                    last_index: index,
282                    cluster_ver: cluster.ver,
283                    inner: task,
284                    retry: 0,
285                    should_retry: false,
286                };
287                pool.send_req(failover_task).await;
288                return;
289            }
290        }
291
292        // No pools available
293        task.set_rpc_error(RpcIntErr::Unreachable);
294        task.done();
295    }
296}
297
298impl<F, P> ClientCallerBlocking for FailoverPool<F, P>
299where
300    F: ClientFacts,
301    P: ClientTransport,
302{
303    type Facts = F;
304    fn send_req_blocking(&self, mut task: F::Task) {
305        let cluster = self.0.pools.load();
306        if let Some(cluster) = cluster.as_ref() {
307            if let Some((index, pool)) =
308                cluster.select(self.0.round_robin, &self.0.rr_counter, None)
309            {
310                let failover_task = FailoverTask {
311                    last_index: index,
312                    cluster_ver: cluster.ver,
313                    inner: task,
314                    retry: 0,
315                    should_retry: false,
316                };
317                pool.send_req_blocking(failover_task);
318                return;
319            }
320        }
321
322        // No pools available
323        task.set_rpc_error(RpcIntErr::Unreachable);
324        task.done();
325    }
326}
327
328pub struct FailoverTask<T: ClientTask> {
329    last_index: usize,
330    cluster_ver: u64,
331    inner: T,
332    retry: usize,
333    should_retry: bool,
334}
335
336impl<T: ClientTask> ClientTaskEncode for FailoverTask<T> {
337    #[inline(always)]
338    fn encode_req<C: Codec>(&self, codec: &C, buf: &mut Vec<u8>) -> Result<usize, ()> {
339        self.inner.encode_req(codec, buf)
340    }
341
342    #[inline(always)]
343    fn get_req_blob(&self) -> Option<&[u8]> {
344        self.inner.get_req_blob()
345    }
346}
347
348impl<T: ClientTask> ClientTaskDecode for FailoverTask<T> {
349    #[inline(always)]
350    fn decode_resp<C: Codec>(&mut self, codec: &C, buf: &[u8]) -> Result<(), ()> {
351        self.inner.decode_resp(codec, buf)
352    }
353
354    #[inline(always)]
355    fn reserve_resp_blob(&mut self, _size: i32) -> Option<&mut [u8]> {
356        self.inner.reserve_resp_blob(_size)
357    }
358}
359
360impl<T: ClientTask> ClientTaskDone for FailoverTask<T> {
361    #[inline(always)]
362    fn set_custom_error<C: Codec>(&mut self, codec: &C, e: EncodedErr) {
363        self.should_retry = false;
364        self.inner.set_custom_error(codec, e);
365    }
366
367    #[inline(always)]
368    fn set_rpc_error(&mut self, e: RpcIntErr) {
369        if e < RpcIntErr::Method {
370            self.should_retry = true;
371            self.retry += 1;
372        } else {
373            self.should_retry = false;
374        }
375        self.inner.set_rpc_error(e.clone());
376    }
377
378    #[inline(always)]
379    fn set_ok(&mut self) {
380        self.inner.set_ok();
381    }
382
383    #[inline(always)]
384    fn done(self) {
385        self.inner.done();
386    }
387}
388
389impl<T: ClientTask> ClientTaskAction for FailoverTask<T> {
390    #[inline(always)]
391    fn get_action<'a>(&'a self) -> RpcAction<'a> {
392        self.inner.get_action()
393    }
394}
395
396impl<T: ClientTask> std::ops::Deref for FailoverTask<T> {
397    type Target = ClientTaskCommon;
398    fn deref(&self) -> &Self::Target {
399        self.inner.deref()
400    }
401}
402
403impl<T: ClientTask> std::ops::DerefMut for FailoverTask<T> {
404    fn deref_mut(&mut self) -> &mut Self::Target {
405        self.inner.deref_mut()
406    }
407}
408
409impl<T: ClientTask> ClientTask for FailoverTask<T> {}
410
411impl<T: ClientTask> fmt::Debug for FailoverTask<T> {
412    #[inline]
413    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
414        self.inner.fmt(f)
415    }
416}