1#![deny(missing_docs)]
29use http_body_util::BodyExt;
30
31#[cfg(feature = "metrics")]
32use metrics::MetricInfoWrapper;
33
34use std::collections::HashMap;
35use std::convert::Infallible;
36use std::env;
37use std::time::Duration;
38
39use base64::Engine;
40use http_body_util::combinators::BoxBody;
41use http_body_util::{Empty, Full};
42use hyper::body::Bytes;
43use hyper::{Method, body::Buf};
44use hyper_util::client::legacy::{Builder, Client, connect::HttpConnector};
45use serde::{Deserialize, Serialize};
46use slog_scope::{error, info};
47use tokio::time::timeout;
48
49pub use errors::ConsulError;
50use errors::Result;
51mod lock;
53mod utils;
55#[cfg(feature = "metrics")]
56use http::StatusCode;
57
58#[cfg(feature = "trace")]
59use opentelemetry::global;
60#[cfg(feature = "trace")]
61use opentelemetry::global::BoxedTracer;
62#[cfg(feature = "trace")]
63use opentelemetry::trace::Span;
64#[cfg(feature = "trace")]
65use opentelemetry::trace::Status;
66
67pub use lock::*;
68#[cfg(feature = "metrics")]
69pub use metrics::MetricInfo;
70pub use metrics::{Function, HttpMethod};
71pub use types::*;
72
73pub mod acl;
75
76pub mod acl_types;
78pub use acl_types::*;
79
80mod errors;
82#[cfg(feature = "trace")]
83mod hyper_wrapper;
84mod metrics;
86pub mod types;
88
89#[derive(Clone, Debug, Serialize, Deserialize)]
91pub struct Config {
92 pub address: String,
94 pub token: Option<String>,
96
97 #[serde(skip)]
99 #[serde(default = "default_builder")]
100 pub hyper_builder: hyper_util::client::legacy::Builder,
101}
102
103fn default_builder() -> Builder {
104 Builder::new(hyper_util::rt::TokioExecutor::new())
106 .pool_idle_timeout(std::time::Duration::from_millis(0))
107 .pool_max_idle_per_host(0)
108 .to_owned()
109}
110
111impl Default for Config {
112 fn default() -> Self {
113 Config {
114 address: String::default(),
115 token: None,
116 hyper_builder: default_builder(),
117 }
118 }
119}
120
121impl Config {
122 pub fn from_env() -> Self {
127 let token = env::var("CONSUL_HTTP_TOKEN").unwrap_or_default();
128 let addr =
129 env::var("CONSUL_HTTP_ADDR").unwrap_or_else(|_| "http://127.0.0.1:8500".to_string());
130
131 Config {
132 address: addr,
133 token: Some(token),
134 hyper_builder: default_builder(),
135 }
136 }
137}
138pub type HttpsClient =
140 Client<hyper_rustls::HttpsConnector<HttpConnector>, BoxBody<Bytes, Infallible>>;
141
142#[derive(Debug)]
143pub struct Consul {
145 https_client: HttpsClient,
146 config: Config,
147 #[cfg(feature = "trace")]
148 tracer: BoxedTracer,
149 #[cfg(feature = "metrics")]
150 metrics_tx: tokio::sync::mpsc::UnboundedSender<MetricInfo>,
151 #[cfg(feature = "metrics")]
152 metrics_rx: Option<tokio::sync::mpsc::UnboundedReceiver<MetricInfo>>,
153}
154
155fn https_connector() -> hyper_rustls::HttpsConnector<HttpConnector> {
156 hyper_rustls::HttpsConnectorBuilder::new()
157 .with_webpki_roots()
158 .https_or_http()
159 .enable_http1()
160 .build()
161}
162
163pub struct ConsulBuilder {
166 config: Config,
167 https_client: Option<HttpsClient>,
168}
169
170impl ConsulBuilder {
171 pub fn new(config: Config) -> Self {
173 Self {
174 config,
175 https_client: None,
176 }
177 }
178
179 pub fn with_https_client(mut self, https_client: HttpsClient) -> Self {
183 self.https_client = Some(https_client);
184 self
185 }
186
187 pub fn build(self) -> Consul {
189 let https_client = self.https_client.unwrap_or_else(|| {
190 let https = https_connector();
191 self.config
192 .hyper_builder
193 .build::<_, BoxBody<Bytes, Infallible>>(https)
194 });
195
196 Consul::new_with_client(self.config, https_client)
197 }
198}
199
200impl Consul {
201 pub fn new(config: Config) -> Self {
206 ConsulBuilder::new(config).build()
207 }
208
209 pub fn new_with_client(config: Config, https_client: HttpsClient) -> Self {
215 #[cfg(feature = "metrics")]
216 let (tx, rx) = tokio::sync::mpsc::unbounded_channel::<MetricInfo>();
217 Consul {
218 https_client,
219 config,
220 #[cfg(feature = "trace")]
221 tracer: global::tracer("consul"),
222 #[cfg(feature = "metrics")]
223 metrics_tx: tx,
224 #[cfg(feature = "metrics")]
225 metrics_rx: Some(rx),
226 }
227 }
228
229 #[cfg(feature = "metrics")]
230 pub fn metrics_receiver(&mut self) -> Option<tokio::sync::mpsc::UnboundedReceiver<MetricInfo>> {
232 self.metrics_rx.take()
233 }
234
235 pub async fn read_key(
241 &self,
242 request: ReadKeyRequest<'_>,
243 ) -> Result<ResponseMeta<Vec<ReadKeyResponse>>> {
244 let req = self.build_read_key_req(request);
245 let (response_body, index) = self
246 .execute_request(
247 req,
248 BoxBody::new(http_body_util::Empty::<Bytes>::new()),
249 None,
250 Function::ReadKey,
251 )
252 .await?;
253 Ok(ResponseMeta {
254 response: serde_json::from_reader::<_, Vec<ReadKeyResponse>>(response_body.reader())
255 .map_err(ConsulError::ResponseDeserializationFailed)?
256 .into_iter()
257 .map(|mut r| {
258 r.value = match r.value {
259 Some(val) => Some(
260 std::str::from_utf8(
261 &base64::engine::general_purpose::STANDARD.decode(val)?,
262 )?
263 .to_string(),
264 ),
265 None => None,
266 };
267
268 Ok(r)
269 })
270 .collect::<Result<Vec<_>>>()?,
271 index,
272 })
273 }
274
275 pub async fn create_or_update_key(
284 &self,
285 request: CreateOrUpdateKeyRequest<'_>,
286 value: Vec<u8>,
287 ) -> Result<(bool, u64)> {
288 let url = self.build_create_or_update_url(request);
289 let req = hyper::Request::builder().method(Method::PUT).uri(url);
290 let (response_body, index) = self
291 .execute_request(
292 req,
293 BoxBody::new(Full::<Bytes>::new(Bytes::from(value))),
294 None,
295 Function::CreateOrUpdateKey,
296 )
297 .await?;
298 Ok((
299 serde_json::from_reader(response_body.reader())
300 .map_err(ConsulError::ResponseDeserializationFailed)?,
301 index,
302 ))
303 }
304
305 pub fn create_or_update_key_sync(
315 &self,
316 request: CreateOrUpdateKeyRequest<'_>,
317 value: Vec<u8>,
318 ) -> Result<bool> {
319 let url = self.build_create_or_update_url(request);
322 #[cfg(feature = "metrics")]
323 let mut metrics_info_wrapper = MetricInfoWrapper::new(
324 HttpMethod::Put,
325 Function::CreateOrUpdateKey,
326 None,
327 self.metrics_tx.clone(),
328 );
329 let result = ureq::put(&url)
330 .header(
331 "X-Consul-Token",
332 &self.config.token.clone().unwrap_or_default(),
333 )
334 .send(&value);
335
336 let response = result.map_err(|e| match e {
337 ureq::Error::StatusCode(code) => {
338 let code = hyper::StatusCode::from_u16(code).unwrap_or_default();
339 #[cfg(feature = "metrics")]
340 {
341 metrics_info_wrapper.set_status(code);
342 metrics_info_wrapper.emit_metrics();
343 }
344 ConsulError::UnexpectedResponseCode(code, None)
345 }
346 e => ConsulError::UReqError(e),
347 })?;
348 let status = response.status();
349 if status == 200 {
350 let val = response
351 .into_body()
352 .read_to_string()
353 .map_err(ConsulError::UReqError)?;
354 let response: bool = std::str::FromStr::from_str(val.trim())?;
355 #[cfg(feature = "metrics")]
356 {
357 metrics_info_wrapper.set_status(StatusCode::OK);
358 metrics_info_wrapper.emit_metrics();
359 }
360 return Ok(response);
361 }
362
363 let body = response
364 .into_body()
365 .read_to_string()
366 .map_err(ConsulError::UReqError)?;
367 Err(ConsulError::SyncUnexpectedResponseCode(
368 hyper::StatusCode::as_u16(&status),
369 body,
370 ))
371 }
372
373 pub async fn delete_key(&self, request: DeleteKeyRequest<'_>) -> Result<bool> {
379 let mut req = hyper::Request::builder().method(Method::DELETE);
380 let mut url = String::new();
381 url.push_str(&format!(
382 "{}/v1/kv/{}?recurse={}",
383 self.config.address, request.key, request.recurse
384 ));
385 if request.check_and_set != 0 {
386 url.push_str(&format!("&cas={}", request.check_and_set));
387 }
388
389 url = utils::add_namespace_and_datacenter(url, request.namespace, request.datacenter);
390 req = req.uri(url);
391 let (response_body, _index) = self
392 .execute_request(
393 req,
394 BoxBody::new(Empty::<Bytes>::new()),
395 None,
396 Function::DeleteKey,
397 )
398 .await?;
399 serde_json::from_reader(response_body.reader())
400 .map_err(ConsulError::ResponseDeserializationFailed)
401 }
402
403 pub async fn register_entity(&self, payload: &RegisterEntityPayload) -> Result<()> {
410 let uri = format!("{}/v1/catalog/register", self.config.address);
411 let request = hyper::Request::builder().method(Method::PUT).uri(uri);
412 let payload = serde_json::to_string(payload).map_err(ConsulError::InvalidRequest)?;
413 self.execute_request(
414 request,
415 BoxBody::new(Full::<Bytes>::new(Bytes::from(payload.into_bytes()))),
416 Some(Duration::from_secs(5)),
417 Function::RegisterEntity,
418 )
419 .await?;
420 Ok(())
421 }
422
423 pub async fn deregister_entity(&self, payload: &DeregisterEntityPayload) -> Result<()> {
430 let uri = format!("{}/v1/catalog/deregister", self.config.address);
431 let request = hyper::Request::builder().method(Method::PUT).uri(uri);
432 let payload = serde_json::to_string(payload).map_err(ConsulError::InvalidRequest)?;
433 self.execute_request(
434 request,
435 BoxBody::new(Full::<Bytes>::new(Bytes::from(payload.into_bytes()))),
436 Some(Duration::from_secs(5)),
437 Function::DeregisterEntity,
438 )
439 .await?;
440 Ok(())
441 }
442
443 pub async fn get_all_registered_service_names(
450 &self,
451 query_opts: Option<QueryOptions>,
452 ) -> Result<ResponseMeta<Vec<String>>> {
453 let mut uri = format!("{}/v1/catalog/services", self.config.address);
454 let query_opts = query_opts.unwrap_or_default();
455 utils::add_query_option_params(&mut uri, &query_opts, '?');
456
457 let request = hyper::Request::builder()
458 .method(Method::GET)
459 .uri(uri.clone());
460 let (response_body, index) = self
461 .execute_request(
462 request,
463 BoxBody::new(Empty::<Bytes>::new()),
464 query_opts.timeout,
465 Function::GetAllRegisteredServices,
466 )
467 .await?;
468 let service_tags_by_name =
469 serde_json::from_reader::<_, HashMap<String, Vec<String>>>(response_body.reader())
470 .map_err(ConsulError::ResponseDeserializationFailed)?;
471
472 Ok(ResponseMeta {
473 response: service_tags_by_name.keys().cloned().collect(),
474 index,
475 })
476 }
477
478 pub async fn get_service_nodes(
486 &self,
487 request: GetServiceNodesRequest<'_>,
488 query_opts: Option<QueryOptions>,
489 ) -> Result<ResponseMeta<GetServiceNodesResponse>> {
490 let query_opts = query_opts.unwrap_or_default();
491 let req = self.build_get_service_nodes_req(request, &query_opts);
492 let (response_body, index) = self
493 .execute_request(
494 req,
495 BoxBody::new(Empty::<Bytes>::new()),
496 query_opts.timeout,
497 Function::GetServiceNodes,
498 )
499 .await?;
500 let response =
501 serde_json::from_reader::<_, GetServiceNodesResponse>(response_body.reader())
502 .map_err(ConsulError::ResponseDeserializationFailed)?;
503 Ok(ResponseMeta { response, index })
504 }
505
506 pub async fn get_service_addresses_and_ports(
508 &self,
509 service_name: &str,
510 query_opts: Option<QueryOptions>,
511 ) -> Result<Vec<(String, u16)>> {
512 let request = GetServiceNodesRequest {
513 service: service_name,
514 passing: true,
515 ..Default::default()
516 };
517 let services = self.get_service_nodes(request, query_opts).await.map_err(|e| {
518 let err = format!(
519 "Unable to query consul to resolve service '{}' to a list of addresses and ports: {:?}",
520 service_name, e
521 );
522 error!("{}", err);
523 ConsulError::ServiceInstanceResolutionFailed(service_name.to_string())
524 })?;
525
526 let addresses_and_ports = services
527 .response
528 .into_iter()
529 .map(Self::parse_host_port_from_service_node_response)
530 .collect();
531 info!(
532 "resolved service '{}' to addresses and ports: '{:?}'",
533 service_name, addresses_and_ports
534 );
535
536 Ok(addresses_and_ports)
537 }
538
539 pub async fn get_nodes(
545 &self,
546 request: GetNodesRequest<'_>,
547 query_opts: Option<QueryOptions>,
548 ) -> Result<ResponseMeta<GetNodesResponse>> {
549 let query_opts = query_opts.unwrap_or_default();
550 let req = self.build_get_nodes_req(request, &query_opts);
551 let (response_body, index) = self
552 .execute_request(
553 req,
554 BoxBody::new(Empty::<Bytes>::new()),
555 query_opts.timeout,
556 Function::GetNodes,
557 )
558 .await?;
559 let response = serde_json::from_reader::<_, GetNodesResponse>(response_body.reader())
560 .map_err(ConsulError::ResponseDeserializationFailed)?;
561 Ok(ResponseMeta { response, index })
562 }
563
564 fn parse_host_port_from_service_node_response(sn: ServiceNode) -> (String, u16) {
576 (
577 if sn.service.address.is_empty() {
578 info!(
579 "Consul service {service_name} instance had an empty Service address, with port:{port}",
580 service_name = &sn.service.service,
581 port = sn.service.port
582 );
583 sn.node.address
584 } else {
585 sn.service.address
586 },
587 sn.service.port,
588 )
589 }
590
591 fn build_read_key_req(&self, request: ReadKeyRequest<'_>) -> http::request::Builder {
592 let req = hyper::Request::builder().method(Method::GET);
593 let mut url = String::new();
594 url.push_str(&format!(
595 "{}/v1/kv/{}?recurse={}",
596 self.config.address, request.key, request.recurse
597 ));
598
599 if !request.separator.is_empty() {
600 url.push_str(&format!("&separator={}", request.separator));
601 }
602 if request.consistency == ConsistencyMode::Consistent {
603 url.push_str("&consistent");
604 } else if request.consistency == ConsistencyMode::Stale {
605 url.push_str("&stale");
606 }
607
608 if let Some(index) = request.index {
609 url.push_str(&format!("&index={}", index));
610 if request.wait.as_secs() > 0 {
611 url.push_str(&format!(
612 "&wait={}",
613 types::duration_as_string(&request.wait)
614 ));
615 }
616 }
617 url = utils::add_namespace_and_datacenter(url, request.namespace, request.datacenter);
618 req.uri(url)
619 }
620
621 fn build_get_service_nodes_req(
622 &self,
623 request: GetServiceNodesRequest<'_>,
624 query_opts: &QueryOptions,
625 ) -> http::request::Builder {
626 let req = hyper::Request::builder().method(Method::GET);
627 let mut url = String::new();
628 url.push_str(&format!(
629 "{}/v1/health/service/{}",
630 self.config.address, request.service
631 ));
632 url.push_str(&format!("?passing={}", request.passing));
633 if let Some(near) = request.near {
634 url.push_str(&format!("&near={}", near));
635 }
636 if let Some(filter) = request.filter {
637 url.push_str(&format!("&filter={}", filter));
638 }
639 utils::add_query_option_params(&mut url, query_opts, '&');
640 req.uri(url)
641 }
642
643 fn build_get_nodes_req(
645 &self,
646 request: GetNodesRequest<'_>,
647 query_opts: &QueryOptions,
648 ) -> http::request::Builder {
649 let req = hyper::Request::builder().method(Method::GET);
650 let mut url = String::new();
651 url.push_str(&format!("{}/v1/catalog/nodes", self.config.address));
652 let mut added_query_param = false;
653 if let Some(near) = request.near {
654 url = utils::add_query_param_separator(url, added_query_param);
655 url.push_str(&format!("near={}", near));
656 added_query_param = true;
657 }
658 if let Some(filter) = request.filter {
659 url = utils::add_query_param_separator(url, added_query_param);
660 url.push_str(&format!("filter={}", filter));
661 added_query_param = true;
662 }
663 if let Some(dc) = &query_opts.datacenter
664 && !dc.is_empty()
665 {
666 url = utils::add_query_param_separator(url, added_query_param);
667 url.push_str(&format!("dc={}", dc));
668 }
669
670 req.uri(url)
671 }
672
673 async fn execute_request(
674 &self,
675 req: http::request::Builder,
676 body: BoxBody<Bytes, Infallible>,
677 duration: Option<std::time::Duration>,
678 _function: Function,
679 ) -> Result<(Box<dyn Buf>, u64)> {
680 let req = req
681 .header(
682 "X-Consul-Token",
683 self.config.token.clone().unwrap_or_default(),
684 )
685 .body(body);
686 let req = req.map_err(ConsulError::RequestError)?;
687 #[cfg(feature = "trace")]
688 let mut span = crate::hyper_wrapper::span_for_request(&self.tracer, &req);
689
690 #[cfg(feature = "metrics")]
691 let mut metrics_info_wrapper = MetricInfoWrapper::new(
692 req.method().clone().into(),
693 _function,
694 None,
695 self.metrics_tx.clone(),
696 );
697 let future = self.https_client.request(req);
698 let response = if let Some(dur) = duration {
699 match timeout(dur, future).await {
700 Ok(resp) => resp.map_err(ConsulError::ResponseError),
701 Err(_) => {
702 #[cfg(feature = "metrics")]
703 {
704 metrics_info_wrapper.set_status(StatusCode::REQUEST_TIMEOUT);
705 metrics_info_wrapper.emit_metrics();
706 }
707 Err(ConsulError::TimeoutExceeded(dur))
708 }
709 }
710 } else {
711 future.await.map_err(ConsulError::ResponseError)
712 };
713
714 let response = response.inspect_err(|_| {
715 #[cfg(feature = "metrics")]
716 metrics_info_wrapper.emit_metrics();
717 })?;
718
719 #[cfg(feature = "trace")]
720 crate::hyper_wrapper::annotate_span_for_response(&mut span, &response);
721
722 let status = response.status();
723 if status != hyper::StatusCode::OK {
724 #[cfg(feature = "metrics")]
725 {
726 metrics_info_wrapper.set_status(status);
727 metrics_info_wrapper.emit_metrics();
728 }
729
730 let mut response_body = response
731 .into_body()
732 .collect()
733 .await
734 .map_err(|e| ConsulError::UnexpectedResponseCode(status, Some(e.to_string())))?
735 .aggregate();
736 let bytes = response_body.copy_to_bytes(response_body.remaining());
737 let resp = std::str::from_utf8(&bytes)
738 .map_err(|e| ConsulError::UnexpectedResponseCode(status, Some(e.to_string())))?;
739 return Err(ConsulError::UnexpectedResponseCode(
740 status,
741 Some(resp.to_string()),
742 ));
743 }
744
745 let index = match response.headers().get("x-consul-index") {
746 Some(header) => header.to_str().unwrap_or("0").parse::<u64>().unwrap_or(0),
747 None => 0,
748 };
749
750 match response.into_body().collect().await.map(|b| b.aggregate()) {
751 Ok(body) => Ok((Box::new(body), index)),
752 Err(e) => {
753 #[cfg(feature = "trace")]
754 span.set_status(Status::error(e.to_string()));
755 Err(ConsulError::InvalidResponse(e))
756 }
757 }
758 }
759
760 fn build_create_or_update_url(&self, request: CreateOrUpdateKeyRequest<'_>) -> String {
761 let mut url = String::new();
762 url.push_str(&format!("{}/v1/kv/{}", self.config.address, request.key));
763 let mut added_query_param = false;
764 if request.flags != 0 {
765 url = utils::add_query_param_separator(url, added_query_param);
766 url.push_str(&format!("flags={}", request.flags));
767 added_query_param = true;
768 }
769 if !request.acquire.is_empty() {
770 url = utils::add_query_param_separator(url, added_query_param);
771 url.push_str(&format!("acquire={}", request.acquire));
772 added_query_param = true;
773 }
774 if !request.release.is_empty() {
775 url = utils::add_query_param_separator(url, added_query_param);
776 url.push_str(&format!("release={}", request.release));
777 added_query_param = true;
778 }
779 if let Some(cas_idx) = request.check_and_set {
780 url = utils::add_query_param_separator(url, added_query_param);
781 url.push_str(&format!("cas={}", cas_idx));
782 }
783
784 utils::add_namespace_and_datacenter(url, request.namespace, request.datacenter)
785 }
786}
787
788#[cfg(test)]
789mod tests {
790 use super::*;
791 #[test]
792 fn test_service_node_parsing() {
793 let node = Node {
794 id: "node".to_string(),
795 node: "node".to_string(),
796 address: "1.1.1.1".to_string(),
797 datacenter: "datacenter".to_string(),
798 };
799
800 let service = Service {
801 id: "node".to_string(),
802 service: "node".to_string(),
803 address: "2.2.2.2".to_string(),
804 port: 32,
805 tags: vec!["foo".to_string(), "bar=baz".to_string()],
806 };
807
808 let empty_service = Service {
809 id: "".to_string(),
810 service: "".to_string(),
811 address: "".to_string(),
812 port: 32,
813 tags: vec![],
814 };
815
816 let sn = ServiceNode {
817 node: node.clone(),
818 service: service.clone(),
819 };
820
821 let (host, port) = Consul::parse_host_port_from_service_node_response(sn);
822 assert_eq!(service.port, port);
823 assert_eq!(service.address, host);
824
825 let sn = ServiceNode {
826 node: node.clone(),
827 service: empty_service,
828 };
829
830 let (host, port) = Consul::parse_host_port_from_service_node_response(sn);
831 assert_eq!(service.port, port);
832 assert_eq!(node.address, host);
833 }
834}