use arrow::datatypes::DataType;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum AggFunc {
Sum,
Avg,
Min,
Max,
Count,
CountDistinct,
Median,
StdDev,
Variance,
First,
Last,
}
impl AggFunc {
pub fn sql_name(&self) -> &'static str {
match self {
AggFunc::Sum => "SUM",
AggFunc::Avg => "AVG",
AggFunc::Min => "MIN",
AggFunc::Max => "MAX",
AggFunc::Count => "COUNT",
AggFunc::CountDistinct => "COUNT",
AggFunc::Median => "MEDIAN",
AggFunc::StdDev => "STDDEV",
AggFunc::Variance => "VAR",
AggFunc::First => "FIRST_VALUE",
AggFunc::Last => "LAST_VALUE",
}
}
pub fn is_compatible_with(&self, data_type: &DataType) -> bool {
use DataType::*;
match self {
AggFunc::Sum | AggFunc::Avg | AggFunc::StdDev | AggFunc::Variance => {
matches!(
data_type,
Int8 | Int16
| Int32
| Int64
| UInt8
| UInt16
| UInt32
| UInt64
| Float32
| Float64
| Decimal128(_, _)
| Decimal256(_, _)
)
}
AggFunc::Min | AggFunc::Max | AggFunc::First | AggFunc::Last => true,
AggFunc::Count | AggFunc::CountDistinct => true,
AggFunc::Median => {
matches!(
data_type,
Int8 | Int16
| Int32
| Int64
| UInt8
| UInt16
| UInt32
| UInt64
| Float32
| Float64
)
}
}
}
}
impl std::fmt::Display for AggFunc {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.sql_name())
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct Measure {
name: String,
data_type: DataType,
default_agg: AggFunc,
nullable: bool,
description: Option<String>,
format: Option<String>,
}
impl Measure {
pub fn new(name: impl Into<String>, data_type: DataType, default_agg: AggFunc) -> Self {
Self {
name: name.into(),
data_type,
default_agg,
nullable: true,
description: None,
format: None,
}
}
pub fn with_config(
name: impl Into<String>,
data_type: DataType,
default_agg: AggFunc,
nullable: bool,
description: Option<String>,
format: Option<String>,
) -> Self {
Self {
name: name.into(),
data_type,
default_agg,
nullable,
description,
format,
}
}
pub fn name(&self) -> &str {
&self.name
}
pub fn data_type(&self) -> &DataType {
&self.data_type
}
pub fn default_agg(&self) -> AggFunc {
self.default_agg
}
pub fn is_nullable(&self) -> bool {
self.nullable
}
pub fn description(&self) -> Option<&str> {
self.description.as_deref()
}
pub fn format(&self) -> Option<&str> {
self.format.as_deref()
}
pub fn set_description(&mut self, description: impl Into<String>) {
self.description = Some(description.into());
}
pub fn set_format(&mut self, format: impl Into<String>) {
self.format = Some(format.into());
}
pub fn with_nullable(mut self, nullable: bool) -> Self {
self.nullable = nullable;
self
}
pub fn with_description(mut self, description: impl Into<String>) -> Self {
self.description = Some(description.into());
self
}
pub fn with_format(mut self, format: impl Into<String>) -> Self {
self.format = Some(format.into());
self
}
pub fn validate(&self) -> Result<(), String> {
if !self.default_agg.is_compatible_with(&self.data_type) {
return Err(format!(
"Aggregation function {} is not compatible with data type {:?}",
self.default_agg, self.data_type
));
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_measure_creation() {
let measure = Measure::new("revenue", DataType::Float64, AggFunc::Sum);
assert_eq!(measure.name(), "revenue");
assert_eq!(measure.data_type(), &DataType::Float64);
assert_eq!(measure.default_agg(), AggFunc::Sum);
assert!(measure.is_nullable());
}
#[test]
fn test_measure_validation() {
let valid_measure = Measure::new("amount", DataType::Float64, AggFunc::Sum);
assert!(valid_measure.validate().is_ok());
let invalid_measure = Measure::new("category", DataType::Utf8, AggFunc::Sum);
assert!(invalid_measure.validate().is_err());
}
#[test]
fn test_agg_func_compatibility() {
assert!(AggFunc::Sum.is_compatible_with(&DataType::Float64));
assert!(AggFunc::Sum.is_compatible_with(&DataType::Int32));
assert!(!AggFunc::Sum.is_compatible_with(&DataType::Utf8));
assert!(AggFunc::Count.is_compatible_with(&DataType::Utf8));
assert!(AggFunc::Max.is_compatible_with(&DataType::Utf8));
}
#[test]
fn test_measure_builder() {
let measure = Measure::new("sales", DataType::Float64, AggFunc::Sum)
.with_nullable(false)
.with_description("Total sales amount")
.with_format("$,.2f");
assert_eq!(measure.name(), "sales");
assert!(!measure.is_nullable());
assert_eq!(measure.description(), Some("Total sales amount"));
assert_eq!(measure.format(), Some("$,.2f"));
}
}