1use serde::{Deserialize, Serialize};
53
54use crate::{
55 compiler::{
56 aggregate_types::{AggregateFunction, HavingOperator, TemporalBucket},
57 fact_table::FactTableMetadata,
58 },
59 db::where_clause::WhereClause,
60 error::{FraiseQLError, Result},
61};
62
63#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
65pub struct AggregationRequest {
66 pub table_name: String,
68 pub where_clause: Option<WhereClause>,
70 pub group_by: Vec<GroupBySelection>,
72 pub aggregates: Vec<AggregateSelection>,
74 pub having: Vec<HavingCondition>,
76 pub order_by: Vec<OrderByClause>,
78 pub limit: Option<u32>,
80 pub offset: Option<u32>,
82}
83
84#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
86pub enum GroupBySelection {
87 Dimension {
89 path: String,
91 alias: String,
93 },
94 TemporalBucket {
96 column: String,
98 bucket: TemporalBucket,
100 alias: String,
102 },
103 CalendarDimension {
105 source_column: String,
107 calendar_column: String,
109 json_key: String,
111 bucket: TemporalBucket,
113 alias: String,
115 },
116}
117
118impl GroupBySelection {
119 #[must_use]
121 pub fn alias(&self) -> &str {
122 match self {
123 Self::Dimension { alias, .. }
124 | Self::TemporalBucket { alias, .. }
125 | Self::CalendarDimension { alias, .. } => alias,
126 }
127 }
128}
129
130#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
132pub enum AggregateSelection {
133 Count {
135 alias: String,
137 },
138 CountDistinct {
140 field: String,
142 alias: String,
144 },
145 MeasureAggregate {
147 measure: String,
149 function: AggregateFunction,
151 alias: String,
153 },
154 BoolAggregate {
156 field: String,
158 function: crate::compiler::aggregate_types::BoolAggregateFunction,
160 alias: String,
162 },
163}
164
165impl AggregateSelection {
166 #[must_use]
168 pub fn alias(&self) -> &str {
169 match self {
170 Self::Count { alias }
171 | Self::CountDistinct { alias, .. }
172 | Self::MeasureAggregate { alias, .. }
173 | Self::BoolAggregate { alias, .. } => alias,
174 }
175 }
176}
177
178#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
180pub struct HavingCondition {
181 pub aggregate: AggregateSelection,
183 pub operator: HavingOperator,
185 pub value: serde_json::Value,
187}
188
189#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
191pub struct OrderByClause {
192 pub field: String,
194 pub direction: OrderDirection,
196}
197
198#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
200pub enum OrderDirection {
201 Asc,
203 Desc,
205}
206
207#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
209pub struct AggregationPlan {
210 pub metadata: FactTableMetadata,
212 pub request: AggregationRequest,
214 pub group_by_expressions: Vec<GroupByExpression>,
216 pub aggregate_expressions: Vec<AggregateExpression>,
218 pub having_conditions: Vec<ValidatedHavingCondition>,
220}
221
222#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
224pub enum GroupByExpression {
225 JsonbPath {
227 jsonb_column: String,
229 path: String,
231 alias: String,
233 },
234 TemporalBucket {
236 column: String,
238 bucket: TemporalBucket,
240 alias: String,
242 },
243 CalendarPath {
245 calendar_column: String,
247 json_key: String,
249 alias: String,
251 },
252}
253
254#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
256pub enum AggregateExpression {
257 Count {
259 alias: String,
261 },
262 CountDistinct {
264 column: String,
266 alias: String,
268 },
269 MeasureAggregate {
271 column: String,
273 function: AggregateFunction,
275 alias: String,
277 },
278 AdvancedAggregate {
280 column: String,
282 function: AggregateFunction,
284 alias: String,
286 delimiter: Option<String>,
288 order_by: Option<Vec<OrderByClause>>,
290 },
291 BoolAggregate {
293 column: String,
295 function: crate::compiler::aggregate_types::BoolAggregateFunction,
297 alias: String,
299 },
300}
301
302#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
304pub struct ValidatedHavingCondition {
305 pub aggregate: AggregateExpression,
307 pub operator: HavingOperator,
309 pub value: serde_json::Value,
311}
312
313pub struct AggregationPlanner;
315
316impl AggregationPlanner {
317 pub fn plan(
331 request: AggregationRequest,
332 metadata: FactTableMetadata,
333 ) -> Result<AggregationPlan> {
334 let group_by_expressions = Self::validate_group_by(&request.group_by, &metadata)?;
336
337 let aggregate_expressions = Self::validate_aggregates(&request.aggregates, &metadata)?;
339
340 let having_conditions = Self::validate_having(&request.having, &aggregate_expressions)?;
342
343 Ok(AggregationPlan {
344 metadata,
345 request,
346 group_by_expressions,
347 aggregate_expressions,
348 having_conditions,
349 })
350 }
351
352 fn validate_group_by(
354 selections: &[GroupBySelection],
355 metadata: &FactTableMetadata,
356 ) -> Result<Vec<GroupByExpression>> {
357 let mut expressions = Vec::new();
358
359 for selection in selections {
360 match selection {
361 GroupBySelection::Dimension { path, alias } => {
362 expressions.push(GroupByExpression::JsonbPath {
364 jsonb_column: metadata.dimensions.name.clone(),
365 path: path.clone(),
366 alias: alias.clone(),
367 });
368 },
369 GroupBySelection::TemporalBucket {
370 column,
371 bucket,
372 alias,
373 } => {
374 let filter_exists =
376 metadata.denormalized_filters.iter().any(|f| f.name == *column);
377
378 if !filter_exists {
379 return Err(FraiseQLError::Validation {
380 message: format!(
381 "Column '{}' not found in fact table '{}'",
382 column, metadata.table_name
383 ),
384 path: None,
385 });
386 }
387
388 expressions.push(GroupByExpression::TemporalBucket {
389 column: column.clone(),
390 bucket: *bucket,
391 alias: alias.clone(),
392 });
393 },
394 GroupBySelection::CalendarDimension {
395 calendar_column,
396 json_key,
397 alias,
398 ..
399 } => {
400 expressions.push(GroupByExpression::CalendarPath {
402 calendar_column: calendar_column.clone(),
403 json_key: json_key.clone(),
404 alias: alias.clone(),
405 });
406 },
407 }
408 }
409
410 Ok(expressions)
411 }
412
413 fn validate_aggregates(
415 selections: &[AggregateSelection],
416 metadata: &FactTableMetadata,
417 ) -> Result<Vec<AggregateExpression>> {
418 let mut expressions = Vec::new();
419
420 for selection in selections {
421 match selection {
422 AggregateSelection::Count { alias } => {
423 expressions.push(AggregateExpression::Count {
424 alias: alias.clone(),
425 });
426 },
427 AggregateSelection::CountDistinct { field, alias } => {
428 let measure_exists = metadata.measures.iter().any(|m| m.name == *field);
430
431 if !measure_exists {
432 return Err(FraiseQLError::Validation {
433 message: format!(
434 "Measure '{}' not found in fact table '{}'",
435 field, metadata.table_name
436 ),
437 path: None,
438 });
439 }
440
441 expressions.push(AggregateExpression::CountDistinct {
442 column: field.clone(),
443 alias: alias.clone(),
444 });
445 },
446 AggregateSelection::MeasureAggregate {
447 measure,
448 function,
449 alias,
450 } => {
451 let measure_exists = metadata.measures.iter().any(|m| m.name == *measure);
453 let is_dimension = metadata.dimensions.paths.iter().any(|p| p.name == *measure);
454 let is_filter =
455 metadata.denormalized_filters.iter().any(|f| f.name == *measure);
456
457 if !measure_exists && !is_dimension && !is_filter {
458 return Err(FraiseQLError::Validation {
459 message: format!(
460 "Measure or field '{}' not found in fact table '{}'",
461 measure, metadata.table_name
462 ),
463 path: None,
464 });
465 }
466
467 if matches!(
469 function,
470 AggregateFunction::ArrayAgg
471 | AggregateFunction::JsonAgg
472 | AggregateFunction::JsonbAgg
473 | AggregateFunction::StringAgg
474 ) {
475 expressions.push(AggregateExpression::AdvancedAggregate {
476 column: measure.clone(),
477 function: *function,
478 alias: alias.clone(),
479 delimiter: if *function == AggregateFunction::StringAgg {
480 Some(", ".to_string())
481 } else {
482 None
483 },
484 order_by: None,
485 });
486 } else {
487 expressions.push(AggregateExpression::MeasureAggregate {
488 column: measure.clone(),
489 function: *function,
490 alias: alias.clone(),
491 });
492 }
493 },
494 AggregateSelection::BoolAggregate {
495 field,
496 function,
497 alias,
498 } => {
499 let field_exists = metadata.dimensions.paths.iter().any(|p| p.name == *field)
501 || metadata.denormalized_filters.iter().any(|f| f.name == *field);
502
503 if !field_exists {
504 return Err(FraiseQLError::Validation {
505 message: format!(
506 "Boolean field '{}' not found in fact table '{}'",
507 field, metadata.table_name
508 ),
509 path: None,
510 });
511 }
512
513 expressions.push(AggregateExpression::BoolAggregate {
514 column: field.clone(),
515 function: *function,
516 alias: alias.clone(),
517 });
518 },
519 }
520 }
521
522 Ok(expressions)
523 }
524
525 fn validate_having(
527 conditions: &[HavingCondition],
528 _aggregate_expressions: &[AggregateExpression],
529 ) -> Result<Vec<ValidatedHavingCondition>> {
530 let mut validated = Vec::new();
531
532 for condition in conditions {
533 let aggregate_expr = match &condition.aggregate {
535 AggregateSelection::Count { alias } => AggregateExpression::Count {
536 alias: alias.clone(),
537 },
538 AggregateSelection::CountDistinct { field, alias } => {
539 AggregateExpression::CountDistinct {
540 column: field.clone(),
541 alias: alias.clone(),
542 }
543 },
544 AggregateSelection::MeasureAggregate {
545 measure,
546 function,
547 alias,
548 } => {
549 if matches!(
551 function,
552 AggregateFunction::ArrayAgg
553 | AggregateFunction::JsonAgg
554 | AggregateFunction::JsonbAgg
555 | AggregateFunction::StringAgg
556 ) {
557 AggregateExpression::AdvancedAggregate {
558 column: measure.clone(),
559 function: *function,
560 alias: alias.clone(),
561 delimiter: if *function == AggregateFunction::StringAgg {
562 Some(", ".to_string())
563 } else {
564 None
565 },
566 order_by: None,
567 }
568 } else {
569 AggregateExpression::MeasureAggregate {
570 column: measure.clone(),
571 function: *function,
572 alias: alias.clone(),
573 }
574 }
575 },
576 AggregateSelection::BoolAggregate {
577 field,
578 function,
579 alias,
580 } => AggregateExpression::BoolAggregate {
581 column: field.clone(),
582 function: *function,
583 alias: alias.clone(),
584 },
585 };
586
587 validated.push(ValidatedHavingCondition {
591 aggregate: aggregate_expr,
592 operator: condition.operator,
593 value: condition.value.clone(),
594 });
595 }
596
597 Ok(validated)
598 }
599}
600
601#[cfg(test)]
602mod tests {
603 use super::*;
604 use crate::compiler::fact_table::{DimensionColumn, FilterColumn, MeasureColumn, SqlType};
605
606 fn create_test_metadata() -> FactTableMetadata {
607 FactTableMetadata {
608 table_name: "tf_sales".to_string(),
609 measures: vec![
610 MeasureColumn {
611 name: "revenue".to_string(),
612 sql_type: SqlType::Decimal,
613 nullable: false,
614 },
615 MeasureColumn {
616 name: "quantity".to_string(),
617 sql_type: SqlType::Int,
618 nullable: false,
619 },
620 ],
621 dimensions: DimensionColumn {
622 name: "dimensions".to_string(),
623 paths: vec![],
624 },
625 denormalized_filters: vec![
626 FilterColumn {
627 name: "customer_id".to_string(),
628 sql_type: SqlType::Uuid,
629 indexed: true,
630 },
631 FilterColumn {
632 name: "occurred_at".to_string(),
633 sql_type: SqlType::Timestamp,
634 indexed: true,
635 },
636 ],
637 calendar_dimensions: vec![],
638 }
639 }
640
641 #[test]
642 fn test_plan_simple_aggregation() {
643 let metadata = create_test_metadata();
644 let request = AggregationRequest {
645 table_name: "tf_sales".to_string(),
646 where_clause: None,
647 group_by: vec![],
648 aggregates: vec![
649 AggregateSelection::Count {
650 alias: "count".to_string(),
651 },
652 AggregateSelection::MeasureAggregate {
653 measure: "revenue".to_string(),
654 function: AggregateFunction::Sum,
655 alias: "revenue_sum".to_string(),
656 },
657 ],
658 having: vec![],
659 order_by: vec![],
660 limit: None,
661 offset: None,
662 };
663
664 let plan = AggregationPlanner::plan(request, metadata).unwrap();
665
666 assert_eq!(plan.aggregate_expressions.len(), 2);
667 assert!(matches!(plan.aggregate_expressions[0], AggregateExpression::Count { .. }));
668 assert!(matches!(
669 plan.aggregate_expressions[1],
670 AggregateExpression::MeasureAggregate { .. }
671 ));
672 }
673
674 #[test]
675 fn test_plan_with_group_by() {
676 let metadata = create_test_metadata();
677 let request = AggregationRequest {
678 table_name: "tf_sales".to_string(),
679 where_clause: None,
680 group_by: vec![
681 GroupBySelection::Dimension {
682 path: "category".to_string(),
683 alias: "category".to_string(),
684 },
685 GroupBySelection::TemporalBucket {
686 column: "occurred_at".to_string(),
687 bucket: TemporalBucket::Day,
688 alias: "occurred_at_day".to_string(),
689 },
690 ],
691 aggregates: vec![AggregateSelection::Count {
692 alias: "count".to_string(),
693 }],
694 having: vec![],
695 order_by: vec![],
696 limit: None,
697 offset: None,
698 };
699
700 let plan = AggregationPlanner::plan(request, metadata).unwrap();
701
702 assert_eq!(plan.group_by_expressions.len(), 2);
703 assert!(matches!(plan.group_by_expressions[0], GroupByExpression::JsonbPath { .. }));
704 assert!(matches!(plan.group_by_expressions[1], GroupByExpression::TemporalBucket { .. }));
705 }
706
707 #[test]
708 fn test_plan_with_having() {
709 let metadata = create_test_metadata();
710 let request = AggregationRequest {
711 table_name: "tf_sales".to_string(),
712 where_clause: None,
713 group_by: vec![GroupBySelection::Dimension {
714 path: "category".to_string(),
715 alias: "category".to_string(),
716 }],
717 aggregates: vec![AggregateSelection::MeasureAggregate {
718 measure: "revenue".to_string(),
719 function: AggregateFunction::Sum,
720 alias: "revenue_sum".to_string(),
721 }],
722 having: vec![HavingCondition {
723 aggregate: AggregateSelection::MeasureAggregate {
724 measure: "revenue".to_string(),
725 function: AggregateFunction::Sum,
726 alias: "revenue_sum".to_string(),
727 },
728 operator: HavingOperator::Gt,
729 value: serde_json::json!(1000),
730 }],
731 order_by: vec![],
732 limit: None,
733 offset: None,
734 };
735
736 let plan = AggregationPlanner::plan(request, metadata).unwrap();
737
738 assert_eq!(plan.having_conditions.len(), 1);
739 assert_eq!(plan.having_conditions[0].operator, HavingOperator::Gt);
740 }
741
742 #[test]
743 fn test_validate_invalid_measure() {
744 let metadata = create_test_metadata();
745 let request = AggregationRequest {
746 table_name: "tf_sales".to_string(),
747 where_clause: None,
748 group_by: vec![],
749 aggregates: vec![AggregateSelection::MeasureAggregate {
750 measure: "nonexistent".to_string(),
751 function: AggregateFunction::Sum,
752 alias: "nonexistent_sum".to_string(),
753 }],
754 having: vec![],
755 order_by: vec![],
756 limit: None,
757 offset: None,
758 };
759
760 let result = AggregationPlanner::plan(request, metadata);
761 assert!(result.is_err());
762 assert!(result.unwrap_err().to_string().contains("not found"));
763 }
764
765 #[test]
766 fn test_validate_invalid_temporal_column() {
767 let metadata = create_test_metadata();
768 let request = AggregationRequest {
769 table_name: "tf_sales".to_string(),
770 where_clause: None,
771 group_by: vec![GroupBySelection::TemporalBucket {
772 column: "nonexistent".to_string(),
773 bucket: TemporalBucket::Day,
774 alias: "day".to_string(),
775 }],
776 aggregates: vec![AggregateSelection::Count {
777 alias: "count".to_string(),
778 }],
779 having: vec![],
780 order_by: vec![],
781 limit: None,
782 offset: None,
783 };
784
785 let result = AggregationPlanner::plan(request, metadata);
786 assert!(result.is_err());
787 assert!(result.unwrap_err().to_string().contains("not found"));
788 }
789}