1use arrow::datatypes::DataType;
4use serde::{Deserialize, Serialize};
5
6#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
8pub enum AggFunc {
9 Sum,
11 Avg,
13 Min,
15 Max,
17 Count,
19 CountDistinct,
21 Median,
23 StdDev,
25 Variance,
27 First,
29 Last,
31}
32
33impl AggFunc {
34 pub fn sql_name(&self) -> &'static str {
36 match self {
37 AggFunc::Sum => "SUM",
38 AggFunc::Avg => "AVG",
39 AggFunc::Min => "MIN",
40 AggFunc::Max => "MAX",
41 AggFunc::Count => "COUNT",
42 AggFunc::CountDistinct => "COUNT",
43 AggFunc::Median => "MEDIAN",
44 AggFunc::StdDev => "STDDEV",
45 AggFunc::Variance => "VAR",
46 AggFunc::First => "FIRST_VALUE",
47 AggFunc::Last => "LAST_VALUE",
48 }
49 }
50
51 pub fn is_compatible_with(&self, data_type: &DataType) -> bool {
53 use DataType::*;
54 match self {
55 AggFunc::Sum | AggFunc::Avg | AggFunc::StdDev | AggFunc::Variance => {
56 matches!(
57 data_type,
58 Int8 | Int16
59 | Int32
60 | Int64
61 | UInt8
62 | UInt16
63 | UInt32
64 | UInt64
65 | Float32
66 | Float64
67 | Decimal128(_, _)
68 | Decimal256(_, _)
69 )
70 }
71 AggFunc::Min | AggFunc::Max | AggFunc::First | AggFunc::Last => true,
72 AggFunc::Count | AggFunc::CountDistinct => true,
73 AggFunc::Median => {
74 matches!(
75 data_type,
76 Int8 | Int16
77 | Int32
78 | Int64
79 | UInt8
80 | UInt16
81 | UInt32
82 | UInt64
83 | Float32
84 | Float64
85 )
86 }
87 }
88 }
89}
90
91impl std::fmt::Display for AggFunc {
92 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
93 write!(f, "{}", self.sql_name())
94 }
95}
96
97#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
102pub struct Measure {
103 name: String,
105
106 data_type: DataType,
108
109 default_agg: AggFunc,
111
112 nullable: bool,
114
115 description: Option<String>,
117
118 format: Option<String>,
120}
121
122impl Measure {
123 pub fn new(name: impl Into<String>, data_type: DataType, default_agg: AggFunc) -> Self {
125 Self {
126 name: name.into(),
127 data_type,
128 default_agg,
129 nullable: true,
130 description: None,
131 format: None,
132 }
133 }
134
135 pub fn with_config(
137 name: impl Into<String>,
138 data_type: DataType,
139 default_agg: AggFunc,
140 nullable: bool,
141 description: Option<String>,
142 format: Option<String>,
143 ) -> Self {
144 Self {
145 name: name.into(),
146 data_type,
147 default_agg,
148 nullable,
149 description,
150 format,
151 }
152 }
153
154 pub fn name(&self) -> &str {
156 &self.name
157 }
158
159 pub fn data_type(&self) -> &DataType {
161 &self.data_type
162 }
163
164 pub fn default_agg(&self) -> AggFunc {
166 self.default_agg
167 }
168
169 pub fn is_nullable(&self) -> bool {
171 self.nullable
172 }
173
174 pub fn description(&self) -> Option<&str> {
176 self.description.as_deref()
177 }
178
179 pub fn format(&self) -> Option<&str> {
181 self.format.as_deref()
182 }
183
184 pub fn set_description(&mut self, description: impl Into<String>) {
186 self.description = Some(description.into());
187 }
188
189 pub fn set_format(&mut self, format: impl Into<String>) {
191 self.format = Some(format.into());
192 }
193
194 pub fn with_nullable(mut self, nullable: bool) -> Self {
196 self.nullable = nullable;
197 self
198 }
199
200 pub fn with_description(mut self, description: impl Into<String>) -> Self {
202 self.description = Some(description.into());
203 self
204 }
205
206 pub fn with_format(mut self, format: impl Into<String>) -> Self {
208 self.format = Some(format.into());
209 self
210 }
211
212 pub fn validate(&self) -> Result<(), String> {
214 if !self.default_agg.is_compatible_with(&self.data_type) {
215 return Err(format!(
216 "Aggregation function {} is not compatible with data type {:?}",
217 self.default_agg, self.data_type
218 ));
219 }
220 Ok(())
221 }
222}
223
224#[cfg(test)]
225mod tests {
226 use super::*;
227
228 #[test]
229 fn test_measure_creation() {
230 let measure = Measure::new("revenue", DataType::Float64, AggFunc::Sum);
231 assert_eq!(measure.name(), "revenue");
232 assert_eq!(measure.data_type(), &DataType::Float64);
233 assert_eq!(measure.default_agg(), AggFunc::Sum);
234 assert!(measure.is_nullable());
235 }
236
237 #[test]
238 fn test_measure_validation() {
239 let valid_measure = Measure::new("amount", DataType::Float64, AggFunc::Sum);
240 assert!(valid_measure.validate().is_ok());
241
242 let invalid_measure = Measure::new("category", DataType::Utf8, AggFunc::Sum);
243 assert!(invalid_measure.validate().is_err());
244 }
245
246 #[test]
247 fn test_agg_func_compatibility() {
248 assert!(AggFunc::Sum.is_compatible_with(&DataType::Float64));
249 assert!(AggFunc::Sum.is_compatible_with(&DataType::Int32));
250 assert!(!AggFunc::Sum.is_compatible_with(&DataType::Utf8));
251
252 assert!(AggFunc::Count.is_compatible_with(&DataType::Utf8));
253 assert!(AggFunc::Max.is_compatible_with(&DataType::Utf8));
254 }
255
256 #[test]
257 fn test_measure_builder() {
258 let measure = Measure::new("sales", DataType::Float64, AggFunc::Sum)
259 .with_nullable(false)
260 .with_description("Total sales amount")
261 .with_format("$,.2f");
262
263 assert_eq!(measure.name(), "sales");
264 assert!(!measure.is_nullable());
265 assert_eq!(measure.description(), Some("Total sales amount"));
266 assert_eq!(measure.format(), Some("$,.2f"));
267 }
268}