datafusion_spark/session_state.rs
1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements. See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership. The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License. You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied. See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::collections::HashMap;
19use std::sync::Arc;
20
21use datafusion::execution::SessionStateBuilder;
22
23use crate::planner::SparkFunctionPlanner;
24use crate::{
25 all_default_aggregate_functions, all_default_scalar_functions,
26 all_default_table_functions, all_default_window_functions,
27};
28
29/// Extension trait for adding Apache Spark features to [`SessionStateBuilder`].
30///
31/// This trait provides a convenient way to register all Apache Spark-compatible
32/// functions and planners with a DataFusion session.
33///
34/// # Example
35///
36/// ```rust
37/// use datafusion::execution::SessionStateBuilder;
38/// use datafusion_spark::SessionStateBuilderSpark;
39///
40/// // Create a SessionState with Apache Spark features enabled
41/// // note: the order matters here, `with_spark_features` should be
42/// // called after `with_default_features` to overwrite any existing functions
43/// let state = SessionStateBuilder::new()
44/// .with_default_features()
45/// .with_spark_features()
46/// .build();
47/// ```
48pub trait SessionStateBuilderSpark {
49 /// Adds all expr_planners, scalar, aggregate, window and table functions
50 /// compatible with Apache Spark.
51 ///
52 /// Note: This overwrites any previously registered items with the same name.
53 fn with_spark_features(self) -> Self;
54}
55
56impl SessionStateBuilderSpark for SessionStateBuilder {
57 fn with_spark_features(mut self) -> Self {
58 self.expr_planners()
59 .get_or_insert_with(Vec::new)
60 // planners are evaluated in order of insertion. Push Apache Spark function planner to the front
61 // to take precedence over others
62 .insert(0, Arc::new(SparkFunctionPlanner));
63
64 self.scalar_functions()
65 .get_or_insert_with(Vec::new)
66 .extend(all_default_scalar_functions());
67
68 self.aggregate_functions()
69 .get_or_insert_with(Vec::new)
70 .extend(all_default_aggregate_functions());
71
72 self.window_functions()
73 .get_or_insert_with(Vec::new)
74 .extend(all_default_window_functions());
75
76 self.table_functions()
77 .get_or_insert_with(HashMap::new)
78 .extend(
79 all_default_table_functions()
80 .into_iter()
81 .map(|f| (f.name().to_string(), f)),
82 );
83
84 self
85 }
86}
87
88#[cfg(test)]
89mod tests {
90 use super::*;
91
92 #[test]
93 fn test_session_state_with_spark_features() {
94 let state = SessionStateBuilder::new().with_spark_features().build();
95
96 assert!(
97 state.scalar_functions().contains_key("sha2"),
98 "Apache Spark scalar function 'sha2' should be registered"
99 );
100
101 assert!(
102 state.aggregate_functions().contains_key("try_sum"),
103 "Apache Spark aggregate function 'try_sum' should be registered"
104 );
105
106 assert!(
107 !state.expr_planners().is_empty(),
108 "Apache Spark expr planners should be registered"
109 );
110 }
111}