1use crate::aggregates::{
2 avg_typed_value, group_concat_typed_value, max_typed_value, min_typed_value,
3 sum_typed_value,
4};
5use crate::builtin::encoding::{
6 with_plain_term_encoding, with_sortable_term_encoding, with_typed_value_encoding,
7};
8use crate::builtin::native::{
9 effective_boolean_value, native_boolean_as_term, native_int64_as_term,
10};
11use crate::builtin::query::is_compatible;
12use crate::scalar::comparison::{
13 EqualSparqlOp, GreaterOrEqualSparqlOp, GreaterThanSparqlOp, LessOrEqualSparqlOp,
14 LessThanSparqlOp, SameTermSparqlOp,
15};
16use crate::scalar::conversion::CastBooleanSparqlOp;
17use crate::scalar::conversion::CastDateTimeSparqlOp;
18use crate::scalar::conversion::CastDecimalSparqlOp;
19use crate::scalar::conversion::CastDoubleSparqlOp;
20use crate::scalar::conversion::CastFloatSparqlOp;
21use crate::scalar::conversion::CastIntSparqlOp;
22use crate::scalar::conversion::CastIntegerSparqlOp;
23use crate::scalar::conversion::CastStringSparqlOp;
24use crate::scalar::dates_and_times::HoursSparqlOp;
25use crate::scalar::dates_and_times::MinutesSparqlOp;
26use crate::scalar::dates_and_times::MonthSparqlOp;
27use crate::scalar::dates_and_times::SecondsSparqlOp;
28use crate::scalar::dates_and_times::TimezoneSparqlOp;
29use crate::scalar::dates_and_times::YearSparqlOp;
30use crate::scalar::dates_and_times::{DaySparqlOp, TzSparqlOp};
31use crate::scalar::functional_form::{BoundSparqlOp, CoalesceSparqlOp, IfSparqlOp};
32use crate::scalar::numeric::RoundSparqlOp;
33use crate::scalar::numeric::{AbsSparqlOp, UnaryMinusSparqlOp, UnaryPlusSparqlOp};
34use crate::scalar::numeric::{
35 AddSparqlOp, DivSparqlOp, FloorSparqlOp, MulSparqlOp, SubSparqlOp,
36};
37use crate::scalar::numeric::{CeilSparqlOp, RandSparqlOp};
38use crate::scalar::strings::{
39 ConcatSparqlOp, ContainsSparqlOp, EncodeForUriSparqlOp, LCaseSparqlOp,
40 LangMatchesSparqlOp, Md5SparqlOp, RegexSparqlOp, ReplaceSparqlOp, Sha1SparqlOp,
41 Sha256SparqlOp, Sha384SparqlOp, Sha512SparqlOp, StrAfterSparqlOp, StrBeforeSparqlOp,
42 StrEndsSparqlOp, StrLenSparqlOp, StrStartsSparqlOp, StrUuidSparqlOp, SubStrSparqlOp,
43 UCaseSparqlOp,
44};
45use crate::scalar::terms::{
46 BNodeSparqlOp, DatatypeSparqlOp, IriSparqlOp, IsBlankSparqlOp, IsIriSparqlOp,
47 IsLiteralSparqlOp, IsNumericSparqlOp, LangSparqlOp, StrDtSparqlOp, StrLangSparqlOp,
48 StrSparqlOp, UuidSparqlOp,
49};
50use crate::scalar::{ScalarSparqlOp, ScalarSparqlOpAdapter};
51use datafusion::common::plan_datafusion_err;
52use datafusion::execution::FunctionRegistry;
53use datafusion::execution::registry::MemoryFunctionRegistry;
54use datafusion::logical_expr::{AggregateUDF, ScalarUDF, TypeSignature};
55use rdf_fusion_encoding::{EncodingName, RdfFusionEncodings};
56use rdf_fusion_extensions::functions::{FunctionName, RdfFusionFunctionRegistry};
57use rdf_fusion_model::DFResult;
58use std::collections::{BTreeSet, HashMap};
59use std::fmt::Debug;
60use std::sync::{Arc, RwLock};
61
62pub struct DefaultRdfFusionFunctionRegistry {
71 encodings: RdfFusionEncodings,
73 inner: Arc<RwLock<RegistryContent>>,
76}
77
78pub struct RegistryContent {
80 udf_encodings: HashMap<String, Vec<EncodingName>>,
84 registry: MemoryFunctionRegistry,
86}
87
88impl Debug for DefaultRdfFusionFunctionRegistry {
89 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
90 f.debug_struct("DefaultRdfFusionFunctionRegistry")
91 .field("encodings", &self.encodings)
92 .finish()
93 }
94}
95
96impl DefaultRdfFusionFunctionRegistry {
97 pub fn new(encodings: RdfFusionEncodings) -> Self {
99 let mut registry = Self {
100 encodings,
101 inner: Arc::new(RwLock::new(RegistryContent {
102 udf_encodings: HashMap::default(),
103 registry: MemoryFunctionRegistry::default(),
104 })),
105 };
106 register_functions(&mut registry);
107 registry
108 }
109}
110
111impl RdfFusionFunctionRegistry for DefaultRdfFusionFunctionRegistry {
112 fn udf_supported_encodings(
113 &self,
114 function_name: &FunctionName,
115 ) -> DFResult<Vec<EncodingName>> {
116 self.inner
117 .read()
118 .unwrap()
119 .udf_encodings
120 .get(&function_name.to_string())
121 .cloned()
122 .ok_or_else(|| plan_datafusion_err!("Function {function_name} not found"))
123 }
124
125 fn udf(&self, function_name: &FunctionName) -> DFResult<Arc<ScalarUDF>> {
126 self.inner
127 .read()
128 .unwrap()
129 .registry
130 .udf(&function_name.to_string())
131 }
132
133 fn udaf(&self, function_name: &FunctionName) -> DFResult<Arc<AggregateUDF>> {
134 self.inner
135 .read()
136 .unwrap()
137 .registry
138 .udaf(&function_name.to_string())
139 }
140
141 fn register_udf(&self, udf: ScalarUDF) {
142 let supported_encodings =
143 supported_encodings(&self.encodings, &udf.signature().type_signature);
144
145 let mut lock = self.inner.write().unwrap();
146
147 lock.udf_encodings.insert(
148 udf.name().to_owned(),
149 supported_encodings.into_iter().collect(),
150 );
151 lock.registry
152 .register_udf(Arc::new(udf))
153 .expect("Cannot fail");
154 }
155
156 fn register_udaf(&self, udaf: AggregateUDF) {
157 self.inner
158 .write()
159 .unwrap()
160 .registry
161 .register_udaf(Arc::new(udaf))
162 .expect("Cannot fail");
163 }
164}
165
166fn supported_encodings(
168 encodings: &RdfFusionEncodings,
169 signature: &TypeSignature,
170) -> BTreeSet<EncodingName> {
171 match signature {
172 TypeSignature::Variadic(data_type) => data_type
173 .iter()
174 .flat_map(|dt| encodings.try_get_encoding_name(dt))
175 .collect(),
176 TypeSignature::Uniform(_, data_type) => data_type
177 .iter()
178 .flat_map(|dt| encodings.try_get_encoding_name(dt))
179 .collect(),
180 TypeSignature::OneOf(inner) => inner
181 .iter()
182 .flat_map(|ts| supported_encodings(encodings, ts).into_iter())
183 .collect(),
184 _ => BTreeSet::new(), }
186}
187
188fn create_scalar_udf<TSparqlOp>(encodings: RdfFusionEncodings) -> ScalarUDF
189where
190 TSparqlOp: ScalarSparqlOp + 'static + Default,
191{
192 let adapter = ScalarSparqlOpAdapter::new(encodings, TSparqlOp::default());
193 ScalarUDF::new_from_impl(adapter)
194}
195
196fn register_functions(registry: &mut DefaultRdfFusionFunctionRegistry) {
197 let scalar_fns: Vec<ScalarUDF> = vec![
198 create_scalar_udf::<StrSparqlOp>(registry.encodings.clone()),
199 create_scalar_udf::<LangSparqlOp>(registry.encodings.clone()),
200 create_scalar_udf::<LangMatchesSparqlOp>(registry.encodings.clone()),
201 create_scalar_udf::<DatatypeSparqlOp>(registry.encodings.clone()),
202 create_scalar_udf::<BNodeSparqlOp>(registry.encodings.clone()),
203 create_scalar_udf::<RandSparqlOp>(registry.encodings.clone()),
204 create_scalar_udf::<AbsSparqlOp>(registry.encodings.clone()),
205 create_scalar_udf::<CeilSparqlOp>(registry.encodings.clone()),
206 create_scalar_udf::<FloorSparqlOp>(registry.encodings.clone()),
207 create_scalar_udf::<RoundSparqlOp>(registry.encodings.clone()),
208 create_scalar_udf::<ConcatSparqlOp>(registry.encodings.clone()),
209 create_scalar_udf::<SubStrSparqlOp>(registry.encodings.clone()),
210 create_scalar_udf::<StrLenSparqlOp>(registry.encodings.clone()),
211 create_scalar_udf::<ReplaceSparqlOp>(registry.encodings.clone()),
212 create_scalar_udf::<UCaseSparqlOp>(registry.encodings.clone()),
213 create_scalar_udf::<LCaseSparqlOp>(registry.encodings.clone()),
214 create_scalar_udf::<EncodeForUriSparqlOp>(registry.encodings.clone()),
215 create_scalar_udf::<ContainsSparqlOp>(registry.encodings.clone()),
216 create_scalar_udf::<StrStartsSparqlOp>(registry.encodings.clone()),
217 create_scalar_udf::<StrEndsSparqlOp>(registry.encodings.clone()),
218 create_scalar_udf::<StrBeforeSparqlOp>(registry.encodings.clone()),
219 create_scalar_udf::<StrAfterSparqlOp>(registry.encodings.clone()),
220 create_scalar_udf::<YearSparqlOp>(registry.encodings.clone()),
221 create_scalar_udf::<MonthSparqlOp>(registry.encodings.clone()),
222 create_scalar_udf::<DaySparqlOp>(registry.encodings.clone()),
223 create_scalar_udf::<HoursSparqlOp>(registry.encodings.clone()),
224 create_scalar_udf::<MinutesSparqlOp>(registry.encodings.clone()),
225 create_scalar_udf::<SecondsSparqlOp>(registry.encodings.clone()),
226 create_scalar_udf::<TimezoneSparqlOp>(registry.encodings.clone()),
227 create_scalar_udf::<TzSparqlOp>(registry.encodings.clone()),
228 create_scalar_udf::<UuidSparqlOp>(registry.encodings.clone()),
229 create_scalar_udf::<StrUuidSparqlOp>(registry.encodings.clone()),
230 create_scalar_udf::<Md5SparqlOp>(registry.encodings.clone()),
231 create_scalar_udf::<Sha1SparqlOp>(registry.encodings.clone()),
232 create_scalar_udf::<Sha256SparqlOp>(registry.encodings.clone()),
233 create_scalar_udf::<Sha384SparqlOp>(registry.encodings.clone()),
234 create_scalar_udf::<Sha512SparqlOp>(registry.encodings.clone()),
235 create_scalar_udf::<StrLangSparqlOp>(registry.encodings.clone()),
236 create_scalar_udf::<StrDtSparqlOp>(registry.encodings.clone()),
237 create_scalar_udf::<IsIriSparqlOp>(registry.encodings.clone()),
238 create_scalar_udf::<IsBlankSparqlOp>(registry.encodings.clone()),
239 create_scalar_udf::<IsLiteralSparqlOp>(registry.encodings.clone()),
240 create_scalar_udf::<IsNumericSparqlOp>(registry.encodings.clone()),
241 create_scalar_udf::<RegexSparqlOp>(registry.encodings.clone()),
242 create_scalar_udf::<BoundSparqlOp>(registry.encodings.clone()),
243 create_scalar_udf::<CoalesceSparqlOp>(registry.encodings.clone()),
244 create_scalar_udf::<IfSparqlOp>(registry.encodings.clone()),
245 create_scalar_udf::<SameTermSparqlOp>(registry.encodings.clone()),
246 create_scalar_udf::<EqualSparqlOp>(registry.encodings.clone()),
247 create_scalar_udf::<GreaterThanSparqlOp>(registry.encodings.clone()),
248 create_scalar_udf::<GreaterOrEqualSparqlOp>(registry.encodings.clone()),
249 create_scalar_udf::<LessThanSparqlOp>(registry.encodings.clone()),
250 create_scalar_udf::<LessOrEqualSparqlOp>(registry.encodings.clone()),
251 create_scalar_udf::<AddSparqlOp>(registry.encodings.clone()),
252 create_scalar_udf::<DivSparqlOp>(registry.encodings.clone()),
253 create_scalar_udf::<MulSparqlOp>(registry.encodings.clone()),
254 create_scalar_udf::<SubSparqlOp>(registry.encodings.clone()),
255 create_scalar_udf::<UnaryMinusSparqlOp>(registry.encodings.clone()),
256 create_scalar_udf::<UnaryPlusSparqlOp>(registry.encodings.clone()),
257 create_scalar_udf::<CastStringSparqlOp>(registry.encodings.clone()),
258 create_scalar_udf::<CastIntegerSparqlOp>(registry.encodings.clone()),
259 create_scalar_udf::<CastIntSparqlOp>(registry.encodings.clone()),
260 create_scalar_udf::<CastFloatSparqlOp>(registry.encodings.clone()),
261 create_scalar_udf::<CastDoubleSparqlOp>(registry.encodings.clone()),
262 create_scalar_udf::<CastDecimalSparqlOp>(registry.encodings.clone()),
263 create_scalar_udf::<CastDateTimeSparqlOp>(registry.encodings.clone()),
264 create_scalar_udf::<CastBooleanSparqlOp>(registry.encodings.clone()),
265 create_scalar_udf::<IriSparqlOp>(registry.encodings.clone()),
266 with_sortable_term_encoding(registry.encodings.clone()),
267 with_plain_term_encoding(registry.encodings.clone()),
268 with_typed_value_encoding(registry.encodings.clone()),
269 effective_boolean_value(),
270 native_boolean_as_term(),
271 is_compatible(®istry.encodings.clone()),
272 native_int64_as_term(),
273 ];
274
275 for udf in scalar_fns {
276 registry.register_udf(udf);
277 }
278
279 let aggregate_fns: Vec<AggregateUDF> = vec![
281 sum_typed_value(),
282 min_typed_value(),
283 max_typed_value(),
284 avg_typed_value(),
285 group_concat_typed_value(),
286 ];
287
288 for udaf_information in aggregate_fns {
289 registry.register_udaf(udaf_information);
290 }
291}