1use serde::{Deserialize, Serialize};
53
54pub use crate::types::{OrderByClause, OrderDirection};
55use crate::{
56 compiler::{
57 aggregate_types::{AggregateFunction, HavingOperator, TemporalBucket},
58 fact_table::FactTableMetadata,
59 },
60 db::where_clause::WhereClause,
61 error::{FraiseQLError, Result},
62};
63
64#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
66pub struct AggregationRequest {
67 pub table_name: String,
69 pub where_clause: Option<WhereClause>,
71 pub group_by: Vec<GroupBySelection>,
73 pub aggregates: Vec<AggregateSelection>,
75 pub having: Vec<HavingCondition>,
77 pub order_by: Vec<OrderByClause>,
79 pub limit: Option<u32>,
81 pub offset: Option<u32>,
83}
84
85#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
87#[non_exhaustive]
88pub enum GroupBySelection {
89 Dimension {
91 path: String,
93 alias: String,
95 },
96 TemporalBucket {
98 column: String,
100 bucket: TemporalBucket,
102 alias: String,
104 },
105 CalendarDimension {
107 source_column: String,
109 calendar_column: String,
111 json_key: String,
113 bucket: TemporalBucket,
115 alias: String,
117 },
118}
119
120impl GroupBySelection {
121 #[must_use]
123 pub fn alias(&self) -> &str {
124 match self {
125 Self::Dimension { alias, .. }
126 | Self::TemporalBucket { alias, .. }
127 | Self::CalendarDimension { alias, .. } => alias,
128 }
129 }
130}
131
132#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
134#[non_exhaustive]
135pub enum AggregateSelection {
136 Count {
138 alias: String,
140 },
141 CountDistinct {
143 field: String,
145 alias: String,
147 },
148 MeasureAggregate {
150 measure: String,
152 function: AggregateFunction,
154 alias: String,
156 },
157 BoolAggregate {
159 field: String,
161 function: crate::compiler::aggregate_types::BoolAggregateFunction,
163 alias: String,
165 },
166}
167
168impl AggregateSelection {
169 #[must_use]
171 pub fn alias(&self) -> &str {
172 match self {
173 Self::Count { alias }
174 | Self::CountDistinct { alias, .. }
175 | Self::MeasureAggregate { alias, .. }
176 | Self::BoolAggregate { alias, .. } => alias,
177 }
178 }
179}
180
181#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
183pub struct HavingCondition {
184 pub aggregate: AggregateSelection,
186 pub operator: HavingOperator,
188 pub value: serde_json::Value,
190}
191
192#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
194pub struct AggregationPlan {
195 pub metadata: FactTableMetadata,
197 pub request: AggregationRequest,
199 pub group_by_expressions: Vec<GroupByExpression>,
201 pub aggregate_expressions: Vec<AggregateExpression>,
203 pub having_conditions: Vec<ValidatedHavingCondition>,
205}
206
207#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
209#[non_exhaustive]
210pub enum GroupByExpression {
211 JsonbPath {
213 jsonb_column: String,
215 path: String,
217 alias: String,
219 },
220 TemporalBucket {
222 column: String,
224 bucket: TemporalBucket,
226 alias: String,
228 },
229 CalendarPath {
231 calendar_column: String,
233 json_key: String,
235 alias: String,
237 },
238}
239
240#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
242#[non_exhaustive]
243pub enum AggregateExpression {
244 Count {
246 alias: String,
248 },
249 CountDistinct {
251 column: String,
253 alias: String,
255 },
256 MeasureAggregate {
258 column: String,
260 function: AggregateFunction,
262 alias: String,
264 },
265 AdvancedAggregate {
267 column: String,
269 function: AggregateFunction,
271 alias: String,
273 delimiter: Option<String>,
275 order_by: Option<Vec<OrderByClause>>,
277 },
278 BoolAggregate {
280 column: String,
282 function: crate::compiler::aggregate_types::BoolAggregateFunction,
284 alias: String,
286 },
287}
288
289#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
291pub struct ValidatedHavingCondition {
292 pub aggregate: AggregateExpression,
294 pub operator: HavingOperator,
296 pub value: serde_json::Value,
298}
299
300pub struct AggregationPlanner;
302
303impl AggregationPlanner {
304 pub fn plan(
318 request: AggregationRequest,
319 metadata: FactTableMetadata,
320 ) -> Result<AggregationPlan> {
321 let group_by_expressions = Self::validate_group_by(&request.group_by, &metadata)?;
323
324 let aggregate_expressions = Self::validate_aggregates(&request.aggregates, &metadata)?;
326
327 let having_conditions = Self::validate_having(&request.having, &aggregate_expressions)?;
329
330 Ok(AggregationPlan {
331 metadata,
332 request,
333 group_by_expressions,
334 aggregate_expressions,
335 having_conditions,
336 })
337 }
338
339 fn validate_group_by(
341 selections: &[GroupBySelection],
342 metadata: &FactTableMetadata,
343 ) -> Result<Vec<GroupByExpression>> {
344 let mut expressions = Vec::new();
345
346 for selection in selections {
347 match selection {
348 GroupBySelection::Dimension { path, alias } => {
349 let known_paths = &metadata.dimensions.paths;
354 if !known_paths.is_empty() && !known_paths.iter().any(|p| p.name == *path) {
355 return Err(FraiseQLError::Validation {
356 message: format!(
357 "Dimension '{}' not found in fact table '{}'",
358 path, metadata.table_name
359 ),
360 path: None,
361 });
362 }
363 expressions.push(GroupByExpression::JsonbPath {
364 jsonb_column: metadata.dimensions.name.clone(),
365 path: path.clone(),
366 alias: alias.clone(),
367 });
368 },
369 GroupBySelection::TemporalBucket {
370 column,
371 bucket,
372 alias,
373 } => {
374 let filter_exists =
376 metadata.denormalized_filters.iter().any(|f| f.name == *column);
377
378 if !filter_exists {
379 return Err(FraiseQLError::Validation {
380 message: format!(
381 "Column '{}' not found in fact table '{}'",
382 column, metadata.table_name
383 ),
384 path: None,
385 });
386 }
387
388 expressions.push(GroupByExpression::TemporalBucket {
389 column: column.clone(),
390 bucket: *bucket,
391 alias: alias.clone(),
392 });
393 },
394 GroupBySelection::CalendarDimension {
395 calendar_column,
396 json_key,
397 alias,
398 ..
399 } => {
400 expressions.push(GroupByExpression::CalendarPath {
402 calendar_column: calendar_column.clone(),
403 json_key: json_key.clone(),
404 alias: alias.clone(),
405 });
406 },
407 }
408 }
409
410 Ok(expressions)
411 }
412
413 fn validate_aggregates(
415 selections: &[AggregateSelection],
416 metadata: &FactTableMetadata,
417 ) -> Result<Vec<AggregateExpression>> {
418 let mut expressions = Vec::new();
419
420 for selection in selections {
421 match selection {
422 AggregateSelection::Count { alias } => {
423 expressions.push(AggregateExpression::Count {
424 alias: alias.clone(),
425 });
426 },
427 AggregateSelection::CountDistinct { field, alias } => {
428 let measure_exists = metadata.measures.iter().any(|m| m.name == *field);
430
431 if !measure_exists {
432 return Err(FraiseQLError::Validation {
433 message: format!(
434 "Measure '{}' not found in fact table '{}'",
435 field, metadata.table_name
436 ),
437 path: None,
438 });
439 }
440
441 expressions.push(AggregateExpression::CountDistinct {
442 column: field.clone(),
443 alias: alias.clone(),
444 });
445 },
446 AggregateSelection::MeasureAggregate {
447 measure,
448 function,
449 alias,
450 } => {
451 let measure_exists = metadata.measures.iter().any(|m| m.name == *measure);
453 let is_dimension = metadata.dimensions.paths.iter().any(|p| p.name == *measure);
454 let is_filter =
455 metadata.denormalized_filters.iter().any(|f| f.name == *measure);
456
457 if !measure_exists && !is_dimension && !is_filter {
458 return Err(FraiseQLError::Validation {
459 message: format!(
460 "Measure or field '{}' not found in fact table '{}'",
461 measure, metadata.table_name
462 ),
463 path: None,
464 });
465 }
466
467 if matches!(
469 function,
470 AggregateFunction::ArrayAgg
471 | AggregateFunction::JsonAgg
472 | AggregateFunction::JsonbAgg
473 | AggregateFunction::StringAgg
474 ) {
475 expressions.push(AggregateExpression::AdvancedAggregate {
476 column: measure.clone(),
477 function: *function,
478 alias: alias.clone(),
479 delimiter: if *function == AggregateFunction::StringAgg {
480 Some(", ".to_string())
481 } else {
482 None
483 },
484 order_by: None,
485 });
486 } else {
487 expressions.push(AggregateExpression::MeasureAggregate {
488 column: measure.clone(),
489 function: *function,
490 alias: alias.clone(),
491 });
492 }
493 },
494 AggregateSelection::BoolAggregate {
495 field,
496 function,
497 alias,
498 } => {
499 let field_exists = metadata.dimensions.paths.iter().any(|p| p.name == *field)
501 || metadata.denormalized_filters.iter().any(|f| f.name == *field);
502
503 if !field_exists {
504 return Err(FraiseQLError::Validation {
505 message: format!(
506 "Boolean field '{}' not found in fact table '{}'",
507 field, metadata.table_name
508 ),
509 path: None,
510 });
511 }
512
513 expressions.push(AggregateExpression::BoolAggregate {
514 column: field.clone(),
515 function: *function,
516 alias: alias.clone(),
517 });
518 },
519 }
520 }
521
522 Ok(expressions)
523 }
524
525 fn validate_having(
527 conditions: &[HavingCondition],
528 _aggregate_expressions: &[AggregateExpression],
529 ) -> Result<Vec<ValidatedHavingCondition>> {
530 let mut validated = Vec::new();
531
532 for condition in conditions {
533 let aggregate_expr = match &condition.aggregate {
535 AggregateSelection::Count { alias } => AggregateExpression::Count {
536 alias: alias.clone(),
537 },
538 AggregateSelection::CountDistinct { field, alias } => {
539 AggregateExpression::CountDistinct {
540 column: field.clone(),
541 alias: alias.clone(),
542 }
543 },
544 AggregateSelection::MeasureAggregate {
545 measure,
546 function,
547 alias,
548 } => {
549 if matches!(
551 function,
552 AggregateFunction::ArrayAgg
553 | AggregateFunction::JsonAgg
554 | AggregateFunction::JsonbAgg
555 | AggregateFunction::StringAgg
556 ) {
557 AggregateExpression::AdvancedAggregate {
558 column: measure.clone(),
559 function: *function,
560 alias: alias.clone(),
561 delimiter: if *function == AggregateFunction::StringAgg {
562 Some(", ".to_string())
563 } else {
564 None
565 },
566 order_by: None,
567 }
568 } else {
569 AggregateExpression::MeasureAggregate {
570 column: measure.clone(),
571 function: *function,
572 alias: alias.clone(),
573 }
574 }
575 },
576 AggregateSelection::BoolAggregate {
577 field,
578 function,
579 alias,
580 } => AggregateExpression::BoolAggregate {
581 column: field.clone(),
582 function: *function,
583 alias: alias.clone(),
584 },
585 };
586
587 validated.push(ValidatedHavingCondition {
591 aggregate: aggregate_expr,
592 operator: condition.operator,
593 value: condition.value.clone(),
594 });
595 }
596
597 Ok(validated)
598 }
599}
600
601#[cfg(test)]
602mod tests {
603 #![allow(clippy::unwrap_used)] use super::*;
606 use crate::compiler::fact_table::{DimensionColumn, FilterColumn, MeasureColumn, SqlType};
607
608 fn create_test_metadata() -> FactTableMetadata {
609 FactTableMetadata {
610 table_name: "tf_sales".to_string(),
611 measures: vec![
612 MeasureColumn {
613 name: "revenue".to_string(),
614 sql_type: SqlType::Decimal,
615 nullable: false,
616 },
617 MeasureColumn {
618 name: "quantity".to_string(),
619 sql_type: SqlType::Int,
620 nullable: false,
621 },
622 ],
623 dimensions: DimensionColumn {
624 name: "dimensions".to_string(),
625 paths: vec![],
626 },
627 denormalized_filters: vec![
628 FilterColumn {
629 name: "customer_id".to_string(),
630 sql_type: SqlType::Uuid,
631 indexed: true,
632 },
633 FilterColumn {
634 name: "occurred_at".to_string(),
635 sql_type: SqlType::Timestamp,
636 indexed: true,
637 },
638 ],
639 calendar_dimensions: vec![],
640 }
641 }
642
643 #[test]
644 fn test_plan_simple_aggregation() {
645 let metadata = create_test_metadata();
646 let request = AggregationRequest {
647 table_name: "tf_sales".to_string(),
648 where_clause: None,
649 group_by: vec![],
650 aggregates: vec![
651 AggregateSelection::Count {
652 alias: "count".to_string(),
653 },
654 AggregateSelection::MeasureAggregate {
655 measure: "revenue".to_string(),
656 function: AggregateFunction::Sum,
657 alias: "revenue_sum".to_string(),
658 },
659 ],
660 having: vec![],
661 order_by: vec![],
662 limit: None,
663 offset: None,
664 };
665
666 let plan = AggregationPlanner::plan(request, metadata).unwrap();
667
668 assert_eq!(plan.aggregate_expressions.len(), 2);
669 assert!(matches!(plan.aggregate_expressions[0], AggregateExpression::Count { .. }));
670 assert!(matches!(
671 plan.aggregate_expressions[1],
672 AggregateExpression::MeasureAggregate { .. }
673 ));
674 }
675
676 #[test]
677 fn test_plan_with_group_by() {
678 let metadata = create_test_metadata();
679 let request = AggregationRequest {
680 table_name: "tf_sales".to_string(),
681 where_clause: None,
682 group_by: vec![
683 GroupBySelection::Dimension {
684 path: "category".to_string(),
685 alias: "category".to_string(),
686 },
687 GroupBySelection::TemporalBucket {
688 column: "occurred_at".to_string(),
689 bucket: TemporalBucket::Day,
690 alias: "occurred_at_day".to_string(),
691 },
692 ],
693 aggregates: vec![AggregateSelection::Count {
694 alias: "count".to_string(),
695 }],
696 having: vec![],
697 order_by: vec![],
698 limit: None,
699 offset: None,
700 };
701
702 let plan = AggregationPlanner::plan(request, metadata).unwrap();
703
704 assert_eq!(plan.group_by_expressions.len(), 2);
705 assert!(matches!(plan.group_by_expressions[0], GroupByExpression::JsonbPath { .. }));
706 assert!(matches!(plan.group_by_expressions[1], GroupByExpression::TemporalBucket { .. }));
707 }
708
709 #[test]
710 fn test_plan_with_having() {
711 let metadata = create_test_metadata();
712 let request = AggregationRequest {
713 table_name: "tf_sales".to_string(),
714 where_clause: None,
715 group_by: vec![GroupBySelection::Dimension {
716 path: "category".to_string(),
717 alias: "category".to_string(),
718 }],
719 aggregates: vec![AggregateSelection::MeasureAggregate {
720 measure: "revenue".to_string(),
721 function: AggregateFunction::Sum,
722 alias: "revenue_sum".to_string(),
723 }],
724 having: vec![HavingCondition {
725 aggregate: AggregateSelection::MeasureAggregate {
726 measure: "revenue".to_string(),
727 function: AggregateFunction::Sum,
728 alias: "revenue_sum".to_string(),
729 },
730 operator: HavingOperator::Gt,
731 value: serde_json::json!(1000),
732 }],
733 order_by: vec![],
734 limit: None,
735 offset: None,
736 };
737
738 let plan = AggregationPlanner::plan(request, metadata).unwrap();
739
740 assert_eq!(plan.having_conditions.len(), 1);
741 assert_eq!(plan.having_conditions[0].operator, HavingOperator::Gt);
742 }
743
744 #[test]
745 fn test_validate_invalid_measure() {
746 let metadata = create_test_metadata();
747 let request = AggregationRequest {
748 table_name: "tf_sales".to_string(),
749 where_clause: None,
750 group_by: vec![],
751 aggregates: vec![AggregateSelection::MeasureAggregate {
752 measure: "nonexistent".to_string(),
753 function: AggregateFunction::Sum,
754 alias: "nonexistent_sum".to_string(),
755 }],
756 having: vec![],
757 order_by: vec![],
758 limit: None,
759 offset: None,
760 };
761
762 let result = AggregationPlanner::plan(request, metadata);
763 assert!(
764 matches!(&result, Err(FraiseQLError::Validation { message, .. }) if message.contains("not found")),
765 "expected Validation error about measure not found, got: {result:?}"
766 );
767 }
768
769 #[test]
770 fn test_validate_invalid_temporal_column() {
771 let metadata = create_test_metadata();
772 let request = AggregationRequest {
773 table_name: "tf_sales".to_string(),
774 where_clause: None,
775 group_by: vec![GroupBySelection::TemporalBucket {
776 column: "nonexistent".to_string(),
777 bucket: TemporalBucket::Day,
778 alias: "day".to_string(),
779 }],
780 aggregates: vec![AggregateSelection::Count {
781 alias: "count".to_string(),
782 }],
783 having: vec![],
784 order_by: vec![],
785 limit: None,
786 offset: None,
787 };
788
789 let result = AggregationPlanner::plan(request, metadata);
790 assert!(
791 matches!(&result, Err(FraiseQLError::Validation { message, .. }) if message.contains("not found")),
792 "expected Validation error about column not found, got: {result:?}"
793 );
794 }
795
796 #[test]
797 fn test_order_by_from_graphql_json_object_format() {
798 let json = serde_json::json!({ "name": "DESC", "created_at": "ASC" });
799 let clauses = OrderByClause::from_graphql_json(&json).unwrap();
800 assert_eq!(clauses.len(), 2);
801 assert!(clauses.iter().any(|c| c.field == "name" && c.direction == OrderDirection::Desc));
802 assert!(
803 clauses
804 .iter()
805 .any(|c| c.field == "created_at" && c.direction == OrderDirection::Asc)
806 );
807 }
808
809 #[test]
810 fn test_order_by_from_graphql_json_array_format() {
811 let json = serde_json::json!([
812 { "field": "name", "direction": "DESC" },
813 { "field": "age" }
814 ]);
815 let clauses = OrderByClause::from_graphql_json(&json).unwrap();
816 assert_eq!(clauses.len(), 2);
817 assert_eq!(clauses[0].field, "name");
818 assert_eq!(clauses[0].direction, OrderDirection::Desc);
819 assert_eq!(clauses[1].field, "age");
820 assert_eq!(clauses[1].direction, OrderDirection::Asc); }
822
823 #[test]
824 fn test_order_by_from_graphql_json_invalid_direction() {
825 let json = serde_json::json!({ "name": "INVALID" });
826 let result = OrderByClause::from_graphql_json(&json);
827 assert!(
828 matches!(result, Err(FraiseQLError::Validation { .. })),
829 "expected Validation error for invalid direction, got: {result:?}"
830 );
831 }
832
833 #[test]
834 fn test_order_by_rejects_sql_injection_in_field() {
835 let json = serde_json::json!({ "x' || pg_sleep(5) || '": "ASC" });
836 let result = OrderByClause::from_graphql_json(&json);
837 assert!(
838 matches!(result, Err(FraiseQLError::Validation { .. })),
839 "expected Validation error for SQL injection in field, got: {result:?}"
840 );
841 }
842
843 #[test]
844 fn test_order_by_rejects_field_with_dot() {
845 let json = serde_json::json!({ "a.b": "ASC" });
846 let result = OrderByClause::from_graphql_json(&json);
847 assert!(
848 matches!(result, Err(FraiseQLError::Validation { .. })),
849 "expected Validation error for dot in field name, got: {result:?}"
850 );
851 }
852
853 #[test]
854 fn test_order_by_rejects_empty_field() {
855 let json = serde_json::json!({ "": "ASC" });
856 let result = OrderByClause::from_graphql_json(&json);
857 assert!(
858 matches!(result, Err(FraiseQLError::Validation { .. })),
859 "expected Validation error for empty field name, got: {result:?}"
860 );
861 }
862
863 #[test]
864 fn test_order_by_accepts_valid_identifiers() {
865 let json = serde_json::json!({ "created_at": "DESC", "_score": "ASC" });
866 let clauses = OrderByClause::from_graphql_json(&json).unwrap();
867 assert_eq!(clauses.len(), 2);
868 }
869
870 #[test]
871 fn test_order_by_array_rejects_injection_field() {
872 let json = serde_json::json!([{ "field": "x' OR '1'='1", "direction": "ASC" }]);
873 let result = OrderByClause::from_graphql_json(&json);
874 assert!(
875 matches!(result, Err(FraiseQLError::Validation { .. })),
876 "expected Validation error for SQL injection in array field, got: {result:?}"
877 );
878 }
879
880 fn create_metadata_with_paths() -> FactTableMetadata {
882 use crate::compiler::fact_table::DimensionPath;
883 let mut meta = create_test_metadata();
884 meta.dimensions.paths = vec![DimensionPath {
885 name: "category".to_string(),
886 json_path: "dimensions->>'category'".to_string(),
887 data_type: "text".to_string(),
888 }];
889 meta
890 }
891
892 #[test]
893 fn test_dimension_allowlist_accepts_declared_path() {
894 let metadata = create_metadata_with_paths();
895 let request = AggregationRequest {
896 table_name: "tf_sales".to_string(),
897 where_clause: None,
898 group_by: vec![GroupBySelection::Dimension {
899 path: "category".to_string(),
900 alias: "category".to_string(),
901 }],
902 aggregates: vec![AggregateSelection::Count {
903 alias: "count".to_string(),
904 }],
905 having: vec![],
906 order_by: vec![],
907 limit: None,
908 offset: None,
909 };
910 AggregationPlanner::plan(request, metadata)
911 .unwrap_or_else(|e| panic!("declared dimension path should be accepted: {e}"));
912 }
913
914 #[test]
915 fn test_dimension_allowlist_rejects_unknown_path() {
916 let metadata = create_metadata_with_paths();
917 let request = AggregationRequest {
918 table_name: "tf_sales".to_string(),
919 where_clause: None,
920 group_by: vec![GroupBySelection::Dimension {
921 path: "undeclared_path".to_string(),
922 alias: "x".to_string(),
923 }],
924 aggregates: vec![AggregateSelection::Count {
925 alias: "count".to_string(),
926 }],
927 having: vec![],
928 order_by: vec![],
929 limit: None,
930 offset: None,
931 };
932 let result = AggregationPlanner::plan(request, metadata);
933 assert!(
934 matches!(&result, Err(FraiseQLError::Validation { message, .. }) if message.contains("not found")),
935 "expected Validation error about undeclared dimension path, got: {result:?}"
936 );
937 }
938
939 #[test]
940 fn test_dimension_allowlist_accepts_any_path_when_paths_empty() {
941 let metadata = create_test_metadata(); let request = AggregationRequest {
945 table_name: "tf_sales".to_string(),
946 where_clause: None,
947 group_by: vec![GroupBySelection::Dimension {
948 path: "any_undeclared_path".to_string(),
949 alias: "x".to_string(),
950 }],
951 aggregates: vec![AggregateSelection::Count {
952 alias: "count".to_string(),
953 }],
954 having: vec![],
955 order_by: vec![],
956 limit: None,
957 offset: None,
958 };
959 AggregationPlanner::plan(request, metadata)
960 .unwrap_or_else(|e| panic!("any path should be accepted when paths empty: {e}"));
961 }
962}