datafusion_substrait/
extensions.rs1use datafusion::common::{DataFusionError, HashMap, plan_err};
19use substrait::proto::extensions::SimpleExtensionDeclaration;
20use substrait::proto::extensions::simple_extension_declaration::{
21 ExtensionFunction, ExtensionType, ExtensionTypeVariation, MappingType,
22};
23
24#[derive(Default, Debug, PartialEq)]
31pub struct Extensions {
32 pub functions: HashMap<u32, String>, pub types: HashMap<u32, String>, pub type_variations: HashMap<u32, String>, }
36
37impl Extensions {
38 pub fn register_function(&mut self, function_name: &str) -> u32 {
42 let function_name = function_name.to_lowercase();
43
44 let function_name = match function_name.as_str() {
47 "substr" => "substring".to_string(),
48 "log" => "logb".to_string(),
49 "isnan" => "is_nan".to_string(),
50 _ => function_name,
51 };
52
53 match self.functions.iter().find(|(_, f)| *f == &function_name) {
54 Some((function_anchor, _)) => *function_anchor, None => {
56 let function_anchor = self.functions.len() as u32;
58 self.functions
59 .insert(function_anchor, function_name.clone());
60 function_anchor
61 }
62 }
63 }
64
65 pub fn register_type(&mut self, type_name: &str) -> u32 {
68 let type_name = type_name.to_lowercase();
69 match self.types.iter().find(|(_, t)| *t == &type_name) {
70 Some((type_anchor, _)) => *type_anchor, None => {
72 let type_anchor = self.types.len() as u32;
74 self.types.insert(type_anchor, type_name.clone());
75 type_anchor
76 }
77 }
78 }
79}
80
81impl TryFrom<&Vec<SimpleExtensionDeclaration>> for Extensions {
82 type Error = DataFusionError;
83
84 fn try_from(
85 value: &Vec<SimpleExtensionDeclaration>,
86 ) -> datafusion::common::Result<Self> {
87 let mut functions = HashMap::new();
88 let mut types = HashMap::new();
89 let mut type_variations = HashMap::new();
90
91 for ext in value {
92 match &ext.mapping_type {
93 Some(MappingType::ExtensionFunction(ext_f)) => {
94 functions.insert(ext_f.function_anchor, ext_f.name.to_owned());
95 }
96 Some(MappingType::ExtensionType(ext_t)) => {
97 types.insert(ext_t.type_anchor, ext_t.name.to_owned());
98 }
99 Some(MappingType::ExtensionTypeVariation(ext_v)) => {
100 type_variations
101 .insert(ext_v.type_variation_anchor, ext_v.name.to_owned());
102 }
103 None => return plan_err!("Cannot parse empty extension"),
104 }
105 }
106
107 Ok(Extensions {
108 functions,
109 types,
110 type_variations,
111 })
112 }
113}
114
115impl From<Extensions> for Vec<SimpleExtensionDeclaration> {
116 fn from(val: Extensions) -> Vec<SimpleExtensionDeclaration> {
117 let mut extensions = vec![];
118 for (f_anchor, f_name) in val.functions {
119 let function_extension = ExtensionFunction {
120 extension_urn_reference: u32::MAX,
121 function_anchor: f_anchor,
122 name: f_name,
123 };
124 let simple_extension = SimpleExtensionDeclaration {
125 mapping_type: Some(MappingType::ExtensionFunction(function_extension)),
126 };
127 extensions.push(simple_extension);
128 }
129
130 for (t_anchor, t_name) in val.types {
131 let type_extension = ExtensionType {
132 extension_urn_reference: u32::MAX, type_anchor: t_anchor,
134 name: t_name,
135 };
136 let simple_extension = SimpleExtensionDeclaration {
137 mapping_type: Some(MappingType::ExtensionType(type_extension)),
138 };
139 extensions.push(simple_extension);
140 }
141
142 for (tv_anchor, tv_name) in val.type_variations {
143 let type_variation_extension = ExtensionTypeVariation {
144 extension_urn_reference: u32::MAX, type_variation_anchor: tv_anchor,
146 name: tv_name,
147 };
148 let simple_extension = SimpleExtensionDeclaration {
149 mapping_type: Some(MappingType::ExtensionTypeVariation(
150 type_variation_extension,
151 )),
152 };
153 extensions.push(simple_extension);
154 }
155
156 extensions
157 }
158}