Skip to main content

razor_stream/client/
failover.rs

1use crate::client::task::*;
2use crate::client::{
3    ClientCaller, ClientCallerBlocking, ClientConfig, ClientFacts, ClientPool, ClientTransport,
4};
5use crate::proto::RpcAction;
6use crate::{
7    Codec,
8    error::{EncodedErr, RpcIntErr},
9};
10use arc_swap::ArcSwapOption;
11use captains_log::filter::LogFilter;
12use crossfire::*;
13use std::fmt;
14use std::sync::{
15    Arc, Weak,
16    atomic::{AtomicU64, AtomicUsize, Ordering},
17};
18
19/// A pool supports failover to multiple address with optional round_robin strategy
20///
21/// Supports async and blocking context.
22///
23/// Only retry RpcIntErr that less than RpcIntErr::Method,
24/// currently ignore custom error due to complexity of generic.
25///
26/// NOTE: there's cycle reference inside FailoverPoolInner and it's ClientPool,
27/// don't clone FailoverPool as it has custom drop. FailoverPool should be put in Arc for usage.
28pub struct FailoverPool<F, P>(Arc<FailoverPoolInner<F, P>>)
29where
30    F: ClientFacts,
31    P: ClientTransport;
32
33struct FailoverPoolInner<F, P>
34where
35    F: ClientFacts,
36    P: ClientTransport,
37{
38    pools: ArcSwapOption<ClusterConfig<F, P>>,
39    round_robin: bool,
40    facts: Arc<F>,
41    retry_limit: usize,
42    retry_tx: MTx<mpsc::List<FailoverTask<F::Task>>>,
43    ver: AtomicU64,
44    rr_counter: AtomicUsize,
45    pool_channel_size: usize,
46    logger: Arc<LogFilter>,
47}
48
49struct ClusterConfig<F, P>
50where
51    F: ClientFacts,
52    P: ClientTransport,
53{
54    pools: Vec<ClientPool<FailoverPoolInner<F, P>, P>>,
55    ver: u64,
56}
57
58impl<F, P> FailoverPool<F, P>
59where
60    F: ClientFacts,
61    P: ClientTransport,
62{
63    pub fn new(
64        facts: Arc<F>, addrs: Vec<String>, round_robin: bool, retry_limit: usize,
65        pool_channel_size: usize,
66    ) -> Self {
67        let (retry_tx, retry_rx) = mpsc::unbounded_async();
68        // NOTE: the ClientPool has cycle reference with FailoverPoolInner
69        let inner = Arc::new(FailoverPoolInner::<F, P> {
70            pools: ArcSwapOption::new(None),
71            round_robin,
72            facts: facts.clone(),
73            retry_limit,
74            retry_tx: retry_tx.into(),
75            ver: AtomicU64::new(1),
76            rr_counter: AtomicUsize::new(0),
77            pool_channel_size,
78            logger: facts.new_logger(),
79        });
80        let mut pools = Vec::with_capacity(addrs.len());
81        for addr in addrs.iter() {
82            let pool = ClientPool::new(inner.clone(), &addr, pool_channel_size);
83            pools.push(pool);
84        }
85        inner.pools.store(Some(Arc::new(ClusterConfig { ver: 0, pools })));
86
87        let retry_logger = facts.new_logger();
88        let weak_self = Arc::downgrade(&inner);
89        facts.spawn_detach(async move {
90            FailoverPoolInner::retry_worker(weak_self, retry_logger, retry_rx).await;
91        });
92        Self(inner)
93    }
94
95    pub fn update_addrs(&self, addrs: Vec<String>) {
96        let inner = &self.0;
97        let old_cluster_arc = inner.pools.load();
98        let old_pools = old_cluster_arc.as_ref().map(|c| c.pools.clone()).unwrap_or_else(Vec::new);
99
100        let mut new_pools = Vec::with_capacity(addrs.len());
101
102        let mut old_pools_map = std::collections::HashMap::with_capacity(old_pools.len());
103        for pool in old_pools {
104            old_pools_map.insert(pool.get_addr().to_string(), pool);
105        }
106
107        for addr in addrs {
108            if let Some(reused_pool) = old_pools_map.remove(&addr) {
109                new_pools.push(reused_pool);
110            } else {
111                // Create a new pool for the new address
112                let new_pool = ClientPool::new(inner.clone(), &addr, inner.pool_channel_size);
113                new_pools.push(new_pool);
114            }
115        }
116
117        let new_ver = inner.ver.fetch_add(1, Ordering::Relaxed) + 1;
118        let new_cluster = ClusterConfig { pools: new_pools, ver: new_ver };
119        inner.pools.store(Some(Arc::new(new_cluster)));
120    }
121}
122
123impl<F, P> ClusterConfig<F, P>
124where
125    F: ClientFacts,
126    P: ClientTransport,
127{
128    #[inline]
129    fn select(
130        &self, round_robin: bool, rr_counter: &AtomicUsize, last_index: Option<usize>,
131    ) -> Option<(usize, &ClientPool<FailoverPoolInner<F, P>, P>)> {
132        let l = self.pools.len();
133        if l == 0 {
134            return None;
135        }
136        let seed = if let Some(index) = last_index {
137            index + 1
138        } else if round_robin {
139            rr_counter.fetch_add(1, Ordering::Relaxed)
140        } else {
141            0
142        };
143        for i in seed..seed + l {
144            let pool = &self.pools[i % l];
145            if pool.is_healthy() {
146                return Some((i, pool));
147            }
148        }
149        return None;
150    }
151}
152
153impl<F, P> FailoverPoolInner<F, P>
154where
155    F: ClientFacts,
156    P: ClientTransport,
157{
158    async fn retry_worker(
159        weak_self: Weak<Self>, logger: Arc<LogFilter>,
160        retry_rx: AsyncRx<mpsc::List<FailoverTask<F::Task>>>,
161    ) {
162        while let Ok(mut task) = retry_rx.recv().await {
163            if let Some(inner) = weak_self.upgrade() {
164                let cluster = inner.pools.load();
165                if let Some(cluster) = cluster.as_ref() {
166                    // if cluster config changed, restart selection
167                    let last_index = if cluster.ver == task.cluster_ver {
168                        Some(task.last_index)
169                    } else {
170                        task.cluster_ver = cluster.ver;
171                        None // restart selection
172                    };
173                    if let Some((index, pool)) =
174                        cluster.select(inner.round_robin, &inner.rr_counter, last_index)
175                    {
176                        if let Some(last) = last_index {
177                            logger_trace!(
178                                logger,
179                                "FailoverPool: task {:?} retry {}->{}",
180                                task.inner,
181                                last,
182                                index
183                            );
184                        }
185                        task.last_index = index;
186                        pool.send_req(task).await; // retry is async
187                        continue;
188                    }
189                }
190                // if we are here, something is wrong, no pool available or selection failed
191                logger_debug!(logger, "FailoverPool: no next hoop for {:?}", task.inner);
192                task.done();
193            } else {
194                logger_trace!(logger, "FailoverPool: skip {:?} due to drop", task.inner);
195                task.done();
196            }
197        }
198        logger_trace!(logger, "FailoverPool retry worker exit");
199    }
200}
201
202impl<F, P> Drop for FailoverPool<F, P>
203where
204    F: ClientFacts,
205    P: ClientTransport,
206{
207    fn drop(&mut self) {
208        // Remove cycle reference before drop
209        self.0.pools.store(None);
210    }
211}
212
213impl<F, P> Drop for FailoverPoolInner<F, P>
214where
215    F: ClientFacts,
216    P: ClientTransport,
217{
218    fn drop(&mut self) {
219        logger_trace!(self.logger, "FailoverPool dropped");
220    }
221}
222
223impl<F, P> std::ops::Deref for FailoverPoolInner<F, P>
224where
225    F: ClientFacts,
226    P: ClientTransport,
227{
228    type Target = F;
229
230    fn deref(&self) -> &Self::Target {
231        &self.facts
232    }
233}
234
235impl<F, P> ClientFacts for FailoverPoolInner<F, P>
236where
237    F: ClientFacts,
238    P: ClientTransport,
239{
240    type Codec = F::Codec;
241
242    type Task = FailoverTask<F::Task>;
243
244    #[inline]
245    fn new_logger(&self) -> Arc<LogFilter> {
246        self.facts.new_logger()
247    }
248
249    #[inline]
250    fn get_config(&self) -> &ClientConfig {
251        self.facts.get_config()
252    }
253
254    #[inline]
255    fn error_handle(&self, task: FailoverTask<F::Task>) {
256        if task.should_retry {
257            if task.retry <= self.retry_limit {
258                if let Err(SendError(_task)) = self.retry_tx.send(task) {
259                    _task.done();
260                }
261                return;
262            }
263        }
264        task.inner.done();
265    }
266}
267
268impl<F, P> ClientCaller for FailoverPool<F, P>
269where
270    F: ClientFacts,
271    P: ClientTransport,
272{
273    type Facts = F;
274
275    async fn send_req(&self, mut task: F::Task) {
276        let cluster = self.0.pools.load();
277        if let Some(cluster) = cluster.as_ref() {
278            if let Some((index, pool)) =
279                cluster.select(self.0.round_robin, &self.0.rr_counter, None)
280            {
281                let failover_task = FailoverTask {
282                    last_index: index,
283                    cluster_ver: cluster.ver,
284                    inner: task,
285                    retry: 0,
286                    should_retry: false,
287                };
288                pool.send_req(failover_task).await;
289                return;
290            }
291        }
292
293        // No pools available
294        task.set_rpc_error(RpcIntErr::Unreachable);
295        task.done();
296    }
297}
298
299impl<F, P> ClientCallerBlocking for FailoverPool<F, P>
300where
301    F: ClientFacts,
302    P: ClientTransport,
303{
304    type Facts = F;
305    fn send_req_blocking(&self, mut task: F::Task) {
306        let cluster = self.0.pools.load();
307        if let Some(cluster) = cluster.as_ref() {
308            if let Some((index, pool)) =
309                cluster.select(self.0.round_robin, &self.0.rr_counter, None)
310            {
311                let failover_task = FailoverTask {
312                    last_index: index,
313                    cluster_ver: cluster.ver,
314                    inner: task,
315                    retry: 0,
316                    should_retry: false,
317                };
318                pool.send_req_blocking(failover_task);
319                return;
320            }
321        }
322
323        // No pools available
324        task.set_rpc_error(RpcIntErr::Unreachable);
325        task.done();
326    }
327}
328
329pub struct FailoverTask<T: ClientTask> {
330    last_index: usize,
331    cluster_ver: u64,
332    inner: T,
333    retry: usize,
334    should_retry: bool,
335}
336
337impl<T: ClientTask> ClientTaskEncode for FailoverTask<T> {
338    #[inline(always)]
339    fn encode_req<C: Codec>(&self, codec: &C, buf: &mut Vec<u8>) -> Result<usize, ()> {
340        self.inner.encode_req(codec, buf)
341    }
342
343    #[inline(always)]
344    fn get_req_blob(&self) -> Option<&[u8]> {
345        self.inner.get_req_blob()
346    }
347}
348
349impl<T: ClientTask> ClientTaskDecode for FailoverTask<T> {
350    #[inline(always)]
351    fn decode_resp<C: Codec>(&mut self, codec: &C, buf: &[u8]) -> Result<(), ()> {
352        self.inner.decode_resp(codec, buf)
353    }
354
355    #[inline(always)]
356    fn reserve_resp_blob(&mut self, _size: i32) -> Option<&mut [u8]> {
357        self.inner.reserve_resp_blob(_size)
358    }
359}
360
361impl<T: ClientTask> ClientTaskDone for FailoverTask<T> {
362    #[inline(always)]
363    fn set_custom_error<C: Codec>(&mut self, codec: &C, e: EncodedErr) {
364        self.should_retry = false;
365        self.inner.set_custom_error(codec, e);
366    }
367
368    #[inline(always)]
369    fn set_rpc_error(&mut self, e: RpcIntErr) {
370        if e < RpcIntErr::Method {
371            self.should_retry = true;
372            self.retry += 1;
373        } else {
374            self.should_retry = false;
375        }
376        self.inner.set_rpc_error(e.clone());
377    }
378
379    #[inline(always)]
380    fn set_ok(&mut self) {
381        self.inner.set_ok();
382    }
383
384    #[inline(always)]
385    fn done(self) {
386        self.inner.done();
387    }
388}
389
390impl<T: ClientTask> ClientTaskAction for FailoverTask<T> {
391    #[inline(always)]
392    fn get_action<'a>(&'a self) -> RpcAction<'a> {
393        self.inner.get_action()
394    }
395}
396
397impl<T: ClientTask> std::ops::Deref for FailoverTask<T> {
398    type Target = ClientTaskCommon;
399    fn deref(&self) -> &Self::Target {
400        self.inner.deref()
401    }
402}
403
404impl<T: ClientTask> std::ops::DerefMut for FailoverTask<T> {
405    fn deref_mut(&mut self) -> &mut Self::Target {
406        self.inner.deref_mut()
407    }
408}
409
410impl<T: ClientTask> ClientTask for FailoverTask<T> {}
411
412impl<T: ClientTask> fmt::Debug for FailoverTask<T> {
413    #[inline]
414    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
415        self.inner.fmt(f)
416    }
417}