use async_trait::async_trait;
use datafusion::prelude::*;
use serde::{Deserialize, Serialize};
use std::collections::{BTreeMap, HashMap};
use std::fmt;
use tracing::{debug, instrument};
use super::{Analyzer, AnalyzerResult, AnalyzerState, MetricValue};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GroupingConfig {
pub columns: Vec<String>,
pub max_groups: Option<usize>,
pub include_overall: bool,
pub overflow_strategy: OverflowStrategy,
}
impl GroupingConfig {
pub fn new(columns: Vec<String>) -> Self {
Self {
columns,
max_groups: Some(10000),
include_overall: true,
overflow_strategy: OverflowStrategy::TopK,
}
}
pub fn with_max_groups(mut self, max: usize) -> Self {
self.max_groups = Some(max);
self
}
pub fn with_overall(mut self, include: bool) -> Self {
self.include_overall = include;
self
}
pub fn with_overflow_strategy(mut self, strategy: OverflowStrategy) -> Self {
self.overflow_strategy = strategy;
self
}
pub fn group_by_sql(&self) -> String {
self.columns.join(", ")
}
pub fn select_group_columns_sql(&self) -> String {
self.columns
.iter()
.map(|col| format!("{col} as group_{col}"))
.collect::<Vec<_>>()
.join(", ")
}
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub enum OverflowStrategy {
TopK,
BottomK,
Sample,
Fail,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GroupedMetrics {
pub groups: BTreeMap<Vec<String>, MetricValue>,
pub overall: Option<MetricValue>,
pub metadata: GroupedMetadata,
}
impl GroupedMetrics {
pub fn new(
groups: BTreeMap<Vec<String>, MetricValue>,
overall: Option<MetricValue>,
metadata: GroupedMetadata,
) -> Self {
Self {
groups,
overall,
metadata,
}
}
pub fn group_count(&self) -> usize {
self.groups.len()
}
pub fn get_group(&self, key: &[String]) -> Option<&MetricValue> {
self.groups.get(key)
}
pub fn is_truncated(&self) -> bool {
self.metadata.truncated
}
pub fn to_metric_value(&self) -> MetricValue {
let mut map = HashMap::new();
for (key, value) in &self.groups {
let key_str = key.join("_");
map.insert(key_str, value.clone());
}
if let Some(ref overall) = self.overall {
map.insert("__overall__".to_string(), overall.clone());
}
map.insert(
"__metadata__".to_string(),
MetricValue::String(serde_json::to_string(&self.metadata).unwrap_or_default()),
);
MetricValue::Map(map)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GroupedMetadata {
pub group_columns: Vec<String>,
pub total_groups: usize,
pub included_groups: usize,
pub truncated: bool,
pub overflow_strategy: Option<OverflowStrategy>,
}
impl GroupedMetadata {
pub fn new(group_columns: Vec<String>, total_groups: usize, included_groups: usize) -> Self {
Self {
group_columns,
total_groups,
included_groups,
truncated: total_groups > included_groups,
overflow_strategy: None,
}
}
}
#[async_trait]
pub trait GroupedAnalyzer: Analyzer {
type GroupedState: GroupedAnalyzerState;
fn with_grouping(self, config: GroupingConfig) -> GroupedAnalyzerWrapper<Self>
where
Self: Sized + 'static,
{
GroupedAnalyzerWrapper::new(self, config)
}
async fn compute_grouped_state_from_data(
&self,
ctx: &SessionContext,
config: &GroupingConfig,
) -> AnalyzerResult<Self::GroupedState>;
fn compute_grouped_metrics_from_state(
&self,
state: &Self::GroupedState,
) -> AnalyzerResult<GroupedMetrics>;
}
pub trait GroupedAnalyzerState: AnalyzerState {}
pub struct GroupedAnalyzerWrapper<A: GroupedAnalyzer> {
analyzer: A,
config: GroupingConfig,
}
impl<A: GroupedAnalyzer> GroupedAnalyzerWrapper<A> {
pub fn new(analyzer: A, config: GroupingConfig) -> Self {
Self { analyzer, config }
}
}
impl<A: GroupedAnalyzer> fmt::Debug for GroupedAnalyzerWrapper<A> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("GroupedAnalyzerWrapper")
.field("analyzer", &self.analyzer.name())
.field("group_columns", &self.config.columns)
.finish()
}
}
#[async_trait]
impl<A> Analyzer for GroupedAnalyzerWrapper<A>
where
A: GroupedAnalyzer + Send + Sync + 'static,
A::GroupedState: AnalyzerState + 'static,
{
type State = A::GroupedState;
type Metric = MetricValue;
#[instrument(skip(ctx), fields(
analyzer = %self.analyzer.name(),
group_columns = ?self.config.columns
))]
async fn compute_state_from_data(&self, ctx: &SessionContext) -> AnalyzerResult<Self::State> {
debug!(
"Computing grouped state for {} analyzer",
self.analyzer.name()
);
self.analyzer
.compute_grouped_state_from_data(ctx, &self.config)
.await
}
fn compute_metric_from_state(&self, state: &Self::State) -> AnalyzerResult<Self::Metric> {
let grouped_metrics = self.analyzer.compute_grouped_metrics_from_state(state)?;
Ok(grouped_metrics.to_metric_value())
}
fn name(&self) -> &str {
self.analyzer.name()
}
fn description(&self) -> &str {
self.analyzer.description()
}
fn metric_key(&self) -> String {
format!(
"{}_grouped_by_{}",
self.analyzer.metric_key(),
self.config.columns.join("_")
)
}
fn columns(&self) -> Vec<&str> {
let mut cols = self.analyzer.columns();
for col in &self.config.columns {
cols.push(col);
}
cols
}
}
pub mod sql_helpers {
use super::GroupingConfig;
pub fn build_group_by_clause(config: &GroupingConfig) -> String {
if config.columns.is_empty() {
String::new()
} else {
format!(" GROUP BY {}", config.group_by_sql())
}
}
pub fn build_group_select(config: &GroupingConfig, metric_sql: &str) -> String {
if config.columns.is_empty() {
metric_sql.to_string()
} else {
format!("{}, {metric_sql}", config.select_group_columns_sql())
}
}
pub fn build_limit_clause(config: &GroupingConfig) -> String {
if let Some(max) = config.max_groups {
format!(" LIMIT {max}")
} else {
String::new()
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_grouping_config() {
let config = GroupingConfig::new(vec!["country".to_string(), "city".to_string()])
.with_max_groups(1000)
.with_overall(false);
assert_eq!(config.group_by_sql(), "country, city");
assert_eq!(
config.select_group_columns_sql(),
"country as group_country, city as group_city"
);
assert_eq!(config.max_groups, Some(1000));
assert!(!config.include_overall);
}
#[test]
fn test_grouped_metrics() {
let mut groups = BTreeMap::new();
groups.insert(
vec!["US".to_string(), "NYC".to_string()],
MetricValue::Double(0.95),
);
groups.insert(
vec!["US".to_string(), "LA".to_string()],
MetricValue::Double(0.92),
);
let metadata = GroupedMetadata::new(vec!["country".to_string(), "city".to_string()], 2, 2);
let grouped = GroupedMetrics::new(groups, Some(MetricValue::Double(0.935)), metadata);
assert_eq!(grouped.group_count(), 2);
assert!(!grouped.is_truncated());
let us_nyc = grouped.get_group(&["US".to_string(), "NYC".to_string()]);
assert_eq!(us_nyc, Some(&MetricValue::Double(0.95)));
}
}