rdf_fusion_functions/scalar/
sparql_op.rs1use crate::scalar::ScalarSparqlOpArgs;
2use crate::scalar::sparql_op_impl::ScalarSparqlOpImpl;
3use datafusion::arrow::datatypes::DataType;
4use datafusion::common::{exec_datafusion_err, exec_err, plan_err};
5use datafusion::logical_expr::{
6 ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature,
7 Volatility,
8};
9use rdf_fusion_encoding::object_id::ObjectIdEncoding;
10use rdf_fusion_encoding::plain_term::{PLAIN_TERM_ENCODING, PlainTermEncoding};
11use rdf_fusion_encoding::typed_value::{TYPED_VALUE_ENCODING, TypedValueEncoding};
12use rdf_fusion_encoding::{EncodingName, RdfFusionEncodings, TermEncoding};
13use rdf_fusion_extensions::functions::FunctionName;
14use rdf_fusion_model::DFResult;
15use std::any::Any;
16use std::collections::HashSet;
17use std::fmt::Debug;
18use std::hash::{Hash, Hasher};
19
20#[derive(Debug, Clone, Eq, PartialEq, Hash)]
23pub enum SparqlOpArity {
24 Nullary,
26 Fixed(usize),
30 OneOf(Vec<SparqlOpArity>),
32 Variadic,
34}
35
36impl SparqlOpArity {
37 pub fn type_signature<TEncoding: TermEncoding>(
39 &self,
40 encoding: &TEncoding,
41 ) -> TypeSignature {
42 match self {
43 SparqlOpArity::Nullary => TypeSignature::Nullary,
44 SparqlOpArity::Fixed(n) => {
45 TypeSignature::Uniform(*n, vec![encoding.data_type()])
46 }
47 SparqlOpArity::OneOf(ns) => {
48 let inner = ns
49 .iter()
50 .map(|n| n.type_signature(encoding))
51 .collect::<Vec<_>>();
52 TypeSignature::OneOf(inner)
53 }
54 SparqlOpArity::Variadic => TypeSignature::OneOf(vec![
55 TypeSignature::Nullary,
56 TypeSignature::Variadic(vec![encoding.data_type()]),
57 ]),
58 }
59 }
60}
61
62pub struct ScalarSparqlOpSignature {
64 pub volatility: Volatility,
66 pub arity: SparqlOpArity,
68}
69
70impl ScalarSparqlOpSignature {
71 pub fn default_with_arity(arity: SparqlOpArity) -> Self {
73 Self {
74 volatility: Volatility::Immutable,
75 arity,
76 }
77 }
78}
79
80pub trait ScalarSparqlOp: Debug + Hash + Eq + Send + Sync {
90 fn name(&self) -> &FunctionName;
92
93 fn signature(&self) -> ScalarSparqlOpSignature;
95
96 fn typed_value_encoding_op(
100 &self,
101 ) -> Option<Box<dyn ScalarSparqlOpImpl<TypedValueEncoding>>> {
102 None
103 }
104
105 fn plain_term_encoding_op(
109 &self,
110 ) -> Option<Box<dyn ScalarSparqlOpImpl<PlainTermEncoding>>> {
111 None
112 }
113
114 fn object_id_encoding_op(
118 &self,
119 _object_id_encoding: &ObjectIdEncoding,
120 ) -> Option<Box<dyn ScalarSparqlOpImpl<ObjectIdEncoding>>> {
121 None
122 }
123}
124
125#[derive(Debug, Eq)]
132pub struct ScalarSparqlOpAdapter<TScalarSparqlOp: ScalarSparqlOp> {
133 name: String,
135 signature: Signature,
137 op: TScalarSparqlOp,
139 encodings: RdfFusionEncodings,
141}
142
143impl<TScalarSparqlOp: ScalarSparqlOp> ScalarSparqlOpAdapter<TScalarSparqlOp> {
144 pub fn new(encodings: RdfFusionEncodings, op: TScalarSparqlOp) -> Self {
146 let name = op.name().to_string();
147 let details = op.signature();
148
149 let mut type_signatures = Vec::new();
150 if op.plain_term_encoding_op().is_some() {
151 let type_signature = details.arity.type_signature(encodings.plain_term());
152 type_signatures.push(type_signature);
153 }
154
155 if op.typed_value_encoding_op().is_some() {
156 let type_signature = details.arity.type_signature(encodings.typed_value());
157 type_signatures.push(type_signature);
158 }
159
160 if let Some(oid_encoding) = encodings.object_id() {
161 if op.object_id_encoding_op(oid_encoding).is_some() {
162 let type_signature = details.arity.type_signature(oid_encoding);
163 type_signatures.push(type_signature);
164 }
165 }
166
167 let type_signature = if type_signatures.len() == 1 {
168 type_signatures.pop().unwrap()
169 } else if type_signatures.is_empty() {
170 TypeSignature::Variadic(vec![]) } else {
172 TypeSignature::OneOf(type_signatures)
173 };
174 let signature = Signature::new(type_signature, details.volatility);
175
176 Self {
177 name,
178 signature,
179 op,
180 encodings,
181 }
182 }
183
184 fn detect_input_encoding(
185 &self,
186 arg_types: &[DataType],
187 ) -> DFResult<Option<EncodingName>> {
188 let encoding_name = arg_types
189 .iter()
190 .map(|dt| {
191 self.encodings
192 .try_get_encoding_name(dt)
193 .ok_or(exec_datafusion_err!(
194 "Cannot extract RDF term encoding from argument."
195 ))
196 })
197 .collect::<DFResult<HashSet<_>>>()?;
198
199 if encoding_name.is_empty() {
200 return Ok(None);
201 }
202
203 if encoding_name.len() > 1 {
204 return plan_err!("More than one RDF term encoding used for arguments.");
205 }
206 Ok(encoding_name.into_iter().next())
207 }
208
209 fn prepare_args<TEncoding: TermEncoding>(
210 &self,
211 encoding: &TEncoding,
212 args: ScalarFunctionArgs,
213 ) -> DFResult<ScalarSparqlOpArgs<TEncoding>> {
214 let sparql_args = args
215 .args
216 .into_iter()
217 .map(|cv| encoding.try_new_datum(cv, args.number_rows))
218 .collect::<DFResult<Vec<_>>>()?;
219
220 Ok(ScalarSparqlOpArgs {
221 number_rows: args.number_rows,
222 args: sparql_args,
223 })
224 }
225}
226
227impl<TScalarSparqlOp: ScalarSparqlOp + 'static> ScalarUDFImpl
228 for ScalarSparqlOpAdapter<TScalarSparqlOp>
229{
230 fn as_any(&self) -> &dyn Any {
231 self
232 }
233
234 fn name(&self) -> &str {
235 &self.name
236 }
237
238 fn signature(&self) -> &Signature {
239 &self.signature
240 }
241
242 fn return_type(&self, arg_types: &[DataType]) -> DFResult<DataType> {
243 let encoding_name = self.detect_input_encoding(arg_types)?;
244 match encoding_name {
245 None => {
246 if let Some(op_impl) = self.op.plain_term_encoding_op() {
247 Ok(op_impl.return_type())
248 } else if let Some(op_impl) = self.op.typed_value_encoding_op() {
249 Ok(op_impl.return_type())
250 } else if let Some(oid_encoding) = self.encodings.object_id()
251 && let Some(op_impl) = self.op.object_id_encoding_op(oid_encoding)
252 {
253 Ok(op_impl.return_type())
254 } else {
255 exec_err!(
256 "The SPARQL operation '{}' does not support any encoding.",
257 &self.name
258 )
259 }
260 }
261 Some(EncodingName::PlainTerm) => self
262 .op
263 .plain_term_encoding_op()
264 .ok_or(exec_datafusion_err!(
265 "The SPARQL operation '{}' does not support the PlainTerm encoding.",
266 &self.name
267 ))
268 .map(|op_impl| op_impl.return_type()),
269 Some(EncodingName::TypedValue) => self
270 .op
271 .typed_value_encoding_op()
272 .ok_or(exec_datafusion_err!(
273 "The SPARQL operation '{}' does not support the TypedValue encoding.",
274 &self.name
275 ))
276 .map(|op_impl| op_impl.return_type()),
277 Some(EncodingName::Sortable) => {
278 exec_err!(
279 "The SparqlOp infrastructure does not support the Sortable encoding."
280 )
281 }
282 Some(EncodingName::ObjectId) => {
283 let encoding = self.encodings.object_id().ok_or(exec_datafusion_err!(
284 "Could not find the object id encoding."
285 ))?;
286
287 self
288 .op
289 .object_id_encoding_op(encoding)
290 .ok_or(exec_datafusion_err!(
291 "The SPARQL operation '{}' does not support the ObjectID encoding.",
292 &self.name
293 ))
294 .map(|op_impl| op_impl.return_type())
295 }
296 }
297 }
298
299 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
300 let data_types = args
301 .arg_fields
302 .iter()
303 .map(|f| f.data_type().clone())
304 .collect::<Vec<_>>();
305 let encoding = self.detect_input_encoding(&data_types)?;
306
307 let encoding = match encoding {
308 None => {
309 if self.op.typed_value_encoding_op().is_some() {
310 EncodingName::TypedValue
311 } else if self.op.plain_term_encoding_op().is_some() {
312 EncodingName::PlainTerm
313 } else if let Some(oid_encoding) = self.encodings.object_id()
314 && self.op.object_id_encoding_op(oid_encoding).is_some()
315 {
316 EncodingName::ObjectId
317 } else {
318 return exec_err!("No supported encodings");
319 }
320 }
321 Some(encoding) => encoding,
322 };
323
324 match encoding {
325 EncodingName::PlainTerm => {
326 if let Some(op) = self.op.plain_term_encoding_op() {
327 op.invoke(self.prepare_args(&PLAIN_TERM_ENCODING, args)?)
328 } else {
329 exec_err!("PlainTerm encoding not supported for this operation")
330 }
331 }
332 EncodingName::TypedValue => {
333 if let Some(op) = self.op.typed_value_encoding_op() {
334 op.invoke(self.prepare_args(&TYPED_VALUE_ENCODING, args)?)
335 } else {
336 exec_err!("TypedValue encoding not supported for this operation")
337 }
338 }
339 EncodingName::ObjectId => {
340 let Some(object_id_encoding) = self.encodings.object_id() else {
341 return exec_err!("Object ID is not registered.");
342 };
343
344 if let Some(op) = self.op.object_id_encoding_op(object_id_encoding) {
345 op.invoke(self.prepare_args(object_id_encoding, args)?)
346 } else {
347 exec_err!("TypedValue encoding not supported for this operation")
348 }
349 }
350 EncodingName::Sortable => exec_err!("Not supported"),
351 }
352 }
353}
354
355impl<TScalarSparqlOp: ScalarSparqlOp> Hash for ScalarSparqlOpAdapter<TScalarSparqlOp> {
359 fn hash<H: Hasher>(&self, state: &mut H) {
360 self.op.hash(state);
361 }
362}
363
364impl<TScalarSparqlOp: ScalarSparqlOp> PartialEq
365 for ScalarSparqlOpAdapter<TScalarSparqlOp>
366{
367 fn eq(&self, other: &Self) -> bool {
368 self.op == other.op && self.encodings == other.encodings
369 }
370}