use std::future::Future;
use std::time::Duration;
use tonic::transport::Channel;
use dashmap::DashMap;
use crate::balance::circuit_breaker::{
get_circuit_breaker, get_instance_circuit_breaker, CircuitBreakerConfig,
};
use crate::balance::load_balancer::{get_load_balancer, LoadBalancer};
use crate::error::AppError;
static CHANNEL_POOL: std::sync::LazyLock<DashMap<String, Channel>> =
std::sync::LazyLock::new(|| DashMap::new());
async fn get_or_connect(
endpoint: &crate::balance::discovery::ServiceEndpoint,
connect_timeout: Duration,
) -> Result<Channel, AppError> {
if let Some(channel) = CHANNEL_POOL.get(&endpoint.instance_id) {
return Ok(channel.clone());
}
let channel = tokio::time::timeout(
connect_timeout,
endpoint.endpoint.connect(),
)
.await
.map_err(|_| {
AppError::ServiceUnavailable(format!(
"连接 {} 超时 ({}ms)",
endpoint.instance_id,
connect_timeout.as_millis()
))
})?
.map_err(|e| {
AppError::ServiceUnavailable(format!(
"连接 {} 失败: {}",
endpoint.instance_id, e
))
})?;
CHANNEL_POOL.insert(endpoint.instance_id.clone(), channel.clone());
Ok(channel)
}
fn evict_channel(instance_id: &str) {
CHANNEL_POOL.remove(instance_id);
}
#[derive(Debug, Clone)]
pub struct RetryConfig {
pub max_retries: u32,
pub base_delay: Duration,
pub max_delay: Duration,
}
impl Default for RetryConfig {
fn default() -> Self {
Self {
max_retries: 2,
base_delay: Duration::from_millis(200),
max_delay: Duration::from_secs(5),
}
}
}
pub struct ResilientGrpcClient {
service_name: String,
timeout: Duration,
retry: RetryConfig,
cb_config: CircuitBreakerConfig,
instance_level_cb: bool,
}
impl ResilientGrpcClient {
pub fn for_service(service_name: &str) -> Self {
Self {
service_name: service_name.to_string(),
timeout: Duration::from_secs(5),
retry: RetryConfig::default(),
cb_config: CircuitBreakerConfig::default(),
instance_level_cb: true,
}
}
pub fn with_timeout(mut self, timeout: Duration) -> Self {
self.timeout = timeout;
self
}
pub fn with_retry(mut self, retry: RetryConfig) -> Self {
self.retry = retry;
self
}
pub fn with_max_retries(mut self, max_retries: u32) -> Self {
self.retry.max_retries = max_retries;
self
}
pub fn no_retry(mut self) -> Self {
self.retry.max_retries = 0;
self
}
pub fn with_circuit_breaker(mut self, config: CircuitBreakerConfig) -> Self {
self.cb_config = config;
self
}
pub fn no_circuit_breaker(mut self) -> Self {
self.cb_config.failure_threshold = u32::MAX;
self
}
pub fn service_level_cb(mut self) -> Self {
self.instance_level_cb = false;
self
}
pub async fn call<F, Fut, R>(&self, f: F) -> Result<R, AppError>
where
F: Fn(Channel) -> Fut + Send + Sync,
Fut: Future<Output = Result<R, tonic::Status>> + Send,
R: Send,
{
if !self.instance_level_cb {
let cb = get_circuit_breaker(&self.service_name, self.cb_config.clone());
if !cb.allow_request() {
tracing::warn!(
service = %self.service_name,
state = %cb.state(),
"熔断器拒绝请求"
);
return Err(AppError::ServiceUnavailable(format!(
"服务 {} 熔断器已打开,请稍后再试",
self.service_name
)));
}
}
let balancer = get_load_balancer(&self.service_name);
let total_attempts = self.retry.max_retries + 1;
let mut last_error: Option<AppError> = None;
for attempt in 0..total_attempts {
if attempt > 0 {
let delay = self.calculate_backoff(attempt);
tracing::debug!(
service = %self.service_name,
attempt = attempt + 1,
total = total_attempts,
delay_ms = delay.as_millis(),
"重试 gRPC 调用"
);
tokio::time::sleep(delay).await;
}
let endpoint = match balancer.next_endpoint() {
Some(ep) => ep,
None => {
let err = AppError::ServiceUnavailable(format!(
"服务 {} 没有可用的实例",
self.service_name
));
return Err(err);
}
};
if self.instance_level_cb {
let inst_cb = get_instance_circuit_breaker(
&self.service_name,
&endpoint.instance_id,
self.cb_config.clone(),
);
if !inst_cb.allow_request() {
tracing::debug!(
service = %self.service_name,
instance = %endpoint.instance_id,
"实例熔断器打开,跳过此实例"
);
continue;
}
}
let channel = match get_or_connect(&endpoint, self.timeout).await {
Ok(ch) => ch,
Err(e) => {
tracing::warn!(
service = %self.service_name,
instance = %endpoint.instance_id,
error = %e,
"连接失败"
);
evict_channel(&endpoint.instance_id);
self.record_instance_failure(&endpoint.instance_id);
last_error = Some(e);
continue;
}
};
let result = tokio::time::timeout(self.timeout, f(channel)).await;
match result {
Ok(Ok(response)) => {
self.record_instance_success(&endpoint.instance_id);
if attempt > 0 {
tracing::info!(
service = %self.service_name,
instance = %endpoint.instance_id,
attempt = attempt + 1,
"gRPC 调用在第 {} 次尝试后成功",
attempt + 1
);
}
return Ok(response);
}
Ok(Err(status)) => {
tracing::warn!(
service = %self.service_name,
instance = %endpoint.instance_id,
attempt = attempt + 1,
error = %status,
"gRPC 调用失败"
);
if is_transport_error(&status) {
evict_channel(&endpoint.instance_id);
}
self.record_instance_failure(&endpoint.instance_id);
last_error = Some(AppError::Internal(anyhow::anyhow!(
"gRPC 调用 {} ({}) 失败: {}",
self.service_name,
endpoint.instance_id,
status
)));
}
Err(_) => {
tracing::warn!(
service = %self.service_name,
instance = %endpoint.instance_id,
attempt = attempt + 1,
timeout_ms = self.timeout.as_millis(),
"gRPC 调用超时"
);
self.record_instance_failure(&endpoint.instance_id);
last_error = Some(AppError::ServiceUnavailable(format!(
"gRPC 调用 {} ({}) 超时 ({}ms)",
self.service_name,
endpoint.instance_id,
self.timeout.as_millis()
)));
}
}
}
tracing::error!(
service = %self.service_name,
attempts = total_attempts,
"gRPC 调用在 {} 次尝试后全部失败",
total_attempts
);
Err(last_error.unwrap_or_else(|| {
AppError::ServiceUnavailable(format!(
"gRPC 调用 {} 失败(重试 {} 次后)",
self.service_name, self.retry.max_retries
))
}))
}
fn record_instance_success(&self, instance_id: &str) {
if self.instance_level_cb {
let cb = get_instance_circuit_breaker(
&self.service_name,
instance_id,
self.cb_config.clone(),
);
cb.record_success();
} else {
let cb = get_circuit_breaker(&self.service_name, self.cb_config.clone());
cb.record_success();
}
}
fn record_instance_failure(&self, instance_id: &str) {
if self.instance_level_cb {
let cb = get_instance_circuit_breaker(
&self.service_name,
instance_id,
self.cb_config.clone(),
);
cb.record_failure();
} else {
let cb = get_circuit_breaker(&self.service_name, self.cb_config.clone());
cb.record_failure();
}
}
fn calculate_backoff(&self, attempt: u32) -> Duration {
let delay = self
.retry
.base_delay
.saturating_mul(2u32.saturating_pow(attempt.saturating_sub(1)));
std::cmp::min(delay, self.retry.max_delay)
}
}
fn is_transport_error(status: &tonic::Status) -> bool {
matches!(
status.code(),
tonic::Code::Unavailable
| tonic::Code::Internal
| tonic::Code::Unknown
)
}
pub async fn grpc_call<F, Fut, R>(service_name: &str, f: F) -> Result<R, AppError>
where
F: Fn(Channel) -> Fut + Send + Sync,
Fut: Future<Output = Result<R, tonic::Status>> + Send,
R: Send,
{
ResilientGrpcClient::for_service(service_name)
.call(f)
.await
}