datafusion_substrait/
extensions.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use datafusion::common::{plan_err, DataFusionError, HashMap};
19use substrait::proto::extensions::simple_extension_declaration::{
20    ExtensionFunction, ExtensionType, ExtensionTypeVariation, MappingType,
21};
22use substrait::proto::extensions::SimpleExtensionDeclaration;
23
24/// Substrait uses [SimpleExtensions](https://substrait.io/extensions/#simple-extensions) to define
25/// behavior of plans in addition to what's supported directly by the protobuf definitions.
26/// That includes functions, but also provides support for custom types and variations for existing
27/// types. This structs facilitates the use of these extensions in DataFusion.
28/// TODO: DF doesn't yet use extensions for type variations <https://github.com/apache/datafusion/issues/11544>
29/// TODO: DF doesn't yet provide valid extensionUris <https://github.com/apache/datafusion/issues/11545>
30#[derive(Default, Debug, PartialEq)]
31pub struct Extensions {
32    pub functions: HashMap<u32, String>, // anchor -> function name
33    pub types: HashMap<u32, String>,     // anchor -> type name
34    pub type_variations: HashMap<u32, String>, // anchor -> type variation name
35}
36
37impl Extensions {
38    /// Registers a function and returns the anchor (reference) to it. If the function has already
39    /// been registered, it returns the existing anchor.
40    /// Function names are case-insensitive (converted to lowercase).
41    pub fn register_function(&mut self, function_name: String) -> u32 {
42        let function_name = function_name.to_lowercase();
43
44        // Some functions are named differently in Substrait default extensions than in DF
45        // Rename those to match the Substrait extensions for interoperability
46        let function_name = match function_name.as_str() {
47            "substr" => "substring".to_string(),
48            _ => function_name,
49        };
50
51        match self.functions.iter().find(|(_, f)| *f == &function_name) {
52            Some((function_anchor, _)) => *function_anchor, // Function has been registered
53            None => {
54                // Function has NOT been registered
55                let function_anchor = self.functions.len() as u32;
56                self.functions
57                    .insert(function_anchor, function_name.clone());
58                function_anchor
59            }
60        }
61    }
62
63    /// Registers a type and returns the anchor (reference) to it. If the type has already
64    /// been registered, it returns the existing anchor.
65    pub fn register_type(&mut self, type_name: String) -> u32 {
66        let type_name = type_name.to_lowercase();
67        match self.types.iter().find(|(_, t)| *t == &type_name) {
68            Some((type_anchor, _)) => *type_anchor, // Type has been registered
69            None => {
70                // Type has NOT been registered
71                let type_anchor = self.types.len() as u32;
72                self.types.insert(type_anchor, type_name.clone());
73                type_anchor
74            }
75        }
76    }
77}
78
79impl TryFrom<&Vec<SimpleExtensionDeclaration>> for Extensions {
80    type Error = DataFusionError;
81
82    fn try_from(
83        value: &Vec<SimpleExtensionDeclaration>,
84    ) -> datafusion::common::Result<Self> {
85        let mut functions = HashMap::new();
86        let mut types = HashMap::new();
87        let mut type_variations = HashMap::new();
88
89        for ext in value {
90            match &ext.mapping_type {
91                Some(MappingType::ExtensionFunction(ext_f)) => {
92                    functions.insert(ext_f.function_anchor, ext_f.name.to_owned());
93                }
94                Some(MappingType::ExtensionType(ext_t)) => {
95                    types.insert(ext_t.type_anchor, ext_t.name.to_owned());
96                }
97                Some(MappingType::ExtensionTypeVariation(ext_v)) => {
98                    type_variations
99                        .insert(ext_v.type_variation_anchor, ext_v.name.to_owned());
100                }
101                None => return plan_err!("Cannot parse empty extension"),
102            }
103        }
104
105        Ok(Extensions {
106            functions,
107            types,
108            type_variations,
109        })
110    }
111}
112
113impl From<Extensions> for Vec<SimpleExtensionDeclaration> {
114    fn from(val: Extensions) -> Vec<SimpleExtensionDeclaration> {
115        let mut extensions = vec![];
116        for (f_anchor, f_name) in val.functions {
117            let function_extension = ExtensionFunction {
118                extension_uri_reference: u32::MAX,
119                function_anchor: f_anchor,
120                name: f_name,
121            };
122            let simple_extension = SimpleExtensionDeclaration {
123                mapping_type: Some(MappingType::ExtensionFunction(function_extension)),
124            };
125            extensions.push(simple_extension);
126        }
127
128        for (t_anchor, t_name) in val.types {
129            let type_extension = ExtensionType {
130                extension_uri_reference: u32::MAX, // https://github.com/apache/datafusion/issues/11545
131                type_anchor: t_anchor,
132                name: t_name,
133            };
134            let simple_extension = SimpleExtensionDeclaration {
135                mapping_type: Some(MappingType::ExtensionType(type_extension)),
136            };
137            extensions.push(simple_extension);
138        }
139
140        for (tv_anchor, tv_name) in val.type_variations {
141            let type_variation_extension = ExtensionTypeVariation {
142                extension_uri_reference: u32::MAX, // We don't register proper extension URIs yet
143                type_variation_anchor: tv_anchor,
144                name: tv_name,
145            };
146            let simple_extension = SimpleExtensionDeclaration {
147                mapping_type: Some(MappingType::ExtensionTypeVariation(
148                    type_variation_extension,
149                )),
150            };
151            extensions.push(simple_extension);
152        }
153
154        extensions
155    }
156}