1use std::collections::HashSet;
2use std::sync::Arc;
3
4use clickhouse_arrow::rustc_hash::FxHashMap;
5use datafusion::arrow::datatypes::Field;
6use datafusion::common::tree_node::{Transformed, TreeNode, TreeNodeRecursion};
7use datafusion::common::{Column, DFSchema, DFSchemaRef, Result, plan_err};
8use datafusion::logical_expr::expr::ScalarFunction;
9use datafusion::logical_expr::{LogicalPlan, Projection, SubqueryAlias};
10use datafusion::optimizer::AnalyzerRule;
11use datafusion::prelude::Expr;
12use datafusion::sql::TableReference;
13
14use super::source_context::ResolvedSource;
15use super::source_visitor::{ColumnId, SourceLineageVistor};
16use super::utils::{
17 extract_function_and_return_type, is_clickhouse_function, use_clickhouse_function_context,
18};
19use crate::utils::analyze::push_exprs_below_subquery;
20
21#[derive(Default, Debug, Clone)]
23struct PushdownState {
24 functions: FxHashMap<ResolvedSource, Vec<Expr>>,
26 function_sources: ResolvedSource,
28 plan_sources: ResolvedSource,
30}
31
32impl PushdownState {
33 fn take_functions(&mut self) -> Vec<Expr> {
34 self.functions.values_mut().flatten().map(std::mem::take).collect::<Vec<_>>()
35 }
36
37 fn has_functions(&self) -> bool { self.functions.values().any(|f| !f.is_empty()) }
38
39 fn take_relevant_functions(
41 &mut self,
42 schema: &DFSchemaRef,
43 visitor: &SourceLineageVistor,
44 ) -> FxHashMap<ResolvedSource, Vec<Expr>> {
45 if !self.has_functions() {
46 return FxHashMap::default();
47 }
48 let sources = visitor.resolve_schema(schema.as_ref());
49 let column_ids = schema
50 .columns()
51 .into_iter()
52 .flat_map(|col| visitor.collect_column_ids(&col))
53 .collect::<HashSet<_>>();
54 let extracted = self
55 .functions
56 .iter_mut()
57 .filter(|(r, _)| r.resolves_intersects(&sources))
58 .map(|(resolved, funcs)| {
59 let extracted = funcs
60 .extract_if(.., |f| {
61 f.column_refs()
62 .into_iter()
63 .flat_map(|col| visitor.collect_column_ids(col))
64 .collect::<HashSet<_>>()
65 .is_subset(&column_ids)
66 })
67 .collect::<Vec<_>>();
68 (resolved.clone(), extracted)
69 })
70 .collect::<FxHashMap<_, _>>();
71 self.functions.retain(|_, funcs| !funcs.is_empty());
73 extracted
74 }
75}
76
77#[derive(Debug, Clone, Copy)]
80pub struct ClickHouseFunctionPushdown;
81
82impl AnalyzerRule for ClickHouseFunctionPushdown {
83 fn analyze(
84 &self,
85 plan: LogicalPlan,
86 _config: &datafusion::common::config::ConfigOptions,
87 ) -> Result<LogicalPlan> {
88 if matches!(plan, LogicalPlan::Ddl(_) | LogicalPlan::DescribeTable(_)) {
89 return Ok(plan);
90 }
91
92 #[cfg_attr(feature = "test-utils", expect(unused))]
94 let mut lineage_visitor = SourceLineageVistor::new();
95
96 #[cfg(feature = "test-utils")]
97 let mut lineage_visitor = lineage_visitor
98 .with_source_grouping(HashSet::from(["table1".to_string(), "table2".to_string()]));
99
100 let _ = plan.visit(&mut lineage_visitor)?;
101
102 if lineage_visitor.clickhouse_function_count == 0 {
104 return Ok(plan);
105 }
106
107 let mut state = PushdownState::default();
109
110 let (new_plan, _) = analyze_and_transform_plan(plan, &mut state, &lineage_visitor)?;
111 Ok(new_plan.data)
112 }
113
114 fn name(&self) -> &'static str { "clickhouse_function_pushdown" }
115}
116
117fn analyze_and_transform_plan(
138 plan: LogicalPlan,
139 state: &mut PushdownState,
140 visitor: &SourceLineageVistor,
141) -> Result<(Transformed<LogicalPlan>, Option<ResolvedSource>)> {
142 let function_violations = check_state_for_violations(state, &plan, visitor);
144
145 let plan_violations = check_plan_for_violations(&plan, visitor);
147
148 let plan_sources = visitor.resolve_schema(plan.schema());
149 let plan_function_sources = resolve_plan_expr_functions(&plan, visitor);
150
151 if state.has_functions() {
153 if state.function_sources.resolves_eq(&plan_sources) {
154 let wrapped_plan = wrap_plan_with_functions(plan, state.take_functions(), visitor)?;
155 return Ok((Transformed::yes(wrapped_plan), None));
156 }
157
158 if !state.function_sources.resolves_within(&plan_sources) {
160 return plan_err!(
161 "SQL not supported, could not determine sources of following functions: {:?}",
162 state.functions.iter().collect::<Vec<_>>()
163 );
164 }
165 }
166
167 let top_level_plan = !state.plan_sources.is_known();
169 if top_level_plan {
170 state.plan_sources = plan_sources.clone();
171 }
172
173 if state.plan_sources.resolves_eq(&plan_function_sources) {
175 if top_level_plan {
177 let wrapped_plan = wrap_plan_with_functions(plan, state.take_functions(), visitor)?;
178 return Ok((Transformed::yes(wrapped_plan), None));
179 }
180 return Ok((Transformed::no(plan), Some(plan_function_sources)));
182 }
183
184 state.function_sources =
186 std::mem::take(&mut state.function_sources).merge(plan_function_sources);
187
188 if state.function_sources.resolves_eq(&plan_sources) {
190 let wrapped_plan = wrap_plan_with_functions(plan, state.take_functions(), visitor)?;
191 return Ok((Transformed::yes(wrapped_plan), None));
192 }
193
194 if state.function_sources.is_known() {
196 semantic_err(
198 plan.display(),
199 "SQL unsupported, pushed functions violate sql semantics in current plan.",
200 &function_violations,
201 )?;
202 semantic_err(
204 plan.display(),
205 "SQL unsupported, plan violates sql semantics in plan's expressions.",
206 &plan_violations,
207 )?;
208 }
209
210 let parent_sources = std::mem::take(&mut state.plan_sources);
212
213 if state.has_functions()
215 && let LogicalPlan::SubqueryAlias(alias) = &plan
216 {
217 state.functions = std::mem::take(&mut state.functions)
218 .into_iter()
219 .map(|(resolv, funcs)| (resolv, push_exprs_below_subquery(funcs, alias)))
220 .collect();
221 }
222
223 let aliased_exprs = collect_and_transform_exprs(plan.expressions(), visitor, state)?;
225
226 let mut was_transformed = state.has_functions();
228
229 let inputs_transformed = plan
231 .inputs()
232 .into_iter()
233 .cloned()
234 .map(|input| {
235 let extracted = state.take_relevant_functions(input.schema(), visitor);
236 let mut input_state = PushdownState {
237 function_sources: extracted
238 .keys()
239 .cloned()
240 .reduce(ResolvedSource::merge)
241 .unwrap_or_default(),
242 functions: extracted,
243 plan_sources: plan_sources.clone(),
244 };
245 analyze_and_transform_plan(input, &mut input_state, visitor)
246 })
247 .collect::<Result<Vec<_>>>()?;
248
249 if state.has_functions() {
251 return plan_err!(
252 "SQL not supported, could not determine sources of following functions: {:?}",
253 state.functions.iter().collect::<Vec<_>>()
254 );
255 }
256
257 let input_resolution = inputs_transformed
259 .iter()
260 .inspect(|(i, _)| was_transformed |= i.transformed)
261 .filter_map(|(_, func_sources)| func_sources.clone())
262 .reduce(ResolvedSource::merge)
263 .unwrap_or_default();
264
265 let new_inputs = inputs_transformed.into_iter().map(|(i, _)| i.data).collect::<Vec<_>>();
266 let new_plan = plan.with_new_exprs(aliased_exprs, new_inputs)?;
267
268 let new_plan =
269 if !top_level_plan && parent_sources.resolves_eq(&input_resolution) {
271 return Ok((Transformed::no(plan), Some(input_resolution)));
272
273 } else if plan_sources.resolves_eq(&input_resolution) {
275 let wrapped = wrap_plan_with_functions(new_plan, state.take_functions(), visitor)?;
276 Transformed::yes(wrapped)
277
278 } else if was_transformed {
280 Transformed::yes(new_plan)
281 } else {
282 Transformed::no(new_plan)
283 };
284
285 Ok((new_plan, None))
286}
287
288fn resolve_plan_expr_functions(
289 input: &LogicalPlan,
290 visitor: &SourceLineageVistor,
291) -> ResolvedSource {
292 let mut exprs_resolved = ResolvedSource::default();
293 let _ = input
294 .apply_expressions(|expr| {
295 use_clickhouse_function_context(expr, |e| {
296 exprs_resolved = std::mem::take(&mut exprs_resolved).merge(visitor.resolve_expr(e));
297 Ok(TreeNodeRecursion::Stop)
298 })
299 .unwrap();
300 Ok(TreeNodeRecursion::Continue)
301 })
302 .unwrap();
303 exprs_resolved
304}
305
306fn wrap_plan_with_functions(
308 plan: LogicalPlan,
309 functions: Vec<Expr>,
310 visitor: &SourceLineageVistor,
311) -> Result<LogicalPlan> {
312 #[cfg(feature = "federation")]
313 #[expect(clippy::unnecessary_wraps)]
314 fn return_wrapped_plan(plan: LogicalPlan) -> Result<LogicalPlan> { Ok(plan) }
315
316 #[cfg(not(feature = "federation"))]
317 fn return_wrapped_plan(plan: LogicalPlan) -> Result<LogicalPlan> {
318 use datafusion::logical_expr::Extension;
319
320 use crate::context::plan_node::ClickHouseFunctionNode;
321
322 Ok(LogicalPlan::Extension(Extension {
323 node: Arc::new(ClickHouseFunctionNode::try_new(plan)?),
324 }))
325 }
326
327 let (func_fields, func_cols) = functions_to_field_and_cols(functions, visitor)?;
329
330 let plan = plan.transform_up_with_subqueries(strip_table_scan_catalog).unwrap().data;
332 let plan = plan.recompute_schema()?;
334
335 match plan {
336 LogicalPlan::SubqueryAlias(alias) => {
337 let func_cols = push_exprs_below_subquery(func_cols, &alias);
339
340 let input = Arc::unwrap_or_clone(alias.input);
341 let new_input = wrap_plan_in_projection(input, func_fields, func_cols)?.into();
342 let new_alias = SubqueryAlias::try_new(new_input, alias.alias)?;
343 return_wrapped_plan(LogicalPlan::SubqueryAlias(new_alias))
344 }
345 _ => return_wrapped_plan(wrap_plan_in_projection(plan, func_fields, func_cols)?),
346 }
347}
348
349type QualifiedField = (Option<TableReference>, Arc<Field>);
350
351fn functions_to_field_and_cols(
353 functions: Vec<Expr>,
354 visitor: &SourceLineageVistor,
355) -> Result<(Vec<QualifiedField>, Vec<Expr>)> {
356 let mut fields = Vec::new();
357 let mut columns = Vec::new();
358 for function in functions {
359 let is_nullable = visitor.resolve_nullable(&function);
360 let alias = function.schema_name().to_string();
361 let (inner_function, data_type) = extract_function_and_return_type(function)?;
362 fields.push((None, Arc::new(Field::new(&alias, data_type, is_nullable))));
363 columns.push(inner_function.alias(alias));
364 }
365 Ok((fields, columns))
366}
367
368fn wrap_plan_in_projection(
369 plan: LogicalPlan,
370 func_fields: Vec<QualifiedField>,
371 func_cols: Vec<Expr>,
372) -> Result<LogicalPlan> {
373 if func_cols.is_empty() {
375 return Ok(plan);
376 }
377
378 let metadata = plan.schema().metadata().clone();
379 let mut fields =
380 plan.schema().iter().map(|(q, f)| (q.cloned(), Arc::clone(f))).collect::<Vec<_>>();
381 fields.extend(func_fields);
382
383 let new_schema = DFSchema::new_with_metadata(fields, metadata)?;
385
386 let new_plan = if let LogicalPlan::Projection(mut projection) = plan {
388 projection.expr.extend(func_cols);
389 Projection::try_new_with_schema(projection.expr, projection.input, new_schema.into())?
390 } else {
391 let mut exprs = plan.schema().columns().into_iter().map(Expr::Column).collect::<Vec<_>>();
392 exprs.extend(func_cols);
393 Projection::try_new_with_schema(exprs, plan.into(), new_schema.into())?
394 };
395
396 Ok(LogicalPlan::Projection(new_plan))
397}
398
399fn collect_and_transform_exprs(
402 exprs: Vec<Expr>,
403 visitor: &SourceLineageVistor,
404 state: &mut PushdownState,
405) -> Result<Vec<Expr>> {
406 exprs
407 .into_iter()
408 .map(|expr| collect_and_transform_function(expr, visitor, state).map(|t| t.data))
409 .collect::<Result<Vec<_>>>()
410}
411
412fn collect_and_transform_function(
417 expr: Expr,
418 visitor: &SourceLineageVistor,
419 state: &mut PushdownState,
420) -> Result<Transformed<Expr>> {
421 expr.transform_down(|e| {
422 if is_clickhouse_function(&e) {
423 let func_resolved = visitor.resolve_expr(&e);
424 let alias = e.schema_name().to_string();
425
426 if matches!(
428 func_resolved,
429 ResolvedSource::Scalar(_) | ResolvedSource::Scalars(_) | ResolvedSource::Unknown
430 ) {
431 let Expr::ScalarFunction(ScalarFunction { mut args, .. }) = e else {
432 unreachable!(); };
434 if args.is_empty() {
435 return plan_err!("`clickhouse` function requires an arg, none found: {alias}");
436 }
437 return Ok(Transformed::new(args.remove(0), true, TreeNodeRecursion::Jump));
438 }
439
440 state.function_sources =
441 std::mem::take(&mut state.function_sources).merge(func_resolved.clone());
442 let current_funcs = state.functions.entry(func_resolved).or_default();
443 if !current_funcs.contains(&e) {
445 current_funcs.push(e);
446 }
447
448 Ok(Transformed::new(
449 Expr::Column(Column::from_name(alias)),
450 true,
451 TreeNodeRecursion::Jump,
452 ))
453 } else {
454 Ok(Transformed::no(e))
455 }
456 })
457}
458
459fn strip_table_scan_catalog(plan: LogicalPlan) -> Result<Transformed<LogicalPlan>> {
461 plan.transform_up_with_subqueries(|node| {
462 if let LogicalPlan::TableScan(mut scan) = node {
463 if let TableReference::Full { schema, table, .. } = scan.table_name {
464 scan.table_name = TableReference::Partial { schema, table };
465 return Ok(Transformed::yes(LogicalPlan::TableScan(scan)));
466 }
467
468 Ok(Transformed::no(LogicalPlan::TableScan(scan)))
469 } else {
470 Ok(Transformed::no(node))
471 }
472 })
473}
474
475fn check_state_for_violations(
477 state: &PushdownState,
478 plan: &LogicalPlan,
479 visitor: &SourceLineageVistor,
480) -> HashSet<Column> {
481 let mut function_violations = HashSet::new();
482 if !state.functions.is_empty() {
483 for func in state.functions.values().flatten() {
484 function_violations
485 .extend(violates_pushdown_semantics(func, plan, visitor).into_iter().cloned());
486 }
487 }
488 function_violations
489}
490
491fn check_plan_for_violations(plan: &LogicalPlan, visitor: &SourceLineageVistor) -> HashSet<Column> {
493 violates_plan_semantics(plan, visitor).into_iter().cloned().collect()
494}
495
496fn violates_pushdown_semantics<'a>(
502 function: &Expr,
503 plan: &'a LogicalPlan,
504 visitor: &SourceLineageVistor,
505) -> HashSet<&'a Column> {
506 let function_column_ids = function
510 .column_refs()
511 .iter()
512 .flat_map(|col| visitor.collect_column_ids(col))
513 .collect::<HashSet<_>>();
514
515 if function_column_ids.is_empty() {
517 return HashSet::new();
518 }
519
520 match plan {
521 LogicalPlan::Aggregate(agg) => {
522 if let Some(related_cols) = check_function_against_exprs(
524 function,
525 &agg.aggr_expr,
526 &function_column_ids,
527 visitor,
528 true,
529 ) {
530 return related_cols;
531 }
532
533 if let Some(related_cols) = check_function_against_exprs(
535 function,
536 &agg.group_expr,
537 &function_column_ids,
538 visitor,
539 false,
540 ) {
541 return related_cols;
542 }
543 }
544 LogicalPlan::Window(window) => {
545 if let Some(related_cols) = check_function_against_exprs(
546 function,
547 &window.window_expr,
548 &function_column_ids,
549 visitor,
550 true,
551 ) {
552 return related_cols;
553 }
554 }
555 LogicalPlan::Subquery(query) => {
556 if let Some(related_cols) = check_function_against_exprs(
557 function,
558 &query.outer_ref_columns,
559 &function_column_ids,
560 visitor,
561 true,
562 ) {
563 return related_cols;
564 }
565 }
566 _ => {}
567 }
568 HashSet::new()
569}
570
571fn violates_plan_semantics<'a>(
576 plan: &'a LogicalPlan,
577 _visitor: &SourceLineageVistor,
578) -> HashSet<&'a Column> {
579 if let LogicalPlan::Aggregate(agg) = plan {
580 for expr in &agg.aggr_expr {
583 let mut violations = None;
584 drop(use_clickhouse_function_context(expr, |agg_func| {
585 let found = agg.group_expr.iter().any(|group_e| {
587 let mut found = false;
588 use_clickhouse_function_context(group_e, |group_func| {
589 found |= group_func == agg_func;
590 Ok(TreeNodeRecursion::Stop)
591 })
592 .unwrap();
593 found
594 });
595
596 if !found {
597 violations = Some(expr.column_refs());
598 return plan_err!("Aggregate functions must be in group by");
600 }
601 Ok(TreeNodeRecursion::Stop)
602 }));
603
604 if let Some(violations) = violations {
605 return violations;
607 }
608 }
609 }
610
611 HashSet::new()
612}
613
614fn check_function_against_exprs<'a>(
617 func: &Expr,
618 exprs: &'a [Expr],
619 func_column_ids: &HashSet<ColumnId>,
620 visitor: &SourceLineageVistor,
621 disjoint_required: bool,
622) -> Option<HashSet<&'a Column>> {
623 for expr in exprs {
624 if expr == func {
626 continue;
627 }
628 let col_refs = expr.column_refs();
629 let expr_column_ids =
630 col_refs.iter().flat_map(|col| visitor.collect_column_ids(col)).collect::<HashSet<_>>();
631
632 if expr_column_ids.is_empty() && !disjoint_required {
634 continue;
635 }
636 if expr_column_ids.is_disjoint(func_column_ids) != disjoint_required {
638 return Some(col_refs);
640 }
641 }
642 None
643}
644
645fn semantic_err(
647 name: impl std::fmt::Display,
648 msg: &str,
649 violations: &HashSet<Column>,
650) -> Result<()> {
651 if !violations.is_empty() {
652 let violations =
653 violations.iter().map(Column::quoted_flat_name).collect::<Vec<_>>().join(", ");
654 return plan_err!("[{name}] - {msg} Violations: {violations}");
655 }
656 Ok(())
657}
658
659#[cfg(all(test, feature = "test-utils"))]
660mod tests {
661 use std::collections::HashSet;
662 use std::sync::Arc;
663
664 use datafusion::arrow::datatypes::{DataType, Field, Schema};
665 use datafusion::catalog::TableProvider;
666 use datafusion::common::{Column, Result};
667 use datafusion::datasource::empty::EmptyTable;
668 use datafusion::datasource::provider_as_source;
669 use datafusion::functions_aggregate::count::count;
670 use datafusion::logical_expr::{Expr, LogicalPlan, LogicalPlanBuilder, table_scan};
671 use datafusion::prelude::*;
672 use datafusion::scalar::ScalarValue;
673 use datafusion::sql::TableReference;
674
675 use super::super::source_visitor::SourceLineageVistor;
676 use super::*;
677 use crate::analyzer::source_context::SourceContext;
678 use crate::analyzer::source_visitor::ColumnLineage;
679 use crate::udfs::clickhouse::clickhouse_udf;
680 #[cfg(feature = "mocks")]
681 use crate::{
682 ClickHouseConnectionPool, ClickHouseTableProvider,
683 analyzer::function_pushdown::ClickHouseFunctionPushdown,
684 plan_node::CLICKHOUSE_FUNCTION_NODE_NAME,
685 };
686
687 fn create_table_scan(table: TableReference, provider: Arc<dyn TableProvider>) -> LogicalPlan {
688 LogicalPlanBuilder::scan(table, provider_as_source(provider), None)
689 .unwrap()
690 .build()
691 .unwrap()
692 }
693
694 #[test]
695 fn test_functions_to_field_and_cols_empty() -> Result<()> {
696 let visitor = SourceLineageVistor::default();
697 let functions = Vec::new();
698 let result = functions_to_field_and_cols(functions, &visitor)?;
699 assert!(result.0.is_empty());
700 assert!(result.1.is_empty());
701 Ok(())
702 }
703
704 #[test]
705 fn test_functions_to_field_and_cols_single_function() -> Result<()> {
706 let visitor = SourceLineageVistor::default();
707 let functions = vec![Expr::ScalarFunction(ScalarFunction {
708 func: Arc::new(clickhouse_udf()),
709 args: vec![lit("count()"), lit("Int64")],
710 })];
711
712 let (fields, funcs) = functions_to_field_and_cols(functions, &visitor)?;
713 assert_eq!(fields.len(), 1);
714 assert_eq!(funcs.len(), 1);
715
716 let (field_ref, field) = &fields[0];
717 assert!(field_ref.is_none());
718 assert_eq!(field.data_type(), &DataType::Int64);
719 assert!(!field.is_nullable());
720
721 Ok(())
722 }
723
724 #[test]
725 fn test_functions_to_field_and_cols_multiple_functions() -> Result<()> {
726 let visitor = SourceLineageVistor::default();
727 let functions = vec![
728 Expr::ScalarFunction(ScalarFunction {
729 func: Arc::new(clickhouse_udf()),
730 args: vec![lit("count()"), lit("Int64")],
731 }),
732 Expr::ScalarFunction(ScalarFunction {
733 func: Arc::new(clickhouse_udf()),
734 args: vec![lit("sum(x)"), lit("Float64")],
735 }),
736 ];
737
738 let (fields, funcs) = functions_to_field_and_cols(functions, &visitor)?;
739 assert_eq!(fields.len(), 2);
740 assert_eq!(funcs.len(), 2);
741
742 assert_eq!(fields[0].1.data_type(), &DataType::Int64);
744 assert_eq!(fields[1].1.data_type(), &DataType::Float64);
745
746 Ok(())
747 }
748
749 #[test]
750 fn test_wrap_plan_in_projection_no_functions() -> Result<()> {
751 let schema = Arc::new(Schema::new(vec![Field::new("col1", DataType::Int32, false)]));
752 let provider = Arc::new(EmptyTable::new(schema));
753 let plan = create_table_scan(TableReference::bare("test_table"), provider);
754
755 let result = wrap_plan_in_projection(plan.clone(), vec![], vec![])?;
756
757 match (&plan, &result) {
759 (LogicalPlan::TableScan(original), LogicalPlan::TableScan(result)) => {
760 assert_eq!(original.table_name, result.table_name);
761 }
762 _ => panic!("Expected TableScan plans"),
763 }
764
765 Ok(())
766 }
767
768 #[test]
769 fn test_wrap_plan_in_projection_with_functions() -> Result<()> {
770 let schema = Arc::new(Schema::new(vec![Field::new("col1", DataType::Int32, false)]));
771 let provider = Arc::new(EmptyTable::new(schema));
772 let plan = create_table_scan(TableReference::bare("test_table"), provider);
773
774 let func_fields =
775 vec![(None, Arc::new(Field::new("func_result", DataType::Float64, true)))];
776 let func_cols = vec![lit("test_function").alias("func_result")];
777
778 let result = wrap_plan_in_projection(plan, func_fields, func_cols)?;
779
780 match result {
782 LogicalPlan::Projection(projection) => {
783 assert_eq!(projection.expr.len(), 2); assert_eq!(projection.schema.fields().len(), 2);
785 }
786 _ => panic!("Expected Projection plan"),
787 }
788
789 Ok(())
790 }
791
792 #[test]
793 fn test_strip_table_scan_catalog_no_catalog() -> Result<()> {
794 let schema = Arc::new(Schema::new(vec![Field::new("col1", DataType::Int32, false)]));
795 let provider = Arc::new(EmptyTable::new(schema));
796 let plan = create_table_scan(TableReference::bare("test_table"), provider);
797
798 let result = strip_table_scan_catalog(plan)?;
799
800 assert!(!result.transformed);
802
803 Ok(())
804 }
805
806 #[test]
807 fn test_strip_table_scan_catalog_with_catalog() -> Result<()> {
808 let schema = Arc::new(Schema::new(vec![Field::new("col1", DataType::Int32, false)]));
809 let provider = Arc::new(EmptyTable::new(schema));
810 let table_ref = TableReference::Full {
811 catalog: "catalog".into(),
812 schema: "schema".into(),
813 table: "table".into(),
814 };
815 let plan = create_table_scan(table_ref, provider);
816
817 let result = strip_table_scan_catalog(plan)?;
818
819 assert!(result.transformed);
821 if let LogicalPlan::TableScan(scan) = result.data {
822 match scan.table_name {
823 TableReference::Partial { schema, table } => {
824 assert_eq!(schema.as_ref(), "schema");
825 assert_eq!(table.as_ref(), "table");
826 }
827 _ => panic!("Expected Partial table reference after catalog stripping"),
828 }
829 } else {
830 panic!("Expected TableScan after transformation");
831 }
832
833 Ok(())
834 }
835
836 #[test]
837 fn test_check_state_for_violations_empty_state() {
838 let state = PushdownState::default();
839 let schema = Arc::new(Schema::new(vec![Field::new("col1", DataType::Int32, false)]));
840 let provider = Arc::new(EmptyTable::new(schema));
841 let plan = create_table_scan(TableReference::bare("test_table"), provider);
842 let visitor = SourceLineageVistor::new();
843
844 let violations = check_state_for_violations(&state, &plan, &visitor);
845 assert!(violations.is_empty());
846 }
847
848 #[test]
849 fn test_check_plan_for_violations_non_aggregate() {
850 let schema = Arc::new(Schema::new(vec![Field::new("col1", DataType::Int32, false)]));
851 let provider = Arc::new(EmptyTable::new(schema));
852 let plan = create_table_scan(TableReference::bare("test_table"), provider);
853
854 let visitor = SourceLineageVistor::new();
855 let violations = check_plan_for_violations(&plan, &visitor);
856 assert!(violations.is_empty());
857 }
858
859 #[test]
860 fn test_semantic_err_no_violations() {
861 let violations = HashSet::new();
862 let result = semantic_err("TestPlan", "Test message", &violations);
863 assert!(result.is_ok());
864 }
865
866 #[test]
867 fn test_semantic_err_with_violations() {
868 let mut violations = HashSet::new();
869 let _ = violations.insert(Column::new_unqualified("test_col"));
870
871 let result = semantic_err("TestPlan", "Test message", &violations);
872 assert!(result.is_err());
873 let err_msg = result.unwrap_err().to_string();
874 assert!(err_msg.contains("[TestPlan]"));
875 assert!(err_msg.contains("Test message"));
876 assert!(err_msg.contains("test_col"));
877 }
878
879 #[test]
880 fn test_collect_and_transform_exprs_empty() -> Result<()> {
881 let exprs = Vec::new();
882 let visitor = SourceLineageVistor::new();
883 let mut state = PushdownState::default();
884
885 let result = collect_and_transform_exprs(exprs, &visitor, &mut state)?;
886 assert!(result.is_empty());
887 Ok(())
888 }
889
890 #[test]
891 fn test_collect_and_transform_exprs_no_clickhouse_functions() -> Result<()> {
892 let exprs = vec![col("test_col"), lit(42)];
893 let visitor = SourceLineageVistor::new();
894 let mut state = PushdownState::default();
895
896 let result = collect_and_transform_exprs(exprs.clone(), &visitor, &mut state)?;
897 assert_eq!(result.len(), 2);
898 assert_eq!(result[0], col("test_col"));
900 assert_eq!(result[1], lit(42));
901 Ok(())
902 }
903
904 #[test]
905 fn test_collect_and_transform_function_non_clickhouse() -> Result<()> {
906 let expr = col("test_col");
907 let visitor = SourceLineageVistor::new();
908 let mut state = PushdownState::default();
909
910 let result = collect_and_transform_function(expr.clone(), &visitor, &mut state)?;
911 assert!(!result.transformed);
912 assert_eq!(result.data, expr);
913 Ok(())
914 }
915
916 #[test]
917 fn test_collect_and_transform_function_with_clickhouse() -> Result<()> {
918 let table = TableReference::bare("test");
919 let test_col = Column::new(Some(table.clone()), "test_col");
920
921 let schema = Schema::new(vec![Field::new("test_col", DataType::Int32, false)]);
922
923 let clickhouse_expr = Expr::ScalarFunction(ScalarFunction {
924 func: Arc::new(clickhouse_udf()),
925 args: vec![count(Expr::Column(test_col.clone())), lit("Int32")],
926 });
927
928 let mut visitor = SourceLineageVistor::new();
929
930 let dummy_plan =
932 table_scan(Some(table.clone()), &schema, None)?.select(vec![0])?.build()?;
933 let _ = dummy_plan.visit(&mut visitor)?;
934 let mut state = PushdownState::default();
935 let result = collect_and_transform_function(clickhouse_expr, &visitor, &mut state)?;
936 assert!(result.transformed);
937
938 match result.data {
940 Expr::Column(_) => {} _ => panic!("Expected Column expression after ClickHouse function transformation"),
942 }
943 assert!(!state.functions.is_empty());
945 Ok(())
946 }
947
948 #[test]
949 fn test_column_id_resolution() -> Result<()> {
950 let table1 = TableReference::bare("test1");
951 let table2 = TableReference::bare("test2");
952 let test_col1 = Column::new(Some(table1.clone()), "test_col1");
953 let test_col_alt1 = Column::new(Some(table1.clone()), "test_col_alt1");
954 let test_col2 = Column::new(Some(table2.clone()), "test_col2");
955
956 let schema1 = Schema::new(vec![
957 Field::new("test_col1", DataType::Int32, false),
958 Field::new("test_col_alt1", DataType::Int32, false),
959 ]);
960
961 let schema2 = Schema::new(vec![Field::new("test_col2", DataType::Int32, false)]);
962
963 let mut visitor = SourceLineageVistor::new();
964
965 let clickhouse_expr_simple = Expr::ScalarFunction(ScalarFunction {
967 func: Arc::new(clickhouse_udf()),
968 args: vec![
969 Expr::Column(test_col1.clone()) + Expr::Column(test_col_alt1.clone()),
970 lit("Int64"),
971 ],
972 })
973 .alias("dummy_udf");
974 let clickhouse_expr_comp = Expr::ScalarFunction(ScalarFunction {
975 func: Arc::new(clickhouse_udf()),
976 args: vec![
977 Expr::Column(test_col1.clone()) + Expr::Column(test_col2.clone()) + lit(2),
978 lit("Int64"),
979 ],
980 })
981 .alias("dummy_udf_comp");
982 let simple_col = Column::from_name("dummy_udf");
983 let comp_col = Column::from_name("dummy_udf_comp");
984 let right_plan =
985 table_scan(Some(table2.clone()), &schema2, None)?.select(vec![0])?.build()?;
986
987 let dummy_plan = table_scan(Some(table1.clone()), &schema1, None)?
988 .select(vec![0, 1])?
989 .project(vec![Expr::Column(test_col1.clone()), clickhouse_expr_simple])?
990 .filter(Expr::Column(test_col1.clone()).gt(lit(0)))?
991 .join_on(right_plan, JoinType::Inner, vec![
992 col("table1.test_col1").eq(col("table2.test_col2")),
993 ])?
994 .project(vec![
995 Expr::Column(simple_col.clone()),
996 clickhouse_expr_comp,
997 lit("hello").alias("scalar_col"),
998 ])?
999 .build()?;
1000
1001 let _ = dummy_plan.visit(&mut visitor)?;
1002
1003 let lineage = visitor.column_lineage.get(&simple_col);
1005 let resolved = visitor.resolve_column(&simple_col);
1006 let col_ids = visitor.collect_column_ids(&simple_col);
1007
1008 let Some(ColumnLineage::Compound(ids)) = lineage else {
1009 panic!("Derived columns of clickhouse functions should be stored as `Compound`");
1010 };
1011
1012 assert_eq!(col_ids.len(), 2, "Expected 2 columns in when resolving context");
1013 assert_eq!(ids.len(), 3, "Expected 3 columns in Compound context");
1014 assert!(col_ids.is_subset(ids), "Expected collected columns to be contained in context");
1015
1016 let ResolvedSource::Compound(sources) = resolved else {
1017 panic!("Expected Compound source for simple column");
1018 };
1019
1020 let mut resolved_sources = sources.to_vec();
1021 let scalar_source = resolved_sources.pop().unwrap();
1022 let table_source = resolved_sources.pop().unwrap();
1023
1024 assert_eq!(
1025 scalar_source.as_ref(),
1026 &SourceContext::Scalar(ScalarValue::Utf8(Some("Int64".into())))
1027 );
1028 assert_eq!(table_source.as_ref(), &SourceContext::Table(table1.clone()));
1029
1030 let nullable = visitor.resolve_nullable(&Expr::Column(simple_col.clone()));
1031 assert!(!nullable);
1032
1033 let lineage = visitor.column_lineage.get(&comp_col);
1035 let col_ids = visitor.collect_column_ids(&comp_col);
1036
1037 let Some(ColumnLineage::Compound(cols)) = lineage else {
1038 panic!("Derived columns should be stored as `Compound`");
1039 };
1040
1041 assert_eq!(cols.len(), 4, "Expected 4 columns in Compound context, inc Scalars");
1043 assert_eq!(col_ids.len(), 2, "Expected 2 columns in collected columns, exc Scalars");
1044 assert!(col_ids.is_subset(cols), "Expected collected columns to be a subset of context");
1045
1046 Ok(())
1047 }
1048
1049 #[cfg(feature = "mocks")]
1059 fn is_clickhouse_extension(plan: &LogicalPlan) -> bool {
1060 if let LogicalPlan::Extension(ext) = plan {
1061 ext.node.name() == CLICKHOUSE_FUNCTION_NODE_NAME
1062 } else {
1063 false
1064 }
1065 }
1066
1067 #[cfg(feature = "mocks")]
1069 fn find_wrapped_plans(plan: &LogicalPlan) -> Vec<String> {
1070 fn traverse(plan: &LogicalPlan, wrapped_plans: &mut Vec<String>, path: &str) {
1071 if is_clickhouse_extension(plan) {
1072 wrapped_plans.push(format!("{path}: {plan}"));
1073 }
1074 for (i, input) in plan.inputs().iter().enumerate() {
1075 traverse(input, wrapped_plans, &format!("{path}/input[{i}]"));
1076 }
1077 }
1078 let mut wrapped_plans = Vec::new();
1079 traverse(plan, &mut wrapped_plans, "root");
1080 wrapped_plans
1081 }
1082
1083 #[cfg(all(feature = "federation", feature = "mocks"))]
1084 fn compare_plan_display(plan: &LogicalPlan, expected: impl Into<String>) {
1085 let mut plan_display = plan.display_indent().to_string();
1086 plan_display.retain(|c| !c.is_whitespace());
1087 let mut expected = expected.into();
1088 expected.retain(|c| !c.is_whitespace());
1089 assert_eq!(plan_display, expected, "Expected equal plans");
1090 }
1091
1092 #[cfg(feature = "mocks")]
1094 fn create_test_context() -> Result<SessionContext> {
1095 let ctx = SessionContext::new();
1096 ctx.register_udf(clickhouse_udf());
1098 ctx.add_analyzer_rule(Arc::new(ClickHouseFunctionPushdown));
1100
1101 let schema1 = Arc::new(Schema::new(vec![
1102 Field::new("col1", DataType::Int32, false),
1103 Field::new("col2", DataType::Int32, false),
1104 Field::new("col3", DataType::Utf8, false),
1105 ]));
1106 let schema2 = Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)]));
1107
1108 let pool = Arc::new(ClickHouseConnectionPool::new("pool".to_string(), ()));
1109 let table1 = ClickHouseTableProvider::new_with_schema_unchecked(
1110 Arc::clone(&pool),
1111 "table1".into(),
1112 Arc::clone(&schema1),
1113 );
1114 let table2 = ClickHouseTableProvider::new_with_schema_unchecked(
1115 Arc::clone(&pool),
1116 "table1".into(),
1117 schema2,
1118 );
1119
1120 drop(ctx.register_table("table1", Arc::new(table1))?);
1122
1123 drop(ctx.register_table("table2", Arc::new(table2))?);
1125
1126 let table3 = Arc::new(EmptyTable::new(schema1));
1128 drop(ctx.register_table("table3", table3)?);
1129
1130 Ok(ctx)
1131 }
1132
1133 #[cfg(feature = "mocks")]
1135 async fn run_query(sql: &str) -> Result<LogicalPlan> {
1136 let ctx = create_test_context()?;
1137 let analyzed_plan = ctx.sql(sql).await?.into_optimized_plan()?; SQLOptions::default().verify_plan(&analyzed_plan)?;
1139 Ok(analyzed_plan)
1140 }
1141
1142 #[cfg(feature = "mocks")]
1143 #[tokio::test]
1144 async fn test_simple_projection_with_clickhouse_function() -> Result<()> {
1145 let sql =
1148 "SELECT clickhouse(exp(col1 + col2), 'Float64'), col2 * 2, UPPER(col3) FROM table1";
1149 let analyzed_plan = run_query(sql).await?;
1150 let wrapped_plans = find_wrapped_plans(&analyzed_plan);
1151
1152 #[cfg(feature = "federation")]
1153 {
1154 let expected_plan = r#"
1155 Projection: clickhouse(exp(CAST(table1.col1 + table1.col2 AS Float64)), Utf8("Float64")), CAST(table1.col2 AS Int64) * Int64(2), upper(table1.col3)
1156 TableScan: table1 projection=[col1, col2, col3]
1157 "#;
1158 assert!(wrapped_plans.is_empty(), "Expected no wrapped plans");
1159 compare_plan_display(&analyzed_plan, expected_plan);
1160 }
1161 #[cfg(not(feature = "federation"))]
1162 {
1163 assert_eq!(wrapped_plans.len(), 1, "Expected exactly one wrapped plan");
1164 assert!(wrapped_plans[0].starts_with("root:"), "Expected wrapping at root level");
1165 }
1166 Ok(())
1167 }
1168
1169 #[cfg(feature = "mocks")]
1170 #[tokio::test]
1171 async fn test_filter_with_clickhouse_function() -> Result<()> {
1172 let sql = "SELECT col2, col3 FROM table1 WHERE clickhouse(exp(col1), 'Float64') > 10";
1176 let analyzed_plan = run_query(sql).await?;
1177 let wrapped_plans = find_wrapped_plans(&analyzed_plan);
1178
1179 #[cfg(feature = "federation")]
1180 {
1181 let expected_plan = r#"
1182 TableScan: table1 projection=[col2, col3], full_filters=[clickhouse(exp(CAST(table1.col1 AS Float64)), Utf8("Float64")) > Float64(10)]
1183 "#;
1184 assert!(wrapped_plans.is_empty(), "No wrapping expected");
1185 compare_plan_display(&analyzed_plan, expected_plan);
1186 }
1187 #[cfg(not(feature = "federation"))]
1188 {
1189 assert_eq!(wrapped_plans.len(), 1, "Expected exactly one wrapped plan");
1192 assert!(wrapped_plans[0].contains("root"), "Expected wrapping at filter level, root");
1193 }
1194 Ok(())
1195 }
1196
1197 #[cfg(feature = "mocks")]
1198 #[tokio::test]
1199 async fn test_aggregate_blocks_pushdown() -> Result<()> {
1200 let sql = "SELECT col2, COUNT(*) FROM table1 WHERE clickhouse(exp(col1), 'Float64') > 5 \
1203 GROUP BY col2";
1204 let analyzed_plan = run_query(sql).await?;
1205 let wrapped_plans = find_wrapped_plans(&analyzed_plan);
1206
1207 #[cfg(feature = "federation")]
1208 {
1209 assert!(wrapped_plans.is_empty(), "Expected no wrapped plans");
1210 assert!(
1211 analyzed_plan.display().to_string().to_lowercase().starts_with("projection"),
1212 "Expected projection"
1213 );
1214 }
1215
1216 #[cfg(not(feature = "federation"))]
1217 {
1218 assert_eq!(wrapped_plans.len(), 1, "Expected exactly one wrapped plan");
1220 assert!(
1222 wrapped_plans.iter().any(|w| w.contains("root")),
1223 "Expected function to be wrapped at aggregate input level due to blocking"
1224 );
1225 }
1226 Ok(())
1227 }
1228
1229 #[cfg(feature = "mocks")]
1230 #[tokio::test]
1231 async fn test_multiple_clickhouse_functions_same_table() -> Result<()> {
1232 let sql = "SELECT clickhouse(exp(col1), 'Float64') AS f1, clickhouse(exp(col2), \
1236 'Float64') AS f2 FROM table1";
1237 let analyzed_plan = run_query(sql).await?;
1238 let wrapped_plans = find_wrapped_plans(&analyzed_plan);
1239
1240 #[cfg(feature = "federation")]
1241 {
1242 assert!(wrapped_plans.is_empty(), "Expected no wrapped plans");
1243 assert!(
1244 analyzed_plan.display().to_string().to_lowercase().starts_with("projection"),
1245 "Expected projection"
1246 );
1247 }
1248
1249 #[cfg(not(feature = "federation"))]
1250 {
1251 assert_eq!(
1253 wrapped_plans.len(),
1254 1,
1255 "Expected exactly one wrapped plan for both functions"
1256 );
1257 assert!(
1258 wrapped_plans[0].starts_with("root:"),
1259 "Expected both functions wrapped together at root level"
1260 );
1261 }
1262 Ok(())
1263 }
1264
1265 #[cfg(feature = "mocks")]
1266 #[tokio::test]
1267 async fn test_no_functions_no_wrapping() -> Result<()> {
1268 let sql = "SELECT col1, col2 FROM table1";
1271 let analyzed_plan = run_query(sql).await?;
1272 let wrapped_plans = find_wrapped_plans(&analyzed_plan);
1273 assert_eq!(wrapped_plans.len(), 0, "Expected no wrapped plans when no functions present");
1275 Ok(())
1276 }
1277
1278 #[cfg(feature = "mocks")]
1279 #[tokio::test]
1280 async fn test_wrapped_disjoint_tables() -> Result<()> {
1281 let sql = "SELECT t1.col1, clickhouse(exp(t2.id), 'Float64') FROM (SELECT col1 FROM \
1284 table1) t1 JOIN (SELECT id from table2) t2 ON t1.col1 = t2.id";
1285 let analyzed_plan = run_query(sql).await?;
1286 let wrapped_plans = find_wrapped_plans(&analyzed_plan);
1287
1288 #[cfg(feature = "federation")]
1289 {
1290 assert!(wrapped_plans.is_empty(), "Expected no wrapped plans");
1291 assert!(
1292 analyzed_plan.display().to_string().to_lowercase().starts_with("projection"),
1293 "Expected projection"
1294 );
1295 }
1296
1297 #[cfg(not(feature = "federation"))]
1298 {
1299 assert_eq!(wrapped_plans.len(), 1, "Expected function wrapped entire plan");
1300 assert!(
1301 wrapped_plans[0].contains("root"),
1302 "Expected function wrapped on right side of JOIN"
1303 );
1304 }
1305 Ok(())
1306 }
1307
1308 #[cfg(feature = "mocks")]
1309 #[tokio::test]
1310 async fn test_disjoint_tables() -> Result<()> {
1311 let sql = "SELECT t3.col1, clickhouse(exp(t2.id), 'Float64')
1314 FROM (SELECT col1 FROM table3) t3
1315 JOIN (SELECT id from table2) t2 ON t3.col1 = t2.id
1316 ";
1317
1318 let analyzed_plan = run_query(sql).await?;
1319 let wrapped_plans = find_wrapped_plans(&analyzed_plan);
1320
1321 #[cfg(feature = "federation")]
1323 {
1324 let expected_plan = r#"
1325 Projection: t3.col1, clickhouse(exp(t2.id),Utf8("Float64"))
1326 Inner Join: t3.col1 = t2.id
1327 SubqueryAlias: t3
1328 TableScan: table3 projection=[col1]
1329 SubqueryAlias: t2
1330 Projection: table2.id, clickhouse(exp(CAST(table2.id AS Float64)), Utf8("Float64")) AS clickhouse(exp(t2.id),Utf8("Float64"))
1331 TableScan: table2 projection=[id]
1332 "#;
1333 assert!(wrapped_plans.is_empty(), "Expected no wrapped plans");
1334 compare_plan_display(&analyzed_plan, expected_plan);
1335 }
1336
1337 #[cfg(not(feature = "federation"))]
1338 {
1339 assert_eq!(wrapped_plans.len(), 1, "Expected function routed to right side of JOIN");
1345 assert!(
1346 wrapped_plans[0].contains("input[1]"),
1347 "Expected function wrapped on right side of JOIN"
1348 );
1349 }
1350
1351 Ok(())
1352 }
1353
1354 #[cfg(feature = "mocks")]
1358 #[tokio::test]
1359 async fn test_complex_agg() -> Result<()> {
1360 let sql = "SELECT
1361 clickhouse(pow(t.id, 2), 'Int32') as id_mod,
1362 COUNT(t.id) as total,
1363 MAX(clickhouse(exp(t.id), 'Float64')) as max_exp
1364 FROM table2 t
1365 GROUP BY id_mod";
1366 let analyzed_plan = run_query(sql).await?;
1367 let wrapped_plans = find_wrapped_plans(&analyzed_plan);
1368
1369 #[cfg(feature = "federation")]
1370 {
1371 let expected_plan = r#"
1372 Projection: clickhouse(power(t.id,Int64(2)),Utf8("Int32")) AS id_mod, count(t.id) AS total, max(clickhouse(exp(t.id),Utf8("Float64"))) AS max_exp
1373 Aggregate: groupBy=[[clickhouse(power(CAST(t.id AS Int64), Int64(2)), Utf8("Int32"))]], aggr=[[count(t.id), max(clickhouse(exp(CAST(t.id AS Float64)), Utf8("Float64")))]]
1374 SubqueryAlias: t
1375 TableScan: table2 projection=[id]
1376 "#
1377 .trim();
1378 assert!(wrapped_plans.is_empty(), "No wrapping expected");
1379 compare_plan_display(&analyzed_plan, expected_plan);
1380 }
1381
1382 #[cfg(not(feature = "federation"))]
1383 {
1384 assert_eq!(wrapped_plans.len(), 1, "Expected entire plan wrapped");
1385 assert!(wrapped_plans[0].contains("root"), "Expected function wrapped at root");
1386 }
1387
1388 Ok(())
1389 }
1390
1391 #[cfg(feature = "mocks")]
1392 #[tokio::test]
1393 async fn test_union() -> Result<()> {
1394 let sql = "
1395 SELECT col1 as id, clickhouse(exp(col1), 'Float64') as func_id
1396 FROM table1 WHERE table1.col1 = 1
1397 UNION ALL
1398 SELECT id, clickhouse(pow(id, 2), 'Float64') as func_id
1399 FROM table2 WHERE table2.id = 1
1400 ";
1401 let analyzed_plan = run_query(sql).await?;
1402 let wrapped_plans = find_wrapped_plans(&analyzed_plan);
1403
1404 #[cfg(feature = "federation")]
1405 {
1406 let expected_plan = r#"
1407 Union
1408 Projection: table1.col1 AS id, clickhouse(exp(CAST(table1.col1 AS Float64)), Utf8("Float64")) AS func_id
1409 TableScan: table1 projection=[col1], full_filters=[table1.col1 = Int32(1)]
1410 Projection: table2.id, clickhouse(power(CAST(table2.id AS Int64), Int64(2)), Utf8("Float64")) AS func_id
1411 TableScan: table2 projection=[id], full_filters=[table2.id = Int32(1)]
1412 "#
1413 .trim();
1414 assert!(wrapped_plans.is_empty(), "No wrapping expected");
1415 compare_plan_display(&analyzed_plan, expected_plan);
1416 }
1417
1418 #[cfg(not(feature = "federation"))]
1419 {
1420 assert_eq!(wrapped_plans.len(), 1, "Expected entire plan wrapped");
1421 assert!(wrapped_plans[0].contains("root"), "Expected function wrapped at root");
1422 }
1423
1424 Ok(())
1425 }
1426
1427 #[cfg(feature = "mocks")]
1428 #[tokio::test]
1429 async fn test_limit() -> Result<()> {
1430 let sql = "SELECT clickhouse(abs(t2.id), 'Int32') FROM table2 t2 LIMIT 1";
1431 let analyzed_plan = run_query(sql).await?;
1432 let wrapped_plans = find_wrapped_plans(&analyzed_plan);
1433
1434 #[cfg(feature = "federation")]
1435 {
1436 let expected_plan = r#"
1437 Projection: clickhouse(abs(t2.id), Utf8("Int32"))
1438 SubqueryAlias: t2
1439 Limit: skip=0, fetch=1
1440 TableScan: table2 projection=[id], fetch=1
1441 "#
1442 .trim();
1443 assert!(wrapped_plans.is_empty(), "No wrapping expected");
1444 compare_plan_display(&analyzed_plan, expected_plan);
1445 }
1446
1447 #[cfg(not(feature = "federation"))]
1448 {
1449 assert_eq!(wrapped_plans.len(), 1, "Expected entire plan wrapped");
1450 assert!(wrapped_plans[0].contains("root"), "Expected function wrapped at root");
1451 }
1452 Ok(())
1453 }
1454
1455 #[cfg(feature = "mocks")]
1456 #[tokio::test]
1457 async fn test_sort() -> Result<()> {
1458 let sql = "SELECT t2.id FROM table2 t2 ORDER BY clickhouse(abs(t2.id), 'Int64')";
1459 let analyzed_plan = run_query(sql).await?;
1460 let wrapped_plans = find_wrapped_plans(&analyzed_plan);
1461
1462 #[cfg(feature = "federation")]
1463 {
1464 let expected_plan = r#"
1465 Sort: clickhouse(abs(t2.id), Utf8("Int64")) ASC NULLS LAST
1466 SubqueryAlias: t2
1467 TableScan: table2 projection=[id]
1468 "#
1469 .trim();
1470 assert!(wrapped_plans.is_empty(), "No wrapping expected");
1471 compare_plan_display(&analyzed_plan, expected_plan);
1472 }
1473
1474 #[cfg(not(feature = "federation"))]
1475 {
1476 assert_eq!(wrapped_plans.len(), 1, "Expected entire plan wrapped");
1477 assert!(wrapped_plans[0].contains("root"), "Expected function wrapped at root");
1478 }
1479 Ok(())
1480 }
1481
1482 #[cfg(feature = "mocks")]
1489 #[tokio::test]
1490 async fn test_multiple_cols_same_function() -> Result<()> {
1491 let sql = "SELECT t3.col2
1492 , t1.col2
1493 , t2.id
1494 , clickhouse(t1.col1 + t2.id, 'Int64') as sum_ids
1495 FROM table3 t3
1496 JOIN table1 t1 ON t1.col1 = t3.col1
1497 JOIN table2 t2 ON t2.id = t1.col1
1498 ";
1499 let result = run_query(sql).await;
1500 assert!(result.is_err(), "Cannot push to either side of join");
1501
1502 let sql = "SELECT t3.col2
1503 , t1.col2
1504 , t2.id
1505 , clickhouse(t1.col1 + t2.id, 'Int64') as sum_ids
1506 FROM table1 t1
1507 JOIN table2 t2 ON t2.id = t1.col1
1508 JOIN table3 t3 ON t2.id = t3.col1
1509 ";
1510
1511 let analyzed_plan = run_query(sql).await?;
1512 let wrapped_plans = find_wrapped_plans(&analyzed_plan);
1513
1514 #[cfg(feature = "federation")]
1515 {
1516 let expected_plan = r#"
1517 Projection: t3.col2, t1.col2, t2.id, clickhouse(t1.col1 + t2.id,Utf8("Int64")) AS sum_ids
1518 Inner Join: t2.id = t3.col1
1519 Projection: t1.col2, t2.id, clickhouse(t1.col1 + t2.id, Utf8("Int64")) AS clickhouse(t1.col1 + t2.id,Utf8("Int64"))
1520 Inner Join: t1.col1 = t2.id
1521 SubqueryAlias: t1
1522 TableScan: table1 projection=[col1, col2]
1523 SubqueryAlias: t2
1524 TableScan: table2 projection=[id]
1525 SubqueryAlias: t3
1526 TableScan: table3 projection=[col1, col2]
1527 "#.trim();
1528
1529 assert!(wrapped_plans.is_empty(), "No wrapping expected");
1530 compare_plan_display(&analyzed_plan, expected_plan);
1531 }
1532
1533 #[cfg(not(feature = "federation"))]
1534 {
1535 assert_eq!(wrapped_plans.len(), 1, "Expected entire plan wrapped");
1536 assert!(wrapped_plans[0].contains("root"), "Expected function wrapped at root");
1537 }
1538 Ok(())
1539 }
1540}