rdf_fusion_functions/builtin/encoding/
with_sortable_encoding.rs

1use datafusion::arrow::array::ArrayRef;
2use datafusion::arrow::datatypes::DataType;
3use datafusion::common::{ScalarValue, exec_datafusion_err, exec_err};
4use datafusion::logical_expr::{
5    ColumnarValue, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature,
6    TypeSignature, Volatility,
7};
8use rdf_fusion_encoding::plain_term::PLAIN_TERM_ENCODING;
9use rdf_fusion_encoding::plain_term::decoders::DefaultPlainTermDecoder;
10use rdf_fusion_encoding::sortable_term::SORTABLE_TERM_ENCODING;
11use rdf_fusion_encoding::sortable_term::encoders::{
12    TermRefSortableTermEncoder, TypedValueRefSortableTermEncoder,
13};
14use rdf_fusion_encoding::typed_value::TYPED_VALUE_ENCODING;
15use rdf_fusion_encoding::typed_value::decoders::DefaultTypedValueDecoder;
16use rdf_fusion_encoding::{
17    EncodingArray, EncodingName, EncodingScalar, RdfFusionEncodings, TermDecoder,
18    TermEncoder, TermEncoding,
19};
20use rdf_fusion_extensions::functions::BuiltinName;
21use rdf_fusion_model::DFResult;
22use std::any::Any;
23use std::hash::{Hash, Hasher};
24
25pub fn with_sortable_term_encoding(encodings: RdfFusionEncodings) -> ScalarUDF {
26    let udf_impl = WithSortableEncoding::new(encodings);
27    ScalarUDF::new_from_impl(udf_impl)
28}
29
30/// Transforms RDF Terms into the [SortableTermEncoding](rdf_fusion_encoding::sortable_term::SortableTermEncoding).
31#[derive(Debug, PartialEq, Eq)]
32struct WithSortableEncoding {
33    /// The name of this function
34    name: String,
35    /// The signature of this function
36    signature: Signature,
37    /// The registered encodings
38    encodings: RdfFusionEncodings,
39}
40
41impl WithSortableEncoding {
42    pub fn new(encodings: RdfFusionEncodings) -> Self {
43        Self {
44            name: BuiltinName::WithSortableEncoding.to_string(),
45            signature: Signature::new(
46                TypeSignature::Uniform(
47                    1,
48                    vec![
49                        PLAIN_TERM_ENCODING.data_type(),
50                        TYPED_VALUE_ENCODING.data_type(),
51                    ],
52                ),
53                Volatility::Immutable,
54            ),
55            encodings,
56        }
57    }
58
59    fn convert_scalar(
60        encoding_name: EncodingName,
61        scalar: ScalarValue,
62    ) -> DFResult<ColumnarValue> {
63        match encoding_name {
64            EncodingName::PlainTerm => {
65                let scalar = PLAIN_TERM_ENCODING.try_new_scalar(scalar)?;
66                let input = DefaultPlainTermDecoder::decode_term(&scalar);
67                let result = TermRefSortableTermEncoder::encode_term(input)?;
68                Ok(ColumnarValue::Scalar(result.into_scalar_value()))
69            }
70            EncodingName::TypedValue => {
71                let scalar = TYPED_VALUE_ENCODING.try_new_scalar(scalar)?;
72                let input = DefaultTypedValueDecoder::decode_term(&scalar);
73                let result = TypedValueRefSortableTermEncoder::encode_term(input)?;
74                Ok(ColumnarValue::Scalar(result.into_scalar_value()))
75            }
76            EncodingName::Sortable => Ok(ColumnarValue::Scalar(scalar)),
77            EncodingName::ObjectId => exec_err!("Cannot from object id."),
78        }
79    }
80
81    fn convert_array(
82        encoding_name: EncodingName,
83        array: ArrayRef,
84    ) -> DFResult<ColumnarValue> {
85        match encoding_name {
86            EncodingName::PlainTerm => {
87                let array = PLAIN_TERM_ENCODING.try_new_array(array)?;
88                let input = DefaultPlainTermDecoder::decode_terms(&array);
89                let result = TermRefSortableTermEncoder::encode_terms(input)?;
90                Ok(ColumnarValue::Array(result.into_array()))
91            }
92            EncodingName::TypedValue => {
93                let array = TYPED_VALUE_ENCODING.try_new_array(array)?;
94                let input = DefaultTypedValueDecoder::decode_terms(&array);
95                let result = TypedValueRefSortableTermEncoder::encode_terms(input)?;
96                Ok(ColumnarValue::Array(result.into_array()))
97            }
98            EncodingName::Sortable => Ok(ColumnarValue::Array(array)),
99            EncodingName::ObjectId => exec_err!("Cannot from object id."),
100        }
101    }
102}
103
104impl ScalarUDFImpl for WithSortableEncoding {
105    fn as_any(&self) -> &dyn Any {
106        self
107    }
108
109    fn name(&self) -> &str {
110        &self.name
111    }
112
113    fn signature(&self) -> &Signature {
114        &self.signature
115    }
116
117    fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
118        Ok(SORTABLE_TERM_ENCODING.data_type())
119    }
120
121    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
122        let args = TryInto::<[ColumnarValue; 1]>::try_into(args.args)
123            .map_err(|_| exec_datafusion_err!("Invalid number of arguments."))?;
124        let encoding_name = self
125            .encodings
126            .try_get_encoding_name(&args[0].data_type())
127            .ok_or(exec_datafusion_err!(
128                "Cannot obtain encoding from argument."
129            ))?;
130
131        match args {
132            [ColumnarValue::Array(array)] => Self::convert_array(encoding_name, array),
133            [ColumnarValue::Scalar(scalar)] => {
134                Self::convert_scalar(encoding_name, scalar)
135            }
136        }
137    }
138}
139
140impl Hash for WithSortableEncoding {
141    fn hash<H: Hasher>(&self, state: &mut H) {
142        self.as_any().type_id().hash(state);
143    }
144}