flare_rpc_core/discover/
discover.rs1use async_trait::async_trait;
2use rand::seq::IndexedRandom;
3use rand::Rng;
4use std::collections::HashMap;
5use std::sync::Arc;
6use thiserror::Error;
7use tokio::sync::Mutex;
8
9
10#[derive(Debug, Error)]
11pub enum ServiceError {
12 #[error("service not found: {0}")]
13 NotFound(String),
14
15 #[error("connection error: {0}")]
16 ConnectionError(String),
17
18 #[error("decode error: {0}")]
19 DecodeError(String),
20
21 #[error("resource error: {0}")]
22 ResourceError(String),
23}
24
25#[derive(Clone, Copy, Debug)]
26pub enum LoadBalanceStrategy {
27 Random,
28 RoundRobin,
29 WeightedRandom,
30}
31
32#[async_trait]
33pub trait RpcDiscovery: Send + Sync + Clone + 'static {
34 async fn discover(&self, service_name: &str) -> Result<ServiceEndpoint, ServiceError>;
36 async fn get_all_endpoints(&self, service_name: &str) -> Result<Vec<ServiceEndpoint>, ServiceError>;
38 async fn start_watch(&self);
40
41 async fn stop_watch(&self);
43}
44
45#[derive(Clone)]
46pub struct ServiceEndpoint {
47 pub address: String,
48 pub port: u16,
49 pub weight: u32,
50}
51
52#[derive(Clone)]
53pub struct LoadBalancer {
54 strategy: LoadBalanceStrategy,
55 round_robin_index: Arc<Mutex<HashMap<String, usize>>>,
56}
57
58impl LoadBalancer {
59 pub fn new(strategy: LoadBalanceStrategy) -> Self {
60 Self {
61 strategy,
62 round_robin_index: Arc::new(Mutex::new(HashMap::new())),
63 }
64 }
65
66 pub async fn select_endpoint(&self, service_name: &str, endpoints: &[ServiceEndpoint]) -> Option<ServiceEndpoint> {
67 if endpoints.is_empty() {
68 return None;
69 }
70
71 match self.strategy {
72 LoadBalanceStrategy::Random => {
73 let mut rng = rand::rng();
74 endpoints.choose(&mut rng).map(|ep| ep.clone())
75 },
76 LoadBalanceStrategy::WeightedRandom => {
77 let total_weight: u32 = endpoints.iter().map(|ep| ep.weight).sum();
78 if total_weight == 0 {
79 return endpoints.choose(&mut rand::rng()).map(|ep| ep.clone());
80 }
81
82 let mut rng = rand::rng();
83 let chosen_weight = rng.random_range(0..total_weight);
84 let mut accumulated_weight = 0;
85
86 for endpoint in endpoints {
88 accumulated_weight += endpoint.weight;
89 if chosen_weight < accumulated_weight {
90 return Some(endpoint.clone());
91 }
92 }
93
94 Some(endpoints[0].clone())
96 },
97 LoadBalanceStrategy::RoundRobin => {
98 let mut indices = self.round_robin_index.lock().await;
99 let index = indices.entry(service_name.to_string())
100 .and_modify(|i| *i = (*i + 1) % endpoints.len())
101 .or_insert(0);
102 Some(endpoints[*index].clone())
103 }
104 }
105 }
106}
107
108#[derive(Clone)]
109pub struct Change {
110 pub service_name: String,
111 pub all: Vec<ServiceEndpoint>,
112 pub added: Vec<ServiceEndpoint>,
113 pub updated: Vec<ServiceEndpoint>,
114 pub removed: Vec<ServiceEndpoint>,
115}