use parking_lot::Mutex;
use std::sync::Arc;
use std::sync::OnceLock;
use std::time::{Duration, Instant};
use bytes::Bytes;
use serde::{de::DeserializeOwned, Serialize};
use crate::mesh::Mesh;
use crate::mesh_rpc::{
CallOptions, CallOptionsTyped, CodecDirection, RpcError, RpcReply, RpcStatus,
};
pub type RetryablePredicate = Arc<dyn Fn(&RpcError) -> bool + Send + Sync>;
#[derive(Clone)]
pub struct RetryPolicy {
pub max_attempts: u32,
pub initial_backoff: Duration,
pub max_backoff: Duration,
pub backoff_multiplier: f64,
pub jitter: bool,
pub retryable: RetryablePredicate,
}
impl std::fmt::Debug for RetryPolicy {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RetryPolicy")
.field("max_attempts", &self.max_attempts)
.field("initial_backoff", &self.initial_backoff)
.field("max_backoff", &self.max_backoff)
.field("backoff_multiplier", &self.backoff_multiplier)
.field("jitter", &self.jitter)
.field("retryable", &"<fn>")
.finish()
}
}
impl Default for RetryPolicy {
fn default() -> Self {
Self {
max_attempts: 3,
initial_backoff: Duration::from_millis(50),
max_backoff: Duration::from_secs(1),
backoff_multiplier: 2.0,
jitter: true,
retryable: Arc::new(default_retryable),
}
}
}
impl RetryPolicy {
pub fn with_retryable<F: Fn(&RpcError) -> bool + Send + Sync + 'static>(
mut self,
predicate: F,
) -> Self {
self.retryable = Arc::new(predicate);
self
}
}
pub fn default_retryable(err: &RpcError) -> bool {
match err {
RpcError::NoRoute { .. } => false,
RpcError::Timeout { .. } => true,
RpcError::Transport(_) => true,
RpcError::ServerError { status, .. } => {
*status == RpcStatus::Internal.to_wire()
|| *status == RpcStatus::Backpressure.to_wire()
|| *status == RpcStatus::Timeout.to_wire()
}
RpcError::Codec { .. } => false,
RpcError::CapabilityDenied { .. } => false,
RpcError::Cancelled => false,
}
}
impl Mesh {
pub async fn call_with_retry(
&self,
target_node_id: u64,
service: &str,
payload: Bytes,
opts: CallOptions,
policy: &RetryPolicy,
) -> std::result::Result<RpcReply, RpcError> {
retry_loop(policy, |attempt| {
let payload = payload.clone();
let opts = opts.clone();
let _ = attempt;
async move { self.call(target_node_id, service, payload, opts).await }
})
.await
}
pub async fn call_service_with_retry(
&self,
service: &str,
payload: Bytes,
opts: CallOptions,
policy: &RetryPolicy,
) -> std::result::Result<RpcReply, RpcError> {
retry_loop(policy, |attempt| {
let payload = payload.clone();
let opts = opts.clone();
let _ = attempt;
async move { self.call_service(service, payload, opts).await }
})
.await
}
pub async fn call_typed_with_retry<Req, Resp>(
&self,
target_node_id: u64,
service: &str,
request: &Req,
opts: CallOptionsTyped,
policy: &RetryPolicy,
) -> std::result::Result<Resp, RpcError>
where
Req: Serialize,
Resp: DeserializeOwned,
{
let codec = opts.codec;
let body = codec.encode(request).map_err(|e| RpcError::Codec {
direction: CodecDirection::Encode,
message: format!("client encode: {e}"),
})?;
let body = Bytes::from(body);
let reply = self
.call_with_retry(target_node_id, service, body, opts.raw, policy)
.await?;
codec.decode(&reply.body).map_err(|e| RpcError::Codec {
direction: CodecDirection::Decode,
message: format!("client decode: {e}"),
})
}
pub async fn call_service_typed_with_retry<Req, Resp>(
&self,
service: &str,
request: &Req,
opts: CallOptionsTyped,
policy: &RetryPolicy,
) -> std::result::Result<Resp, RpcError>
where
Req: Serialize,
Resp: DeserializeOwned,
{
let codec = opts.codec;
let body = codec.encode(request).map_err(|e| RpcError::Codec {
direction: CodecDirection::Encode,
message: format!("client encode: {e}"),
})?;
let body = Bytes::from(body);
let reply = self
.call_service_with_retry(service, body, opts.raw, policy)
.await?;
codec.decode(&reply.body).map_err(|e| RpcError::Codec {
direction: CodecDirection::Decode,
message: format!("client decode: {e}"),
})
}
}
async fn retry_loop<T, F, Fut>(
policy: &RetryPolicy,
mut attempt_fn: F,
) -> std::result::Result<T, RpcError>
where
F: FnMut(u32) -> Fut,
Fut: std::future::Future<Output = std::result::Result<T, RpcError>>,
{
let max = policy.max_attempts.max(1);
let mut last_err: Option<RpcError> = None;
for attempt in 1..=max {
match attempt_fn(attempt).await {
Ok(value) => return Ok(value),
Err(e) => {
let retryable = (policy.retryable)(&e);
let is_last = attempt == max;
if !retryable || is_last {
return Err(e);
}
let backoff = compute_backoff(policy, attempt);
last_err = Some(e);
if !backoff.is_zero() {
tokio::time::sleep(backoff).await;
}
}
}
}
Err(last_err.unwrap_or_else(|| {
RpcError::Transport(net::error::AdapterError::Connection(
"retry_loop: exhausted with no error captured (bug)".into(),
))
}))
}
#[derive(Debug, Clone)]
pub struct HedgePolicy {
pub delay: Duration,
pub hedges: u32,
}
impl Default for HedgePolicy {
fn default() -> Self {
Self {
delay: Duration::from_millis(50),
hedges: 1,
}
}
}
impl Mesh {
pub async fn call_with_hedge_to(
&self,
targets: &[u64],
service: &str,
payload: Bytes,
opts: CallOptions,
policy: &HedgePolicy,
) -> std::result::Result<RpcReply, RpcError> {
if targets.is_empty() {
return Err(RpcError::NoRoute {
target: 0,
reason: "call_with_hedge_to: targets is empty".into(),
});
}
let total = (1 + policy.hedges as usize).min(targets.len());
let chosen: Vec<u64> = targets[..total].to_vec();
hedge_race(self, &chosen, service, payload, opts, policy.delay).await
}
pub async fn call_service_with_hedge(
&self,
service: &str,
payload: Bytes,
opts: CallOptions,
policy: &HedgePolicy,
) -> std::result::Result<RpcReply, RpcError> {
let candidates = self.resolve_hedge_candidates(service)?;
let total = (1 + policy.hedges as usize).min(candidates.len());
let chosen = &candidates[..total];
hedge_race(self, chosen, service, payload, opts, policy.delay).await
}
pub async fn call_typed_with_hedge_to<Req, Resp>(
&self,
targets: &[u64],
service: &str,
request: &Req,
opts: CallOptionsTyped,
policy: &HedgePolicy,
) -> std::result::Result<Resp, RpcError>
where
Req: Serialize,
Resp: DeserializeOwned,
{
let codec = opts.codec;
let body = codec.encode(request).map_err(|e| RpcError::Codec {
direction: CodecDirection::Encode,
message: format!("client encode: {e}"),
})?;
let reply = self
.call_with_hedge_to(targets, service, Bytes::from(body), opts.raw, policy)
.await?;
codec.decode(&reply.body).map_err(|e| RpcError::Codec {
direction: CodecDirection::Decode,
message: format!("client decode: {e}"),
})
}
pub async fn call_service_typed_with_hedge<Req, Resp>(
&self,
service: &str,
request: &Req,
opts: CallOptionsTyped,
policy: &HedgePolicy,
) -> std::result::Result<Resp, RpcError>
where
Req: Serialize,
Resp: DeserializeOwned,
{
let codec = opts.codec;
let body = codec.encode(request).map_err(|e| RpcError::Codec {
direction: CodecDirection::Encode,
message: format!("client encode: {e}"),
})?;
let reply = self
.call_service_with_hedge(service, Bytes::from(body), opts.raw, policy)
.await?;
codec.decode(&reply.body).map_err(|e| RpcError::Codec {
direction: CodecDirection::Decode,
message: format!("client decode: {e}"),
})
}
fn resolve_hedge_candidates(&self, service: &str) -> std::result::Result<Vec<u64>, RpcError> {
let mut candidates = self.find_service_nodes(service);
if candidates.is_empty() {
return Err(RpcError::NoRoute {
target: 0,
reason: format!("no nodes advertise `nrpc:{service}`"),
});
}
candidates.sort_unstable();
Ok(candidates)
}
}
async fn hedge_race(
mesh: &Mesh,
targets: &[u64],
service: &str,
payload: Bytes,
opts: CallOptions,
delay: Duration,
) -> std::result::Result<RpcReply, RpcError> {
use futures::future::FutureExt;
let node = mesh.node_arc();
let service_owned = service.to_string();
let mut futures: Vec<
futures::future::BoxFuture<'static, (usize, std::result::Result<RpcReply, RpcError>)>,
> = targets
.iter()
.copied()
.enumerate()
.map(|(idx, target)| {
let node = Arc::clone(&node);
let service = service_owned.clone();
let payload = payload.clone();
let opts = opts.clone();
let wait = delay.saturating_mul(idx as u32);
async move {
if !wait.is_zero() {
tokio::time::sleep(wait).await;
}
let r = node.call(target, &service, payload, opts).await;
(idx, r)
}
.boxed()
})
.collect();
let mut errors: Vec<Option<RpcError>> = (0..targets.len()).map(|_| None).collect();
while !futures.is_empty() {
let ((target_idx, result), _select_idx, remaining) =
futures::future::select_all(futures).await;
match result {
Ok(reply) => return Ok(reply),
Err(e) => {
if target_idx < errors.len() {
errors[target_idx] = Some(e);
}
futures = remaining;
}
}
}
let chosen = errors.into_iter().flatten().next();
Err(chosen.unwrap_or_else(|| {
RpcError::Transport(net::error::AdapterError::Connection(
"hedge_race: drained with no error captured (bug)".into(),
))
}))
}
fn compute_backoff(policy: &RetryPolicy, attempt: u32) -> Duration {
let mult = policy.backoff_multiplier.max(1.0);
let exp = (attempt.saturating_sub(1)) as i32;
let scaled = policy.initial_backoff.as_secs_f64() * mult.powi(exp);
let max_secs = policy.max_backoff.as_secs_f64();
let pre_cap = scaled.min(max_secs);
let jittered = if policy.jitter {
static PROCESS_EPOCH: OnceLock<Instant> = OnceLock::new();
let epoch = PROCESS_EPOCH.get_or_init(Instant::now);
let now_ns = epoch.elapsed().as_nanos() as u64;
let thread_id_bits = {
let mut s = std::collections::hash_map::DefaultHasher::new();
std::hash::Hash::hash(&std::thread::current().id(), &mut s);
std::hash::Hasher::finish(&s)
};
let stack_addr = (&attempt as *const u32) as usize as u64;
let seed = now_ns
^ thread_id_bits
^ stack_addr.rotate_left(17)
^ (attempt as u64).wrapping_mul(0x9E3779B97F4A7C15);
let mixed = seed
.wrapping_mul(0x100000001B3)
.wrapping_add(0xCBF29CE484222325);
let frac = ((mixed >> 32) as u32) as f64 / u32::MAX as f64;
pre_cap * (0.5 + 0.5 * frac)
} else {
pre_cap
};
let final_secs = jittered.min(max_secs).max(0.0);
Duration::from_secs_f64(final_secs)
}
pub type BreakerFailurePredicate = Arc<dyn Fn(&RpcError) -> bool + Send + Sync>;
#[derive(Clone)]
pub struct CircuitBreakerConfig {
pub failure_threshold: u32,
pub success_threshold: u32,
pub reset_after: Duration,
pub failure_predicate: BreakerFailurePredicate,
}
impl std::fmt::Debug for CircuitBreakerConfig {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CircuitBreakerConfig")
.field("failure_threshold", &self.failure_threshold)
.field("success_threshold", &self.success_threshold)
.field("reset_after", &self.reset_after)
.field("failure_predicate", &"<fn>")
.finish()
}
}
impl Default for CircuitBreakerConfig {
fn default() -> Self {
Self {
failure_threshold: 5,
success_threshold: 1,
reset_after: Duration::from_secs(30),
failure_predicate: Arc::new(default_breaker_failure),
}
}
}
pub fn default_breaker_failure(err: &RpcError) -> bool {
default_retryable(err)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum BreakerState {
Closed,
Open,
HalfOpen,
}
#[derive(Debug, thiserror::Error)]
pub enum BreakerError {
#[error("circuit breaker is open")]
Open,
#[error("inner: {0}")]
Inner(#[from] RpcError),
}
impl BreakerError {
pub fn into_rpc_error(self) -> RpcError {
match self {
BreakerError::Open => RpcError::NoRoute {
target: 0,
reason: "circuit breaker is open".into(),
},
BreakerError::Inner(e) => e,
}
}
}
pub struct CircuitBreaker {
config: CircuitBreakerConfig,
inner: Mutex<BreakerInner>,
}
struct BreakerInner {
state: BreakerState,
consecutive_failures: u32,
consecutive_successes: u32,
opened_at: Option<std::time::Instant>,
probe_in_flight: bool,
}
impl CircuitBreaker {
pub fn new(config: CircuitBreakerConfig) -> Self {
Self {
config,
inner: Mutex::new(BreakerInner {
state: BreakerState::Closed,
consecutive_failures: 0,
consecutive_successes: 0,
opened_at: None,
probe_in_flight: false,
}),
}
}
fn lock_inner(&self) -> parking_lot::MutexGuard<'_, BreakerInner> {
self.inner.lock()
}
pub fn state(&self) -> BreakerState {
self.lock_inner().state
}
pub fn consecutive_failures(&self) -> u32 {
self.lock_inner().consecutive_failures
}
pub fn reset(&self) {
let mut g = self.lock_inner();
g.state = BreakerState::Closed;
g.consecutive_failures = 0;
g.consecutive_successes = 0;
g.opened_at = None;
g.probe_in_flight = false;
}
pub async fn call<F, Fut, T>(&self, f: F) -> std::result::Result<T, BreakerError>
where
F: FnOnce() -> Fut,
Fut: std::future::Future<Output = std::result::Result<T, RpcError>>,
{
let admitted_as = self.try_admit();
let admitted_as = match admitted_as {
AdmissionOutcome::Closed => Admission::Closed,
AdmissionOutcome::HalfOpenProbe => Admission::HalfOpenProbe,
AdmissionOutcome::Reject => return Err(BreakerError::Open),
};
let _probe_guard = ProbeGuard {
breaker: self,
admission: admitted_as,
disarmed: std::cell::Cell::new(false),
};
let outcome = f().await;
_probe_guard.disarmed.set(true);
let mut g = self.lock_inner();
match (&outcome, admitted_as) {
(Ok(_), Admission::Closed) => {
g.consecutive_failures = 0;
}
(Ok(_), Admission::HalfOpenProbe) => {
g.probe_in_flight = false;
g.consecutive_successes = g.consecutive_successes.saturating_add(1);
if g.consecutive_successes >= self.config.success_threshold.max(1) {
g.state = BreakerState::Closed;
g.consecutive_failures = 0;
g.consecutive_successes = 0;
g.opened_at = None;
}
}
(Err(e), admission) => {
let counts = (self.config.failure_predicate)(e);
if matches!(admission, Admission::HalfOpenProbe) {
g.probe_in_flight = false;
}
if counts {
match admission {
Admission::Closed => {
g.consecutive_failures = g.consecutive_failures.saturating_add(1);
if g.consecutive_failures >= self.config.failure_threshold.max(1) {
g.state = BreakerState::Open;
g.opened_at = Some(std::time::Instant::now());
g.consecutive_successes = 0;
}
}
Admission::HalfOpenProbe => {
g.state = BreakerState::Open;
g.opened_at = Some(std::time::Instant::now());
g.consecutive_failures = 0;
g.consecutive_successes = 0;
}
}
}
}
}
drop(g);
outcome.map_err(BreakerError::Inner)
}
fn try_admit(&self) -> AdmissionOutcome {
let mut g = self.lock_inner();
match g.state {
BreakerState::Closed => AdmissionOutcome::Closed,
BreakerState::Open => {
let elapsed = g.opened_at.map(|i| i.elapsed()).unwrap_or(Duration::ZERO);
if elapsed >= self.config.reset_after {
g.state = BreakerState::HalfOpen;
g.consecutive_successes = 0;
g.probe_in_flight = true;
AdmissionOutcome::HalfOpenProbe
} else {
AdmissionOutcome::Reject
}
}
BreakerState::HalfOpen => {
if g.probe_in_flight {
AdmissionOutcome::Reject
} else {
g.probe_in_flight = true;
AdmissionOutcome::HalfOpenProbe
}
}
}
}
}
struct ProbeGuard<'a> {
breaker: &'a CircuitBreaker,
admission: Admission,
disarmed: std::cell::Cell<bool>,
}
impl Drop for ProbeGuard<'_> {
fn drop(&mut self) {
if self.disarmed.get() {
return;
}
if matches!(self.admission, Admission::HalfOpenProbe) {
let mut g = self.breaker.lock_inner();
g.probe_in_flight = false;
g.state = BreakerState::Open;
g.opened_at = Some(std::time::Instant::now());
g.consecutive_failures = 0;
g.consecutive_successes = 0;
}
}
}
#[derive(Clone, Copy)]
enum AdmissionOutcome {
Closed,
HalfOpenProbe,
Reject,
}
#[derive(Clone, Copy)]
enum Admission {
Closed,
HalfOpenProbe,
}