use crate::core::{LuciError, Result};
use serde_json::Value;
use super::{AggregationExpression, RangeDef};
use crate::query::parser::{opt_f64, opt_str, opt_u64, parse_query};
fn validate_keys<'a>(
val: &'a Value,
expected: &[&str],
ctx: &str,
) -> Result<&'a serde_json::Map<String, Value>> {
let obj = val
.as_object()
.ok_or_else(|| LuciError::InvalidQuery(format!("{ctx}: must be an object")))?;
for key in obj.keys() {
if !expected.contains(&key.as_str()) {
let expected_list = expected
.iter()
.map(|k| format!("`{k}`"))
.collect::<Vec<_>>()
.join(", ");
return Err(LuciError::InvalidQuery(format!(
"{ctx}: unknown field `{key}`, expected one of {expected_list}"
)));
}
}
Ok(obj)
}
pub fn parse_aggs(json: &Value) -> Result<Vec<(String, AggregationExpression)>> {
let obj = match json.as_object() {
Some(o) => o,
None => return Err(LuciError::InvalidQuery("aggs must be an object".into())),
};
let mut aggs = Vec::new();
for (name, agg_val) in obj {
aggs.push(parse_single_agg(name, agg_val)?);
}
Ok(aggs)
}
fn parse_single_agg(name: &str, val: &Value) -> Result<(String, AggregationExpression)> {
let obj = val.as_object().ok_or_else(|| {
LuciError::InvalidQuery(format!("aggregation '{name}' must be an object"))
})?;
let mut agg_type = None;
let mut sub_aggs_val = None;
for (key, v) in obj {
match key.as_str() {
"aggs" | "aggregations" => sub_aggs_val = Some(v),
_ => {
if agg_type.is_some() {
return Err(LuciError::InvalidQuery(format!(
"aggregation '{name}' has multiple type keys"
)));
}
agg_type = Some((key.as_str(), v));
}
}
}
let (type_key, type_val) = agg_type
.ok_or_else(|| LuciError::InvalidQuery(format!("aggregation '{name}' has no type")))?;
let sub_aggs = match sub_aggs_val {
Some(v) => parse_aggs(v)?,
None => Vec::new(),
};
if !sub_aggs.is_empty() && !agg_type_accepts_sub_aggs(type_key) {
return Err(LuciError::InvalidQuery(format!(
"aggregation '{name}' of type [{type_key}] cannot have sub-aggregations"
)));
}
let expr = parse_agg_expr(name, type_key, type_val, sub_aggs)?;
Ok((name.to_string(), expr))
}
fn agg_type_accepts_sub_aggs(type_key: &str) -> bool {
matches!(
type_key,
"terms"
| "range"
| "date_range"
| "histogram"
| "date_histogram"
| "filter"
| "filters"
| "nested"
| "reverse_nested"
| "geohash_grid"
)
}
fn parse_agg_expr(
name: &str,
key: &str,
val: &Value,
sub_aggs: Vec<(String, AggregationExpression)>,
) -> Result<AggregationExpression> {
let ctx = format!("{name}.{key}");
match key {
"avg" => Ok(AggregationExpression::Avg {
field: parse_field_only(val, &ctx)?,
}),
"sum" => Ok(AggregationExpression::Sum {
field: parse_field_only(val, &ctx)?,
}),
"min" => Ok(AggregationExpression::Min {
field: parse_field_only(val, &ctx)?,
}),
"max" => Ok(AggregationExpression::Max {
field: parse_field_only(val, &ctx)?,
}),
"value_count" => Ok(AggregationExpression::ValueCount {
field: parse_field_only(val, &ctx)?,
}),
"stats" => Ok(AggregationExpression::Stats {
field: parse_field_only(val, &ctx)?,
}),
"extended_stats" => Ok(AggregationExpression::ExtendedStats {
field: parse_field_only(val, &ctx)?,
}),
"terms" => {
let obj = validate_keys(val, &["field", "size"], &ctx)?;
Ok(AggregationExpression::Terms {
field: require_field(obj, &ctx)?,
size: opt_u64(obj, "size", &ctx)?.unwrap_or(10) as usize,
sub_aggs,
})
}
"range" => {
let obj = validate_keys(val, &["field", "ranges"], &ctx)?;
let field = require_field(obj, &ctx)?;
let ranges = parse_range_defs(obj, &ctx, false)?;
Ok(AggregationExpression::Range {
field,
ranges,
sub_aggs,
})
}
"histogram" => {
let obj = validate_keys(val, &["field", "interval"], &ctx)?;
let field = require_field(obj, &ctx)?;
let interval = obj
.get("interval")
.and_then(|v| v.as_f64())
.ok_or_else(|| LuciError::InvalidQuery("histogram requires 'interval'".into()))?;
Ok(AggregationExpression::Histogram {
field,
interval,
sub_aggs,
})
}
"filter" => {
let query = parse_query(val)?;
Ok(AggregationExpression::Filter { query, sub_aggs })
}
"cardinality" => {
let obj = validate_keys(val, &["field", "precision_threshold"], &ctx)?;
Ok(AggregationExpression::Cardinality {
field: require_field(obj, &ctx)?,
precision_threshold: opt_u64(obj, "precision_threshold", &ctx)?.unwrap_or(3000)
as u32,
})
}
"percentiles" => {
let obj = validate_keys(val, &["field", "percents", "tdigest"], &ctx)?;
let field = require_field(obj, &ctx)?;
let percents = match obj.get("percents") {
Some(v) => {
let arr = v.as_array().ok_or_else(|| {
LuciError::InvalidQuery(
"percentiles: \"percents\" must be an array of numbers".into(),
)
})?;
arr.iter()
.map(|p| {
p.as_f64().ok_or_else(|| {
LuciError::InvalidQuery(format!(
"percentiles: percents[] entries must be numbers, got {p}"
))
})
})
.collect::<Result<Vec<f64>>>()?
}
None => vec![1.0, 5.0, 25.0, 50.0, 75.0, 95.0, 99.0],
};
let compression = match obj.get("tdigest") {
Some(t) => {
let tdigest_obj = validate_keys(t, &["compression"], "percentiles.tdigest")?;
opt_f64(tdigest_obj, "compression", "percentiles.tdigest")?.unwrap_or(100.0)
}
None => 100.0,
};
Ok(AggregationExpression::Percentiles {
field,
percents,
compression,
})
}
"geo_bounds" => Ok(AggregationExpression::GeoBounds {
field: parse_field_only(val, &ctx)?,
}),
"geo_centroid" => Ok(AggregationExpression::GeoCentroid {
field: parse_field_only(val, &ctx)?,
}),
"nested" => {
let obj = validate_keys(val, &["path"], &ctx)?;
let path = obj
.get("path")
.and_then(|v| v.as_str())
.ok_or_else(|| LuciError::InvalidQuery("nested agg requires 'path'".into()))?
.to_string();
Ok(AggregationExpression::Nested { path, sub_aggs })
}
"reverse_nested" => {
validate_keys(val, &[], &ctx)?;
Ok(AggregationExpression::ReverseNested { sub_aggs })
}
"geohash_grid" => {
let obj = validate_keys(val, &["field", "precision", "size"], &ctx)?;
Ok(AggregationExpression::GeohashGrid {
field: require_field(obj, &ctx)?,
precision: opt_u64(obj, "precision", &ctx)?.unwrap_or(5) as usize,
size: opt_u64(obj, "size", &ctx)?.unwrap_or(10000) as usize,
sub_aggs,
})
}
"top_hits" => {
let obj = validate_keys(val, &["size"], &ctx)?;
Ok(AggregationExpression::TopHits {
size: opt_u64(obj, "size", &ctx)?.unwrap_or(3) as usize,
})
}
"date_histogram" => {
let obj = validate_keys(
val,
&["field", "calendar_interval", "fixed_interval", "interval"],
&ctx,
)?;
let field = require_field(obj, &ctx)?;
let interval = if let Some(cal) = opt_str(obj, "calendar_interval", &ctx)? {
let cal_int = match cal {
"minute" | "1m" => super::CalendarInterval::Minute,
"hour" | "1h" => super::CalendarInterval::Hour,
"day" | "1d" => super::CalendarInterval::Day,
"week" | "1w" => super::CalendarInterval::Week,
"month" | "1M" => super::CalendarInterval::Month,
"quarter" | "1q" => super::CalendarInterval::Quarter,
"year" | "1y" => super::CalendarInterval::Year,
other => {
return Err(LuciError::InvalidQuery(format!(
"date_histogram: unknown calendar_interval '{other}'"
)));
}
};
super::DateInterval::Calendar(cal_int)
} else if let Some(fixed) = opt_str(obj, "fixed_interval", &ctx)? {
let ms = parse_fixed_interval(fixed)?;
super::DateInterval::Fixed(ms)
} else if let Some(interval_str) = opt_str(obj, "interval", &ctx)? {
if let Ok(ms) = parse_fixed_interval(interval_str) {
super::DateInterval::Fixed(ms)
} else {
return Err(LuciError::InvalidQuery(format!(
"date_histogram: invalid interval '{interval_str}'"
)));
}
} else {
return Err(LuciError::InvalidQuery(
"date_histogram requires 'calendar_interval' or 'fixed_interval'".into(),
));
};
Ok(AggregationExpression::DateHistogram {
field,
interval,
sub_aggs,
})
}
"date_range" => {
let obj = validate_keys(val, &["field", "ranges"], &ctx)?;
let field = require_field(obj, &ctx)?;
let ranges = parse_range_defs(obj, &ctx, true)?;
Ok(AggregationExpression::DateRange {
field,
ranges,
sub_aggs,
})
}
"filters" => {
let obj = validate_keys(val, &["filters"], &ctx)?;
let filters_obj = obj
.get("filters")
.and_then(|v| v.as_object())
.ok_or_else(|| {
LuciError::InvalidQuery("filters requires 'filters' object".into())
})?;
let mut filters = Vec::new();
for (name, query_val) in filters_obj {
let query = parse_query(query_val)?;
filters.push((name.clone(), query));
}
Ok(AggregationExpression::Filters { filters, sub_aggs })
}
_ => Err(LuciError::UnsupportedQuery(format!(
"unknown aggregation type: {key}"
))),
}
}
fn parse_field_only(val: &Value, ctx: &str) -> Result<String> {
let obj = validate_keys(val, &["field"], ctx)?;
require_field(obj, ctx)
}
fn require_field(obj: &serde_json::Map<String, Value>, ctx: &str) -> Result<String> {
obj.get("field")
.and_then(|v| v.as_str())
.map(String::from)
.ok_or_else(|| LuciError::InvalidQuery(format!("{ctx} requires 'field'")))
}
fn parse_range_defs(
obj: &serde_json::Map<String, Value>,
ctx: &str,
dates: bool,
) -> Result<Vec<RangeDef>> {
let ranges_val = obj
.get("ranges")
.and_then(|v| v.as_array())
.ok_or_else(|| LuciError::InvalidQuery(format!("{ctx}: missing 'ranges' array")))?;
let mut ranges = Vec::with_capacity(ranges_val.len());
for r in ranges_val {
let r_obj = validate_keys(r, &["key", "from", "to"], &format!("{ctx}.ranges[]"))?;
let key = r_obj.get("key").and_then(|v| v.as_str()).map(String::from);
let (from, to) = if dates {
(
r_obj.get("from").and_then(parse_date_value),
r_obj.get("to").and_then(parse_date_value),
)
} else {
(
r_obj.get("from").and_then(|v| v.as_f64()),
r_obj.get("to").and_then(|v| v.as_f64()),
)
};
ranges.push(RangeDef { key, from, to });
}
Ok(ranges)
}
fn parse_fixed_interval(s: &str) -> Result<f64> {
let s = s.trim();
if let Some(n) = s.strip_suffix("ms") {
return n
.parse::<f64>()
.map_err(|_| LuciError::InvalidQuery(format!("invalid interval: {s}")));
}
if let Some(n) = s.strip_suffix('s') {
return Ok(n
.parse::<f64>()
.map_err(|_| LuciError::InvalidQuery(format!("invalid interval: {s}")))?
* 1_000.0);
}
if let Some(n) = s.strip_suffix('m') {
return Ok(n
.parse::<f64>()
.map_err(|_| LuciError::InvalidQuery(format!("invalid interval: {s}")))?
* 60_000.0);
}
if let Some(n) = s.strip_suffix('h') {
return Ok(n
.parse::<f64>()
.map_err(|_| LuciError::InvalidQuery(format!("invalid interval: {s}")))?
* 3_600_000.0);
}
if let Some(n) = s.strip_suffix('d') {
return Ok(n
.parse::<f64>()
.map_err(|_| LuciError::InvalidQuery(format!("invalid interval: {s}")))?
* 86_400_000.0);
}
Err(LuciError::InvalidQuery(format!(
"invalid fixed_interval: {s}"
)))
}
fn parse_date_value(v: &Value) -> Option<f64> {
match v {
Value::Number(n) => n.as_f64(),
Value::String(s) => {
if let Ok(ms) = s.parse::<f64>() {
return Some(ms);
}
if s.len() >= 10 {
let parts: Vec<&str> = s.split('T').collect();
let date_parts: Vec<&str> = parts[0].split('-').collect();
if date_parts.len() == 3 {
let y: i64 = date_parts[0].parse().ok()?;
let m: i64 = date_parts[1].parse().ok()?;
let d: i64 = date_parts[2].parse().ok()?;
let days = (y - 1970) * 365 + (y - 1969) / 4 + (m - 1) * 30 + d - 1;
return Some(days as f64 * 86_400_000.0);
}
}
None
}
_ => None,
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn parse_avg() {
let aggs = parse_aggs(&json!({"my_avg": {"avg": {"field": "price"}}})).unwrap();
assert_eq!(aggs.len(), 1);
assert_eq!(aggs[0].0, "my_avg");
assert!(matches!(&aggs[0].1, AggregationExpression::Avg { field } if field == "price"));
}
#[test]
fn parse_terms_with_size() {
let aggs = parse_aggs(&json!({"by_tag": {"terms": {"field": "tag", "size": 5}}})).unwrap();
if let AggregationExpression::Terms { field, size, .. } = &aggs[0].1 {
assert_eq!(field, "tag");
assert_eq!(*size, 5);
} else {
panic!();
}
}
#[test]
fn parse_terms_default_size() {
let aggs = parse_aggs(&json!({"by_tag": {"terms": {"field": "tag"}}})).unwrap();
if let AggregationExpression::Terms { size, .. } = &aggs[0].1 {
assert_eq!(*size, 10);
} else {
panic!();
}
}
#[test]
fn parse_terms_string_size_rejected() {
let err =
parse_aggs(&json!({"by_tag": {"terms": {"field": "tag", "size": "5"}}})).unwrap_err();
assert!(format!("{err}").contains("size"), "{err}");
}
#[test]
fn parse_percentiles_non_number_percent_rejected() {
let err =
parse_aggs(&json!({"p": {"percentiles": {"field": "price", "percents": [50, "99"]}}}))
.unwrap_err();
assert!(format!("{err}").contains("percents"), "{err}");
}
#[test]
fn parse_range() {
let aggs = parse_aggs(&json!({
"price_ranges": {"range": {"field": "price", "ranges": [
{"to": 50.0},
{"from": 50.0, "to": 100.0},
{"from": 100.0}
]}}
}))
.unwrap();
if let AggregationExpression::Range { ranges, .. } = &aggs[0].1 {
assert_eq!(ranges.len(), 3);
} else {
panic!();
}
}
#[test]
fn parse_histogram() {
let aggs = parse_aggs(&json!({
"prices": {"histogram": {"field": "price", "interval": 10.0}}
}))
.unwrap();
if let AggregationExpression::Histogram { interval, .. } = &aggs[0].1 {
assert_eq!(*interval, 10.0);
} else {
panic!();
}
}
#[test]
fn parse_nested_sub_aggs() {
let aggs = parse_aggs(&json!({
"by_tag": {
"terms": {"field": "tag"},
"aggs": {
"avg_price": {"avg": {"field": "price"}}
}
}
}))
.unwrap();
if let AggregationExpression::Terms { sub_aggs, .. } = &aggs[0].1 {
assert_eq!(sub_aggs.len(), 1);
assert_eq!(sub_aggs[0].0, "avg_price");
} else {
panic!();
}
}
#[test]
fn parse_multiple_aggs() {
let aggs = parse_aggs(&json!({
"total": {"sum": {"field": "amount"}},
"average": {"avg": {"field": "amount"}}
}))
.unwrap();
assert_eq!(aggs.len(), 2);
}
#[test]
fn parse_filter_agg() {
let aggs = parse_aggs(&json!({
"active": {"filter": {"term": {"status": "active"}}}
}))
.unwrap();
assert!(matches!(&aggs[0].1, AggregationExpression::Filter { .. }));
}
#[test]
fn unknown_agg_type_error() {
let r = parse_aggs(&json!({"x": {"unknown_type": {"field": "f"}}}));
assert!(r.is_err());
}
#[test]
fn missing_field_error() {
let r = parse_aggs(&json!({"x": {"avg": {}}}));
assert!(r.is_err());
}
#[test]
fn unknown_agg_body_key_error() {
let r = parse_aggs(&json!({
"x": {"avg": {"field": "price", "missing_value": 0}}
}));
assert!(r.is_err(), "missing_value is not a valid avg key");
let msg = r.unwrap_err().to_string();
assert!(msg.contains("missing_value"));
}
}