1use crate::core::LuciError;
13
14use crate::agg::AggregationExpression;
15use crate::query::ast::{QueryExpression, ScoringExpression};
16use crate::query::parser::{opt_f64, opt_str, opt_u64, parse_query, parse_query_expression};
17use crate::search::{SortField, SortValue, TrackTotalHits};
18
19const SEARCH_LEVEL_KEYS: &[&str] = &[
32 "query",
33 "aggs",
34 "aggregations",
35 "size",
36 "from",
37 "sort",
38 "search_after",
39 "collapse",
40 "track_total_hits",
41 "rescore",
42 "_source",
43 "fields",
44];
45
46fn is_bare_query(json: &serde_json::Value) -> bool {
51 let Some(obj) = json.as_object() else {
52 return false;
53 };
54 if obj.len() != 1 {
55 return false;
56 }
57 let key = obj.keys().next().expect("checked len == 1");
58 !SEARCH_LEVEL_KEYS.contains(&key.as_str())
59}
60
61pub(crate) fn validate_obj_keys<'a>(
68 val: &'a serde_json::Value,
69 expected: &[&str],
70 ctx: &str,
71) -> crate::core::Result<&'a serde_json::Map<String, serde_json::Value>> {
72 let obj = val
73 .as_object()
74 .ok_or_else(|| crate::core::LuciError::InvalidQuery(format!("{ctx}: must be an object")))?;
75 for key in obj.keys() {
76 if !expected.contains(&key.as_str()) {
77 let expected_list = expected
78 .iter()
79 .map(|k| format!("`{k}`"))
80 .collect::<Vec<_>>()
81 .join(", ");
82 return Err(crate::core::LuciError::InvalidQuery(format!(
83 "{ctx}: unknown field `{key}`, expected one of {expected_list}"
84 )));
85 }
86 }
87 Ok(obj)
88}
89
90fn validate_search_keys(
93 obj: &serde_json::Map<String, serde_json::Value>,
94) -> crate::core::Result<()> {
95 for key in obj.keys() {
96 if !SEARCH_LEVEL_KEYS.contains(&key.as_str()) {
97 let expected = SEARCH_LEVEL_KEYS
98 .iter()
99 .map(|k| format!("`{k}`"))
100 .collect::<Vec<_>>()
101 .join(", ");
102 return Err(crate::core::LuciError::InvalidQuery(format!(
103 "invalid search request: unknown field `{key}`, expected one of {expected}"
104 )));
105 }
106 }
107 Ok(())
108}
109
110pub struct SearchExpression {
118 pub(crate) query: Option<QueryExpression>,
120 pub(crate) aggs: Vec<(String, AggregationExpression)>,
122 pub(crate) size: usize,
124 pub(crate) from: usize,
126 pub(crate) sort: Option<Vec<SortField>>,
128 pub(crate) collapse: Option<String>,
130 pub(crate) search_after: Option<Vec<SortValue>>,
132 pub(crate) track_total_hits: TrackTotalHits,
134 pub(crate) rescore: Option<RescoreSpec>,
136}
137
138pub struct RescoreSpec {
140 pub(crate) query: Box<dyn crate::query::Query>,
141 pub window_size: usize,
142 pub query_weight: f32,
143 pub rescore_query_weight: f32,
144 pub score_mode: crate::search::RescoreScoreMode,
145}
146
147impl SearchExpression {
148 pub fn new() -> Self {
150 Self {
151 query: None,
152 aggs: Vec::new(),
153 size: 10,
154 from: 0,
155 sort: None,
156 collapse: None,
157 search_after: None,
158 track_total_hits: TrackTotalHits::Exact,
159 rescore: None,
160 }
161 }
162
163 pub fn query(mut self, query: QueryExpression) -> Self {
165 self.query = Some(query);
166 self
167 }
168
169 pub fn scoring_query(mut self, query: ScoringExpression) -> Self {
171 self.query = Some(QueryExpression::Scoring(query));
172 self
173 }
174
175 pub fn agg(mut self, name: impl Into<String>, agg: AggregationExpression) -> Self {
177 self.aggs.push((name.into(), agg));
178 self
179 }
180
181 pub fn size(mut self, size: usize) -> Self {
183 self.size = size;
184 self
185 }
186
187 pub fn from(mut self, from: usize) -> Self {
189 self.from = from;
190 self
191 }
192
193 pub fn sort(mut self, sort: Vec<SortField>) -> Self {
195 self.sort = Some(sort);
196 self
197 }
198
199 pub fn collapse(mut self, field: impl Into<String>) -> Self {
201 self.collapse = Some(field.into());
202 self
203 }
204
205 pub fn search_after(mut self, cursor: Vec<SortValue>) -> Self {
207 self.search_after = Some(cursor);
208 self
209 }
210
211 pub fn track_total_hits(mut self, mode: TrackTotalHits) -> Self {
213 self.track_total_hits = mode;
214 self
215 }
216
217 pub fn rescore(mut self, rescore: RescoreSpec) -> Self {
219 self.rescore = Some(rescore);
220 self
221 }
222}
223
224impl Default for SearchExpression {
225 fn default() -> Self {
226 Self::new()
227 }
228}
229
230pub fn parse_search(
238 json: serde_json::Value,
239 default_size: usize,
240) -> Result<SearchExpression, crate::core::LuciError> {
241 SearchExpression::from_json(json, default_size)
242}
243
244impl SearchExpression {
245 pub fn from_json(
250 json: serde_json::Value,
251 default_size: usize,
252 ) -> Result<SearchExpression, crate::core::LuciError> {
253 let mut expr = SearchExpression::new();
254
255 if !json.is_object() || is_bare_query(&json) {
260 expr.query = Some(parse_query_expression(&json)?);
261 expr.size = default_size;
262 return Ok(expr);
263 }
264
265 let json_obj = json.as_object().expect("is_object checked above");
269 validate_search_keys(json_obj)?;
270
271 if let Some(q) = json.get("query") {
272 expr.query = Some(parse_query_expression(q)?);
273 }
274
275 if let Some(aggs_json) = json.get("aggs").or_else(|| json.get("aggregations")) {
276 expr.aggs = crate::agg::parser::parse_aggs(aggs_json)?;
277 }
278
279 expr.size = opt_u64(json_obj, "size", "search")?
280 .map(|v| v as usize)
281 .unwrap_or(default_size);
282 expr.from = opt_u64(json_obj, "from", "search")?
283 .map(|v| v as usize)
284 .unwrap_or(0);
285
286 expr.sort = crate::index::parse_sort(json.get("sort"))?;
287 expr.search_after = crate::index::parse_search_after(json.get("search_after"))?;
288
289 if let Some(collapse_val) = json.get("collapse") {
290 let obj = validate_obj_keys(collapse_val, &["field"], "collapse")?;
294 expr.collapse = opt_str(obj, "field", "collapse")?.map(String::from);
295 }
296
297 expr.track_total_hits = match json.get("track_total_hits") {
298 Some(serde_json::Value::Bool(true)) | None => TrackTotalHits::Exact,
299 Some(serde_json::Value::Bool(false)) => TrackTotalHits::Disabled,
300 Some(serde_json::Value::Number(n)) => {
301 TrackTotalHits::UpTo(n.as_u64().ok_or_else(|| {
302 LuciError::InvalidQuery(
303 "track_total_hits: integer count must be a non-negative integer".into(),
304 )
305 })?)
306 }
307 Some(other) => {
308 return Err(LuciError::InvalidQuery(format!(
309 "track_total_hits: must be a boolean or integer, got {other}"
310 )));
311 }
312 };
313
314 if let Some(rescore_val) = json.get("rescore") {
315 let rescore_obj = validate_obj_keys(rescore_val, &["window_size", "query"], "rescore")?;
316 let window_size = opt_u64(rescore_obj, "window_size", "rescore")?
317 .map(|v| v as usize)
318 .unwrap_or(10);
319 let inner_query = rescore_obj.get("query");
320 let inner_obj = match inner_query {
321 Some(v) => Some(validate_obj_keys(
322 v,
323 &[
324 "rescore_query",
325 "query_weight",
326 "rescore_query_weight",
327 "score_mode",
328 ],
329 "rescore.query",
330 )?),
331 None => None,
332 };
333 if let Some(rq) = inner_obj.and_then(|o| o.get("rescore_query")) {
334 let rescore_query: Box<dyn crate::query::Query> = Box::new(parse_query(rq)?);
335 let inner = inner_obj.expect("inner_obj checked above");
336 let query_weight = opt_f64(inner, "query_weight", "rescore.query")?
337 .map(|v| v as f32)
338 .unwrap_or(1.0);
339 let rescore_query_weight = opt_f64(inner, "rescore_query_weight", "rescore.query")?
340 .map(|v| v as f32)
341 .unwrap_or(1.0);
342 let score_mode = match opt_str(inner, "score_mode", "rescore.query")? {
343 Some("multiply") => crate::search::RescoreScoreMode::Multiply,
344 Some("avg") => crate::search::RescoreScoreMode::Avg,
345 Some("max") => crate::search::RescoreScoreMode::Max,
346 Some("min") => crate::search::RescoreScoreMode::Min,
347 Some("total") | None => crate::search::RescoreScoreMode::Total,
348 Some(other) => {
349 return Err(crate::core::LuciError::InvalidQuery(format!(
350 "rescore.query.score_mode: unknown value '{other}', expected \
351 one of `total`, `multiply`, `avg`, `max`, `min`"
352 )));
353 }
354 };
355 expr.rescore = Some(RescoreSpec {
356 query: rescore_query,
357 window_size,
358 query_weight,
359 rescore_query_weight,
360 score_mode,
361 });
362 }
363 }
364
365 Ok(expr)
366 }
367}