1use std::marker::PhantomData;
4use std::time::Duration;
5
6use serde::Deserialize;
7use serde::de::DeserializeOwned;
8use serde_json::{Map, Value};
9
10use crate::Client;
11use crate::error::Result;
12use crate::handles::{MinimumShouldMatch, NestedProjection, Sort};
13use crate::query::{AsQuery, BoolBuilder, Root};
14
15pub trait FlussoDocument: DeserializeOwned {
23 const INDEX: &'static str;
25
26 const SCHEMA_HASH: &'static str;
28
29 fn physical_index() -> String {
34 format!("{}_{}", Self::INDEX, Self::SCHEMA_HASH)
35 }
36
37 fn query() -> Search<Self> {
43 Search::new(Self::INDEX, Self::SCHEMA_HASH)
44 }
45
46 fn get(
48 client: &Client,
49 id: impl std::fmt::Display,
50 ) -> impl std::future::Future<Output = Result<Option<Self>>> {
51 client.get_one::<Self>(Self::INDEX, Self::SCHEMA_HASH, id)
52 }
53}
54
55#[derive(Debug, Clone)]
66pub struct Search<T> {
67 index: String,
68 hash: String,
69 bool_query: BoolBuilder,
70 raw: Option<Value>,
71 sort: Vec<Sort>,
72 from: Option<u64>,
73 size: Option<u64>,
74 nested: Vec<NestedProjection>,
75 min_score: Option<f32>,
76 track_total_hits: Option<Value>,
77 track_scores: Option<bool>,
78 search_after: Option<Vec<Value>>,
79 collapse: Option<Value>,
80 post_filter: Option<Value>,
81 highlight: Option<Highlight>,
82 _marker: PhantomData<fn() -> T>,
83}
84
85impl<T> Search<T> {
86 pub fn new(index: impl Into<String>, hash: impl Into<String>) -> Self {
88 Self {
89 index: index.into(),
90 hash: hash.into(),
91 bool_query: BoolBuilder::default(),
92 raw: None,
93 sort: Vec::new(),
94 from: None,
95 size: None,
96 nested: Vec::new(),
97 min_score: None,
98 track_total_hits: None,
99 track_scores: None,
100 search_after: None,
101 collapse: None,
102 post_filter: None,
103 highlight: None,
104 _marker: PhantomData,
105 }
106 }
107
108 #[must_use]
111 pub fn query(mut self, query: impl AsQuery<Root>) -> Self {
112 if let Some(query) = query.into_query() {
113 self.bool_query.push_must(query.into_inner());
114 }
115 self
116 }
117
118 #[must_use]
121 pub fn filter(mut self, query: impl AsQuery<Root>) -> Self {
122 if let Some(query) = query.into_query() {
123 self.bool_query.push_filter(query.into_inner());
124 }
125 self
126 }
127
128 #[must_use]
130 pub fn must_not(mut self, query: impl AsQuery<Root>) -> Self {
131 if let Some(query) = query.into_query() {
132 self.bool_query.push_must_not(query.into_inner());
133 }
134 self
135 }
136
137 #[must_use]
139 pub fn should(mut self, query: impl AsQuery<Root>) -> Self {
140 if let Some(query) = query.into_query() {
141 self.bool_query.push_should(query.into_inner());
142 }
143 self
144 }
145
146 #[must_use]
151 pub fn min_should_match(mut self, value: impl Into<MinimumShouldMatch>) -> Self {
152 self.bool_query
153 .set_min_should_match(value.into().to_value());
154 self
155 }
156
157 #[must_use]
158 pub fn sort(mut self, sort: Sort) -> Self {
159 self.sort.push(sort);
160 self
161 }
162
163 #[must_use]
165 pub fn min_score(mut self, min_score: f32) -> Self {
166 self.min_score = Some(min_score);
167 self
168 }
169
170 #[must_use]
173 pub fn track_total_hits(mut self, track: impl Into<Value>) -> Self {
174 self.track_total_hits = Some(track.into());
175 self
176 }
177
178 #[must_use]
180 pub fn track_scores(mut self, track: bool) -> Self {
181 self.track_scores = Some(track);
182 self
183 }
184
185 #[must_use]
188 pub fn search_after(mut self, values: impl IntoIterator<Item = impl Into<Value>>) -> Self {
189 self.search_after = Some(values.into_iter().map(Into::into).collect());
190 self
191 }
192
193 #[must_use]
195 pub fn collapse(mut self, field: impl Into<String>) -> Self {
196 let mut body = Map::new();
197 body.insert("field".to_string(), Value::String(field.into()));
198 self.collapse = Some(Value::Object(body));
199 self
200 }
201
202 #[must_use]
205 pub fn post_filter(mut self, query: impl AsQuery<Root>) -> Self {
206 if let Some(query) = query.into_query() {
207 self.post_filter = Some(query.to_value());
208 }
209 self
210 }
211
212 #[must_use]
214 pub fn highlight(mut self, highlight: Highlight) -> Self {
215 self.highlight = Some(highlight);
216 self
217 }
218
219 #[must_use]
221 pub fn from(mut self, from: u64) -> Self {
222 self.from = Some(from);
223 self
224 }
225
226 #[must_use]
228 pub fn size(mut self, size: u64) -> Self {
229 self.size = Some(size);
230 self
231 }
232
233 #[must_use]
237 pub fn raw(mut self, query: Value) -> Self {
238 self.raw = Some(query);
239 self
240 }
241
242 #[must_use]
246 pub fn filter_nested(mut self, projection: NestedProjection) -> Self {
247 self.nested.push(projection);
248 self
249 }
250
251 fn query_value(&self) -> Value {
255 match &self.raw {
256 Some(raw) => raw.clone(),
257 None if self.bool_query.is_empty() => crate::handles::match_all_value(),
258 None => self.bool_query.to_value(),
259 }
260 }
261
262 #[must_use]
265 pub fn body(&self) -> Value {
266 let query = self.query_value();
267
268 let query = if self.nested.is_empty() {
272 query
273 } else {
274 let mut bool_body = Map::new();
275 bool_body.insert("must".to_string(), Value::Array(vec![query]));
276 let shoulds = self.nested.iter().map(NestedProjection::to_value).collect();
277 bool_body.insert("should".to_string(), Value::Array(shoulds));
278 let mut outer = Map::new();
279 outer.insert("bool".to_string(), Value::Object(bool_body));
280 Value::Object(outer)
281 };
282
283 let mut root = Map::new();
284 root.insert("query".to_string(), query);
285 self.insert_page_params(&mut root);
286 self.insert_search_level(&mut root, true);
287 Value::Object(root)
288 }
289
290 fn insert_page_params(&self, root: &mut Map<String, Value>) {
293 if !self.sort.is_empty() {
294 let keys = self.sort.iter().map(Sort::to_value).collect();
295 root.insert("sort".to_string(), Value::Array(keys));
296 }
297 if let Some(from) = self.from {
298 root.insert("from".to_string(), Value::from(from));
299 }
300 if let Some(size) = self.size {
301 root.insert("size".to_string(), Value::from(size));
302 }
303 }
304
305 fn insert_search_level(&self, root: &mut Map<String, Value>, with_highlight: bool) {
312 if let Some(min_score) = self.min_score {
313 root.insert("min_score".to_string(), Value::from(min_score));
314 }
315 if let Some(track) = &self.track_total_hits {
316 root.insert("track_total_hits".to_string(), track.clone());
317 }
318 if let Some(track) = self.track_scores {
319 root.insert("track_scores".to_string(), Value::Bool(track));
320 }
321 if let Some(values) = &self.search_after {
322 root.insert("search_after".to_string(), Value::Array(values.clone()));
323 }
324 if let Some(collapse) = &self.collapse {
325 root.insert("collapse".to_string(), collapse.clone());
326 }
327 if let Some(post_filter) = &self.post_filter {
328 root.insert("post_filter".to_string(), post_filter.clone());
329 }
330 if with_highlight && let Some(highlight) = &self.highlight {
331 root.insert("highlight".to_string(), highlight.to_value());
332 }
333 }
334
335 #[must_use]
340 pub fn count_body(&self) -> Value {
341 let mut root = Map::new();
342 root.insert("query".to_string(), self.query_value());
343 Value::Object(root)
344 }
345
346 #[must_use]
352 pub fn ids_body(&self) -> Value {
353 let mut root = Map::new();
354 root.insert("query".to_string(), self.query_value());
355 self.insert_page_params(&mut root);
356 self.insert_search_level(&mut root, false);
357 root.insert("_source".to_string(), Value::Bool(false));
358 Value::Object(root)
359 }
360
361 #[tracing::instrument(
367 name = "search.ids",
368 skip_all,
369 fields(index = %self.index, returned = tracing::field::Empty),
370 err,
371 )]
372 pub async fn ids(&self, client: &Client) -> Result<Vec<String>> {
373 let body = self.ids_body();
374 let response = client.search_at(&self.physical_index(), &body).await?;
375 let raw: RawIdsResponse = serde_json::from_value(response)?;
376 let ids: Vec<String> = raw.hits.hits.into_iter().map(|hit| hit.id).collect();
377 tracing::Span::current().record("returned", ids.len());
378 tracing::debug!(returned = ids.len(), "ids search completed");
379 Ok(ids)
380 }
381
382 pub(crate) fn physical_index(&self) -> String {
386 format!("{}_{}", self.index, self.hash)
387 }
388
389 pub(crate) fn nested_paths(&self) -> Vec<&str> {
392 self.nested.iter().map(NestedProjection::path).collect()
393 }
394
395 #[tracing::instrument(
400 name = "search.count",
401 skip_all,
402 fields(index = %self.index, count = tracing::field::Empty),
403 err,
404 )]
405 pub async fn count(&self, client: &Client) -> Result<u64> {
406 let body = self.count_body();
407 let response = client.count_at(&self.physical_index(), &body).await?;
408 let raw: RawCount = serde_json::from_value(response)?;
409 tracing::Span::current().record("count", raw.count);
410 tracing::debug!(count = raw.count, "count completed");
411 Ok(raw.count)
412 }
413}
414
415impl<T> Search<T>
416where
417 T: DeserializeOwned,
418{
419 #[tracing::instrument(
421 name = "search.send",
422 skip_all,
423 fields(
424 index = %self.index,
425 from = ?self.from,
426 size = ?self.size,
427 total = tracing::field::Empty,
428 took_ms = tracing::field::Empty,
429 ),
430 err,
431 )]
432 pub async fn send(&self, client: &Client) -> Result<SearchResponse<T>> {
433 let body = self.body();
434 let mut response = client.search_at(&self.physical_index(), &body).await?;
435 let paths = self.nested_paths();
436 if !paths.is_empty() {
437 merge_inner_hits(&mut response, &paths);
438 }
439 let page = SearchResponse::from_value(response)?;
440 let span = tracing::Span::current();
441 span.record("total", page.total);
442 span.record("took_ms", page.took.as_millis() as u64);
443 tracing::debug!(
444 total = page.total,
445 hits = page.hits.len(),
446 "search completed"
447 );
448 Ok(page)
449 }
450}
451
452pub(crate) fn merge_inner_hits(response: &mut Value, paths: &[&str]) {
455 let Some(hits) = response
456 .get_mut("hits")
457 .and_then(|hits| hits.get_mut("hits"))
458 .and_then(Value::as_array_mut)
459 else {
460 return;
461 };
462 for hit in hits {
463 let inner = match hit.get("inner_hits") {
464 Some(inner) => inner.clone(),
465 None => continue,
466 };
467 let Some(source) = hit.get_mut("_source").and_then(Value::as_object_mut) else {
468 continue;
469 };
470 for path in paths {
471 let subset: Vec<Value> = inner
472 .get(*path)
473 .and_then(|hit| hit.get("hits"))
474 .and_then(|hits| hits.get("hits"))
475 .and_then(Value::as_array)
476 .map(|hits| {
477 hits.iter()
478 .filter_map(|h| h.get("_source").cloned())
479 .collect()
480 })
481 .unwrap_or_default();
482 source.insert((*path).to_string(), Value::Array(subset));
483 }
484 }
485}
486
487#[derive(Debug, Clone, Default)]
491pub struct Highlight {
492 fields: Map<String, Value>,
493 opts: Map<String, Value>,
494}
495
496impl Highlight {
497 #[must_use]
499 pub fn new() -> Self {
500 Self::default()
501 }
502
503 #[must_use]
505 pub fn field(mut self, field: impl Into<String>) -> Self {
506 self.fields.insert(field.into(), Value::Object(Map::new()));
507 self
508 }
509
510 #[must_use]
513 pub fn field_with(mut self, field: impl Into<String>, settings: Value) -> Self {
514 self.fields.insert(field.into(), settings);
515 self
516 }
517
518 #[must_use]
520 pub fn pre_tags(mut self, tags: impl IntoIterator<Item = impl Into<String>>) -> Self {
521 self.opts.insert(
522 "pre_tags".to_string(),
523 Value::Array(tags.into_iter().map(|t| Value::String(t.into())).collect()),
524 );
525 self
526 }
527
528 #[must_use]
530 pub fn post_tags(mut self, tags: impl IntoIterator<Item = impl Into<String>>) -> Self {
531 self.opts.insert(
532 "post_tags".to_string(),
533 Value::Array(tags.into_iter().map(|t| Value::String(t.into())).collect()),
534 );
535 self
536 }
537
538 #[must_use]
540 pub fn fragment_size(mut self, fragment_size: u32) -> Self {
541 self.opts
542 .insert("fragment_size".to_string(), Value::from(fragment_size));
543 self
544 }
545
546 #[must_use]
548 pub fn number_of_fragments(mut self, number_of_fragments: u32) -> Self {
549 self.opts.insert(
550 "number_of_fragments".to_string(),
551 Value::from(number_of_fragments),
552 );
553 self
554 }
555
556 #[must_use]
558 pub fn require_field_match(mut self, require: bool) -> Self {
559 self.opts
560 .insert("require_field_match".to_string(), Value::Bool(require));
561 self
562 }
563
564 fn to_value(&self) -> Value {
565 let mut body = self.opts.clone();
566 body.insert("fields".to_string(), Value::Object(self.fields.clone()));
567 Value::Object(body)
568 }
569}
570
571#[derive(Debug)]
573pub struct SearchResponse<T> {
574 pub total: u64,
576 pub max_score: Option<f32>,
578 pub hits: Vec<Hit<T>>,
580 pub took: Duration,
582}
583
584impl<T> SearchResponse<T>
585where
586 T: DeserializeOwned,
587{
588 pub fn from_value(value: Value) -> Result<Self> {
590 let raw: RawResponse<T> = serde_json::from_value(value)?;
591 let hits = raw
592 .hits
593 .hits
594 .into_iter()
595 .map(|hit| Hit {
596 id: hit.id,
597 score: hit.score.unwrap_or(0.0),
598 source: hit.source,
599 })
600 .collect();
601 Ok(Self {
602 total: raw.hits.total.value,
603 max_score: raw.hits.max_score,
604 hits,
605 took: Duration::from_millis(raw.took),
606 })
607 }
608}
609
610#[derive(Debug)]
612pub struct Hit<T> {
613 pub id: String,
615 pub score: f32,
617 pub source: T,
619}
620
621#[derive(Deserialize)]
622struct RawResponse<T> {
623 #[serde(default)]
624 took: u64,
625 hits: RawHits<T>,
626}
627
628#[derive(Deserialize)]
629struct RawHits<T> {
630 total: RawTotal,
631 #[serde(default)]
632 max_score: Option<f32>,
633 hits: Vec<RawHit<T>>,
634}
635
636#[derive(Deserialize)]
637struct RawTotal {
638 value: u64,
639}
640
641#[derive(Deserialize)]
644pub(crate) struct RawCount {
645 pub(crate) count: u64,
646}
647
648#[derive(Deserialize)]
651struct RawIdsResponse {
652 hits: RawIdsHits,
653}
654
655#[derive(Deserialize)]
656struct RawIdsHits {
657 hits: Vec<RawIdHit>,
658}
659
660#[derive(Deserialize)]
661struct RawIdHit {
662 #[serde(rename = "_id")]
663 id: String,
664}
665
666#[derive(Deserialize)]
667struct RawHit<T> {
668 #[serde(rename = "_id")]
669 id: String,
670 #[serde(rename = "_score", default)]
671 score: Option<f32>,
672 #[serde(rename = "_source")]
673 source: T,
674}