datafusion_substrait/
extensions.rs1use datafusion::common::{plan_err, DataFusionError, HashMap};
19use substrait::proto::extensions::simple_extension_declaration::{
20 ExtensionFunction, ExtensionType, ExtensionTypeVariation, MappingType,
21};
22use substrait::proto::extensions::SimpleExtensionDeclaration;
23
24#[derive(Default, Debug, PartialEq)]
31pub struct Extensions {
32 pub functions: HashMap<u32, String>, pub types: HashMap<u32, String>, pub type_variations: HashMap<u32, String>, }
36
37impl Extensions {
38 pub fn register_function(&mut self, function_name: String) -> u32 {
42 let function_name = function_name.to_lowercase();
43
44 let function_name = match function_name.as_str() {
47 "substr" => "substring".to_string(),
48 _ => function_name,
49 };
50
51 match self.functions.iter().find(|(_, f)| *f == &function_name) {
52 Some((function_anchor, _)) => *function_anchor, None => {
54 let function_anchor = self.functions.len() as u32;
56 self.functions
57 .insert(function_anchor, function_name.clone());
58 function_anchor
59 }
60 }
61 }
62
63 pub fn register_type(&mut self, type_name: String) -> u32 {
66 let type_name = type_name.to_lowercase();
67 match self.types.iter().find(|(_, t)| *t == &type_name) {
68 Some((type_anchor, _)) => *type_anchor, None => {
70 let type_anchor = self.types.len() as u32;
72 self.types.insert(type_anchor, type_name.clone());
73 type_anchor
74 }
75 }
76 }
77}
78
79impl TryFrom<&Vec<SimpleExtensionDeclaration>> for Extensions {
80 type Error = DataFusionError;
81
82 fn try_from(
83 value: &Vec<SimpleExtensionDeclaration>,
84 ) -> datafusion::common::Result<Self> {
85 let mut functions = HashMap::new();
86 let mut types = HashMap::new();
87 let mut type_variations = HashMap::new();
88
89 for ext in value {
90 match &ext.mapping_type {
91 Some(MappingType::ExtensionFunction(ext_f)) => {
92 functions.insert(ext_f.function_anchor, ext_f.name.to_owned());
93 }
94 Some(MappingType::ExtensionType(ext_t)) => {
95 types.insert(ext_t.type_anchor, ext_t.name.to_owned());
96 }
97 Some(MappingType::ExtensionTypeVariation(ext_v)) => {
98 type_variations
99 .insert(ext_v.type_variation_anchor, ext_v.name.to_owned());
100 }
101 None => return plan_err!("Cannot parse empty extension"),
102 }
103 }
104
105 Ok(Extensions {
106 functions,
107 types,
108 type_variations,
109 })
110 }
111}
112
113impl From<Extensions> for Vec<SimpleExtensionDeclaration> {
114 fn from(val: Extensions) -> Vec<SimpleExtensionDeclaration> {
115 let mut extensions = vec![];
116 for (f_anchor, f_name) in val.functions {
117 let function_extension = ExtensionFunction {
118 extension_uri_reference: u32::MAX,
119 function_anchor: f_anchor,
120 name: f_name,
121 };
122 let simple_extension = SimpleExtensionDeclaration {
123 mapping_type: Some(MappingType::ExtensionFunction(function_extension)),
124 };
125 extensions.push(simple_extension);
126 }
127
128 for (t_anchor, t_name) in val.types {
129 let type_extension = ExtensionType {
130 extension_uri_reference: u32::MAX, type_anchor: t_anchor,
132 name: t_name,
133 };
134 let simple_extension = SimpleExtensionDeclaration {
135 mapping_type: Some(MappingType::ExtensionType(type_extension)),
136 };
137 extensions.push(simple_extension);
138 }
139
140 for (tv_anchor, tv_name) in val.type_variations {
141 let type_variation_extension = ExtensionTypeVariation {
142 extension_uri_reference: u32::MAX, type_variation_anchor: tv_anchor,
144 name: tv_name,
145 };
146 let simple_extension = SimpleExtensionDeclaration {
147 mapping_type: Some(MappingType::ExtensionTypeVariation(
148 type_variation_extension,
149 )),
150 };
151 extensions.push(simple_extension);
152 }
153
154 extensions
155 }
156}