Skip to main content

razor_stream/client/
failover.rs

1use crate::client::task::*;
2use crate::client::{
3    ClientCaller, ClientCallerBlocking, ClientConfig, ClientFacts, ClientPool, ClientTransport,
4};
5use crate::proto::RpcAction;
6use crate::{
7    Codec,
8    error::{EncodedErr, RpcIntErr},
9};
10use arc_swap::ArcSwap;
11use captains_log::filter::LogFilter;
12use crossfire::{AsyncRx, MTx, SendError, mpsc};
13use orb::AsyncRuntime;
14use std::fmt;
15use std::sync::{
16    Arc, Weak,
17    atomic::{AtomicU64, AtomicUsize, Ordering},
18};
19
20/// A pool supports failover to multiple address with bias or round_robin strategy
21///
22/// Supports async and blocking context.
23///
24/// Only retry RpcIntErr that less than RpcIntErr::Method,
25/// currently ignore custom error due to complexity of generic.
26/// (If you need to custom failover logic, copy the code and impl your own pool.)
27pub struct FailoverPool<F, P>
28where
29    F: ClientFacts,
30    P: ClientTransport,
31{
32    inner: Arc<FailoverPoolInner<F, P>>,
33}
34
35struct FailoverFacts<F>
36where
37    F: ClientFacts,
38{
39    retry_tx: MTx<mpsc::List<FailoverTask<F::Task>>>,
40    facts: Arc<F>,
41    logger: Arc<LogFilter>,
42    retry_limit: usize,
43}
44
45struct FailoverPoolInner<F, P>
46where
47    F: ClientFacts,
48    P: ClientTransport,
49{
50    pools: ArcSwap<ClusterConfig<F, P>>,
51    round_robin: bool,
52    ver: AtomicU64,
53    rr_counter: AtomicUsize,
54    pool_channel_size: usize,
55    facts: Arc<FailoverFacts<F>>,
56}
57
58struct ClusterConfig<F, P>
59where
60    F: ClientFacts,
61    P: ClientTransport,
62{
63    pools: Vec<ClientPool<FailoverFacts<F>, P>>,
64    ver: u64,
65}
66
67impl<F, P> FailoverPool<F, P>
68where
69    F: ClientFacts,
70    P: ClientTransport,
71{
72    /// Initiate the pool with multiple address.
73    /// When round_robin == true, all address in the pool will be select with equal chanced;
74    /// When round_robin == false, the first address will always be pick unless error happens.
75    pub fn new<RT: AsyncRuntime + Clone>(
76        facts: Arc<F>, rt: &RT, addrs: Vec<String>, round_robin: bool, retry_limit: usize,
77        pool_channel_size: usize,
78    ) -> Self {
79        let (retry_tx, retry_rx) = mpsc::unbounded_async();
80        let retry_logger = facts.new_logger();
81        let wrapped_facts =
82            Arc::new(FailoverFacts { retry_limit, retry_tx, logger: facts.new_logger(), facts });
83        let mut pools = Vec::with_capacity(addrs.len());
84        for addr in addrs.iter() {
85            let pool = ClientPool::new::<RT>(wrapped_facts.clone(), rt, addr, pool_channel_size);
86            pools.push(pool);
87        }
88        // NOTE: the ClientPool has cycle reference with FailoverPoolInner
89        let inner = Arc::new(FailoverPoolInner::<F, P> {
90            pools: ArcSwap::new(Arc::new(ClusterConfig { ver: 0, pools })),
91            round_robin,
92            facts: wrapped_facts,
93            ver: AtomicU64::new(1),
94            rr_counter: AtomicUsize::new(0),
95            pool_channel_size,
96        });
97        let weak_self = Arc::downgrade(&inner);
98        rt.spawn_detach(async move {
99            FailoverPoolInner::retry_worker(weak_self, retry_logger, retry_rx).await;
100        });
101        Self { inner }
102    }
103
104    pub fn update_addrs<RT: AsyncRuntime + Clone>(&self, addrs: Vec<String>, rt: &RT) {
105        let inner = &self.inner;
106        let old_pools = inner.pools.load_full();
107        let mut new_pools: Vec<ClientPool<FailoverFacts<F>, P>> = Vec::with_capacity(addrs.len());
108
109        let mut old_pools_map = std::collections::HashMap::with_capacity(old_pools.pools.len());
110        for pool in &old_pools.pools {
111            old_pools_map.insert(pool.get_addr().to_string(), pool);
112        }
113
114        for addr in addrs {
115            if let Some(reused_pool) = old_pools_map.remove(&addr) {
116                new_pools.push(reused_pool.clone());
117            } else {
118                // Create a new pool for the new address
119                let new_pool =
120                    ClientPool::new::<RT>(inner.facts.clone(), rt, &addr, inner.pool_channel_size);
121                new_pools.push(new_pool);
122            }
123        }
124        let new_ver = inner.ver.fetch_add(1, Ordering::Relaxed) + 1;
125        let new_cluster = ClusterConfig { pools: new_pools, ver: new_ver };
126        inner.pools.store(Arc::new(new_cluster));
127    }
128}
129
130impl<F, P> ClusterConfig<F, P>
131where
132    F: ClientFacts,
133    P: ClientTransport,
134{
135    #[inline]
136    fn select(
137        &self, round_robin: bool, rr_counter: &AtomicUsize, last_index: Option<usize>,
138    ) -> Option<(usize, &ClientPool<FailoverFacts<F>, P>)> {
139        let l = self.pools.len();
140        if l == 0 {
141            return None;
142        }
143        let seed = if let Some(index) = last_index {
144            // last_index is address with error occurs
145            index + 1
146        } else if round_robin {
147            rr_counter.fetch_add(1, Ordering::Relaxed)
148        } else {
149            0
150        };
151        for i in seed..seed + l {
152            let pool = &self.pools[i % l];
153            if pool.is_healthy() {
154                return Some((i, pool));
155            }
156        }
157        return None;
158    }
159}
160
161impl<F, P> FailoverPoolInner<F, P>
162where
163    F: ClientFacts,
164    P: ClientTransport,
165{
166    async fn retry_worker(
167        weak_self: Weak<Self>, logger: Arc<LogFilter>,
168        retry_rx: AsyncRx<mpsc::List<FailoverTask<F::Task>>>,
169    ) {
170        while let Ok(mut task) = retry_rx.recv().await {
171            if let Some(inner) = weak_self.upgrade() {
172                let cluster = inner.pools.load();
173                // if cluster config changed, restart selection
174                let last_index = if cluster.ver == task.cluster_ver {
175                    Some(task.last_index)
176                } else {
177                    task.cluster_ver = cluster.ver;
178                    None // restart selection
179                };
180                if let Some((index, pool)) =
181                    cluster.select(inner.round_robin, &inner.rr_counter, last_index)
182                {
183                    if let Some(last) = last_index {
184                        logger_trace!(
185                            logger,
186                            "FailoverPool: task {:?} retry {}->{}",
187                            task.inner,
188                            last,
189                            index
190                        );
191                    }
192                    task.last_index = index;
193                    pool.send_req(task).await; // retry is async
194                    continue;
195                }
196                // if we are here, something is wrong, no pool available or selection failed
197                logger_debug!(logger, "FailoverPool: no next hoop for {:?}", task.inner);
198                task.done();
199            } else {
200                logger_trace!(logger, "FailoverPool: skip {:?} due to drop", task.inner);
201                task.done();
202            }
203        }
204        logger_trace!(logger, "FailoverPool retry worker exit");
205    }
206}
207
208impl<F, P> Drop for FailoverPoolInner<F, P>
209where
210    F: ClientFacts,
211    P: ClientTransport,
212{
213    #[inline]
214    fn drop(&mut self) {
215        logger_trace!(self.facts.logger, "FailoverPool dropped");
216    }
217}
218
219/// `orb::AsyncRuntime` will follow deref to blanket impl it for wrapper types
220impl<F> std::ops::Deref for FailoverFacts<F>
221where
222    F: ClientFacts,
223{
224    type Target = F;
225
226    #[inline]
227    fn deref(&self) -> &Self::Target {
228        self.facts.as_ref()
229    }
230}
231
232impl<F> ClientFacts for FailoverFacts<F>
233where
234    F: ClientFacts,
235{
236    type Codec = F::Codec;
237
238    type Task = FailoverTask<F::Task>;
239
240    #[inline]
241    fn new_logger(&self) -> Arc<LogFilter> {
242        self.facts.new_logger()
243    }
244
245    #[inline]
246    fn get_config(&self) -> &ClientConfig {
247        self.facts.get_config()
248    }
249
250    #[inline]
251    fn error_handle(&self, task: FailoverTask<F::Task>) {
252        if task.should_retry && task.retry <= self.retry_limit {
253            if let Err(SendError(_task)) = self.retry_tx.send(task) {
254                _task.done();
255            }
256            return;
257        }
258        task.inner.done();
259    }
260}
261
262impl<F, P> Clone for FailoverPool<F, P>
263where
264    F: ClientFacts,
265    P: ClientTransport,
266{
267    #[inline]
268    fn clone(&self) -> Self {
269        Self { inner: self.inner.clone() }
270    }
271}
272
273impl<F, P> ClientCaller for FailoverPool<F, P>
274where
275    F: ClientFacts,
276    P: ClientTransport,
277{
278    type Facts = F;
279
280    async fn send_req(&self, mut task: F::Task) {
281        let cluster = self.inner.pools.load();
282        if let Some((index, pool)) =
283            cluster.select(self.inner.round_robin, &self.inner.rr_counter, None)
284        {
285            let failover_task = FailoverTask {
286                last_index: index,
287                cluster_ver: cluster.ver,
288                inner: task,
289                retry: 0,
290                should_retry: false,
291            };
292            pool.send_req(failover_task).await;
293            return;
294        }
295
296        // No pools available
297        task.set_rpc_error(RpcIntErr::Unreachable);
298        task.done();
299    }
300}
301
302impl<F, P> ClientCallerBlocking for FailoverPool<F, P>
303where
304    F: ClientFacts,
305    P: ClientTransport,
306{
307    type Facts = F;
308    fn send_req_blocking(&self, mut task: F::Task) {
309        let cluster = self.inner.pools.load();
310        if let Some((index, pool)) =
311            cluster.select(self.inner.round_robin, &self.inner.rr_counter, None)
312        {
313            let failover_task = FailoverTask {
314                last_index: index,
315                cluster_ver: cluster.ver,
316                inner: task,
317                retry: 0,
318                should_retry: false,
319            };
320            pool.send_req_blocking(failover_task);
321            return;
322        }
323
324        // No pools available
325        task.set_rpc_error(RpcIntErr::Unreachable);
326        task.done();
327    }
328}
329
330pub struct FailoverTask<T: ClientTask> {
331    last_index: usize,
332    cluster_ver: u64,
333    inner: T,
334    retry: usize,
335    should_retry: bool,
336}
337
338impl<T: ClientTask> ClientTaskEncode for FailoverTask<T> {
339    #[inline(always)]
340    fn encode_req<C: Codec>(&self, codec: &C, buf: &mut Vec<u8>) -> Result<usize, ()> {
341        self.inner.encode_req(codec, buf)
342    }
343
344    #[inline(always)]
345    fn get_req_blob(&self) -> Option<&[u8]> {
346        self.inner.get_req_blob()
347    }
348}
349
350impl<T: ClientTask> ClientTaskDecode for FailoverTask<T> {
351    #[inline(always)]
352    fn decode_resp<C: Codec>(&mut self, codec: &C, buf: &[u8]) -> Result<(), ()> {
353        self.inner.decode_resp(codec, buf)
354    }
355
356    #[inline(always)]
357    fn reserve_resp_blob(&mut self, _size: i32) -> Option<&mut [u8]> {
358        self.inner.reserve_resp_blob(_size)
359    }
360}
361
362impl<T: ClientTask> ClientTaskDone for FailoverTask<T> {
363    #[inline(always)]
364    fn set_custom_error<C: Codec>(&mut self, codec: &C, e: EncodedErr) {
365        self.should_retry = false;
366        self.inner.set_custom_error(codec, e);
367    }
368
369    #[inline(always)]
370    fn set_rpc_error(&mut self, e: RpcIntErr) {
371        if e < RpcIntErr::Method {
372            self.should_retry = true;
373            self.retry += 1;
374        } else {
375            self.should_retry = false;
376        }
377        self.inner.set_rpc_error(e.clone());
378    }
379
380    #[inline(always)]
381    fn set_ok(&mut self) {
382        self.inner.set_ok();
383    }
384
385    #[inline(always)]
386    fn done(self) {
387        self.inner.done();
388    }
389}
390
391impl<T: ClientTask> ClientTaskAction for FailoverTask<T> {
392    #[inline(always)]
393    fn get_action<'a>(&'a self) -> RpcAction<'a> {
394        self.inner.get_action()
395    }
396}
397
398impl<T: ClientTask> std::ops::Deref for FailoverTask<T> {
399    type Target = ClientTaskCommon;
400    fn deref(&self) -> &Self::Target {
401        self.inner.deref()
402    }
403}
404
405impl<T: ClientTask> std::ops::DerefMut for FailoverTask<T> {
406    fn deref_mut(&mut self) -> &mut Self::Target {
407        self.inner.deref_mut()
408    }
409}
410
411impl<T: ClientTask> ClientTask for FailoverTask<T> {}
412
413impl<T: ClientTask> fmt::Debug for FailoverTask<T> {
414    #[inline]
415    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
416        self.inner.fmt(f)
417    }
418}