Skip to main content

razor_stream/client/
failover.rs

1use crate::client::task::*;
2use crate::client::{
3    ClientCaller, ClientCallerBlocking, ClientConfig, ClientFacts, ClientTransport, ConnPool,
4};
5use crate::proto::RpcAction;
6use crate::{
7    Codec,
8    error::{EncodedErr, RpcIntErr},
9};
10use ahash::AHashMap;
11use arc_swap::ArcSwap;
12use captains_log::filter::LogFilter;
13use crossfire::{AsyncRx, MTx, SendError, mpsc};
14use orb::prelude::{AsyncExec, AsyncRuntime};
15use parking_lot::Mutex;
16use std::fmt;
17use std::sync::{
18    Arc, Weak,
19    atomic::{AtomicUsize, Ordering},
20};
21
22/// A pool supports failover to multiple addresses with stateless (round-robin) or stateful (leader-based) strategy
23///
24/// Supports async and blocking context.
25///
26/// Only retry RpcIntErr that less than RpcIntErr::Method,
27/// currently ignore custom error due to complexity of generic.
28/// (If you need to custom failover logic, copy the code and impl your own pool.)
29pub struct FailoverPool<F, P>
30where
31    F: ClientFacts,
32    P: ClientTransport,
33{
34    inner: Arc<FailoverPoolInner<F, P>>,
35}
36
37struct FailoverFacts<F>
38where
39    F: ClientFacts,
40{
41    retry_tx: MTx<mpsc::List<FailoverTask<F::Task>>>,
42    facts: Arc<F>,
43    logger: Arc<LogFilter>,
44    retry_limit: usize,
45}
46
47struct FailoverPoolInner<F, P>
48where
49    F: ClientFacts,
50    P: ClientTransport,
51{
52    pools: ArcSwap<ClusterConfig<F, P>>,
53    stateless: bool,
54    /// Next node index for routing:
55    /// - In stateless mode: used for round-robin selection
56    /// - In stateful mode: used as the current leader index
57    next_node: AtomicUsize,
58    pool_channel_size: usize,
59    facts: Arc<FailoverFacts<F>>,
60    /// Mutex to protect concurrent pool addition
61    add_pool_mutex: Mutex<()>,
62    rt: Option<<P::RT as AsyncRuntime>::Exec>,
63}
64
65struct ClusterConfig<F, P>
66where
67    F: ClientFacts,
68    P: ClientTransport,
69{
70    pools: Vec<ConnPool<FailoverFacts<F>, P>>,
71    ver: u64,
72}
73
74impl<F, P> FailoverPool<F, P>
75where
76    F: ClientFacts,
77    P: ClientTransport,
78{
79    /// Initiate the pool with multiple addresses.
80
81    ///
82    /// # Argument
83    ///
84    /// - `rt`: When we are in orb async context, just pass None, otherwise (in thread context),
85    ///   pass the AsyncRuntime::Exec.
86    /// - `stateless`:  When true, all addresses in the pool will be selected with equal chance (round-robin);
87    ///   When  false, the leader address will always be picked unless error happens.
88    pub fn new(
89        facts: Arc<F>, rt: Option<&<P::RT as AsyncRuntime>::Exec>, addrs: Vec<String>,
90        stateless: bool, retry_limit: usize, pool_channel_size: usize,
91    ) -> Self {
92        let (retry_tx, retry_rx) = mpsc::unbounded_async();
93        let retry_logger = facts.new_logger();
94        let wrapped_facts =
95            Arc::new(FailoverFacts { retry_limit, retry_tx, logger: facts.new_logger(), facts });
96        let mut pools = Vec::with_capacity(addrs.len());
97        for addr in addrs.iter() {
98            let pool = ConnPool::new(wrapped_facts.clone(), rt, addr, pool_channel_size);
99            pools.push(pool);
100        }
101        // NOTE: the ConnPool has cycle reference with FailoverPoolInner
102        let inner = Arc::new(FailoverPoolInner::<F, P> {
103            pools: ArcSwap::new(Arc::new(ClusterConfig { ver: 0, pools })),
104            stateless,
105            facts: wrapped_facts,
106            next_node: AtomicUsize::new(0),
107            pool_channel_size,
108            add_pool_mutex: Mutex::new(()),
109            rt: rt.cloned(),
110        });
111        let weak_self = Arc::downgrade(&inner);
112        let f = FailoverPoolInner::retry_worker(weak_self, retry_logger, retry_rx);
113        if let Some(_rt) = rt {
114            _rt.spawn_detach(f);
115        } else {
116            P::RT::spawn_detach(f);
117        }
118        Self { inner }
119    }
120
121    /// Get the retry limit for redirect operations
122    #[inline]
123    pub fn get_retry_limit(&self) -> usize {
124        self.inner.facts.retry_limit
125    }
126
127    /// Resubmit a request for retry with optional specific address.
128    /// Called by APIClientCaller when should_failover returns Ok(_).
129    ///
130    /// NOTE: max_retries is currently not used yet (TODO api interface)
131    pub async fn resubmit(
132        &self, task: F::Task, addr_or_retry: Result<String, usize>, retry_count: usize,
133        max_retries: Option<usize>,
134    ) where
135        F::Task: ClientTask,
136    {
137        // If specific address provided, try to find matching pool
138        match &addr_or_retry {
139            Ok(addr) => {
140                let (pool, index, conf_ver) = self.get_or_add_addr(addr);
141                // TODO should add addr to the pool on-the-fly if not exists
142                // Update next_node (leader in stateful mode)
143                self.inner.next_node.store(index, Ordering::SeqCst);
144                let failover_task = FailoverTask {
145                    last_index: index,
146                    config_ver: conf_ver,
147                    inner: task,
148                    retry: retry_count,
149                    should_retry: false,
150                    max_retries: max_retries.unwrap_or(0),
151                };
152                pool.send_req(failover_task).await;
153                return;
154            }
155            Err(last_index) => {
156                let cluster = self.inner.pools.load();
157                // Fallback to select next node
158                if let Some((pool, index)) = cluster.select(self.inner.stateless, Err(*last_index))
159                {
160                    let failover_task = FailoverTask {
161                        last_index: index,
162                        config_ver: cluster.ver,
163                        inner: task,
164                        retry: 0,
165                        should_retry: false,
166                        max_retries: 0,
167                    };
168                    pool.send_req(failover_task).await;
169                    return;
170                }
171                // No pools available
172                let mut task = task;
173                task.set_rpc_error(RpcIntErr::Unreachable);
174                task.done();
175            }
176        }
177    }
178
179    // return pool, idx, config_ver
180    fn get_or_add_addr(&self, addr: &str) -> (ConnPool<FailoverFacts<F>, P>, usize, u64) {
181        let inner = &self.inner;
182        // Fast path: check if address already exists
183        {
184            let cluster = inner.pools.load();
185            if let Some((pool, idx)) = cluster.get_by_addr(addr) {
186                return (pool.clone(), idx, cluster.ver);
187            }
188        }
189        {
190            // Slow path: need to add new pool, acquire lock to prevent concurrent modification
191            let _guard = self.inner.add_pool_mutex.lock();
192            // Double-check after acquiring lock (another thread might have added it)
193            let old_cluster = self.inner.pools.load_full();
194            if let Some((pool, idx)) = old_cluster.get_by_addr(addr) {
195                return (pool.clone(), idx, old_cluster.ver);
196            }
197            let mut new_cluster = Vec::with_capacity(old_cluster.pools.len() + 1);
198            // Create new pool for the address
199            let new_pool = ConnPool::new(inner.facts.clone(), None, addr, inner.pool_channel_size);
200
201            // Build new cluster config with the new pool inserted at front (index 0)
202            // New address is likely the leader, so prioritize it
203            new_cluster.push(new_pool.clone());
204            new_cluster.extend(old_cluster.pools.iter().cloned());
205            let new_ver = old_cluster.ver.wrapping_add(1);
206            drop(old_cluster);
207            let new_cluster = ClusterConfig { pools: new_cluster, ver: new_ver };
208            inner.pools.store(Arc::new(new_cluster));
209            (new_pool, 0, new_ver)
210        }
211    }
212
213    pub fn update_addrs(&self, addrs: Vec<String>) {
214        let inner = &self.inner;
215        {
216            let _guard = self.inner.add_pool_mutex.lock();
217            let old_cluster = inner.pools.load_full();
218            let mut new_pools: Vec<ConnPool<FailoverFacts<F>, P>> = Vec::with_capacity(addrs.len());
219            let mut old_pools_map = AHashMap::with_capacity(old_cluster.pools.len());
220            for pool in &old_cluster.pools {
221                old_pools_map.insert(pool.get_addr().to_string(), pool);
222            }
223            for addr in addrs {
224                if let Some(reused_pool) = old_pools_map.remove(&addr) {
225                    new_pools.push(reused_pool.clone());
226                } else {
227                    // Create a new pool for the new address
228                    let new_pool = ConnPool::new(
229                        inner.facts.clone(),
230                        inner.rt.as_ref(),
231                        &addr,
232                        inner.pool_channel_size,
233                    );
234                    new_pools.push(new_pool);
235                }
236            }
237            let new_ver = old_cluster.ver.wrapping_add(1);
238            drop(old_cluster);
239            let new_cluster = ClusterConfig { pools: new_pools, ver: new_ver };
240            inner.pools.store(Arc::new(new_cluster));
241        }
242    }
243}
244
245impl<F, P> ClusterConfig<F, P>
246where
247    F: ClientFacts,
248    P: ClientTransport,
249{
250    /// Select a pool based on routing strategy
251    /// - stateless=true: round-robin selection using rr_counter
252    /// - stateless=false: leader-based selection, fallback to round-robin if leader unhealthy
253    /// - route: Ok(next_node), Err(last_index)
254    #[inline]
255    fn select(
256        &self, stateless: bool, route: Result<&AtomicUsize, usize>,
257    ) -> Option<(&ConnPool<FailoverFacts<F>, P>, usize)> {
258        let l = self.pools.len();
259        if l == 0 {
260            return None;
261        }
262        // TODO should compare config version (if version is changed, the order or addr of the pool
263        // might be changed)
264        let seed = match &route {
265            Err(index) => *index + 1, // try next backup node
266            Ok(next_node) => {
267                // first time
268                if stateless {
269                    // round-robin
270                    next_node.fetch_add(1, Ordering::Relaxed)
271                } else {
272                    next_node.load(Ordering::SeqCst)
273                }
274            }
275        };
276        for i in seed..seed + l {
277            let pool = &self.pools[i % l];
278            if pool.is_healthy() {
279                return Some((pool, i));
280            }
281        }
282        None
283    }
284
285    /// Find pool by address
286    fn get_by_addr(&self, addr: &str) -> Option<(&ConnPool<FailoverFacts<F>, P>, usize)> {
287        for (i, pool) in self.pools.iter().enumerate() {
288            if pool.get_addr() == addr {
289                // NOTE: we ignore pool healthy state, we don't know the knowledge of server is
290                // more up to day than the client, just try it.
291                return Some((pool, i));
292            }
293        }
294        None
295    }
296}
297
298impl<F, P> FailoverPoolInner<F, P>
299where
300    F: ClientFacts,
301    P: ClientTransport,
302{
303    async fn retry_worker(
304        weak_self: Weak<Self>, logger: Arc<LogFilter>,
305        retry_rx: AsyncRx<mpsc::List<FailoverTask<F::Task>>>,
306    ) {
307        while let Ok(mut task) = retry_rx.recv().await {
308            if let Some(inner) = weak_self.upgrade() {
309                let cluster = inner.pools.load();
310                let route = if cluster.ver == task.config_ver {
311                    Err(task.last_index)
312                } else {
313                    // if cluster config changed (outside source update the address or leader
314                    // changed), restart selection
315                    task.config_ver = cluster.ver;
316                    Ok(&inner.next_node) // restart selection
317                };
318                if let Some((pool, index)) = cluster.select(inner.stateless, route) {
319                    if let Err(last) = &route {
320                        logger_trace!(
321                            logger,
322                            "FailoverPool: task {:?} retry {}->{}",
323                            task.inner,
324                            last,
325                            index
326                        );
327                    }
328                    task.last_index = index;
329                    pool.send_req(task).await; // retry is async
330                    continue;
331                }
332                // if we are here, something is wrong, no pool available or selection failed
333                logger_debug!(logger, "FailoverPool: no next hoop for {:?}", task.inner);
334                task.done();
335            } else {
336                logger_trace!(logger, "FailoverPool: skip {:?} due to drop", task.inner);
337                task.done();
338            }
339        }
340        logger_trace!(logger, "FailoverPool retry worker exit");
341    }
342}
343
344impl<F, P> Drop for FailoverPoolInner<F, P>
345where
346    F: ClientFacts,
347    P: ClientTransport,
348{
349    #[inline]
350    fn drop(&mut self) {
351        logger_trace!(self.facts.logger, "FailoverPool dropped");
352    }
353}
354
355/// `orb::AsyncRuntime` will follow deref to blanket impl it for wrapper types
356impl<F> std::ops::Deref for FailoverFacts<F>
357where
358    F: ClientFacts,
359{
360    type Target = F;
361
362    #[inline]
363    fn deref(&self) -> &Self::Target {
364        self.facts.as_ref()
365    }
366}
367
368impl<F> ClientFacts for FailoverFacts<F>
369where
370    F: ClientFacts,
371{
372    type Codec = F::Codec;
373
374    type Task = FailoverTask<F::Task>;
375
376    #[inline]
377    fn new_logger(&self) -> Arc<LogFilter> {
378        self.facts.new_logger()
379    }
380
381    #[inline]
382    fn get_config(&self) -> &ClientConfig {
383        self.facts.get_config()
384    }
385
386    #[inline]
387    fn error_handle(&self, task: FailoverTask<F::Task>) {
388        // Use the max_retries from the task if set (non-zero), otherwise use the default retry_limit
389        let retry_limit = if task.max_retries > 0 { task.max_retries } else { self.retry_limit };
390        if task.should_retry && task.retry <= retry_limit {
391            if let Err(SendError(_task)) = self.retry_tx.send(task) {
392                _task.done();
393            }
394            return;
395        }
396        task.inner.done();
397    }
398}
399
400impl<F, P> Clone for FailoverPool<F, P>
401where
402    F: ClientFacts,
403    P: ClientTransport,
404{
405    #[inline]
406    fn clone(&self) -> Self {
407        Self { inner: self.inner.clone() }
408    }
409}
410
411impl<F, P> ClientCaller for FailoverPool<F, P>
412where
413    F: ClientFacts,
414    P: ClientTransport,
415{
416    type Facts = F;
417
418    async fn send_req(&self, mut task: F::Task) {
419        let cluster = self.inner.pools.load();
420        if let Some((pool, index)) = cluster.select(self.inner.stateless, Ok(&self.inner.next_node))
421        {
422            let failover_task = FailoverTask {
423                last_index: index,
424                config_ver: cluster.ver,
425                inner: task,
426                retry: 0,
427                should_retry: false,
428                max_retries: 0, // Use default retry_limit from FailoverFacts
429            };
430            pool.send_req(failover_task).await;
431            return;
432        }
433
434        // No pools available
435        task.set_rpc_error(RpcIntErr::Unreachable);
436        task.done();
437    }
438}
439
440impl<F, P> ClientCallerBlocking for FailoverPool<F, P>
441where
442    F: ClientFacts,
443    P: ClientTransport,
444{
445    type Facts = F;
446    fn send_req_blocking(&self, mut task: F::Task) {
447        let cluster = self.inner.pools.load();
448        if let Some((pool, index)) = cluster.select(self.inner.stateless, Ok(&self.inner.next_node))
449        {
450            let failover_task = FailoverTask {
451                last_index: index,
452                config_ver: cluster.ver,
453                inner: task,
454                retry: 0,
455                should_retry: false,
456                max_retries: 0, // Use default retry_limit from FailoverFacts
457            };
458            pool.send_req_blocking(failover_task);
459            return;
460        }
461
462        // No pools available
463        task.set_rpc_error(RpcIntErr::Unreachable);
464        task.done();
465    }
466}
467
468pub struct FailoverTask<T: ClientTask> {
469    last_index: usize,
470    config_ver: u64,
471    inner: T,
472    retry: usize,
473    should_retry: bool,
474    /// default to be 0, only set by resubmit with custom value for specified interface
475    max_retries: usize,
476}
477
478impl<T: ClientTask> ClientTaskEncode for FailoverTask<T> {
479    #[inline(always)]
480    fn encode_req<C: Codec>(&self, codec: &C, buf: &mut Vec<u8>) -> Result<usize, ()> {
481        self.inner.encode_req(codec, buf)
482    }
483
484    #[inline(always)]
485    fn get_req_blob(&self) -> Option<&[u8]> {
486        self.inner.get_req_blob()
487    }
488}
489
490impl<T: ClientTask> ClientTaskDecode for FailoverTask<T> {
491    #[inline(always)]
492    fn decode_resp<C: Codec>(&mut self, codec: &C, buf: &[u8]) -> Result<(), ()> {
493        self.inner.decode_resp(codec, buf)
494    }
495
496    #[inline(always)]
497    fn reserve_resp_blob(&mut self, _size: i32) -> Option<&mut [u8]> {
498        self.inner.reserve_resp_blob(_size)
499    }
500}
501
502impl<T: ClientTask> ClientTaskDone for FailoverTask<T> {
503    #[inline(always)]
504    fn set_custom_error<C: Codec>(
505        &mut self, codec: &C, e: EncodedErr, _last_index: usize, _conf_ver: u64,
506    ) {
507        self.should_retry = false;
508        self.inner.set_custom_error(codec, e, self.last_index, self.config_ver);
509    }
510
511    #[inline(always)]
512    fn set_rpc_error(&mut self, e: RpcIntErr) {
513        if e < RpcIntErr::Method {
514            self.should_retry = true;
515            self.retry += 1;
516        } else {
517            self.should_retry = false;
518        }
519        self.inner.set_rpc_error(e.clone());
520    }
521
522    #[inline(always)]
523    fn set_ok(&mut self) {
524        self.inner.set_ok();
525    }
526
527    #[inline(always)]
528    fn done(self) {
529        self.inner.done();
530    }
531}
532
533impl<T: ClientTask> ClientTaskAction for FailoverTask<T> {
534    #[inline(always)]
535    fn get_action<'a>(&'a self) -> RpcAction<'a> {
536        self.inner.get_action()
537    }
538}
539
540impl<T: ClientTask> std::ops::Deref for FailoverTask<T> {
541    type Target = ClientTaskCommon;
542    fn deref(&self) -> &Self::Target {
543        self.inner.deref()
544    }
545}
546
547impl<T: ClientTask> std::ops::DerefMut for FailoverTask<T> {
548    fn deref_mut(&mut self) -> &mut Self::Target {
549        self.inner.deref_mut()
550    }
551}
552
553impl<T: ClientTask> ClientTask for FailoverTask<T> {}
554
555impl<T: ClientTask> fmt::Debug for FailoverTask<T> {
556    #[inline]
557    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
558        self.inner.fmt(f)
559    }
560}