#![deny(missing_docs)]
use std::collections::HashMap;
use std::time::{Duration, Instant};
use std::{env, str::Utf8Error};
use base64::Engine;
use hyper::{body::Buf, client::HttpConnector, Body, Method};
#[cfg(any(feature = "rustls-native", feature = "rustls-webpki"))]
#[cfg(feature = "metrics")]
use lazy_static::lazy_static;
use quick_error::quick_error;
use serde::{Deserialize, Serialize};
use slog_scope::{error, info};
use tokio::time::timeout;
#[cfg(feature = "trace")]
use opentelemetry::global;
#[cfg(feature = "trace")]
use opentelemetry::global::BoxedTracer;
#[cfg(feature = "trace")]
use opentelemetry::trace::Span;
#[cfg(feature = "trace")]
use opentelemetry::trace::Status;
pub use types::*;
#[cfg(feature = "trace")]
mod hyper_wrapper;
pub mod types;
quick_error! {
#[derive(Debug)]
pub enum ConsulError {
InvalidRequest(err: serde_json::error::Error) {}
RequestError(err: http::Error) {}
ResponseError(err: hyper::Error) {}
InvalidResponse(err: hyper::Error) {}
ResponseDeserializationFailed(err: serde_json::error::Error) {}
ResponseStringDeserializationFailed(err: std::str::Utf8Error) {}
UnexpectedResponseCode(status_code: hyper::http::StatusCode, body: String) {}
LockAcquisitionFailure(err: u64) {}
InvalidUtf8(err: Utf8Error) {
from()
}
InvalidBase64(err: base64::DecodeError) {
from()
}
SyncIoError(err: std::io::Error) {
from()
}
SyncInvalidResponseError(err: std::str::ParseBoolError) {
from()
}
SyncUnexpectedResponseCode(status_code: u16, body: String) {}
TimeoutExceeded(timeout: std::time::Duration) {
display("Consul request exceeded timeout of {:?}", timeout)
}
ServiceInstanceResolutionFailed(service_name: String) {
display("Unable to resolve service '{}' to a concrete list of addresses and ports for its instances via consul.", service_name)
}
TransportError(kind: ureq::ErrorKind, message: String) {
display("Transport error: {} - {}", kind, message)
}
}
}
#[cfg(feature = "metrics")]
lazy_static! {
static ref CONSUL_REQUESTS_TOTAL: prometheus::CounterVec = prometheus::register_counter_vec!(
prometheus::opts!("consul_requests_total", "Total requests made to consul"),
&["method", "function"]
)
.unwrap();
static ref CONSUL_REQUESTS_FAILED_TOTAL: prometheus::CounterVec =
prometheus::register_counter_vec!(
prometheus::opts!(
"consul_requests_failed_total",
"Total requests made to consul that failed"
),
&["method", "function"]
)
.unwrap();
static ref CONSUL_REQUESTS_DURATION_MS: prometheus::HistogramVec =
prometheus::register_histogram_vec!(
prometheus::histogram_opts!(
"consul_requests_duration_milliseconds",
"Time it takes for a consul request to complete"
),
&["method", "function"]
)
.unwrap();
}
const READ_KEY_METHOD_NAME: &str = "read_key";
const CREATE_OR_UPDATE_KEY_METHOD_NAME: &str = "create_or_update_key";
const CREATE_OR_UPDATE_KEY_SYNC_METHOD_NAME: &str = "create_or_update_key_sync";
const DELETE_KEY_METHOD_NAME: &str = "delete_key";
const GET_LOCK_METHOD_NAME: &str = "get_lock";
const REGISTER_ENTITY_METHOD_NAME: &str = "register_entity";
const DEREGISTER_ENTITY_METHOD_NAME: &str = "deregister_entity";
const GET_ALL_REGISTERED_SERVICE_NAMES_METHOD_NAME: &str = "get_all_registered_service_names";
const GET_SERVICE_NODES_METHOD_NAME: &str = "get_service_nodes";
const GET_SESSION_METHOD_NAME: &str = "get_session";
pub(crate) type Result<T> = std::result::Result<T, ConsulError>;
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
pub struct Config {
pub address: String,
pub token: Option<String>,
#[serde(skip)]
pub hyper_builder: hyper::client::Builder,
}
impl Config {
pub fn from_env() -> Self {
let token = env::var("CONSUL_HTTP_TOKEN").unwrap_or_default();
let addr =
env::var("CONSUL_HTTP_ADDR").unwrap_or_else(|_| "http://127.0.0.1:8500".to_string());
Config {
address: addr,
token: Some(token),
hyper_builder: Default::default(),
}
}
}
#[derive(Clone, Debug)]
pub struct Lock<'a> {
pub session_id: String,
pub key: String,
pub timeout: std::time::Duration,
pub namespace: String,
pub datacenter: String,
pub value: Option<Vec<u8>>,
pub consul: &'a Consul,
}
impl Drop for Lock<'_> {
fn drop(&mut self) {
let req = CreateOrUpdateKeyRequest {
key: &self.key,
namespace: &self.namespace,
datacenter: &self.datacenter,
release: &self.session_id,
..Default::default()
};
let val = self.value.clone().unwrap_or_default();
let _res = self.consul.create_or_update_key_sync(req, val);
}
}
#[derive(Debug)]
pub struct Consul {
https_client: hyper::Client<hyper_rustls::HttpsConnector<HttpConnector>, Body>,
config: Config,
#[cfg(feature = "trace")]
tracer: BoxedTracer,
}
fn https_connector() -> hyper_rustls::HttpsConnector<HttpConnector> {
#[cfg(feature = "rustls-webpki")]
return hyper_rustls::HttpsConnectorBuilder::new()
.with_webpki_roots()
.https_or_http()
.enable_http1()
.build();
#[allow(unreachable_code)]
hyper_rustls::HttpsConnectorBuilder::new()
.with_native_roots()
.https_or_http()
.enable_http1()
.build()
}
impl Consul {
pub fn new(config: Config) -> Self {
let https = https_connector();
let https_client = config.hyper_builder.build::<_, hyper::Body>(https);
Consul {
https_client,
config,
#[cfg(feature = "trace")]
tracer: global::tracer("consul"),
}
}
pub async fn read_key(&self, request: ReadKeyRequest<'_>) -> Result<Vec<ReadKeyResponse>> {
let req = self.build_read_key_req(request);
let (mut response_body, _index) = self
.execute_request(req, hyper::Body::empty(), None, READ_KEY_METHOD_NAME)
.await?;
let bytes = response_body.copy_to_bytes(response_body.remaining());
serde_json::from_slice::<Vec<ReadKeyResponse>>(&bytes)
.map_err(ConsulError::ResponseDeserializationFailed)?
.into_iter()
.map(|mut r| {
r.value = match r.value {
Some(val) => Some(
std::str::from_utf8(
&base64::engine::general_purpose::STANDARD.decode(val)?,
)?
.to_string(),
),
None => None,
};
Ok(r)
})
.collect()
}
pub async fn create_or_update_key(
&self,
request: CreateOrUpdateKeyRequest<'_>,
value: Vec<u8>,
) -> Result<(bool, u64)> {
let url = self.build_create_or_update_url(request);
let req = hyper::Request::builder().method(Method::PUT).uri(url);
let (mut response_body, index) = self
.execute_request(
req,
Body::from(value),
None,
CREATE_OR_UPDATE_KEY_METHOD_NAME,
)
.await?;
let bytes = response_body.copy_to_bytes(response_body.remaining());
Ok((
serde_json::from_slice(&bytes).map_err(ConsulError::ResponseDeserializationFailed)?,
index,
))
}
pub fn create_or_update_key_sync(
&self,
request: CreateOrUpdateKeyRequest<'_>,
value: Vec<u8>,
) -> Result<bool> {
let url = self.build_create_or_update_url(request);
record_request_metric_if_enabled(&Method::PUT, CREATE_OR_UPDATE_KEY_SYNC_METHOD_NAME);
let step_start_instant = Instant::now();
let result = ureq::put(&url)
.set(
"X-Consul-Token",
&self.config.token.clone().unwrap_or_default(),
)
.send_bytes(&value);
record_duration_metric_if_enabled(
&Method::PUT,
CREATE_OR_UPDATE_KEY_SYNC_METHOD_NAME,
step_start_instant.elapsed().as_millis() as f64,
);
let response = result.map_err(|e| {
record_failure_metric_if_enabled(&Method::PUT, CREATE_OR_UPDATE_KEY_SYNC_METHOD_NAME);
match e {
ureq::Error::Status(code, response) => ConsulError::UnexpectedResponseCode(
hyper::StatusCode::from_u16(code).unwrap_or_default(),
response.into_string().unwrap_or_default(),
),
ureq::Error::Transport(t) => ConsulError::TransportError(
t.kind(),
t.message().unwrap_or_default().to_string(),
),
}
})?;
let status = response.status();
if status == 200 {
let val = response.into_string()?;
let response: bool = std::str::FromStr::from_str(val.trim())?;
return Ok(response);
}
let body = response.into_string()?;
record_failure_metric_if_enabled(&Method::PUT, CREATE_OR_UPDATE_KEY_SYNC_METHOD_NAME);
Err(ConsulError::SyncUnexpectedResponseCode(status, body))
}
pub async fn delete_key(&self, request: DeleteKeyRequest<'_>) -> Result<bool> {
let mut req = hyper::Request::builder().method(Method::DELETE);
let mut url = String::new();
url.push_str(&format!(
"{}/v1/kv/{}?recurse={}",
self.config.address, request.key, request.recurse
));
if request.check_and_set != 0 {
url.push_str(&format!("&cas={}", request.check_and_set));
}
url = add_namespace_and_datacenter(url, request.namespace, request.datacenter);
req = req.uri(url);
let (mut response_body, _index) = self
.execute_request(req, hyper::Body::empty(), None, DELETE_KEY_METHOD_NAME)
.await?;
let bytes = response_body.copy_to_bytes(response_body.remaining());
serde_json::from_slice(&bytes).map_err(ConsulError::ResponseDeserializationFailed)
}
pub async fn get_lock(&self, request: LockRequest<'_>, value: &[u8]) -> Result<Lock<'_>> {
let session = self.get_session(request).await?;
let req = CreateOrUpdateKeyRequest {
key: request.key,
namespace: request.namespace,
datacenter: request.datacenter,
acquire: &session.id,
..Default::default()
};
let value_copy = value.to_vec();
let (lock_acquisition_result, _index) = self.create_or_update_key(req, value_copy).await?;
if lock_acquisition_result {
let value_copy = value.to_vec();
Ok(Lock {
timeout: request.timeout,
key: request.key.to_string(),
session_id: session.id,
consul: self,
datacenter: request.datacenter.to_string(),
namespace: request.namespace.to_string(),
value: Some(value_copy),
})
} else {
let watch_req = ReadKeyRequest {
key: request.key,
datacenter: request.datacenter,
namespace: request.namespace,
index: Some(0),
wait: std::time::Duration::from_secs(0),
..Default::default()
};
let lock_index_req = self.build_read_key_req(watch_req);
let (_watch, index) = self
.execute_request(
lock_index_req,
hyper::Body::empty(),
None,
GET_LOCK_METHOD_NAME,
)
.await?;
Err(ConsulError::LockAcquisitionFailure(index))
}
}
pub async fn watch_lock<'a>(
&self,
request: LockWatchRequest<'_>,
) -> Result<Vec<ReadKeyResponse>> {
let req = ReadKeyRequest {
key: request.key,
namespace: request.namespace,
datacenter: request.datacenter,
index: request.index,
wait: request.wait,
consistency: request.consistency,
..Default::default()
};
self.read_key(req).await
}
pub async fn register_entity(&self, payload: &RegisterEntityPayload) -> Result<()> {
let uri = format!("{}/v1/catalog/register", self.config.address);
let request = hyper::Request::builder().method(Method::PUT).uri(uri);
let payload = serde_json::to_string(payload).map_err(ConsulError::InvalidRequest)?;
self.execute_request(
request,
payload.into(),
Some(Duration::from_secs(5)),
REGISTER_ENTITY_METHOD_NAME,
)
.await?;
Ok(())
}
pub async fn deregister_entity(&self, payload: &DeregisterEntityPayload) -> Result<()> {
let uri = format!("{}/v1/catalog/deregister", self.config.address);
let request = hyper::Request::builder().method(Method::PUT).uri(uri);
let payload = serde_json::to_string(payload).map_err(ConsulError::InvalidRequest)?;
self.execute_request(
request,
payload.into(),
Some(Duration::from_secs(5)),
DEREGISTER_ENTITY_METHOD_NAME,
)
.await?;
Ok(())
}
pub async fn get_all_registered_service_names(
&self,
query_opts: Option<QueryOptions>,
) -> Result<ResponseMeta<Vec<String>>> {
let mut uri = format!("{}/v1/catalog/services", self.config.address);
let query_opts = query_opts.unwrap_or_default();
add_query_option_params(&mut uri, &query_opts, '?');
let request = hyper::Request::builder()
.method(Method::GET)
.uri(uri.clone());
let (mut response_body, index) = self
.execute_request(
request,
hyper::Body::empty(),
query_opts.timeout,
GET_ALL_REGISTERED_SERVICE_NAMES_METHOD_NAME,
)
.await?;
let bytes = response_body.copy_to_bytes(response_body.remaining());
let service_tags_by_name = serde_json::from_slice::<HashMap<String, Vec<String>>>(&bytes)
.map_err(ConsulError::ResponseDeserializationFailed)?;
Ok(ResponseMeta {
response: service_tags_by_name.keys().cloned().collect(),
index,
})
}
pub async fn get_service_nodes(
&self,
request: GetServiceNodesRequest<'_>,
query_opts: Option<QueryOptions>,
) -> Result<ResponseMeta<GetServiceNodesResponse>> {
let query_opts = query_opts.unwrap_or_default();
let req = self.build_get_service_nodes_req(request, &query_opts);
let (mut response_body, index) = self
.execute_request(
req,
hyper::Body::empty(),
query_opts.timeout,
GET_SERVICE_NODES_METHOD_NAME,
)
.await?;
let bytes = response_body.copy_to_bytes(response_body.remaining());
let response = serde_json::from_slice::<GetServiceNodesResponse>(&bytes)
.map_err(ConsulError::ResponseDeserializationFailed)?;
Ok(ResponseMeta { response, index })
}
pub async fn get_service_addresses_and_ports(
&self,
service_name: &str,
query_opts: Option<QueryOptions>,
) -> Result<Vec<(String, u16)>> {
let request = GetServiceNodesRequest {
service: service_name,
passing: true,
..Default::default()
};
let services = self.get_service_nodes(request, query_opts).await.map_err(|e| {
let err = format!(
"Unable to query consul to resolve service '{}' to a list of addresses and ports: {:?}",
service_name, e
);
error!("{}", err);
ConsulError::ServiceInstanceResolutionFailed(service_name.to_string())
})?;
let addresses_and_ports = services
.response
.into_iter()
.map(Self::parse_host_port_from_service_node_response)
.collect();
info!(
"resolved service '{}' to addresses and ports: '{:?}'",
service_name, addresses_and_ports
);
Ok(addresses_and_ports)
}
fn parse_host_port_from_service_node_response(sn: ServiceNode) -> (String, u16) {
(
if sn.service.address.is_empty() {
info!(
"Consul service {service_name} instance had an empty Service address, with port:{port}",
service_name = &sn.service.service, port = sn.service.port
);
sn.node.address
} else {
sn.service.address
},
sn.service.port,
)
}
fn build_read_key_req(&self, request: ReadKeyRequest<'_>) -> http::request::Builder {
let req = hyper::Request::builder().method(Method::GET);
let mut url = String::new();
url.push_str(&format!(
"{}/v1/kv/{}?recurse={}",
self.config.address, request.key, request.recurse
));
if !request.separator.is_empty() {
url.push_str(&format!("&separator={}", request.separator));
}
if request.consistency == ConsistencyMode::Consistent {
url.push_str("&consistent");
} else if request.consistency == ConsistencyMode::Stale {
url.push_str("&stale");
}
if let Some(index) = request.index {
url.push_str(&format!("&index={}", index));
if request.wait.as_secs() > 0 {
url.push_str(&format!(
"&wait={}",
types::duration_as_string(&request.wait)
));
}
}
url = add_namespace_and_datacenter(url, request.namespace, request.datacenter);
req.uri(url)
}
async fn get_session(&self, request: LockRequest<'_>) -> Result<SessionResponse> {
let session_req = CreateSessionRequest {
lock_delay: request.lock_delay,
behavior: request.behavior,
ttl: request.timeout,
..Default::default()
};
let mut req = hyper::Request::builder().method(Method::PUT);
let mut url = String::new();
url.push_str(&format!("{}/v1/session/create?", self.config.address));
url = add_namespace_and_datacenter(url, request.namespace, request.datacenter);
req = req.uri(url);
let create_session_json =
serde_json::to_string(&session_req).map_err(ConsulError::InvalidRequest)?;
let (mut response_body, _index) = self
.execute_request(
req,
hyper::Body::from(create_session_json),
None,
GET_SESSION_METHOD_NAME,
)
.await?;
let bytes = response_body.copy_to_bytes(response_body.remaining());
serde_json::from_slice(&bytes).map_err(ConsulError::ResponseDeserializationFailed)
}
fn build_get_service_nodes_req(
&self,
request: GetServiceNodesRequest<'_>,
query_opts: &QueryOptions,
) -> http::request::Builder {
let req = hyper::Request::builder().method(Method::GET);
let mut url = String::new();
url.push_str(&format!(
"{}/v1/health/service/{}",
self.config.address, request.service
));
url.push_str(&format!("?passing={}", request.passing));
if let Some(near) = request.near {
url.push_str(&format!("&near={}", near));
}
if let Some(filter) = request.filter {
url.push_str(&format!("&filter={}", filter));
}
add_query_option_params(&mut url, query_opts, '&');
req.uri(url)
}
async fn execute_request<'a>(
&self,
req: http::request::Builder,
body: hyper::Body,
duration: Option<std::time::Duration>,
request_name: &str,
) -> Result<(Box<dyn Buf>, u64)> {
let req = req
.header(
"X-Consul-Token",
self.config.token.clone().unwrap_or_default(),
)
.body(body);
let req = req.map_err(ConsulError::RequestError)?;
#[cfg(feature = "trace")]
let mut span = crate::hyper_wrapper::span_for_request(&self.tracer, &req);
let method = req.method().clone();
record_request_metric_if_enabled(&method, request_name);
let future = self.https_client.request(req);
let step_start_instant = Instant::now();
let response = if let Some(dur) = duration {
match timeout(dur, future).await {
Ok(resp) => resp.map_err(ConsulError::ResponseError),
Err(_) => Err(ConsulError::TimeoutExceeded(dur)),
}
} else {
future.await.map_err(ConsulError::ResponseError)
};
record_duration_metric_if_enabled(
&method,
request_name,
step_start_instant.elapsed().as_millis() as f64,
);
if response.is_err() {
record_failure_metric_if_enabled(&method, request_name);
}
let response = response?;
#[cfg(feature = "trace")]
crate::hyper_wrapper::annotate_span_for_response(&mut span, &response);
let status = response.status();
if status != hyper::StatusCode::OK {
record_failure_metric_if_enabled(&method, request_name);
let mut response_body = hyper::body::aggregate(response.into_body())
.await
.map_err(|e| ConsulError::UnexpectedResponseCode(status, e.to_string()))?;
let bytes = response_body.copy_to_bytes(response_body.remaining());
let resp = std::str::from_utf8(&bytes)
.map_err(|e| ConsulError::UnexpectedResponseCode(status, e.to_string()))?;
return Err(ConsulError::UnexpectedResponseCode(
status,
resp.to_string(),
));
}
let index = match response.headers().get("x-consul-index") {
Some(header) => header.to_str().unwrap_or("0").parse::<u64>().unwrap_or(0),
None => 0,
};
match hyper::body::aggregate(response.into_body()).await {
Ok(body) => Ok((Box::new(body), index)),
Err(e) => {
record_failure_metric_if_enabled(&method, request_name);
#[cfg(feature = "trace")]
span.set_status(Status::error(e.to_string()));
Err(ConsulError::InvalidResponse(e))
}
}
}
fn build_create_or_update_url(&self, request: CreateOrUpdateKeyRequest<'_>) -> String {
let mut url = String::new();
url.push_str(&format!("{}/v1/kv/{}", self.config.address, request.key));
let mut added_query_param = false;
if request.flags != 0 {
url = add_query_param_separator(url, added_query_param);
url.push_str(&format!("flags={}", request.flags));
added_query_param = true;
}
if !request.acquire.is_empty() {
url = add_query_param_separator(url, added_query_param);
url.push_str(&format!("acquire={}", request.acquire));
added_query_param = true;
}
if !request.release.is_empty() {
url = add_query_param_separator(url, added_query_param);
url.push_str(&format!("release={}", request.release));
added_query_param = true;
}
if let Some(cas_idx) = request.check_and_set {
url = add_query_param_separator(url, added_query_param);
url.push_str(&format!("cas={}", cas_idx));
}
add_namespace_and_datacenter(url, request.namespace, request.datacenter)
}
}
fn add_query_option_params(uri: &mut String, query_opts: &QueryOptions, mut separator: char) {
if let Some(ns) = &query_opts.namespace {
if !ns.is_empty() {
uri.push_str(&format!("{}ns={}", separator, ns));
separator = '&';
}
}
if let Some(dc) = &query_opts.datacenter {
if !dc.is_empty() {
uri.push_str(&format!("{}dc={}", separator, dc));
separator = '&';
}
}
if let Some(idx) = query_opts.index {
uri.push_str(&format!("{}index={}", separator, idx));
separator = '&';
if let Some(wait) = query_opts.wait {
uri.push_str(&format!(
"{}wait={}",
separator,
types::duration_as_string(&wait)
));
}
}
}
fn add_namespace_and_datacenter<'a>(
mut url: String,
namespace: &'a str,
datacenter: &'a str,
) -> String {
if !namespace.is_empty() {
url.push_str(&format!("&ns={}", namespace));
}
if !datacenter.is_empty() {
url.push_str(&format!("&dc={}", datacenter));
}
url
}
fn add_query_param_separator(mut url: String, already_added: bool) -> String {
if already_added {
url.push('&');
} else {
url.push('?');
}
url
}
fn record_request_metric_if_enabled(_method: &Method, _function: &str) {
#[cfg(feature = "metrics")]
{
CONSUL_REQUESTS_TOTAL
.with_label_values(&[_method.as_str(), _function])
.inc();
}
}
fn record_failure_metric_if_enabled(_method: &Method, _function: &str) {
#[cfg(feature = "metrics")]
{
CONSUL_REQUESTS_FAILED_TOTAL
.with_label_values(&[_method.as_str(), _function])
.inc();
}
}
fn record_duration_metric_if_enabled(_method: &Method, _function: &str, _duration: f64) {
#[cfg(feature = "metrics")]
{
CONSUL_REQUESTS_DURATION_MS
.with_label_values(&[_method.as_str(), _function])
.observe(_duration);
}
}
#[cfg(test)]
mod tests {
use std::time::Duration;
use tokio::time::sleep;
use super::*;
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn create_and_read_key() {
let consul = get_client();
let key = "test/consul/read";
let string_value = "This is a test";
let res = create_or_update_key_value(&consul, key, string_value).await;
assert_expected_result_with_index(res);
let res = read_key(&consul, key).await;
verify_single_value_matches(res, string_value);
}
#[tokio::test(flavor = "multi_thread")]
async fn test_register_and_retrieve_services() {
let consul = get_client();
let new_service_name = "test-service-44".to_string();
register_entity(&consul, &new_service_name, "local").await;
assert!(is_registered(&consul, &new_service_name).await);
}
#[tokio::test(flavor = "multi_thread")]
async fn test_deregister_and_retrieve_services() {
let consul = get_client();
let new_service_name = "test-service-45".to_string();
let node_id = "local";
register_entity(&consul, &new_service_name, node_id).await;
let payload = DeregisterEntityPayload {
Node: Some(node_id.to_string()),
Datacenter: None,
CheckID: None,
ServiceID: None,
Namespace: None,
};
consul
.deregister_entity(&payload)
.await
.expect("expected deregister_entity request to succeed");
assert!(!is_registered(&consul, &new_service_name).await);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn get_services_nodes() {
let consul = get_client();
let req = GetServiceNodesRequest {
service: "nonexistent",
passing: true,
..Default::default()
};
let ResponseMeta { response, .. } = consul.get_service_nodes(req, None).await.unwrap();
assert_eq!(response.len(), 0);
let req = GetServiceNodesRequest {
service: "test-service",
passing: true,
..Default::default()
};
let ResponseMeta { response, .. } = consul.get_service_nodes(req, None).await.unwrap();
assert_eq!(response.len(), 3);
let addresses: Vec<String> = response
.iter()
.map(|sn| sn.service.address.clone())
.collect();
let expected_addresses = vec![
"1.1.1.1".to_string(),
"2.2.2.2".to_string(),
"3.3.3.3".to_string(),
];
assert!(expected_addresses
.iter()
.all(|item| addresses.contains(item)));
let _: Vec<_> = response
.iter()
.map(|sn| assert_eq!("dc1", sn.node.datacenter))
.collect();
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn create_and_delete_key() {
let consul = get_client();
let key = "test/consul/again";
let string_value = "This is a new test";
let res = create_or_update_key_value(&consul, key, string_value).await;
assert_expected_result_with_index(res);
let res = delete_key(&consul, key).await;
assert_expected_result(res);
let res = read_key(&consul, key).await.unwrap_err();
match res {
ConsulError::UnexpectedResponseCode(code, _body) => {
assert_eq!(code, hyper::http::StatusCode::NOT_FOUND)
}
_ => panic!(
"Expected ConsulError::UnexpectedResponseCode, got {:#?}",
res
),
};
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn create_and_release_lock() {
let consul = get_client();
let key = "test/consul/lock";
let string_value = "This is a lock test";
let new_string_value = "This is a changed lock test";
let req = LockRequest {
key,
behavior: LockExpirationBehavior::Release,
lock_delay: std::time::Duration::from_secs(1),
..Default::default()
};
let session_id: String;
{
let res = consul.get_lock(req, string_value.as_bytes()).await;
assert!(res.is_ok());
let mut lock = res.unwrap();
let res2 = consul.get_lock(req, string_value.as_bytes()).await;
assert!(res2.is_err());
let err = res2.unwrap_err();
match err {
ConsulError::LockAcquisitionFailure(_index) => (),
_ => panic!(
"Expected ConsulError::LockAcquisitionFailure, got {:#?}",
err
),
}
session_id = lock.session_id.to_string();
lock.value = Some(new_string_value.as_bytes().to_vec())
}
sleep(Duration::from_secs(2)).await;
let key_resp = read_key(&consul, key).await;
verify_single_value_matches(key_resp, new_string_value);
let req = LockRequest {
key,
behavior: LockExpirationBehavior::Delete,
lock_delay: std::time::Duration::from_secs(1),
session_id: &session_id,
..Default::default()
};
let res = consul.get_lock(req, string_value.as_bytes()).await;
assert!(res.is_ok());
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn create_and_watch_lock() {
let consul = get_client();
let key = "test/consul/watchedlock";
let string_value = "This is a lock test";
let req = LockRequest {
key,
behavior: LockExpirationBehavior::Release,
lock_delay: std::time::Duration::from_secs(0),
..Default::default()
};
let res = consul.get_lock(req, string_value.as_bytes()).await;
assert!(res.is_ok());
let lock = res.unwrap();
let res2 = consul.get_lock(req, string_value.as_bytes()).await;
assert!(res2.is_err());
let err = res2.unwrap_err();
let start_index = match err {
ConsulError::LockAcquisitionFailure(index) => index,
_ => panic!(
"Expected ConsulError::LockAcquisitionFailure, got {:#?}",
err
),
};
assert!(start_index > 0);
let watch_req = LockWatchRequest {
key,
consistency: ConsistencyMode::Consistent,
index: Some(start_index),
wait: Duration::from_secs(60),
..Default::default()
};
let res = consul.watch_lock(watch_req).await;
assert!(res.is_ok());
std::mem::drop(lock); let res = consul.get_lock(req, string_value.as_bytes()).await;
assert!(res.is_ok());
}
#[test]
fn test_service_node_parsing() {
let node = Node {
id: "node".to_string(),
node: "node".to_string(),
address: "1.1.1.1".to_string(),
datacenter: "datacenter".to_string(),
};
let service = Service {
id: "node".to_string(),
service: "node".to_string(),
address: "2.2.2.2".to_string(),
port: 32,
};
let empty_service = Service {
id: "".to_string(),
service: "".to_string(),
address: "".to_string(),
port: 32,
};
let sn = ServiceNode {
node: node.clone(),
service: service.clone(),
};
let (host, port) = Consul::parse_host_port_from_service_node_response(sn);
assert_eq!(service.port, port);
assert_eq!(service.address, host);
let sn = ServiceNode {
node: node.clone(),
service: empty_service,
};
let (host, port) = Consul::parse_host_port_from_service_node_response(sn);
assert_eq!(service.port, port);
assert_eq!(node.address, host);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn properly_handle_check_and_set() {
let consul = get_client();
let key = "test/consul/proper_cas_handling";
let string_value1 = "This is CAS test";
let req = CreateOrUpdateKeyRequest {
key,
check_and_set: Some(0),
..Default::default()
};
let (set, _) = consul
.create_or_update_key(req.clone(), string_value1.as_bytes().to_vec())
.await
.expect("failed to create key initially");
assert!(set);
let (value, mod_idx1) = get_single_key_value_with_index(&consul, key).await;
assert_eq!(string_value1, &value.unwrap());
let string_value2 = "This is CAS test - not valid";
let (set, _) = consul
.create_or_update_key(req, string_value2.as_bytes().to_vec())
.await
.expect("failed to run subsequent create_or_update_key");
assert!(!set);
let (value, mod_idx2) = get_single_key_value_with_index(&consul, key).await;
assert_eq!(string_value1, &value.unwrap());
assert_eq!(mod_idx1, mod_idx2);
let req = CreateOrUpdateKeyRequest {
key,
check_and_set: Some(mod_idx1),
..Default::default()
};
let string_value3 = "This is correct CAS updated";
let (set, _) = consul
.create_or_update_key(req, string_value3.as_bytes().to_vec())
.await
.expect("failed to run create_or_update_key with proper CAS value");
assert!(set);
let (value, mod_idx3) = get_single_key_value_with_index(&consul, key).await;
assert_eq!(string_value3, &value.unwrap());
assert_ne!(mod_idx1, mod_idx3);
let req = CreateOrUpdateKeyRequest {
key,
check_and_set: None,
..Default::default()
};
let string_value4 = "This is non CAS update";
let (set, _) = consul
.create_or_update_key(req, string_value4.as_bytes().to_vec())
.await
.expect("failed to run create_or_update_key without CAS");
assert!(set);
let (value, mod_idx4) = get_single_key_value_with_index(&consul, key).await;
assert_eq!(string_value4, &value.unwrap());
assert_ne!(mod_idx3, mod_idx4);
}
async fn register_entity(consul: &Consul, service_name: &String, node_id: &str) {
let ResponseMeta {
response: service_names_before_register,
..
} = consul
.get_all_registered_service_names(None)
.await
.expect("expected get_registered_service_names request to succeed");
assert!(!service_names_before_register.contains(service_name));
let payload = RegisterEntityPayload {
ID: None,
Node: node_id.to_string(),
Address: "127.0.0.1".to_string(),
Datacenter: None,
TaggedAddresses: Default::default(),
NodeMeta: Default::default(),
Service: Some(RegisterEntityService {
ID: None,
Service: service_name.clone(),
Tags: vec![],
TaggedAddresses: Default::default(),
Meta: Default::default(),
Port: Some(42424),
Namespace: None,
}),
Check: None,
SkipNodeUpdate: None,
};
consul
.register_entity(&payload)
.await
.expect("expected register_entity request to succeed");
}
async fn is_registered(consul: &Consul, service_name: &String) -> bool {
let ResponseMeta {
response: service_names_after_register,
..
} = consul
.get_all_registered_service_names(None)
.await
.expect("expected get_registered_service_names request to succeed");
service_names_after_register.contains(service_name)
}
fn get_client() -> Consul {
let conf: Config = Config::from_env();
Consul::new(conf)
}
async fn create_or_update_key_value(
consul: &Consul,
key: &str,
value: &str,
) -> Result<(bool, u64)> {
let req = CreateOrUpdateKeyRequest {
key,
..Default::default()
};
consul
.create_or_update_key(req, value.as_bytes().to_vec())
.await
}
async fn read_key(consul: &Consul, key: &str) -> Result<Vec<ReadKeyResponse>> {
let req = ReadKeyRequest {
key,
..Default::default()
};
consul.read_key(req).await
}
async fn delete_key(consul: &Consul, key: &str) -> Result<bool> {
let req = DeleteKeyRequest {
key,
..Default::default()
};
consul.delete_key(req).await
}
fn assert_expected_result_with_index(res: Result<(bool, u64)>) {
assert!(res.is_ok());
let (result, _index) = res.unwrap();
assert!(result);
}
fn assert_expected_result(res: Result<bool>) {
assert!(res.is_ok());
assert!(res.unwrap());
}
async fn get_single_key_value_with_index(consul: &Consul, key: &str) -> (Option<String>, i64) {
let res = read_key(consul, key).await.expect("failed to read key");
let r = res.into_iter().next().unwrap();
(r.value, r.modify_index)
}
fn verify_single_value_matches(res: Result<Vec<ReadKeyResponse>>, value: &str) {
assert!(res.is_ok());
assert_eq!(
res.unwrap().into_iter().next().unwrap().value.unwrap(),
value
)
}
}