flare_rpc_core/discover/
discover.rs

1use async_trait::async_trait;
2use rand::seq::IndexedRandom;
3use rand::Rng;
4use std::collections::HashMap;
5use std::sync::Arc;
6use thiserror::Error;
7use tokio::sync::Mutex;
8
9
10#[derive(Debug, Error)]
11pub enum ServiceError {
12    #[error("service not found: {0}")]
13    NotFound(String),
14    
15    #[error("connection error: {0}")]
16    ConnectionError(String),
17    
18    #[error("decode error: {0}")]
19    DecodeError(String),
20    
21    #[error("resource error: {0}")]
22    ResourceError(String),
23}
24
25#[derive(Clone, Copy, Debug)]
26pub enum LoadBalanceStrategy {
27    Random,
28    RoundRobin,
29    WeightedRandom,
30}
31
32#[async_trait]
33pub trait RpcDiscovery: Send + Sync + Clone + 'static {
34    /// 发现服务并返回连接通道
35    async fn discover(&self, service_name: &str) -> Result<ServiceEndpoint, ServiceError>;
36    /// 根据服务名获取所有服务节点
37    async fn get_all_endpoints(&self, service_name: &str) -> Result<Vec<ServiceEndpoint>, ServiceError>;
38    /// 启动服务发现监听
39    async fn start_watch(&self);
40
41    /// 停止服务发现监听
42    async fn stop_watch(&self);
43}
44
45#[derive(Clone)]
46pub struct ServiceEndpoint {
47    pub address: String,
48    pub port: u16,
49    pub weight: u32,
50}
51
52#[derive(Clone)]
53pub struct LoadBalancer {
54    strategy: LoadBalanceStrategy,
55    round_robin_index: Arc<Mutex<HashMap<String, usize>>>,
56}
57
58impl LoadBalancer {
59    pub fn new(strategy: LoadBalanceStrategy) -> Self {
60        Self {
61            strategy,
62            round_robin_index: Arc::new(Mutex::new(HashMap::new())),
63        }
64    }
65
66    pub async fn select_endpoint(&self, service_name: &str, endpoints: &[ServiceEndpoint]) -> Option<ServiceEndpoint> {
67        if endpoints.is_empty() {
68            return None;
69        }
70
71        match self.strategy {
72            LoadBalanceStrategy::Random => {
73                let mut rng = rand::rng();
74                endpoints.choose(&mut rng).map(|ep| ep.clone())
75            },
76            LoadBalanceStrategy::WeightedRandom => {
77                let total_weight: u32 = endpoints.iter().map(|ep| ep.weight).sum();
78                if total_weight == 0 {
79                    return endpoints.choose(&mut rand::rng()).map(|ep| ep.clone());
80                }
81
82                let mut rng = rand::rng();
83                let chosen_weight = rng.random_range(0..total_weight);
84                let mut accumulated_weight = 0;
85                
86                // 累加权重,权重大的服务占据更大的随机空间
87                for endpoint in endpoints {
88                    accumulated_weight += endpoint.weight;
89                    if chosen_weight < accumulated_weight {
90                        return Some(endpoint.clone());
91                    }
92                }
93                
94                // 保底返回第一个服务
95                Some(endpoints[0].clone())
96            },
97            LoadBalanceStrategy::RoundRobin => {
98                let mut indices = self.round_robin_index.lock().await;
99                let index = indices.entry(service_name.to_string())
100                    .and_modify(|i| *i = (*i + 1) % endpoints.len())
101                    .or_insert(0);
102                Some(endpoints[*index].clone())
103            }
104        }
105    }
106}
107
108#[derive(Clone)]
109pub struct Change {
110    pub service_name: String,
111    pub all: Vec<ServiceEndpoint>,
112    pub added: Vec<ServiceEndpoint>,
113    pub updated: Vec<ServiceEndpoint>,
114    pub removed: Vec<ServiceEndpoint>,
115}