1use std::marker::PhantomData;
4use std::time::Duration;
5
6use serde::Deserialize;
7use serde::de::DeserializeOwned;
8use serde_json::{Map, Value};
9
10use crate::Client;
11use crate::error::Result;
12use crate::handles::{NestedProjection, Sort};
13use crate::query::{AsQuery, BoolBuilder, Root};
14
15pub trait FlussoDocument: DeserializeOwned {
23 const INDEX: &'static str;
25
26 const SCHEMA_HASH: &'static str;
28
29 fn physical_index() -> String {
34 format!("{}_{}", Self::INDEX, Self::SCHEMA_HASH)
35 }
36
37 fn query() -> Search<Self> {
43 Search::new(Self::INDEX, Self::SCHEMA_HASH)
44 }
45
46 fn get(
48 client: &Client,
49 id: impl std::fmt::Display,
50 ) -> impl std::future::Future<Output = Result<Option<Self>>> {
51 client.get_one::<Self>(Self::INDEX, Self::SCHEMA_HASH, id)
52 }
53}
54
55#[derive(Debug, Clone)]
66pub struct Search<T> {
67 index: String,
68 hash: String,
69 bool_query: BoolBuilder,
70 raw: Option<Value>,
71 sort: Vec<Sort>,
72 from: Option<u64>,
73 size: Option<u64>,
74 nested: Vec<NestedProjection>,
75 min_score: Option<f32>,
76 track_total_hits: Option<Value>,
77 track_scores: Option<bool>,
78 search_after: Option<Vec<Value>>,
79 collapse: Option<Value>,
80 post_filter: Option<Value>,
81 highlight: Option<Highlight>,
82 _marker: PhantomData<fn() -> T>,
83}
84
85impl<T> Search<T> {
86 pub fn new(index: impl Into<String>, hash: impl Into<String>) -> Self {
88 Self {
89 index: index.into(),
90 hash: hash.into(),
91 bool_query: BoolBuilder::default(),
92 raw: None,
93 sort: Vec::new(),
94 from: None,
95 size: None,
96 nested: Vec::new(),
97 min_score: None,
98 track_total_hits: None,
99 track_scores: None,
100 search_after: None,
101 collapse: None,
102 post_filter: None,
103 highlight: None,
104 _marker: PhantomData,
105 }
106 }
107
108 #[must_use]
111 pub fn query(mut self, query: impl AsQuery<Root>) -> Self {
112 if let Some(query) = query.into_query() {
113 self.bool_query.push_must(query.into_inner());
114 }
115 self
116 }
117
118 #[must_use]
121 pub fn filter(mut self, query: impl AsQuery<Root>) -> Self {
122 if let Some(query) = query.into_query() {
123 self.bool_query.push_filter(query.into_inner());
124 }
125 self
126 }
127
128 #[must_use]
130 pub fn must_not(mut self, query: impl AsQuery<Root>) -> Self {
131 if let Some(query) = query.into_query() {
132 self.bool_query.push_must_not(query.into_inner());
133 }
134 self
135 }
136
137 #[must_use]
139 pub fn should(mut self, query: impl AsQuery<Root>) -> Self {
140 if let Some(query) = query.into_query() {
141 self.bool_query.push_should(query.into_inner());
142 }
143 self
144 }
145
146 #[must_use]
151 pub fn min_should_match(mut self, value: impl Into<Value>) -> Self {
152 self.bool_query.set_min_should_match(value.into());
153 self
154 }
155
156 #[must_use]
157 pub fn sort(mut self, sort: Sort) -> Self {
158 self.sort.push(sort);
159 self
160 }
161
162 #[must_use]
164 pub fn min_score(mut self, min_score: f32) -> Self {
165 self.min_score = Some(min_score);
166 self
167 }
168
169 #[must_use]
172 pub fn track_total_hits(mut self, track: impl Into<Value>) -> Self {
173 self.track_total_hits = Some(track.into());
174 self
175 }
176
177 #[must_use]
179 pub fn track_scores(mut self, track: bool) -> Self {
180 self.track_scores = Some(track);
181 self
182 }
183
184 #[must_use]
187 pub fn search_after(mut self, values: impl IntoIterator<Item = impl Into<Value>>) -> Self {
188 self.search_after = Some(values.into_iter().map(Into::into).collect());
189 self
190 }
191
192 #[must_use]
194 pub fn collapse(mut self, field: impl Into<String>) -> Self {
195 let mut body = Map::new();
196 body.insert("field".to_string(), Value::String(field.into()));
197 self.collapse = Some(Value::Object(body));
198 self
199 }
200
201 #[must_use]
204 pub fn post_filter(mut self, query: impl AsQuery<Root>) -> Self {
205 if let Some(query) = query.into_query() {
206 self.post_filter = Some(query.to_value());
207 }
208 self
209 }
210
211 #[must_use]
213 pub fn highlight(mut self, highlight: Highlight) -> Self {
214 self.highlight = Some(highlight);
215 self
216 }
217
218 #[must_use]
220 pub fn from(mut self, from: u64) -> Self {
221 self.from = Some(from);
222 self
223 }
224
225 #[must_use]
227 pub fn size(mut self, size: u64) -> Self {
228 self.size = Some(size);
229 self
230 }
231
232 #[must_use]
236 pub fn raw(mut self, query: Value) -> Self {
237 self.raw = Some(query);
238 self
239 }
240
241 #[must_use]
245 pub fn filter_nested(mut self, projection: NestedProjection) -> Self {
246 self.nested.push(projection);
247 self
248 }
249
250 fn query_value(&self) -> Value {
254 match &self.raw {
255 Some(raw) => raw.clone(),
256 None if self.bool_query.is_empty() => crate::handles::match_all_value(),
257 None => self.bool_query.to_value(),
258 }
259 }
260
261 #[must_use]
264 pub fn body(&self) -> Value {
265 let query = self.query_value();
266
267 let query = if self.nested.is_empty() {
271 query
272 } else {
273 let mut bool_body = Map::new();
274 bool_body.insert("must".to_string(), Value::Array(vec![query]));
275 let shoulds = self.nested.iter().map(NestedProjection::to_value).collect();
276 bool_body.insert("should".to_string(), Value::Array(shoulds));
277 let mut outer = Map::new();
278 outer.insert("bool".to_string(), Value::Object(bool_body));
279 Value::Object(outer)
280 };
281
282 let mut root = Map::new();
283 root.insert("query".to_string(), query);
284 self.insert_page_params(&mut root);
285 self.insert_search_level(&mut root, true);
286 Value::Object(root)
287 }
288
289 fn insert_page_params(&self, root: &mut Map<String, Value>) {
292 if !self.sort.is_empty() {
293 let keys = self.sort.iter().map(Sort::to_value).collect();
294 root.insert("sort".to_string(), Value::Array(keys));
295 }
296 if let Some(from) = self.from {
297 root.insert("from".to_string(), Value::from(from));
298 }
299 if let Some(size) = self.size {
300 root.insert("size".to_string(), Value::from(size));
301 }
302 }
303
304 fn insert_search_level(&self, root: &mut Map<String, Value>, with_highlight: bool) {
311 if let Some(min_score) = self.min_score {
312 root.insert("min_score".to_string(), Value::from(min_score));
313 }
314 if let Some(track) = &self.track_total_hits {
315 root.insert("track_total_hits".to_string(), track.clone());
316 }
317 if let Some(track) = self.track_scores {
318 root.insert("track_scores".to_string(), Value::Bool(track));
319 }
320 if let Some(values) = &self.search_after {
321 root.insert("search_after".to_string(), Value::Array(values.clone()));
322 }
323 if let Some(collapse) = &self.collapse {
324 root.insert("collapse".to_string(), collapse.clone());
325 }
326 if let Some(post_filter) = &self.post_filter {
327 root.insert("post_filter".to_string(), post_filter.clone());
328 }
329 if with_highlight && let Some(highlight) = &self.highlight {
330 root.insert("highlight".to_string(), highlight.to_value());
331 }
332 }
333
334 #[must_use]
339 pub fn count_body(&self) -> Value {
340 let mut root = Map::new();
341 root.insert("query".to_string(), self.query_value());
342 Value::Object(root)
343 }
344
345 #[must_use]
351 pub fn ids_body(&self) -> Value {
352 let mut root = Map::new();
353 root.insert("query".to_string(), self.query_value());
354 self.insert_page_params(&mut root);
355 self.insert_search_level(&mut root, false);
356 root.insert("_source".to_string(), Value::Bool(false));
357 Value::Object(root)
358 }
359
360 #[tracing::instrument(
366 name = "search.ids",
367 skip_all,
368 fields(index = %self.index, returned = tracing::field::Empty),
369 err,
370 )]
371 pub async fn ids(&self, client: &Client) -> Result<Vec<String>> {
372 let body = self.ids_body();
373 let response = client.search_at(&self.physical_index(), &body).await?;
374 let raw: RawIdsResponse = serde_json::from_value(response)?;
375 let ids: Vec<String> = raw.hits.hits.into_iter().map(|hit| hit.id).collect();
376 tracing::Span::current().record("returned", ids.len());
377 tracing::debug!(returned = ids.len(), "ids search completed");
378 Ok(ids)
379 }
380
381 pub(crate) fn physical_index(&self) -> String {
385 format!("{}_{}", self.index, self.hash)
386 }
387
388 pub(crate) fn nested_paths(&self) -> Vec<&str> {
391 self.nested.iter().map(NestedProjection::path).collect()
392 }
393
394 #[tracing::instrument(
399 name = "search.count",
400 skip_all,
401 fields(index = %self.index, count = tracing::field::Empty),
402 err,
403 )]
404 pub async fn count(&self, client: &Client) -> Result<u64> {
405 let body = self.count_body();
406 let response = client.count_at(&self.physical_index(), &body).await?;
407 let raw: RawCount = serde_json::from_value(response)?;
408 tracing::Span::current().record("count", raw.count);
409 tracing::debug!(count = raw.count, "count completed");
410 Ok(raw.count)
411 }
412}
413
414impl<T> Search<T>
415where
416 T: DeserializeOwned,
417{
418 #[tracing::instrument(
420 name = "search.send",
421 skip_all,
422 fields(
423 index = %self.index,
424 from = ?self.from,
425 size = ?self.size,
426 total = tracing::field::Empty,
427 took_ms = tracing::field::Empty,
428 ),
429 err,
430 )]
431 pub async fn send(&self, client: &Client) -> Result<SearchResponse<T>> {
432 let body = self.body();
433 let mut response = client.search_at(&self.physical_index(), &body).await?;
434 let paths = self.nested_paths();
435 if !paths.is_empty() {
436 merge_inner_hits(&mut response, &paths);
437 }
438 let page = SearchResponse::from_value(response)?;
439 let span = tracing::Span::current();
440 span.record("total", page.total);
441 span.record("took_ms", page.took.as_millis() as u64);
442 tracing::debug!(
443 total = page.total,
444 hits = page.hits.len(),
445 "search completed"
446 );
447 Ok(page)
448 }
449}
450
451pub(crate) fn merge_inner_hits(response: &mut Value, paths: &[&str]) {
454 let Some(hits) = response
455 .get_mut("hits")
456 .and_then(|hits| hits.get_mut("hits"))
457 .and_then(Value::as_array_mut)
458 else {
459 return;
460 };
461 for hit in hits {
462 let inner = match hit.get("inner_hits") {
463 Some(inner) => inner.clone(),
464 None => continue,
465 };
466 let Some(source) = hit.get_mut("_source").and_then(Value::as_object_mut) else {
467 continue;
468 };
469 for path in paths {
470 let subset: Vec<Value> = inner
471 .get(*path)
472 .and_then(|hit| hit.get("hits"))
473 .and_then(|hits| hits.get("hits"))
474 .and_then(Value::as_array)
475 .map(|hits| {
476 hits.iter()
477 .filter_map(|h| h.get("_source").cloned())
478 .collect()
479 })
480 .unwrap_or_default();
481 source.insert((*path).to_string(), Value::Array(subset));
482 }
483 }
484}
485
486#[derive(Debug, Clone, Default)]
490pub struct Highlight {
491 fields: Map<String, Value>,
492 opts: Map<String, Value>,
493}
494
495impl Highlight {
496 #[must_use]
498 pub fn new() -> Self {
499 Self::default()
500 }
501
502 #[must_use]
504 pub fn field(mut self, field: impl Into<String>) -> Self {
505 self.fields.insert(field.into(), Value::Object(Map::new()));
506 self
507 }
508
509 #[must_use]
512 pub fn field_with(mut self, field: impl Into<String>, settings: Value) -> Self {
513 self.fields.insert(field.into(), settings);
514 self
515 }
516
517 #[must_use]
519 pub fn pre_tags(mut self, tags: impl IntoIterator<Item = impl Into<String>>) -> Self {
520 self.opts.insert(
521 "pre_tags".to_string(),
522 Value::Array(tags.into_iter().map(|t| Value::String(t.into())).collect()),
523 );
524 self
525 }
526
527 #[must_use]
529 pub fn post_tags(mut self, tags: impl IntoIterator<Item = impl Into<String>>) -> Self {
530 self.opts.insert(
531 "post_tags".to_string(),
532 Value::Array(tags.into_iter().map(|t| Value::String(t.into())).collect()),
533 );
534 self
535 }
536
537 #[must_use]
539 pub fn fragment_size(mut self, fragment_size: u32) -> Self {
540 self.opts
541 .insert("fragment_size".to_string(), Value::from(fragment_size));
542 self
543 }
544
545 #[must_use]
547 pub fn number_of_fragments(mut self, number_of_fragments: u32) -> Self {
548 self.opts.insert(
549 "number_of_fragments".to_string(),
550 Value::from(number_of_fragments),
551 );
552 self
553 }
554
555 #[must_use]
557 pub fn require_field_match(mut self, require: bool) -> Self {
558 self.opts
559 .insert("require_field_match".to_string(), Value::Bool(require));
560 self
561 }
562
563 fn to_value(&self) -> Value {
564 let mut body = self.opts.clone();
565 body.insert("fields".to_string(), Value::Object(self.fields.clone()));
566 Value::Object(body)
567 }
568}
569
570#[derive(Debug)]
572pub struct SearchResponse<T> {
573 pub total: u64,
575 pub max_score: Option<f32>,
577 pub hits: Vec<Hit<T>>,
579 pub took: Duration,
581}
582
583impl<T> SearchResponse<T>
584where
585 T: DeserializeOwned,
586{
587 pub fn from_value(value: Value) -> Result<Self> {
589 let raw: RawResponse<T> = serde_json::from_value(value)?;
590 let hits = raw
591 .hits
592 .hits
593 .into_iter()
594 .map(|hit| Hit {
595 id: hit.id,
596 score: hit.score.unwrap_or(0.0),
597 source: hit.source,
598 })
599 .collect();
600 Ok(Self {
601 total: raw.hits.total.value,
602 max_score: raw.hits.max_score,
603 hits,
604 took: Duration::from_millis(raw.took),
605 })
606 }
607}
608
609#[derive(Debug)]
611pub struct Hit<T> {
612 pub id: String,
614 pub score: f32,
616 pub source: T,
618}
619
620#[derive(Deserialize)]
621struct RawResponse<T> {
622 #[serde(default)]
623 took: u64,
624 hits: RawHits<T>,
625}
626
627#[derive(Deserialize)]
628struct RawHits<T> {
629 total: RawTotal,
630 #[serde(default)]
631 max_score: Option<f32>,
632 hits: Vec<RawHit<T>>,
633}
634
635#[derive(Deserialize)]
636struct RawTotal {
637 value: u64,
638}
639
640#[derive(Deserialize)]
643pub(crate) struct RawCount {
644 pub(crate) count: u64,
645}
646
647#[derive(Deserialize)]
650struct RawIdsResponse {
651 hits: RawIdsHits,
652}
653
654#[derive(Deserialize)]
655struct RawIdsHits {
656 hits: Vec<RawIdHit>,
657}
658
659#[derive(Deserialize)]
660struct RawIdHit {
661 #[serde(rename = "_id")]
662 id: String,
663}
664
665#[derive(Deserialize)]
666struct RawHit<T> {
667 #[serde(rename = "_id")]
668 id: String,
669 #[serde(rename = "_score", default)]
670 score: Option<f32>,
671 #[serde(rename = "_source")]
672 source: T,
673}