Skip to main content

nova_boot_discovery_consul/
lib.rs

1use async_trait::async_trait;
2use nova_boot::discovery::{
3    Discovery, DiscoveryError, InstanceStatus, ServiceInstance, WatchStream,
4};
5use serde::Deserialize;
6use serde_json::Value as JsonValue;
7use std::collections::HashMap;
8use std::sync::Arc;
9use tokio::sync::{RwLock, mpsc};
10use tracing::{debug, warn};
11
12type WatchersMap = HashMap<String, Vec<mpsc::Sender<Vec<ServiceInstance>>>>;
13
14#[derive(Clone)]
15pub struct ConsulDiscovery {
16    client: reqwest::Client,
17    base_url: String,
18    datacenter: Option<String>,
19    token: Option<String>,
20    // Internal: tracks watch channels per service.
21    watchers: Arc<RwLock<WatchersMap>>,
22    watch_tasks: Arc<RwLock<HashMap<String, tokio::task::JoinHandle<()>>>>,
23}
24
25impl ConsulDiscovery {
26    pub fn new(
27        base_url: impl Into<String>,
28        datacenter: Option<String>,
29        token: Option<String>,
30    ) -> Self {
31        Self {
32            client: reqwest::Client::new(),
33            base_url: base_url.into().trim_end_matches('/').to_string(),
34            datacenter,
35            token,
36            watchers: Arc::new(RwLock::new(HashMap::new())),
37            watch_tasks: Arc::new(RwLock::new(HashMap::new())),
38        }
39    }
40
41    fn url(&self, path: &str) -> String {
42        format!(
43            "{}/{}",
44            self.base_url.trim_end_matches('/'),
45            path.trim_start_matches('/')
46        )
47    }
48
49    fn request(&self, method: reqwest::Method, path: &str) -> reqwest::RequestBuilder {
50        let mut builder = self.client.request(method, self.url(path));
51        if let Some(dc) = &self.datacenter {
52            builder = builder.query(&[("dc", dc)]);
53        }
54        if let Some(token) = &self.token {
55            builder = builder.header("X-Consul-Token", token);
56        }
57        builder
58    }
59
60    async fn send_request(
61        &self,
62        builder: reqwest::RequestBuilder,
63    ) -> Result<reqwest::Response, DiscoveryError> {
64        let response = builder
65            .send()
66            .await
67            .map_err(|e| DiscoveryError::Backend(e.to_string()))?;
68
69        if response.status() == reqwest::StatusCode::NOT_FOUND {
70            return Err(DiscoveryError::NotFound(
71                "consul resource not found".to_string(),
72            ));
73        }
74
75        if !response.status().is_success() {
76            return Err(DiscoveryError::Backend(format!(
77                "consul request failed with status {}",
78                response.status()
79            )));
80        }
81
82        Ok(response)
83    }
84
85    fn split_address(address: &str) -> Result<(String, u16), DiscoveryError> {
86        let (host, port) = address
87            .rsplit_once(':')
88            .ok_or_else(|| DiscoveryError::Backend(format!("invalid address: {address}")))?;
89        let port = port
90            .parse::<u16>()
91            .map_err(|e| DiscoveryError::Backend(format!("invalid address port: {e}")))?;
92        Ok((host.to_string(), port))
93    }
94
95    fn metadata_from_value(value: Option<&JsonValue>) -> HashMap<String, String> {
96        match value.and_then(JsonValue::as_object) {
97            Some(map) => map
98                .iter()
99                .map(|(key, value)| {
100                    let rendered = value
101                        .as_str()
102                        .map(ToString::to_string)
103                        .unwrap_or_else(|| value.to_string());
104                    (key.clone(), rendered)
105                })
106                .collect(),
107            None => HashMap::new(),
108        }
109    }
110
111    fn service_instance_from_consul(
112        service: &ConsulServiceEntry,
113        status: InstanceStatus,
114    ) -> Result<ServiceInstance, DiscoveryError> {
115        let address = if service.service.address.is_empty() {
116            service.node.address.clone()
117        } else {
118            service.service.address.clone()
119        };
120        let address = format!("{}:{}", address, service.service.port);
121        let metadata = Self::metadata_from_value(service.service.meta.as_ref());
122
123        Ok(ServiceInstance {
124            id: service.service.id.clone(),
125            name: service.service.service.clone(),
126            address,
127            metadata,
128            status,
129            last_heartbeat: None,
130        })
131    }
132
133    async fn notify_watchers(&self, service_name: &str, instances: Vec<ServiceInstance>) {
134        let watchers = {
135            let watchers = self.watchers.read().await;
136            watchers.get(service_name).cloned().unwrap_or_default()
137        };
138
139        if watchers.is_empty() {
140            return;
141        }
142
143        let mut alive = Vec::with_capacity(watchers.len());
144        for watcher in watchers {
145            if watcher.send(instances.clone()).await.is_ok() {
146                alive.push(watcher);
147            }
148        }
149
150        let mut watchers_map = self.watchers.write().await;
151        if let Some(entry) = watchers_map.get_mut(service_name) {
152            *entry = alive;
153        }
154    }
155
156    async fn current_instances(
157        &self,
158        service_name: &str,
159    ) -> Result<Vec<ServiceInstance>, DiscoveryError> {
160        self.discover(service_name).await
161    }
162
163    async fn watch_loop(self, service_name: String) {
164        let mut last_index: u64 = 0;
165
166        loop {
167            let has_watchers = {
168                let watchers = self.watchers.read().await;
169                watchers
170                    .get(&service_name)
171                    .map(|items| !items.is_empty())
172                    .unwrap_or(false)
173            };
174
175            if !has_watchers {
176                break;
177            }
178
179            match self.long_poll(&service_name, last_index).await {
180                Ok((instances, index)) => {
181                    if index > last_index {
182                        last_index = index;
183                        self.notify_watchers(&service_name, instances).await;
184                    }
185                }
186                Err(err) => {
187                    warn!(service = %service_name, error = %err, "Consul watch loop error");
188                    tokio::time::sleep(std::time::Duration::from_secs(1)).await;
189                }
190            }
191        }
192
193        self.watch_tasks.write().await.remove(&service_name);
194    }
195
196    async fn long_poll(
197        &self,
198        service_name: &str,
199        last_index: u64,
200    ) -> Result<(Vec<ServiceInstance>, u64), DiscoveryError> {
201        let mut request = self
202            .request(
203                reqwest::Method::GET,
204                &format!("/v1/health/service/{service_name}"),
205            )
206            .query(&[("passing", "true"), ("wait", "30s")]);
207        if last_index > 0 {
208            request = request.query(&[("index", &last_index.to_string())]);
209        }
210
211        let response = self.send_request(request).await?;
212        let index = response
213            .headers()
214            .get("X-Consul-Index")
215            .and_then(|value| value.to_str().ok())
216            .and_then(|value| value.parse::<u64>().ok())
217            .unwrap_or(last_index);
218        let entries = response
219            .json::<Vec<ConsulServiceEntry>>()
220            .await
221            .map_err(|e| DiscoveryError::Backend(e.to_string()))?;
222        let instances = entries
223            .iter()
224            .map(|entry| Self::service_instance_from_consul(entry, InstanceStatus::Healthy))
225            .collect::<Result<Vec<_>, _>>()?;
226        Ok((instances, index))
227    }
228}
229
230#[derive(Debug, Deserialize)]
231struct ConsulNodeEntry {
232    #[serde(rename = "Address")]
233    address: String,
234}
235
236#[derive(Debug, Deserialize)]
237struct ConsulServiceInfo {
238    #[serde(rename = "ID")]
239    id: String,
240    #[serde(rename = "Service")]
241    service: String,
242    #[serde(rename = "Address", default)]
243    address: String,
244    #[serde(rename = "Port")]
245    port: u16,
246    #[serde(rename = "Meta", default)]
247    meta: Option<JsonValue>,
248}
249
250#[derive(Debug, Deserialize)]
251struct ConsulServiceEntry {
252    #[serde(rename = "Node")]
253    node: ConsulNodeEntry,
254    #[serde(rename = "Service")]
255    service: ConsulServiceInfo,
256}
257
258#[async_trait]
259impl Discovery for ConsulDiscovery {
260    async fn register(&self, instance: ServiceInstance) -> Result<(), DiscoveryError> {
261        let (address, port) = Self::split_address(&instance.address)?;
262        let check_id = format!("service:{}", instance.id);
263        let payload = serde_json::json!({
264            "ID": instance.id,
265            "Name": instance.name,
266            "Address": address,
267            "Port": port,
268            "Meta": instance.metadata,
269            "Check": {
270                "CheckID": check_id,
271                "TTL": "30s",
272                "DeregisterCriticalServiceAfter": "90s"
273            }
274        });
275
276        self.send_request(
277            self.request(reqwest::Method::PUT, "/v1/agent/service/register")
278                .json(&payload),
279        )
280        .await?;
281
282        self.notify_watchers(
283            &instance.name,
284            self.current_instances(&instance.name).await?,
285        )
286        .await;
287        Ok(())
288    }
289
290    async fn discover(&self, service_name: &str) -> Result<Vec<ServiceInstance>, DiscoveryError> {
291        let response = self
292            .send_request(
293                self.request(
294                    reqwest::Method::GET,
295                    &format!("/v1/health/service/{service_name}"),
296                )
297                .query(&[("passing", "true")]),
298            )
299            .await?;
300
301        let entries = response
302            .json::<Vec<ConsulServiceEntry>>()
303            .await
304            .map_err(|e| DiscoveryError::Backend(e.to_string()))?;
305
306        entries
307            .iter()
308            .map(|entry| Self::service_instance_from_consul(entry, InstanceStatus::Healthy))
309            .collect()
310    }
311
312    async fn heartbeat(&self, service_name: &str, instance_id: &str) -> Result<(), DiscoveryError> {
313        let check_id = format!("service:{instance_id}");
314        debug!(service = %service_name, instance = %instance_id, check_id = %check_id, "sending consul heartbeat");
315        self.send_request(self.request(
316            reqwest::Method::PUT,
317            &format!("/v1/agent/check/pass/{check_id}"),
318        ))
319        .await?;
320        Ok(())
321    }
322
323    async fn deregister(
324        &self,
325        service_name: &str,
326        instance_id: &str,
327    ) -> Result<(), DiscoveryError> {
328        self.send_request(self.request(
329            reqwest::Method::PUT,
330            &format!("/v1/agent/service/deregister/{instance_id}"),
331        ))
332        .await?;
333
334        self.notify_watchers(service_name, self.current_instances(service_name).await?)
335            .await;
336        Ok(())
337    }
338
339    async fn watch(&self, service_name: &str) -> Result<WatchStream, DiscoveryError> {
340        let (tx, rx) = mpsc::channel(16);
341        {
342            let mut watchers = self.watchers.write().await;
343            watchers
344                .entry(service_name.to_string())
345                .or_default()
346                .push(tx.clone());
347        }
348
349        let initial = self.discover(service_name).await?;
350        let _ = tx.send(initial).await;
351
352        let mut tasks = self.watch_tasks.write().await;
353        if !tasks.contains_key(service_name) {
354            let service = service_name.to_string();
355            let discovery = self.clone();
356            let handle = tokio::spawn(async move {
357                discovery.watch_loop(service).await;
358            });
359            tasks.insert(service_name.to_string(), handle);
360        }
361
362        Ok(WatchStream { rx })
363    }
364}
365
366#[cfg(test)]
367mod tests {
368    use super::*;
369    use nova_boot::discovery::DiscoveryError;
370    use tokio::io::{AsyncReadExt, AsyncWriteExt};
371
372    fn discovery() -> ConsulDiscovery {
373        ConsulDiscovery::new("http://127.0.0.1:1", None, None)
374    }
375
376    #[tokio::test]
377    async fn invalid_url_returns_backend_error() {
378        let discovery = ConsulDiscovery::new("http://[", None, None);
379        let err = discovery.discover("users").await.expect_err("should fail");
380        assert!(matches!(err, DiscoveryError::Backend(_)));
381    }
382
383    #[tokio::test]
384    async fn discover_404_maps_to_not_found() {
385        let listener = tokio::net::TcpListener::bind("127.0.0.1:0")
386            .await
387            .expect("bind test server");
388        let addr = listener.local_addr().expect("local addr");
389
390        tokio::spawn(async move {
391            if let Ok((mut socket, _)) = listener.accept().await {
392                let mut buf = [0u8; 1024];
393                let _ = socket.read(&mut buf).await;
394                let response =
395                    b"HTTP/1.1 404 Not Found\r\nContent-Length: 0\r\nConnection: close\r\n\r\n";
396                let _ = socket.write_all(response).await;
397                let _ = socket.shutdown().await;
398            }
399        });
400
401        let discovery = ConsulDiscovery::new(format!("http://{addr}"), None, None);
402        let err = discovery.discover("users").await.expect_err("should fail");
403        assert!(matches!(err, DiscoveryError::NotFound(_)));
404    }
405
406    #[tokio::test]
407    async fn register_invalid_address_returns_backend_error() {
408        let discovery = discovery();
409        let instance = ServiceInstance {
410            id: "users-1".to_string(),
411            name: "users".to_string(),
412            address: "invalid-address".to_string(),
413            metadata: HashMap::new(),
414            status: InstanceStatus::Healthy,
415            last_heartbeat: None,
416        };
417
418        let err = discovery.register(instance).await.expect_err("should fail");
419        assert!(matches!(err, DiscoveryError::Backend(_)));
420    }
421}