use std::marker::PhantomData;
use std::time::Duration;
use serde::Deserialize;
use serde_json::{Map, Value};
use crate::Client;
use crate::error::Result;
use crate::handles::Sort;
use crate::query::{AsQuery, BoolBuilder, Root};
use crate::search::{Hit, RawCount, SearchResponse};
pub trait FlussoMultiDocument: Sized {
const TARGETS: &'static [(&'static str, &'static str)];
fn decode(physical_index: &str, source: Value) -> Result<Self>;
fn query() -> MultiSearch<Self> {
MultiSearch::new()
}
}
#[derive(Debug, Clone)]
pub struct MultiSearch<U> {
path: String,
bool_query: BoolBuilder,
raw: Option<Value>,
sort: Vec<Sort>,
from: Option<u64>,
size: Option<u64>,
_marker: PhantomData<fn() -> U>,
}
impl<U: FlussoMultiDocument> MultiSearch<U> {
#[must_use]
pub fn new() -> Self {
let path = U::TARGETS
.iter()
.map(|(index, hash)| format!("{index}_{hash}"))
.collect::<Vec<_>>()
.join(",");
Self {
path,
bool_query: BoolBuilder::default(),
raw: None,
sort: Vec::new(),
from: None,
size: None,
_marker: PhantomData,
}
}
#[must_use]
pub fn query(mut self, query: impl AsQuery<Root>) -> Self {
if let Some(query) = query.into_query() {
self.bool_query.push_must(query.into_inner());
}
self
}
#[must_use]
pub fn filter(mut self, query: impl AsQuery<Root>) -> Self {
if let Some(query) = query.into_query() {
self.bool_query.push_filter(query.into_inner());
}
self
}
#[must_use]
pub fn must_not(mut self, query: impl AsQuery<Root>) -> Self {
if let Some(query) = query.into_query() {
self.bool_query.push_must_not(query.into_inner());
}
self
}
#[must_use]
pub fn should(mut self, query: impl AsQuery<Root>) -> Self {
if let Some(query) = query.into_query() {
self.bool_query.push_should(query.into_inner());
}
self
}
#[must_use]
pub fn sort(mut self, sort: Sort) -> Self {
self.sort.push(sort);
self
}
#[must_use]
pub fn sorts(mut self, sorts: impl IntoIterator<Item = Sort>) -> Self {
self.sort.extend(sorts);
self
}
#[must_use]
pub fn from(mut self, from: u64) -> Self {
self.from = Some(from);
self
}
#[must_use]
pub fn size(mut self, size: u64) -> Self {
self.size = Some(size);
self
}
#[must_use]
pub fn raw(mut self, query: Value) -> Self {
self.raw = Some(query);
self
}
#[must_use]
pub fn physical_path(&self) -> &str {
&self.path
}
fn query_value(&self) -> Value {
match &self.raw {
Some(raw) => raw.clone(),
None if self.bool_query.is_empty() => crate::handles::match_all_value(),
None => self.bool_query.to_value(),
}
}
#[must_use]
pub fn body(&self) -> Value {
let mut root = Map::new();
root.insert("query".to_string(), self.query_value());
if !self.sort.is_empty() {
let keys = self.sort.iter().map(Sort::to_value).collect();
root.insert("sort".to_string(), Value::Array(keys));
}
if let Some(from) = self.from {
root.insert("from".to_string(), Value::from(from));
}
if let Some(size) = self.size {
root.insert("size".to_string(), Value::from(size));
}
Value::Object(root)
}
#[must_use]
pub fn count_body(&self) -> Value {
let mut root = Map::new();
root.insert("query".to_string(), self.query_value());
Value::Object(root)
}
#[tracing::instrument(
name = "search.multi",
skip_all,
fields(
path = %self.path,
from = ?self.from,
size = ?self.size,
total = tracing::field::Empty,
took_ms = tracing::field::Empty,
),
err,
)]
pub async fn send(&self, client: &Client) -> Result<SearchResponse<U>> {
let body = self.body();
let response = client.search_at(&self.path, &body).await?;
let page = decode_response::<U>(response, &client.index_prefix)?;
let span = tracing::Span::current();
span.record("total", page.total);
span.record("took_ms", page.took.as_millis() as u64);
tracing::debug!(
total = page.total,
hits = page.hits.len(),
"combined search completed"
);
Ok(page)
}
#[tracing::instrument(
name = "search.multi_count",
skip_all,
fields(path = %self.path, count = tracing::field::Empty),
err,
)]
pub async fn count(&self, client: &Client) -> Result<u64> {
let body = self.count_body();
let response = client.count_at(&self.path, &body).await?;
let raw: RawCount = serde_json::from_value(response)?;
tracing::Span::current().record("count", raw.count);
tracing::debug!(count = raw.count, "combined count completed");
Ok(raw.count)
}
}
impl<U: FlussoMultiDocument> Default for MultiSearch<U> {
fn default() -> Self {
Self::new()
}
}
pub(crate) fn decode_response<U: FlussoMultiDocument>(
value: Value,
prefix: &str,
) -> Result<SearchResponse<U>> {
let raw: RawMultiResponse = serde_json::from_value(value)?;
let hits = raw
.hits
.hits
.into_iter()
.map(|hit| {
let index = hit.index.strip_prefix(prefix).unwrap_or(&hit.index);
Ok(Hit {
id: hit.id,
score: hit.score.unwrap_or(0.0),
source: U::decode(index, hit.source)?,
})
})
.collect::<Result<Vec<_>>>()?;
Ok(SearchResponse {
total: raw.hits.total.value,
max_score: raw.hits.max_score,
hits,
took: Duration::from_millis(raw.took),
})
}
#[derive(Deserialize)]
struct RawMultiResponse {
#[serde(default)]
took: u64,
hits: RawMultiHits,
}
#[derive(Deserialize)]
struct RawMultiHits {
total: RawMultiTotal,
#[serde(default)]
max_score: Option<f32>,
hits: Vec<RawMultiHit>,
}
#[derive(Deserialize)]
struct RawMultiTotal {
value: u64,
}
#[derive(Deserialize)]
struct RawMultiHit {
#[serde(rename = "_index")]
index: String,
#[serde(rename = "_id")]
id: String,
#[serde(rename = "_score", default)]
score: Option<f32>,
#[serde(rename = "_source")]
source: Value,
}