Skip to main content

razor_stream/client/
failover.rs

1use crate::client::task::*;
2use crate::client::{
3    ClientCaller, ClientCallerBlocking, ClientConfig, ClientFacts, ClientPool, ClientTransport,
4};
5use crate::proto::RpcAction;
6use crate::{
7    Codec,
8    error::{EncodedErr, RpcIntErr},
9};
10use arc_swap::ArcSwap;
11use captains_log::filter::LogFilter;
12use crossfire::{AsyncRx, MTx, SendError, mpsc};
13use std::fmt;
14use std::sync::{
15    Arc, Weak,
16    atomic::{AtomicU64, AtomicUsize, Ordering},
17};
18
19/// A pool supports failover to multiple address with bias or round_robin strategy
20///
21/// Supports async and blocking context.
22///
23/// Only retry RpcIntErr that less than RpcIntErr::Method,
24/// currently ignore custom error due to complexity of generic.
25/// (If you need to custom failover logic, copy the code and impl your own pool.)
26pub struct FailoverPool<F, P>
27where
28    F: ClientFacts,
29    P: ClientTransport,
30{
31    inner: Arc<FailoverPoolInner<F, P>>,
32}
33
34struct FailoverFacts<F>
35where
36    F: ClientFacts,
37{
38    retry_tx: MTx<mpsc::List<FailoverTask<F::Task>>>,
39    facts: Arc<F>,
40    logger: Arc<LogFilter>,
41    retry_limit: usize,
42}
43
44struct FailoverPoolInner<F, P>
45where
46    F: ClientFacts,
47    P: ClientTransport,
48{
49    pools: ArcSwap<ClusterConfig<F, P>>,
50    round_robin: bool,
51    ver: AtomicU64,
52    rr_counter: AtomicUsize,
53    pool_channel_size: usize,
54    facts: Arc<FailoverFacts<F>>,
55}
56
57struct ClusterConfig<F, P>
58where
59    F: ClientFacts,
60    P: ClientTransport,
61{
62    pools: Vec<ClientPool<FailoverFacts<F>, P>>,
63    ver: u64,
64}
65
66impl<F, P> FailoverPool<F, P>
67where
68    F: ClientFacts,
69    P: ClientTransport,
70{
71    /// Initiate the pool with multiple address.
72    /// When round_robin == true, all address in the pool will be select with equal chanced;
73    /// When round_robin == false, the first address will always be pick unless error happens.
74    pub fn new(
75        facts: Arc<F>, addrs: Vec<String>, round_robin: bool, retry_limit: usize,
76        pool_channel_size: usize,
77    ) -> Self {
78        let (retry_tx, retry_rx) = mpsc::unbounded_async();
79        let retry_logger = facts.new_logger();
80        let wrapped_facts =
81            Arc::new(FailoverFacts { retry_limit, retry_tx, logger: facts.new_logger(), facts });
82        let mut pools = Vec::with_capacity(addrs.len());
83        for addr in addrs.iter() {
84            let pool = ClientPool::new(wrapped_facts.clone(), addr, pool_channel_size);
85            pools.push(pool);
86        }
87        // NOTE: the ClientPool has cycle reference with FailoverPoolInner
88        let inner = Arc::new(FailoverPoolInner::<F, P> {
89            pools: ArcSwap::new(Arc::new(ClusterConfig { ver: 0, pools })),
90            round_robin,
91            facts: wrapped_facts,
92            ver: AtomicU64::new(1),
93            rr_counter: AtomicUsize::new(0),
94            pool_channel_size,
95        });
96        let weak_self = Arc::downgrade(&inner);
97        inner.facts.spawn_detach(async move {
98            FailoverPoolInner::retry_worker(weak_self, retry_logger, retry_rx).await;
99        });
100        Self { inner }
101    }
102
103    pub fn update_addrs(&self, addrs: Vec<String>) {
104        let inner = &self.inner;
105        let old_pools = inner.pools.load_full();
106        let mut new_pools: Vec<ClientPool<FailoverFacts<F>, P>> = Vec::with_capacity(addrs.len());
107
108        let mut old_pools_map = std::collections::HashMap::with_capacity(old_pools.pools.len());
109        for pool in &old_pools.pools {
110            old_pools_map.insert(pool.get_addr().to_string(), pool);
111        }
112
113        for addr in addrs {
114            if let Some(reused_pool) = old_pools_map.remove(&addr) {
115                new_pools.push(reused_pool.clone());
116            } else {
117                // Create a new pool for the new address
118                let new_pool = ClientPool::new(inner.facts.clone(), &addr, inner.pool_channel_size);
119                new_pools.push(new_pool);
120            }
121        }
122        let new_ver = inner.ver.fetch_add(1, Ordering::Relaxed) + 1;
123        let new_cluster = ClusterConfig { pools: new_pools, ver: new_ver };
124        inner.pools.store(Arc::new(new_cluster));
125    }
126}
127
128impl<F, P> ClusterConfig<F, P>
129where
130    F: ClientFacts,
131    P: ClientTransport,
132{
133    #[inline]
134    fn select(
135        &self, round_robin: bool, rr_counter: &AtomicUsize, last_index: Option<usize>,
136    ) -> Option<(usize, &ClientPool<FailoverFacts<F>, P>)> {
137        let l = self.pools.len();
138        if l == 0 {
139            return None;
140        }
141        let seed = if let Some(index) = last_index {
142            // last_index is address with error occurs
143            index + 1
144        } else if round_robin {
145            rr_counter.fetch_add(1, Ordering::Relaxed)
146        } else {
147            0
148        };
149        for i in seed..seed + l {
150            let pool = &self.pools[i % l];
151            if pool.is_healthy() {
152                return Some((i, pool));
153            }
154        }
155        return None;
156    }
157}
158
159impl<F, P> FailoverPoolInner<F, P>
160where
161    F: ClientFacts,
162    P: ClientTransport,
163{
164    async fn retry_worker(
165        weak_self: Weak<Self>, logger: Arc<LogFilter>,
166        retry_rx: AsyncRx<mpsc::List<FailoverTask<F::Task>>>,
167    ) {
168        while let Ok(mut task) = retry_rx.recv().await {
169            if let Some(inner) = weak_self.upgrade() {
170                let cluster = inner.pools.load();
171                // if cluster config changed, restart selection
172                let last_index = if cluster.ver == task.cluster_ver {
173                    Some(task.last_index)
174                } else {
175                    task.cluster_ver = cluster.ver;
176                    None // restart selection
177                };
178                if let Some((index, pool)) =
179                    cluster.select(inner.round_robin, &inner.rr_counter, last_index)
180                {
181                    if let Some(last) = last_index {
182                        logger_trace!(
183                            logger,
184                            "FailoverPool: task {:?} retry {}->{}",
185                            task.inner,
186                            last,
187                            index
188                        );
189                    }
190                    task.last_index = index;
191                    pool.send_req(task).await; // retry is async
192                    continue;
193                }
194                // if we are here, something is wrong, no pool available or selection failed
195                logger_debug!(logger, "FailoverPool: no next hoop for {:?}", task.inner);
196                task.done();
197            } else {
198                logger_trace!(logger, "FailoverPool: skip {:?} due to drop", task.inner);
199                task.done();
200            }
201        }
202        logger_trace!(logger, "FailoverPool retry worker exit");
203    }
204}
205
206impl<F, P> Drop for FailoverPoolInner<F, P>
207where
208    F: ClientFacts,
209    P: ClientTransport,
210{
211    #[inline]
212    fn drop(&mut self) {
213        logger_trace!(self.facts.logger, "FailoverPool dropped");
214    }
215}
216
217/// `orb::AsyncRuntime` will follow deref to blanket impl it for wrapper types
218impl<F> std::ops::Deref for FailoverFacts<F>
219where
220    F: ClientFacts,
221{
222    type Target = F;
223
224    #[inline]
225    fn deref(&self) -> &Self::Target {
226        self.facts.as_ref()
227    }
228}
229
230impl<F> ClientFacts for FailoverFacts<F>
231where
232    F: ClientFacts,
233{
234    type Codec = F::Codec;
235
236    type Task = FailoverTask<F::Task>;
237
238    #[inline]
239    fn new_logger(&self) -> Arc<LogFilter> {
240        self.facts.new_logger()
241    }
242
243    #[inline]
244    fn get_config(&self) -> &ClientConfig {
245        self.facts.get_config()
246    }
247
248    #[inline]
249    fn error_handle(&self, task: FailoverTask<F::Task>) {
250        if task.should_retry && task.retry <= self.retry_limit {
251            if let Err(SendError(_task)) = self.retry_tx.send(task) {
252                _task.done();
253            }
254            return;
255        }
256        task.inner.done();
257    }
258}
259
260impl<F, P> Clone for FailoverPool<F, P>
261where
262    F: ClientFacts,
263    P: ClientTransport,
264{
265    #[inline]
266    fn clone(&self) -> Self {
267        Self { inner: self.inner.clone() }
268    }
269}
270
271impl<F, P> ClientCaller for FailoverPool<F, P>
272where
273    F: ClientFacts,
274    P: ClientTransport,
275{
276    type Facts = F;
277
278    async fn send_req(&self, mut task: F::Task) {
279        let cluster = self.inner.pools.load();
280        if let Some((index, pool)) =
281            cluster.select(self.inner.round_robin, &self.inner.rr_counter, None)
282        {
283            let failover_task = FailoverTask {
284                last_index: index,
285                cluster_ver: cluster.ver,
286                inner: task,
287                retry: 0,
288                should_retry: false,
289            };
290            pool.send_req(failover_task).await;
291            return;
292        }
293
294        // No pools available
295        task.set_rpc_error(RpcIntErr::Unreachable);
296        task.done();
297    }
298}
299
300impl<F, P> ClientCallerBlocking for FailoverPool<F, P>
301where
302    F: ClientFacts,
303    P: ClientTransport,
304{
305    type Facts = F;
306    fn send_req_blocking(&self, mut task: F::Task) {
307        let cluster = self.inner.pools.load();
308        if let Some((index, pool)) =
309            cluster.select(self.inner.round_robin, &self.inner.rr_counter, None)
310        {
311            let failover_task = FailoverTask {
312                last_index: index,
313                cluster_ver: cluster.ver,
314                inner: task,
315                retry: 0,
316                should_retry: false,
317            };
318            pool.send_req_blocking(failover_task);
319            return;
320        }
321
322        // No pools available
323        task.set_rpc_error(RpcIntErr::Unreachable);
324        task.done();
325    }
326}
327
328pub struct FailoverTask<T: ClientTask> {
329    last_index: usize,
330    cluster_ver: u64,
331    inner: T,
332    retry: usize,
333    should_retry: bool,
334}
335
336impl<T: ClientTask> ClientTaskEncode for FailoverTask<T> {
337    #[inline(always)]
338    fn encode_req<C: Codec>(&self, codec: &C, buf: &mut Vec<u8>) -> Result<usize, ()> {
339        self.inner.encode_req(codec, buf)
340    }
341
342    #[inline(always)]
343    fn get_req_blob(&self) -> Option<&[u8]> {
344        self.inner.get_req_blob()
345    }
346}
347
348impl<T: ClientTask> ClientTaskDecode for FailoverTask<T> {
349    #[inline(always)]
350    fn decode_resp<C: Codec>(&mut self, codec: &C, buf: &[u8]) -> Result<(), ()> {
351        self.inner.decode_resp(codec, buf)
352    }
353
354    #[inline(always)]
355    fn reserve_resp_blob(&mut self, _size: i32) -> Option<&mut [u8]> {
356        self.inner.reserve_resp_blob(_size)
357    }
358}
359
360impl<T: ClientTask> ClientTaskDone for FailoverTask<T> {
361    #[inline(always)]
362    fn set_custom_error<C: Codec>(&mut self, codec: &C, e: EncodedErr) {
363        self.should_retry = false;
364        self.inner.set_custom_error(codec, e);
365    }
366
367    #[inline(always)]
368    fn set_rpc_error(&mut self, e: RpcIntErr) {
369        if e < RpcIntErr::Method {
370            self.should_retry = true;
371            self.retry += 1;
372        } else {
373            self.should_retry = false;
374        }
375        self.inner.set_rpc_error(e.clone());
376    }
377
378    #[inline(always)]
379    fn set_ok(&mut self) {
380        self.inner.set_ok();
381    }
382
383    #[inline(always)]
384    fn done(self) {
385        self.inner.done();
386    }
387}
388
389impl<T: ClientTask> ClientTaskAction for FailoverTask<T> {
390    #[inline(always)]
391    fn get_action<'a>(&'a self) -> RpcAction<'a> {
392        self.inner.get_action()
393    }
394}
395
396impl<T: ClientTask> std::ops::Deref for FailoverTask<T> {
397    type Target = ClientTaskCommon;
398    fn deref(&self) -> &Self::Target {
399        self.inner.deref()
400    }
401}
402
403impl<T: ClientTask> std::ops::DerefMut for FailoverTask<T> {
404    fn deref_mut(&mut self) -> &mut Self::Target {
405        self.inner.deref_mut()
406    }
407}
408
409impl<T: ClientTask> ClientTask for FailoverTask<T> {}
410
411impl<T: ClientTask> fmt::Debug for FailoverTask<T> {
412    #[inline]
413    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
414        self.inner.fmt(f)
415    }
416}