use reqwest::StatusCode;
use serde::Deserialize;
use serde::de::DeserializeOwned;
use serde_json::Value;
use url::Url;
use crate::error::{Error, Result};
#[derive(Debug, Clone)]
pub struct Client {
http: reqwest::Client,
base: String,
auth: Option<(String, String)>,
}
impl Client {
pub fn connect(url: impl AsRef<str>) -> Result<Self> {
let raw = url.as_ref();
let parsed = Url::parse(raw).map_err(|error| Error::Url(format!("{raw}: {error}")))?;
match parsed.scheme() {
"http" | "https" => {}
other => return Err(Error::Url(format!("unsupported scheme `{other}` in {raw}"))),
}
let http = reqwest::Client::builder().build()?;
Ok(Self {
http,
base: raw.trim_end_matches('/').to_string(),
auth: None,
})
}
#[must_use]
pub fn basic_auth(mut self, username: impl Into<String>, password: impl Into<String>) -> Self {
self.auth = Some((username.into(), password.into()));
self
}
fn authed(&self, builder: reqwest::RequestBuilder) -> reqwest::RequestBuilder {
match &self.auth {
Some((user, pass)) => builder.basic_auth(user, Some(pass)),
None => builder,
}
}
#[tracing::instrument(
name = "search.request",
level = "debug",
skip_all,
fields(path, status = tracing::field::Empty),
err,
)]
pub(crate) async fn search_at(&self, path: &str, body: &Value) -> Result<Value> {
let endpoint = format!("{}/{path}/_search", self.base);
tracing::debug!(%endpoint, "POST _search");
self.post_json(&endpoint, body).await
}
#[tracing::instrument(
name = "count.request",
level = "debug",
skip_all,
fields(path, status = tracing::field::Empty),
err,
)]
pub(crate) async fn count_at(&self, path: &str, body: &Value) -> Result<Value> {
let endpoint = format!("{}/{path}/_count", self.base);
tracing::debug!(%endpoint, "POST _count");
self.post_json(&endpoint, body).await
}
#[tracing::instrument(
name = "msearch.request",
level = "debug",
skip_all,
fields(bytes = ndjson.len(), status = tracing::field::Empty),
err,
)]
pub(crate) async fn msearch_raw(&self, ndjson: String) -> Result<Value> {
let endpoint = format!("{}/_msearch", self.base);
tracing::debug!(%endpoint, "POST _msearch");
let builder = self
.http
.post(&endpoint)
.header("Content-Type", "application/x-ndjson")
.body(ndjson);
self.execute_json(builder).await
}
async fn post_json(&self, endpoint: &str, body: &Value) -> Result<Value> {
self.execute_json(self.http.post(endpoint).json(body)).await
}
async fn execute_json(&self, builder: reqwest::RequestBuilder) -> Result<Value> {
let response = self.authed(builder).send().await?;
let status = response.status();
tracing::Span::current().record("status", status.as_u16());
if !status.is_success() {
return Err(Error::Status {
status: status.as_u16(),
body: response.text().await.unwrap_or_default(),
});
}
Ok(response.json::<Value>().await?)
}
#[tracing::instrument(
name = "search.get",
level = "debug",
skip_all,
fields(index, hash, id = %id, status = tracing::field::Empty),
err,
)]
pub async fn get_one<T>(
&self,
index: &str,
hash: &str,
id: impl std::fmt::Display,
) -> Result<Option<T>>
where
T: DeserializeOwned,
{
let endpoint = format!("{}/{index}_{hash}/_doc/{id}", self.base);
tracing::debug!(%endpoint, "GET _doc");
let response = self.authed(self.http.get(&endpoint)).send().await?;
let status = response.status();
tracing::Span::current().record("status", status.as_u16());
if status == StatusCode::NOT_FOUND {
return Ok(None);
}
if !status.is_success() {
return Err(Error::Status {
status: status.as_u16(),
body: response.text().await.unwrap_or_default(),
});
}
let doc: GetResponse<T> = response.json().await?;
match (doc.found, doc.source) {
(true, Some(source)) => Ok(Some(source)),
_ => Ok(None),
}
}
}
#[derive(Deserialize)]
struct GetResponse<T> {
#[serde(default)]
found: bool,
#[serde(rename = "_source", default = "none")]
source: Option<T>,
}
fn none<T>() -> Option<T> {
None
}