rdf_fusion_functions/builtin/encoding/
with_plain_term_encoding.rs

1use datafusion::arrow::array::ArrayRef;
2use datafusion::arrow::datatypes::{DataType, Field, FieldRef};
3use datafusion::common::{ScalarValue, exec_datafusion_err, exec_err, plan_err};
4use datafusion::logical_expr::{
5    ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl,
6    Signature, TypeSignature, Volatility,
7};
8use rdf_fusion_encoding::plain_term::PLAIN_TERM_ENCODING;
9use rdf_fusion_encoding::plain_term::encoders::TypedValueRefPlainTermEncoder;
10use rdf_fusion_encoding::typed_value::TYPED_VALUE_ENCODING;
11use rdf_fusion_encoding::typed_value::decoders::DefaultTypedValueDecoder;
12use rdf_fusion_encoding::{
13    EncodingArray, EncodingName, EncodingScalar, RdfFusionEncodings, TermDecoder,
14    TermEncoder, TermEncoding,
15};
16use rdf_fusion_extensions::functions::BuiltinName;
17use rdf_fusion_model::DFResult;
18use std::any::Any;
19use std::hash::{Hash, Hasher};
20
21pub fn with_plain_term_encoding(encodings: RdfFusionEncodings) -> ScalarUDF {
22    let udf_impl = WithPlainTermEncoding::new(encodings);
23    ScalarUDF::new_from_impl(udf_impl)
24}
25
26/// Transforms RDF Terms into the [PlainTermEncoding](rdf_fusion_encoding::plain_term::PlainTermEncoding).
27#[derive(Debug, PartialEq, Eq)]
28struct WithPlainTermEncoding {
29    /// The name of the UDF
30    name: String,
31    /// The signature
32    signature: Signature,
33    /// A reference to used encodings.
34    encodings: RdfFusionEncodings,
35}
36
37impl WithPlainTermEncoding {
38    pub fn new(encodings: RdfFusionEncodings) -> Self {
39        Self {
40            name: BuiltinName::WithPlainTermEncoding.to_string(),
41            signature: Signature::new(
42                TypeSignature::Uniform(
43                    1,
44                    encodings.get_data_types(&[
45                        EncodingName::PlainTerm,
46                        EncodingName::TypedValue,
47                        EncodingName::ObjectId,
48                    ]),
49                ),
50                Volatility::Immutable,
51            ),
52            encodings,
53        }
54    }
55
56    fn convert_array(
57        &self,
58        encoding_name: EncodingName,
59        array: ArrayRef,
60    ) -> DFResult<ColumnarValue> {
61        match encoding_name {
62            EncodingName::PlainTerm => Ok(ColumnarValue::Array(array)),
63            EncodingName::TypedValue => {
64                let array = TYPED_VALUE_ENCODING.try_new_array(array)?;
65                let input = DefaultTypedValueDecoder::decode_terms(&array);
66                let result = TypedValueRefPlainTermEncoder::encode_terms(input)?;
67                Ok(ColumnarValue::Array(result.into_array()))
68            }
69            EncodingName::Sortable => exec_err!("Cannot from sortable term."),
70            EncodingName::ObjectId => match self.encodings.object_id_mapping() {
71                None => exec_err!("Cannot from object id as no encoding is provided."),
72                Some(object_id_encoding) => {
73                    let array = object_id_encoding.encoding().try_new_array(array)?;
74                    let decoded = object_id_encoding.decode_array(&array)?;
75                    Ok(ColumnarValue::Array(decoded.into_array()))
76                }
77            },
78        }
79    }
80
81    fn convert_scalar(
82        &self,
83        encoding_name: EncodingName,
84        scalar: ScalarValue,
85    ) -> DFResult<ColumnarValue> {
86        match encoding_name {
87            EncodingName::PlainTerm => Ok(ColumnarValue::Scalar(scalar)),
88            EncodingName::TypedValue => {
89                let scalar = TYPED_VALUE_ENCODING.try_new_scalar(scalar)?;
90                let input = DefaultTypedValueDecoder::decode_term(&scalar);
91                let result = TypedValueRefPlainTermEncoder::encode_term(input)?;
92                Ok(ColumnarValue::Scalar(result.into_scalar_value()))
93            }
94            EncodingName::Sortable => exec_err!("Cannot from sortable term."),
95            EncodingName::ObjectId => match self.encodings.object_id_mapping() {
96                None => exec_err!("Cannot from object id as no encoding is provided."),
97                Some(object_id_encoding) => {
98                    let scalar = object_id_encoding.encoding().try_new_scalar(scalar)?;
99                    let decoded = object_id_encoding.decode_scalar(&scalar)?;
100                    Ok(ColumnarValue::Scalar(decoded.into_scalar_value()))
101                }
102            },
103        }
104    }
105}
106
107impl ScalarUDFImpl for WithPlainTermEncoding {
108    fn as_any(&self) -> &dyn Any {
109        self
110    }
111
112    fn name(&self) -> &str {
113        &self.name
114    }
115
116    fn signature(&self) -> &Signature {
117        &self.signature
118    }
119
120    fn return_type(
121        &self,
122        _arg_types: &[DataType],
123    ) -> datafusion::common::Result<DataType> {
124        exec_err!("return_field_from_args should be called")
125    }
126
127    fn return_field_from_args(&self, args: ReturnFieldArgs<'_>) -> DFResult<FieldRef> {
128        if args.arg_fields.len() != 1 {
129            return plan_err!(
130                "Unexpected number of arg fields in return_field_from_args."
131            );
132        }
133
134        let data_type = PLAIN_TERM_ENCODING.data_type();
135        let incoming_null = args.arg_fields[0].is_nullable();
136        Ok(FieldRef::new(Field::new(
137            "output",
138            data_type,
139            incoming_null,
140        )))
141    }
142
143    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
144        let args = TryInto::<[ColumnarValue; 1]>::try_into(args.args)
145            .map_err(|_| exec_datafusion_err!("Invalid number of arguments."))?;
146        let encoding_name = self
147            .encodings
148            .try_get_encoding_name(&args[0].data_type())
149            .ok_or(exec_datafusion_err!(
150                "Cannot obtain encoding from argument."
151            ))?;
152
153        match args {
154            [ColumnarValue::Array(array)] => self.convert_array(encoding_name, array),
155            [ColumnarValue::Scalar(scalar)] => self.convert_scalar(encoding_name, scalar),
156        }
157    }
158}
159
160impl Hash for WithPlainTermEncoding {
161    fn hash<H: Hasher>(&self, state: &mut H) {
162        self.as_any().type_id().hash(state);
163    }
164}