use crate::client::task::*;
use crate::client::{
ClientCaller, ClientCallerBlocking, ClientConfig, ClientFacts, ClientTransport, ConnPool,
};
use crate::proto::RpcAction;
use crate::{
Codec,
error::{EncodedErr, RpcIntErr},
};
use ahash::AHashMap;
use arc_swap::ArcSwap;
use captains_log::filter::LogFilter;
use crossfire::{AsyncRx, MTx, SendError, mpsc};
use orb::prelude::{AsyncExec, AsyncRuntime};
use parking_lot::Mutex;
use std::fmt;
use std::sync::{
Arc, Weak,
atomic::{AtomicUsize, Ordering},
};
pub struct FailoverPool<F, P>
where
F: ClientFacts,
P: ClientTransport,
{
inner: Arc<FailoverPoolInner<F, P>>,
}
struct FailoverFacts<F>
where
F: ClientFacts,
{
retry_tx: MTx<mpsc::List<FailoverTask<F::Task>>>,
facts: Arc<F>,
logger: Arc<LogFilter>,
retry_limit: usize,
}
struct FailoverPoolInner<F, P>
where
F: ClientFacts,
P: ClientTransport,
{
pools: ArcSwap<ClusterConfig<F, P>>,
stateless: bool,
next_node: AtomicUsize,
pool_channel_size: usize,
facts: Arc<FailoverFacts<F>>,
add_pool_mutex: Mutex<()>,
rt: Option<<P::RT as AsyncRuntime>::Exec>,
}
struct ClusterConfig<F, P>
where
F: ClientFacts,
P: ClientTransport,
{
pools: Vec<ConnPool<FailoverFacts<F>, P>>,
ver: u64,
}
impl<F, P> FailoverPool<F, P>
where
F: ClientFacts,
P: ClientTransport,
{
pub fn new(
facts: Arc<F>, rt: Option<&<P::RT as AsyncRuntime>::Exec>, addrs: Vec<String>,
stateless: bool, retry_limit: usize, pool_channel_size: usize,
) -> Self {
let (retry_tx, retry_rx) = mpsc::unbounded_async();
let retry_logger = facts.new_logger();
let wrapped_facts =
Arc::new(FailoverFacts { retry_limit, retry_tx, logger: facts.new_logger(), facts });
let mut pools = Vec::with_capacity(addrs.len());
for addr in addrs.iter() {
let pool = ConnPool::new(wrapped_facts.clone(), rt, addr, pool_channel_size);
pools.push(pool);
}
let inner = Arc::new(FailoverPoolInner::<F, P> {
pools: ArcSwap::new(Arc::new(ClusterConfig { ver: 0, pools })),
stateless,
facts: wrapped_facts,
next_node: AtomicUsize::new(0),
pool_channel_size,
add_pool_mutex: Mutex::new(()),
rt: rt.cloned(),
});
let weak_self = Arc::downgrade(&inner);
let f = FailoverPoolInner::retry_worker(weak_self, retry_logger, retry_rx);
if let Some(_rt) = rt {
_rt.spawn_detach(f);
} else {
P::RT::spawn_detach(f);
}
Self { inner }
}
#[inline]
pub fn get_retry_limit(&self) -> usize {
self.inner.facts.retry_limit
}
pub async fn resubmit(
&self, task: F::Task, addr_or_retry: Result<String, usize>, retry_count: usize,
max_retries: Option<usize>,
) where
F::Task: ClientTask,
{
match &addr_or_retry {
Ok(addr) => {
let (pool, index, conf_ver) = self.get_or_add_addr(addr);
self.inner.next_node.store(index, Ordering::SeqCst);
let failover_task = FailoverTask {
last_index: index,
config_ver: conf_ver,
inner: task,
retry: retry_count,
should_retry: false,
max_retries: max_retries.unwrap_or(0),
};
pool.send_req(failover_task).await;
return;
}
Err(last_index) => {
let cluster = self.inner.pools.load();
if let Some((pool, index)) = cluster.select(self.inner.stateless, Err(*last_index))
{
let failover_task = FailoverTask {
last_index: index,
config_ver: cluster.ver,
inner: task,
retry: 0,
should_retry: false,
max_retries: 0,
};
pool.send_req(failover_task).await;
return;
}
let mut task = task;
task.set_rpc_error(RpcIntErr::Unreachable);
task.done();
}
}
}
fn get_or_add_addr(&self, addr: &str) -> (ConnPool<FailoverFacts<F>, P>, usize, u64) {
let inner = &self.inner;
{
let cluster = inner.pools.load();
if let Some((pool, idx)) = cluster.get_by_addr(addr) {
return (pool.clone(), idx, cluster.ver);
}
}
{
let _guard = self.inner.add_pool_mutex.lock();
let old_cluster = self.inner.pools.load_full();
if let Some((pool, idx)) = old_cluster.get_by_addr(addr) {
return (pool.clone(), idx, old_cluster.ver);
}
let mut new_cluster = Vec::with_capacity(old_cluster.pools.len() + 1);
let new_pool = ConnPool::new(inner.facts.clone(), None, addr, inner.pool_channel_size);
new_cluster.push(new_pool.clone());
new_cluster.extend(old_cluster.pools.iter().cloned());
let new_ver = old_cluster.ver.wrapping_add(1);
drop(old_cluster);
let new_cluster = ClusterConfig { pools: new_cluster, ver: new_ver };
inner.pools.store(Arc::new(new_cluster));
(new_pool, 0, new_ver)
}
}
pub fn update_addrs(&self, addrs: Vec<String>) {
let inner = &self.inner;
{
let _guard = self.inner.add_pool_mutex.lock();
let old_cluster = inner.pools.load_full();
let mut new_pools: Vec<ConnPool<FailoverFacts<F>, P>> = Vec::with_capacity(addrs.len());
let mut old_pools_map = AHashMap::with_capacity(old_cluster.pools.len());
for pool in &old_cluster.pools {
old_pools_map.insert(pool.get_addr().to_string(), pool);
}
for addr in addrs {
if let Some(reused_pool) = old_pools_map.remove(&addr) {
new_pools.push(reused_pool.clone());
} else {
let new_pool = ConnPool::new(
inner.facts.clone(),
inner.rt.as_ref(),
&addr,
inner.pool_channel_size,
);
new_pools.push(new_pool);
}
}
let new_ver = old_cluster.ver.wrapping_add(1);
drop(old_cluster);
let new_cluster = ClusterConfig { pools: new_pools, ver: new_ver };
inner.pools.store(Arc::new(new_cluster));
}
}
}
impl<F, P> ClusterConfig<F, P>
where
F: ClientFacts,
P: ClientTransport,
{
#[inline]
fn select(
&self, stateless: bool, route: Result<&AtomicUsize, usize>,
) -> Option<(&ConnPool<FailoverFacts<F>, P>, usize)> {
let l = self.pools.len();
if l == 0 {
return None;
}
let seed = match &route {
Err(index) => *index + 1, Ok(next_node) => {
if stateless {
next_node.fetch_add(1, Ordering::Relaxed)
} else {
next_node.load(Ordering::SeqCst)
}
}
};
for i in seed..seed + l {
let pool = &self.pools[i % l];
if pool.is_healthy() {
return Some((pool, i));
}
}
None
}
fn get_by_addr(&self, addr: &str) -> Option<(&ConnPool<FailoverFacts<F>, P>, usize)> {
for (i, pool) in self.pools.iter().enumerate() {
if pool.get_addr() == addr {
return Some((pool, i));
}
}
None
}
}
impl<F, P> FailoverPoolInner<F, P>
where
F: ClientFacts,
P: ClientTransport,
{
async fn retry_worker(
weak_self: Weak<Self>, logger: Arc<LogFilter>,
retry_rx: AsyncRx<mpsc::List<FailoverTask<F::Task>>>,
) {
while let Ok(mut task) = retry_rx.recv().await {
if let Some(inner) = weak_self.upgrade() {
let cluster = inner.pools.load();
let route = if cluster.ver == task.config_ver {
Err(task.last_index)
} else {
task.config_ver = cluster.ver;
Ok(&inner.next_node) };
if let Some((pool, index)) = cluster.select(inner.stateless, route) {
if let Err(last) = &route {
logger_trace!(
logger,
"FailoverPool: task {:?} retry {}->{}",
task.inner,
last,
index
);
}
task.last_index = index;
pool.send_req(task).await; continue;
}
logger_debug!(logger, "FailoverPool: no next hoop for {:?}", task.inner);
task.done();
} else {
logger_trace!(logger, "FailoverPool: skip {:?} due to drop", task.inner);
task.done();
}
}
logger_trace!(logger, "FailoverPool retry worker exit");
}
}
impl<F, P> Drop for FailoverPoolInner<F, P>
where
F: ClientFacts,
P: ClientTransport,
{
#[inline]
fn drop(&mut self) {
logger_trace!(self.facts.logger, "FailoverPool dropped");
}
}
impl<F> std::ops::Deref for FailoverFacts<F>
where
F: ClientFacts,
{
type Target = F;
#[inline]
fn deref(&self) -> &Self::Target {
self.facts.as_ref()
}
}
impl<F> ClientFacts for FailoverFacts<F>
where
F: ClientFacts,
{
type Codec = F::Codec;
type Task = FailoverTask<F::Task>;
#[inline]
fn new_logger(&self) -> Arc<LogFilter> {
self.facts.new_logger()
}
#[inline]
fn get_config(&self) -> &ClientConfig {
self.facts.get_config()
}
#[inline]
fn error_handle(&self, task: FailoverTask<F::Task>) {
let retry_limit = if task.max_retries > 0 { task.max_retries } else { self.retry_limit };
if task.should_retry && task.retry <= retry_limit {
if let Err(SendError(_task)) = self.retry_tx.send(task) {
_task.done();
}
return;
}
task.inner.done();
}
}
impl<F, P> Clone for FailoverPool<F, P>
where
F: ClientFacts,
P: ClientTransport,
{
#[inline]
fn clone(&self) -> Self {
Self { inner: self.inner.clone() }
}
}
impl<F, P> ClientCaller for FailoverPool<F, P>
where
F: ClientFacts,
P: ClientTransport,
{
type Facts = F;
async fn send_req(&self, mut task: F::Task) {
let cluster = self.inner.pools.load();
if let Some((pool, index)) = cluster.select(self.inner.stateless, Ok(&self.inner.next_node))
{
let failover_task = FailoverTask {
last_index: index,
config_ver: cluster.ver,
inner: task,
retry: 0,
should_retry: false,
max_retries: 0, };
pool.send_req(failover_task).await;
return;
}
task.set_rpc_error(RpcIntErr::Unreachable);
task.done();
}
}
impl<F, P> ClientCallerBlocking for FailoverPool<F, P>
where
F: ClientFacts,
P: ClientTransport,
{
type Facts = F;
fn send_req_blocking(&self, mut task: F::Task) {
let cluster = self.inner.pools.load();
if let Some((pool, index)) = cluster.select(self.inner.stateless, Ok(&self.inner.next_node))
{
let failover_task = FailoverTask {
last_index: index,
config_ver: cluster.ver,
inner: task,
retry: 0,
should_retry: false,
max_retries: 0, };
pool.send_req_blocking(failover_task);
return;
}
task.set_rpc_error(RpcIntErr::Unreachable);
task.done();
}
}
pub struct FailoverTask<T: ClientTask> {
last_index: usize,
config_ver: u64,
inner: T,
retry: usize,
should_retry: bool,
max_retries: usize,
}
impl<T: ClientTask> ClientTaskEncode for FailoverTask<T> {
#[inline(always)]
fn encode_req<C: Codec>(&self, codec: &C, buf: &mut Vec<u8>) -> Result<usize, ()> {
self.inner.encode_req(codec, buf)
}
#[inline(always)]
fn get_req_blob(&self) -> Option<&[u8]> {
self.inner.get_req_blob()
}
}
impl<T: ClientTask> ClientTaskDecode for FailoverTask<T> {
#[inline(always)]
fn decode_resp<C: Codec>(&mut self, codec: &C, buf: &[u8]) -> Result<(), ()> {
self.inner.decode_resp(codec, buf)
}
#[inline(always)]
fn reserve_resp_blob(&mut self, _size: i32) -> Option<&mut [u8]> {
self.inner.reserve_resp_blob(_size)
}
}
impl<T: ClientTask> ClientTaskDone for FailoverTask<T> {
#[inline(always)]
fn set_custom_error<C: Codec>(
&mut self, codec: &C, e: EncodedErr, _last_index: usize, _conf_ver: u64,
) {
self.should_retry = false;
self.inner.set_custom_error(codec, e, self.last_index, self.config_ver);
}
#[inline(always)]
fn set_rpc_error(&mut self, e: RpcIntErr) {
if e < RpcIntErr::Method {
self.should_retry = true;
self.retry += 1;
} else {
self.should_retry = false;
}
self.inner.set_rpc_error(e.clone());
}
#[inline(always)]
fn set_ok(&mut self) {
self.inner.set_ok();
}
#[inline(always)]
fn done(self) {
self.inner.done();
}
}
impl<T: ClientTask> ClientTaskAction for FailoverTask<T> {
#[inline(always)]
fn get_action<'a>(&'a self) -> RpcAction<'a> {
self.inner.get_action()
}
}
impl<T: ClientTask> std::ops::Deref for FailoverTask<T> {
type Target = ClientTaskCommon;
fn deref(&self) -> &Self::Target {
self.inner.deref()
}
}
impl<T: ClientTask> std::ops::DerefMut for FailoverTask<T> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.inner.deref_mut()
}
}
impl<T: ClientTask> ClientTask for FailoverTask<T> {}
impl<T: ClientTask> fmt::Debug for FailoverTask<T> {
#[inline]
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
self.inner.fmt(f)
}
}