use crate::core::LuciError;
use crate::agg::AggregationExpression;
use crate::query::ast::{QueryExpression, ScoringExpression};
use crate::query::parser::{opt_f64, opt_str, opt_u64, parse_query, parse_query_expression};
use crate::search::{SortField, SortValue, TrackTotalHits};
const SEARCH_LEVEL_KEYS: &[&str] = &[
"query",
"aggs",
"aggregations",
"size",
"from",
"sort",
"search_after",
"collapse",
"track_total_hits",
"rescore",
"_source",
"fields",
];
fn is_bare_query(json: &serde_json::Value) -> bool {
let Some(obj) = json.as_object() else {
return false;
};
if obj.len() != 1 {
return false;
}
let key = obj.keys().next().expect("checked len == 1");
!SEARCH_LEVEL_KEYS.contains(&key.as_str())
}
pub(crate) fn validate_obj_keys<'a>(
val: &'a serde_json::Value,
expected: &[&str],
ctx: &str,
) -> crate::core::Result<&'a serde_json::Map<String, serde_json::Value>> {
let obj = val
.as_object()
.ok_or_else(|| crate::core::LuciError::InvalidQuery(format!("{ctx}: must be an object")))?;
for key in obj.keys() {
if !expected.contains(&key.as_str()) {
let expected_list = expected
.iter()
.map(|k| format!("`{k}`"))
.collect::<Vec<_>>()
.join(", ");
return Err(crate::core::LuciError::InvalidQuery(format!(
"{ctx}: unknown field `{key}`, expected one of {expected_list}"
)));
}
}
Ok(obj)
}
fn validate_search_keys(
obj: &serde_json::Map<String, serde_json::Value>,
) -> crate::core::Result<()> {
for key in obj.keys() {
if !SEARCH_LEVEL_KEYS.contains(&key.as_str()) {
let expected = SEARCH_LEVEL_KEYS
.iter()
.map(|k| format!("`{k}`"))
.collect::<Vec<_>>()
.join(", ");
return Err(crate::core::LuciError::InvalidQuery(format!(
"invalid search request: unknown field `{key}`, expected one of {expected}"
)));
}
}
Ok(())
}
pub struct SearchExpression {
pub(crate) query: Option<QueryExpression>,
pub(crate) aggs: Vec<(String, AggregationExpression)>,
pub(crate) size: usize,
pub(crate) from: usize,
pub(crate) sort: Option<Vec<SortField>>,
pub(crate) collapse: Option<String>,
pub(crate) search_after: Option<Vec<SortValue>>,
pub(crate) track_total_hits: TrackTotalHits,
pub(crate) rescore: Option<RescoreSpec>,
}
pub struct RescoreSpec {
pub(crate) query: Box<dyn crate::query::Query>,
pub window_size: usize,
pub query_weight: f32,
pub rescore_query_weight: f32,
pub score_mode: crate::search::RescoreScoreMode,
}
impl SearchExpression {
pub fn new() -> Self {
Self {
query: None,
aggs: Vec::new(),
size: 10,
from: 0,
sort: None,
collapse: None,
search_after: None,
track_total_hits: TrackTotalHits::Exact,
rescore: None,
}
}
pub fn query(mut self, query: QueryExpression) -> Self {
self.query = Some(query);
self
}
pub fn scoring_query(mut self, query: ScoringExpression) -> Self {
self.query = Some(QueryExpression::Scoring(query));
self
}
pub fn agg(mut self, name: impl Into<String>, agg: AggregationExpression) -> Self {
self.aggs.push((name.into(), agg));
self
}
pub fn size(mut self, size: usize) -> Self {
self.size = size;
self
}
pub fn from(mut self, from: usize) -> Self {
self.from = from;
self
}
pub fn sort(mut self, sort: Vec<SortField>) -> Self {
self.sort = Some(sort);
self
}
pub fn collapse(mut self, field: impl Into<String>) -> Self {
self.collapse = Some(field.into());
self
}
pub fn search_after(mut self, cursor: Vec<SortValue>) -> Self {
self.search_after = Some(cursor);
self
}
pub fn track_total_hits(mut self, mode: TrackTotalHits) -> Self {
self.track_total_hits = mode;
self
}
pub fn rescore(mut self, rescore: RescoreSpec) -> Self {
self.rescore = Some(rescore);
self
}
}
impl Default for SearchExpression {
fn default() -> Self {
Self::new()
}
}
pub fn parse_search(
json: serde_json::Value,
default_size: usize,
) -> Result<SearchExpression, crate::core::LuciError> {
SearchExpression::from_json(json, default_size)
}
impl SearchExpression {
pub fn from_json(
json: serde_json::Value,
default_size: usize,
) -> Result<SearchExpression, crate::core::LuciError> {
let mut expr = SearchExpression::new();
if !json.is_object() || is_bare_query(&json) {
expr.query = Some(parse_query_expression(&json)?);
expr.size = default_size;
return Ok(expr);
}
let json_obj = json.as_object().expect("is_object checked above");
validate_search_keys(json_obj)?;
if let Some(q) = json.get("query") {
expr.query = Some(parse_query_expression(q)?);
}
if let Some(aggs_json) = json.get("aggs").or_else(|| json.get("aggregations")) {
expr.aggs = crate::agg::parser::parse_aggs(aggs_json)?;
}
expr.size = opt_u64(json_obj, "size", "search")?
.map(|v| v as usize)
.unwrap_or(default_size);
expr.from = opt_u64(json_obj, "from", "search")?
.map(|v| v as usize)
.unwrap_or(0);
expr.sort = crate::index::parse_sort(json.get("sort"))?;
expr.search_after = crate::index::parse_search_after(json.get("search_after"))?;
if let Some(collapse_val) = json.get("collapse") {
let obj = validate_obj_keys(collapse_val, &["field"], "collapse")?;
expr.collapse = opt_str(obj, "field", "collapse")?.map(String::from);
}
expr.track_total_hits = match json.get("track_total_hits") {
Some(serde_json::Value::Bool(true)) | None => TrackTotalHits::Exact,
Some(serde_json::Value::Bool(false)) => TrackTotalHits::Disabled,
Some(serde_json::Value::Number(n)) => {
TrackTotalHits::UpTo(n.as_u64().ok_or_else(|| {
LuciError::InvalidQuery(
"track_total_hits: integer count must be a non-negative integer".into(),
)
})?)
}
Some(other) => {
return Err(LuciError::InvalidQuery(format!(
"track_total_hits: must be a boolean or integer, got {other}"
)));
}
};
if let Some(rescore_val) = json.get("rescore") {
let rescore_obj = validate_obj_keys(rescore_val, &["window_size", "query"], "rescore")?;
let window_size = opt_u64(rescore_obj, "window_size", "rescore")?
.map(|v| v as usize)
.unwrap_or(10);
let inner_query = rescore_obj.get("query");
let inner_obj = match inner_query {
Some(v) => Some(validate_obj_keys(
v,
&[
"rescore_query",
"query_weight",
"rescore_query_weight",
"score_mode",
],
"rescore.query",
)?),
None => None,
};
if let Some(rq) = inner_obj.and_then(|o| o.get("rescore_query")) {
let rescore_query: Box<dyn crate::query::Query> = Box::new(parse_query(rq)?);
let inner = inner_obj.expect("inner_obj checked above");
let query_weight = opt_f64(inner, "query_weight", "rescore.query")?
.map(|v| v as f32)
.unwrap_or(1.0);
let rescore_query_weight = opt_f64(inner, "rescore_query_weight", "rescore.query")?
.map(|v| v as f32)
.unwrap_or(1.0);
let score_mode = match opt_str(inner, "score_mode", "rescore.query")? {
Some("multiply") => crate::search::RescoreScoreMode::Multiply,
Some("avg") => crate::search::RescoreScoreMode::Avg,
Some("max") => crate::search::RescoreScoreMode::Max,
Some("min") => crate::search::RescoreScoreMode::Min,
Some("total") | None => crate::search::RescoreScoreMode::Total,
Some(other) => {
return Err(crate::core::LuciError::InvalidQuery(format!(
"rescore.query.score_mode: unknown value '{other}', expected \
one of `total`, `multiply`, `avg`, `max`, `min`"
)));
}
};
expr.rescore = Some(RescoreSpec {
query: rescore_query,
window_size,
query_weight,
rescore_query_weight,
score_mode,
});
}
}
Ok(expr)
}
}