nova_boot_discovery_consul/
lib.rs1use async_trait::async_trait;
2use nova_boot::discovery::{
3 Discovery, DiscoveryError, InstanceStatus, ServiceInstance, WatchStream,
4};
5use serde::Deserialize;
6use serde_json::Value as JsonValue;
7use std::collections::HashMap;
8use std::sync::Arc;
9use tokio::sync::{RwLock, mpsc};
10use tracing::{debug, warn};
11
12type WatchersMap = HashMap<String, Vec<mpsc::Sender<Vec<ServiceInstance>>>>;
13
14#[derive(Clone)]
15pub struct ConsulDiscovery {
16 client: reqwest::Client,
17 base_url: String,
18 datacenter: Option<String>,
19 token: Option<String>,
20 watchers: Arc<RwLock<WatchersMap>>,
22 watch_tasks: Arc<RwLock<HashMap<String, tokio::task::JoinHandle<()>>>>,
23}
24
25impl ConsulDiscovery {
26 pub fn new(
27 base_url: impl Into<String>,
28 datacenter: Option<String>,
29 token: Option<String>,
30 ) -> Self {
31 Self {
32 client: reqwest::Client::new(),
33 base_url: base_url.into().trim_end_matches('/').to_string(),
34 datacenter,
35 token,
36 watchers: Arc::new(RwLock::new(HashMap::new())),
37 watch_tasks: Arc::new(RwLock::new(HashMap::new())),
38 }
39 }
40
41 fn url(&self, path: &str) -> String {
42 format!(
43 "{}/{}",
44 self.base_url.trim_end_matches('/'),
45 path.trim_start_matches('/')
46 )
47 }
48
49 fn request(&self, method: reqwest::Method, path: &str) -> reqwest::RequestBuilder {
50 let mut builder = self.client.request(method, self.url(path));
51 if let Some(dc) = &self.datacenter {
52 builder = builder.query(&[("dc", dc)]);
53 }
54 if let Some(token) = &self.token {
55 builder = builder.header("X-Consul-Token", token);
56 }
57 builder
58 }
59
60 async fn send_request(
61 &self,
62 builder: reqwest::RequestBuilder,
63 ) -> Result<reqwest::Response, DiscoveryError> {
64 let response = builder
65 .send()
66 .await
67 .map_err(|e| DiscoveryError::Backend(e.to_string()))?;
68
69 if response.status() == reqwest::StatusCode::NOT_FOUND {
70 return Err(DiscoveryError::NotFound(
71 "consul resource not found".to_string(),
72 ));
73 }
74
75 if !response.status().is_success() {
76 return Err(DiscoveryError::Backend(format!(
77 "consul request failed with status {}",
78 response.status()
79 )));
80 }
81
82 Ok(response)
83 }
84
85 fn split_address(address: &str) -> Result<(String, u16), DiscoveryError> {
86 let (host, port) = address
87 .rsplit_once(':')
88 .ok_or_else(|| DiscoveryError::Backend(format!("invalid address: {address}")))?;
89 let port = port
90 .parse::<u16>()
91 .map_err(|e| DiscoveryError::Backend(format!("invalid address port: {e}")))?;
92 Ok((host.to_string(), port))
93 }
94
95 fn metadata_from_value(value: Option<&JsonValue>) -> HashMap<String, String> {
96 match value.and_then(JsonValue::as_object) {
97 Some(map) => map
98 .iter()
99 .map(|(key, value)| {
100 let rendered = value
101 .as_str()
102 .map(ToString::to_string)
103 .unwrap_or_else(|| value.to_string());
104 (key.clone(), rendered)
105 })
106 .collect(),
107 None => HashMap::new(),
108 }
109 }
110
111 fn service_instance_from_consul(
112 service: &ConsulServiceEntry,
113 status: InstanceStatus,
114 ) -> Result<ServiceInstance, DiscoveryError> {
115 let address = if service.service.address.is_empty() {
116 service.node.address.clone()
117 } else {
118 service.service.address.clone()
119 };
120 let address = format!("{}:{}", address, service.service.port);
121 let metadata = Self::metadata_from_value(service.service.meta.as_ref());
122
123 Ok(ServiceInstance {
124 id: service.service.id.clone(),
125 name: service.service.service.clone(),
126 address,
127 metadata,
128 status,
129 last_heartbeat: None,
130 })
131 }
132
133 async fn notify_watchers(&self, service_name: &str, instances: Vec<ServiceInstance>) {
134 let watchers = {
135 let watchers = self.watchers.read().await;
136 watchers.get(service_name).cloned().unwrap_or_default()
137 };
138
139 if watchers.is_empty() {
140 return;
141 }
142
143 let mut alive = Vec::with_capacity(watchers.len());
144 for watcher in watchers {
145 if watcher.send(instances.clone()).await.is_ok() {
146 alive.push(watcher);
147 }
148 }
149
150 let mut watchers_map = self.watchers.write().await;
151 if let Some(entry) = watchers_map.get_mut(service_name) {
152 *entry = alive;
153 }
154 }
155
156 async fn current_instances(
157 &self,
158 service_name: &str,
159 ) -> Result<Vec<ServiceInstance>, DiscoveryError> {
160 self.discover(service_name).await
161 }
162
163 async fn watch_loop(self, service_name: String) {
164 let mut last_index: u64 = 0;
165
166 loop {
167 let has_watchers = {
168 let watchers = self.watchers.read().await;
169 watchers
170 .get(&service_name)
171 .map(|items| !items.is_empty())
172 .unwrap_or(false)
173 };
174
175 if !has_watchers {
176 break;
177 }
178
179 match self.long_poll(&service_name, last_index).await {
180 Ok((instances, index)) => {
181 if index > last_index {
182 last_index = index;
183 self.notify_watchers(&service_name, instances).await;
184 }
185 }
186 Err(err) => {
187 warn!(service = %service_name, error = %err, "Consul watch loop error");
188 tokio::time::sleep(std::time::Duration::from_secs(1)).await;
189 }
190 }
191 }
192
193 self.watch_tasks.write().await.remove(&service_name);
194 }
195
196 async fn long_poll(
197 &self,
198 service_name: &str,
199 last_index: u64,
200 ) -> Result<(Vec<ServiceInstance>, u64), DiscoveryError> {
201 let mut request = self
202 .request(
203 reqwest::Method::GET,
204 &format!("/v1/health/service/{service_name}"),
205 )
206 .query(&[("passing", "true"), ("wait", "30s")]);
207 if last_index > 0 {
208 request = request.query(&[("index", &last_index.to_string())]);
209 }
210
211 let response = self.send_request(request).await?;
212 let index = response
213 .headers()
214 .get("X-Consul-Index")
215 .and_then(|value| value.to_str().ok())
216 .and_then(|value| value.parse::<u64>().ok())
217 .unwrap_or(last_index);
218 let entries = response
219 .json::<Vec<ConsulServiceEntry>>()
220 .await
221 .map_err(|e| DiscoveryError::Backend(e.to_string()))?;
222 let instances = entries
223 .iter()
224 .map(|entry| Self::service_instance_from_consul(entry, InstanceStatus::Healthy))
225 .collect::<Result<Vec<_>, _>>()?;
226 Ok((instances, index))
227 }
228}
229
230#[derive(Debug, Deserialize)]
231struct ConsulNodeEntry {
232 #[serde(rename = "Address")]
233 address: String,
234}
235
236#[derive(Debug, Deserialize)]
237struct ConsulServiceInfo {
238 #[serde(rename = "ID")]
239 id: String,
240 #[serde(rename = "Service")]
241 service: String,
242 #[serde(rename = "Address", default)]
243 address: String,
244 #[serde(rename = "Port")]
245 port: u16,
246 #[serde(rename = "Meta", default)]
247 meta: Option<JsonValue>,
248}
249
250#[derive(Debug, Deserialize)]
251struct ConsulServiceEntry {
252 #[serde(rename = "Node")]
253 node: ConsulNodeEntry,
254 #[serde(rename = "Service")]
255 service: ConsulServiceInfo,
256}
257
258#[async_trait]
259impl Discovery for ConsulDiscovery {
260 async fn register(&self, instance: ServiceInstance) -> Result<(), DiscoveryError> {
261 let (address, port) = Self::split_address(&instance.address)?;
262 let check_id = format!("service:{}", instance.id);
263 let payload = serde_json::json!({
264 "ID": instance.id,
265 "Name": instance.name,
266 "Address": address,
267 "Port": port,
268 "Meta": instance.metadata,
269 "Check": {
270 "CheckID": check_id,
271 "TTL": "30s",
272 "DeregisterCriticalServiceAfter": "90s"
273 }
274 });
275
276 self.send_request(
277 self.request(reqwest::Method::PUT, "/v1/agent/service/register")
278 .json(&payload),
279 )
280 .await?;
281
282 self.notify_watchers(
283 &instance.name,
284 self.current_instances(&instance.name).await?,
285 )
286 .await;
287 Ok(())
288 }
289
290 async fn discover(&self, service_name: &str) -> Result<Vec<ServiceInstance>, DiscoveryError> {
291 let response = self
292 .send_request(
293 self.request(
294 reqwest::Method::GET,
295 &format!("/v1/health/service/{service_name}"),
296 )
297 .query(&[("passing", "true")]),
298 )
299 .await?;
300
301 let entries = response
302 .json::<Vec<ConsulServiceEntry>>()
303 .await
304 .map_err(|e| DiscoveryError::Backend(e.to_string()))?;
305
306 entries
307 .iter()
308 .map(|entry| Self::service_instance_from_consul(entry, InstanceStatus::Healthy))
309 .collect()
310 }
311
312 async fn heartbeat(&self, service_name: &str, instance_id: &str) -> Result<(), DiscoveryError> {
313 let check_id = format!("service:{instance_id}");
314 debug!(service = %service_name, instance = %instance_id, check_id = %check_id, "sending consul heartbeat");
315 self.send_request(self.request(
316 reqwest::Method::PUT,
317 &format!("/v1/agent/check/pass/{check_id}"),
318 ))
319 .await?;
320 Ok(())
321 }
322
323 async fn deregister(
324 &self,
325 service_name: &str,
326 instance_id: &str,
327 ) -> Result<(), DiscoveryError> {
328 self.send_request(self.request(
329 reqwest::Method::PUT,
330 &format!("/v1/agent/service/deregister/{instance_id}"),
331 ))
332 .await?;
333
334 self.notify_watchers(service_name, self.current_instances(service_name).await?)
335 .await;
336 Ok(())
337 }
338
339 async fn watch(&self, service_name: &str) -> Result<WatchStream, DiscoveryError> {
340 let (tx, rx) = mpsc::channel(16);
341 {
342 let mut watchers = self.watchers.write().await;
343 watchers
344 .entry(service_name.to_string())
345 .or_default()
346 .push(tx.clone());
347 }
348
349 let initial = self.discover(service_name).await?;
350 let _ = tx.send(initial).await;
351
352 let mut tasks = self.watch_tasks.write().await;
353 if !tasks.contains_key(service_name) {
354 let service = service_name.to_string();
355 let discovery = self.clone();
356 let handle = tokio::spawn(async move {
357 discovery.watch_loop(service).await;
358 });
359 tasks.insert(service_name.to_string(), handle);
360 }
361
362 Ok(WatchStream { rx })
363 }
364}
365
366#[cfg(test)]
367mod tests {
368 use super::*;
369 use nova_boot::discovery::DiscoveryError;
370 use tokio::io::{AsyncReadExt, AsyncWriteExt};
371
372 fn discovery() -> ConsulDiscovery {
373 ConsulDiscovery::new("http://127.0.0.1:1", None, None)
374 }
375
376 #[tokio::test]
377 async fn invalid_url_returns_backend_error() {
378 let discovery = ConsulDiscovery::new("http://[", None, None);
379 let err = discovery.discover("users").await.expect_err("should fail");
380 assert!(matches!(err, DiscoveryError::Backend(_)));
381 }
382
383 #[tokio::test]
384 async fn discover_404_maps_to_not_found() {
385 let listener = tokio::net::TcpListener::bind("127.0.0.1:0")
386 .await
387 .expect("bind test server");
388 let addr = listener.local_addr().expect("local addr");
389
390 tokio::spawn(async move {
391 if let Ok((mut socket, _)) = listener.accept().await {
392 let mut buf = [0u8; 1024];
393 let _ = socket.read(&mut buf).await;
394 let response =
395 b"HTTP/1.1 404 Not Found\r\nContent-Length: 0\r\nConnection: close\r\n\r\n";
396 let _ = socket.write_all(response).await;
397 let _ = socket.shutdown().await;
398 }
399 });
400
401 let discovery = ConsulDiscovery::new(format!("http://{addr}"), None, None);
402 let err = discovery.discover("users").await.expect_err("should fail");
403 assert!(matches!(err, DiscoveryError::NotFound(_)));
404 }
405
406 #[tokio::test]
407 async fn register_invalid_address_returns_backend_error() {
408 let discovery = discovery();
409 let instance = ServiceInstance {
410 id: "users-1".to_string(),
411 name: "users".to_string(),
412 address: "invalid-address".to_string(),
413 metadata: HashMap::new(),
414 status: InstanceStatus::Healthy,
415 last_heartbeat: None,
416 };
417
418 let err = discovery.register(instance).await.expect_err("should fail");
419 assert!(matches!(err, DiscoveryError::Backend(_)));
420 }
421}