1use reqwest::StatusCode;
5use serde::Deserialize;
6use serde::de::DeserializeOwned;
7use serde_json::Value;
8use url::Url;
9
10use crate::error::{Error, Result};
11
12#[derive(Debug, Clone)]
18pub struct Client {
19 http: reqwest::Client,
20 base: String,
22 auth: Option<(String, String)>,
23}
24
25impl Client {
26 pub fn connect(url: impl AsRef<str>) -> Result<Self> {
29 let raw = url.as_ref();
30 let parsed = Url::parse(raw).map_err(|error| Error::Url(format!("{raw}: {error}")))?;
31 match parsed.scheme() {
32 "http" | "https" => {}
33 other => return Err(Error::Url(format!("unsupported scheme `{other}` in {raw}"))),
34 }
35 let http = reqwest::Client::builder().build()?;
36 Ok(Self {
37 http,
38 base: raw.trim_end_matches('/').to_string(),
39 auth: None,
40 })
41 }
42
43 #[must_use]
45 pub fn basic_auth(mut self, username: impl Into<String>, password: impl Into<String>) -> Self {
46 self.auth = Some((username.into(), password.into()));
47 self
48 }
49
50 fn authed(&self, builder: reqwest::RequestBuilder) -> reqwest::RequestBuilder {
52 match &self.auth {
53 Some((user, pass)) => builder.basic_auth(user, Some(pass)),
54 None => builder,
55 }
56 }
57
58 #[tracing::instrument(
63 name = "search.request",
64 level = "debug",
65 skip_all,
66 fields(path, status = tracing::field::Empty),
67 err,
68 )]
69 pub(crate) async fn search_at(&self, path: &str, body: &Value) -> Result<Value> {
70 let endpoint = format!("{}/{path}/_search", self.base);
71 tracing::debug!(%endpoint, "POST _search");
72 self.post_json(&endpoint, body).await
73 }
74
75 #[tracing::instrument(
79 name = "count.request",
80 level = "debug",
81 skip_all,
82 fields(path, status = tracing::field::Empty),
83 err,
84 )]
85 pub(crate) async fn count_at(&self, path: &str, body: &Value) -> Result<Value> {
86 let endpoint = format!("{}/{path}/_count", self.base);
87 tracing::debug!(%endpoint, "POST _count");
88 self.post_json(&endpoint, body).await
89 }
90
91 #[tracing::instrument(
95 name = "msearch.request",
96 level = "debug",
97 skip_all,
98 fields(bytes = ndjson.len(), status = tracing::field::Empty),
99 err,
100 )]
101 pub(crate) async fn msearch_raw(&self, ndjson: String) -> Result<Value> {
102 let endpoint = format!("{}/_msearch", self.base);
103 tracing::debug!(%endpoint, "POST _msearch");
104 let builder = self
105 .http
106 .post(&endpoint)
107 .header("Content-Type", "application/x-ndjson")
108 .body(ndjson);
109 self.execute_json(builder).await
110 }
111
112 async fn post_json(&self, endpoint: &str, body: &Value) -> Result<Value> {
115 self.execute_json(self.http.post(endpoint).json(body)).await
116 }
117
118 async fn execute_json(&self, builder: reqwest::RequestBuilder) -> Result<Value> {
121 let response = self.authed(builder).send().await?;
122 let status = response.status();
123 tracing::Span::current().record("status", status.as_u16());
124 if !status.is_success() {
125 return Err(Error::Status {
126 status: status.as_u16(),
127 body: response.text().await.unwrap_or_default(),
128 });
129 }
130 Ok(response.json::<Value>().await?)
131 }
132
133 #[tracing::instrument(
139 name = "search.get",
140 level = "debug",
141 skip_all,
142 fields(index, hash, id = %id, status = tracing::field::Empty),
143 err,
144 )]
145 pub async fn get_one<T>(
146 &self,
147 index: &str,
148 hash: &str,
149 id: impl std::fmt::Display,
150 ) -> Result<Option<T>>
151 where
152 T: DeserializeOwned,
153 {
154 let endpoint = format!("{}/{index}_{hash}/_doc/{id}", self.base);
155 tracing::debug!(%endpoint, "GET _doc");
156 let response = self.authed(self.http.get(&endpoint)).send().await?;
157 let status = response.status();
158 tracing::Span::current().record("status", status.as_u16());
159 if status == StatusCode::NOT_FOUND {
160 return Ok(None);
161 }
162 if !status.is_success() {
163 return Err(Error::Status {
164 status: status.as_u16(),
165 body: response.text().await.unwrap_or_default(),
166 });
167 }
168 let doc: GetResponse<T> = response.json().await?;
169 match (doc.found, doc.source) {
170 (true, Some(source)) => Ok(Some(source)),
171 _ => Ok(None),
172 }
173 }
174}
175
176#[derive(Deserialize)]
177struct GetResponse<T> {
178 #[serde(default)]
179 found: bool,
180 #[serde(rename = "_source", default = "none")]
181 source: Option<T>,
182}
183
184fn none<T>() -> Option<T> {
187 None
188}