1use std::marker::PhantomData;
4use std::time::Duration;
5
6use serde::Deserialize;
7use serde::de::DeserializeOwned;
8use serde_json::{Map, Value};
9
10use crate::Client;
11use crate::error::Result;
12use crate::handles::{MinimumShouldMatch, NestedProjection, Sort};
13use crate::path::Segment;
14use crate::query::{AsQuery, BoolBuilder, Root};
15
16pub trait FlussoDocument {
26 const PATH: &'static [Segment];
29}
30
31pub trait FlussoIndex: FlussoDocument + DeserializeOwned {
40 const INDEX: &'static str;
42
43 const SCHEMA_HASH: &'static str;
45
46 fn physical_index() -> String {
51 format!("{}_{}", Self::INDEX, Self::SCHEMA_HASH)
52 }
53
54 fn query() -> Search<Self> {
60 Search::new(Self::INDEX, Self::SCHEMA_HASH)
61 }
62
63 fn get(
65 client: &Client,
66 id: impl std::fmt::Display,
67 ) -> impl std::future::Future<Output = Result<Option<Self>>> {
68 client.get_one::<Self>(Self::INDEX, Self::SCHEMA_HASH, id)
69 }
70}
71
72#[derive(Debug, Clone)]
83pub struct Search<T> {
84 index: String,
85 hash: String,
86 bool_query: BoolBuilder,
87 raw: Option<Value>,
88 sort: Vec<Sort>,
89 from: Option<u64>,
90 size: Option<u64>,
91 nested: Vec<NestedProjection>,
92 min_score: Option<f32>,
93 track_total_hits: Option<Value>,
94 track_scores: Option<bool>,
95 search_after: Option<Vec<Value>>,
96 collapse: Option<Value>,
97 post_filter: Option<Value>,
98 highlight: Option<Highlight>,
99 _marker: PhantomData<fn() -> T>,
100}
101
102impl<T> Search<T> {
103 pub fn new(index: impl Into<String>, hash: impl Into<String>) -> Self {
105 Self {
106 index: index.into(),
107 hash: hash.into(),
108 bool_query: BoolBuilder::default(),
109 raw: None,
110 sort: Vec::new(),
111 from: None,
112 size: None,
113 nested: Vec::new(),
114 min_score: None,
115 track_total_hits: None,
116 track_scores: None,
117 search_after: None,
118 collapse: None,
119 post_filter: None,
120 highlight: None,
121 _marker: PhantomData,
122 }
123 }
124
125 #[must_use]
128 pub fn query(mut self, query: impl AsQuery<Root>) -> Self {
129 if let Some(query) = query.into_query() {
130 self.bool_query.push_must(query.into_inner());
131 }
132 self
133 }
134
135 #[must_use]
138 pub fn filter(mut self, query: impl AsQuery<Root>) -> Self {
139 if let Some(query) = query.into_query() {
140 self.bool_query.push_filter(query.into_inner());
141 }
142 self
143 }
144
145 #[must_use]
147 pub fn must_not(mut self, query: impl AsQuery<Root>) -> Self {
148 if let Some(query) = query.into_query() {
149 self.bool_query.push_must_not(query.into_inner());
150 }
151 self
152 }
153
154 #[must_use]
156 pub fn should(mut self, query: impl AsQuery<Root>) -> Self {
157 if let Some(query) = query.into_query() {
158 self.bool_query.push_should(query.into_inner());
159 }
160 self
161 }
162
163 #[must_use]
168 pub fn min_should_match(mut self, value: impl Into<MinimumShouldMatch>) -> Self {
169 self.bool_query
170 .set_min_should_match(value.into().to_value());
171 self
172 }
173
174 #[must_use]
175 pub fn sort(mut self, sort: Sort) -> Self {
176 self.sort.push(sort);
177 self
178 }
179
180 #[must_use]
183 pub fn sorts(mut self, sorts: impl IntoIterator<Item = Sort>) -> Self {
184 self.sort.extend(sorts);
185 self
186 }
187
188 #[must_use]
190 pub fn min_score(mut self, min_score: f32) -> Self {
191 self.min_score = Some(min_score);
192 self
193 }
194
195 #[must_use]
198 pub fn track_total_hits(mut self, track: impl Into<Value>) -> Self {
199 self.track_total_hits = Some(track.into());
200 self
201 }
202
203 #[must_use]
205 pub fn track_scores(mut self, track: bool) -> Self {
206 self.track_scores = Some(track);
207 self
208 }
209
210 #[must_use]
213 pub fn search_after(mut self, values: impl IntoIterator<Item = impl Into<Value>>) -> Self {
214 self.search_after = Some(values.into_iter().map(Into::into).collect());
215 self
216 }
217
218 #[must_use]
220 pub fn collapse(mut self, field: impl Into<String>) -> Self {
221 let mut body = Map::new();
222 body.insert("field".to_string(), Value::String(field.into()));
223 self.collapse = Some(Value::Object(body));
224 self
225 }
226
227 #[must_use]
230 pub fn post_filter(mut self, query: impl AsQuery<Root>) -> Self {
231 if let Some(query) = query.into_query() {
232 self.post_filter = Some(query.to_value());
233 }
234 self
235 }
236
237 #[must_use]
239 pub fn highlight(mut self, highlight: Highlight) -> Self {
240 self.highlight = Some(highlight);
241 self
242 }
243
244 #[must_use]
246 pub fn from(mut self, from: u64) -> Self {
247 self.from = Some(from);
248 self
249 }
250
251 #[must_use]
253 pub fn size(mut self, size: u64) -> Self {
254 self.size = Some(size);
255 self
256 }
257
258 #[must_use]
262 pub fn raw(mut self, query: Value) -> Self {
263 self.raw = Some(query);
264 self
265 }
266
267 #[must_use]
271 pub fn filter_nested(mut self, projection: NestedProjection) -> Self {
272 self.nested.push(projection);
273 self
274 }
275
276 fn query_value(&self) -> Value {
280 match &self.raw {
281 Some(raw) => raw.clone(),
282 None if self.bool_query.is_empty() => crate::handles::match_all_value(),
283 None => self.bool_query.to_value(),
284 }
285 }
286
287 #[must_use]
290 pub fn body(&self) -> Value {
291 let query = self.query_value();
292
293 let query = if self.nested.is_empty() {
297 query
298 } else {
299 let mut bool_body = Map::new();
300 bool_body.insert("must".to_string(), Value::Array(vec![query]));
301 let shoulds = self.nested.iter().map(NestedProjection::to_value).collect();
302 bool_body.insert("should".to_string(), Value::Array(shoulds));
303 let mut outer = Map::new();
304 outer.insert("bool".to_string(), Value::Object(bool_body));
305 Value::Object(outer)
306 };
307
308 let mut root = Map::new();
309 root.insert("query".to_string(), query);
310 self.insert_page_params(&mut root);
311 self.insert_search_level(&mut root, true);
312 Value::Object(root)
313 }
314
315 fn insert_page_params(&self, root: &mut Map<String, Value>) {
318 if !self.sort.is_empty() {
319 let keys = self.sort.iter().map(Sort::to_value).collect();
320 root.insert("sort".to_string(), Value::Array(keys));
321 }
322 if let Some(from) = self.from {
323 root.insert("from".to_string(), Value::from(from));
324 }
325 if let Some(size) = self.size {
326 root.insert("size".to_string(), Value::from(size));
327 }
328 }
329
330 fn insert_search_level(&self, root: &mut Map<String, Value>, with_highlight: bool) {
337 if let Some(min_score) = self.min_score {
338 root.insert("min_score".to_string(), Value::from(min_score));
339 }
340 if let Some(track) = &self.track_total_hits {
341 root.insert("track_total_hits".to_string(), track.clone());
342 }
343 if let Some(track) = self.track_scores {
344 root.insert("track_scores".to_string(), Value::Bool(track));
345 }
346 if let Some(values) = &self.search_after {
347 root.insert("search_after".to_string(), Value::Array(values.clone()));
348 }
349 if let Some(collapse) = &self.collapse {
350 root.insert("collapse".to_string(), collapse.clone());
351 }
352 if let Some(post_filter) = &self.post_filter {
353 root.insert("post_filter".to_string(), post_filter.clone());
354 }
355 if with_highlight && let Some(highlight) = &self.highlight {
356 root.insert("highlight".to_string(), highlight.to_value());
357 }
358 }
359
360 #[must_use]
365 pub fn count_body(&self) -> Value {
366 let mut root = Map::new();
367 root.insert("query".to_string(), self.query_value());
368 Value::Object(root)
369 }
370
371 #[must_use]
377 pub fn ids_body(&self) -> Value {
378 let mut root = Map::new();
379 root.insert("query".to_string(), self.query_value());
380 self.insert_page_params(&mut root);
381 self.insert_search_level(&mut root, false);
382 root.insert("_source".to_string(), Value::Bool(false));
383 Value::Object(root)
384 }
385
386 #[tracing::instrument(
392 name = "search.ids",
393 skip_all,
394 fields(index = %self.index, returned = tracing::field::Empty),
395 err,
396 )]
397 pub async fn ids(&self, client: &Client) -> Result<Vec<String>> {
398 let body = self.ids_body();
399 let response = client.search_at(&self.physical_index(), &body).await?;
400 let raw: RawIdsResponse = serde_json::from_value(response)?;
401 let ids: Vec<String> = raw.hits.hits.into_iter().map(|hit| hit.id).collect();
402 tracing::Span::current().record("returned", ids.len());
403 tracing::debug!(returned = ids.len(), "ids search completed");
404 Ok(ids)
405 }
406
407 pub(crate) fn physical_index(&self) -> String {
411 format!("{}_{}", self.index, self.hash)
412 }
413
414 pub(crate) fn nested_paths(&self) -> Vec<&str> {
417 self.nested.iter().map(NestedProjection::path).collect()
418 }
419
420 #[tracing::instrument(
425 name = "search.count",
426 skip_all,
427 fields(index = %self.index, count = tracing::field::Empty),
428 err,
429 )]
430 pub async fn count(&self, client: &Client) -> Result<u64> {
431 let body = self.count_body();
432 let response = client.count_at(&self.physical_index(), &body).await?;
433 let raw: RawCount = serde_json::from_value(response)?;
434 tracing::Span::current().record("count", raw.count);
435 tracing::debug!(count = raw.count, "count completed");
436 Ok(raw.count)
437 }
438}
439
440impl<T> Search<T>
441where
442 T: DeserializeOwned,
443{
444 #[tracing::instrument(
446 name = "search.send",
447 skip_all,
448 fields(
449 index = %self.index,
450 from = ?self.from,
451 size = ?self.size,
452 total = tracing::field::Empty,
453 took_ms = tracing::field::Empty,
454 ),
455 err,
456 )]
457 pub async fn send(&self, client: &Client) -> Result<SearchResponse<T>> {
458 let body = self.body();
459 let mut response = client.search_at(&self.physical_index(), &body).await?;
460 let paths = self.nested_paths();
461 if !paths.is_empty() {
462 merge_inner_hits(&mut response, &paths);
463 }
464 let page = SearchResponse::from_value(response)?;
465 let span = tracing::Span::current();
466 span.record("total", page.total);
467 span.record("took_ms", page.took.as_millis() as u64);
468 if page.is_partial() {
469 tracing::warn!(
470 index = %self.index,
471 timed_out = page.timed_out,
472 shards_failed = page.shards.failed,
473 shards_total = page.shards.total,
474 "search returned partial results"
475 );
476 }
477 tracing::debug!(
478 total = page.total,
479 hits = page.hits.len(),
480 "search completed"
481 );
482 Ok(page)
483 }
484}
485
486pub(crate) fn merge_inner_hits(response: &mut Value, paths: &[&str]) {
489 let Some(hits) = response
490 .get_mut("hits")
491 .and_then(|hits| hits.get_mut("hits"))
492 .and_then(Value::as_array_mut)
493 else {
494 return;
495 };
496 for hit in hits {
497 let inner = match hit.get("inner_hits") {
498 Some(inner) => inner.clone(),
499 None => continue,
500 };
501 let Some(source) = hit.get_mut("_source").and_then(Value::as_object_mut) else {
502 continue;
503 };
504 for path in paths {
505 let subset: Vec<Value> = inner
506 .get(*path)
507 .and_then(|hit| hit.get("hits"))
508 .and_then(|hits| hits.get("hits"))
509 .and_then(Value::as_array)
510 .map(|hits| {
511 hits.iter()
512 .filter_map(|h| h.get("_source").cloned())
513 .collect()
514 })
515 .unwrap_or_default();
516 source.insert((*path).to_string(), Value::Array(subset));
517 }
518 }
519}
520
521#[derive(Debug, Clone, Default)]
525pub struct Highlight {
526 fields: Map<String, Value>,
527 opts: Map<String, Value>,
528}
529
530impl Highlight {
531 #[must_use]
533 pub fn new() -> Self {
534 Self::default()
535 }
536
537 #[must_use]
539 pub fn field(mut self, field: impl Into<String>) -> Self {
540 self.fields.insert(field.into(), Value::Object(Map::new()));
541 self
542 }
543
544 #[must_use]
547 pub fn field_with(mut self, field: impl Into<String>, settings: Value) -> Self {
548 self.fields.insert(field.into(), settings);
549 self
550 }
551
552 #[must_use]
554 pub fn pre_tags(mut self, tags: impl IntoIterator<Item = impl Into<String>>) -> Self {
555 self.opts.insert(
556 "pre_tags".to_string(),
557 Value::Array(tags.into_iter().map(|t| Value::String(t.into())).collect()),
558 );
559 self
560 }
561
562 #[must_use]
564 pub fn post_tags(mut self, tags: impl IntoIterator<Item = impl Into<String>>) -> Self {
565 self.opts.insert(
566 "post_tags".to_string(),
567 Value::Array(tags.into_iter().map(|t| Value::String(t.into())).collect()),
568 );
569 self
570 }
571
572 #[must_use]
574 pub fn fragment_size(mut self, fragment_size: u32) -> Self {
575 self.opts
576 .insert("fragment_size".to_string(), Value::from(fragment_size));
577 self
578 }
579
580 #[must_use]
582 pub fn number_of_fragments(mut self, number_of_fragments: u32) -> Self {
583 self.opts.insert(
584 "number_of_fragments".to_string(),
585 Value::from(number_of_fragments),
586 );
587 self
588 }
589
590 #[must_use]
592 pub fn require_field_match(mut self, require: bool) -> Self {
593 self.opts
594 .insert("require_field_match".to_string(), Value::Bool(require));
595 self
596 }
597
598 fn to_value(&self) -> Value {
599 let mut body = self.opts.clone();
600 body.insert("fields".to_string(), Value::Object(self.fields.clone()));
601 Value::Object(body)
602 }
603}
604
605#[derive(Debug)]
607pub struct SearchResponse<T> {
608 pub total: u64,
610 pub max_score: Option<f32>,
612 pub hits: Vec<Hit<T>>,
614 pub took: Duration,
616 pub timed_out: bool,
620 pub shards: ShardStats,
624}
625
626impl<T> SearchResponse<T> {
627 #[must_use]
631 pub fn is_partial(&self) -> bool {
632 self.timed_out || self.shards.failed > 0
633 }
634}
635
636#[derive(Debug, Clone, Copy, Default)]
641pub struct ShardStats {
642 pub total: u64,
644 pub successful: u64,
646 pub skipped: u64,
648 pub failed: u64,
650}
651
652impl<T> SearchResponse<T>
653where
654 T: DeserializeOwned,
655{
656 pub fn from_value(value: Value) -> Result<Self> {
658 let raw: RawResponse<T> = serde_json::from_value(value)?;
659 let hits = raw
660 .hits
661 .hits
662 .into_iter()
663 .map(|hit| Hit {
664 id: hit.id,
665 score: hit.score.unwrap_or(0.0),
666 source: hit.source,
667 })
668 .collect();
669 Ok(Self {
670 total: raw.hits.total.value,
671 max_score: raw.hits.max_score,
672 hits,
673 took: Duration::from_millis(raw.took),
674 timed_out: raw.timed_out,
675 shards: raw.shards.into(),
676 })
677 }
678}
679
680#[derive(Debug)]
682pub struct Hit<T> {
683 pub id: String,
685 pub score: f32,
687 pub source: T,
689}
690
691#[derive(Deserialize)]
692struct RawResponse<T> {
693 #[serde(default)]
694 took: u64,
695 #[serde(default)]
696 timed_out: bool,
697 #[serde(rename = "_shards", default)]
698 shards: RawShards,
699 hits: RawHits<T>,
700}
701
702#[derive(Deserialize, Default)]
704pub(crate) struct RawShards {
705 #[serde(default)]
706 total: u64,
707 #[serde(default)]
708 successful: u64,
709 #[serde(default)]
710 skipped: u64,
711 #[serde(default)]
712 failed: u64,
713}
714
715impl From<RawShards> for ShardStats {
716 fn from(raw: RawShards) -> Self {
717 Self {
718 total: raw.total,
719 successful: raw.successful,
720 skipped: raw.skipped,
721 failed: raw.failed,
722 }
723 }
724}
725
726#[derive(Deserialize)]
727struct RawHits<T> {
728 total: RawTotal,
729 #[serde(default)]
730 max_score: Option<f32>,
731 hits: Vec<RawHit<T>>,
732}
733
734#[derive(Deserialize)]
735struct RawTotal {
736 value: u64,
737}
738
739#[derive(Deserialize)]
742pub(crate) struct RawCount {
743 pub(crate) count: u64,
744}
745
746#[derive(Deserialize)]
749struct RawIdsResponse {
750 hits: RawIdsHits,
751}
752
753#[derive(Deserialize)]
754struct RawIdsHits {
755 hits: Vec<RawIdHit>,
756}
757
758#[derive(Deserialize)]
759struct RawIdHit {
760 #[serde(rename = "_id")]
761 id: String,
762}
763
764#[derive(Deserialize)]
765struct RawHit<T> {
766 #[serde(rename = "_id")]
767 id: String,
768 #[serde(rename = "_score", default)]
769 score: Option<f32>,
770 #[serde(rename = "_source")]
771 source: T,
772}