1pub mod plan_node;
13pub mod planner;
14
15use std::collections::HashMap;
16use std::collections::hash_map::Entry;
17use std::sync::Arc;
18
19use async_trait::async_trait;
20use datafusion::arrow::datatypes::{DataType, SchemaRef};
21use datafusion::catalog::cte_worktable::CteWorkTable;
22use datafusion::common::file_options::file_type::FileType;
23use datafusion::common::plan_datafusion_err;
24use datafusion::config::ConfigOptions;
25use datafusion::datasource::file_format::format_as_file_type;
26use datafusion::datasource::provider_as_source;
27use datafusion::error::Result;
28use datafusion::execution::SessionState;
29use datafusion::execution::context::QueryPlanner;
30use datafusion::logical_expr::planner::{ExprPlanner, TypePlanner};
31use datafusion::logical_expr::var_provider::is_system_variables;
32use datafusion::logical_expr::{AggregateUDF, LogicalPlan, ScalarUDF, TableSource, WindowUDF};
33use datafusion::optimizer::AnalyzerRule;
34use datafusion::optimizer::analyzer::type_coercion::TypeCoercion;
35use datafusion::physical_plan::ExecutionPlan;
36use datafusion::physical_planner::{DefaultPhysicalPlanner, ExtensionPlanner, PhysicalPlanner};
37use datafusion::prelude::{DataFrame, Expr, SQLOptions, SessionContext};
38use datafusion::sql::parser::Statement;
39use datafusion::sql::planner::{ContextProvider, NullOrdering, ParserOptions, SqlToRel};
40use datafusion::sql::{ResolvedTableReference, TableReference};
41use datafusion::variable::VarType;
42
43use self::planner::ClickHouseExtensionPlanner;
44use crate::analyzer::function_pushdown::ClickHouseFunctionPushdown;
45use crate::udfs::apply::{CLICKHOUSE_APPLY_ALIASES, clickhouse_apply_udf};
46use crate::udfs::clickhouse::{CLICKHOUSE_UDF_ALIASES, clickhouse_udf};
47use crate::udfs::placeholder::PlaceholderUDF;
48
49pub fn prepare_session_context(
60 ctx: SessionContext,
61 extension_planners: Option<Vec<Arc<dyn ExtensionPlanner + Send + Sync>>>,
62) -> SessionContext {
63 #[cfg(feature = "federation")]
64 use crate::federation::FederatedContext as _;
65
66 #[cfg(feature = "federation")]
69 let ctx = ctx.federate();
70 let state = ctx.state();
72 let config = state.config().clone();
73 let config = config.set_str("datafusion.sql_parser.dialect", "ClickHouse");
81 let state_builder = if state
83 .analyzer()
84 .rules
85 .iter()
86 .any(|rule| rule.name() == ClickHouseFunctionPushdown.name())
87 {
88 ctx.into_state_builder()
89 } else {
90 let analyzer_rules = configure_analyzer_rules(&state);
91 ctx.into_state_builder().with_analyzer_rules(analyzer_rules)
92 };
93 let ctx = SessionContext::new_with_state(
95 state_builder
96 .with_config(config)
97 .with_query_planner(Arc::new(ClickHouseQueryPlanner::new_with_planners(
98 extension_planners.unwrap_or_default(),
99 )))
100 .build(),
101 );
102 ctx.register_udf(clickhouse_udf());
103 ctx.register_udf(clickhouse_apply_udf());
104 ctx
105}
106
107pub fn configure_analyzer_rules(state: &SessionState) -> Vec<Arc<dyn AnalyzerRule + Send + Sync>> {
109 let mut analyzer_rules = state.analyzer().rules.clone();
111
112 let type_coercion = TypeCoercion::default();
115 let pos = analyzer_rules.iter().position(|x| x.name() == type_coercion.name()).unwrap_or(0);
116
117 let pushdown_rule = Arc::new(ClickHouseFunctionPushdown);
118 analyzer_rules.insert(pos, pushdown_rule);
119 analyzer_rules
120}
121
122#[derive(Clone)]
124pub struct ClickHouseQueryPlanner {
125 planners: Vec<Arc<dyn ExtensionPlanner + Send + Sync>>,
126}
127
128impl std::fmt::Debug for ClickHouseQueryPlanner {
129 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
130 f.debug_struct("ClickHouseQueryPlanner").finish()
131 }
132}
133
134impl Default for ClickHouseQueryPlanner {
135 fn default() -> Self { Self::new() }
136}
137
138impl ClickHouseQueryPlanner {
139 pub fn new() -> Self {
141 let planners = vec![
142 #[cfg(feature = "federation")]
143 Arc::new(datafusion_federation::FederatedPlanner::new()),
144 Arc::new(ClickHouseExtensionPlanner {}) as Arc<dyn ExtensionPlanner + Send + Sync>,
145 ];
146 ClickHouseQueryPlanner { planners }
147 }
148
149 pub fn new_with_planners(planners: Vec<Arc<dyn ExtensionPlanner + Send + Sync>>) -> Self {
151 let mut this = Self::new();
152 this.planners.extend(planners);
153 this
154 }
155
156 #[must_use]
158 pub fn with_planner(mut self, planner: Arc<dyn ExtensionPlanner + Send + Sync>) -> Self {
159 self.planners.push(planner);
160 self
161 }
162}
163
164#[async_trait]
165impl QueryPlanner for ClickHouseQueryPlanner {
166 async fn create_physical_plan(
167 &self,
168 logical_plan: &LogicalPlan,
169 session_state: &SessionState,
170 ) -> Result<Arc<dyn ExecutionPlan>> {
171 let planner = DefaultPhysicalPlanner::with_extension_planners(self.planners.clone());
173 planner.create_physical_plan(logical_plan, session_state).await
174 }
175}
176
177#[derive(Clone)]
179pub struct ClickHouseSessionContext {
180 inner: SessionContext,
181 expr_planner: Option<Arc<dyn ExprPlanner>>,
182}
183
184impl ClickHouseSessionContext {
185 pub fn new(
187 ctx: SessionContext,
188 extension_planners: Option<Vec<Arc<dyn ExtensionPlanner + Send + Sync>>>,
189 ) -> Self {
190 Self { inner: prepare_session_context(ctx, extension_planners), expr_planner: None }
191 }
192
193 #[must_use]
194 pub fn with_expr_planner(mut self, expr_planner: Arc<dyn ExprPlanner>) -> Self {
195 self.expr_planner = Some(expr_planner);
196 self
197 }
198
199 pub fn session_context(&self) -> &SessionContext { &self.inner }
201
202 pub fn into_session_context(self) -> SessionContext { self.inner }
204
205 pub async fn sql(&self, sql: &str) -> Result<DataFrame> {
210 self.sql_with_options(sql, SQLOptions::new()).await
211 }
212
213 pub async fn sql_with_options(&self, sql: &str, options: SQLOptions) -> Result<DataFrame> {
218 let state = self.inner.state();
219 let statement = state.sql_to_statement(sql, "ClickHouse")?;
220 let plan = self.statement_to_plan(&state, statement).await?;
221 options.verify_plan(&plan)?;
222 self.execute_logical_plan(plan).await
223 }
224
225 pub async fn statement_to_plan(
232 &self,
233 state: &SessionState,
234 statement: Statement,
235 ) -> Result<LogicalPlan> {
236 let references = state.resolve_table_references(&statement)?;
237
238 let provider =
239 ClickHouseContextProvider::new(state.clone(), HashMap::with_capacity(references.len()));
240
241 let mut provider = if let Some(planner) = self.expr_planner.as_ref() {
242 provider.with_expr_planner(Arc::clone(planner))
243 } else {
244 provider
245 };
246
247 for reference in references {
248 let catalog = &state.config_options().catalog;
251 let resolved = reference.resolve(&catalog.default_catalog, &catalog.default_schema);
252 if let Entry::Vacant(v) = provider.tables.entry(resolved) {
253 let resolved = v.key();
254 if let Ok(schema) = provider.state.schema_for_ref(resolved.clone())
255 && let Some(table) = schema.table(&resolved.table).await?
256 {
257 let _ = v.insert(provider_as_source(table));
258 }
259 }
260 }
261
262 SqlToRel::new_with_options(&provider, Self::get_parser_options(&self.state()))
263 .statement_to_plan(statement)
264 }
265
266 fn get_parser_options(state: &SessionState) -> ParserOptions {
267 let sql_parser_options = &state.config().options().sql_parser;
268 ParserOptions {
269 parse_float_as_decimal: sql_parser_options.parse_float_as_decimal,
270 enable_ident_normalization: sql_parser_options.enable_ident_normalization,
271 enable_options_value_normalization: sql_parser_options
272 .enable_options_value_normalization,
273 support_varchar_with_length: sql_parser_options.support_varchar_with_length,
274 map_string_types_to_utf8view: sql_parser_options.map_string_types_to_utf8view,
275 collect_spans: sql_parser_options.collect_spans,
276 default_null_ordering: NullOrdering::NullsMax,
277 }
278 }
279}
280
281impl From<SessionContext> for ClickHouseSessionContext {
282 fn from(inner: SessionContext) -> Self { Self::new(inner, None) }
283}
284
285impl From<&SessionContext> for ClickHouseSessionContext {
286 fn from(inner: &SessionContext) -> Self { Self::new(inner.clone(), None) }
287}
288
289impl std::ops::Deref for ClickHouseSessionContext {
290 type Target = SessionContext;
291
292 fn deref(&self) -> &Self::Target { &self.inner }
293}
294
295pub struct ClickHouseContextProvider {
300 state: SessionState,
301 tables: HashMap<ResolvedTableReference, Arc<dyn TableSource>>,
302 expr_planners: Vec<Arc<dyn ExprPlanner>>,
303 type_planner: Option<Arc<dyn TypePlanner>>,
304}
305
306impl ClickHouseContextProvider {
307 pub fn new(
308 state: SessionState,
309 tables: HashMap<ResolvedTableReference, Arc<dyn TableSource>>,
310 ) -> Self {
311 Self { state, tables, expr_planners: vec![], type_planner: None }
312 }
313
314 #[must_use]
315 pub fn with_expr_planner(mut self, planner: Arc<dyn ExprPlanner>) -> Self {
316 self.expr_planners.push(planner);
317 self
318 }
319
320 #[must_use]
321 pub fn with_type_planner(mut self, type_planner: Arc<dyn TypePlanner>) -> Self {
322 self.type_planner = Some(type_planner);
323 self
324 }
325
326 fn resolve_table_ref(&self, table_ref: impl Into<TableReference>) -> ResolvedTableReference {
329 let catalog = &self.state.config_options().catalog;
330 table_ref.into().resolve(&catalog.default_catalog, &catalog.default_schema)
331 }
332}
333
334impl ContextProvider for ClickHouseContextProvider {
335 fn get_function_meta(&self, name: &str) -> Option<Arc<ScalarUDF>> {
336 if CLICKHOUSE_UDF_ALIASES.contains(&name) {
338 return Some(Arc::new(clickhouse_udf()));
339 }
340
341 if CLICKHOUSE_APPLY_ALIASES.contains(&name) {
343 return Some(Arc::new(clickhouse_apply_udf()));
344 }
345
346 if let Some(func) = self.state.scalar_functions().get(name) {
348 return Some(Arc::clone(func));
349 }
350
351 if self.state.aggregate_functions().contains_key(name) {
354 return None;
355 }
356 if self.state.window_functions().contains_key(name) {
357 return None;
358 }
359
360 Some(Arc::new(ScalarUDF::new_from_impl(PlaceholderUDF::new(name))))
362 }
363
364 fn get_expr_planners(&self) -> &[Arc<dyn ExprPlanner>] { &self.expr_planners }
365
366 fn get_type_planner(&self) -> Option<Arc<dyn TypePlanner>> {
367 if let Some(type_planner) = &self.type_planner {
368 Some(Arc::clone(type_planner))
369 } else {
370 None
371 }
372 }
373
374 fn get_table_source(&self, name: TableReference) -> Result<Arc<dyn TableSource>> {
375 let name = self.resolve_table_ref(name);
376 self.tables
377 .get(&name)
378 .cloned()
379 .ok_or_else(|| plan_datafusion_err!("table '{name}' not found"))
380 }
381
382 fn get_table_function_source(
383 &self,
384 name: &str,
385 args: Vec<Expr>,
386 ) -> Result<Arc<dyn TableSource>> {
387 let tbl_func = self
388 .state
389 .table_functions()
390 .get(name)
391 .cloned()
392 .ok_or_else(|| plan_datafusion_err!("table function '{name}' not found"))?;
393 let provider = tbl_func.create_table_provider(&args)?;
394
395 Ok(provider_as_source(provider))
396 }
397
398 fn create_cte_work_table(&self, name: &str, schema: SchemaRef) -> Result<Arc<dyn TableSource>> {
402 let table = Arc::new(CteWorkTable::new(name, schema));
403 Ok(provider_as_source(table))
404 }
405
406 fn get_aggregate_meta(&self, name: &str) -> Option<Arc<AggregateUDF>> {
407 self.state.aggregate_functions().get(name).cloned()
408 }
409
410 fn get_window_meta(&self, name: &str) -> Option<Arc<WindowUDF>> {
411 self.state.window_functions().get(name).cloned()
412 }
413
414 fn get_variable_type(&self, variable_names: &[String]) -> Option<DataType> {
415 if variable_names.is_empty() {
416 return None;
417 }
418
419 let provider_type = if is_system_variables(variable_names) {
420 VarType::System
421 } else {
422 VarType::UserDefined
423 };
424
425 self.state
426 .execution_props()
427 .var_providers
428 .as_ref()
429 .and_then(|provider| provider.get(&provider_type)?.get_type(variable_names))
430 }
431
432 fn options(&self) -> &ConfigOptions { self.state.config_options() }
433
434 fn udf_names(&self) -> Vec<String> { self.state.scalar_functions().keys().cloned().collect() }
436
437 fn udaf_names(&self) -> Vec<String> {
438 self.state.aggregate_functions().keys().cloned().collect()
439 }
440
441 fn udwf_names(&self) -> Vec<String> { self.state.window_functions().keys().cloned().collect() }
442
443 fn get_file_type(&self, ext: &str) -> Result<Arc<dyn FileType>> {
444 self.state
445 .get_file_format_factory(ext)
446 .ok_or(plan_datafusion_err!("There is no registered file format with ext {ext}"))
447 .map(|file_type| format_as_file_type(file_type))
448 }
449}
450
451#[cfg(all(test, feature = "test-utils"))]
452mod tests {
453 use std::collections::HashMap;
454 use std::sync::Arc;
455
456 use datafusion::arrow::datatypes::{DataType, Field, Schema};
457 use datafusion::common::DFSchema;
458 use datafusion::logical_expr::planner::{
459 ExprPlanner, PlannerResult, RawBinaryExpr, TypePlanner,
460 };
461 use datafusion::prelude::{SessionContext, lit};
462 use datafusion::sql::TableReference;
463 use datafusion::sql::sqlparser::ast;
464
465 use super::*;
466
467 #[derive(Debug)]
469 struct MockTypePlanner;
470
471 impl TypePlanner for MockTypePlanner {
472 fn plan_type(&self, _expr: &ast::DataType) -> Result<Option<DataType>> {
473 Ok(Some(DataType::Utf8))
474 }
475 }
476
477 #[derive(Debug)]
479 struct MockExprPlanner;
480
481 impl ExprPlanner for MockExprPlanner {
482 fn plan_binary_op(
483 &self,
484 expr: RawBinaryExpr,
485 _schema: &DFSchema,
486 ) -> Result<PlannerResult<RawBinaryExpr>> {
487 Ok(PlannerResult::Original(expr))
488 }
489 }
490
491 fn create_test_context_provider() -> ClickHouseContextProvider {
492 let ctx = SessionContext::new();
493 let state = ctx.state();
494 let tables = HashMap::new();
495 ClickHouseContextProvider::new(state, tables)
496 }
497
498 #[test]
499 fn test_with_expr_planner() {
500 let mut provider = create_test_context_provider();
501 assert!(provider.expr_planners.is_empty());
502
503 let expr_planner = Arc::new(MockExprPlanner) as Arc<dyn ExprPlanner>;
504 provider = provider.with_expr_planner(Arc::clone(&expr_planner));
505
506 assert_eq!(provider.expr_planners.len(), 1);
507 assert_eq!(provider.get_expr_planners().len(), 1);
508 }
509
510 #[test]
511 fn test_with_type_planner() {
512 let mut provider = create_test_context_provider();
513 assert!(provider.type_planner.is_none());
514
515 let type_planner = Arc::new(MockTypePlanner) as Arc<dyn TypePlanner>;
516 provider = provider.with_type_planner(Arc::clone(&type_planner));
517
518 assert!(provider.type_planner.is_some());
519 }
520
521 #[test]
522 fn test_get_type_planner() {
523 let provider = create_test_context_provider();
524 assert!(provider.get_type_planner().is_none());
525
526 let type_planner = Arc::new(MockTypePlanner) as Arc<dyn TypePlanner>;
527 let provider = provider.with_type_planner(Arc::clone(&type_planner));
528
529 assert!(provider.get_type_planner().is_some());
530 }
531
532 #[test]
533 fn test_get_table_function_source_not_found() {
534 let provider = create_test_context_provider();
535 let args = vec![lit("test")];
536
537 let result = provider.get_table_function_source("nonexistent_function", args);
538 assert!(result.is_err());
539 }
540
541 #[test]
542 fn test_create_cte_work_table() {
543 let provider = create_test_context_provider();
544 let schema = Arc::new(Schema::new(vec![
545 Field::new("id", DataType::Int32, false),
546 Field::new("name", DataType::Utf8, true),
547 ]));
548
549 let result = provider.create_cte_work_table("test_cte", Arc::clone(&schema));
550 assert!(result.is_ok());
551
552 let table_source = result.unwrap();
553 assert_eq!(table_source.schema(), schema);
554 }
555
556 #[test]
557 fn test_get_variable_type_empty() {
558 let provider = create_test_context_provider();
559 let result = provider.get_variable_type(&[]);
560 assert!(result.is_none());
561 }
562
563 #[test]
564 fn test_get_variable_type_system_variables() {
565 let provider = create_test_context_provider();
566 let result = provider.get_variable_type(&["@@version".to_string()]);
568 assert!(result.is_none());
570 }
571
572 #[test]
573 fn test_get_variable_type_user_defined() {
574 let provider = create_test_context_provider();
575 let result = provider.get_variable_type(&["user_var".to_string()]);
577 assert!(result.is_none());
579 }
580
581 #[test]
582 fn test_get_file_type_unknown_extension() {
583 let provider = create_test_context_provider();
584 let result = provider.get_file_type("unknown_ext");
585 assert!(result.is_err());
586 }
587
588 #[test]
589 fn test_get_file_type_known_extension() {
590 let provider = create_test_context_provider();
591 let result = provider.get_file_type("csv");
593 assert!(result.is_ok());
594 }
595
596 #[test]
597 fn test_get_function_meta_clickhouse_udf() {
598 let provider = create_test_context_provider();
599
600 let result = provider.get_function_meta("clickhouse");
602 assert!(result.is_some());
603 let udf = result.unwrap();
604 assert_eq!(udf.name(), "clickhouse");
605 }
606
607 #[test]
608 fn test_get_function_meta_placeholder_udf() {
609 let provider = create_test_context_provider();
610
611 let result = provider.get_function_meta("unknown_function");
613 assert!(result.is_some());
614 let udf = result.unwrap();
615 assert_eq!(udf.name(), "unknown_function");
616 }
617
618 #[test]
619 fn test_get_function_meta_aggregate_function() {
620 let provider = create_test_context_provider();
621
622 let result = provider.get_function_meta("sum");
624 assert!(result.is_none());
625 }
626
627 #[test]
628 fn test_get_function_meta_window_function() {
629 let provider = create_test_context_provider();
630
631 let result = provider.get_function_meta("row_number");
633 assert!(result.is_none());
634 }
635
636 #[test]
637 fn test_get_table_source_not_found() {
638 let provider = create_test_context_provider();
639 let table_ref = TableReference::bare("nonexistent_table");
640
641 let result = provider.get_table_source(table_ref);
642 assert!(result.is_err());
643 }
644
645 #[test]
646 fn test_resolve_table_ref() {
647 let provider = create_test_context_provider();
648
649 let table_ref = TableReference::bare("test_table");
651 let resolved = provider.resolve_table_ref(table_ref);
652 assert_eq!(resolved.table.as_ref(), "test_table");
653
654 let table_ref = TableReference::partial("test_schema", "test_table");
656 let resolved = provider.resolve_table_ref(table_ref);
657 assert_eq!(resolved.schema.as_ref(), "test_schema");
658 assert_eq!(resolved.table.as_ref(), "test_table");
659
660 let table_ref = TableReference::full("test_catalog", "test_schema", "test_table");
662 let resolved = provider.resolve_table_ref(table_ref);
663 assert_eq!(resolved.catalog.as_ref(), "test_catalog");
664 assert_eq!(resolved.schema.as_ref(), "test_schema");
665 assert_eq!(resolved.table.as_ref(), "test_table");
666 }
667
668 #[test]
669 fn test_udf_names() {
670 let provider = create_test_context_provider();
671 let udf_names = provider.udf_names();
672 assert!(!udf_names.is_empty());
675 }
676
677 #[test]
678 fn test_udaf_names() {
679 let provider = create_test_context_provider();
680 let udaf_names = provider.udaf_names();
681 assert!(!udaf_names.is_empty());
683 assert!(udaf_names.contains(&"sum".to_string()));
684 assert!(udaf_names.contains(&"count".to_string()));
685 }
686
687 #[test]
688 fn test_udwf_names() {
689 let provider = create_test_context_provider();
690 let udwf_names = provider.udwf_names();
691 assert!(!udwf_names.is_empty());
693 assert!(udwf_names.contains(&"row_number".to_string()));
694 }
695
696 #[test]
697 fn test_options() {
698 let provider = create_test_context_provider();
699 let options = provider.options();
700 assert!(!options.catalog.default_catalog.is_empty());
702 assert!(!options.catalog.default_schema.is_empty());
703 }
704
705 #[test]
706 fn test_get_aggregate_meta() {
707 let provider = create_test_context_provider();
708
709 let result = provider.get_aggregate_meta("sum");
711 assert!(result.is_some());
712 let udf = result.unwrap();
713 assert_eq!(udf.name().to_lowercase().as_str(), "sum");
714
715 let result = provider.get_aggregate_meta("unknown_aggregate");
717 assert!(result.is_none());
718 }
719
720 #[test]
721 fn test_get_window_meta() {
722 let provider = create_test_context_provider();
723
724 let result = provider.get_window_meta("row_number");
726 assert!(result.is_some());
727 let udf = result.unwrap();
728 assert_eq!(udf.name().to_lowercase().as_str(), "row_number");
729
730 let result = provider.get_window_meta("unknown_window");
732 assert!(result.is_none());
733 }
734}