1use std::marker::PhantomData;
4use std::time::Duration;
5
6use serde::Deserialize;
7use serde::de::DeserializeOwned;
8use serde_json::{Map, Value};
9
10use crate::Client;
11use crate::error::Result;
12use crate::handles::{NestedProjection, Sort};
13use crate::query::{AsQuery, BoolBuilder, Root};
14
15pub trait FlussoDocument: DeserializeOwned {
23 const INDEX: &'static str;
25
26 const SCHEMA_HASH: &'static str;
28
29 fn physical_index() -> String {
34 format!("{}_{}", Self::INDEX, Self::SCHEMA_HASH)
35 }
36
37 fn query() -> Search<Self> {
43 Search::new(Self::INDEX, Self::SCHEMA_HASH)
44 }
45
46 fn get(
48 client: &Client,
49 id: impl std::fmt::Display,
50 ) -> impl std::future::Future<Output = Result<Option<Self>>> {
51 client.get_one::<Self>(Self::INDEX, Self::SCHEMA_HASH, id)
52 }
53}
54
55#[derive(Debug, Clone)]
66pub struct Search<T> {
67 index: String,
68 hash: String,
69 bool_query: BoolBuilder,
70 raw: Option<Value>,
71 sort: Vec<Sort>,
72 from: Option<u64>,
73 size: Option<u64>,
74 nested: Vec<NestedProjection>,
75 _marker: PhantomData<fn() -> T>,
76}
77
78impl<T> Search<T> {
79 pub fn new(index: impl Into<String>, hash: impl Into<String>) -> Self {
81 Self {
82 index: index.into(),
83 hash: hash.into(),
84 bool_query: BoolBuilder::default(),
85 raw: None,
86 sort: Vec::new(),
87 from: None,
88 size: None,
89 nested: Vec::new(),
90 _marker: PhantomData,
91 }
92 }
93
94 #[must_use]
97 pub fn query(mut self, query: impl AsQuery<Root>) -> Self {
98 if let Some(query) = query.into_query() {
99 self.bool_query.push_must(query.into_inner());
100 }
101 self
102 }
103
104 #[must_use]
107 pub fn filter(mut self, query: impl AsQuery<Root>) -> Self {
108 if let Some(query) = query.into_query() {
109 self.bool_query.push_filter(query.into_inner());
110 }
111 self
112 }
113
114 #[must_use]
116 pub fn must_not(mut self, query: impl AsQuery<Root>) -> Self {
117 if let Some(query) = query.into_query() {
118 self.bool_query.push_must_not(query.into_inner());
119 }
120 self
121 }
122
123 #[must_use]
125 pub fn should(mut self, query: impl AsQuery<Root>) -> Self {
126 if let Some(query) = query.into_query() {
127 self.bool_query.push_should(query.into_inner());
128 }
129 self
130 }
131
132 #[must_use]
134 pub fn sort(mut self, sort: Sort) -> Self {
135 self.sort.push(sort);
136 self
137 }
138
139 #[must_use]
141 pub fn from(mut self, from: u64) -> Self {
142 self.from = Some(from);
143 self
144 }
145
146 #[must_use]
148 pub fn size(mut self, size: u64) -> Self {
149 self.size = Some(size);
150 self
151 }
152
153 #[must_use]
157 pub fn raw(mut self, query: Value) -> Self {
158 self.raw = Some(query);
159 self
160 }
161
162 #[must_use]
166 pub fn filter_nested(mut self, projection: NestedProjection) -> Self {
167 self.nested.push(projection);
168 self
169 }
170
171 fn query_value(&self) -> Value {
175 match &self.raw {
176 Some(raw) => raw.clone(),
177 None if self.bool_query.is_empty() => crate::handles::match_all_value(),
178 None => self.bool_query.to_value(),
179 }
180 }
181
182 #[must_use]
185 pub fn body(&self) -> Value {
186 let query = self.query_value();
187
188 let query = if self.nested.is_empty() {
192 query
193 } else {
194 let mut bool_body = Map::new();
195 bool_body.insert("must".to_string(), Value::Array(vec![query]));
196 let shoulds = self.nested.iter().map(NestedProjection::to_value).collect();
197 bool_body.insert("should".to_string(), Value::Array(shoulds));
198 let mut outer = Map::new();
199 outer.insert("bool".to_string(), Value::Object(bool_body));
200 Value::Object(outer)
201 };
202
203 let mut root = Map::new();
204 root.insert("query".to_string(), query);
205 self.insert_page_params(&mut root);
206 Value::Object(root)
207 }
208
209 fn insert_page_params(&self, root: &mut Map<String, Value>) {
212 if !self.sort.is_empty() {
213 let keys = self.sort.iter().map(Sort::to_value).collect();
214 root.insert("sort".to_string(), Value::Array(keys));
215 }
216 if let Some(from) = self.from {
217 root.insert("from".to_string(), Value::from(from));
218 }
219 if let Some(size) = self.size {
220 root.insert("size".to_string(), Value::from(size));
221 }
222 }
223
224 #[must_use]
229 pub fn count_body(&self) -> Value {
230 let mut root = Map::new();
231 root.insert("query".to_string(), self.query_value());
232 Value::Object(root)
233 }
234
235 #[must_use]
241 pub fn ids_body(&self) -> Value {
242 let mut root = Map::new();
243 root.insert("query".to_string(), self.query_value());
244 self.insert_page_params(&mut root);
245 root.insert("_source".to_string(), Value::Bool(false));
246 Value::Object(root)
247 }
248
249 #[tracing::instrument(
255 name = "search.ids",
256 skip_all,
257 fields(index = %self.index, returned = tracing::field::Empty),
258 err,
259 )]
260 pub async fn ids(&self, client: &Client) -> Result<Vec<String>> {
261 let body = self.ids_body();
262 let response = client.search_at(&self.physical_index(), &body).await?;
263 let raw: RawIdsResponse = serde_json::from_value(response)?;
264 let ids: Vec<String> = raw.hits.hits.into_iter().map(|hit| hit.id).collect();
265 tracing::Span::current().record("returned", ids.len());
266 tracing::debug!(returned = ids.len(), "ids search completed");
267 Ok(ids)
268 }
269
270 pub(crate) fn physical_index(&self) -> String {
274 format!("{}_{}", self.index, self.hash)
275 }
276
277 pub(crate) fn nested_paths(&self) -> Vec<&str> {
280 self.nested.iter().map(NestedProjection::path).collect()
281 }
282
283 #[tracing::instrument(
288 name = "search.count",
289 skip_all,
290 fields(index = %self.index, count = tracing::field::Empty),
291 err,
292 )]
293 pub async fn count(&self, client: &Client) -> Result<u64> {
294 let body = self.count_body();
295 let response = client.count_at(&self.physical_index(), &body).await?;
296 let raw: RawCount = serde_json::from_value(response)?;
297 tracing::Span::current().record("count", raw.count);
298 tracing::debug!(count = raw.count, "count completed");
299 Ok(raw.count)
300 }
301}
302
303impl<T> Search<T>
304where
305 T: DeserializeOwned,
306{
307 #[tracing::instrument(
309 name = "search.send",
310 skip_all,
311 fields(
312 index = %self.index,
313 from = ?self.from,
314 size = ?self.size,
315 total = tracing::field::Empty,
316 took_ms = tracing::field::Empty,
317 ),
318 err,
319 )]
320 pub async fn send(&self, client: &Client) -> Result<SearchResponse<T>> {
321 let body = self.body();
322 let mut response = client.search_at(&self.physical_index(), &body).await?;
323 let paths = self.nested_paths();
324 if !paths.is_empty() {
325 merge_inner_hits(&mut response, &paths);
326 }
327 let page = SearchResponse::from_value(response)?;
328 let span = tracing::Span::current();
329 span.record("total", page.total);
330 span.record("took_ms", page.took.as_millis() as u64);
331 tracing::debug!(
332 total = page.total,
333 hits = page.hits.len(),
334 "search completed"
335 );
336 Ok(page)
337 }
338}
339
340pub(crate) fn merge_inner_hits(response: &mut Value, paths: &[&str]) {
343 let Some(hits) = response
344 .get_mut("hits")
345 .and_then(|hits| hits.get_mut("hits"))
346 .and_then(Value::as_array_mut)
347 else {
348 return;
349 };
350 for hit in hits {
351 let inner = match hit.get("inner_hits") {
352 Some(inner) => inner.clone(),
353 None => continue,
354 };
355 let Some(source) = hit.get_mut("_source").and_then(Value::as_object_mut) else {
356 continue;
357 };
358 for path in paths {
359 let subset: Vec<Value> = inner
360 .get(*path)
361 .and_then(|hit| hit.get("hits"))
362 .and_then(|hits| hits.get("hits"))
363 .and_then(Value::as_array)
364 .map(|hits| {
365 hits.iter()
366 .filter_map(|h| h.get("_source").cloned())
367 .collect()
368 })
369 .unwrap_or_default();
370 source.insert((*path).to_string(), Value::Array(subset));
371 }
372 }
373}
374
375#[derive(Debug)]
377pub struct SearchResponse<T> {
378 pub total: u64,
380 pub max_score: Option<f32>,
382 pub hits: Vec<Hit<T>>,
384 pub took: Duration,
386}
387
388impl<T> SearchResponse<T>
389where
390 T: DeserializeOwned,
391{
392 pub fn from_value(value: Value) -> Result<Self> {
394 let raw: RawResponse<T> = serde_json::from_value(value)?;
395 let hits = raw
396 .hits
397 .hits
398 .into_iter()
399 .map(|hit| Hit {
400 id: hit.id,
401 score: hit.score.unwrap_or(0.0),
402 source: hit.source,
403 })
404 .collect();
405 Ok(Self {
406 total: raw.hits.total.value,
407 max_score: raw.hits.max_score,
408 hits,
409 took: Duration::from_millis(raw.took),
410 })
411 }
412}
413
414#[derive(Debug)]
416pub struct Hit<T> {
417 pub id: String,
419 pub score: f32,
421 pub source: T,
423}
424
425#[derive(Deserialize)]
428struct RawResponse<T> {
429 #[serde(default)]
430 took: u64,
431 hits: RawHits<T>,
432}
433
434#[derive(Deserialize)]
435struct RawHits<T> {
436 total: RawTotal,
437 #[serde(default)]
438 max_score: Option<f32>,
439 hits: Vec<RawHit<T>>,
440}
441
442#[derive(Deserialize)]
443struct RawTotal {
444 value: u64,
445}
446
447#[derive(Deserialize)]
450pub(crate) struct RawCount {
451 pub(crate) count: u64,
452}
453
454#[derive(Deserialize)]
457struct RawIdsResponse {
458 hits: RawIdsHits,
459}
460
461#[derive(Deserialize)]
462struct RawIdsHits {
463 hits: Vec<RawIdHit>,
464}
465
466#[derive(Deserialize)]
467struct RawIdHit {
468 #[serde(rename = "_id")]
469 id: String,
470}
471
472#[derive(Deserialize)]
473struct RawHit<T> {
474 #[serde(rename = "_id")]
475 id: String,
476 #[serde(rename = "_score", default)]
477 score: Option<f32>,
478 #[serde(rename = "_source")]
479 source: T,
480}