use crate::balance::discovery::ServiceEndpoint;
use dashmap::DashMap;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use tonic::transport::Channel;
pub trait LoadBalancer: Send + Sync {
fn next_endpoint(&self) -> Option<ServiceEndpoint>;
fn service_name(&self) -> &str;
}
pub struct RoundRobinLoadBalancer {
service_name: String,
current_index: Arc<AtomicUsize>,
}
static LOAD_BALANCERS: std::sync::LazyLock<DashMap<String, Arc<RoundRobinLoadBalancer>>> =
std::sync::LazyLock::new(|| DashMap::new());
pub fn get_load_balancer(service_name: &str) -> Arc<RoundRobinLoadBalancer> {
LOAD_BALANCERS
.entry(service_name.to_string())
.or_insert_with(|| Arc::new(RoundRobinLoadBalancer::new(service_name.to_string())))
.clone()
}
impl RoundRobinLoadBalancer {
pub fn new(service_name: String) -> Self {
Self {
service_name,
current_index: Arc::new(AtomicUsize::new(0)),
}
}
pub async fn build_channel(&self) -> Result<Channel, crate::error::AppError> {
let endpoint = self
.next_endpoint()
.ok_or_else(|| {
crate::error::AppError::ServiceUnavailable(format!(
"服务 {} 没有可用的实例",
self.service_name
))
})?;
endpoint.endpoint.connect().await.map_err(|e| {
crate::error::AppError::ServiceUnavailable(format!(
"连接服务 {} 失败: {}",
self.service_name, e
))
})
}
}
impl LoadBalancer for RoundRobinLoadBalancer {
fn next_endpoint(&self) -> Option<ServiceEndpoint> {
let endpoints = crate::balance::discovery::get_service_endpoints(&self.service_name);
if endpoints.is_empty() {
tracing::warn!("服务 {} 没有可用的实例", self.service_name);
return None;
}
let available: Vec<&ServiceEndpoint> = endpoints
.iter()
.filter(|ep| {
crate::balance::health::is_available(&self.service_name, &ep.instance_id)
})
.collect();
let pool = if available.is_empty() {
tracing::warn!(
"服务 {} 所有 {} 个实例均不健康,fallback 到全部实例",
self.service_name,
endpoints.len()
);
endpoints.iter().collect::<Vec<_>>()
} else {
available
};
let index = self.current_index.fetch_add(1, Ordering::Relaxed);
let selected = pool.get(index % pool.len()).cloned().cloned();
if let Some(ref endpoint) = selected {
tracing::debug!(
"负载均衡选择: 服务={}, 实例={}, 索引={}/{} (健康池/总: {}/{})",
self.service_name,
endpoint.instance_id,
index % pool.len(),
pool.len(),
pool.len(),
endpoints.len()
);
}
selected
}
fn service_name(&self) -> &str {
&self.service_name
}
}