rdf_fusion_functions/scalar/
sparql_op.rs

1use crate::scalar::ScalarSparqlOpArgs;
2use crate::scalar::sparql_op_impl::ScalarSparqlOpImpl;
3use datafusion::arrow::datatypes::DataType;
4use datafusion::common::{exec_datafusion_err, exec_err, plan_err};
5use datafusion::logical_expr::{
6    ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature,
7    Volatility,
8};
9use rdf_fusion_encoding::object_id::ObjectIdEncoding;
10use rdf_fusion_encoding::plain_term::{PLAIN_TERM_ENCODING, PlainTermEncoding};
11use rdf_fusion_encoding::typed_value::{TYPED_VALUE_ENCODING, TypedValueEncoding};
12use rdf_fusion_encoding::{EncodingName, RdfFusionEncodings, TermEncoding};
13use rdf_fusion_extensions::functions::FunctionName;
14use rdf_fusion_model::DFResult;
15use std::any::Any;
16use std::collections::HashSet;
17use std::fmt::Debug;
18use std::hash::{Hash, Hasher};
19
20/// Defines the arity of a SPARQL operation. In other words, the number of arguments that the
21/// [ScalarSparqlOp] has.
22#[derive(Debug, Clone, Eq, PartialEq, Hash)]
23pub enum SparqlOpArity {
24    /// No arguments.
25    Nullary,
26    /// A fixed number of arguments.
27    ///
28    /// `SparqlOpArity::Fixed(0)` is equivalent to [SparqlOpArity::Nullary].
29    Fixed(usize),
30    /// One of the given [SparqlOpArity].
31    OneOf(Vec<SparqlOpArity>),
32    /// Any number of arguments (including zero).
33    Variadic,
34}
35
36impl SparqlOpArity {
37    /// Returns a [TypeSignature] for the given [SparqlOpArity].
38    pub fn type_signature<TEncoding: TermEncoding>(
39        &self,
40        encoding: &TEncoding,
41    ) -> TypeSignature {
42        match self {
43            SparqlOpArity::Nullary => TypeSignature::Nullary,
44            SparqlOpArity::Fixed(n) => {
45                TypeSignature::Uniform(*n, vec![encoding.data_type()])
46            }
47            SparqlOpArity::OneOf(ns) => {
48                let inner = ns
49                    .iter()
50                    .map(|n| n.type_signature(encoding))
51                    .collect::<Vec<_>>();
52                TypeSignature::OneOf(inner)
53            }
54            SparqlOpArity::Variadic => TypeSignature::OneOf(vec![
55                TypeSignature::Nullary,
56                TypeSignature::Variadic(vec![encoding.data_type()]),
57            ]),
58        }
59    }
60}
61
62/// Defines further details about a [ScalarSparqlOp].
63pub struct ScalarSparqlOpSignature {
64    /// Whether the [ScalarSparqlOp] is volatile. See [Volatility] for more information.
65    pub volatility: Volatility,
66    /// The [SparqlOpArity] of the [ScalarSparqlOp].
67    pub arity: SparqlOpArity,
68}
69
70impl ScalarSparqlOpSignature {
71    /// Returns a [ScalarSparqlOpSignature] with the given arity and [Volatility::Immutable].
72    pub fn default_with_arity(arity: SparqlOpArity) -> Self {
73        Self {
74            volatility: Volatility::Immutable,
75            arity,
76        }
77    }
78}
79
80/// A [ScalarSparqlOp] is a function that operates on RDF terms. The function may return a different
81/// type of value. For example, a function that takes two RDF terms and outputs an integer can be
82/// implemented using this trait.
83///
84/// The goal is to make it easier for users to implement custom SPARQL functions. The different
85/// encodings of RDF Fusion are handled by providing a [ScalarSparqlOpImpl] for any given encoding.
86///
87/// To install a [ScalarSparqlOp] in DataFusion, use the [ScalarSparqlOpAdapter]. The adapter will
88/// mediate between DataFusion's API and the given [ScalarSparqlOpImpl].
89pub trait ScalarSparqlOp: Debug + Hash + Eq + Send + Sync {
90    /// Returns the name of the operation.
91    fn name(&self) -> &FunctionName;
92
93    /// Returns the signature of this operation.
94    fn signature(&self) -> ScalarSparqlOpSignature;
95
96    /// Returns the [ScalarSparqlOpImpl] for the [TypedValueEncoding].
97    ///
98    /// If [None] is returned, the operation does not support the [TypedValueEncoding].
99    fn typed_value_encoding_op(
100        &self,
101    ) -> Option<Box<dyn ScalarSparqlOpImpl<TypedValueEncoding>>> {
102        None
103    }
104
105    /// Returns the [ScalarSparqlOpImpl] for the [PlainTermEncoding].
106    ///
107    /// If [None] is returned, the operation does not support the [PlainTermEncoding].
108    fn plain_term_encoding_op(
109        &self,
110    ) -> Option<Box<dyn ScalarSparqlOpImpl<PlainTermEncoding>>> {
111        None
112    }
113
114    /// Returns the [ScalarSparqlOpImpl] for the [ObjectIdEncoding].
115    ///
116    /// If [None] is returned, the operation does not support the [ObjectIdEncoding].
117    fn object_id_encoding_op(
118        &self,
119        _object_id_encoding: &ObjectIdEncoding,
120    ) -> Option<Box<dyn ScalarSparqlOpImpl<ObjectIdEncoding>>> {
121        None
122    }
123}
124
125/// Mediates between DataFusion's API and a [ScalarSparqlOp].
126///
127/// This includes the following tasks:
128/// - Set up the argument types of the UDFs depending on the supported encodings
129/// - Set up the argument types of the UDFs depending on the configured encodings in the engine
130/// - Detecting the used input encoding and calling the correct [ScalarSparqlOpImpl].
131#[derive(Debug, Eq)]
132pub struct ScalarSparqlOpAdapter<TScalarSparqlOp: ScalarSparqlOp> {
133    /// The stringified name of the [ScalarSparqlOp].
134    name: String,
135    /// The DataFusion [Signature] of the [ScalarSparqlOp].
136    signature: Signature,
137    /// The instance of the [ScalarSparqlOp].
138    op: TScalarSparqlOp,
139    /// The configured [RdfFusionEncodings] in the engine.
140    encodings: RdfFusionEncodings,
141}
142
143impl<TScalarSparqlOp: ScalarSparqlOp> ScalarSparqlOpAdapter<TScalarSparqlOp> {
144    /// Creates a new adapter for the given `op`.
145    pub fn new(encodings: RdfFusionEncodings, op: TScalarSparqlOp) -> Self {
146        let name = op.name().to_string();
147        let details = op.signature();
148
149        let mut type_signatures = Vec::new();
150        if op.plain_term_encoding_op().is_some() {
151            let type_signature = details.arity.type_signature(encodings.plain_term());
152            type_signatures.push(type_signature);
153        }
154
155        if op.typed_value_encoding_op().is_some() {
156            let type_signature = details.arity.type_signature(encodings.typed_value());
157            type_signatures.push(type_signature);
158        }
159
160        if let Some(oid_encoding) = encodings.object_id() {
161            if op.object_id_encoding_op(oid_encoding).is_some() {
162                let type_signature = details.arity.type_signature(oid_encoding);
163                type_signatures.push(type_signature);
164            }
165        }
166
167        let type_signature = if type_signatures.len() == 1 {
168            type_signatures.pop().unwrap()
169        } else if type_signatures.is_empty() {
170            TypeSignature::Variadic(vec![]) // Or handle this case as an error if no encodings are supported
171        } else {
172            TypeSignature::OneOf(type_signatures)
173        };
174        let signature = Signature::new(type_signature, details.volatility);
175
176        Self {
177            name,
178            signature,
179            op,
180            encodings,
181        }
182    }
183
184    fn detect_input_encoding(
185        &self,
186        arg_types: &[DataType],
187    ) -> DFResult<Option<EncodingName>> {
188        let encoding_name = arg_types
189            .iter()
190            .map(|dt| {
191                self.encodings
192                    .try_get_encoding_name(dt)
193                    .ok_or(exec_datafusion_err!(
194                        "Cannot extract RDF term encoding from argument."
195                    ))
196            })
197            .collect::<DFResult<HashSet<_>>>()?;
198
199        if encoding_name.is_empty() {
200            return Ok(None);
201        }
202
203        if encoding_name.len() > 1 {
204            return plan_err!("More than one RDF term encoding used for arguments.");
205        }
206        Ok(encoding_name.into_iter().next())
207    }
208
209    fn prepare_args<TEncoding: TermEncoding>(
210        &self,
211        encoding: &TEncoding,
212        args: ScalarFunctionArgs,
213    ) -> DFResult<ScalarSparqlOpArgs<TEncoding>> {
214        let sparql_args = args
215            .args
216            .into_iter()
217            .map(|cv| encoding.try_new_datum(cv, args.number_rows))
218            .collect::<DFResult<Vec<_>>>()?;
219
220        Ok(ScalarSparqlOpArgs {
221            number_rows: args.number_rows,
222            args: sparql_args,
223        })
224    }
225}
226
227impl<TScalarSparqlOp: ScalarSparqlOp + 'static> ScalarUDFImpl
228    for ScalarSparqlOpAdapter<TScalarSparqlOp>
229{
230    fn as_any(&self) -> &dyn Any {
231        self
232    }
233
234    fn name(&self) -> &str {
235        &self.name
236    }
237
238    fn signature(&self) -> &Signature {
239        &self.signature
240    }
241
242    fn return_type(&self, arg_types: &[DataType]) -> DFResult<DataType> {
243        let encoding_name = self.detect_input_encoding(arg_types)?;
244        match encoding_name {
245            None => {
246                if let Some(op_impl) = self.op.plain_term_encoding_op() {
247                    Ok(op_impl.return_type())
248                } else if let Some(op_impl) = self.op.typed_value_encoding_op() {
249                    Ok(op_impl.return_type())
250                } else if let Some(oid_encoding) = self.encodings.object_id()
251                    && let Some(op_impl) = self.op.object_id_encoding_op(oid_encoding)
252                {
253                    Ok(op_impl.return_type())
254                } else {
255                    exec_err!(
256                        "The SPARQL operation '{}' does not support any encoding.",
257                        &self.name
258                    )
259                }
260            }
261            Some(EncodingName::PlainTerm) => self
262                .op
263                .plain_term_encoding_op()
264                .ok_or(exec_datafusion_err!(
265                    "The SPARQL operation '{}' does not support the PlainTerm encoding.",
266                    &self.name
267                ))
268                .map(|op_impl| op_impl.return_type()),
269            Some(EncodingName::TypedValue) => self
270                .op
271                .typed_value_encoding_op()
272                .ok_or(exec_datafusion_err!(
273                    "The SPARQL operation '{}' does not support the TypedValue encoding.",
274                    &self.name
275                ))
276                .map(|op_impl| op_impl.return_type()),
277            Some(EncodingName::Sortable) => {
278                exec_err!(
279                    "The SparqlOp infrastructure does not support the Sortable encoding."
280                )
281            }
282            Some(EncodingName::ObjectId) => {
283                let encoding = self.encodings.object_id().ok_or(exec_datafusion_err!(
284                    "Could not find the object id encoding."
285                ))?;
286
287                self
288                    .op
289                    .object_id_encoding_op(encoding)
290                    .ok_or(exec_datafusion_err!(
291                    "The SPARQL operation '{}' does not support the ObjectID encoding.",
292                    &self.name
293                ))
294                    .map(|op_impl| op_impl.return_type())
295            }
296        }
297    }
298
299    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
300        let data_types = args
301            .arg_fields
302            .iter()
303            .map(|f| f.data_type().clone())
304            .collect::<Vec<_>>();
305        let encoding = self.detect_input_encoding(&data_types)?;
306
307        let encoding = match encoding {
308            None => {
309                if self.op.typed_value_encoding_op().is_some() {
310                    EncodingName::TypedValue
311                } else if self.op.plain_term_encoding_op().is_some() {
312                    EncodingName::PlainTerm
313                } else if let Some(oid_encoding) = self.encodings.object_id()
314                    && self.op.object_id_encoding_op(oid_encoding).is_some()
315                {
316                    EncodingName::ObjectId
317                } else {
318                    return exec_err!("No supported encodings");
319                }
320            }
321            Some(encoding) => encoding,
322        };
323
324        match encoding {
325            EncodingName::PlainTerm => {
326                if let Some(op) = self.op.plain_term_encoding_op() {
327                    op.invoke(self.prepare_args(&PLAIN_TERM_ENCODING, args)?)
328                } else {
329                    exec_err!("PlainTerm encoding not supported for this operation")
330                }
331            }
332            EncodingName::TypedValue => {
333                if let Some(op) = self.op.typed_value_encoding_op() {
334                    op.invoke(self.prepare_args(&TYPED_VALUE_ENCODING, args)?)
335                } else {
336                    exec_err!("TypedValue encoding not supported for this operation")
337                }
338            }
339            EncodingName::ObjectId => {
340                let Some(object_id_encoding) = self.encodings.object_id() else {
341                    return exec_err!("Object ID is not registered.");
342                };
343
344                if let Some(op) = self.op.object_id_encoding_op(object_id_encoding) {
345                    op.invoke(self.prepare_args(object_id_encoding, args)?)
346                } else {
347                    exec_err!("TypedValue encoding not supported for this operation")
348                }
349            }
350            EncodingName::Sortable => exec_err!("Not supported"),
351        }
352    }
353}
354
355/// While it would be possible to create two different SparqlOpAdapters for the same
356/// [ScalarSparqlOp] that are not identical (different encodings), this situation is unlikely to
357/// happen when using RDF Fusion "normally". Therefore, we only hash the contained [ScalarSparqlOp].
358impl<TScalarSparqlOp: ScalarSparqlOp> Hash for ScalarSparqlOpAdapter<TScalarSparqlOp> {
359    fn hash<H: Hasher>(&self, state: &mut H) {
360        self.op.hash(state);
361    }
362}
363
364impl<TScalarSparqlOp: ScalarSparqlOp> PartialEq
365    for ScalarSparqlOpAdapter<TScalarSparqlOp>
366{
367    fn eq(&self, other: &Self) -> bool {
368        self.op == other.op && self.encodings == other.encodings
369    }
370}