Skip to main content

fraiseql_core/compiler/
aggregation.rs

1//! Aggregation Execution Plan Module
2//!
3//! This module generates execution plans for GROUP BY queries with aggregations.
4//!
5//! # Execution Plan Flow
6//!
7//! ```text
8//! GraphQL Query
9//!      ↓
10//! AggregationRequest (parsed)
11//!      ↓
12//! AggregationPlan (validated, optimized)
13//!      ↓
14//! SQL Generation (database-specific)
15//!      ↓
16//! Query Execution
17//! ```
18//!
19//! # Example
20//!
21//! ```graphql
22//! query {
23//!   sales_aggregate(
24//!     where: { customer_id: { _eq: "uuid-123" } }
25//!     groupBy: { category: true, occurred_at_day: true }
26//!     having: { revenue_sum_gt: 1000 }
27//!   ) {
28//!     category
29//!     occurred_at_day
30//!     count
31//!     revenue_sum
32//!     revenue_avg
33//!   }
34//! }
35//! ```
36//!
37//! Generates:
38//!
39//! ```sql
40//! SELECT
41//!   data->>'category' AS category,
42//!   DATE_TRUNC('day', occurred_at) AS occurred_at_day,
43//!   COUNT(*) AS count,
44//!   SUM(revenue) AS revenue_sum,
45//!   AVG(revenue) AS revenue_avg
46//! FROM tf_sales
47//! WHERE customer_id = $1
48//! GROUP BY data->>'category', DATE_TRUNC('day', occurred_at)
49//! HAVING SUM(revenue) > $2
50//! ```
51
52use serde::{Deserialize, Serialize};
53
54use crate::{
55    compiler::{
56        aggregate_types::{AggregateFunction, HavingOperator, TemporalBucket},
57        fact_table::FactTableMetadata,
58    },
59    db::where_clause::WhereClause,
60    error::{FraiseQLError, Result},
61};
62
63/// Aggregation request from GraphQL query
64#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
65pub struct AggregationRequest {
66    /// Fact table name
67    pub table_name:   String,
68    /// WHERE clause filters (applied before GROUP BY)
69    pub where_clause: Option<WhereClause>,
70    /// GROUP BY selections
71    pub group_by:     Vec<GroupBySelection>,
72    /// Aggregate selections (what to compute)
73    pub aggregates:   Vec<AggregateSelection>,
74    /// HAVING clause filters (applied after GROUP BY)
75    pub having:       Vec<HavingCondition>,
76    /// ORDER BY clauses
77    pub order_by:     Vec<OrderByClause>,
78    /// LIMIT
79    pub limit:        Option<u32>,
80    /// OFFSET
81    pub offset:       Option<u32>,
82}
83
84/// GROUP BY selection
85#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
86pub enum GroupBySelection {
87    /// Group by JSONB dimension
88    Dimension {
89        /// JSONB path (e.g., "category")
90        path:  String,
91        /// Alias for result
92        alias: String,
93    },
94    /// Group by temporal bucket
95    TemporalBucket {
96        /// Column name (e.g., "occurred_at")
97        column: String,
98        /// Bucket type
99        bucket: TemporalBucket,
100        /// Alias for result
101        alias:  String,
102    },
103    /// Group by pre-computed calendar dimension
104    CalendarDimension {
105        /// Source timestamp column (e.g., "occurred_at")
106        source_column:   String,
107        /// Calendar JSONB column (e.g., "date_info")
108        calendar_column: String,
109        /// JSON key within calendar column (e.g., "month")
110        json_key:        String,
111        /// Temporal bucket type
112        bucket:          TemporalBucket,
113        /// Alias for result
114        alias:           String,
115    },
116}
117
118impl GroupBySelection {
119    /// Get the result alias for this selection
120    #[must_use]
121    pub fn alias(&self) -> &str {
122        match self {
123            Self::Dimension { alias, .. }
124            | Self::TemporalBucket { alias, .. }
125            | Self::CalendarDimension { alias, .. } => alias,
126        }
127    }
128}
129
130/// Aggregate selection (what to compute)
131#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
132pub enum AggregateSelection {
133    /// COUNT(*)
134    Count {
135        /// Alias for result
136        alias: String,
137    },
138    /// COUNT(DISTINCT field)
139    CountDistinct {
140        /// Field to count
141        field: String,
142        /// Alias for result
143        alias: String,
144    },
145    /// Aggregate function on a measure
146    MeasureAggregate {
147        /// Measure column name
148        measure:  String,
149        /// Aggregate function
150        function: AggregateFunction,
151        /// Alias for result
152        alias:    String,
153    },
154    /// Boolean aggregate
155    BoolAggregate {
156        /// Field to aggregate
157        field:    String,
158        /// Boolean aggregate function
159        function: crate::compiler::aggregate_types::BoolAggregateFunction,
160        /// Alias for result
161        alias:    String,
162    },
163}
164
165impl AggregateSelection {
166    /// Get the result alias for this selection
167    #[must_use]
168    pub fn alias(&self) -> &str {
169        match self {
170            Self::Count { alias }
171            | Self::CountDistinct { alias, .. }
172            | Self::MeasureAggregate { alias, .. }
173            | Self::BoolAggregate { alias, .. } => alias,
174        }
175    }
176}
177
178/// HAVING condition (post-aggregation filter)
179#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
180pub struct HavingCondition {
181    /// Aggregate to filter on
182    pub aggregate: AggregateSelection,
183    /// Comparison operator
184    pub operator:  HavingOperator,
185    /// Value to compare against
186    pub value:     serde_json::Value,
187}
188
189/// ORDER BY clause
190#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
191pub struct OrderByClause {
192    /// Field to order by (can be dimension, aggregate, or temporal bucket)
193    pub field:     String,
194    /// Sort direction
195    pub direction: OrderDirection,
196}
197
198/// Sort direction
199#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
200pub enum OrderDirection {
201    /// Ascending (A-Z, 0-9)
202    Asc,
203    /// Descending (Z-A, 9-0)
204    Desc,
205}
206
207/// Validated and optimized aggregation execution plan
208#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
209pub struct AggregationPlan {
210    /// Fact table metadata
211    pub metadata:              FactTableMetadata,
212    /// Original request
213    pub request:               AggregationRequest,
214    /// Validated GROUP BY expressions
215    pub group_by_expressions:  Vec<GroupByExpression>,
216    /// Validated aggregate expressions
217    pub aggregate_expressions: Vec<AggregateExpression>,
218    /// Validated HAVING conditions
219    pub having_conditions:     Vec<ValidatedHavingCondition>,
220}
221
222/// Validated GROUP BY expression
223#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
224pub enum GroupByExpression {
225    /// JSONB dimension extraction
226    JsonbPath {
227        /// JSONB column name (usually "data")
228        jsonb_column: String,
229        /// Path to extract (e.g., "category")
230        path:         String,
231        /// Result alias
232        alias:        String,
233    },
234    /// Temporal bucket with DATE_TRUNC
235    TemporalBucket {
236        /// Timestamp column name
237        column: String,
238        /// Bucket type
239        bucket: TemporalBucket,
240        /// Result alias
241        alias:  String,
242    },
243    /// Pre-computed calendar dimension extraction
244    CalendarPath {
245        /// Calendar JSONB column (e.g., "date_info")
246        calendar_column: String,
247        /// JSON key within calendar column (e.g., "month")
248        json_key:        String,
249        /// Result alias
250        alias:           String,
251    },
252}
253
254/// Validated aggregate expression
255#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
256pub enum AggregateExpression {
257    /// COUNT(*)
258    Count {
259        /// Result alias
260        alias: String,
261    },
262    /// COUNT(DISTINCT field)
263    CountDistinct {
264        /// Column to count
265        column: String,
266        /// Result alias
267        alias:  String,
268    },
269    /// Aggregate function on measure column
270    MeasureAggregate {
271        /// Measure column name
272        column:   String,
273        /// Aggregate function
274        function: AggregateFunction,
275        /// Result alias
276        alias:    String,
277    },
278    /// Advanced aggregate with optional parameters
279    AdvancedAggregate {
280        /// Column to aggregate
281        column:    String,
282        /// Aggregate function
283        function:  AggregateFunction,
284        /// Result alias
285        alias:     String,
286        /// Optional delimiter for STRING_AGG
287        delimiter: Option<String>,
288        /// Optional ORDER BY for ARRAY_AGG/STRING_AGG
289        order_by:  Option<Vec<OrderByClause>>,
290    },
291    /// Boolean aggregate (BOOL_AND/BOOL_OR)
292    BoolAggregate {
293        /// Column to aggregate (boolean expression)
294        column:   String,
295        /// Boolean aggregate function
296        function: crate::compiler::aggregate_types::BoolAggregateFunction,
297        /// Result alias
298        alias:    String,
299    },
300}
301
302/// Validated HAVING condition
303#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
304pub struct ValidatedHavingCondition {
305    /// Aggregate expression to filter on
306    pub aggregate: AggregateExpression,
307    /// Comparison operator
308    pub operator:  HavingOperator,
309    /// Value to compare against
310    pub value:     serde_json::Value,
311}
312
313/// Aggregation plan generator
314pub struct AggregationPlanner;
315
316impl AggregationPlanner {
317    /// Generate execution plan from request
318    ///
319    /// # Arguments
320    ///
321    /// * `request` - Aggregation request from GraphQL
322    /// * `metadata` - Fact table metadata
323    ///
324    /// # Errors
325    ///
326    /// Returns error if:
327    /// - Request references non-existent measures or dimensions
328    /// - GROUP BY selections are invalid
329    /// - HAVING conditions reference non-computed aggregates
330    pub fn plan(
331        request: AggregationRequest,
332        metadata: FactTableMetadata,
333    ) -> Result<AggregationPlan> {
334        // Validate and convert GROUP BY selections
335        let group_by_expressions = Self::validate_group_by(&request.group_by, &metadata)?;
336
337        // Validate and convert aggregate selections
338        let aggregate_expressions = Self::validate_aggregates(&request.aggregates, &metadata)?;
339
340        // Validate HAVING conditions
341        let having_conditions = Self::validate_having(&request.having, &aggregate_expressions)?;
342
343        Ok(AggregationPlan {
344            metadata,
345            request,
346            group_by_expressions,
347            aggregate_expressions,
348            having_conditions,
349        })
350    }
351
352    /// Validate GROUP BY selections
353    fn validate_group_by(
354        selections: &[GroupBySelection],
355        metadata: &FactTableMetadata,
356    ) -> Result<Vec<GroupByExpression>> {
357        let mut expressions = Vec::new();
358
359        for selection in selections {
360            match selection {
361                GroupBySelection::Dimension { path, alias } => {
362                    // Validate dimension exists in metadata (for now, just accept any path)
363                    expressions.push(GroupByExpression::JsonbPath {
364                        jsonb_column: metadata.dimensions.name.clone(),
365                        path:         path.clone(),
366                        alias:        alias.clone(),
367                    });
368                },
369                GroupBySelection::TemporalBucket {
370                    column,
371                    bucket,
372                    alias,
373                } => {
374                    // Validate column exists in denormalized filters
375                    let filter_exists =
376                        metadata.denormalized_filters.iter().any(|f| f.name == *column);
377
378                    if !filter_exists {
379                        return Err(FraiseQLError::Validation {
380                            message: format!(
381                                "Column '{}' not found in fact table '{}'",
382                                column, metadata.table_name
383                            ),
384                            path:    None,
385                        });
386                    }
387
388                    expressions.push(GroupByExpression::TemporalBucket {
389                        column: column.clone(),
390                        bucket: *bucket,
391                        alias:  alias.clone(),
392                    });
393                },
394                GroupBySelection::CalendarDimension {
395                    calendar_column,
396                    json_key,
397                    alias,
398                    ..
399                } => {
400                    // Calendar dimension - use pre-computed JSONB field
401                    expressions.push(GroupByExpression::CalendarPath {
402                        calendar_column: calendar_column.clone(),
403                        json_key:        json_key.clone(),
404                        alias:           alias.clone(),
405                    });
406                },
407            }
408        }
409
410        Ok(expressions)
411    }
412
413    /// Validate aggregate selections
414    fn validate_aggregates(
415        selections: &[AggregateSelection],
416        metadata: &FactTableMetadata,
417    ) -> Result<Vec<AggregateExpression>> {
418        let mut expressions = Vec::new();
419
420        for selection in selections {
421            match selection {
422                AggregateSelection::Count { alias } => {
423                    expressions.push(AggregateExpression::Count {
424                        alias: alias.clone(),
425                    });
426                },
427                AggregateSelection::CountDistinct { field, alias } => {
428                    // Validate field is a measure
429                    let measure_exists = metadata.measures.iter().any(|m| m.name == *field);
430
431                    if !measure_exists {
432                        return Err(FraiseQLError::Validation {
433                            message: format!(
434                                "Measure '{}' not found in fact table '{}'",
435                                field, metadata.table_name
436                            ),
437                            path:    None,
438                        });
439                    }
440
441                    expressions.push(AggregateExpression::CountDistinct {
442                        column: field.clone(),
443                        alias:  alias.clone(),
444                    });
445                },
446                AggregateSelection::MeasureAggregate {
447                    measure,
448                    function,
449                    alias,
450                } => {
451                    // Validate measure exists (or is a dimension path for advanced aggregates)
452                    let measure_exists = metadata.measures.iter().any(|m| m.name == *measure);
453                    let is_dimension = metadata.dimensions.paths.iter().any(|p| p.name == *measure);
454                    let is_filter =
455                        metadata.denormalized_filters.iter().any(|f| f.name == *measure);
456
457                    if !measure_exists && !is_dimension && !is_filter {
458                        return Err(FraiseQLError::Validation {
459                            message: format!(
460                                "Measure or field '{}' not found in fact table '{}'",
461                                measure, metadata.table_name
462                            ),
463                            path:    None,
464                        });
465                    }
466
467                    // For advanced aggregates, create AdvancedAggregate variant
468                    if matches!(
469                        function,
470                        AggregateFunction::ArrayAgg
471                            | AggregateFunction::JsonAgg
472                            | AggregateFunction::JsonbAgg
473                            | AggregateFunction::StringAgg
474                    ) {
475                        expressions.push(AggregateExpression::AdvancedAggregate {
476                            column:    measure.clone(),
477                            function:  *function,
478                            alias:     alias.clone(),
479                            delimiter: if *function == AggregateFunction::StringAgg {
480                                Some(", ".to_string())
481                            } else {
482                                None
483                            },
484                            order_by:  None,
485                        });
486                    } else {
487                        expressions.push(AggregateExpression::MeasureAggregate {
488                            column:   measure.clone(),
489                            function: *function,
490                            alias:    alias.clone(),
491                        });
492                    }
493                },
494                AggregateSelection::BoolAggregate {
495                    field,
496                    function,
497                    alias,
498                } => {
499                    // Validate field exists
500                    let field_exists = metadata.dimensions.paths.iter().any(|p| p.name == *field)
501                        || metadata.denormalized_filters.iter().any(|f| f.name == *field);
502
503                    if !field_exists {
504                        return Err(FraiseQLError::Validation {
505                            message: format!(
506                                "Boolean field '{}' not found in fact table '{}'",
507                                field, metadata.table_name
508                            ),
509                            path:    None,
510                        });
511                    }
512
513                    expressions.push(AggregateExpression::BoolAggregate {
514                        column:   field.clone(),
515                        function: *function,
516                        alias:    alias.clone(),
517                    });
518                },
519            }
520        }
521
522        Ok(expressions)
523    }
524
525    /// Validate HAVING conditions
526    fn validate_having(
527        conditions: &[HavingCondition],
528        _aggregate_expressions: &[AggregateExpression],
529    ) -> Result<Vec<ValidatedHavingCondition>> {
530        let mut validated = Vec::new();
531
532        for condition in conditions {
533            // Convert the aggregate selection to an expression
534            let aggregate_expr = match &condition.aggregate {
535                AggregateSelection::Count { alias } => AggregateExpression::Count {
536                    alias: alias.clone(),
537                },
538                AggregateSelection::CountDistinct { field, alias } => {
539                    AggregateExpression::CountDistinct {
540                        column: field.clone(),
541                        alias:  alias.clone(),
542                    }
543                },
544                AggregateSelection::MeasureAggregate {
545                    measure,
546                    function,
547                    alias,
548                } => {
549                    // For advanced aggregates in HAVING, create AdvancedAggregate variant
550                    if matches!(
551                        function,
552                        AggregateFunction::ArrayAgg
553                            | AggregateFunction::JsonAgg
554                            | AggregateFunction::JsonbAgg
555                            | AggregateFunction::StringAgg
556                    ) {
557                        AggregateExpression::AdvancedAggregate {
558                            column:    measure.clone(),
559                            function:  *function,
560                            alias:     alias.clone(),
561                            delimiter: if *function == AggregateFunction::StringAgg {
562                                Some(", ".to_string())
563                            } else {
564                                None
565                            },
566                            order_by:  None,
567                        }
568                    } else {
569                        AggregateExpression::MeasureAggregate {
570                            column:   measure.clone(),
571                            function: *function,
572                            alias:    alias.clone(),
573                        }
574                    }
575                },
576                AggregateSelection::BoolAggregate {
577                    field,
578                    function,
579                    alias,
580                } => AggregateExpression::BoolAggregate {
581                    column:   field.clone(),
582                    function: *function,
583                    alias:    alias.clone(),
584                },
585            };
586
587            // Note: We don't strictly require the aggregate to be in the SELECT list
588            // Some databases allow filtering on aggregates not in SELECT
589
590            validated.push(ValidatedHavingCondition {
591                aggregate: aggregate_expr,
592                operator:  condition.operator,
593                value:     condition.value.clone(),
594            });
595        }
596
597        Ok(validated)
598    }
599}
600
601#[cfg(test)]
602mod tests {
603    use super::*;
604    use crate::compiler::fact_table::{DimensionColumn, FilterColumn, MeasureColumn, SqlType};
605
606    fn create_test_metadata() -> FactTableMetadata {
607        FactTableMetadata {
608            table_name:           "tf_sales".to_string(),
609            measures:             vec![
610                MeasureColumn {
611                    name:     "revenue".to_string(),
612                    sql_type: SqlType::Decimal,
613                    nullable: false,
614                },
615                MeasureColumn {
616                    name:     "quantity".to_string(),
617                    sql_type: SqlType::Int,
618                    nullable: false,
619                },
620            ],
621            dimensions:           DimensionColumn {
622                name:  "dimensions".to_string(),
623                paths: vec![],
624            },
625            denormalized_filters: vec![
626                FilterColumn {
627                    name:     "customer_id".to_string(),
628                    sql_type: SqlType::Uuid,
629                    indexed:  true,
630                },
631                FilterColumn {
632                    name:     "occurred_at".to_string(),
633                    sql_type: SqlType::Timestamp,
634                    indexed:  true,
635                },
636            ],
637            calendar_dimensions:  vec![],
638        }
639    }
640
641    #[test]
642    fn test_plan_simple_aggregation() {
643        let metadata = create_test_metadata();
644        let request = AggregationRequest {
645            table_name:   "tf_sales".to_string(),
646            where_clause: None,
647            group_by:     vec![],
648            aggregates:   vec![
649                AggregateSelection::Count {
650                    alias: "count".to_string(),
651                },
652                AggregateSelection::MeasureAggregate {
653                    measure:  "revenue".to_string(),
654                    function: AggregateFunction::Sum,
655                    alias:    "revenue_sum".to_string(),
656                },
657            ],
658            having:       vec![],
659            order_by:     vec![],
660            limit:        None,
661            offset:       None,
662        };
663
664        let plan = AggregationPlanner::plan(request, metadata).unwrap();
665
666        assert_eq!(plan.aggregate_expressions.len(), 2);
667        assert!(matches!(plan.aggregate_expressions[0], AggregateExpression::Count { .. }));
668        assert!(matches!(
669            plan.aggregate_expressions[1],
670            AggregateExpression::MeasureAggregate { .. }
671        ));
672    }
673
674    #[test]
675    fn test_plan_with_group_by() {
676        let metadata = create_test_metadata();
677        let request = AggregationRequest {
678            table_name:   "tf_sales".to_string(),
679            where_clause: None,
680            group_by:     vec![
681                GroupBySelection::Dimension {
682                    path:  "category".to_string(),
683                    alias: "category".to_string(),
684                },
685                GroupBySelection::TemporalBucket {
686                    column: "occurred_at".to_string(),
687                    bucket: TemporalBucket::Day,
688                    alias:  "occurred_at_day".to_string(),
689                },
690            ],
691            aggregates:   vec![AggregateSelection::Count {
692                alias: "count".to_string(),
693            }],
694            having:       vec![],
695            order_by:     vec![],
696            limit:        None,
697            offset:       None,
698        };
699
700        let plan = AggregationPlanner::plan(request, metadata).unwrap();
701
702        assert_eq!(plan.group_by_expressions.len(), 2);
703        assert!(matches!(plan.group_by_expressions[0], GroupByExpression::JsonbPath { .. }));
704        assert!(matches!(plan.group_by_expressions[1], GroupByExpression::TemporalBucket { .. }));
705    }
706
707    #[test]
708    fn test_plan_with_having() {
709        let metadata = create_test_metadata();
710        let request = AggregationRequest {
711            table_name:   "tf_sales".to_string(),
712            where_clause: None,
713            group_by:     vec![GroupBySelection::Dimension {
714                path:  "category".to_string(),
715                alias: "category".to_string(),
716            }],
717            aggregates:   vec![AggregateSelection::MeasureAggregate {
718                measure:  "revenue".to_string(),
719                function: AggregateFunction::Sum,
720                alias:    "revenue_sum".to_string(),
721            }],
722            having:       vec![HavingCondition {
723                aggregate: AggregateSelection::MeasureAggregate {
724                    measure:  "revenue".to_string(),
725                    function: AggregateFunction::Sum,
726                    alias:    "revenue_sum".to_string(),
727                },
728                operator:  HavingOperator::Gt,
729                value:     serde_json::json!(1000),
730            }],
731            order_by:     vec![],
732            limit:        None,
733            offset:       None,
734        };
735
736        let plan = AggregationPlanner::plan(request, metadata).unwrap();
737
738        assert_eq!(plan.having_conditions.len(), 1);
739        assert_eq!(plan.having_conditions[0].operator, HavingOperator::Gt);
740    }
741
742    #[test]
743    fn test_validate_invalid_measure() {
744        let metadata = create_test_metadata();
745        let request = AggregationRequest {
746            table_name:   "tf_sales".to_string(),
747            where_clause: None,
748            group_by:     vec![],
749            aggregates:   vec![AggregateSelection::MeasureAggregate {
750                measure:  "nonexistent".to_string(),
751                function: AggregateFunction::Sum,
752                alias:    "nonexistent_sum".to_string(),
753            }],
754            having:       vec![],
755            order_by:     vec![],
756            limit:        None,
757            offset:       None,
758        };
759
760        let result = AggregationPlanner::plan(request, metadata);
761        assert!(result.is_err());
762        assert!(result.unwrap_err().to_string().contains("not found"));
763    }
764
765    #[test]
766    fn test_validate_invalid_temporal_column() {
767        let metadata = create_test_metadata();
768        let request = AggregationRequest {
769            table_name:   "tf_sales".to_string(),
770            where_clause: None,
771            group_by:     vec![GroupBySelection::TemporalBucket {
772                column: "nonexistent".to_string(),
773                bucket: TemporalBucket::Day,
774                alias:  "day".to_string(),
775            }],
776            aggregates:   vec![AggregateSelection::Count {
777                alias: "count".to_string(),
778            }],
779            having:       vec![],
780            order_by:     vec![],
781            limit:        None,
782            offset:       None,
783        };
784
785        let result = AggregationPlanner::plan(request, metadata);
786        assert!(result.is_err());
787        assert!(result.unwrap_err().to_string().contains("not found"));
788    }
789}