flare_rpc_core/discover/consul/
discover.rs1use super::{ConsulConfig, ConsulService};
2use crate::discover::discover::{Change, ServiceError};
3use crate::discover::{LoadBalanceStrategy, LoadBalancer, RpcDiscovery, ServiceEndpoint};
4use async_trait::async_trait;
5use reqwest;
6use std::collections::HashMap;
7use std::sync::Arc;
8use std::time::Duration;
9use tokio::sync::RwLock;
10use tokio::sync::broadcast::{self, Sender};
11
12
13#[derive(Clone)]
14pub struct ConsulDiscover {
15 client: reqwest::Client,
16 config: ConsulConfig,
17 services: Arc<RwLock<HashMap<String, Vec<ServiceEndpoint>>>>,
18 watch_task: Arc<RwLock<Option<tokio::task::JoinHandle<()>>>>,
19 load_balancer: Arc<LoadBalancer>,
20 broadcaster: Sender<Change>,
21}
22
23impl ConsulDiscover {
24 pub fn new(config: ConsulConfig, strategy: LoadBalanceStrategy) -> Self {
25 let client = reqwest::Client::builder()
26 .timeout(config.timeout)
27 .build()
28 .unwrap();
29
30 let (broadcaster, _rx) = broadcast::channel(100);
31
32 Self {
33 client,
34 config,
35 services: Arc::new(RwLock::new(HashMap::new())),
36 watch_task: Arc::new(RwLock::new(None)),
37 load_balancer: Arc::new(LoadBalancer::new(strategy)),
38 broadcaster,
39 }
40 }
41
42 async fn clear_services(services: &Arc<RwLock<HashMap<String, Vec<ServiceEndpoint>>>>, broadcaster: &Sender<Change>) {
43 let mut services_lock = services.write().await;
44 for (service_name, old_endpoints) in services_lock.drain() {
45 let change = Change {
46 service_name,
47 all: vec![],
48 added: vec![],
49 updated: vec![],
50 removed: old_endpoints,
51 };
52 if let Err(e) = broadcaster.send(change) {
53 log::error!("Failed to broadcast service removal: {}", e);
54 }
55 }
56 }
57 async fn sync_services(&self) {
59 let health_url = format!("{}/v1/health/state/passing", self.config.url());
61 let mut health_request = self.client.get(&health_url);
62 if let Some(token) = &self.config.token {
63 health_request = health_request.header("X-Consul-Token", token);
64 }
65
66 let healthy_services = match health_request.send().await {
67 Ok(health_response) => {
68 match health_response.json::<Vec<serde_json::Value>>().await {
69 Ok(checks) => {
70 checks.iter()
72 .filter_map(|check| {
73 check.get("ServiceID")
74 .and_then(|v| v.as_str())
75 .map(|s| s.to_string())
76 })
77 .collect::<std::collections::HashSet<String>>()
78 }
79 Err(e) => {
80 log::error!("Failed to parse health checks response: {}", e);
81 std::collections::HashSet::new()
82 }
83 }
84 }
85 Err(e) => {
86 log::error!("Failed to fetch health checks: {}", e);
87 std::collections::HashSet::new()
88 }
89 };
90
91 let url = format!("{}/v1/agent/services", self.config.url());
93 let mut request = self.client.get(&url);
94 if let Some(token) = &self.config.token {
95 request = request.header("X-Consul-Token", token);
96 }
97
98 match request.send().await {
99 Ok(response) => {
100 if let Ok(services_map) = response.json::<HashMap<String, ConsulService>>().await {
101 let mut new_services = HashMap::new();
102
103 for (id, service) in services_map {
105 if !healthy_services.contains(&id) {
106 continue;
107 }
108
109 let endpoints = new_services
110 .entry(service.service.clone())
111 .or_insert_with(Vec::new);
112
113 let weight = service.meta.get("weight")
114 .and_then(|w| w.parse::<u32>().ok())
115 .unwrap_or(1);
116
117 endpoints.push(ServiceEndpoint {
118 address: service.address.clone(),
119 port: service.port,
120 weight,
121 });
122 }
123
124 let mut services_lock = self.services.write().await;
126 let mut old_services: Vec<String> = services_lock.keys().cloned().collect();
127
128 for (service_name, new_endpoints) in new_services {
130 old_services.retain(|s| s != &service_name);
131
132 let old_endpoints = services_lock.get(&service_name)
133 .cloned()
134 .unwrap_or_default();
135
136 let added: Vec<_> = new_endpoints.iter()
138 .filter(|ep| !old_endpoints.iter().any(|old|
139 old.address == ep.address && old.port == ep.port
140 ))
141 .cloned()
142 .collect();
143
144 let removed: Vec<_> = old_endpoints.iter()
145 .filter(|ep| !new_endpoints.iter().any(|new|
146 new.address == ep.address && new.port == ep.port
147 ))
148 .cloned()
149 .collect();
150
151 if !added.is_empty() || !removed.is_empty() {
153 services_lock.insert(service_name.clone(), new_endpoints.clone());
155
156 let change = Change {
158 service_name,
159 all: new_endpoints,
160 added,
161 updated: vec![],
162 removed,
163 };
164
165 if self.broadcaster.receiver_count() > 0 {
167 if let Err(e) = self.broadcaster.send(change) {
168 log::debug!("Failed to broadcast service changes: {}", e);
169 }
170 }
171 }
172 }
173
174 for service_name in old_services {
176 if let Some(old_endpoints) = services_lock.remove(&service_name) {
177 let change = Change {
178 service_name,
179 all: vec![],
180 added: vec![],
181 updated: vec![],
182 removed: old_endpoints,
183 };
184
185 if self.broadcaster.receiver_count() > 0 {
187 if let Err(e) = self.broadcaster.send(change) {
188 log::debug!("Failed to broadcast service removal: {}", e);
189 }
190 }
191 }
192 }
193 } else {
194 log::error!("Failed to parse services response");
195 Self::clear_services(&self.services, &self.broadcaster).await;
196 }
197 }
198 Err(e) => {
199 log::error!("Failed to sync services: {}", e);
200 Self::clear_services(&self.services, &self.broadcaster).await;
201 }
202 }
203 }
204}
205
206#[async_trait]
207impl RpcDiscovery for ConsulDiscover {
208 async fn discover(&self, service_name: &str) -> Result<ServiceEndpoint, ServiceError> {
209 let services = self.services.read().await;
210 let endpoints = services.get(service_name)
211 .ok_or_else(|| ServiceError::NotFound(service_name.to_string()))?;
212
213 if endpoints.is_empty() {
214 return Err(ServiceError::NotFound(format!("No endpoints found for service {}", service_name)));
215 }
216 self.load_balancer.select_endpoint(service_name, endpoints).await
217 .ok_or_else(|| ServiceError::ResourceError("Failed to choose endpoint".to_string()))
218 }
219
220 async fn get_all_endpoints(&self, service_name: &str) -> Result<Vec<ServiceEndpoint>, ServiceError> {
221 let services = self.services.read().await;
222 let endpoints = services.get(service_name)
223 .ok_or_else(|| ServiceError::NotFound(service_name.to_string()))?;
224
225 Ok(endpoints.clone())
226 }
227
228 async fn start_watch(&self) {
229 self.sync_services().await;
231
232 let this = self.clone();
233 let task = tokio::spawn(async move {
234 let mut interval = tokio::time::interval(Duration::from_secs(30));
235
236 loop {
237 interval.tick().await;
238 this.sync_services().await;
239 }
240 });
241
242 *self.watch_task.write().await = Some(task);
243 }
244
245 async fn stop_watch(&self) {
246 if let Some(task) = self.watch_task.write().await.take() {
247 task.abort();
248 }
249 }
250}