Skip to main content

reddb_server/storage/query/optimizer/
decorrelate.rs

1//! Subquery Decorrelation Optimizer
2//!
3//! Transforms correlated subqueries into efficient join-based queries.
4//!
5//! # Motivation
6//!
7//! Correlated subqueries are evaluated per-row of the outer query (O(n²)).
8//! Decorrelation transforms them into joins which can be executed more efficiently (O(n log n)).
9//!
10//! # Example Transformation
11//!
12//! **Before (correlated):**
13//! ```sql
14//! SELECT * FROM orders o
15//! WHERE total > (SELECT AVG(total) FROM orders WHERE customer_id = o.customer_id)
16//! ```
17//!
18//! **After (decorrelated):**
19//! ```sql
20//! SELECT o.* FROM orders o
21//! JOIN (SELECT customer_id, AVG(total) as avg_total FROM orders GROUP BY customer_id) sub
22//!   ON o.customer_id = sub.customer_id
23//! WHERE o.total > sub.avg_total
24//! ```
25//!
26//! # Supported Patterns
27//!
28//! - Scalar correlated subqueries with equality correlation predicates
29//! - IN/EXISTS correlated subqueries
30//! - Aggregation subqueries (GROUP BY the correlation columns)
31
32/// Represents a correlation predicate between outer and inner queries
33#[derive(Debug, Clone, PartialEq, Eq)]
34pub struct CorrelationPredicate {
35    /// Column from outer query
36    pub outer_col: String,
37    /// Column from inner query
38    pub inner_col: String,
39    /// Comparison operator (typically Eq for decorrelation)
40    pub op: CorrelationOp,
41}
42
43/// Correlation comparison operator
44#[derive(Debug, Clone, Copy, PartialEq, Eq)]
45pub enum CorrelationOp {
46    /// Equality correlation (most common, fully decorrelatable)
47    Eq,
48    /// Less than (semi-decorrelatable)
49    Lt,
50    /// Greater than (semi-decorrelatable)
51    Gt,
52}
53
54/// Analysis result for a subquery
55#[derive(Debug, Clone)]
56pub struct SubqueryAnalysis {
57    /// Whether the subquery is correlated
58    pub is_correlated: bool,
59    /// Correlation predicates (if correlated)
60    pub correlation_predicates: Vec<CorrelationPredicate>,
61    /// Whether decorrelation is possible
62    pub can_decorrelate: bool,
63    /// Reason if decorrelation is not possible
64    pub decorrelation_blocker: Option<DecorrelationBlocker>,
65    /// Suggested decorrelation strategy
66    pub strategy: Option<DecorrelationStrategy>,
67}
68
69/// Reasons why decorrelation might not be possible
70#[derive(Debug, Clone, PartialEq, Eq)]
71pub enum DecorrelationBlocker {
72    /// Non-equality correlation predicates that can't be converted to joins
73    NonEqualityCorrelation,
74    /// Correlation in LIMIT/OFFSET (can't be pushed down)
75    CorrelationInLimit,
76    /// Multiple correlation levels (nested correlated subqueries)
77    NestedCorrelation,
78    /// Correlation in HAVING clause (complex transformation needed)
79    CorrelationInHaving,
80    /// Lateral join semantics required but not supported
81    RequiresLateralJoin,
82}
83
84/// Strategy for decorrelating a subquery
85#[derive(Debug, Clone, PartialEq, Eq)]
86pub enum DecorrelationStrategy {
87    /// Convert to INNER JOIN with GROUP BY on correlation columns
88    /// Used for scalar subqueries with aggregation
89    JoinWithGroupBy {
90        /// Columns to group by (the correlation columns)
91        group_by_cols: Vec<String>,
92        /// Join condition (equality on correlation columns)
93        join_condition: Vec<(String, String)>,
94    },
95
96    /// Convert to LEFT JOIN (for nullable results)
97    LeftJoinWithGroupBy {
98        group_by_cols: Vec<String>,
99        join_condition: Vec<(String, String)>,
100    },
101
102    /// Convert IN subquery to SEMI JOIN
103    SemiJoin {
104        join_condition: Vec<(String, String)>,
105    },
106
107    /// Convert NOT IN/NOT EXISTS to ANTI JOIN
108    AntiJoin {
109        join_condition: Vec<(String, String)>,
110    },
111
112    /// Apply DISTINCT to inner query and join
113    /// Used when inner returns duplicates but only existence matters
114    DistinctJoin {
115        join_condition: Vec<(String, String)>,
116    },
117}
118
119/// Subquery decorrelation optimizer
120pub struct Decorrelator {
121    /// Counter for generating unique aliases
122    alias_counter: usize,
123}
124
125impl Decorrelator {
126    /// Create a new decorrelator
127    pub fn new() -> Self {
128        Self { alias_counter: 0 }
129    }
130
131    /// Generate a unique alias for derived tables
132    fn next_alias(&mut self) -> String {
133        self.alias_counter += 1;
134        format!("__derived_{}", self.alias_counter)
135    }
136
137    /// Analyze a correlated subquery for decorrelation potential
138    pub fn analyze(
139        &self,
140        outer_refs: &[String],
141        inner_cols: &[String],
142        correlation_predicates: &[(String, String)], // (outer, inner) equality pairs
143        subquery_type: SubqueryKind,
144        has_aggregation: bool,
145        has_limit: bool,
146    ) -> SubqueryAnalysis {
147        // Non-correlated subqueries don't need decorrelation
148        if outer_refs.is_empty() {
149            return SubqueryAnalysis {
150                is_correlated: false,
151                correlation_predicates: Vec::new(),
152                can_decorrelate: false,
153                decorrelation_blocker: None,
154                strategy: None,
155            };
156        }
157
158        // Build correlation predicates
159        let predicates: Vec<CorrelationPredicate> = correlation_predicates
160            .iter()
161            .map(|(outer, inner)| CorrelationPredicate {
162                outer_col: outer.clone(),
163                inner_col: inner.clone(),
164                op: CorrelationOp::Eq,
165            })
166            .collect();
167
168        // Check for blockers
169        if has_limit {
170            return SubqueryAnalysis {
171                is_correlated: true,
172                correlation_predicates: predicates,
173                can_decorrelate: false,
174                decorrelation_blocker: Some(DecorrelationBlocker::CorrelationInLimit),
175                strategy: None,
176            };
177        }
178
179        // Determine strategy based on subquery type
180        let strategy = match subquery_type {
181            SubqueryKind::Scalar if has_aggregation => {
182                // Scalar aggregation can be decorrelated with GROUP BY
183                let group_by_cols: Vec<String> =
184                    predicates.iter().map(|p| p.inner_col.clone()).collect();
185                let join_condition: Vec<(String, String)> = predicates
186                    .iter()
187                    .map(|p| (p.outer_col.clone(), p.inner_col.clone()))
188                    .collect();
189
190                Some(DecorrelationStrategy::JoinWithGroupBy {
191                    group_by_cols,
192                    join_condition,
193                })
194            }
195            SubqueryKind::Scalar => {
196                // Non-aggregation scalar - use LEFT JOIN
197                let group_by_cols: Vec<String> =
198                    predicates.iter().map(|p| p.inner_col.clone()).collect();
199                let join_condition: Vec<(String, String)> = predicates
200                    .iter()
201                    .map(|p| (p.outer_col.clone(), p.inner_col.clone()))
202                    .collect();
203
204                Some(DecorrelationStrategy::LeftJoinWithGroupBy {
205                    group_by_cols,
206                    join_condition,
207                })
208            }
209            SubqueryKind::Exists | SubqueryKind::In => {
210                // EXISTS/IN becomes SEMI JOIN
211                let join_condition: Vec<(String, String)> = predicates
212                    .iter()
213                    .map(|p| (p.outer_col.clone(), p.inner_col.clone()))
214                    .collect();
215
216                Some(DecorrelationStrategy::SemiJoin { join_condition })
217            }
218            SubqueryKind::NotExists | SubqueryKind::NotIn => {
219                // NOT EXISTS/NOT IN becomes ANTI JOIN
220                let join_condition: Vec<(String, String)> = predicates
221                    .iter()
222                    .map(|p| (p.outer_col.clone(), p.inner_col.clone()))
223                    .collect();
224
225                Some(DecorrelationStrategy::AntiJoin { join_condition })
226            }
227            SubqueryKind::Any | SubqueryKind::All => {
228                // ANY/ALL with equality can be semi/anti join
229                // For other comparisons, more complex transformation needed
230                None
231            }
232        };
233
234        SubqueryAnalysis {
235            is_correlated: true,
236            correlation_predicates: predicates,
237            can_decorrelate: strategy.is_some(),
238            decorrelation_blocker: if strategy.is_none() {
239                Some(DecorrelationBlocker::RequiresLateralJoin)
240            } else {
241                None
242            },
243            strategy,
244        }
245    }
246
247    /// Estimate cost improvement from decorrelation
248    /// Returns the ratio of (correlated cost) / (decorrelated cost)
249    pub fn estimate_speedup(
250        &self,
251        outer_cardinality: usize,
252        inner_cardinality: usize,
253        strategy: &DecorrelationStrategy,
254    ) -> f64 {
255        // Correlated: O(outer * inner) - subquery runs once per outer row
256        let correlated_cost = (outer_cardinality * inner_cardinality) as f64;
257
258        // Decorrelated: O(outer + inner + join) - depends on strategy
259        let decorrelated_cost = match strategy {
260            DecorrelationStrategy::JoinWithGroupBy { group_by_cols, .. } => {
261                // GROUP BY inner + hash join
262                let group_by_cost = inner_cardinality as f64 * (group_by_cols.len() as f64).log2();
263                let join_cost = (outer_cardinality + inner_cardinality) as f64;
264                group_by_cost + join_cost
265            }
266            DecorrelationStrategy::LeftJoinWithGroupBy { .. } => {
267                // Similar to inner join but may produce more rows
268                (outer_cardinality + inner_cardinality) as f64 * 1.5
269            }
270            DecorrelationStrategy::SemiJoin { .. } | DecorrelationStrategy::AntiJoin { .. } => {
271                // Hash-based semi/anti join
272                (outer_cardinality + inner_cardinality) as f64
273            }
274            DecorrelationStrategy::DistinctJoin { .. } => {
275                // Distinct + join
276                let distinct_cost = inner_cardinality as f64 * 1.2;
277                let join_cost = (outer_cardinality + inner_cardinality) as f64;
278                distinct_cost + join_cost
279            }
280        };
281
282        // Avoid division by zero
283        if decorrelated_cost < 1.0 {
284            return correlated_cost;
285        }
286
287        correlated_cost / decorrelated_cost
288    }
289
290    /// Check if decorrelation is worthwhile based on cardinality
291    pub fn should_decorrelate(
292        &self,
293        outer_cardinality: usize,
294        inner_cardinality: usize,
295        strategy: &DecorrelationStrategy,
296    ) -> bool {
297        // Always decorrelate if speedup > 1.5x
298        let speedup = self.estimate_speedup(outer_cardinality, inner_cardinality, strategy);
299        speedup > 1.5
300    }
301}
302
303impl Default for Decorrelator {
304    fn default() -> Self {
305        Self::new()
306    }
307}
308
309/// Kind of subquery for decorrelation analysis
310#[derive(Debug, Clone, Copy, PartialEq, Eq)]
311pub enum SubqueryKind {
312    /// Scalar subquery (returns single value)
313    Scalar,
314    /// EXISTS subquery
315    Exists,
316    /// NOT EXISTS subquery
317    NotExists,
318    /// IN subquery
319    In,
320    /// NOT IN subquery
321    NotIn,
322    /// ANY comparison
323    Any,
324    /// ALL comparison
325    All,
326}
327
328// ============================================================================
329// Rewrite Rules
330// ============================================================================
331
332/// Represents a rewrite of a correlated subquery to a join
333#[derive(Debug, Clone)]
334pub struct SubqueryRewrite {
335    /// Alias for the derived table
336    pub derived_alias: String,
337    /// Join type to use
338    pub join_type: RewriteJoinType,
339    /// Columns to select from inner query (for the derived table)
340    pub inner_select: Vec<String>,
341    /// GROUP BY columns for the derived table (if aggregation)
342    pub group_by: Vec<String>,
343    /// Join condition (outer_col = derived_alias.inner_col pairs)
344    pub join_on: Vec<(String, String)>,
345    /// Column in derived table that replaces the subquery result
346    pub result_col: Option<String>,
347}
348
349/// Join type for rewritten subquery
350#[derive(Debug, Clone, Copy, PartialEq, Eq)]
351pub enum RewriteJoinType {
352    Inner,
353    Left,
354    Semi,
355    Anti,
356}
357
358impl Decorrelator {
359    /// Generate a rewrite plan for a decorrelatable subquery
360    pub fn plan_rewrite(
361        &mut self,
362        analysis: &SubqueryAnalysis,
363        aggregation_col: Option<&str>,
364    ) -> Option<SubqueryRewrite> {
365        let strategy = analysis.strategy.as_ref()?;
366
367        let alias = self.next_alias();
368
369        match strategy {
370            DecorrelationStrategy::JoinWithGroupBy {
371                group_by_cols,
372                join_condition,
373            } => {
374                let mut inner_select = group_by_cols.clone();
375                let result_col = aggregation_col.map(|c| {
376                    let col_name = format!("__agg_{}", c);
377                    inner_select.push(col_name.clone());
378                    col_name
379                });
380
381                Some(SubqueryRewrite {
382                    derived_alias: alias.clone(),
383                    join_type: RewriteJoinType::Inner,
384                    inner_select,
385                    group_by: group_by_cols.clone(),
386                    join_on: join_condition
387                        .iter()
388                        .map(|(o, i)| (o.clone(), format!("{}.{}", alias, i)))
389                        .collect(),
390                    result_col,
391                })
392            }
393            DecorrelationStrategy::LeftJoinWithGroupBy {
394                group_by_cols,
395                join_condition,
396            } => {
397                let mut inner_select = group_by_cols.clone();
398                let result_col = aggregation_col.map(|c| {
399                    let col_name = format!("__agg_{}", c);
400                    inner_select.push(col_name.clone());
401                    col_name
402                });
403
404                Some(SubqueryRewrite {
405                    derived_alias: alias.clone(),
406                    join_type: RewriteJoinType::Left,
407                    inner_select,
408                    group_by: group_by_cols.clone(),
409                    join_on: join_condition
410                        .iter()
411                        .map(|(o, i)| (o.clone(), format!("{}.{}", alias, i)))
412                        .collect(),
413                    result_col,
414                })
415            }
416            DecorrelationStrategy::SemiJoin { join_condition } => Some(SubqueryRewrite {
417                derived_alias: alias.clone(),
418                join_type: RewriteJoinType::Semi,
419                inner_select: join_condition.iter().map(|(_, i)| i.clone()).collect(),
420                group_by: Vec::new(),
421                join_on: join_condition
422                    .iter()
423                    .map(|(o, i)| (o.clone(), format!("{}.{}", alias, i)))
424                    .collect(),
425                result_col: None,
426            }),
427            DecorrelationStrategy::AntiJoin { join_condition } => Some(SubqueryRewrite {
428                derived_alias: alias.clone(),
429                join_type: RewriteJoinType::Anti,
430                inner_select: join_condition.iter().map(|(_, i)| i.clone()).collect(),
431                group_by: Vec::new(),
432                join_on: join_condition
433                    .iter()
434                    .map(|(o, i)| (o.clone(), format!("{}.{}", alias, i)))
435                    .collect(),
436                result_col: None,
437            }),
438            DecorrelationStrategy::DistinctJoin { join_condition } => {
439                Some(SubqueryRewrite {
440                    derived_alias: alias.clone(),
441                    join_type: RewriteJoinType::Semi,
442                    inner_select: join_condition.iter().map(|(_, i)| i.clone()).collect(),
443                    group_by: join_condition.iter().map(|(_, i)| i.clone()).collect(), // DISTINCT via GROUP BY
444                    join_on: join_condition
445                        .iter()
446                        .map(|(o, i)| (o.clone(), format!("{}.{}", alias, i)))
447                        .collect(),
448                    result_col: None,
449                })
450            }
451        }
452    }
453}
454
455#[cfg(test)]
456mod tests {
457    use super::*;
458
459    #[test]
460    fn test_non_correlated() {
461        let decorrelator = Decorrelator::new();
462        let analysis = decorrelator.analyze(
463            &[], // no outer refs
464            &["id".to_string(), "value".to_string()],
465            &[],
466            SubqueryKind::Scalar,
467            true,
468            false,
469        );
470
471        assert!(!analysis.is_correlated);
472        assert!(!analysis.can_decorrelate);
473    }
474
475    #[test]
476    fn test_scalar_aggregation_decorrelation() {
477        let decorrelator = Decorrelator::new();
478        let analysis = decorrelator.analyze(
479            &["o.customer_id".to_string()],
480            &["customer_id".to_string(), "total".to_string()],
481            &[("o.customer_id".to_string(), "customer_id".to_string())],
482            SubqueryKind::Scalar,
483            true,  // has aggregation
484            false, // no limit
485        );
486
487        assert!(analysis.is_correlated);
488        assert!(analysis.can_decorrelate);
489        assert!(matches!(
490            analysis.strategy,
491            Some(DecorrelationStrategy::JoinWithGroupBy { .. })
492        ));
493    }
494
495    #[test]
496    fn test_exists_decorrelation() {
497        let decorrelator = Decorrelator::new();
498        let analysis = decorrelator.analyze(
499            &["o.id".to_string()],
500            &["order_id".to_string()],
501            &[("o.id".to_string(), "order_id".to_string())],
502            SubqueryKind::Exists,
503            false,
504            false,
505        );
506
507        assert!(analysis.is_correlated);
508        assert!(analysis.can_decorrelate);
509        assert!(matches!(
510            analysis.strategy,
511            Some(DecorrelationStrategy::SemiJoin { .. })
512        ));
513    }
514
515    #[test]
516    fn test_limit_blocks_decorrelation() {
517        let decorrelator = Decorrelator::new();
518        let analysis = decorrelator.analyze(
519            &["o.id".to_string()],
520            &["order_id".to_string()],
521            &[("o.id".to_string(), "order_id".to_string())],
522            SubqueryKind::Scalar,
523            false,
524            true, // has limit - blocks decorrelation
525        );
526
527        assert!(analysis.is_correlated);
528        assert!(!analysis.can_decorrelate);
529        assert_eq!(
530            analysis.decorrelation_blocker,
531            Some(DecorrelationBlocker::CorrelationInLimit)
532        );
533    }
534
535    #[test]
536    fn test_speedup_estimation() {
537        let decorrelator = Decorrelator::new();
538
539        // With 1000 outer rows and 1000 inner rows:
540        // Correlated: 1000 * 1000 = 1,000,000 operations
541        // Decorrelated (join): ~2000 + join cost
542        let speedup = decorrelator.estimate_speedup(
543            1000,
544            1000,
545            &DecorrelationStrategy::SemiJoin {
546                join_condition: vec![("a".to_string(), "b".to_string())],
547            },
548        );
549
550        // Should be significant speedup
551        assert!(speedup > 100.0);
552    }
553
554    #[test]
555    fn test_rewrite_plan() {
556        let mut decorrelator = Decorrelator::new();
557
558        let analysis = decorrelator.analyze(
559            &["o.customer_id".to_string()],
560            &["customer_id".to_string(), "total".to_string()],
561            &[("o.customer_id".to_string(), "customer_id".to_string())],
562            SubqueryKind::Scalar,
563            true,
564            false,
565        );
566
567        let rewrite = decorrelator.plan_rewrite(&analysis, Some("avg_total"));
568        assert!(rewrite.is_some());
569
570        let rewrite = rewrite.unwrap();
571        assert_eq!(rewrite.join_type, RewriteJoinType::Inner);
572        assert!(rewrite.group_by.contains(&"customer_id".to_string()));
573        assert!(rewrite.result_col.is_some());
574    }
575}