datafusion_substrait/
extensions.rs1use datafusion::common::{plan_err, DataFusionError, HashMap};
19use substrait::proto::extensions::simple_extension_declaration::{
20 ExtensionFunction, ExtensionType, ExtensionTypeVariation, MappingType,
21};
22use substrait::proto::extensions::SimpleExtensionDeclaration;
23
24#[derive(Default, Debug, PartialEq)]
31pub struct Extensions {
32 pub functions: HashMap<u32, String>, pub types: HashMap<u32, String>, pub type_variations: HashMap<u32, String>, }
36
37impl Extensions {
38 pub fn register_function(&mut self, function_name: String) -> u32 {
42 let function_name = function_name.to_lowercase();
43
44 let function_name = match function_name.as_str() {
47 "substr" => "substring".to_string(),
48 "log" => "logb".to_string(),
49 "isnan" => "is_nan".to_string(),
50 _ => function_name,
51 };
52
53 match self.functions.iter().find(|(_, f)| *f == &function_name) {
54 Some((function_anchor, _)) => *function_anchor, None => {
56 let function_anchor = self.functions.len() as u32;
58 self.functions
59 .insert(function_anchor, function_name.clone());
60 function_anchor
61 }
62 }
63 }
64
65 pub fn register_type(&mut self, type_name: String) -> u32 {
68 let type_name = type_name.to_lowercase();
69 match self.types.iter().find(|(_, t)| *t == &type_name) {
70 Some((type_anchor, _)) => *type_anchor, None => {
72 let type_anchor = self.types.len() as u32;
74 self.types.insert(type_anchor, type_name.clone());
75 type_anchor
76 }
77 }
78 }
79}
80
81impl TryFrom<&Vec<SimpleExtensionDeclaration>> for Extensions {
82 type Error = DataFusionError;
83
84 fn try_from(
85 value: &Vec<SimpleExtensionDeclaration>,
86 ) -> datafusion::common::Result<Self> {
87 let mut functions = HashMap::new();
88 let mut types = HashMap::new();
89 let mut type_variations = HashMap::new();
90
91 for ext in value {
92 match &ext.mapping_type {
93 Some(MappingType::ExtensionFunction(ext_f)) => {
94 functions.insert(ext_f.function_anchor, ext_f.name.to_owned());
95 }
96 Some(MappingType::ExtensionType(ext_t)) => {
97 types.insert(ext_t.type_anchor, ext_t.name.to_owned());
98 }
99 Some(MappingType::ExtensionTypeVariation(ext_v)) => {
100 type_variations
101 .insert(ext_v.type_variation_anchor, ext_v.name.to_owned());
102 }
103 None => return plan_err!("Cannot parse empty extension"),
104 }
105 }
106
107 Ok(Extensions {
108 functions,
109 types,
110 type_variations,
111 })
112 }
113}
114
115impl From<Extensions> for Vec<SimpleExtensionDeclaration> {
116 #[allow(deprecated)]
119 fn from(val: Extensions) -> Vec<SimpleExtensionDeclaration> {
120 let mut extensions = vec![];
121 for (f_anchor, f_name) in val.functions {
122 let function_extension = ExtensionFunction {
123 extension_uri_reference: u32::MAX,
124 extension_urn_reference: u32::MAX,
125 function_anchor: f_anchor,
126 name: f_name,
127 };
128 let simple_extension = SimpleExtensionDeclaration {
129 mapping_type: Some(MappingType::ExtensionFunction(function_extension)),
130 };
131 extensions.push(simple_extension);
132 }
133
134 for (t_anchor, t_name) in val.types {
135 let type_extension = ExtensionType {
136 extension_uri_reference: u32::MAX, extension_urn_reference: u32::MAX, type_anchor: t_anchor,
139 name: t_name,
140 };
141 let simple_extension = SimpleExtensionDeclaration {
142 mapping_type: Some(MappingType::ExtensionType(type_extension)),
143 };
144 extensions.push(simple_extension);
145 }
146
147 for (tv_anchor, tv_name) in val.type_variations {
148 let type_variation_extension = ExtensionTypeVariation {
149 extension_uri_reference: u32::MAX, extension_urn_reference: u32::MAX, type_variation_anchor: tv_anchor,
152 name: tv_name,
153 };
154 let simple_extension = SimpleExtensionDeclaration {
155 mapping_type: Some(MappingType::ExtensionTypeVariation(
156 type_variation_extension,
157 )),
158 };
159 extensions.push(simple_extension);
160 }
161
162 extensions
163 }
164}