use std::marker::PhantomData;
use std::time::Duration;
use serde::Deserialize;
use serde::de::DeserializeOwned;
use serde_json::{Map, Value};
use crate::Client;
use crate::error::Result;
use crate::handles::{NestedProjection, Sort};
use crate::query::{AsQuery, BoolBuilder, Root};
pub trait FlussoDocument: DeserializeOwned {
const INDEX: &'static str;
const SCHEMA_HASH: &'static str;
fn physical_index() -> String {
format!("{}_{}", Self::INDEX, Self::SCHEMA_HASH)
}
fn query() -> Search<Self> {
Search::new(Self::INDEX, Self::SCHEMA_HASH)
}
fn get(
client: &Client,
id: impl std::fmt::Display,
) -> impl std::future::Future<Output = Result<Option<Self>>> {
client.get_one::<Self>(Self::INDEX, Self::SCHEMA_HASH, id)
}
}
#[derive(Debug, Clone)]
pub struct Search<T> {
index: String,
hash: String,
bool_query: BoolBuilder,
raw: Option<Value>,
sort: Vec<Sort>,
from: Option<u64>,
size: Option<u64>,
nested: Vec<NestedProjection>,
min_score: Option<f32>,
track_total_hits: Option<Value>,
track_scores: Option<bool>,
search_after: Option<Vec<Value>>,
collapse: Option<Value>,
post_filter: Option<Value>,
highlight: Option<Highlight>,
_marker: PhantomData<fn() -> T>,
}
impl<T> Search<T> {
pub fn new(index: impl Into<String>, hash: impl Into<String>) -> Self {
Self {
index: index.into(),
hash: hash.into(),
bool_query: BoolBuilder::default(),
raw: None,
sort: Vec::new(),
from: None,
size: None,
nested: Vec::new(),
min_score: None,
track_total_hits: None,
track_scores: None,
search_after: None,
collapse: None,
post_filter: None,
highlight: None,
_marker: PhantomData,
}
}
#[must_use]
pub fn query(mut self, query: impl AsQuery<Root>) -> Self {
if let Some(query) = query.into_query() {
self.bool_query.push_must(query.into_inner());
}
self
}
#[must_use]
pub fn filter(mut self, query: impl AsQuery<Root>) -> Self {
if let Some(query) = query.into_query() {
self.bool_query.push_filter(query.into_inner());
}
self
}
#[must_use]
pub fn must_not(mut self, query: impl AsQuery<Root>) -> Self {
if let Some(query) = query.into_query() {
self.bool_query.push_must_not(query.into_inner());
}
self
}
#[must_use]
pub fn should(mut self, query: impl AsQuery<Root>) -> Self {
if let Some(query) = query.into_query() {
self.bool_query.push_should(query.into_inner());
}
self
}
#[must_use]
pub fn min_should_match(mut self, value: impl Into<Value>) -> Self {
self.bool_query.set_min_should_match(value.into());
self
}
#[must_use]
pub fn sort(mut self, sort: Sort) -> Self {
self.sort.push(sort);
self
}
#[must_use]
pub fn min_score(mut self, min_score: f32) -> Self {
self.min_score = Some(min_score);
self
}
#[must_use]
pub fn track_total_hits(mut self, track: impl Into<Value>) -> Self {
self.track_total_hits = Some(track.into());
self
}
#[must_use]
pub fn track_scores(mut self, track: bool) -> Self {
self.track_scores = Some(track);
self
}
#[must_use]
pub fn search_after(mut self, values: impl IntoIterator<Item = impl Into<Value>>) -> Self {
self.search_after = Some(values.into_iter().map(Into::into).collect());
self
}
#[must_use]
pub fn collapse(mut self, field: impl Into<String>) -> Self {
let mut body = Map::new();
body.insert("field".to_string(), Value::String(field.into()));
self.collapse = Some(Value::Object(body));
self
}
#[must_use]
pub fn post_filter(mut self, query: impl AsQuery<Root>) -> Self {
if let Some(query) = query.into_query() {
self.post_filter = Some(query.to_value());
}
self
}
#[must_use]
pub fn highlight(mut self, highlight: Highlight) -> Self {
self.highlight = Some(highlight);
self
}
#[must_use]
pub fn from(mut self, from: u64) -> Self {
self.from = Some(from);
self
}
#[must_use]
pub fn size(mut self, size: u64) -> Self {
self.size = Some(size);
self
}
#[must_use]
pub fn raw(mut self, query: Value) -> Self {
self.raw = Some(query);
self
}
#[must_use]
pub fn filter_nested(mut self, projection: NestedProjection) -> Self {
self.nested.push(projection);
self
}
fn query_value(&self) -> Value {
match &self.raw {
Some(raw) => raw.clone(),
None if self.bool_query.is_empty() => crate::handles::match_all_value(),
None => self.bool_query.to_value(),
}
}
#[must_use]
pub fn body(&self) -> Value {
let query = self.query_value();
let query = if self.nested.is_empty() {
query
} else {
let mut bool_body = Map::new();
bool_body.insert("must".to_string(), Value::Array(vec![query]));
let shoulds = self.nested.iter().map(NestedProjection::to_value).collect();
bool_body.insert("should".to_string(), Value::Array(shoulds));
let mut outer = Map::new();
outer.insert("bool".to_string(), Value::Object(bool_body));
Value::Object(outer)
};
let mut root = Map::new();
root.insert("query".to_string(), query);
self.insert_page_params(&mut root);
self.insert_search_level(&mut root, true);
Value::Object(root)
}
fn insert_page_params(&self, root: &mut Map<String, Value>) {
if !self.sort.is_empty() {
let keys = self.sort.iter().map(Sort::to_value).collect();
root.insert("sort".to_string(), Value::Array(keys));
}
if let Some(from) = self.from {
root.insert("from".to_string(), Value::from(from));
}
if let Some(size) = self.size {
root.insert("size".to_string(), Value::from(size));
}
}
fn insert_search_level(&self, root: &mut Map<String, Value>, with_highlight: bool) {
if let Some(min_score) = self.min_score {
root.insert("min_score".to_string(), Value::from(min_score));
}
if let Some(track) = &self.track_total_hits {
root.insert("track_total_hits".to_string(), track.clone());
}
if let Some(track) = self.track_scores {
root.insert("track_scores".to_string(), Value::Bool(track));
}
if let Some(values) = &self.search_after {
root.insert("search_after".to_string(), Value::Array(values.clone()));
}
if let Some(collapse) = &self.collapse {
root.insert("collapse".to_string(), collapse.clone());
}
if let Some(post_filter) = &self.post_filter {
root.insert("post_filter".to_string(), post_filter.clone());
}
if with_highlight && let Some(highlight) = &self.highlight {
root.insert("highlight".to_string(), highlight.to_value());
}
}
#[must_use]
pub fn count_body(&self) -> Value {
let mut root = Map::new();
root.insert("query".to_string(), self.query_value());
Value::Object(root)
}
#[must_use]
pub fn ids_body(&self) -> Value {
let mut root = Map::new();
root.insert("query".to_string(), self.query_value());
self.insert_page_params(&mut root);
self.insert_search_level(&mut root, false);
root.insert("_source".to_string(), Value::Bool(false));
Value::Object(root)
}
#[tracing::instrument(
name = "search.ids",
skip_all,
fields(index = %self.index, returned = tracing::field::Empty),
err,
)]
pub async fn ids(&self, client: &Client) -> Result<Vec<String>> {
let body = self.ids_body();
let response = client.search_at(&self.physical_index(), &body).await?;
let raw: RawIdsResponse = serde_json::from_value(response)?;
let ids: Vec<String> = raw.hits.hits.into_iter().map(|hit| hit.id).collect();
tracing::Span::current().record("returned", ids.len());
tracing::debug!(returned = ids.len(), "ids search completed");
Ok(ids)
}
pub(crate) fn physical_index(&self) -> String {
format!("{}_{}", self.index, self.hash)
}
pub(crate) fn nested_paths(&self) -> Vec<&str> {
self.nested.iter().map(NestedProjection::path).collect()
}
#[tracing::instrument(
name = "search.count",
skip_all,
fields(index = %self.index, count = tracing::field::Empty),
err,
)]
pub async fn count(&self, client: &Client) -> Result<u64> {
let body = self.count_body();
let response = client.count_at(&self.physical_index(), &body).await?;
let raw: RawCount = serde_json::from_value(response)?;
tracing::Span::current().record("count", raw.count);
tracing::debug!(count = raw.count, "count completed");
Ok(raw.count)
}
}
impl<T> Search<T>
where
T: DeserializeOwned,
{
#[tracing::instrument(
name = "search.send",
skip_all,
fields(
index = %self.index,
from = ?self.from,
size = ?self.size,
total = tracing::field::Empty,
took_ms = tracing::field::Empty,
),
err,
)]
pub async fn send(&self, client: &Client) -> Result<SearchResponse<T>> {
let body = self.body();
let mut response = client.search_at(&self.physical_index(), &body).await?;
let paths = self.nested_paths();
if !paths.is_empty() {
merge_inner_hits(&mut response, &paths);
}
let page = SearchResponse::from_value(response)?;
let span = tracing::Span::current();
span.record("total", page.total);
span.record("took_ms", page.took.as_millis() as u64);
tracing::debug!(
total = page.total,
hits = page.hits.len(),
"search completed"
);
Ok(page)
}
}
pub(crate) fn merge_inner_hits(response: &mut Value, paths: &[&str]) {
let Some(hits) = response
.get_mut("hits")
.and_then(|hits| hits.get_mut("hits"))
.and_then(Value::as_array_mut)
else {
return;
};
for hit in hits {
let inner = match hit.get("inner_hits") {
Some(inner) => inner.clone(),
None => continue,
};
let Some(source) = hit.get_mut("_source").and_then(Value::as_object_mut) else {
continue;
};
for path in paths {
let subset: Vec<Value> = inner
.get(*path)
.and_then(|hit| hit.get("hits"))
.and_then(|hits| hits.get("hits"))
.and_then(Value::as_array)
.map(|hits| {
hits.iter()
.filter_map(|h| h.get("_source").cloned())
.collect()
})
.unwrap_or_default();
source.insert((*path).to_string(), Value::Array(subset));
}
}
}
#[derive(Debug, Clone, Default)]
pub struct Highlight {
fields: Map<String, Value>,
opts: Map<String, Value>,
}
impl Highlight {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn field(mut self, field: impl Into<String>) -> Self {
self.fields.insert(field.into(), Value::Object(Map::new()));
self
}
#[must_use]
pub fn field_with(mut self, field: impl Into<String>, settings: Value) -> Self {
self.fields.insert(field.into(), settings);
self
}
#[must_use]
pub fn pre_tags(mut self, tags: impl IntoIterator<Item = impl Into<String>>) -> Self {
self.opts.insert(
"pre_tags".to_string(),
Value::Array(tags.into_iter().map(|t| Value::String(t.into())).collect()),
);
self
}
#[must_use]
pub fn post_tags(mut self, tags: impl IntoIterator<Item = impl Into<String>>) -> Self {
self.opts.insert(
"post_tags".to_string(),
Value::Array(tags.into_iter().map(|t| Value::String(t.into())).collect()),
);
self
}
#[must_use]
pub fn fragment_size(mut self, fragment_size: u32) -> Self {
self.opts
.insert("fragment_size".to_string(), Value::from(fragment_size));
self
}
#[must_use]
pub fn number_of_fragments(mut self, number_of_fragments: u32) -> Self {
self.opts.insert(
"number_of_fragments".to_string(),
Value::from(number_of_fragments),
);
self
}
#[must_use]
pub fn require_field_match(mut self, require: bool) -> Self {
self.opts
.insert("require_field_match".to_string(), Value::Bool(require));
self
}
fn to_value(&self) -> Value {
let mut body = self.opts.clone();
body.insert("fields".to_string(), Value::Object(self.fields.clone()));
Value::Object(body)
}
}
#[derive(Debug)]
pub struct SearchResponse<T> {
pub total: u64,
pub max_score: Option<f32>,
pub hits: Vec<Hit<T>>,
pub took: Duration,
}
impl<T> SearchResponse<T>
where
T: DeserializeOwned,
{
pub fn from_value(value: Value) -> Result<Self> {
let raw: RawResponse<T> = serde_json::from_value(value)?;
let hits = raw
.hits
.hits
.into_iter()
.map(|hit| Hit {
id: hit.id,
score: hit.score.unwrap_or(0.0),
source: hit.source,
})
.collect();
Ok(Self {
total: raw.hits.total.value,
max_score: raw.hits.max_score,
hits,
took: Duration::from_millis(raw.took),
})
}
}
#[derive(Debug)]
pub struct Hit<T> {
pub id: String,
pub score: f32,
pub source: T,
}
#[derive(Deserialize)]
struct RawResponse<T> {
#[serde(default)]
took: u64,
hits: RawHits<T>,
}
#[derive(Deserialize)]
struct RawHits<T> {
total: RawTotal,
#[serde(default)]
max_score: Option<f32>,
hits: Vec<RawHit<T>>,
}
#[derive(Deserialize)]
struct RawTotal {
value: u64,
}
#[derive(Deserialize)]
pub(crate) struct RawCount {
pub(crate) count: u64,
}
#[derive(Deserialize)]
struct RawIdsResponse {
hits: RawIdsHits,
}
#[derive(Deserialize)]
struct RawIdsHits {
hits: Vec<RawIdHit>,
}
#[derive(Deserialize)]
struct RawIdHit {
#[serde(rename = "_id")]
id: String,
}
#[derive(Deserialize)]
struct RawHit<T> {
#[serde(rename = "_id")]
id: String,
#[serde(rename = "_score", default)]
score: Option<f32>,
#[serde(rename = "_source")]
source: T,
}