use futures_util::TryFutureExt;
use reqwest::{Client as HttpClient, RequestBuilder, Response as HttpResponse};
use std::collections::{BTreeMap, HashMap};
use std::fmt::{self, Debug, Formatter};
use std::sync::Arc;
use crate::query::QueryType;
use crate::{Error, Query};
#[derive(Clone)]
pub struct Client {
pub(crate) url: Arc<String>,
pub(crate) parameters: Arc<HashMap<&'static str, String>>,
pub(crate) token: Option<String>,
pub(crate) client: HttpClient,
}
struct RedactPassword<'a>(&'a HashMap<&'static str, String>);
impl<'a> Debug for RedactPassword<'a> {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
let entries = self
.0
.iter()
.map(|(k, v)| match *k {
"p" => (*k, "<redacted>"),
_ => (*k, v.as_str()),
})
.collect::<BTreeMap<&'static str, &str>>();
f.debug_map().entries(entries).finish()
}
}
impl Debug for Client {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
f.debug_struct("Client")
.field("url", &self.url)
.field("parameters", &RedactPassword(&self.parameters))
.finish_non_exhaustive()
}
}
impl Client {
#[must_use = "Creating a client is pointless unless you use it"]
pub fn new<S1, S2>(url: S1, database: S2) -> Self
where
S1: Into<String>,
S2: Into<String>,
{
let mut parameters = HashMap::<&str, String>::new();
parameters.insert("db", database.into());
Client {
url: Arc::new(url.into()),
parameters: Arc::new(parameters),
client: HttpClient::new(),
token: None,
}
}
#[must_use = "Creating a client is pointless unless you use it"]
pub fn with_auth<S1, S2>(mut self, username: S1, password: S2) -> Self
where
S1: Into<String>,
S2: Into<String>,
{
let mut with_auth = self.parameters.as_ref().clone();
with_auth.insert("u", username.into());
with_auth.insert("p", password.into());
self.parameters = Arc::new(with_auth);
self
}
#[must_use = "Creating a client is pointless unless you use it"]
pub fn with_http_client(mut self, http_client: HttpClient) -> Self {
self.client = http_client;
self
}
pub fn with_token<S>(mut self, token: S) -> Self
where
S: Into<String>,
{
self.token = Some(token.into());
self
}
pub fn database_name(&self) -> &str {
self.parameters.get("db").unwrap()
}
pub fn database_url(&self) -> &str {
&self.url
}
pub async fn ping(&self) -> Result<(String, String), Error> {
let url = &format!("{}/ping", self.url);
let res = self
.client
.get(url)
.send()
.await
.map_err(|err| Error::ProtocolError {
error: err.to_string(),
})?;
const BUILD_HEADER: &str = "X-Influxdb-Build";
const VERSION_HEADER: &str = "X-Influxdb-Version";
let (build, version) = {
let hdrs = res.headers();
(
hdrs.get(BUILD_HEADER).and_then(|value| value.to_str().ok()),
hdrs.get(VERSION_HEADER)
.and_then(|value| value.to_str().ok()),
)
};
Ok((build.unwrap().to_owned(), version.unwrap().to_owned()))
}
pub async fn query<Q>(&self, q: Q) -> Result<String, Error>
where
Q: Query,
{
let query = q.build().map_err(|err| Error::InvalidQueryError {
error: err.to_string(),
})?;
let mut parameters = self.parameters.as_ref().clone();
let request_builder = match q.get_type() {
QueryType::ReadQuery => {
let read_query = query.get();
let url = &format!("{}/query", &self.url);
parameters.insert("q", read_query.clone());
if read_query.contains("SELECT") || read_query.contains("SHOW") {
self.client.get(url).query(¶meters)
} else {
self.client.post(url).query(¶meters)
}
}
QueryType::WriteQuery(precision) => {
let url = &format!("{}/write", &self.url);
let mut parameters = self.parameters.as_ref().clone();
parameters.insert("precision", precision);
self.client.post(url).body(query.get()).query(¶meters)
}
};
let res = self
.auth_if_needed(request_builder)
.send()
.map_err(|err| Error::ConnectionError {
error: err.to_string(),
})
.await?;
check_status(&res)?;
let body = res.text();
let s = body.await.map_err(|_| Error::DeserializationError {
error: "response could not be converted to UTF-8".into(),
})?;
if s.contains("\"error\"") || s.contains("\"Error\"") {
return Err(Error::DatabaseError {
error: format!("influxdb error: {s:?}"),
});
}
Ok(s)
}
fn auth_if_needed(&self, rb: RequestBuilder) -> RequestBuilder {
if let Some(ref token) = self.token {
rb.header("Authorization", format!("Token {token}"))
} else {
rb
}
}
}
pub(crate) fn check_status(res: &HttpResponse) -> Result<(), Error> {
let status = res.status();
if !status.is_success() {
return Err(Error::ApiError(status.into()));
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::Client;
use indoc::indoc;
#[test]
fn test_client_debug_redacted_password() {
let client = Client::new("https://localhost:8086", "db").with_auth("user", "pass");
let actual = format!("{client:#?}");
let expected = indoc! { r#"
Client {
url: "https://localhost:8086",
parameters: {
"db": "db",
"p": "<redacted>",
"u": "user",
},
..
}
"# };
assert_eq!(actual.trim(), expected.trim());
}
#[test]
fn test_fn_database() {
let client = Client::new("http://localhost:8068", "database");
assert_eq!(client.database_name(), "database");
assert_eq!(client.database_url(), "http://localhost:8068");
}
#[test]
fn test_with_auth() {
let client = Client::new("http://localhost:8068", "database");
assert_eq!(client.parameters.len(), 1);
assert_eq!(client.parameters.get("db").unwrap(), "database");
let with_auth = client.with_auth("username", "password");
assert_eq!(with_auth.parameters.len(), 3);
assert_eq!(with_auth.parameters.get("db").unwrap(), "database");
assert_eq!(with_auth.parameters.get("u").unwrap(), "username");
assert_eq!(with_auth.parameters.get("p").unwrap(), "password");
let client = Client::new("http://localhost:8068", "database");
let with_auth = client.with_token("token");
assert_eq!(with_auth.parameters.len(), 1);
assert_eq!(with_auth.parameters.get("db").unwrap(), "database");
assert_eq!(with_auth.token.unwrap(), "token");
}
}