1pub mod plan_node;
13pub mod planner;
14
15use std::collections::HashMap;
16use std::collections::hash_map::Entry;
17use std::sync::Arc;
18
19use async_trait::async_trait;
20use datafusion::arrow::datatypes::{DataType, SchemaRef};
21use datafusion::catalog::cte_worktable::CteWorkTable;
22use datafusion::common::file_options::file_type::FileType;
23use datafusion::common::plan_datafusion_err;
24use datafusion::config::ConfigOptions;
25use datafusion::datasource::file_format::format_as_file_type;
26use datafusion::datasource::provider_as_source;
27use datafusion::error::Result;
28use datafusion::execution::SessionState;
29use datafusion::execution::context::QueryPlanner;
30use datafusion::logical_expr::planner::{ExprPlanner, TypePlanner};
31use datafusion::logical_expr::var_provider::is_system_variables;
32use datafusion::logical_expr::{AggregateUDF, LogicalPlan, ScalarUDF, TableSource, WindowUDF};
33use datafusion::optimizer::AnalyzerRule;
34use datafusion::optimizer::analyzer::type_coercion::TypeCoercion;
35use datafusion::physical_plan::ExecutionPlan;
36use datafusion::physical_planner::{DefaultPhysicalPlanner, ExtensionPlanner, PhysicalPlanner};
37use datafusion::prelude::{DataFrame, Expr, SQLOptions, SessionContext};
38use datafusion::sql::parser::Statement;
39use datafusion::sql::planner::{ContextProvider, ParserOptions, SqlToRel};
40use datafusion::sql::{ResolvedTableReference, TableReference};
41use datafusion::variable::VarType;
42
43use self::planner::ClickHouseExtensionPlanner;
44use crate::analyzer::function_pushdown::ClickHouseFunctionPushdown;
45use crate::udfs::apply::{CLICKHOUSE_APPLY_ALIASES, clickhouse_apply_udf};
46use crate::udfs::clickhouse::{CLICKHOUSE_UDF_ALIASES, clickhouse_udf};
47use crate::udfs::placeholder::PlaceholderUDF;
48
49pub fn prepare_session_context(
60 ctx: SessionContext,
61 extension_planners: Option<Vec<Arc<dyn ExtensionPlanner + Send + Sync>>>,
62) -> SessionContext {
63 #[cfg(feature = "federation")]
64 use crate::federation::FederatedContext as _;
65
66 #[cfg(feature = "federation")]
69 let ctx = ctx.federate();
70 let state = ctx.state();
72 let config = state.config().clone();
73 let config = config.set_str("datafusion.sql_parser.dialect", "ClickHouse");
81 let state_builder = if state
83 .analyzer()
84 .rules
85 .iter()
86 .any(|rule| rule.name() == ClickHouseFunctionPushdown.name())
87 {
88 ctx.into_state_builder()
89 } else {
90 let analyzer_rules = configure_analyzer_rules(&state);
91 ctx.into_state_builder().with_analyzer_rules(analyzer_rules)
92 };
93 let ctx = SessionContext::new_with_state(
95 state_builder
96 .with_config(config)
97 .with_query_planner(Arc::new(ClickHouseQueryPlanner::new_with_planners(
98 extension_planners.unwrap_or_default(),
99 )))
100 .build(),
101 );
102 ctx.register_udf(clickhouse_udf());
103 ctx.register_udf(clickhouse_apply_udf());
104 ctx
105}
106
107pub fn configure_analyzer_rules(state: &SessionState) -> Vec<Arc<dyn AnalyzerRule + Send + Sync>> {
109 let mut analyzer_rules = state.analyzer().rules.clone();
111
112 let type_coercion = TypeCoercion::default();
115 let pos = analyzer_rules.iter().position(|x| x.name() == type_coercion.name()).unwrap_or(0);
116
117 let pushdown_rule = Arc::new(ClickHouseFunctionPushdown);
118 analyzer_rules.insert(pos, pushdown_rule);
119 analyzer_rules
120}
121
122#[derive(Clone)]
124pub struct ClickHouseQueryPlanner {
125 planners: Vec<Arc<dyn ExtensionPlanner + Send + Sync>>,
126}
127
128impl std::fmt::Debug for ClickHouseQueryPlanner {
129 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
130 f.debug_struct("ClickHouseQueryPlanner").finish()
131 }
132}
133
134impl Default for ClickHouseQueryPlanner {
135 fn default() -> Self { Self::new() }
136}
137
138impl ClickHouseQueryPlanner {
139 pub fn new() -> Self {
141 let planners = vec![
142 #[cfg(feature = "federation")]
143 Arc::new(datafusion_federation::FederatedPlanner::new()),
144 Arc::new(ClickHouseExtensionPlanner {}) as Arc<dyn ExtensionPlanner + Send + Sync>,
145 ];
146 ClickHouseQueryPlanner { planners }
147 }
148
149 pub fn new_with_planners(planners: Vec<Arc<dyn ExtensionPlanner + Send + Sync>>) -> Self {
151 let mut this = Self::new();
152 this.planners.extend(planners);
153 this
154 }
155
156 #[must_use]
158 pub fn with_planner(mut self, planner: Arc<dyn ExtensionPlanner + Send + Sync>) -> Self {
159 self.planners.push(planner);
160 self
161 }
162}
163
164#[async_trait]
165impl QueryPlanner for ClickHouseQueryPlanner {
166 async fn create_physical_plan(
167 &self,
168 logical_plan: &LogicalPlan,
169 session_state: &SessionState,
170 ) -> Result<Arc<dyn ExecutionPlan>> {
171 let planner = DefaultPhysicalPlanner::with_extension_planners(self.planners.clone());
173 planner.create_physical_plan(logical_plan, session_state).await
174 }
175}
176
177#[derive(Clone)]
179pub struct ClickHouseSessionContext {
180 inner: SessionContext,
181 expr_planner: Option<Arc<dyn ExprPlanner>>,
182}
183
184impl ClickHouseSessionContext {
185 pub fn new(
187 ctx: SessionContext,
188 extension_planners: Option<Vec<Arc<dyn ExtensionPlanner + Send + Sync>>>,
189 ) -> Self {
190 Self { inner: prepare_session_context(ctx, extension_planners), expr_planner: None }
191 }
192
193 #[must_use]
194 pub fn with_expr_planner(mut self, expr_planner: Arc<dyn ExprPlanner>) -> Self {
195 self.expr_planner = Some(expr_planner);
196 self
197 }
198
199 pub fn into_session_context(self) -> SessionContext { self.inner }
201
202 pub async fn sql(&self, sql: &str) -> Result<DataFrame> {
207 self.sql_with_options(sql, SQLOptions::new()).await
208 }
209
210 pub async fn sql_with_options(&self, sql: &str, options: SQLOptions) -> Result<DataFrame> {
215 let state = self.inner.state();
216 let statement = state.sql_to_statement(sql, "ClickHouse")?;
217 let plan = self.statement_to_plan(&state, statement).await?;
218 options.verify_plan(&plan)?;
219 self.execute_logical_plan(plan).await
220 }
221
222 pub async fn statement_to_plan(
226 &self,
227 state: &SessionState,
228 statement: Statement,
229 ) -> Result<LogicalPlan> {
230 let references = state.resolve_table_references(&statement)?;
231
232 let provider =
233 ClickHouseContextProvider::new(state.clone(), HashMap::with_capacity(references.len()));
234
235 let mut provider = if let Some(planner) = self.expr_planner.as_ref() {
236 provider.with_expr_planner(Arc::clone(planner))
237 } else {
238 provider
239 };
240
241 for reference in references {
242 let catalog = &state.config_options().catalog;
245 let resolved = reference.resolve(&catalog.default_catalog, &catalog.default_schema);
246 if let Entry::Vacant(v) = provider.tables.entry(resolved) {
247 let resolved = v.key();
248 if let Ok(schema) = provider.state.schema_for_ref(resolved.clone())
249 && let Some(table) = schema.table(&resolved.table).await?
250 {
251 let _ = v.insert(provider_as_source(table));
252 }
253 }
254 }
255
256 SqlToRel::new_with_options(&provider, Self::get_parser_options(&self.state()))
257 .statement_to_plan(statement)
258 }
259
260 fn get_parser_options(state: &SessionState) -> ParserOptions {
261 let sql_parser_options = &state.config().options().sql_parser;
262 ParserOptions {
263 parse_float_as_decimal: sql_parser_options.parse_float_as_decimal,
264 enable_ident_normalization: sql_parser_options.enable_ident_normalization,
265 enable_options_value_normalization: sql_parser_options
266 .enable_options_value_normalization,
267 support_varchar_with_length: sql_parser_options.support_varchar_with_length,
268 map_string_types_to_utf8view: sql_parser_options.map_string_types_to_utf8view,
269 collect_spans: sql_parser_options.collect_spans,
270 }
271 }
272}
273
274impl From<SessionContext> for ClickHouseSessionContext {
275 fn from(inner: SessionContext) -> Self { Self::new(inner, None) }
276}
277
278impl From<&SessionContext> for ClickHouseSessionContext {
279 fn from(inner: &SessionContext) -> Self { Self::new(inner.clone(), None) }
280}
281
282impl std::ops::Deref for ClickHouseSessionContext {
283 type Target = SessionContext;
284
285 fn deref(&self) -> &Self::Target { &self.inner }
286}
287
288pub struct ClickHouseContextProvider {
293 state: SessionState,
294 tables: HashMap<ResolvedTableReference, Arc<dyn TableSource>>,
295 expr_planners: Vec<Arc<dyn ExprPlanner>>,
296 type_planner: Option<Arc<dyn TypePlanner>>,
297}
298
299impl ClickHouseContextProvider {
300 pub fn new(
301 state: SessionState,
302 tables: HashMap<ResolvedTableReference, Arc<dyn TableSource>>,
303 ) -> Self {
304 Self { state, tables, expr_planners: vec![], type_planner: None }
305 }
306
307 #[must_use]
308 pub fn with_expr_planner(mut self, planner: Arc<dyn ExprPlanner>) -> Self {
309 self.expr_planners.push(planner);
310 self
311 }
312
313 #[must_use]
314 pub fn with_type_planner(mut self, type_planner: Arc<dyn TypePlanner>) -> Self {
315 self.type_planner = Some(type_planner);
316 self
317 }
318
319 fn resolve_table_ref(&self, table_ref: impl Into<TableReference>) -> ResolvedTableReference {
322 let catalog = &self.state.config_options().catalog;
323 table_ref.into().resolve(&catalog.default_catalog, &catalog.default_schema)
324 }
325}
326
327impl ContextProvider for ClickHouseContextProvider {
328 fn get_function_meta(&self, name: &str) -> Option<Arc<ScalarUDF>> {
329 if CLICKHOUSE_UDF_ALIASES.contains(&name) {
331 return Some(Arc::new(clickhouse_udf()));
332 }
333
334 if CLICKHOUSE_APPLY_ALIASES.contains(&name) {
336 return Some(Arc::new(clickhouse_apply_udf()));
337 }
338
339 if let Some(func) = self.state.scalar_functions().get(name) {
341 return Some(Arc::clone(func));
342 }
343
344 if self.state.aggregate_functions().contains_key(name) {
347 return None;
348 }
349 if self.state.window_functions().contains_key(name) {
350 return None;
351 }
352
353 Some(Arc::new(ScalarUDF::new_from_impl(PlaceholderUDF::new(name))))
355 }
356
357 fn get_expr_planners(&self) -> &[Arc<dyn ExprPlanner>] { &self.expr_planners }
358
359 fn get_type_planner(&self) -> Option<Arc<dyn TypePlanner>> {
360 if let Some(type_planner) = &self.type_planner {
361 Some(Arc::clone(type_planner))
362 } else {
363 None
364 }
365 }
366
367 fn get_table_source(&self, name: TableReference) -> Result<Arc<dyn TableSource>> {
368 let name = self.resolve_table_ref(name);
369 self.tables
370 .get(&name)
371 .cloned()
372 .ok_or_else(|| plan_datafusion_err!("table '{name}' not found"))
373 }
374
375 fn get_table_function_source(
376 &self,
377 name: &str,
378 args: Vec<Expr>,
379 ) -> Result<Arc<dyn TableSource>> {
380 let tbl_func = self
381 .state
382 .table_functions()
383 .get(name)
384 .cloned()
385 .ok_or_else(|| plan_datafusion_err!("table function '{name}' not found"))?;
386 let provider = tbl_func.create_table_provider(&args)?;
387
388 Ok(provider_as_source(provider))
389 }
390
391 fn create_cte_work_table(&self, name: &str, schema: SchemaRef) -> Result<Arc<dyn TableSource>> {
395 let table = Arc::new(CteWorkTable::new(name, schema));
396 Ok(provider_as_source(table))
397 }
398
399 fn get_aggregate_meta(&self, name: &str) -> Option<Arc<AggregateUDF>> {
400 self.state.aggregate_functions().get(name).cloned()
401 }
402
403 fn get_window_meta(&self, name: &str) -> Option<Arc<WindowUDF>> {
404 self.state.window_functions().get(name).cloned()
405 }
406
407 fn get_variable_type(&self, variable_names: &[String]) -> Option<DataType> {
408 if variable_names.is_empty() {
409 return None;
410 }
411
412 let provider_type = if is_system_variables(variable_names) {
413 VarType::System
414 } else {
415 VarType::UserDefined
416 };
417
418 self.state
419 .execution_props()
420 .var_providers
421 .as_ref()
422 .and_then(|provider| provider.get(&provider_type)?.get_type(variable_names))
423 }
424
425 fn options(&self) -> &ConfigOptions { self.state.config_options() }
426
427 fn udf_names(&self) -> Vec<String> { self.state.scalar_functions().keys().cloned().collect() }
429
430 fn udaf_names(&self) -> Vec<String> {
431 self.state.aggregate_functions().keys().cloned().collect()
432 }
433
434 fn udwf_names(&self) -> Vec<String> { self.state.window_functions().keys().cloned().collect() }
435
436 fn get_file_type(&self, ext: &str) -> Result<Arc<dyn FileType>> {
437 self.state
438 .get_file_format_factory(ext)
439 .ok_or(plan_datafusion_err!("There is no registered file format with ext {ext}"))
440 .map(|file_type| format_as_file_type(file_type))
441 }
442}
443
444#[cfg(all(test, feature = "test-utils"))]
445mod tests {
446 use std::collections::HashMap;
447 use std::sync::Arc;
448
449 use datafusion::arrow::datatypes::{DataType, Field, Schema};
450 use datafusion::common::DFSchema;
451 use datafusion::logical_expr::planner::{
452 ExprPlanner, PlannerResult, RawBinaryExpr, TypePlanner,
453 };
454 use datafusion::prelude::{SessionContext, lit};
455 use datafusion::sql::TableReference;
456 use datafusion::sql::sqlparser::ast;
457
458 use super::*;
459
460 #[derive(Debug)]
462 struct MockTypePlanner;
463
464 impl TypePlanner for MockTypePlanner {
465 fn plan_type(&self, _expr: &ast::DataType) -> Result<Option<DataType>> {
466 Ok(Some(DataType::Utf8))
467 }
468 }
469
470 #[derive(Debug)]
472 struct MockExprPlanner;
473
474 impl ExprPlanner for MockExprPlanner {
475 fn plan_binary_op(
476 &self,
477 expr: RawBinaryExpr,
478 _schema: &DFSchema,
479 ) -> Result<PlannerResult<RawBinaryExpr>> {
480 Ok(PlannerResult::Original(expr))
481 }
482 }
483
484 fn create_test_context_provider() -> ClickHouseContextProvider {
485 let ctx = SessionContext::new();
486 let state = ctx.state();
487 let tables = HashMap::new();
488 ClickHouseContextProvider::new(state, tables)
489 }
490
491 #[test]
492 fn test_with_expr_planner() {
493 let mut provider = create_test_context_provider();
494 assert!(provider.expr_planners.is_empty());
495
496 let expr_planner = Arc::new(MockExprPlanner) as Arc<dyn ExprPlanner>;
497 provider = provider.with_expr_planner(Arc::clone(&expr_planner));
498
499 assert_eq!(provider.expr_planners.len(), 1);
500 assert_eq!(provider.get_expr_planners().len(), 1);
501 }
502
503 #[test]
504 fn test_with_type_planner() {
505 let mut provider = create_test_context_provider();
506 assert!(provider.type_planner.is_none());
507
508 let type_planner = Arc::new(MockTypePlanner) as Arc<dyn TypePlanner>;
509 provider = provider.with_type_planner(Arc::clone(&type_planner));
510
511 assert!(provider.type_planner.is_some());
512 }
513
514 #[test]
515 fn test_get_type_planner() {
516 let provider = create_test_context_provider();
517 assert!(provider.get_type_planner().is_none());
518
519 let type_planner = Arc::new(MockTypePlanner) as Arc<dyn TypePlanner>;
520 let provider = provider.with_type_planner(Arc::clone(&type_planner));
521
522 assert!(provider.get_type_planner().is_some());
523 }
524
525 #[test]
526 fn test_get_table_function_source_not_found() {
527 let provider = create_test_context_provider();
528 let args = vec![lit("test")];
529
530 let result = provider.get_table_function_source("nonexistent_function", args);
531 assert!(result.is_err());
532 }
533
534 #[test]
535 fn test_create_cte_work_table() {
536 let provider = create_test_context_provider();
537 let schema = Arc::new(Schema::new(vec![
538 Field::new("id", DataType::Int32, false),
539 Field::new("name", DataType::Utf8, true),
540 ]));
541
542 let result = provider.create_cte_work_table("test_cte", Arc::clone(&schema));
543 assert!(result.is_ok());
544
545 let table_source = result.unwrap();
546 assert_eq!(table_source.schema(), schema);
547 }
548
549 #[test]
550 fn test_get_variable_type_empty() {
551 let provider = create_test_context_provider();
552 let result = provider.get_variable_type(&[]);
553 assert!(result.is_none());
554 }
555
556 #[test]
557 fn test_get_variable_type_system_variables() {
558 let provider = create_test_context_provider();
559 let result = provider.get_variable_type(&["@@version".to_string()]);
561 assert!(result.is_none());
563 }
564
565 #[test]
566 fn test_get_variable_type_user_defined() {
567 let provider = create_test_context_provider();
568 let result = provider.get_variable_type(&["user_var".to_string()]);
570 assert!(result.is_none());
572 }
573
574 #[test]
575 fn test_get_file_type_unknown_extension() {
576 let provider = create_test_context_provider();
577 let result = provider.get_file_type("unknown_ext");
578 assert!(result.is_err());
579 }
580
581 #[test]
582 fn test_get_file_type_known_extension() {
583 let provider = create_test_context_provider();
584 let result = provider.get_file_type("csv");
586 assert!(result.is_ok());
587 }
588
589 #[test]
590 fn test_get_function_meta_clickhouse_udf() {
591 let provider = create_test_context_provider();
592
593 let result = provider.get_function_meta("clickhouse");
595 assert!(result.is_some());
596 let udf = result.unwrap();
597 assert_eq!(udf.name(), "clickhouse");
598 }
599
600 #[test]
601 fn test_get_function_meta_placeholder_udf() {
602 let provider = create_test_context_provider();
603
604 let result = provider.get_function_meta("unknown_function");
606 assert!(result.is_some());
607 let udf = result.unwrap();
608 assert_eq!(udf.name(), "unknown_function");
609 }
610
611 #[test]
612 fn test_get_function_meta_aggregate_function() {
613 let provider = create_test_context_provider();
614
615 let result = provider.get_function_meta("sum");
617 assert!(result.is_none());
618 }
619
620 #[test]
621 fn test_get_function_meta_window_function() {
622 let provider = create_test_context_provider();
623
624 let result = provider.get_function_meta("row_number");
626 assert!(result.is_none());
627 }
628
629 #[test]
630 fn test_get_table_source_not_found() {
631 let provider = create_test_context_provider();
632 let table_ref = TableReference::bare("nonexistent_table");
633
634 let result = provider.get_table_source(table_ref);
635 assert!(result.is_err());
636 }
637
638 #[test]
639 fn test_resolve_table_ref() {
640 let provider = create_test_context_provider();
641
642 let table_ref = TableReference::bare("test_table");
644 let resolved = provider.resolve_table_ref(table_ref);
645 assert_eq!(resolved.table.as_ref(), "test_table");
646
647 let table_ref = TableReference::partial("test_schema", "test_table");
649 let resolved = provider.resolve_table_ref(table_ref);
650 assert_eq!(resolved.schema.as_ref(), "test_schema");
651 assert_eq!(resolved.table.as_ref(), "test_table");
652
653 let table_ref = TableReference::full("test_catalog", "test_schema", "test_table");
655 let resolved = provider.resolve_table_ref(table_ref);
656 assert_eq!(resolved.catalog.as_ref(), "test_catalog");
657 assert_eq!(resolved.schema.as_ref(), "test_schema");
658 assert_eq!(resolved.table.as_ref(), "test_table");
659 }
660
661 #[test]
662 fn test_udf_names() {
663 let provider = create_test_context_provider();
664 let udf_names = provider.udf_names();
665 assert!(!udf_names.is_empty());
668 }
669
670 #[test]
671 fn test_udaf_names() {
672 let provider = create_test_context_provider();
673 let udaf_names = provider.udaf_names();
674 assert!(!udaf_names.is_empty());
676 assert!(udaf_names.contains(&"sum".to_string()));
677 assert!(udaf_names.contains(&"count".to_string()));
678 }
679
680 #[test]
681 fn test_udwf_names() {
682 let provider = create_test_context_provider();
683 let udwf_names = provider.udwf_names();
684 assert!(!udwf_names.is_empty());
686 assert!(udwf_names.contains(&"row_number".to_string()));
687 }
688
689 #[test]
690 fn test_options() {
691 let provider = create_test_context_provider();
692 let options = provider.options();
693 assert!(!options.catalog.default_catalog.is_empty());
695 assert!(!options.catalog.default_schema.is_empty());
696 }
697
698 #[test]
699 fn test_get_aggregate_meta() {
700 let provider = create_test_context_provider();
701
702 let result = provider.get_aggregate_meta("sum");
704 assert!(result.is_some());
705 let udf = result.unwrap();
706 assert_eq!(udf.name().to_lowercase().as_str(), "sum");
707
708 let result = provider.get_aggregate_meta("unknown_aggregate");
710 assert!(result.is_none());
711 }
712
713 #[test]
714 fn test_get_window_meta() {
715 let provider = create_test_context_provider();
716
717 let result = provider.get_window_meta("row_number");
719 assert!(result.is_some());
720 let udf = result.unwrap();
721 assert_eq!(udf.name().to_lowercase().as_str(), "row_number");
722
723 let result = provider.get_window_meta("unknown_window");
725 assert!(result.is_none());
726 }
727}