use std::marker::PhantomData;
use std::time::Duration;
use serde::Deserialize;
use serde::de::DeserializeOwned;
use serde_json::{Map, Value};
use crate::Client;
use crate::error::Result;
use crate::handles::{NestedProjection, Sort};
use crate::query::{AsQuery, BoolBuilder, Root};
pub trait FlussoDocument: DeserializeOwned {
const INDEX: &'static str;
const SCHEMA_HASH: &'static str;
fn physical_index() -> String {
format!("{}_{}", Self::INDEX, Self::SCHEMA_HASH)
}
fn query() -> Search<Self> {
Search::new(Self::INDEX, Self::SCHEMA_HASH)
}
fn get(
client: &Client,
id: impl std::fmt::Display,
) -> impl std::future::Future<Output = Result<Option<Self>>> {
client.get_one::<Self>(Self::INDEX, Self::SCHEMA_HASH, id)
}
}
#[derive(Debug, Clone)]
pub struct Search<T> {
index: String,
hash: String,
bool_query: BoolBuilder,
raw: Option<Value>,
sort: Vec<Sort>,
from: Option<u64>,
size: Option<u64>,
nested: Vec<NestedProjection>,
_marker: PhantomData<fn() -> T>,
}
impl<T> Search<T> {
pub fn new(index: impl Into<String>, hash: impl Into<String>) -> Self {
Self {
index: index.into(),
hash: hash.into(),
bool_query: BoolBuilder::default(),
raw: None,
sort: Vec::new(),
from: None,
size: None,
nested: Vec::new(),
_marker: PhantomData,
}
}
#[must_use]
pub fn query(mut self, query: impl AsQuery<Root>) -> Self {
if let Some(query) = query.into_query() {
self.bool_query.push_must(query.into_inner());
}
self
}
#[must_use]
pub fn filter(mut self, query: impl AsQuery<Root>) -> Self {
if let Some(query) = query.into_query() {
self.bool_query.push_filter(query.into_inner());
}
self
}
#[must_use]
pub fn must_not(mut self, query: impl AsQuery<Root>) -> Self {
if let Some(query) = query.into_query() {
self.bool_query.push_must_not(query.into_inner());
}
self
}
#[must_use]
pub fn should(mut self, query: impl AsQuery<Root>) -> Self {
if let Some(query) = query.into_query() {
self.bool_query.push_should(query.into_inner());
}
self
}
#[must_use]
pub fn sort(mut self, sort: Sort) -> Self {
self.sort.push(sort);
self
}
#[must_use]
pub fn from(mut self, from: u64) -> Self {
self.from = Some(from);
self
}
#[must_use]
pub fn size(mut self, size: u64) -> Self {
self.size = Some(size);
self
}
#[must_use]
pub fn raw(mut self, query: Value) -> Self {
self.raw = Some(query);
self
}
#[must_use]
pub fn filter_nested(mut self, projection: NestedProjection) -> Self {
self.nested.push(projection);
self
}
fn query_value(&self) -> Value {
match &self.raw {
Some(raw) => raw.clone(),
None if self.bool_query.is_empty() => crate::handles::match_all_value(),
None => self.bool_query.to_value(),
}
}
#[must_use]
pub fn body(&self) -> Value {
let query = self.query_value();
let query = if self.nested.is_empty() {
query
} else {
let mut bool_body = Map::new();
bool_body.insert("must".to_string(), Value::Array(vec![query]));
let shoulds = self.nested.iter().map(NestedProjection::to_value).collect();
bool_body.insert("should".to_string(), Value::Array(shoulds));
let mut outer = Map::new();
outer.insert("bool".to_string(), Value::Object(bool_body));
Value::Object(outer)
};
let mut root = Map::new();
root.insert("query".to_string(), query);
self.insert_page_params(&mut root);
Value::Object(root)
}
fn insert_page_params(&self, root: &mut Map<String, Value>) {
if !self.sort.is_empty() {
let keys = self.sort.iter().map(Sort::to_value).collect();
root.insert("sort".to_string(), Value::Array(keys));
}
if let Some(from) = self.from {
root.insert("from".to_string(), Value::from(from));
}
if let Some(size) = self.size {
root.insert("size".to_string(), Value::from(size));
}
}
#[must_use]
pub fn count_body(&self) -> Value {
let mut root = Map::new();
root.insert("query".to_string(), self.query_value());
Value::Object(root)
}
#[must_use]
pub fn ids_body(&self) -> Value {
let mut root = Map::new();
root.insert("query".to_string(), self.query_value());
self.insert_page_params(&mut root);
root.insert("_source".to_string(), Value::Bool(false));
Value::Object(root)
}
#[tracing::instrument(
name = "search.ids",
skip_all,
fields(index = %self.index, returned = tracing::field::Empty),
err,
)]
pub async fn ids(&self, client: &Client) -> Result<Vec<String>> {
let body = self.ids_body();
let response = client.search_at(&self.physical_index(), &body).await?;
let raw: RawIdsResponse = serde_json::from_value(response)?;
let ids: Vec<String> = raw.hits.hits.into_iter().map(|hit| hit.id).collect();
tracing::Span::current().record("returned", ids.len());
tracing::debug!(returned = ids.len(), "ids search completed");
Ok(ids)
}
pub(crate) fn physical_index(&self) -> String {
format!("{}_{}", self.index, self.hash)
}
pub(crate) fn nested_paths(&self) -> Vec<&str> {
self.nested.iter().map(NestedProjection::path).collect()
}
#[tracing::instrument(
name = "search.count",
skip_all,
fields(index = %self.index, count = tracing::field::Empty),
err,
)]
pub async fn count(&self, client: &Client) -> Result<u64> {
let body = self.count_body();
let response = client.count_at(&self.physical_index(), &body).await?;
let raw: RawCount = serde_json::from_value(response)?;
tracing::Span::current().record("count", raw.count);
tracing::debug!(count = raw.count, "count completed");
Ok(raw.count)
}
}
impl<T> Search<T>
where
T: DeserializeOwned,
{
#[tracing::instrument(
name = "search.send",
skip_all,
fields(
index = %self.index,
from = ?self.from,
size = ?self.size,
total = tracing::field::Empty,
took_ms = tracing::field::Empty,
),
err,
)]
pub async fn send(&self, client: &Client) -> Result<SearchResponse<T>> {
let body = self.body();
let mut response = client.search_at(&self.physical_index(), &body).await?;
let paths = self.nested_paths();
if !paths.is_empty() {
merge_inner_hits(&mut response, &paths);
}
let page = SearchResponse::from_value(response)?;
let span = tracing::Span::current();
span.record("total", page.total);
span.record("took_ms", page.took.as_millis() as u64);
tracing::debug!(
total = page.total,
hits = page.hits.len(),
"search completed"
);
Ok(page)
}
}
pub(crate) fn merge_inner_hits(response: &mut Value, paths: &[&str]) {
let Some(hits) = response
.get_mut("hits")
.and_then(|hits| hits.get_mut("hits"))
.and_then(Value::as_array_mut)
else {
return;
};
for hit in hits {
let inner = match hit.get("inner_hits") {
Some(inner) => inner.clone(),
None => continue,
};
let Some(source) = hit.get_mut("_source").and_then(Value::as_object_mut) else {
continue;
};
for path in paths {
let subset: Vec<Value> = inner
.get(*path)
.and_then(|hit| hit.get("hits"))
.and_then(|hits| hits.get("hits"))
.and_then(Value::as_array)
.map(|hits| {
hits.iter()
.filter_map(|h| h.get("_source").cloned())
.collect()
})
.unwrap_or_default();
source.insert((*path).to_string(), Value::Array(subset));
}
}
}
#[derive(Debug)]
pub struct SearchResponse<T> {
pub total: u64,
pub max_score: Option<f32>,
pub hits: Vec<Hit<T>>,
pub took: Duration,
}
impl<T> SearchResponse<T>
where
T: DeserializeOwned,
{
pub fn from_value(value: Value) -> Result<Self> {
let raw: RawResponse<T> = serde_json::from_value(value)?;
let hits = raw
.hits
.hits
.into_iter()
.map(|hit| Hit {
id: hit.id,
score: hit.score.unwrap_or(0.0),
source: hit.source,
})
.collect();
Ok(Self {
total: raw.hits.total.value,
max_score: raw.hits.max_score,
hits,
took: Duration::from_millis(raw.took),
})
}
}
#[derive(Debug)]
pub struct Hit<T> {
pub id: String,
pub score: f32,
pub source: T,
}
#[derive(Deserialize)]
struct RawResponse<T> {
#[serde(default)]
took: u64,
hits: RawHits<T>,
}
#[derive(Deserialize)]
struct RawHits<T> {
total: RawTotal,
#[serde(default)]
max_score: Option<f32>,
hits: Vec<RawHit<T>>,
}
#[derive(Deserialize)]
struct RawTotal {
value: u64,
}
#[derive(Deserialize)]
pub(crate) struct RawCount {
pub(crate) count: u64,
}
#[derive(Deserialize)]
struct RawIdsResponse {
hits: RawIdsHits,
}
#[derive(Deserialize)]
struct RawIdsHits {
hits: Vec<RawIdHit>,
}
#[derive(Deserialize)]
struct RawIdHit {
#[serde(rename = "_id")]
id: String,
}
#[derive(Deserialize)]
struct RawHit<T> {
#[serde(rename = "_id")]
id: String,
#[serde(rename = "_score", default)]
score: Option<f32>,
#[serde(rename = "_source")]
source: T,
}