1use reqwest::StatusCode;
5use serde::Deserialize;
6use serde::de::DeserializeOwned;
7use serde_json::Value;
8use url::Url;
9
10use crate::error::{Error, Result};
11
12#[derive(Debug, Clone)]
18pub struct Client {
19 http: reqwest::Client,
20 base: String,
22 auth: Option<(String, String)>,
23 pub(crate) index_prefix: String,
28}
29
30impl Client {
31 pub fn connect(url: impl AsRef<str>) -> Result<Self> {
34 let raw = url.as_ref();
35 let parsed = Url::parse(raw).map_err(|error| Error::Url(format!("{raw}: {error}")))?;
36 match parsed.scheme() {
37 "http" | "https" => {}
38 other => return Err(Error::Url(format!("unsupported scheme `{other}` in {raw}"))),
39 }
40 let http = reqwest::Client::builder().build()?;
41 Ok(Self {
42 http,
43 base: raw.trim_end_matches('/').to_string(),
44 auth: None,
45 index_prefix: String::new(),
46 })
47 }
48
49 #[must_use]
51 pub fn basic_auth(mut self, username: impl Into<String>, password: impl Into<String>) -> Self {
52 self.auth = Some((username.into(), password.into()));
53 self
54 }
55
56 #[must_use]
61 pub fn index_prefix(mut self, prefix: impl Into<String>) -> Self {
62 self.index_prefix = prefix.into();
63 self
64 }
65
66 pub(crate) fn prefixed(&self, path: &str) -> String {
70 if self.index_prefix.is_empty() {
71 return path.to_owned();
72 }
73 path.split(',')
74 .map(|segment| format!("{}{segment}", self.index_prefix))
75 .collect::<Vec<_>>()
76 .join(",")
77 }
78
79 fn authed(&self, builder: reqwest::RequestBuilder) -> reqwest::RequestBuilder {
81 match &self.auth {
82 Some((user, pass)) => builder.basic_auth(user, Some(pass)),
83 None => builder,
84 }
85 }
86
87 #[tracing::instrument(
92 name = "search.request",
93 level = "debug",
94 skip_all,
95 fields(path, status = tracing::field::Empty),
96 err,
97 )]
98 pub(crate) async fn search_at(&self, path: &str, body: &Value) -> Result<Value> {
99 let endpoint = format!("{}/{}/_search", self.base, self.prefixed(path));
100 tracing::debug!(%endpoint, query = %body, "POST _search");
101 self.post_json(&endpoint, body).await
102 }
103
104 #[tracing::instrument(
108 name = "count.request",
109 level = "debug",
110 skip_all,
111 fields(path, status = tracing::field::Empty),
112 err,
113 )]
114 pub(crate) async fn count_at(&self, path: &str, body: &Value) -> Result<Value> {
115 let endpoint = format!("{}/{}/_count", self.base, self.prefixed(path));
116 tracing::debug!(%endpoint, query = %body, "POST _count");
117 self.post_json(&endpoint, body).await
118 }
119
120 #[tracing::instrument(
124 name = "msearch.request",
125 level = "debug",
126 skip_all,
127 fields(bytes = ndjson.len(), status = tracing::field::Empty),
128 err,
129 )]
130 pub(crate) async fn msearch_raw(&self, ndjson: String) -> Result<Value> {
131 let endpoint = format!("{}/_msearch", self.base);
132 tracing::debug!(%endpoint, query = %ndjson, "POST _msearch");
133 let builder = self
134 .http
135 .post(&endpoint)
136 .header("Content-Type", "application/x-ndjson")
137 .body(ndjson);
138 self.execute_json(builder).await
139 }
140
141 async fn post_json(&self, endpoint: &str, body: &Value) -> Result<Value> {
144 self.execute_json(self.http.post(endpoint).json(body)).await
145 }
146
147 async fn execute_json(&self, builder: reqwest::RequestBuilder) -> Result<Value> {
150 let response = self.authed(builder).send().await?;
151 let status = response.status();
152 tracing::Span::current().record("status", status.as_u16());
153 if !status.is_success() {
154 return Err(Error::Status {
155 status: status.as_u16(),
156 body: response.text().await.unwrap_or_default(),
157 });
158 }
159 Ok(response.json::<Value>().await?)
160 }
161
162 #[tracing::instrument(
168 name = "search.get",
169 level = "debug",
170 skip_all,
171 fields(index, hash, id = %id, status = tracing::field::Empty),
172 err,
173 )]
174 pub async fn get_one<T>(
175 &self,
176 index: &str,
177 hash: &str,
178 id: impl std::fmt::Display,
179 ) -> Result<Option<T>>
180 where
181 T: DeserializeOwned,
182 {
183 let endpoint = format!(
184 "{}/{}/_doc/{id}",
185 self.base,
186 self.prefixed(&format!("{index}_{hash}"))
187 );
188 tracing::debug!(%endpoint, "GET _doc");
189 let response = self.authed(self.http.get(&endpoint)).send().await?;
190 let status = response.status();
191 tracing::Span::current().record("status", status.as_u16());
192 if status == StatusCode::NOT_FOUND {
193 tracing::debug!(found = false, "GET _doc resolved");
194 return Ok(None);
195 }
196 if !status.is_success() {
197 return Err(Error::Status {
198 status: status.as_u16(),
199 body: response.text().await.unwrap_or_default(),
200 });
201 }
202 let doc: GetResponse<T> = response.json().await?;
203 match (doc.found, doc.source) {
204 (true, Some(source)) => {
205 tracing::debug!(found = true, "GET _doc resolved");
206 Ok(Some(source))
207 }
208 _ => {
209 tracing::debug!(found = false, "GET _doc resolved");
210 Ok(None)
211 }
212 }
213 }
214}
215
216#[derive(Deserialize)]
217struct GetResponse<T> {
218 #[serde(default)]
219 found: bool,
220 #[serde(rename = "_source", default = "none")]
221 source: Option<T>,
222}
223
224fn none<T>() -> Option<T> {
227 None
228}