use serde::{Deserialize, Serialize};
use super::GroupBy;
use crate::{
CanonicalColumnName, aggregation::Aggregate, interval::AggregationSource, unit::MeasurementUnit,
};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GroupAggregationPlan {
pub measurement: CanonicalColumnName,
pub aggregation: Aggregate,
pub aggregation_source: AggregationSource,
pub reason: String,
}
pub struct GroupAggregationPlanner<'m, 'g> {
unit: &'m MeasurementUnit,
group_by: &'g GroupBy,
}
impl<'m, 'g> GroupAggregationPlanner<'m, 'g> {
pub fn new(unit: &'m MeasurementUnit, group_by: &'g GroupBy) -> Self {
Self { unit, group_by }
}
pub fn plan(&self) -> GroupAggregationPlan {
let (aggregation, aggregation_source) = self.choose_aggregation();
let reason = match aggregation_source {
AggregationSource::Schema => {
format!("{:?} from measurement schema default", aggregation,)
}
AggregationSource::Override => {
format!("{:?} from request's aggregation_override", aggregation,)
}
};
GroupAggregationPlan {
measurement: self.unit.name.clone(),
aggregation,
aggregation_source,
reason,
}
}
fn choose_aggregation(&self) -> (Aggregate, AggregationSource) {
if let Some(ref overrides) = self.group_by.aggregation_override
&& let Some(agg) = overrides.get(&self.unit.name)
{
return (*agg, AggregationSource::Override);
}
(self.unit.signal_aggregation(), AggregationSource::Schema)
}
}
#[cfg(test)]
mod tests {
use std::collections::HashMap;
use super::*;
use crate::{MeasurementKind, group::MissingQualityPolicy, signal_policy::SignalPolicy};
fn unit_named(name: &str, kind: MeasurementKind) -> MeasurementUnit {
MeasurementUnit::new("subject", "time", name, kind)
.with_signal_policy(SignalPolicy::instant())
.with_sample_rate_ms(60_000)
}
fn group_by_parish() -> GroupBy {
GroupBy {
qualities: vec![CanonicalColumnName::new("parish")],
aggregation_override: None,
missing_policy: MissingQualityPolicy::SyntheticGroup,
}
}
#[test]
fn schema_default_aggregation_used_when_no_override() {
let sump = unit_named("sump", MeasurementKind::Measure);
let plan = GroupAggregationPlanner::new(&sump, &group_by_parish()).plan();
assert_eq!(plan.aggregation, Aggregate::Mean);
assert_eq!(plan.aggregation_source, AggregationSource::Schema);
}
#[test]
fn override_aggregation_beats_schema_default() {
let sump = unit_named("sump", MeasurementKind::Measure);
let mut overrides = HashMap::new();
overrides.insert(CanonicalColumnName::new("sump"), Aggregate::Max);
let group_by = GroupBy {
qualities: vec![CanonicalColumnName::new("parish")],
aggregation_override: Some(overrides),
missing_policy: MissingQualityPolicy::SyntheticGroup,
};
let plan = GroupAggregationPlanner::new(&sump, &group_by).plan();
assert_eq!(plan.aggregation, Aggregate::Max);
assert_eq!(plan.aggregation_source, AggregationSource::Override);
assert!(plan.reason.contains("aggregation_override"));
}
#[test]
fn override_keyed_on_other_measurement_is_ignored() {
let sump = unit_named("sump", MeasurementKind::Measure);
let mut overrides = HashMap::new();
overrides.insert(CanonicalColumnName::new("engines_on_count"), Aggregate::Sum);
let group_by = GroupBy {
qualities: vec![CanonicalColumnName::new("parish")],
aggregation_override: Some(overrides),
missing_policy: MissingQualityPolicy::SyntheticGroup,
};
let plan = GroupAggregationPlanner::new(&sump, &group_by).plan();
assert_eq!(plan.aggregation, Aggregate::Mean);
assert_eq!(plan.aggregation_source, AggregationSource::Schema);
}
#[test]
fn plan_records_measurement_name() {
let sump = unit_named("sump", MeasurementKind::Measure);
let plan = GroupAggregationPlanner::new(&sump, &group_by_parish()).plan();
assert_eq!(plan.measurement, CanonicalColumnName::new("sump"));
assert!(!plan.reason.is_empty());
}
}