1use crate::core::{LuciError, Result};
11use serde_json::Value;
12
13use super::{AggregationExpression, RangeDef};
14use crate::query::parser::{opt_f64, opt_str, opt_u64, parse_query};
15
16fn validate_keys<'a>(
18 val: &'a Value,
19 expected: &[&str],
20 ctx: &str,
21) -> Result<&'a serde_json::Map<String, Value>> {
22 let obj = val
23 .as_object()
24 .ok_or_else(|| LuciError::InvalidQuery(format!("{ctx}: must be an object")))?;
25 for key in obj.keys() {
26 if !expected.contains(&key.as_str()) {
27 let expected_list = expected
28 .iter()
29 .map(|k| format!("`{k}`"))
30 .collect::<Vec<_>>()
31 .join(", ");
32 return Err(LuciError::InvalidQuery(format!(
33 "{ctx}: unknown field `{key}`, expected one of {expected_list}"
34 )));
35 }
36 }
37 Ok(obj)
38}
39
40pub fn parse_aggs(json: &Value) -> Result<Vec<(String, AggregationExpression)>> {
46 let obj = match json.as_object() {
47 Some(o) => o,
48 None => return Err(LuciError::InvalidQuery("aggs must be an object".into())),
49 };
50
51 let mut aggs = Vec::new();
52 for (name, agg_val) in obj {
53 aggs.push(parse_single_agg(name, agg_val)?);
54 }
55 Ok(aggs)
56}
57
58fn parse_single_agg(name: &str, val: &Value) -> Result<(String, AggregationExpression)> {
59 let obj = val.as_object().ok_or_else(|| {
60 LuciError::InvalidQuery(format!("aggregation '{name}' must be an object"))
61 })?;
62
63 let mut agg_type = None;
65 let mut sub_aggs_val = None;
66
67 for (key, v) in obj {
68 match key.as_str() {
69 "aggs" | "aggregations" => sub_aggs_val = Some(v),
70 _ => {
71 if agg_type.is_some() {
72 return Err(LuciError::InvalidQuery(format!(
73 "aggregation '{name}' has multiple type keys"
74 )));
75 }
76 agg_type = Some((key.as_str(), v));
77 }
78 }
79 }
80
81 let (type_key, type_val) = agg_type
82 .ok_or_else(|| LuciError::InvalidQuery(format!("aggregation '{name}' has no type")))?;
83
84 let sub_aggs = match sub_aggs_val {
85 Some(v) => parse_aggs(v)?,
86 None => Vec::new(),
87 };
88
89 if !sub_aggs.is_empty() && !agg_type_accepts_sub_aggs(type_key) {
98 return Err(LuciError::InvalidQuery(format!(
99 "aggregation '{name}' of type [{type_key}] cannot have sub-aggregations"
100 )));
101 }
102
103 let expr = parse_agg_expr(name, type_key, type_val, sub_aggs)?;
104 Ok((name.to_string(), expr))
105}
106
107fn agg_type_accepts_sub_aggs(type_key: &str) -> bool {
116 matches!(
117 type_key,
118 "terms"
119 | "range"
120 | "date_range"
121 | "histogram"
122 | "date_histogram"
123 | "filter"
124 | "filters"
125 | "nested"
126 | "reverse_nested"
127 | "geohash_grid"
128 )
129}
130
131fn parse_agg_expr(
132 name: &str,
133 key: &str,
134 val: &Value,
135 sub_aggs: Vec<(String, AggregationExpression)>,
136) -> Result<AggregationExpression> {
137 let ctx = format!("{name}.{key}");
138 match key {
139 "avg" => Ok(AggregationExpression::Avg {
140 field: parse_field_only(val, &ctx)?,
141 }),
142 "sum" => Ok(AggregationExpression::Sum {
143 field: parse_field_only(val, &ctx)?,
144 }),
145 "min" => Ok(AggregationExpression::Min {
146 field: parse_field_only(val, &ctx)?,
147 }),
148 "max" => Ok(AggregationExpression::Max {
149 field: parse_field_only(val, &ctx)?,
150 }),
151 "value_count" => Ok(AggregationExpression::ValueCount {
152 field: parse_field_only(val, &ctx)?,
153 }),
154 "stats" => Ok(AggregationExpression::Stats {
155 field: parse_field_only(val, &ctx)?,
156 }),
157 "extended_stats" => Ok(AggregationExpression::ExtendedStats {
158 field: parse_field_only(val, &ctx)?,
159 }),
160 "terms" => {
161 let obj = validate_keys(val, &["field", "size"], &ctx)?;
162 Ok(AggregationExpression::Terms {
163 field: require_field(obj, &ctx)?,
164 size: opt_u64(obj, "size", &ctx)?.unwrap_or(10) as usize,
165 sub_aggs,
166 })
167 }
168 "range" => {
169 let obj = validate_keys(val, &["field", "ranges"], &ctx)?;
170 let field = require_field(obj, &ctx)?;
171 let ranges = parse_range_defs(obj, &ctx, false)?;
172 Ok(AggregationExpression::Range {
173 field,
174 ranges,
175 sub_aggs,
176 })
177 }
178 "histogram" => {
179 let obj = validate_keys(val, &["field", "interval"], &ctx)?;
180 let field = require_field(obj, &ctx)?;
181 let interval = obj
182 .get("interval")
183 .and_then(|v| v.as_f64())
184 .ok_or_else(|| LuciError::InvalidQuery("histogram requires 'interval'".into()))?;
185 Ok(AggregationExpression::Histogram {
186 field,
187 interval,
188 sub_aggs,
189 })
190 }
191 "filter" => {
192 let query = parse_query(val)?;
195 Ok(AggregationExpression::Filter { query, sub_aggs })
196 }
197 "cardinality" => {
198 let obj = validate_keys(val, &["field", "precision_threshold"], &ctx)?;
199 Ok(AggregationExpression::Cardinality {
200 field: require_field(obj, &ctx)?,
201 precision_threshold: opt_u64(obj, "precision_threshold", &ctx)?.unwrap_or(3000)
202 as u32,
203 })
204 }
205 "percentiles" => {
206 let obj = validate_keys(val, &["field", "percents", "tdigest"], &ctx)?;
207 let field = require_field(obj, &ctx)?;
208 let percents = match obj.get("percents") {
209 Some(v) => {
210 let arr = v.as_array().ok_or_else(|| {
211 LuciError::InvalidQuery(
212 "percentiles: \"percents\" must be an array of numbers".into(),
213 )
214 })?;
215 arr.iter()
216 .map(|p| {
217 p.as_f64().ok_or_else(|| {
218 LuciError::InvalidQuery(format!(
219 "percentiles: percents[] entries must be numbers, got {p}"
220 ))
221 })
222 })
223 .collect::<Result<Vec<f64>>>()?
224 }
225 None => vec![1.0, 5.0, 25.0, 50.0, 75.0, 95.0, 99.0],
226 };
227 let compression = match obj.get("tdigest") {
228 Some(t) => {
229 let tdigest_obj = validate_keys(t, &["compression"], "percentiles.tdigest")?;
230 opt_f64(tdigest_obj, "compression", "percentiles.tdigest")?.unwrap_or(100.0)
231 }
232 None => 100.0,
233 };
234 Ok(AggregationExpression::Percentiles {
235 field,
236 percents,
237 compression,
238 })
239 }
240 "geo_bounds" => Ok(AggregationExpression::GeoBounds {
241 field: parse_field_only(val, &ctx)?,
242 }),
243 "geo_centroid" => Ok(AggregationExpression::GeoCentroid {
244 field: parse_field_only(val, &ctx)?,
245 }),
246 "nested" => {
247 let obj = validate_keys(val, &["path"], &ctx)?;
248 let path = obj
249 .get("path")
250 .and_then(|v| v.as_str())
251 .ok_or_else(|| LuciError::InvalidQuery("nested agg requires 'path'".into()))?
252 .to_string();
253 Ok(AggregationExpression::Nested { path, sub_aggs })
254 }
255 "reverse_nested" => {
256 validate_keys(val, &[], &ctx)?;
257 Ok(AggregationExpression::ReverseNested { sub_aggs })
258 }
259 "geohash_grid" => {
260 let obj = validate_keys(val, &["field", "precision", "size"], &ctx)?;
261 Ok(AggregationExpression::GeohashGrid {
262 field: require_field(obj, &ctx)?,
263 precision: opt_u64(obj, "precision", &ctx)?.unwrap_or(5) as usize,
264 size: opt_u64(obj, "size", &ctx)?.unwrap_or(10000) as usize,
265 sub_aggs,
266 })
267 }
268 "top_hits" => {
269 let obj = validate_keys(val, &["size"], &ctx)?;
270 Ok(AggregationExpression::TopHits {
271 size: opt_u64(obj, "size", &ctx)?.unwrap_or(3) as usize,
272 })
273 }
274 "date_histogram" => {
275 let obj = validate_keys(
276 val,
277 &["field", "calendar_interval", "fixed_interval", "interval"],
278 &ctx,
279 )?;
280 let field = require_field(obj, &ctx)?;
281 let interval = if let Some(cal) = opt_str(obj, "calendar_interval", &ctx)? {
282 let cal_int = match cal {
283 "minute" | "1m" => super::CalendarInterval::Minute,
284 "hour" | "1h" => super::CalendarInterval::Hour,
285 "day" | "1d" => super::CalendarInterval::Day,
286 "week" | "1w" => super::CalendarInterval::Week,
287 "month" | "1M" => super::CalendarInterval::Month,
288 "quarter" | "1q" => super::CalendarInterval::Quarter,
289 "year" | "1y" => super::CalendarInterval::Year,
290 other => {
291 return Err(LuciError::InvalidQuery(format!(
292 "date_histogram: unknown calendar_interval '{other}'"
293 )));
294 }
295 };
296 super::DateInterval::Calendar(cal_int)
297 } else if let Some(fixed) = opt_str(obj, "fixed_interval", &ctx)? {
298 let ms = parse_fixed_interval(fixed)?;
299 super::DateInterval::Fixed(ms)
300 } else if let Some(interval_str) = opt_str(obj, "interval", &ctx)? {
301 if let Ok(ms) = parse_fixed_interval(interval_str) {
303 super::DateInterval::Fixed(ms)
304 } else {
305 return Err(LuciError::InvalidQuery(format!(
306 "date_histogram: invalid interval '{interval_str}'"
307 )));
308 }
309 } else {
310 return Err(LuciError::InvalidQuery(
311 "date_histogram requires 'calendar_interval' or 'fixed_interval'".into(),
312 ));
313 };
314 Ok(AggregationExpression::DateHistogram {
315 field,
316 interval,
317 sub_aggs,
318 })
319 }
320 "date_range" => {
321 let obj = validate_keys(val, &["field", "ranges"], &ctx)?;
322 let field = require_field(obj, &ctx)?;
323 let ranges = parse_range_defs(obj, &ctx, true)?;
324 Ok(AggregationExpression::DateRange {
325 field,
326 ranges,
327 sub_aggs,
328 })
329 }
330 "filters" => {
331 let obj = validate_keys(val, &["filters"], &ctx)?;
332 let filters_obj = obj
333 .get("filters")
334 .and_then(|v| v.as_object())
335 .ok_or_else(|| {
336 LuciError::InvalidQuery("filters requires 'filters' object".into())
337 })?;
338 let mut filters = Vec::new();
339 for (name, query_val) in filters_obj {
340 let query = parse_query(query_val)?;
341 filters.push((name.clone(), query));
342 }
343 Ok(AggregationExpression::Filters { filters, sub_aggs })
344 }
345 _ => Err(LuciError::UnsupportedQuery(format!(
346 "unknown aggregation type: {key}"
347 ))),
348 }
349}
350
351fn parse_field_only(val: &Value, ctx: &str) -> Result<String> {
352 let obj = validate_keys(val, &["field"], ctx)?;
353 require_field(obj, ctx)
354}
355
356fn require_field(obj: &serde_json::Map<String, Value>, ctx: &str) -> Result<String> {
357 obj.get("field")
358 .and_then(|v| v.as_str())
359 .map(String::from)
360 .ok_or_else(|| LuciError::InvalidQuery(format!("{ctx} requires 'field'")))
361}
362
363fn parse_range_defs(
364 obj: &serde_json::Map<String, Value>,
365 ctx: &str,
366 dates: bool,
367) -> Result<Vec<RangeDef>> {
368 let ranges_val = obj
369 .get("ranges")
370 .and_then(|v| v.as_array())
371 .ok_or_else(|| LuciError::InvalidQuery(format!("{ctx}: missing 'ranges' array")))?;
372 let mut ranges = Vec::with_capacity(ranges_val.len());
373 for r in ranges_val {
374 let r_obj = validate_keys(r, &["key", "from", "to"], &format!("{ctx}.ranges[]"))?;
375 let key = r_obj.get("key").and_then(|v| v.as_str()).map(String::from);
376 let (from, to) = if dates {
377 (
378 r_obj.get("from").and_then(parse_date_value),
379 r_obj.get("to").and_then(parse_date_value),
380 )
381 } else {
382 (
383 r_obj.get("from").and_then(|v| v.as_f64()),
384 r_obj.get("to").and_then(|v| v.as_f64()),
385 )
386 };
387 ranges.push(RangeDef { key, from, to });
388 }
389 Ok(ranges)
390}
391
392fn parse_fixed_interval(s: &str) -> Result<f64> {
394 let s = s.trim();
395 if let Some(n) = s.strip_suffix("ms") {
396 return n
397 .parse::<f64>()
398 .map_err(|_| LuciError::InvalidQuery(format!("invalid interval: {s}")));
399 }
400 if let Some(n) = s.strip_suffix('s') {
401 return Ok(n
402 .parse::<f64>()
403 .map_err(|_| LuciError::InvalidQuery(format!("invalid interval: {s}")))?
404 * 1_000.0);
405 }
406 if let Some(n) = s.strip_suffix('m') {
407 return Ok(n
408 .parse::<f64>()
409 .map_err(|_| LuciError::InvalidQuery(format!("invalid interval: {s}")))?
410 * 60_000.0);
411 }
412 if let Some(n) = s.strip_suffix('h') {
413 return Ok(n
414 .parse::<f64>()
415 .map_err(|_| LuciError::InvalidQuery(format!("invalid interval: {s}")))?
416 * 3_600_000.0);
417 }
418 if let Some(n) = s.strip_suffix('d') {
419 return Ok(n
420 .parse::<f64>()
421 .map_err(|_| LuciError::InvalidQuery(format!("invalid interval: {s}")))?
422 * 86_400_000.0);
423 }
424 Err(LuciError::InvalidQuery(format!(
425 "invalid fixed_interval: {s}"
426 )))
427}
428
429fn parse_date_value(v: &Value) -> Option<f64> {
432 match v {
433 Value::Number(n) => n.as_f64(),
434 Value::String(s) => {
435 if let Ok(ms) = s.parse::<f64>() {
437 return Some(ms);
438 }
439 if s.len() >= 10 {
441 let parts: Vec<&str> = s.split('T').collect();
442 let date_parts: Vec<&str> = parts[0].split('-').collect();
443 if date_parts.len() == 3 {
444 let y: i64 = date_parts[0].parse().ok()?;
445 let m: i64 = date_parts[1].parse().ok()?;
446 let d: i64 = date_parts[2].parse().ok()?;
447 let days = (y - 1970) * 365 + (y - 1969) / 4 + (m - 1) * 30 + d - 1;
449 return Some(days as f64 * 86_400_000.0);
450 }
451 }
452 None
453 }
454 _ => None,
455 }
456}
457
458#[cfg(test)]
459mod tests {
460 use super::*;
461 use serde_json::json;
462
463 #[test]
464 fn parse_avg() {
465 let aggs = parse_aggs(&json!({"my_avg": {"avg": {"field": "price"}}})).unwrap();
466 assert_eq!(aggs.len(), 1);
467 assert_eq!(aggs[0].0, "my_avg");
468 assert!(matches!(&aggs[0].1, AggregationExpression::Avg { field } if field == "price"));
469 }
470
471 #[test]
472 fn parse_terms_with_size() {
473 let aggs = parse_aggs(&json!({"by_tag": {"terms": {"field": "tag", "size": 5}}})).unwrap();
474 if let AggregationExpression::Terms { field, size, .. } = &aggs[0].1 {
475 assert_eq!(field, "tag");
476 assert_eq!(*size, 5);
477 } else {
478 panic!();
479 }
480 }
481
482 #[test]
483 fn parse_terms_default_size() {
484 let aggs = parse_aggs(&json!({"by_tag": {"terms": {"field": "tag"}}})).unwrap();
485 if let AggregationExpression::Terms { size, .. } = &aggs[0].1 {
486 assert_eq!(*size, 10);
487 } else {
488 panic!();
489 }
490 }
491
492 #[test]
497 fn parse_terms_string_size_rejected() {
498 let err =
499 parse_aggs(&json!({"by_tag": {"terms": {"field": "tag", "size": "5"}}})).unwrap_err();
500 assert!(format!("{err}").contains("size"), "{err}");
501 }
502
503 #[test]
504 fn parse_percentiles_non_number_percent_rejected() {
505 let err =
506 parse_aggs(&json!({"p": {"percentiles": {"field": "price", "percents": [50, "99"]}}}))
507 .unwrap_err();
508 assert!(format!("{err}").contains("percents"), "{err}");
509 }
510
511 #[test]
512 fn parse_range() {
513 let aggs = parse_aggs(&json!({
514 "price_ranges": {"range": {"field": "price", "ranges": [
515 {"to": 50.0},
516 {"from": 50.0, "to": 100.0},
517 {"from": 100.0}
518 ]}}
519 }))
520 .unwrap();
521 if let AggregationExpression::Range { ranges, .. } = &aggs[0].1 {
522 assert_eq!(ranges.len(), 3);
523 } else {
524 panic!();
525 }
526 }
527
528 #[test]
529 fn parse_histogram() {
530 let aggs = parse_aggs(&json!({
531 "prices": {"histogram": {"field": "price", "interval": 10.0}}
532 }))
533 .unwrap();
534 if let AggregationExpression::Histogram { interval, .. } = &aggs[0].1 {
535 assert_eq!(*interval, 10.0);
536 } else {
537 panic!();
538 }
539 }
540
541 #[test]
542 fn parse_nested_sub_aggs() {
543 let aggs = parse_aggs(&json!({
544 "by_tag": {
545 "terms": {"field": "tag"},
546 "aggs": {
547 "avg_price": {"avg": {"field": "price"}}
548 }
549 }
550 }))
551 .unwrap();
552 if let AggregationExpression::Terms { sub_aggs, .. } = &aggs[0].1 {
553 assert_eq!(sub_aggs.len(), 1);
554 assert_eq!(sub_aggs[0].0, "avg_price");
555 } else {
556 panic!();
557 }
558 }
559
560 #[test]
561 fn parse_multiple_aggs() {
562 let aggs = parse_aggs(&json!({
563 "total": {"sum": {"field": "amount"}},
564 "average": {"avg": {"field": "amount"}}
565 }))
566 .unwrap();
567 assert_eq!(aggs.len(), 2);
568 }
569
570 #[test]
571 fn parse_filter_agg() {
572 let aggs = parse_aggs(&json!({
573 "active": {"filter": {"term": {"status": "active"}}}
574 }))
575 .unwrap();
576 assert!(matches!(&aggs[0].1, AggregationExpression::Filter { .. }));
577 }
578
579 #[test]
580 fn unknown_agg_type_error() {
581 let r = parse_aggs(&json!({"x": {"unknown_type": {"field": "f"}}}));
582 assert!(r.is_err());
583 }
584
585 #[test]
586 fn missing_field_error() {
587 let r = parse_aggs(&json!({"x": {"avg": {}}}));
588 assert!(r.is_err());
589 }
590
591 #[test]
592 fn unknown_agg_body_key_error() {
593 let r = parse_aggs(&json!({
594 "x": {"avg": {"field": "price", "missing_value": 0}}
595 }));
596 assert!(r.is_err(), "missing_value is not a valid avg key");
597 let msg = r.unwrap_err().to_string();
598 assert!(msg.contains("missing_value"));
599 }
600}