Skip to main content

datafusion_spark/
session_state.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::collections::HashMap;
19use std::sync::Arc;
20
21use datafusion::execution::SessionStateBuilder;
22
23use crate::planner::SparkFunctionPlanner;
24use crate::{
25    all_default_aggregate_functions, all_default_scalar_functions,
26    all_default_table_functions, all_default_window_functions,
27};
28
29/// Extension trait for adding Apache Spark features to [`SessionStateBuilder`].
30///
31/// This trait provides a convenient way to register all Apache Spark-compatible
32/// functions and planners with a DataFusion session.
33///
34/// # Example
35///
36/// ```rust
37/// use datafusion::execution::SessionStateBuilder;
38/// use datafusion_spark::SessionStateBuilderSpark;
39///
40/// // Create a SessionState with Apache Spark features enabled
41/// // note: the order matters here, `with_spark_features` should be
42/// // called after `with_default_features` to overwrite any existing functions
43/// let state = SessionStateBuilder::new()
44///     .with_default_features()
45///     .with_spark_features()
46///     .build();
47/// ```
48pub trait SessionStateBuilderSpark {
49    /// Adds all expr_planners, scalar, aggregate, window and table functions
50    /// compatible with Apache Spark.
51    ///
52    /// Note: This overwrites any previously registered items with the same name.
53    fn with_spark_features(self) -> Self;
54}
55
56impl SessionStateBuilderSpark for SessionStateBuilder {
57    fn with_spark_features(mut self) -> Self {
58        self.expr_planners()
59            .get_or_insert_with(Vec::new)
60            // planners are evaluated in order of insertion. Push Apache Spark function planner to the front
61            // to take precedence over others
62            .insert(0, Arc::new(SparkFunctionPlanner));
63
64        self.scalar_functions()
65            .get_or_insert_with(Vec::new)
66            .extend(all_default_scalar_functions());
67
68        self.aggregate_functions()
69            .get_or_insert_with(Vec::new)
70            .extend(all_default_aggregate_functions());
71
72        self.window_functions()
73            .get_or_insert_with(Vec::new)
74            .extend(all_default_window_functions());
75
76        self.table_functions()
77            .get_or_insert_with(HashMap::new)
78            .extend(
79                all_default_table_functions()
80                    .into_iter()
81                    .map(|f| (f.name().to_string(), f)),
82            );
83
84        self
85    }
86}
87
88#[cfg(test)]
89mod tests {
90    use super::*;
91
92    #[test]
93    fn test_session_state_with_spark_features() {
94        let state = SessionStateBuilder::new().with_spark_features().build();
95
96        assert!(
97            state.scalar_functions().contains_key("sha2"),
98            "Apache Spark scalar function 'sha2' should be registered"
99        );
100
101        assert!(
102            state.aggregate_functions().contains_key("try_sum"),
103            "Apache Spark aggregate function 'try_sum' should be registered"
104        );
105
106        assert!(
107            !state.expr_planners().is_empty(),
108            "Apache Spark expr planners should be registered"
109        );
110    }
111}