datafusion_substrait/
extensions.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use datafusion::common::{plan_err, DataFusionError, HashMap};
19use substrait::proto::extensions::simple_extension_declaration::{
20    ExtensionFunction, ExtensionType, ExtensionTypeVariation, MappingType,
21};
22use substrait::proto::extensions::SimpleExtensionDeclaration;
23
24/// Substrait uses [SimpleExtensions](https://substrait.io/extensions/#simple-extensions) to define
25/// behavior of plans in addition to what's supported directly by the protobuf definitions.
26/// That includes functions, but also provides support for custom types and variations for existing
27/// types. This structs facilitates the use of these extensions in DataFusion.
28/// TODO: DF doesn't yet use extensions for type variations <https://github.com/apache/datafusion/issues/11544>
29/// TODO: DF doesn't yet provide valid extensionUris <https://github.com/apache/datafusion/issues/11545>
30#[derive(Default, Debug, PartialEq)]
31pub struct Extensions {
32    pub functions: HashMap<u32, String>, // anchor -> function name
33    pub types: HashMap<u32, String>,     // anchor -> type name
34    pub type_variations: HashMap<u32, String>, // anchor -> type variation name
35}
36
37impl Extensions {
38    /// Registers a function and returns the anchor (reference) to it. If the function has already
39    /// been registered, it returns the existing anchor.
40    /// Function names are case-insensitive (converted to lowercase).
41    pub fn register_function(&mut self, function_name: String) -> u32 {
42        let function_name = function_name.to_lowercase();
43
44        // Some functions are named differently in Substrait default extensions than in DF
45        // Rename those to match the Substrait extensions for interoperability
46        let function_name = match function_name.as_str() {
47            "substr" => "substring".to_string(),
48            "log" => "logb".to_string(),
49            "isnan" => "is_nan".to_string(),
50            _ => function_name,
51        };
52
53        match self.functions.iter().find(|(_, f)| *f == &function_name) {
54            Some((function_anchor, _)) => *function_anchor, // Function has been registered
55            None => {
56                // Function has NOT been registered
57                let function_anchor = self.functions.len() as u32;
58                self.functions
59                    .insert(function_anchor, function_name.clone());
60                function_anchor
61            }
62        }
63    }
64
65    /// Registers a type and returns the anchor (reference) to it. If the type has already
66    /// been registered, it returns the existing anchor.
67    pub fn register_type(&mut self, type_name: String) -> u32 {
68        let type_name = type_name.to_lowercase();
69        match self.types.iter().find(|(_, t)| *t == &type_name) {
70            Some((type_anchor, _)) => *type_anchor, // Type has been registered
71            None => {
72                // Type has NOT been registered
73                let type_anchor = self.types.len() as u32;
74                self.types.insert(type_anchor, type_name.clone());
75                type_anchor
76            }
77        }
78    }
79}
80
81impl TryFrom<&Vec<SimpleExtensionDeclaration>> for Extensions {
82    type Error = DataFusionError;
83
84    fn try_from(
85        value: &Vec<SimpleExtensionDeclaration>,
86    ) -> datafusion::common::Result<Self> {
87        let mut functions = HashMap::new();
88        let mut types = HashMap::new();
89        let mut type_variations = HashMap::new();
90
91        for ext in value {
92            match &ext.mapping_type {
93                Some(MappingType::ExtensionFunction(ext_f)) => {
94                    functions.insert(ext_f.function_anchor, ext_f.name.to_owned());
95                }
96                Some(MappingType::ExtensionType(ext_t)) => {
97                    types.insert(ext_t.type_anchor, ext_t.name.to_owned());
98                }
99                Some(MappingType::ExtensionTypeVariation(ext_v)) => {
100                    type_variations
101                        .insert(ext_v.type_variation_anchor, ext_v.name.to_owned());
102                }
103                None => return plan_err!("Cannot parse empty extension"),
104            }
105        }
106
107        Ok(Extensions {
108            functions,
109            types,
110            type_variations,
111        })
112    }
113}
114
115impl From<Extensions> for Vec<SimpleExtensionDeclaration> {
116    // Silence deprecation warnings for `extension_uri_reference` during the uri -> urn migration
117    // See: https://github.com/substrait-io/substrait/issues/856
118    #[allow(deprecated)]
119    fn from(val: Extensions) -> Vec<SimpleExtensionDeclaration> {
120        let mut extensions = vec![];
121        for (f_anchor, f_name) in val.functions {
122            let function_extension = ExtensionFunction {
123                extension_uri_reference: u32::MAX,
124                extension_urn_reference: u32::MAX,
125                function_anchor: f_anchor,
126                name: f_name,
127            };
128            let simple_extension = SimpleExtensionDeclaration {
129                mapping_type: Some(MappingType::ExtensionFunction(function_extension)),
130            };
131            extensions.push(simple_extension);
132        }
133
134        for (t_anchor, t_name) in val.types {
135            let type_extension = ExtensionType {
136                extension_uri_reference: u32::MAX, // https://github.com/apache/datafusion/issues/11545
137                extension_urn_reference: u32::MAX, // https://github.com/apache/datafusion/issues/11545
138                type_anchor: t_anchor,
139                name: t_name,
140            };
141            let simple_extension = SimpleExtensionDeclaration {
142                mapping_type: Some(MappingType::ExtensionType(type_extension)),
143            };
144            extensions.push(simple_extension);
145        }
146
147        for (tv_anchor, tv_name) in val.type_variations {
148            let type_variation_extension = ExtensionTypeVariation {
149                extension_uri_reference: u32::MAX, // We don't register proper extension URIs yet
150                extension_urn_reference: u32::MAX, // We don't register proper extension URIs yet
151                type_variation_anchor: tv_anchor,
152                name: tv_name,
153            };
154            let simple_extension = SimpleExtensionDeclaration {
155                mapping_type: Some(MappingType::ExtensionTypeVariation(
156                    type_variation_extension,
157                )),
158            };
159            extensions.push(simple_extension);
160        }
161
162        extensions
163    }
164}