Skip to main content

razor_stream/client/
failover.rs

1use crate::client::task::*;
2use crate::client::{
3    ClientCaller, ClientCallerBlocking, ClientConfig, ClientFacts, ClientTransport, ConnPool,
4};
5use crate::proto::RpcAction;
6use crate::{
7    Codec,
8    error::{EncodedErr, RpcIntErr},
9};
10use ahash::AHashMap;
11use arc_swap::ArcSwap;
12use captains_log::filter::LogFilter;
13use crossfire::{AsyncRx, MTx, SendError, mpsc};
14use orb::prelude::AsyncExec;
15use parking_lot::Mutex;
16use std::fmt;
17use std::sync::{
18    Arc, Weak,
19    atomic::{AtomicUsize, Ordering},
20};
21
22/// A pool supports failover to multiple addresses with stateless (round-robin) or stateful (leader-based) strategy
23///
24/// Supports async and blocking context.
25///
26/// Only retry RpcIntErr that less than RpcIntErr::Method,
27/// currently ignore custom error due to complexity of generic.
28/// (If you need to custom failover logic, copy the code and impl your own pool.)
29pub struct FailoverPool<F, P>
30where
31    F: ClientFacts,
32    P: ClientTransport,
33{
34    inner: Arc<FailoverPoolInner<F, P>>,
35}
36
37struct FailoverFacts<F>
38where
39    F: ClientFacts,
40{
41    retry_tx: MTx<mpsc::List<FailoverTask<F::Task>>>,
42    facts: Arc<F>,
43    logger: Arc<LogFilter>,
44    retry_limit: usize,
45}
46
47struct FailoverPoolInner<F, P>
48where
49    F: ClientFacts,
50    P: ClientTransport,
51{
52    pools: ArcSwap<ClusterConfig<F, P>>,
53    stateless: bool,
54    /// Next node index for routing:
55    /// - In stateless mode: used for round-robin selection
56    /// - In stateful mode: used as the current leader index
57    next_node: AtomicUsize,
58    pool_channel_size: usize,
59    facts: Arc<FailoverFacts<F>>,
60    rt: P::RT,
61    /// Mutex to protect concurrent pool addition
62    add_pool_mutex: Mutex<()>,
63}
64
65struct ClusterConfig<F, P>
66where
67    F: ClientFacts,
68    P: ClientTransport,
69{
70    pools: Vec<ConnPool<FailoverFacts<F>, P>>,
71    ver: u64,
72}
73
74impl<F, P> FailoverPool<F, P>
75where
76    F: ClientFacts,
77    P: ClientTransport,
78{
79    /// Initiate the pool with multiple addresses.
80    /// When stateless == true, all addresses in the pool will be selected with equal chance (round-robin);
81    /// When stateless == false, the leader address will always be picked unless error happens.
82    pub fn new(
83        facts: Arc<F>, rt: &P::RT, addrs: Vec<String>, stateless: bool, retry_limit: usize,
84        pool_channel_size: usize,
85    ) -> Self {
86        let (retry_tx, retry_rx) = mpsc::unbounded_async();
87        let retry_logger = facts.new_logger();
88        let wrapped_facts =
89            Arc::new(FailoverFacts { retry_limit, retry_tx, logger: facts.new_logger(), facts });
90        let mut pools = Vec::with_capacity(addrs.len());
91        for addr in addrs.iter() {
92            let pool = ConnPool::new(wrapped_facts.clone(), rt, addr, pool_channel_size);
93            pools.push(pool);
94        }
95        // NOTE: the ConnPool has cycle reference with FailoverPoolInner
96        let inner = Arc::new(FailoverPoolInner::<F, P> {
97            pools: ArcSwap::new(Arc::new(ClusterConfig { ver: 0, pools })),
98            stateless,
99            facts: wrapped_facts,
100            next_node: AtomicUsize::new(0),
101            pool_channel_size,
102            rt: rt.clone(),
103            add_pool_mutex: Mutex::new(()),
104        });
105        let weak_self = Arc::downgrade(&inner);
106        rt.spawn_detach(async move {
107            FailoverPoolInner::retry_worker(weak_self, retry_logger, retry_rx).await;
108        });
109        Self { inner }
110    }
111
112    /// Get the retry limit for redirect operations
113    #[inline]
114    pub fn get_retry_limit(&self) -> usize {
115        self.inner.facts.retry_limit
116    }
117
118    /// Resubmit a request for retry with optional specific address.
119    /// Called by APIClientCaller when should_failover returns Ok(_).
120    ///
121    /// NOTE: max_retries is currently not used yet (TODO api interface)
122    pub async fn resubmit(
123        &self, task: F::Task, addr_or_retry: Result<String, usize>, retry_count: usize,
124        max_retries: Option<usize>,
125    ) where
126        F::Task: ClientTask,
127    {
128        // If specific address provided, try to find matching pool
129        match &addr_or_retry {
130            Ok(addr) => {
131                let (pool, index, conf_ver) = self.get_or_add_addr(addr);
132                // TODO should add addr to the pool on-the-fly if not exists
133                // Update next_node (leader in stateful mode)
134                self.inner.next_node.store(index, Ordering::SeqCst);
135                let failover_task = FailoverTask {
136                    last_index: index,
137                    config_ver: conf_ver,
138                    inner: task,
139                    retry: retry_count,
140                    should_retry: false,
141                    max_retries: max_retries.unwrap_or(0),
142                };
143                pool.send_req(failover_task).await;
144                return;
145            }
146            Err(last_index) => {
147                let cluster = self.inner.pools.load();
148                // Fallback to select next node
149                if let Some((pool, index)) = cluster.select(self.inner.stateless, Err(*last_index))
150                {
151                    let failover_task = FailoverTask {
152                        last_index: index,
153                        config_ver: cluster.ver,
154                        inner: task,
155                        retry: 0,
156                        should_retry: false,
157                        max_retries: 0,
158                    };
159                    pool.send_req(failover_task).await;
160                    return;
161                }
162                // No pools available
163                let mut task = task;
164                task.set_rpc_error(RpcIntErr::Unreachable);
165                task.done();
166            }
167        }
168    }
169
170    // return pool, idx, config_ver
171    fn get_or_add_addr(&self, addr: &str) -> (ConnPool<FailoverFacts<F>, P>, usize, u64) {
172        let inner = &self.inner;
173        // Fast path: check if address already exists
174        {
175            let cluster = inner.pools.load();
176            if let Some((pool, idx)) = cluster.get_by_addr(addr) {
177                return (pool.clone(), idx, cluster.ver);
178            }
179        }
180        {
181            // Slow path: need to add new pool, acquire lock to prevent concurrent modification
182            let _guard = self.inner.add_pool_mutex.lock();
183            // Double-check after acquiring lock (another thread might have added it)
184            let old_cluster = self.inner.pools.load_full();
185            if let Some((pool, idx)) = old_cluster.get_by_addr(addr) {
186                return (pool.clone(), idx, old_cluster.ver);
187            }
188            let mut new_cluster = Vec::with_capacity(old_cluster.pools.len() + 1);
189            // Create new pool for the address
190            let new_pool =
191                ConnPool::new(inner.facts.clone(), &inner.rt, addr, inner.pool_channel_size);
192
193            // Build new cluster config with the new pool inserted at front (index 0)
194            // New address is likely the leader, so prioritize it
195            new_cluster.push(new_pool.clone());
196            new_cluster.extend(old_cluster.pools.iter().cloned());
197            let new_ver = old_cluster.ver.wrapping_add(1);
198            drop(old_cluster);
199            let new_cluster = ClusterConfig { pools: new_cluster, ver: new_ver };
200            inner.pools.store(Arc::new(new_cluster));
201            (new_pool, 0, new_ver)
202        }
203    }
204
205    pub fn update_addrs(&self, addrs: Vec<String>) {
206        let inner = &self.inner;
207        {
208            let _guard = self.inner.add_pool_mutex.lock();
209            let old_cluster = inner.pools.load_full();
210            let mut new_pools: Vec<ConnPool<FailoverFacts<F>, P>> = Vec::with_capacity(addrs.len());
211            let mut old_pools_map = AHashMap::with_capacity(old_cluster.pools.len());
212            for pool in &old_cluster.pools {
213                old_pools_map.insert(pool.get_addr().to_string(), pool);
214            }
215            for addr in addrs {
216                if let Some(reused_pool) = old_pools_map.remove(&addr) {
217                    new_pools.push(reused_pool.clone());
218                } else {
219                    // Create a new pool for the new address
220                    let new_pool = ConnPool::new(
221                        inner.facts.clone(),
222                        &inner.rt,
223                        &addr,
224                        inner.pool_channel_size,
225                    );
226                    new_pools.push(new_pool);
227                }
228            }
229            let new_ver = old_cluster.ver.wrapping_add(1);
230            drop(old_cluster);
231            let new_cluster = ClusterConfig { pools: new_pools, ver: new_ver };
232            inner.pools.store(Arc::new(new_cluster));
233        }
234    }
235}
236
237impl<F, P> ClusterConfig<F, P>
238where
239    F: ClientFacts,
240    P: ClientTransport,
241{
242    /// Select a pool based on routing strategy
243    /// - stateless=true: round-robin selection using rr_counter
244    /// - stateless=false: leader-based selection, fallback to round-robin if leader unhealthy
245    /// - route: Ok(next_node), Err(last_index)
246    #[inline]
247    fn select(
248        &self, stateless: bool, route: Result<&AtomicUsize, usize>,
249    ) -> Option<(&ConnPool<FailoverFacts<F>, P>, usize)> {
250        let l = self.pools.len();
251        if l == 0 {
252            return None;
253        }
254        // TODO should compare config version (if version is changed, the order or addr of the pool
255        // might be changed)
256        let seed = match &route {
257            Err(index) => *index + 1, // try next backup node
258            Ok(next_node) => {
259                // first time
260                if stateless {
261                    // round-robin
262                    next_node.fetch_add(1, Ordering::Relaxed)
263                } else {
264                    next_node.load(Ordering::SeqCst)
265                }
266            }
267        };
268        for i in seed..seed + l {
269            let pool = &self.pools[i % l];
270            if pool.is_healthy() {
271                return Some((pool, i));
272            }
273        }
274        None
275    }
276
277    /// Find pool by address
278    fn get_by_addr(&self, addr: &str) -> Option<(&ConnPool<FailoverFacts<F>, P>, usize)> {
279        for (i, pool) in self.pools.iter().enumerate() {
280            if pool.get_addr() == addr {
281                // NOTE: we ignore pool healthy state, we don't know the knowledge of server is
282                // more up to day than the client, just try it.
283                return Some((pool, i));
284            }
285        }
286        None
287    }
288}
289
290impl<F, P> FailoverPoolInner<F, P>
291where
292    F: ClientFacts,
293    P: ClientTransport,
294{
295    async fn retry_worker(
296        weak_self: Weak<Self>, logger: Arc<LogFilter>,
297        retry_rx: AsyncRx<mpsc::List<FailoverTask<F::Task>>>,
298    ) {
299        while let Ok(mut task) = retry_rx.recv().await {
300            if let Some(inner) = weak_self.upgrade() {
301                let cluster = inner.pools.load();
302                let route = if cluster.ver == task.config_ver {
303                    Err(task.last_index)
304                } else {
305                    // if cluster config changed (outside source update the address or leader
306                    // changed), restart selection
307                    task.config_ver = cluster.ver;
308                    Ok(&inner.next_node) // restart selection
309                };
310                if let Some((pool, index)) = cluster.select(inner.stateless, route) {
311                    if let Err(last) = &route {
312                        logger_trace!(
313                            logger,
314                            "FailoverPool: task {:?} retry {}->{}",
315                            task.inner,
316                            last,
317                            index
318                        );
319                    }
320                    task.last_index = index;
321                    pool.send_req(task).await; // retry is async
322                    continue;
323                }
324                // if we are here, something is wrong, no pool available or selection failed
325                logger_debug!(logger, "FailoverPool: no next hoop for {:?}", task.inner);
326                task.done();
327            } else {
328                logger_trace!(logger, "FailoverPool: skip {:?} due to drop", task.inner);
329                task.done();
330            }
331        }
332        logger_trace!(logger, "FailoverPool retry worker exit");
333    }
334}
335
336impl<F, P> Drop for FailoverPoolInner<F, P>
337where
338    F: ClientFacts,
339    P: ClientTransport,
340{
341    #[inline]
342    fn drop(&mut self) {
343        logger_trace!(self.facts.logger, "FailoverPool dropped");
344    }
345}
346
347/// `orb::AsyncRuntime` will follow deref to blanket impl it for wrapper types
348impl<F> std::ops::Deref for FailoverFacts<F>
349where
350    F: ClientFacts,
351{
352    type Target = F;
353
354    #[inline]
355    fn deref(&self) -> &Self::Target {
356        self.facts.as_ref()
357    }
358}
359
360impl<F> ClientFacts for FailoverFacts<F>
361where
362    F: ClientFacts,
363{
364    type Codec = F::Codec;
365
366    type Task = FailoverTask<F::Task>;
367
368    #[inline]
369    fn new_logger(&self) -> Arc<LogFilter> {
370        self.facts.new_logger()
371    }
372
373    #[inline]
374    fn get_config(&self) -> &ClientConfig {
375        self.facts.get_config()
376    }
377
378    #[inline]
379    fn error_handle(&self, task: FailoverTask<F::Task>) {
380        // Use the max_retries from the task if set (non-zero), otherwise use the default retry_limit
381        let retry_limit = if task.max_retries > 0 { task.max_retries } else { self.retry_limit };
382        if task.should_retry && task.retry <= retry_limit {
383            if let Err(SendError(_task)) = self.retry_tx.send(task) {
384                _task.done();
385            }
386            return;
387        }
388        task.inner.done();
389    }
390}
391
392impl<F, P> Clone for FailoverPool<F, P>
393where
394    F: ClientFacts,
395    P: ClientTransport,
396{
397    #[inline]
398    fn clone(&self) -> Self {
399        Self { inner: self.inner.clone() }
400    }
401}
402
403impl<F, P> ClientCaller for FailoverPool<F, P>
404where
405    F: ClientFacts,
406    P: ClientTransport,
407{
408    type Facts = F;
409
410    async fn send_req(&self, mut task: F::Task) {
411        let cluster = self.inner.pools.load();
412        if let Some((pool, index)) = cluster.select(self.inner.stateless, Ok(&self.inner.next_node))
413        {
414            let failover_task = FailoverTask {
415                last_index: index,
416                config_ver: cluster.ver,
417                inner: task,
418                retry: 0,
419                should_retry: false,
420                max_retries: 0, // Use default retry_limit from FailoverFacts
421            };
422            pool.send_req(failover_task).await;
423            return;
424        }
425
426        // No pools available
427        task.set_rpc_error(RpcIntErr::Unreachable);
428        task.done();
429    }
430}
431
432impl<F, P> ClientCallerBlocking for FailoverPool<F, P>
433where
434    F: ClientFacts,
435    P: ClientTransport,
436{
437    type Facts = F;
438    fn send_req_blocking(&self, mut task: F::Task) {
439        let cluster = self.inner.pools.load();
440        if let Some((pool, index)) = cluster.select(self.inner.stateless, Ok(&self.inner.next_node))
441        {
442            let failover_task = FailoverTask {
443                last_index: index,
444                config_ver: cluster.ver,
445                inner: task,
446                retry: 0,
447                should_retry: false,
448                max_retries: 0, // Use default retry_limit from FailoverFacts
449            };
450            pool.send_req_blocking(failover_task);
451            return;
452        }
453
454        // No pools available
455        task.set_rpc_error(RpcIntErr::Unreachable);
456        task.done();
457    }
458}
459
460pub struct FailoverTask<T: ClientTask> {
461    last_index: usize,
462    config_ver: u64,
463    inner: T,
464    retry: usize,
465    should_retry: bool,
466    /// default to be 0, only set by resubmit with custom value for specified interface
467    max_retries: usize,
468}
469
470impl<T: ClientTask> ClientTaskEncode for FailoverTask<T> {
471    #[inline(always)]
472    fn encode_req<C: Codec>(&self, codec: &C, buf: &mut Vec<u8>) -> Result<usize, ()> {
473        self.inner.encode_req(codec, buf)
474    }
475
476    #[inline(always)]
477    fn get_req_blob(&self) -> Option<&[u8]> {
478        self.inner.get_req_blob()
479    }
480}
481
482impl<T: ClientTask> ClientTaskDecode for FailoverTask<T> {
483    #[inline(always)]
484    fn decode_resp<C: Codec>(&mut self, codec: &C, buf: &[u8]) -> Result<(), ()> {
485        self.inner.decode_resp(codec, buf)
486    }
487
488    #[inline(always)]
489    fn reserve_resp_blob(&mut self, _size: i32) -> Option<&mut [u8]> {
490        self.inner.reserve_resp_blob(_size)
491    }
492}
493
494impl<T: ClientTask> ClientTaskDone for FailoverTask<T> {
495    #[inline(always)]
496    fn set_custom_error<C: Codec>(
497        &mut self, codec: &C, e: EncodedErr, _last_index: usize, _conf_ver: u64,
498    ) {
499        self.should_retry = false;
500        self.inner.set_custom_error(codec, e, self.last_index, self.config_ver);
501    }
502
503    #[inline(always)]
504    fn set_rpc_error(&mut self, e: RpcIntErr) {
505        if e < RpcIntErr::Method {
506            self.should_retry = true;
507            self.retry += 1;
508        } else {
509            self.should_retry = false;
510        }
511        self.inner.set_rpc_error(e.clone());
512    }
513
514    #[inline(always)]
515    fn set_ok(&mut self) {
516        self.inner.set_ok();
517    }
518
519    #[inline(always)]
520    fn done(self) {
521        self.inner.done();
522    }
523}
524
525impl<T: ClientTask> ClientTaskAction for FailoverTask<T> {
526    #[inline(always)]
527    fn get_action<'a>(&'a self) -> RpcAction<'a> {
528        self.inner.get_action()
529    }
530}
531
532impl<T: ClientTask> std::ops::Deref for FailoverTask<T> {
533    type Target = ClientTaskCommon;
534    fn deref(&self) -> &Self::Target {
535        self.inner.deref()
536    }
537}
538
539impl<T: ClientTask> std::ops::DerefMut for FailoverTask<T> {
540    fn deref_mut(&mut self) -> &mut Self::Target {
541        self.inner.deref_mut()
542    }
543}
544
545impl<T: ClientTask> ClientTask for FailoverTask<T> {}
546
547impl<T: ClientTask> fmt::Debug for FailoverTask<T> {
548    #[inline]
549    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
550        self.inner.fmt(f)
551    }
552}