rdf_fusion_functions/
registry.rs

1use crate::aggregates::{
2    avg_typed_value, group_concat_typed_value, max_typed_value, min_typed_value,
3    sum_typed_value,
4};
5use crate::builtin::encoding::{
6    with_plain_term_encoding, with_sortable_term_encoding, with_typed_value_encoding,
7};
8use crate::builtin::native::{
9    effective_boolean_value, native_boolean_as_term, native_int64_as_term,
10};
11use crate::builtin::query::is_compatible;
12use crate::scalar::comparison::{
13    EqualSparqlOp, GreaterOrEqualSparqlOp, GreaterThanSparqlOp, LessOrEqualSparqlOp,
14    LessThanSparqlOp, SameTermSparqlOp,
15};
16use crate::scalar::conversion::CastBooleanSparqlOp;
17use crate::scalar::conversion::CastDateTimeSparqlOp;
18use crate::scalar::conversion::CastDecimalSparqlOp;
19use crate::scalar::conversion::CastDoubleSparqlOp;
20use crate::scalar::conversion::CastFloatSparqlOp;
21use crate::scalar::conversion::CastIntSparqlOp;
22use crate::scalar::conversion::CastIntegerSparqlOp;
23use crate::scalar::conversion::CastStringSparqlOp;
24use crate::scalar::dates_and_times::HoursSparqlOp;
25use crate::scalar::dates_and_times::MinutesSparqlOp;
26use crate::scalar::dates_and_times::MonthSparqlOp;
27use crate::scalar::dates_and_times::SecondsSparqlOp;
28use crate::scalar::dates_and_times::TimezoneSparqlOp;
29use crate::scalar::dates_and_times::YearSparqlOp;
30use crate::scalar::dates_and_times::{DaySparqlOp, TzSparqlOp};
31use crate::scalar::functional_form::{BoundSparqlOp, CoalesceSparqlOp, IfSparqlOp};
32use crate::scalar::numeric::RoundSparqlOp;
33use crate::scalar::numeric::{AbsSparqlOp, UnaryMinusSparqlOp, UnaryPlusSparqlOp};
34use crate::scalar::numeric::{
35    AddSparqlOp, DivSparqlOp, FloorSparqlOp, MulSparqlOp, SubSparqlOp,
36};
37use crate::scalar::numeric::{CeilSparqlOp, RandSparqlOp};
38use crate::scalar::strings::{
39    ConcatSparqlOp, ContainsSparqlOp, EncodeForUriSparqlOp, LCaseSparqlOp,
40    LangMatchesSparqlOp, Md5SparqlOp, RegexSparqlOp, ReplaceSparqlOp, Sha1SparqlOp,
41    Sha256SparqlOp, Sha384SparqlOp, Sha512SparqlOp, StrAfterSparqlOp, StrBeforeSparqlOp,
42    StrEndsSparqlOp, StrLenSparqlOp, StrStartsSparqlOp, StrUuidSparqlOp, SubStrSparqlOp,
43    UCaseSparqlOp,
44};
45use crate::scalar::terms::{
46    BNodeSparqlOp, DatatypeSparqlOp, IriSparqlOp, IsBlankSparqlOp, IsIriSparqlOp,
47    IsLiteralSparqlOp, IsNumericSparqlOp, LangSparqlOp, StrDtSparqlOp, StrLangSparqlOp,
48    StrSparqlOp, UuidSparqlOp,
49};
50use crate::scalar::{ScalarSparqlOp, ScalarSparqlOpAdapter};
51use datafusion::common::plan_datafusion_err;
52use datafusion::execution::FunctionRegistry;
53use datafusion::execution::registry::MemoryFunctionRegistry;
54use datafusion::logical_expr::{AggregateUDF, ScalarUDF, TypeSignature};
55use rdf_fusion_encoding::{EncodingName, RdfFusionEncodings};
56use rdf_fusion_extensions::functions::{FunctionName, RdfFusionFunctionRegistry};
57use rdf_fusion_model::DFResult;
58use std::collections::{BTreeSet, HashMap};
59use std::fmt::Debug;
60use std::sync::{Arc, RwLock};
61
62/// The default implementation of the `RdfFusionFunctionRegistry` trait.
63///
64/// This registry provides implementations for all standard SPARQL functions
65/// defined in the SPARQL 1.1 specification, mapping them to their corresponding
66/// DataFusion UDFs and UDAFs.
67///
68/// # Additional Resources
69/// - [SPARQL 1.1 Query Language - Function Library](https://www.w3.org/TR/sparql11-query/#SparqlOps)
70pub struct DefaultRdfFusionFunctionRegistry {
71    /// The registered encodings.
72    encodings: RdfFusionEncodings,
73    /// A DataFusion [MemoryFunctionRegistry] that is used for actually storing the functions. Note
74    /// that this registry is *not* connected to the [SessionContext] of the RDF Fusion engine.
75    inner: Arc<RwLock<RegistryContent>>,
76}
77
78/// The actual data storage of the registry.
79pub struct RegistryContent {
80    /// The supported encodings for each function.
81    ///
82    /// Currently, this is not needed for aggregate functions as they only support typed values.
83    udf_encodings: HashMap<String, Vec<EncodingName>>,
84    /// The actual function registry.
85    registry: MemoryFunctionRegistry,
86}
87
88impl Debug for DefaultRdfFusionFunctionRegistry {
89    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
90        f.debug_struct("DefaultRdfFusionFunctionRegistry")
91            .field("encodings", &self.encodings)
92            .finish()
93    }
94}
95
96impl DefaultRdfFusionFunctionRegistry {
97    /// Create a new [DefaultRdfFusionFunctionRegistry].
98    pub fn new(encodings: RdfFusionEncodings) -> Self {
99        let mut registry = Self {
100            encodings,
101            inner: Arc::new(RwLock::new(RegistryContent {
102                udf_encodings: HashMap::default(),
103                registry: MemoryFunctionRegistry::default(),
104            })),
105        };
106        register_functions(&mut registry);
107        registry
108    }
109}
110
111impl RdfFusionFunctionRegistry for DefaultRdfFusionFunctionRegistry {
112    fn udf_supported_encodings(
113        &self,
114        function_name: &FunctionName,
115    ) -> DFResult<Vec<EncodingName>> {
116        self.inner
117            .read()
118            .unwrap()
119            .udf_encodings
120            .get(&function_name.to_string())
121            .cloned()
122            .ok_or_else(|| plan_datafusion_err!("Function {function_name} not found"))
123    }
124
125    fn udf(&self, function_name: &FunctionName) -> DFResult<Arc<ScalarUDF>> {
126        self.inner
127            .read()
128            .unwrap()
129            .registry
130            .udf(&function_name.to_string())
131    }
132
133    fn udaf(&self, function_name: &FunctionName) -> DFResult<Arc<AggregateUDF>> {
134        self.inner
135            .read()
136            .unwrap()
137            .registry
138            .udaf(&function_name.to_string())
139    }
140
141    fn register_udf(&self, udf: ScalarUDF) {
142        let supported_encodings =
143            supported_encodings(&self.encodings, &udf.signature().type_signature);
144
145        let mut lock = self.inner.write().unwrap();
146
147        lock.udf_encodings.insert(
148            udf.name().to_owned(),
149            supported_encodings.into_iter().collect(),
150        );
151        lock.registry
152            .register_udf(Arc::new(udf))
153            .expect("Cannot fail");
154    }
155
156    fn register_udaf(&self, udaf: AggregateUDF) {
157        self.inner
158            .write()
159            .unwrap()
160            .registry
161            .register_udaf(Arc::new(udaf))
162            .expect("Cannot fail");
163    }
164}
165
166/// Computes the supported encodings from the given type signature.
167fn supported_encodings(
168    encodings: &RdfFusionEncodings,
169    signature: &TypeSignature,
170) -> BTreeSet<EncodingName> {
171    match signature {
172        TypeSignature::Variadic(data_type) => data_type
173            .iter()
174            .flat_map(|dt| encodings.try_get_encoding_name(dt))
175            .collect(),
176        TypeSignature::Uniform(_, data_type) => data_type
177            .iter()
178            .flat_map(|dt| encodings.try_get_encoding_name(dt))
179            .collect(),
180        TypeSignature::OneOf(inner) => inner
181            .iter()
182            .flat_map(|ts| supported_encodings(encodings, ts).into_iter())
183            .collect(),
184        _ => BTreeSet::new(), // Unsupported type signature.
185    }
186}
187
188fn create_scalar_udf<TSparqlOp>(encodings: RdfFusionEncodings) -> ScalarUDF
189where
190    TSparqlOp: ScalarSparqlOp + 'static + Default,
191{
192    let adapter = ScalarSparqlOpAdapter::new(encodings, TSparqlOp::default());
193    ScalarUDF::new_from_impl(adapter)
194}
195
196fn register_functions(registry: &mut DefaultRdfFusionFunctionRegistry) {
197    let scalar_fns: Vec<ScalarUDF> = vec![
198        create_scalar_udf::<StrSparqlOp>(registry.encodings.clone()),
199        create_scalar_udf::<LangSparqlOp>(registry.encodings.clone()),
200        create_scalar_udf::<LangMatchesSparqlOp>(registry.encodings.clone()),
201        create_scalar_udf::<DatatypeSparqlOp>(registry.encodings.clone()),
202        create_scalar_udf::<BNodeSparqlOp>(registry.encodings.clone()),
203        create_scalar_udf::<RandSparqlOp>(registry.encodings.clone()),
204        create_scalar_udf::<AbsSparqlOp>(registry.encodings.clone()),
205        create_scalar_udf::<CeilSparqlOp>(registry.encodings.clone()),
206        create_scalar_udf::<FloorSparqlOp>(registry.encodings.clone()),
207        create_scalar_udf::<RoundSparqlOp>(registry.encodings.clone()),
208        create_scalar_udf::<ConcatSparqlOp>(registry.encodings.clone()),
209        create_scalar_udf::<SubStrSparqlOp>(registry.encodings.clone()),
210        create_scalar_udf::<StrLenSparqlOp>(registry.encodings.clone()),
211        create_scalar_udf::<ReplaceSparqlOp>(registry.encodings.clone()),
212        create_scalar_udf::<UCaseSparqlOp>(registry.encodings.clone()),
213        create_scalar_udf::<LCaseSparqlOp>(registry.encodings.clone()),
214        create_scalar_udf::<EncodeForUriSparqlOp>(registry.encodings.clone()),
215        create_scalar_udf::<ContainsSparqlOp>(registry.encodings.clone()),
216        create_scalar_udf::<StrStartsSparqlOp>(registry.encodings.clone()),
217        create_scalar_udf::<StrEndsSparqlOp>(registry.encodings.clone()),
218        create_scalar_udf::<StrBeforeSparqlOp>(registry.encodings.clone()),
219        create_scalar_udf::<StrAfterSparqlOp>(registry.encodings.clone()),
220        create_scalar_udf::<YearSparqlOp>(registry.encodings.clone()),
221        create_scalar_udf::<MonthSparqlOp>(registry.encodings.clone()),
222        create_scalar_udf::<DaySparqlOp>(registry.encodings.clone()),
223        create_scalar_udf::<HoursSparqlOp>(registry.encodings.clone()),
224        create_scalar_udf::<MinutesSparqlOp>(registry.encodings.clone()),
225        create_scalar_udf::<SecondsSparqlOp>(registry.encodings.clone()),
226        create_scalar_udf::<TimezoneSparqlOp>(registry.encodings.clone()),
227        create_scalar_udf::<TzSparqlOp>(registry.encodings.clone()),
228        create_scalar_udf::<UuidSparqlOp>(registry.encodings.clone()),
229        create_scalar_udf::<StrUuidSparqlOp>(registry.encodings.clone()),
230        create_scalar_udf::<Md5SparqlOp>(registry.encodings.clone()),
231        create_scalar_udf::<Sha1SparqlOp>(registry.encodings.clone()),
232        create_scalar_udf::<Sha256SparqlOp>(registry.encodings.clone()),
233        create_scalar_udf::<Sha384SparqlOp>(registry.encodings.clone()),
234        create_scalar_udf::<Sha512SparqlOp>(registry.encodings.clone()),
235        create_scalar_udf::<StrLangSparqlOp>(registry.encodings.clone()),
236        create_scalar_udf::<StrDtSparqlOp>(registry.encodings.clone()),
237        create_scalar_udf::<IsIriSparqlOp>(registry.encodings.clone()),
238        create_scalar_udf::<IsBlankSparqlOp>(registry.encodings.clone()),
239        create_scalar_udf::<IsLiteralSparqlOp>(registry.encodings.clone()),
240        create_scalar_udf::<IsNumericSparqlOp>(registry.encodings.clone()),
241        create_scalar_udf::<RegexSparqlOp>(registry.encodings.clone()),
242        create_scalar_udf::<BoundSparqlOp>(registry.encodings.clone()),
243        create_scalar_udf::<CoalesceSparqlOp>(registry.encodings.clone()),
244        create_scalar_udf::<IfSparqlOp>(registry.encodings.clone()),
245        create_scalar_udf::<SameTermSparqlOp>(registry.encodings.clone()),
246        create_scalar_udf::<EqualSparqlOp>(registry.encodings.clone()),
247        create_scalar_udf::<GreaterThanSparqlOp>(registry.encodings.clone()),
248        create_scalar_udf::<GreaterOrEqualSparqlOp>(registry.encodings.clone()),
249        create_scalar_udf::<LessThanSparqlOp>(registry.encodings.clone()),
250        create_scalar_udf::<LessOrEqualSparqlOp>(registry.encodings.clone()),
251        create_scalar_udf::<AddSparqlOp>(registry.encodings.clone()),
252        create_scalar_udf::<DivSparqlOp>(registry.encodings.clone()),
253        create_scalar_udf::<MulSparqlOp>(registry.encodings.clone()),
254        create_scalar_udf::<SubSparqlOp>(registry.encodings.clone()),
255        create_scalar_udf::<UnaryMinusSparqlOp>(registry.encodings.clone()),
256        create_scalar_udf::<UnaryPlusSparqlOp>(registry.encodings.clone()),
257        create_scalar_udf::<CastStringSparqlOp>(registry.encodings.clone()),
258        create_scalar_udf::<CastIntegerSparqlOp>(registry.encodings.clone()),
259        create_scalar_udf::<CastIntSparqlOp>(registry.encodings.clone()),
260        create_scalar_udf::<CastFloatSparqlOp>(registry.encodings.clone()),
261        create_scalar_udf::<CastDoubleSparqlOp>(registry.encodings.clone()),
262        create_scalar_udf::<CastDecimalSparqlOp>(registry.encodings.clone()),
263        create_scalar_udf::<CastDateTimeSparqlOp>(registry.encodings.clone()),
264        create_scalar_udf::<CastBooleanSparqlOp>(registry.encodings.clone()),
265        create_scalar_udf::<IriSparqlOp>(registry.encodings.clone()),
266        with_sortable_term_encoding(registry.encodings.clone()),
267        with_plain_term_encoding(registry.encodings.clone()),
268        with_typed_value_encoding(registry.encodings.clone()),
269        effective_boolean_value(),
270        native_boolean_as_term(),
271        is_compatible(&registry.encodings.clone()),
272        native_int64_as_term(),
273    ];
274
275    for udf in scalar_fns {
276        registry.register_udf(udf);
277    }
278
279    // Aggregate functions
280    let aggregate_fns: Vec<AggregateUDF> = vec![
281        sum_typed_value(),
282        min_typed_value(),
283        max_typed_value(),
284        avg_typed_value(),
285        group_concat_typed_value(),
286    ];
287
288    for udaf_information in aggregate_fns {
289        registry.register_udaf(udaf_information);
290    }
291}