flare_rpc_core/discover/consul/
discover.rs

1use super::{ConsulConfig, ConsulService};
2use crate::discover::discover::{Change, ServiceError};
3use crate::discover::{LoadBalanceStrategy, LoadBalancer, RpcDiscovery, ServiceEndpoint};
4use async_trait::async_trait;
5use reqwest;
6use std::collections::HashMap;
7use std::sync::Arc;
8use std::time::Duration;
9use tokio::sync::RwLock;
10use tokio::sync::broadcast::{self, Sender};
11
12
13#[derive(Clone)]
14pub struct ConsulDiscover {
15    client: reqwest::Client,
16    config: ConsulConfig,
17    services: Arc<RwLock<HashMap<String, Vec<ServiceEndpoint>>>>,
18    watch_task: Arc<RwLock<Option<tokio::task::JoinHandle<()>>>>,
19    load_balancer: Arc<LoadBalancer>,
20    broadcaster: Sender<Change>,
21}
22
23impl ConsulDiscover {
24    pub fn new(config: ConsulConfig, strategy: LoadBalanceStrategy) -> Self {
25        let client = reqwest::Client::builder()
26            .timeout(config.timeout)
27            .build()
28            .unwrap();
29            
30        let (broadcaster, _rx) = broadcast::channel(100);
31            
32        Self {
33            client,
34            config,
35            services: Arc::new(RwLock::new(HashMap::new())),
36            watch_task: Arc::new(RwLock::new(None)),
37            load_balancer: Arc::new(LoadBalancer::new(strategy)),
38            broadcaster,
39        }
40    }
41
42    async fn clear_services(services: &Arc<RwLock<HashMap<String, Vec<ServiceEndpoint>>>>, broadcaster: &Sender<Change>) {
43        let mut services_lock = services.write().await;
44        for (service_name, old_endpoints) in services_lock.drain() {
45            let change = Change {
46                service_name,
47                all: vec![],
48                added: vec![],
49                updated: vec![],
50                removed: old_endpoints,
51            };
52            if let Err(e) = broadcaster.send(change) {
53                log::error!("Failed to broadcast service removal: {}", e);
54            }
55        }
56    }
57    // 同步服务列表
58    async fn sync_services(&self) {
59        // 获取通过健康检查的服务
60        let health_url = format!("{}/v1/health/state/passing", self.config.url());
61        let mut health_request = self.client.get(&health_url);
62        if let Some(token) = &self.config.token {
63            health_request = health_request.header("X-Consul-Token", token);
64        }
65
66        let healthy_services = match health_request.send().await {
67            Ok(health_response) => {
68                match health_response.json::<Vec<serde_json::Value>>().await {
69                    Ok(checks) => {
70                        // 获取所有通过健康检查的服务 ID
71                        checks.iter()
72                            .filter_map(|check| {
73                                check.get("ServiceID")
74                                    .and_then(|v| v.as_str())
75                                    .map(|s| s.to_string())
76                            })
77                            .collect::<std::collections::HashSet<String>>()
78                    }
79                    Err(e) => {
80                        log::error!("Failed to parse health checks response: {}", e);
81                        std::collections::HashSet::new()
82                    }
83                }
84            }
85            Err(e) => {
86                log::error!("Failed to fetch health checks: {}", e);
87                std::collections::HashSet::new()
88            }
89        };
90
91        // 获取服务详情
92        let url = format!("{}/v1/agent/services", self.config.url());
93        let mut request = self.client.get(&url);
94        if let Some(token) = &self.config.token {
95            request = request.header("X-Consul-Token", token);
96        }
97        
98        match request.send().await {
99            Ok(response) => {
100                if let Ok(services_map) = response.json::<HashMap<String, ConsulService>>().await {
101                    let mut new_services = HashMap::new();
102                    
103                    // 构建新的服务列表,只包含健康的服务
104                    for (id, service) in services_map {
105                        if !healthy_services.contains(&id) {
106                            continue;
107                        }
108                        
109                        let endpoints = new_services
110                            .entry(service.service.clone())
111                            .or_insert_with(Vec::new);
112                            
113                        let weight = service.meta.get("weight")
114                            .and_then(|w| w.parse::<u32>().ok())
115                            .unwrap_or(1);
116                            
117                        endpoints.push(ServiceEndpoint {
118                            address: service.address.clone(),
119                            port: service.port,
120                            weight,
121                        });
122                    }
123                    
124                    // 获取旧服务列表
125                    let mut services_lock = self.services.write().await;
126                    let mut old_services: Vec<String> = services_lock.keys().cloned().collect();
127                    
128                    // 处理服务变更
129                    for (service_name, new_endpoints) in new_services {
130                        old_services.retain(|s| s != &service_name);
131                        
132                        let old_endpoints = services_lock.get(&service_name)
133                            .cloned()
134                            .unwrap_or_default();
135                            
136                        // 计算新增和移除的端点
137                        let added: Vec<_> = new_endpoints.iter()
138                            .filter(|ep| !old_endpoints.iter().any(|old| 
139                                old.address == ep.address && old.port == ep.port
140                            ))
141                            .cloned()
142                            .collect();
143                            
144                        let removed: Vec<_> = old_endpoints.iter()
145                            .filter(|ep| !new_endpoints.iter().any(|new| 
146                                new.address == ep.address && new.port == ep.port
147                            ))
148                            .cloned()
149                            .collect();
150                            
151                        // 只有在有变更时才发送通知
152                        if !added.is_empty() || !removed.is_empty() {
153                            // 更新服务列表
154                            services_lock.insert(service_name.clone(), new_endpoints.clone());
155                            
156                            // 发送变更通知
157                            let change = Change {
158                                service_name,
159                                all: new_endpoints,
160                                added,
161                                updated: vec![],
162                                removed,
163                            };
164                            
165                            // 检查是否有接收者
166                            if self.broadcaster.receiver_count() > 0 {
167                                if let Err(e) = self.broadcaster.send(change) {
168                                    log::debug!("Failed to broadcast service changes: {}", e);
169                                }
170                            }
171                        }
172                    }
173                    
174                    // 处理已删除的服务
175                    for service_name in old_services {
176                        if let Some(old_endpoints) = services_lock.remove(&service_name) {
177                            let change = Change {
178                                service_name,
179                                all: vec![],
180                                added: vec![],
181                                updated: vec![],
182                                removed: old_endpoints,
183                            };
184                            
185                            // 检查是否有接收者
186                            if self.broadcaster.receiver_count() > 0 {
187                                if let Err(e) = self.broadcaster.send(change) {
188                                    log::debug!("Failed to broadcast service removal: {}", e);
189                                }
190                            }
191                        }
192                    }
193                } else {
194                    log::error!("Failed to parse services response");
195                    Self::clear_services(&self.services, &self.broadcaster).await;
196                }
197            }
198            Err(e) => {
199                log::error!("Failed to sync services: {}", e);
200                Self::clear_services(&self.services, &self.broadcaster).await;
201            }
202        }
203    }
204}
205
206#[async_trait]
207impl RpcDiscovery for ConsulDiscover {
208    async fn discover(&self, service_name: &str) -> Result<ServiceEndpoint, ServiceError> {
209        let services = self.services.read().await;
210        let endpoints = services.get(service_name)
211            .ok_or_else(|| ServiceError::NotFound(service_name.to_string()))?;
212
213        if endpoints.is_empty() {
214            return Err(ServiceError::NotFound(format!("No endpoints found for service {}", service_name)));
215        }
216        self.load_balancer.select_endpoint(service_name, endpoints).await
217            .ok_or_else(|| ServiceError::ResourceError("Failed to choose endpoint".to_string()))
218    }
219
220    async fn get_all_endpoints(&self, service_name: &str) -> Result<Vec<ServiceEndpoint>, ServiceError> {
221        let services = self.services.read().await;
222        let endpoints = services.get(service_name)
223            .ok_or_else(|| ServiceError::NotFound(service_name.to_string()))?;
224
225        Ok(endpoints.clone())
226    }
227
228    async fn start_watch(&self) {
229        // 首次同步服务列表
230        self.sync_services().await;
231
232        let this = self.clone();
233        let task = tokio::spawn(async move {
234            let mut interval = tokio::time::interval(Duration::from_secs(30));
235            
236            loop {
237                interval.tick().await;
238                this.sync_services().await;
239            }
240        });
241
242        *self.watch_task.write().await = Some(task);
243    }
244
245    async fn stop_watch(&self) {
246        if let Some(task) = self.watch_task.write().await.take() {
247            task.abort();
248        }
249    }
250}