1use std::marker::PhantomData;
4use std::time::Duration;
5
6use serde::Deserialize;
7use serde::de::DeserializeOwned;
8use serde_json::{Map, Value};
9
10use crate::Client;
11use crate::error::Result;
12use crate::handles::{NestedProjection, Sort};
13use crate::query::{AsQuery, BoolBuilder, Root};
14
15pub trait FlussoDocument: DeserializeOwned {
23 const INDEX: &'static str;
25
26 const SCHEMA_HASH: &'static str;
28
29 fn physical_index() -> String {
34 format!("{}_{}", Self::INDEX, Self::SCHEMA_HASH)
35 }
36
37 fn query() -> Search<Self> {
43 Search::new(Self::INDEX, Self::SCHEMA_HASH)
44 }
45
46 fn get(
48 client: &Client,
49 id: impl std::fmt::Display,
50 ) -> impl std::future::Future<Output = Result<Option<Self>>> {
51 client.get_one::<Self>(Self::INDEX, Self::SCHEMA_HASH, id)
52 }
53}
54
55#[derive(Debug, Clone)]
66pub struct Search<T> {
67 index: String,
68 hash: String,
69 bool_query: BoolBuilder,
70 raw: Option<Value>,
71 sort: Vec<Sort>,
72 from: Option<u64>,
73 size: Option<u64>,
74 nested: Vec<NestedProjection>,
75 _marker: PhantomData<fn() -> T>,
76}
77
78impl<T> Search<T> {
79 pub fn new(index: impl Into<String>, hash: impl Into<String>) -> Self {
81 Self {
82 index: index.into(),
83 hash: hash.into(),
84 bool_query: BoolBuilder::default(),
85 raw: None,
86 sort: Vec::new(),
87 from: None,
88 size: None,
89 nested: Vec::new(),
90 _marker: PhantomData,
91 }
92 }
93
94 #[must_use]
97 pub fn query(mut self, query: impl AsQuery<Root>) -> Self {
98 if let Some(query) = query.into_query() {
99 self.bool_query.push_must(query.into_inner());
100 }
101 self
102 }
103
104 #[must_use]
107 pub fn filter(mut self, query: impl AsQuery<Root>) -> Self {
108 if let Some(query) = query.into_query() {
109 self.bool_query.push_filter(query.into_inner());
110 }
111 self
112 }
113
114 #[must_use]
116 pub fn must_not(mut self, query: impl AsQuery<Root>) -> Self {
117 if let Some(query) = query.into_query() {
118 self.bool_query.push_must_not(query.into_inner());
119 }
120 self
121 }
122
123 #[must_use]
125 pub fn should(mut self, query: impl AsQuery<Root>) -> Self {
126 if let Some(query) = query.into_query() {
127 self.bool_query.push_should(query.into_inner());
128 }
129 self
130 }
131
132 #[must_use]
133 pub fn sort(mut self, sort: Sort) -> Self {
134 self.sort.push(sort);
135 self
136 }
137
138 #[must_use]
140 pub fn from(mut self, from: u64) -> Self {
141 self.from = Some(from);
142 self
143 }
144
145 #[must_use]
147 pub fn size(mut self, size: u64) -> Self {
148 self.size = Some(size);
149 self
150 }
151
152 #[must_use]
156 pub fn raw(mut self, query: Value) -> Self {
157 self.raw = Some(query);
158 self
159 }
160
161 #[must_use]
165 pub fn filter_nested(mut self, projection: NestedProjection) -> Self {
166 self.nested.push(projection);
167 self
168 }
169
170 fn query_value(&self) -> Value {
174 match &self.raw {
175 Some(raw) => raw.clone(),
176 None if self.bool_query.is_empty() => crate::handles::match_all_value(),
177 None => self.bool_query.to_value(),
178 }
179 }
180
181 #[must_use]
184 pub fn body(&self) -> Value {
185 let query = self.query_value();
186
187 let query = if self.nested.is_empty() {
191 query
192 } else {
193 let mut bool_body = Map::new();
194 bool_body.insert("must".to_string(), Value::Array(vec![query]));
195 let shoulds = self.nested.iter().map(NestedProjection::to_value).collect();
196 bool_body.insert("should".to_string(), Value::Array(shoulds));
197 let mut outer = Map::new();
198 outer.insert("bool".to_string(), Value::Object(bool_body));
199 Value::Object(outer)
200 };
201
202 let mut root = Map::new();
203 root.insert("query".to_string(), query);
204 self.insert_page_params(&mut root);
205 Value::Object(root)
206 }
207
208 fn insert_page_params(&self, root: &mut Map<String, Value>) {
211 if !self.sort.is_empty() {
212 let keys = self.sort.iter().map(Sort::to_value).collect();
213 root.insert("sort".to_string(), Value::Array(keys));
214 }
215 if let Some(from) = self.from {
216 root.insert("from".to_string(), Value::from(from));
217 }
218 if let Some(size) = self.size {
219 root.insert("size".to_string(), Value::from(size));
220 }
221 }
222
223 #[must_use]
228 pub fn count_body(&self) -> Value {
229 let mut root = Map::new();
230 root.insert("query".to_string(), self.query_value());
231 Value::Object(root)
232 }
233
234 #[must_use]
240 pub fn ids_body(&self) -> Value {
241 let mut root = Map::new();
242 root.insert("query".to_string(), self.query_value());
243 self.insert_page_params(&mut root);
244 root.insert("_source".to_string(), Value::Bool(false));
245 Value::Object(root)
246 }
247
248 #[tracing::instrument(
254 name = "search.ids",
255 skip_all,
256 fields(index = %self.index, returned = tracing::field::Empty),
257 err,
258 )]
259 pub async fn ids(&self, client: &Client) -> Result<Vec<String>> {
260 let body = self.ids_body();
261 let response = client.search_at(&self.physical_index(), &body).await?;
262 let raw: RawIdsResponse = serde_json::from_value(response)?;
263 let ids: Vec<String> = raw.hits.hits.into_iter().map(|hit| hit.id).collect();
264 tracing::Span::current().record("returned", ids.len());
265 tracing::debug!(returned = ids.len(), "ids search completed");
266 Ok(ids)
267 }
268
269 pub(crate) fn physical_index(&self) -> String {
273 format!("{}_{}", self.index, self.hash)
274 }
275
276 pub(crate) fn nested_paths(&self) -> Vec<&str> {
279 self.nested.iter().map(NestedProjection::path).collect()
280 }
281
282 #[tracing::instrument(
287 name = "search.count",
288 skip_all,
289 fields(index = %self.index, count = tracing::field::Empty),
290 err,
291 )]
292 pub async fn count(&self, client: &Client) -> Result<u64> {
293 let body = self.count_body();
294 let response = client.count_at(&self.physical_index(), &body).await?;
295 let raw: RawCount = serde_json::from_value(response)?;
296 tracing::Span::current().record("count", raw.count);
297 tracing::debug!(count = raw.count, "count completed");
298 Ok(raw.count)
299 }
300}
301
302impl<T> Search<T>
303where
304 T: DeserializeOwned,
305{
306 #[tracing::instrument(
308 name = "search.send",
309 skip_all,
310 fields(
311 index = %self.index,
312 from = ?self.from,
313 size = ?self.size,
314 total = tracing::field::Empty,
315 took_ms = tracing::field::Empty,
316 ),
317 err,
318 )]
319 pub async fn send(&self, client: &Client) -> Result<SearchResponse<T>> {
320 let body = self.body();
321 let mut response = client.search_at(&self.physical_index(), &body).await?;
322 let paths = self.nested_paths();
323 if !paths.is_empty() {
324 merge_inner_hits(&mut response, &paths);
325 }
326 let page = SearchResponse::from_value(response)?;
327 let span = tracing::Span::current();
328 span.record("total", page.total);
329 span.record("took_ms", page.took.as_millis() as u64);
330 tracing::debug!(
331 total = page.total,
332 hits = page.hits.len(),
333 "search completed"
334 );
335 Ok(page)
336 }
337}
338
339pub(crate) fn merge_inner_hits(response: &mut Value, paths: &[&str]) {
342 let Some(hits) = response
343 .get_mut("hits")
344 .and_then(|hits| hits.get_mut("hits"))
345 .and_then(Value::as_array_mut)
346 else {
347 return;
348 };
349 for hit in hits {
350 let inner = match hit.get("inner_hits") {
351 Some(inner) => inner.clone(),
352 None => continue,
353 };
354 let Some(source) = hit.get_mut("_source").and_then(Value::as_object_mut) else {
355 continue;
356 };
357 for path in paths {
358 let subset: Vec<Value> = inner
359 .get(*path)
360 .and_then(|hit| hit.get("hits"))
361 .and_then(|hits| hits.get("hits"))
362 .and_then(Value::as_array)
363 .map(|hits| {
364 hits.iter()
365 .filter_map(|h| h.get("_source").cloned())
366 .collect()
367 })
368 .unwrap_or_default();
369 source.insert((*path).to_string(), Value::Array(subset));
370 }
371 }
372}
373
374#[derive(Debug)]
376pub struct SearchResponse<T> {
377 pub total: u64,
379 pub max_score: Option<f32>,
381 pub hits: Vec<Hit<T>>,
383 pub took: Duration,
385}
386
387impl<T> SearchResponse<T>
388where
389 T: DeserializeOwned,
390{
391 pub fn from_value(value: Value) -> Result<Self> {
393 let raw: RawResponse<T> = serde_json::from_value(value)?;
394 let hits = raw
395 .hits
396 .hits
397 .into_iter()
398 .map(|hit| Hit {
399 id: hit.id,
400 score: hit.score.unwrap_or(0.0),
401 source: hit.source,
402 })
403 .collect();
404 Ok(Self {
405 total: raw.hits.total.value,
406 max_score: raw.hits.max_score,
407 hits,
408 took: Duration::from_millis(raw.took),
409 })
410 }
411}
412
413#[derive(Debug)]
415pub struct Hit<T> {
416 pub id: String,
418 pub score: f32,
420 pub source: T,
422}
423
424#[derive(Deserialize)]
425struct RawResponse<T> {
426 #[serde(default)]
427 took: u64,
428 hits: RawHits<T>,
429}
430
431#[derive(Deserialize)]
432struct RawHits<T> {
433 total: RawTotal,
434 #[serde(default)]
435 max_score: Option<f32>,
436 hits: Vec<RawHit<T>>,
437}
438
439#[derive(Deserialize)]
440struct RawTotal {
441 value: u64,
442}
443
444#[derive(Deserialize)]
447pub(crate) struct RawCount {
448 pub(crate) count: u64,
449}
450
451#[derive(Deserialize)]
454struct RawIdsResponse {
455 hits: RawIdsHits,
456}
457
458#[derive(Deserialize)]
459struct RawIdsHits {
460 hits: Vec<RawIdHit>,
461}
462
463#[derive(Deserialize)]
464struct RawIdHit {
465 #[serde(rename = "_id")]
466 id: String,
467}
468
469#[derive(Deserialize)]
470struct RawHit<T> {
471 #[serde(rename = "_id")]
472 id: String,
473 #[serde(rename = "_score", default)]
474 score: Option<f32>,
475 #[serde(rename = "_source")]
476 source: T,
477}