clickhouse_datafusion/analyzer/
function_pushdown.rs

1use std::collections::HashSet;
2use std::sync::Arc;
3
4use clickhouse_arrow::rustc_hash::FxHashMap;
5use datafusion::arrow::datatypes::Field;
6use datafusion::common::tree_node::{Transformed, TreeNode, TreeNodeRecursion};
7use datafusion::common::{Column, DFSchema, DFSchemaRef, Result, plan_err};
8use datafusion::logical_expr::expr::ScalarFunction;
9use datafusion::logical_expr::{LogicalPlan, Projection, SubqueryAlias};
10use datafusion::optimizer::AnalyzerRule;
11use datafusion::prelude::Expr;
12use datafusion::sql::TableReference;
13
14use super::source_context::ResolvedSource;
15use super::source_visitor::{ColumnId, SourceLineageVistor};
16use super::utils::{
17    extract_function_and_return_type, is_clickhouse_function, use_clickhouse_function_context,
18};
19use crate::utils::analyze::push_exprs_below_subquery;
20
21/// State for tracking `ClickHouse` functions during pushdown analysis
22#[derive(Default, Debug, Clone)]
23struct PushdownState {
24    /// `ClickHouse` functions being pushed down, organized by resolved sources
25    functions:        FxHashMap<ResolvedSource, Vec<Expr>>,
26    /// Tracks the resolved sources of all functions collected, for non-branching plans
27    function_sources: ResolvedSource,
28    /// Tracks the sources of the schemas of the plans visited so far
29    plan_sources:     ResolvedSource,
30}
31
32impl PushdownState {
33    fn take_functions(&mut self) -> Vec<Expr> {
34        self.functions.values_mut().flatten().map(std::mem::take).collect::<Vec<_>>()
35    }
36
37    fn has_functions(&self) -> bool { self.functions.values().any(|f| !f.is_empty()) }
38
39    /// Using the provided `LogicalPlan`s schema, pull out any functions that can be pushed safely.
40    fn take_relevant_functions(
41        &mut self,
42        schema: &DFSchemaRef,
43        visitor: &SourceLineageVistor,
44    ) -> FxHashMap<ResolvedSource, Vec<Expr>> {
45        if !self.has_functions() {
46            return FxHashMap::default();
47        }
48        let sources = visitor.resolve_schema(schema.as_ref());
49        let column_ids = schema
50            .columns()
51            .into_iter()
52            .flat_map(|col| visitor.collect_column_ids(&col))
53            .collect::<HashSet<_>>();
54        let extracted = self
55            .functions
56            .iter_mut()
57            .filter(|(r, _)| r.resolves_intersects(&sources))
58            .map(|(resolved, funcs)| {
59                let extracted = funcs
60                    .extract_if(.., |f| {
61                        f.column_refs()
62                            .into_iter()
63                            .flat_map(|col| visitor.collect_column_ids(col))
64                            .collect::<HashSet<_>>()
65                            .is_subset(&column_ids)
66                    })
67                    .collect::<Vec<_>>();
68                (resolved.clone(), extracted)
69            })
70            .collect::<FxHashMap<_, _>>();
71        // Clean up any empty function lists
72        self.functions.retain(|_, funcs| !funcs.is_empty());
73        extracted
74    }
75}
76
77/// A `DataFusion` `AnalyzerRule` that identifies largest subtree of a plan to wrap with an
78/// extension node, otherwise "pushes down" `ClickHouse` functions when required
79#[derive(Debug, Clone, Copy)]
80pub struct ClickHouseFunctionPushdown;
81
82impl AnalyzerRule for ClickHouseFunctionPushdown {
83    fn analyze(
84        &self,
85        plan: LogicalPlan,
86        _config: &datafusion::common::config::ConfigOptions,
87    ) -> Result<LogicalPlan> {
88        if matches!(plan, LogicalPlan::Ddl(_) | LogicalPlan::DescribeTable(_)) {
89            return Ok(plan);
90        }
91
92        // Create source lineage visitor and build column -> source lineage for the plan
93        #[cfg_attr(feature = "test-utils", expect(unused))]
94        let mut lineage_visitor = SourceLineageVistor::new();
95
96        #[cfg(feature = "test-utils")]
97        let mut lineage_visitor = lineage_visitor
98            .with_source_grouping(HashSet::from(["table1".to_string(), "table2".to_string()]));
99
100        let _ = plan.visit(&mut lineage_visitor)?;
101
102        // Nothing to transform
103        if lineage_visitor.clickhouse_function_count == 0 {
104            return Ok(plan);
105        }
106
107        // Initialize state, setting the plan_sources
108        let mut state = PushdownState::default();
109
110        let (new_plan, _) = analyze_and_transform_plan(plan, &mut state, &lineage_visitor)?;
111        Ok(new_plan.data)
112    }
113
114    fn name(&self) -> &'static str { "clickhouse_function_pushdown" }
115}
116
117// TODOS: * Implement error handling for required pushdowns
118//        * Check is_volatile and other invariants
119//        * Handle Sort in violations checks. Sorting may prevent pushdown in some cases
120//        * Refine Subquery in violations checks, might need some tuning
121//        * Investigate how to separate mixed functions, ie in aggr_exprs
122//        * Initial violations checks allocate, could be improved for cases where no alloc needed
123//        * Match more plans to reduce allocations
124//
125/// Analyze and transform a plan. Functions flow DOWN - recurse if current level cannot handle them.
126///
127/// The basic flow is:
128/// 1. Determine any semantic violations with pushed functions or plan functions
129/// 2. Handle some plans specially, ie joins which require function routing
130/// 3. Update `PushdownState`'s resolved sources to reflect any functions in plan expressions
131/// 4. Check if the resolved sources fully encompass the functions's resolved sources
132/// 5. If so, wrap there, highest subtree identified
133/// 6. If not, enforce any semantic violations (wrapping prevents this need)
134/// 7. Extract any clickhouse functions, replace with alias, and store in state
135/// 8. Recurse into inputs
136/// 9. Reconstruct plan while unwinding recursion
137fn analyze_and_transform_plan(
138    plan: LogicalPlan,
139    state: &mut PushdownState,
140    visitor: &SourceLineageVistor,
141) -> Result<(Transformed<LogicalPlan>, Option<ResolvedSource>)> {
142    // First determine any semantic violations with pushed functions
143    let function_violations = check_state_for_violations(state, &plan, visitor);
144
145    // Then check any plan violations for plan functions
146    let plan_violations = check_plan_for_violations(&plan, visitor);
147
148    let plan_sources = visitor.resolve_schema(plan.schema());
149    let plan_function_sources = resolve_plan_expr_functions(&plan, visitor);
150
151    // If plan's schema resolves pushed functions, wrap it and return
152    if state.has_functions() {
153        if state.function_sources.resolves_eq(&plan_sources) {
154            let wrapped_plan = wrap_plan_with_functions(plan, state.take_functions(), visitor)?;
155            return Ok((Transformed::yes(wrapped_plan), None));
156        }
157
158        // All ClickHouse functions need to be resolveable by this point
159        if !state.function_sources.resolves_within(&plan_sources) {
160            return plan_err!(
161                "SQL not supported, could not determine sources of following functions: {:?}",
162                state.functions.iter().collect::<Vec<_>>()
163            );
164        }
165    }
166
167    // It may be that the very top level plan is wrapped, need to know if it is the first plan
168    let top_level_plan = !state.plan_sources.is_known();
169    if top_level_plan {
170        state.plan_sources = plan_sources.clone();
171    }
172
173    // Wrap higher if function sources resolve into parent plan sources
174    if state.plan_sources.resolves_eq(&plan_function_sources) {
175        // Top level plan can be wrapped
176        if top_level_plan {
177            let wrapped_plan = wrap_plan_with_functions(plan, state.take_functions(), visitor)?;
178            return Ok((Transformed::yes(wrapped_plan), None));
179        }
180        // Otherwise send function sources back up
181        return Ok((Transformed::no(plan), Some(plan_function_sources)));
182    }
183
184    // Merge function sources
185    state.function_sources =
186        std::mem::take(&mut state.function_sources).merge(plan_function_sources);
187
188    // Finally check if plan can be wrapped itself
189    if state.function_sources.resolves_eq(&plan_sources) {
190        let wrapped_plan = wrap_plan_with_functions(plan, state.take_functions(), visitor)?;
191        return Ok((Transformed::yes(wrapped_plan), None));
192    }
193
194    // Either pushdowns or plan functions that need to be pushed down
195    if state.function_sources.is_known() {
196        // Ensure no semantic violations in the result of pushdown
197        semantic_err(
198            plan.display(),
199            "SQL unsupported, pushed functions violate sql semantics in current plan.",
200            &function_violations,
201        )?;
202        // Ensure no plan violations in the result of pushdown
203        semantic_err(
204            plan.display(),
205            "SQL unsupported, plan violates sql semantics in plan's expressions.",
206            &plan_violations,
207        )?;
208    }
209
210    // Pull out plan sources for comparison
211    let parent_sources = std::mem::take(&mut state.plan_sources);
212
213    // SubqueryAlias handling for pushdowns
214    if state.has_functions()
215        && let LogicalPlan::SubqueryAlias(alias) = &plan
216    {
217        state.functions = std::mem::take(&mut state.functions)
218            .into_iter()
219            .map(|(resolv, funcs)| (resolv, push_exprs_below_subquery(funcs, alias)))
220            .collect();
221    }
222
223    // Update w/ schema and expression changes
224    let aliased_exprs = collect_and_transform_exprs(plan.expressions(), visitor, state)?;
225
226    // Track if this plan will be transformed in any way
227    let mut was_transformed = state.has_functions();
228
229    // Recurse deeper
230    let inputs_transformed = plan
231        .inputs()
232        .into_iter()
233        .cloned()
234        .map(|input| {
235            let extracted = state.take_relevant_functions(input.schema(), visitor);
236            let mut input_state = PushdownState {
237                function_sources: extracted
238                    .keys()
239                    .cloned()
240                    .reduce(ResolvedSource::merge)
241                    .unwrap_or_default(),
242                functions:        extracted,
243                plan_sources:     plan_sources.clone(),
244            };
245            analyze_and_transform_plan(input, &mut input_state, visitor)
246        })
247        .collect::<Result<Vec<_>>>()?;
248
249    // All ClickHouse functions need to be resolved by this point
250    if state.has_functions() {
251        return plan_err!(
252            "SQL not supported, could not determine sources of following functions: {:?}",
253            state.functions.iter().collect::<Vec<_>>()
254        );
255    }
256
257    // If the inputs determined wrap higher up, merge their functions sources
258    let input_resolution = inputs_transformed
259        .iter()
260        .inspect(|(i, _)| was_transformed |= i.transformed)
261        .filter_map(|(_, func_sources)| func_sources.clone())
262        .reduce(ResolvedSource::merge)
263        .unwrap_or_default();
264
265    let new_inputs = inputs_transformed.into_iter().map(|(i, _)| i.data).collect::<Vec<_>>();
266    let new_plan = plan.with_new_exprs(aliased_exprs, new_inputs)?;
267
268    let new_plan =
269        // If this node's parent sources resolves, defer to higher plan
270        if !top_level_plan && parent_sources.resolves_eq(&input_resolution) {
271            return Ok((Transformed::no(plan), Some(input_resolution)));
272
273        // If this plan's schema resolves the function sources, wrap here
274        } else if plan_sources.resolves_eq(&input_resolution) {
275            let wrapped = wrap_plan_with_functions(new_plan, state.take_functions(), visitor)?;
276            Transformed::yes(wrapped)
277
278        // Otherwise, returned plan
279        } else if was_transformed {
280            Transformed::yes(new_plan)
281        } else {
282            Transformed::no(new_plan)
283        };
284
285    Ok((new_plan, None))
286}
287
288fn resolve_plan_expr_functions(
289    input: &LogicalPlan,
290    visitor: &SourceLineageVistor,
291) -> ResolvedSource {
292    let mut exprs_resolved = ResolvedSource::default();
293    let _ = input
294        .apply_expressions(|expr| {
295            use_clickhouse_function_context(expr, |e| {
296                exprs_resolved = std::mem::take(&mut exprs_resolved).merge(visitor.resolve_expr(e));
297                Ok(TreeNodeRecursion::Stop)
298            })
299            .unwrap();
300            Ok(TreeNodeRecursion::Continue)
301        })
302        .unwrap();
303    exprs_resolved
304}
305
306/// Add functions to a plan by wrapping with Extension node that persists through optimization
307fn wrap_plan_with_functions(
308    plan: LogicalPlan,
309    functions: Vec<Expr>,
310    visitor: &SourceLineageVistor,
311) -> Result<LogicalPlan> {
312    #[cfg(feature = "federation")]
313    #[expect(clippy::unnecessary_wraps)]
314    fn return_wrapped_plan(plan: LogicalPlan) -> Result<LogicalPlan> { Ok(plan) }
315
316    #[cfg(not(feature = "federation"))]
317    fn return_wrapped_plan(plan: LogicalPlan) -> Result<LogicalPlan> {
318        use datafusion::logical_expr::Extension;
319
320        use crate::context::plan_node::ClickHouseFunctionNode;
321
322        Ok(LogicalPlan::Extension(Extension {
323            node: Arc::new(ClickHouseFunctionNode::try_new(plan)?),
324        }))
325    }
326
327    // Collect pushed functions
328    let (func_fields, func_cols) = functions_to_field_and_cols(functions, visitor)?;
329
330    // Remove catalog from table scan
331    let plan = plan.transform_up_with_subqueries(strip_table_scan_catalog).unwrap().data;
332    // Recompute schema
333    let plan = plan.recompute_schema()?;
334
335    match plan {
336        LogicalPlan::SubqueryAlias(alias) => {
337            // Ensure the alias isn't carried down.
338            let func_cols = push_exprs_below_subquery(func_cols, &alias);
339
340            let input = Arc::unwrap_or_clone(alias.input);
341            let new_input = wrap_plan_in_projection(input, func_fields, func_cols)?.into();
342            let new_alias = SubqueryAlias::try_new(new_input, alias.alias)?;
343            return_wrapped_plan(LogicalPlan::SubqueryAlias(new_alias))
344        }
345        _ => return_wrapped_plan(wrap_plan_in_projection(plan, func_fields, func_cols)?),
346    }
347}
348
349type QualifiedField = (Option<TableReference>, Arc<Field>);
350
351/// Extract inner functions from `ClickHouse` `UDF` wrappers
352fn functions_to_field_and_cols(
353    functions: Vec<Expr>,
354    visitor: &SourceLineageVistor,
355) -> Result<(Vec<QualifiedField>, Vec<Expr>)> {
356    let mut fields = Vec::new();
357    let mut columns = Vec::new();
358    for function in functions {
359        let is_nullable = visitor.resolve_nullable(&function);
360        let alias = function.schema_name().to_string();
361        let (inner_function, data_type) = extract_function_and_return_type(function)?;
362        fields.push((None, Arc::new(Field::new(&alias, data_type, is_nullable))));
363        columns.push(inner_function.alias(alias));
364    }
365    Ok((fields, columns))
366}
367
368fn wrap_plan_in_projection(
369    plan: LogicalPlan,
370    func_fields: Vec<QualifiedField>,
371    func_cols: Vec<Expr>,
372) -> Result<LogicalPlan> {
373    // No functions, no modification needed
374    if func_cols.is_empty() {
375        return Ok(plan);
376    }
377
378    let metadata = plan.schema().metadata().clone();
379    let mut fields =
380        plan.schema().iter().map(|(q, f)| (q.cloned(), Arc::clone(f))).collect::<Vec<_>>();
381    fields.extend(func_fields);
382
383    // Create new schema accounting for pushed functions
384    let new_schema = DFSchema::new_with_metadata(fields, metadata)?;
385
386    // Wrap in a projection only if not already a projection
387    let new_plan = if let LogicalPlan::Projection(mut projection) = plan {
388        projection.expr.extend(func_cols);
389        Projection::try_new_with_schema(projection.expr, projection.input, new_schema.into())?
390    } else {
391        let mut exprs = plan.schema().columns().into_iter().map(Expr::Column).collect::<Vec<_>>();
392        exprs.extend(func_cols);
393        Projection::try_new_with_schema(exprs, plan.into(), new_schema.into())?
394    };
395
396    Ok(LogicalPlan::Projection(new_plan))
397}
398
399/// Helper to both collect clickhouse functions into state as well as transform the original
400/// expression into an alias
401fn collect_and_transform_exprs(
402    exprs: Vec<Expr>,
403    visitor: &SourceLineageVistor,
404    state: &mut PushdownState,
405) -> Result<Vec<Expr>> {
406    exprs
407        .into_iter()
408        .map(|expr| collect_and_transform_function(expr, visitor, state).map(|t| t.data))
409        .collect::<Result<Vec<_>>>()
410}
411
412/// Transform an expression possibly containing a `ClickHouse` function.
413///
414/// # Errors
415/// - Returns an error if the
416fn collect_and_transform_function(
417    expr: Expr,
418    visitor: &SourceLineageVistor,
419    state: &mut PushdownState,
420) -> Result<Transformed<Expr>> {
421    expr.transform_down(|e| {
422        if is_clickhouse_function(&e) {
423            let func_resolved = visitor.resolve_expr(&e);
424            let alias = e.schema_name().to_string();
425
426            // If the only source of the function is a Scalar, unwrap and leave as is.
427            if matches!(
428                func_resolved,
429                ResolvedSource::Scalar(_) | ResolvedSource::Scalars(_) | ResolvedSource::Unknown
430            ) {
431                let Expr::ScalarFunction(ScalarFunction { mut args, .. }) = e else {
432                    unreachable!(); // Guaranteed by `is_clickhouse_function` check
433                };
434                if args.is_empty() {
435                    return plan_err!("`clickhouse` function requires an arg, none found: {alias}");
436                }
437                return Ok(Transformed::new(args.remove(0), true, TreeNodeRecursion::Jump));
438            }
439
440            state.function_sources =
441                std::mem::take(&mut state.function_sources).merge(func_resolved.clone());
442            let current_funcs = state.functions.entry(func_resolved).or_default();
443            // Store the original expression for pushdown
444            if !current_funcs.contains(&e) {
445                current_funcs.push(e);
446            }
447
448            Ok(Transformed::new(
449                Expr::Column(Column::from_name(alias)),
450                true,
451                TreeNodeRecursion::Jump,
452            ))
453        } else {
454            Ok(Transformed::no(e))
455        }
456    })
457}
458
459/// Strip table scan of catalog name before passing to extension node
460fn strip_table_scan_catalog(plan: LogicalPlan) -> Result<Transformed<LogicalPlan>> {
461    plan.transform_up_with_subqueries(|node| {
462        if let LogicalPlan::TableScan(mut scan) = node {
463            if let TableReference::Full { schema, table, .. } = scan.table_name {
464                scan.table_name = TableReference::Partial { schema, table };
465                return Ok(Transformed::yes(LogicalPlan::TableScan(scan)));
466            }
467
468            Ok(Transformed::no(LogicalPlan::TableScan(scan)))
469        } else {
470            Ok(Transformed::no(node))
471        }
472    })
473}
474
475/// Analyze state and its functions against current plan for any possible semantic violations
476fn check_state_for_violations(
477    state: &PushdownState,
478    plan: &LogicalPlan,
479    visitor: &SourceLineageVistor,
480) -> HashSet<Column> {
481    let mut function_violations = HashSet::new();
482    if !state.functions.is_empty() {
483        for func in state.functions.values().flatten() {
484            function_violations
485                .extend(violates_pushdown_semantics(func, plan, visitor).into_iter().cloned());
486        }
487    }
488    function_violations
489}
490
491/// Analyze current plan for any possible semantic violations for found functions
492fn check_plan_for_violations(plan: &LogicalPlan, visitor: &SourceLineageVistor) -> HashSet<Column> {
493    violates_plan_semantics(plan, visitor).into_iter().cloned().collect()
494}
495
496/// Check the provided function expr's column references to ensure the plan being analyzed would
497/// violate the plan's semantics if the function needed to be pushed below it.
498///
499/// NOTE: Important! The function must NOT be from the plan passed in but from a plan higher up,
500/// otherwise use `violates_plan_semantics` below.
501fn violates_pushdown_semantics<'a>(
502    function: &Expr,
503    plan: &'a LogicalPlan,
504    visitor: &SourceLineageVistor,
505) -> HashSet<&'a Column> {
506    // Resolve the function's column references to match against plan expressions
507    //
508    // NOTE: The actual Column IDs are used to provide more granular information.
509    let function_column_ids = function
510        .column_refs()
511        .iter()
512        .flat_map(|col| visitor.collect_column_ids(col))
513        .collect::<HashSet<_>>();
514
515    // If column ids are empty, nothing to analyze
516    if function_column_ids.is_empty() {
517        return HashSet::new();
518    }
519
520    match plan {
521        LogicalPlan::Aggregate(agg) => {
522            // Ensure aggr exprs do not change the semantic meaning of function column references
523            if let Some(related_cols) = check_function_against_exprs(
524                function,
525                &agg.aggr_expr,
526                &function_column_ids,
527                visitor,
528                true,
529            ) {
530                return related_cols;
531            }
532
533            // Ensure function column references are exposed in group exprs
534            if let Some(related_cols) = check_function_against_exprs(
535                function,
536                &agg.group_expr,
537                &function_column_ids,
538                visitor,
539                false,
540            ) {
541                return related_cols;
542            }
543        }
544        LogicalPlan::Window(window) => {
545            if let Some(related_cols) = check_function_against_exprs(
546                function,
547                &window.window_expr,
548                &function_column_ids,
549                visitor,
550                true,
551            ) {
552                return related_cols;
553            }
554        }
555        LogicalPlan::Subquery(query) => {
556            if let Some(related_cols) = check_function_against_exprs(
557                function,
558                &query.outer_ref_columns,
559                &function_column_ids,
560                visitor,
561                true,
562            ) {
563                return related_cols;
564            }
565        }
566        _ => {}
567    }
568    HashSet::new()
569}
570
571/// Check the provided function expr's column references to ensure the plan being analyzed does not
572/// disallow further pushed down.
573///
574/// NOTE: Important! The function must NOT be from the plan passed in but from a plan higher up
575fn violates_plan_semantics<'a>(
576    plan: &'a LogicalPlan,
577    _visitor: &SourceLineageVistor,
578) -> HashSet<&'a Column> {
579    if let LogicalPlan::Aggregate(agg) = plan {
580        // Finally, ensure any clickhouse functions WITHIN the aggregate plan do not violate
581        // grouping semantics
582        for expr in &agg.aggr_expr {
583            let mut violations = None;
584            drop(use_clickhouse_function_context(expr, |agg_func| {
585                // A clickhouse function was found, ensure it exists in the group by
586                let found = agg.group_expr.iter().any(|group_e| {
587                    let mut found = false;
588                    use_clickhouse_function_context(group_e, |group_func| {
589                        found |= group_func == agg_func;
590                        Ok(TreeNodeRecursion::Stop)
591                    })
592                    .unwrap();
593                    found
594                });
595
596                if !found {
597                    violations = Some(expr.column_refs());
598                    // Exit early
599                    return plan_err!("Aggregate functions must be in group by");
600                }
601                Ok(TreeNodeRecursion::Stop)
602            }));
603
604            if let Some(violations) = violations {
605                // Return early as invariant is already violated
606                return violations;
607            }
608        }
609    }
610
611    HashSet::new()
612}
613
614/// Iterate over each plan expression provided and determine whether the function's column IDs must
615/// be in the expr's column references or must not be.
616fn check_function_against_exprs<'a>(
617    func: &Expr,
618    exprs: &'a [Expr],
619    func_column_ids: &HashSet<ColumnId>,
620    visitor: &SourceLineageVistor,
621    disjoint_required: bool,
622) -> Option<HashSet<&'a Column>> {
623    for expr in exprs {
624        // Little safeguard to allow replacing entire functions in the plan.
625        if expr == func {
626            continue;
627        }
628        let col_refs = expr.column_refs();
629        let expr_column_ids =
630            col_refs.iter().flat_map(|col| visitor.collect_column_ids(col)).collect::<HashSet<_>>();
631
632        // Not sure when this would occur, scalar maybe?
633        if expr_column_ids.is_empty() && !disjoint_required {
634            continue;
635        }
636        // Check if the intersection must be empty or must not be empty
637        if expr_column_ids.is_disjoint(func_column_ids) != disjoint_required {
638            // Return early as invariant is already violated
639            return Some(col_refs);
640        }
641    }
642    None
643}
644
645/// Helper to emit an error in the situation where a pushdown would violate query semantics
646fn semantic_err(
647    name: impl std::fmt::Display,
648    msg: &str,
649    violations: &HashSet<Column>,
650) -> Result<()> {
651    if !violations.is_empty() {
652        let violations =
653            violations.iter().map(Column::quoted_flat_name).collect::<Vec<_>>().join(", ");
654        return plan_err!("[{name}] - {msg} Violations: {violations}");
655    }
656    Ok(())
657}
658
659#[cfg(all(test, feature = "test-utils"))]
660mod tests {
661    use std::collections::HashSet;
662    use std::sync::Arc;
663
664    use datafusion::arrow::datatypes::{DataType, Field, Schema};
665    use datafusion::catalog::TableProvider;
666    use datafusion::common::{Column, Result};
667    use datafusion::datasource::empty::EmptyTable;
668    use datafusion::datasource::provider_as_source;
669    use datafusion::functions_aggregate::count::count;
670    use datafusion::logical_expr::{Expr, LogicalPlan, LogicalPlanBuilder, table_scan};
671    use datafusion::prelude::*;
672    use datafusion::scalar::ScalarValue;
673    use datafusion::sql::TableReference;
674
675    use super::super::source_visitor::SourceLineageVistor;
676    use super::*;
677    use crate::analyzer::source_context::SourceContext;
678    use crate::analyzer::source_visitor::ColumnLineage;
679    use crate::udfs::clickhouse::clickhouse_udf;
680    #[cfg(feature = "mocks")]
681    use crate::{
682        ClickHouseConnectionPool, ClickHouseTableProvider,
683        analyzer::function_pushdown::ClickHouseFunctionPushdown,
684        plan_node::CLICKHOUSE_FUNCTION_NODE_NAME,
685    };
686
687    fn create_table_scan(table: TableReference, provider: Arc<dyn TableProvider>) -> LogicalPlan {
688        LogicalPlanBuilder::scan(table, provider_as_source(provider), None)
689            .unwrap()
690            .build()
691            .unwrap()
692    }
693
694    #[test]
695    fn test_functions_to_field_and_cols_empty() -> Result<()> {
696        let visitor = SourceLineageVistor::default();
697        let functions = Vec::new();
698        let result = functions_to_field_and_cols(functions, &visitor)?;
699        assert!(result.0.is_empty());
700        assert!(result.1.is_empty());
701        Ok(())
702    }
703
704    #[test]
705    fn test_functions_to_field_and_cols_single_function() -> Result<()> {
706        let visitor = SourceLineageVistor::default();
707        let functions = vec![Expr::ScalarFunction(ScalarFunction {
708            func: Arc::new(clickhouse_udf()),
709            args: vec![lit("count()"), lit("Int64")],
710        })];
711
712        let (fields, funcs) = functions_to_field_and_cols(functions, &visitor)?;
713        assert_eq!(fields.len(), 1);
714        assert_eq!(funcs.len(), 1);
715
716        let (field_ref, field) = &fields[0];
717        assert!(field_ref.is_none());
718        assert_eq!(field.data_type(), &DataType::Int64);
719        assert!(!field.is_nullable());
720
721        Ok(())
722    }
723
724    #[test]
725    fn test_functions_to_field_and_cols_multiple_functions() -> Result<()> {
726        let visitor = SourceLineageVistor::default();
727        let functions = vec![
728            Expr::ScalarFunction(ScalarFunction {
729                func: Arc::new(clickhouse_udf()),
730                args: vec![lit("count()"), lit("Int64")],
731            }),
732            Expr::ScalarFunction(ScalarFunction {
733                func: Arc::new(clickhouse_udf()),
734                args: vec![lit("sum(x)"), lit("Float64")],
735            }),
736        ];
737
738        let (fields, funcs) = functions_to_field_and_cols(functions, &visitor)?;
739        assert_eq!(fields.len(), 2);
740        assert_eq!(funcs.len(), 2);
741
742        // Check data types
743        assert_eq!(fields[0].1.data_type(), &DataType::Int64);
744        assert_eq!(fields[1].1.data_type(), &DataType::Float64);
745
746        Ok(())
747    }
748
749    #[test]
750    fn test_wrap_plan_in_projection_no_functions() -> Result<()> {
751        let schema = Arc::new(Schema::new(vec![Field::new("col1", DataType::Int32, false)]));
752        let provider = Arc::new(EmptyTable::new(schema));
753        let plan = create_table_scan(TableReference::bare("test_table"), provider);
754
755        let result = wrap_plan_in_projection(plan.clone(), vec![], vec![])?;
756
757        // Should return original plan unchanged when no functions
758        match (&plan, &result) {
759            (LogicalPlan::TableScan(original), LogicalPlan::TableScan(result)) => {
760                assert_eq!(original.table_name, result.table_name);
761            }
762            _ => panic!("Expected TableScan plans"),
763        }
764
765        Ok(())
766    }
767
768    #[test]
769    fn test_wrap_plan_in_projection_with_functions() -> Result<()> {
770        let schema = Arc::new(Schema::new(vec![Field::new("col1", DataType::Int32, false)]));
771        let provider = Arc::new(EmptyTable::new(schema));
772        let plan = create_table_scan(TableReference::bare("test_table"), provider);
773
774        let func_fields =
775            vec![(None, Arc::new(Field::new("func_result", DataType::Float64, true)))];
776        let func_cols = vec![lit("test_function").alias("func_result")];
777
778        let result = wrap_plan_in_projection(plan, func_fields, func_cols)?;
779
780        // Should be wrapped in a projection
781        match result {
782            LogicalPlan::Projection(projection) => {
783                assert_eq!(projection.expr.len(), 2); // original column + function
784                assert_eq!(projection.schema.fields().len(), 2);
785            }
786            _ => panic!("Expected Projection plan"),
787        }
788
789        Ok(())
790    }
791
792    #[test]
793    fn test_strip_table_scan_catalog_no_catalog() -> Result<()> {
794        let schema = Arc::new(Schema::new(vec![Field::new("col1", DataType::Int32, false)]));
795        let provider = Arc::new(EmptyTable::new(schema));
796        let plan = create_table_scan(TableReference::bare("test_table"), provider);
797
798        let result = strip_table_scan_catalog(plan)?;
799
800        // Should not be transformed since no catalog to strip
801        assert!(!result.transformed);
802
803        Ok(())
804    }
805
806    #[test]
807    fn test_strip_table_scan_catalog_with_catalog() -> Result<()> {
808        let schema = Arc::new(Schema::new(vec![Field::new("col1", DataType::Int32, false)]));
809        let provider = Arc::new(EmptyTable::new(schema));
810        let table_ref = TableReference::Full {
811            catalog: "catalog".into(),
812            schema:  "schema".into(),
813            table:   "table".into(),
814        };
815        let plan = create_table_scan(table_ref, provider);
816
817        let result = strip_table_scan_catalog(plan)?;
818
819        // Should be transformed to remove catalog
820        assert!(result.transformed);
821        if let LogicalPlan::TableScan(scan) = result.data {
822            match scan.table_name {
823                TableReference::Partial { schema, table } => {
824                    assert_eq!(schema.as_ref(), "schema");
825                    assert_eq!(table.as_ref(), "table");
826                }
827                _ => panic!("Expected Partial table reference after catalog stripping"),
828            }
829        } else {
830            panic!("Expected TableScan after transformation");
831        }
832
833        Ok(())
834    }
835
836    #[test]
837    fn test_check_state_for_violations_empty_state() {
838        let state = PushdownState::default();
839        let schema = Arc::new(Schema::new(vec![Field::new("col1", DataType::Int32, false)]));
840        let provider = Arc::new(EmptyTable::new(schema));
841        let plan = create_table_scan(TableReference::bare("test_table"), provider);
842        let visitor = SourceLineageVistor::new();
843
844        let violations = check_state_for_violations(&state, &plan, &visitor);
845        assert!(violations.is_empty());
846    }
847
848    #[test]
849    fn test_check_plan_for_violations_non_aggregate() {
850        let schema = Arc::new(Schema::new(vec![Field::new("col1", DataType::Int32, false)]));
851        let provider = Arc::new(EmptyTable::new(schema));
852        let plan = create_table_scan(TableReference::bare("test_table"), provider);
853
854        let visitor = SourceLineageVistor::new();
855        let violations = check_plan_for_violations(&plan, &visitor);
856        assert!(violations.is_empty());
857    }
858
859    #[test]
860    fn test_semantic_err_no_violations() {
861        let violations = HashSet::new();
862        let result = semantic_err("TestPlan", "Test message", &violations);
863        assert!(result.is_ok());
864    }
865
866    #[test]
867    fn test_semantic_err_with_violations() {
868        let mut violations = HashSet::new();
869        let _ = violations.insert(Column::new_unqualified("test_col"));
870
871        let result = semantic_err("TestPlan", "Test message", &violations);
872        assert!(result.is_err());
873        let err_msg = result.unwrap_err().to_string();
874        assert!(err_msg.contains("[TestPlan]"));
875        assert!(err_msg.contains("Test message"));
876        assert!(err_msg.contains("test_col"));
877    }
878
879    #[test]
880    fn test_collect_and_transform_exprs_empty() -> Result<()> {
881        let exprs = Vec::new();
882        let visitor = SourceLineageVistor::new();
883        let mut state = PushdownState::default();
884
885        let result = collect_and_transform_exprs(exprs, &visitor, &mut state)?;
886        assert!(result.is_empty());
887        Ok(())
888    }
889
890    #[test]
891    fn test_collect_and_transform_exprs_no_clickhouse_functions() -> Result<()> {
892        let exprs = vec![col("test_col"), lit(42)];
893        let visitor = SourceLineageVistor::new();
894        let mut state = PushdownState::default();
895
896        let result = collect_and_transform_exprs(exprs.clone(), &visitor, &mut state)?;
897        assert_eq!(result.len(), 2);
898        // Original expressions should be preserved when no ClickHouse functions
899        assert_eq!(result[0], col("test_col"));
900        assert_eq!(result[1], lit(42));
901        Ok(())
902    }
903
904    #[test]
905    fn test_collect_and_transform_function_non_clickhouse() -> Result<()> {
906        let expr = col("test_col");
907        let visitor = SourceLineageVistor::new();
908        let mut state = PushdownState::default();
909
910        let result = collect_and_transform_function(expr.clone(), &visitor, &mut state)?;
911        assert!(!result.transformed);
912        assert_eq!(result.data, expr);
913        Ok(())
914    }
915
916    #[test]
917    fn test_collect_and_transform_function_with_clickhouse() -> Result<()> {
918        let table = TableReference::bare("test");
919        let test_col = Column::new(Some(table.clone()), "test_col");
920
921        let schema = Schema::new(vec![Field::new("test_col", DataType::Int32, false)]);
922
923        let clickhouse_expr = Expr::ScalarFunction(ScalarFunction {
924            func: Arc::new(clickhouse_udf()),
925            args: vec![count(Expr::Column(test_col.clone())), lit("Int32")],
926        });
927
928        let mut visitor = SourceLineageVistor::new();
929
930        // Seed the column from a dummy plan
931        let dummy_plan =
932            table_scan(Some(table.clone()), &schema, None)?.select(vec![0])?.build()?;
933        let _ = dummy_plan.visit(&mut visitor)?;
934        let mut state = PushdownState::default();
935        let result = collect_and_transform_function(clickhouse_expr, &visitor, &mut state)?;
936        assert!(result.transformed);
937
938        // Should be replaced with a column alias
939        match result.data {
940            Expr::Column(_) => {} // Expected
941            _ => panic!("Expected Column expression after ClickHouse function transformation"),
942        }
943        // State should have functions added
944        assert!(!state.functions.is_empty());
945        Ok(())
946    }
947
948    #[test]
949    fn test_column_id_resolution() -> Result<()> {
950        let table1 = TableReference::bare("test1");
951        let table2 = TableReference::bare("test2");
952        let test_col1 = Column::new(Some(table1.clone()), "test_col1");
953        let test_col_alt1 = Column::new(Some(table1.clone()), "test_col_alt1");
954        let test_col2 = Column::new(Some(table2.clone()), "test_col2");
955
956        let schema1 = Schema::new(vec![
957            Field::new("test_col1", DataType::Int32, false),
958            Field::new("test_col_alt1", DataType::Int32, false),
959        ]);
960
961        let schema2 = Schema::new(vec![Field::new("test_col2", DataType::Int32, false)]);
962
963        let mut visitor = SourceLineageVistor::new();
964
965        // Test Columns Ids
966        let clickhouse_expr_simple = Expr::ScalarFunction(ScalarFunction {
967            func: Arc::new(clickhouse_udf()),
968            args: vec![
969                Expr::Column(test_col1.clone()) + Expr::Column(test_col_alt1.clone()),
970                lit("Int64"),
971            ],
972        })
973        .alias("dummy_udf");
974        let clickhouse_expr_comp = Expr::ScalarFunction(ScalarFunction {
975            func: Arc::new(clickhouse_udf()),
976            args: vec![
977                Expr::Column(test_col1.clone()) + Expr::Column(test_col2.clone()) + lit(2),
978                lit("Int64"),
979            ],
980        })
981        .alias("dummy_udf_comp");
982        let simple_col = Column::from_name("dummy_udf");
983        let comp_col = Column::from_name("dummy_udf_comp");
984        let right_plan =
985            table_scan(Some(table2.clone()), &schema2, None)?.select(vec![0])?.build()?;
986
987        let dummy_plan = table_scan(Some(table1.clone()), &schema1, None)?
988            .select(vec![0, 1])?
989            .project(vec![Expr::Column(test_col1.clone()), clickhouse_expr_simple])?
990            .filter(Expr::Column(test_col1.clone()).gt(lit(0)))?
991            .join_on(right_plan, JoinType::Inner, vec![
992                col("table1.test_col1").eq(col("table2.test_col2")),
993            ])?
994            .project(vec![
995                Expr::Column(simple_col.clone()),
996                clickhouse_expr_comp,
997                lit("hello").alias("scalar_col"),
998            ])?
999            .build()?;
1000
1001        let _ = dummy_plan.visit(&mut visitor)?;
1002
1003        // Simple
1004        let lineage = visitor.column_lineage.get(&simple_col);
1005        let resolved = visitor.resolve_column(&simple_col);
1006        let col_ids = visitor.collect_column_ids(&simple_col);
1007
1008        let Some(ColumnLineage::Compound(ids)) = lineage else {
1009            panic!("Derived columns of clickhouse functions should be stored as `Compound`");
1010        };
1011
1012        assert_eq!(col_ids.len(), 2, "Expected 2 columns in when resolving context");
1013        assert_eq!(ids.len(), 3, "Expected 3 columns in Compound context");
1014        assert!(col_ids.is_subset(ids), "Expected collected columns to be contained in context");
1015
1016        let ResolvedSource::Compound(sources) = resolved else {
1017            panic!("Expected Compound source for simple column");
1018        };
1019
1020        let mut resolved_sources = sources.to_vec();
1021        let scalar_source = resolved_sources.pop().unwrap();
1022        let table_source = resolved_sources.pop().unwrap();
1023
1024        assert_eq!(
1025            scalar_source.as_ref(),
1026            &SourceContext::Scalar(ScalarValue::Utf8(Some("Int64".into())))
1027        );
1028        assert_eq!(table_source.as_ref(), &SourceContext::Table(table1.clone()));
1029
1030        let nullable = visitor.resolve_nullable(&Expr::Column(simple_col.clone()));
1031        assert!(!nullable);
1032
1033        // Compound
1034        let lineage = visitor.column_lineage.get(&comp_col);
1035        let col_ids = visitor.collect_column_ids(&comp_col);
1036
1037        let Some(ColumnLineage::Compound(cols)) = lineage else {
1038            panic!("Derived columns should be stored as `Compound`");
1039        };
1040
1041        // NOTE: It does NOT return the ScalarValue!
1042        assert_eq!(cols.len(), 4, "Expected 4 columns in Compound context, inc Scalars");
1043        assert_eq!(col_ids.len(), 2, "Expected 2 columns in collected columns, exc Scalars");
1044        assert!(col_ids.is_subset(cols), "Expected collected columns to be a subset of context");
1045
1046        Ok(())
1047    }
1048
1049    // TODO: Add tests for Unnest plan after upstream bug is addressed
1050    //
1051    // Ref: https://github.com/datafusion-contrib/datafusion-federation/pull/135
1052
1053    // ----
1054    // All of the following tests require the "mocks" feature to be enabled
1055    // ----
1056
1057    /// Helper to check if a plan is a `ClickHouse` Extension node
1058    #[cfg(feature = "mocks")]
1059    fn is_clickhouse_extension(plan: &LogicalPlan) -> bool {
1060        if let LogicalPlan::Extension(ext) = plan {
1061            ext.node.name() == CLICKHOUSE_FUNCTION_NODE_NAME
1062        } else {
1063            false
1064        }
1065    }
1066
1067    /// Helper to find all `ClickHouse` Extension nodes in the tree with their positions
1068    #[cfg(feature = "mocks")]
1069    fn find_wrapped_plans(plan: &LogicalPlan) -> Vec<String> {
1070        fn traverse(plan: &LogicalPlan, wrapped_plans: &mut Vec<String>, path: &str) {
1071            if is_clickhouse_extension(plan) {
1072                wrapped_plans.push(format!("{path}: {plan}"));
1073            }
1074            for (i, input) in plan.inputs().iter().enumerate() {
1075                traverse(input, wrapped_plans, &format!("{path}/input[{i}]"));
1076            }
1077        }
1078        let mut wrapped_plans = Vec::new();
1079        traverse(plan, &mut wrapped_plans, "root");
1080        wrapped_plans
1081    }
1082
1083    #[cfg(all(feature = "federation", feature = "mocks"))]
1084    fn compare_plan_display(plan: &LogicalPlan, expected: impl Into<String>) {
1085        let mut plan_display = plan.display_indent().to_string();
1086        plan_display.retain(|c| !c.is_whitespace());
1087        let mut expected = expected.into();
1088        expected.retain(|c| !c.is_whitespace());
1089        assert_eq!(plan_display, expected, "Expected equal plans");
1090    }
1091
1092    /// Create a `SessionContext` with registered tables and analyzer for SQL testing
1093    #[cfg(feature = "mocks")]
1094    fn create_test_context() -> Result<SessionContext> {
1095        let ctx = SessionContext::new();
1096        // Register the ClickHouse pushdown UDF
1097        ctx.register_udf(clickhouse_udf());
1098        // Register the ClickHouse function pushdown analyzer
1099        ctx.add_analyzer_rule(Arc::new(ClickHouseFunctionPushdown));
1100
1101        let schema1 = Arc::new(Schema::new(vec![
1102            Field::new("col1", DataType::Int32, false),
1103            Field::new("col2", DataType::Int32, false),
1104            Field::new("col3", DataType::Utf8, false),
1105        ]));
1106        let schema2 = Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)]));
1107
1108        let pool = Arc::new(ClickHouseConnectionPool::new("pool".to_string(), ()));
1109        let table1 = ClickHouseTableProvider::new_with_schema_unchecked(
1110            Arc::clone(&pool),
1111            "table1".into(),
1112            Arc::clone(&schema1),
1113        );
1114        let table2 = ClickHouseTableProvider::new_with_schema_unchecked(
1115            Arc::clone(&pool),
1116            "table1".into(),
1117            schema2,
1118        );
1119
1120        // Register table1 (col1: Int32, col2: Int32, col3: Utf8)
1121        drop(ctx.register_table("table1", Arc::new(table1))?);
1122
1123        // Register table2 (id: Int32)
1124        drop(ctx.register_table("table2", Arc::new(table2))?);
1125
1126        // Register table3 (outside grouping) (col1: Int32, col2: Int32, col3: Utf8)
1127        let table3 = Arc::new(EmptyTable::new(schema1));
1128        drop(ctx.register_table("table3", table3)?);
1129
1130        Ok(ctx)
1131    }
1132
1133    /// Run a query and return the analyzed plan
1134    #[cfg(feature = "mocks")]
1135    async fn run_query(sql: &str) -> Result<LogicalPlan> {
1136        let ctx = create_test_context()?;
1137        let analyzed_plan = ctx.sql(sql).await?.into_optimized_plan()?; // Analyzer runs automatically
1138        SQLOptions::default().verify_plan(&analyzed_plan)?;
1139        Ok(analyzed_plan)
1140    }
1141
1142    #[cfg(feature = "mocks")]
1143    #[tokio::test]
1144    async fn test_simple_projection_with_clickhouse_function() -> Result<()> {
1145        // Test: SELECT clickhouse(exp(col1 + col2), 'Float64'), col2 * 2, UPPER(col3) FROM table1
1146        // Expected: Entire plan wrapped because all functions and columns from same table
1147        let sql =
1148            "SELECT clickhouse(exp(col1 + col2), 'Float64'), col2 * 2, UPPER(col3) FROM table1";
1149        let analyzed_plan = run_query(sql).await?;
1150        let wrapped_plans = find_wrapped_plans(&analyzed_plan);
1151
1152        #[cfg(feature = "federation")]
1153        {
1154            let expected_plan = r#"
1155            Projection: clickhouse(exp(CAST(table1.col1 + table1.col2 AS Float64)), Utf8("Float64")), CAST(table1.col2 AS Int64) * Int64(2), upper(table1.col3)
1156              TableScan: table1 projection=[col1, col2, col3]
1157            "#;
1158            assert!(wrapped_plans.is_empty(), "Expected no wrapped plans");
1159            compare_plan_display(&analyzed_plan, expected_plan);
1160        }
1161        #[cfg(not(feature = "federation"))]
1162        {
1163            assert_eq!(wrapped_plans.len(), 1, "Expected exactly one wrapped plan");
1164            assert!(wrapped_plans[0].starts_with("root:"), "Expected wrapping at root level");
1165        }
1166        Ok(())
1167    }
1168
1169    #[cfg(feature = "mocks")]
1170    #[tokio::test]
1171    async fn test_filter_with_clickhouse_function() -> Result<()> {
1172        // Test: SELECT col2, col3 FROM table1 WHERE clickhouse(exp(col1), 'Float64') > 10
1173        // Expected: Filter wrapped because it contains ClickHouse function, outer projection has no
1174        // functions
1175        let sql = "SELECT col2, col3 FROM table1 WHERE clickhouse(exp(col1), 'Float64') > 10";
1176        let analyzed_plan = run_query(sql).await?;
1177        let wrapped_plans = find_wrapped_plans(&analyzed_plan);
1178
1179        #[cfg(feature = "federation")]
1180        {
1181            let expected_plan = r#"
1182               TableScan: table1 projection=[col2, col3], full_filters=[clickhouse(exp(CAST(table1.col1 AS Float64)), Utf8("Float64")) > Float64(10)]
1183            "#;
1184            assert!(wrapped_plans.is_empty(), "No wrapping expected");
1185            compare_plan_display(&analyzed_plan, expected_plan);
1186        }
1187        #[cfg(not(feature = "federation"))]
1188        {
1189            // Should have exactly one wrapped plan at the filter level (since outer projection has
1190            // no functions)
1191            assert_eq!(wrapped_plans.len(), 1, "Expected exactly one wrapped plan");
1192            assert!(wrapped_plans[0].contains("root"), "Expected wrapping at filter level, root");
1193        }
1194        Ok(())
1195    }
1196
1197    #[cfg(feature = "mocks")]
1198    #[tokio::test]
1199    async fn test_aggregate_blocks_pushdown() -> Result<()> {
1200        // Test: SELECT col2, COUNT(*) FROM table1 WHERE clickhouse(exp(col1), 'Float64') > 5 GROUP
1201        // BY col2 Expected: Function should be wrapped at a level above the aggregate
1202        let sql = "SELECT col2, COUNT(*) FROM table1 WHERE clickhouse(exp(col1), 'Float64') > 5 \
1203                   GROUP BY col2";
1204        let analyzed_plan = run_query(sql).await?;
1205        let wrapped_plans = find_wrapped_plans(&analyzed_plan);
1206
1207        #[cfg(feature = "federation")]
1208        {
1209            assert!(wrapped_plans.is_empty(), "Expected no wrapped plans");
1210            assert!(
1211                analyzed_plan.display().to_string().to_lowercase().starts_with("projection"),
1212                "Expected projection"
1213            );
1214        }
1215
1216        #[cfg(not(feature = "federation"))]
1217        {
1218            // Should have exactly one wrapped plan, and it should be at the aggregate input level
1219            assert_eq!(wrapped_plans.len(), 1, "Expected exactly one wrapped plan");
1220            // The wrapped plan should be at the input to the aggregate (aggregate blocks pushdown)
1221            assert!(
1222                wrapped_plans.iter().any(|w| w.contains("root")),
1223                "Expected function to be wrapped at aggregate input level due to blocking"
1224            );
1225        }
1226        Ok(())
1227    }
1228
1229    #[cfg(feature = "mocks")]
1230    #[tokio::test]
1231    async fn test_multiple_clickhouse_functions_same_table() -> Result<()> {
1232        // Test: SELECT clickhouse(exp(col1), 'Float64'), clickhouse(exp(col2), 'Float64') FROM
1233        // table1 Expected: Both functions use same table, should be wrapped together at
1234        // root level
1235        let sql = "SELECT clickhouse(exp(col1), 'Float64') AS f1, clickhouse(exp(col2), \
1236                   'Float64') AS f2 FROM table1";
1237        let analyzed_plan = run_query(sql).await?;
1238        let wrapped_plans = find_wrapped_plans(&analyzed_plan);
1239
1240        #[cfg(feature = "federation")]
1241        {
1242            assert!(wrapped_plans.is_empty(), "Expected no wrapped plans");
1243            assert!(
1244                analyzed_plan.display().to_string().to_lowercase().starts_with("projection"),
1245                "Expected projection"
1246            );
1247        }
1248
1249        #[cfg(not(feature = "federation"))]
1250        {
1251            // Should have exactly one wrapped plan containing both functions
1252            assert_eq!(
1253                wrapped_plans.len(),
1254                1,
1255                "Expected exactly one wrapped plan for both functions"
1256            );
1257            assert!(
1258                wrapped_plans[0].starts_with("root:"),
1259                "Expected both functions wrapped together at root level"
1260            );
1261        }
1262        Ok(())
1263    }
1264
1265    #[cfg(feature = "mocks")]
1266    #[tokio::test]
1267    async fn test_no_functions_no_wrapping() -> Result<()> {
1268        // Test: SELECT col1, col2 FROM table1
1269        // Expected: No exp functions, so no wrapping should occur
1270        let sql = "SELECT col1, col2 FROM table1";
1271        let analyzed_plan = run_query(sql).await?;
1272        let wrapped_plans = find_wrapped_plans(&analyzed_plan);
1273        // Should have no wrapped plans
1274        assert_eq!(wrapped_plans.len(), 0, "Expected no wrapped plans when no functions present");
1275        Ok(())
1276    }
1277
1278    #[cfg(feature = "mocks")]
1279    #[tokio::test]
1280    async fn test_wrapped_disjoint_tables() -> Result<()> {
1281        // Test: Verify the disjoint check wraps the entire tree, since `table1` and `table2` have
1282        // the same context during testing.
1283        let sql = "SELECT t1.col1, clickhouse(exp(t2.id), 'Float64') FROM (SELECT col1 FROM \
1284                   table1) t1 JOIN (SELECT id from table2) t2 ON t1.col1 = t2.id";
1285        let analyzed_plan = run_query(sql).await?;
1286        let wrapped_plans = find_wrapped_plans(&analyzed_plan);
1287
1288        #[cfg(feature = "federation")]
1289        {
1290            assert!(wrapped_plans.is_empty(), "Expected no wrapped plans");
1291            assert!(
1292                analyzed_plan.display().to_string().to_lowercase().starts_with("projection"),
1293                "Expected projection"
1294            );
1295        }
1296
1297        #[cfg(not(feature = "federation"))]
1298        {
1299            assert_eq!(wrapped_plans.len(), 1, "Expected function wrapped entire plan");
1300            assert!(
1301                wrapped_plans[0].contains("root"),
1302                "Expected function wrapped on right side of JOIN"
1303            );
1304        }
1305        Ok(())
1306    }
1307
1308    #[cfg(feature = "mocks")]
1309    #[tokio::test]
1310    async fn test_disjoint_tables() -> Result<()> {
1311        // Test: Verify the disjoint check wraps plan on join side only.
1312        //       The test tables "test1" and "test2" will use the same context, "test" will not.
1313        let sql = "SELECT t3.col1, clickhouse(exp(t2.id), 'Float64')
1314            FROM (SELECT col1 FROM table3) t3
1315            JOIN (SELECT id from table2) t2 ON t3.col1 = t2.id
1316        ";
1317
1318        let analyzed_plan = run_query(sql).await?;
1319        let wrapped_plans = find_wrapped_plans(&analyzed_plan);
1320
1321        // TODO: Update with expected plan display
1322        #[cfg(feature = "federation")]
1323        {
1324            let expected_plan = r#"
1325            Projection: t3.col1, clickhouse(exp(t2.id),Utf8("Float64"))
1326              Inner Join: t3.col1 = t2.id
1327                SubqueryAlias: t3
1328                  TableScan: table3 projection=[col1]
1329                SubqueryAlias: t2
1330                  Projection: table2.id, clickhouse(exp(CAST(table2.id AS Float64)), Utf8("Float64")) AS clickhouse(exp(t2.id),Utf8("Float64"))
1331                    TableScan: table2 projection=[id]
1332            "#;
1333            assert!(wrapped_plans.is_empty(), "Expected no wrapped plans");
1334            compare_plan_display(&analyzed_plan, expected_plan);
1335        }
1336
1337        #[cfg(not(feature = "federation"))]
1338        {
1339            // This test verifies that the algorithm correctly handles complex JOIN scenarios:
1340            // 1. Function column refs resolve to table2 source (identified by "grouped source")
1341            // 2. Projection column refs include both table and table2 sources
1342            // 3. Algorithm should route function to the right side of JOIN where t2.id is available
1343            // 4. Should have exactly one wrapped plan on the right side
1344            assert_eq!(wrapped_plans.len(), 1, "Expected function routed to right side of JOIN");
1345            assert!(
1346                wrapped_plans[0].contains("input[1]"),
1347                "Expected function wrapped on right side of JOIN"
1348            );
1349        }
1350
1351        Ok(())
1352    }
1353
1354    // TODO: This plan represents a feature that needs to be implemented: how to handle "mixed"
1355    // functions in a plan node. Ideally the functions would be separated and the clickhouse
1356    // function lowered, but that will take quite a bit of logic.
1357    #[cfg(feature = "mocks")]
1358    #[tokio::test]
1359    async fn test_complex_agg() -> Result<()> {
1360        let sql = "SELECT
1361                clickhouse(pow(t.id, 2), 'Int32') as id_mod,
1362                COUNT(t.id) as total,
1363                MAX(clickhouse(exp(t.id), 'Float64')) as max_exp
1364            FROM table2 t
1365            GROUP BY id_mod";
1366        let analyzed_plan = run_query(sql).await?;
1367        let wrapped_plans = find_wrapped_plans(&analyzed_plan);
1368
1369        #[cfg(feature = "federation")]
1370        {
1371            let expected_plan = r#"
1372            Projection: clickhouse(power(t.id,Int64(2)),Utf8("Int32")) AS id_mod, count(t.id) AS total, max(clickhouse(exp(t.id),Utf8("Float64"))) AS max_exp
1373              Aggregate: groupBy=[[clickhouse(power(CAST(t.id AS Int64), Int64(2)), Utf8("Int32"))]], aggr=[[count(t.id), max(clickhouse(exp(CAST(t.id AS Float64)), Utf8("Float64")))]]
1374                SubqueryAlias: t
1375                  TableScan: table2 projection=[id]
1376            "#
1377            .trim();
1378            assert!(wrapped_plans.is_empty(), "No wrapping expected");
1379            compare_plan_display(&analyzed_plan, expected_plan);
1380        }
1381
1382        #[cfg(not(feature = "federation"))]
1383        {
1384            assert_eq!(wrapped_plans.len(), 1, "Expected entire plan wrapped");
1385            assert!(wrapped_plans[0].contains("root"), "Expected function wrapped at root");
1386        }
1387
1388        Ok(())
1389    }
1390
1391    #[cfg(feature = "mocks")]
1392    #[tokio::test]
1393    async fn test_union() -> Result<()> {
1394        let sql = "
1395            SELECT col1 as id, clickhouse(exp(col1), 'Float64') as func_id
1396            FROM table1 WHERE table1.col1 = 1
1397            UNION ALL
1398            SELECT id, clickhouse(pow(id, 2), 'Float64') as func_id
1399            FROM table2 WHERE table2.id = 1
1400        ";
1401        let analyzed_plan = run_query(sql).await?;
1402        let wrapped_plans = find_wrapped_plans(&analyzed_plan);
1403
1404        #[cfg(feature = "federation")]
1405        {
1406            let expected_plan = r#"
1407            Union
1408              Projection: table1.col1 AS id, clickhouse(exp(CAST(table1.col1 AS Float64)), Utf8("Float64")) AS func_id
1409                TableScan: table1 projection=[col1], full_filters=[table1.col1 = Int32(1)]
1410              Projection: table2.id, clickhouse(power(CAST(table2.id AS Int64), Int64(2)), Utf8("Float64")) AS func_id
1411                TableScan: table2 projection=[id], full_filters=[table2.id = Int32(1)]
1412            "#
1413            .trim();
1414            assert!(wrapped_plans.is_empty(), "No wrapping expected");
1415            compare_plan_display(&analyzed_plan, expected_plan);
1416        }
1417
1418        #[cfg(not(feature = "federation"))]
1419        {
1420            assert_eq!(wrapped_plans.len(), 1, "Expected entire plan wrapped");
1421            assert!(wrapped_plans[0].contains("root"), "Expected function wrapped at root");
1422        }
1423
1424        Ok(())
1425    }
1426
1427    #[cfg(feature = "mocks")]
1428    #[tokio::test]
1429    async fn test_limit() -> Result<()> {
1430        let sql = "SELECT clickhouse(abs(t2.id), 'Int32') FROM table2 t2 LIMIT 1";
1431        let analyzed_plan = run_query(sql).await?;
1432        let wrapped_plans = find_wrapped_plans(&analyzed_plan);
1433
1434        #[cfg(feature = "federation")]
1435        {
1436            let expected_plan = r#"
1437                Projection: clickhouse(abs(t2.id), Utf8("Int32"))
1438                  SubqueryAlias: t2
1439                    Limit: skip=0, fetch=1
1440                      TableScan: table2 projection=[id], fetch=1
1441            "#
1442            .trim();
1443            assert!(wrapped_plans.is_empty(), "No wrapping expected");
1444            compare_plan_display(&analyzed_plan, expected_plan);
1445        }
1446
1447        #[cfg(not(feature = "federation"))]
1448        {
1449            assert_eq!(wrapped_plans.len(), 1, "Expected entire plan wrapped");
1450            assert!(wrapped_plans[0].contains("root"), "Expected function wrapped at root");
1451        }
1452        Ok(())
1453    }
1454
1455    #[cfg(feature = "mocks")]
1456    #[tokio::test]
1457    async fn test_sort() -> Result<()> {
1458        let sql = "SELECT t2.id FROM table2 t2 ORDER BY clickhouse(abs(t2.id), 'Int64')";
1459        let analyzed_plan = run_query(sql).await?;
1460        let wrapped_plans = find_wrapped_plans(&analyzed_plan);
1461
1462        #[cfg(feature = "federation")]
1463        {
1464            let expected_plan = r#"
1465            Sort: clickhouse(abs(t2.id), Utf8("Int64")) ASC NULLS LAST
1466              SubqueryAlias: t2
1467                TableScan: table2 projection=[id]
1468            "#
1469            .trim();
1470            assert!(wrapped_plans.is_empty(), "No wrapping expected");
1471            compare_plan_display(&analyzed_plan, expected_plan);
1472        }
1473
1474        #[cfg(not(feature = "federation"))]
1475        {
1476            assert_eq!(wrapped_plans.len(), 1, "Expected entire plan wrapped");
1477            assert!(wrapped_plans[0].contains("root"), "Expected function wrapped at root");
1478        }
1479        Ok(())
1480    }
1481
1482    // NOTE: First query will fail. That is expected due to how DataFusion groups joins. The
1483    // unresolved table will block pushing the function to the left side of the join, and the usage
1484    // of both tables will block pushing the function to the right side of the join.
1485    //
1486    // But, in the second query, the joins are re-organized. This groups the "clickhouse" tables on
1487    // the same side of the join.
1488    #[cfg(feature = "mocks")]
1489    #[tokio::test]
1490    async fn test_multiple_cols_same_function() -> Result<()> {
1491        let sql = "SELECT t3.col2
1492                , t1.col2
1493                , t2.id
1494                , clickhouse(t1.col1 + t2.id, 'Int64') as sum_ids
1495            FROM table3 t3
1496            JOIN table1 t1 ON t1.col1 = t3.col1
1497            JOIN table2 t2 ON t2.id = t1.col1
1498        ";
1499        let result = run_query(sql).await;
1500        assert!(result.is_err(), "Cannot push to either side of join");
1501
1502        let sql = "SELECT t3.col2
1503                , t1.col2
1504                , t2.id
1505                , clickhouse(t1.col1 + t2.id, 'Int64') as sum_ids
1506            FROM table1 t1
1507            JOIN table2 t2 ON t2.id = t1.col1
1508            JOIN table3 t3 ON t2.id = t3.col1
1509        ";
1510
1511        let analyzed_plan = run_query(sql).await?;
1512        let wrapped_plans = find_wrapped_plans(&analyzed_plan);
1513
1514        #[cfg(feature = "federation")]
1515        {
1516            let expected_plan = r#"
1517            Projection: t3.col2, t1.col2, t2.id, clickhouse(t1.col1 + t2.id,Utf8("Int64")) AS sum_ids
1518              Inner Join: t2.id = t3.col1
1519                Projection: t1.col2, t2.id, clickhouse(t1.col1 + t2.id, Utf8("Int64")) AS clickhouse(t1.col1 + t2.id,Utf8("Int64"))
1520                  Inner Join: t1.col1 = t2.id
1521                    SubqueryAlias: t1
1522                      TableScan: table1 projection=[col1, col2]
1523                    SubqueryAlias: t2
1524                      TableScan: table2 projection=[id]
1525              SubqueryAlias: t3
1526                TableScan: table3 projection=[col1, col2]
1527            "#.trim();
1528
1529            assert!(wrapped_plans.is_empty(), "No wrapping expected");
1530            compare_plan_display(&analyzed_plan, expected_plan);
1531        }
1532
1533        #[cfg(not(feature = "federation"))]
1534        {
1535            assert_eq!(wrapped_plans.len(), 1, "Expected entire plan wrapped");
1536            assert!(wrapped_plans[0].contains("root"), "Expected function wrapped at root");
1537        }
1538        Ok(())
1539    }
1540}