rdf_fusion_functions/builtin/encoding/
with_plain_term_encoding.rs1use datafusion::arrow::array::ArrayRef;
2use datafusion::arrow::datatypes::{DataType, Field, FieldRef};
3use datafusion::common::{ScalarValue, exec_datafusion_err, exec_err, plan_err};
4use datafusion::logical_expr::{
5 ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl,
6 Signature, TypeSignature, Volatility,
7};
8use rdf_fusion_encoding::plain_term::PLAIN_TERM_ENCODING;
9use rdf_fusion_encoding::plain_term::encoders::TypedValueRefPlainTermEncoder;
10use rdf_fusion_encoding::typed_value::TYPED_VALUE_ENCODING;
11use rdf_fusion_encoding::typed_value::decoders::DefaultTypedValueDecoder;
12use rdf_fusion_encoding::{
13 EncodingArray, EncodingName, EncodingScalar, RdfFusionEncodings, TermDecoder,
14 TermEncoder, TermEncoding,
15};
16use rdf_fusion_extensions::functions::BuiltinName;
17use rdf_fusion_model::DFResult;
18use std::any::Any;
19use std::hash::{Hash, Hasher};
20
21pub fn with_plain_term_encoding(encodings: RdfFusionEncodings) -> ScalarUDF {
22 let udf_impl = WithPlainTermEncoding::new(encodings);
23 ScalarUDF::new_from_impl(udf_impl)
24}
25
26#[derive(Debug, PartialEq, Eq)]
28struct WithPlainTermEncoding {
29 name: String,
31 signature: Signature,
33 encodings: RdfFusionEncodings,
35}
36
37impl WithPlainTermEncoding {
38 pub fn new(encodings: RdfFusionEncodings) -> Self {
39 Self {
40 name: BuiltinName::WithPlainTermEncoding.to_string(),
41 signature: Signature::new(
42 TypeSignature::Uniform(
43 1,
44 encodings.get_data_types(&[
45 EncodingName::PlainTerm,
46 EncodingName::TypedValue,
47 EncodingName::ObjectId,
48 ]),
49 ),
50 Volatility::Immutable,
51 ),
52 encodings,
53 }
54 }
55
56 fn convert_array(
57 &self,
58 encoding_name: EncodingName,
59 array: ArrayRef,
60 ) -> DFResult<ColumnarValue> {
61 match encoding_name {
62 EncodingName::PlainTerm => Ok(ColumnarValue::Array(array)),
63 EncodingName::TypedValue => {
64 let array = TYPED_VALUE_ENCODING.try_new_array(array)?;
65 let input = DefaultTypedValueDecoder::decode_terms(&array);
66 let result = TypedValueRefPlainTermEncoder::encode_terms(input)?;
67 Ok(ColumnarValue::Array(result.into_array()))
68 }
69 EncodingName::Sortable => exec_err!("Cannot from sortable term."),
70 EncodingName::ObjectId => match self.encodings.object_id_mapping() {
71 None => exec_err!("Cannot from object id as no encoding is provided."),
72 Some(object_id_encoding) => {
73 let array = object_id_encoding.encoding().try_new_array(array)?;
74 let decoded = object_id_encoding.decode_array(&array)?;
75 Ok(ColumnarValue::Array(decoded.into_array()))
76 }
77 },
78 }
79 }
80
81 fn convert_scalar(
82 &self,
83 encoding_name: EncodingName,
84 scalar: ScalarValue,
85 ) -> DFResult<ColumnarValue> {
86 match encoding_name {
87 EncodingName::PlainTerm => Ok(ColumnarValue::Scalar(scalar)),
88 EncodingName::TypedValue => {
89 let scalar = TYPED_VALUE_ENCODING.try_new_scalar(scalar)?;
90 let input = DefaultTypedValueDecoder::decode_term(&scalar);
91 let result = TypedValueRefPlainTermEncoder::encode_term(input)?;
92 Ok(ColumnarValue::Scalar(result.into_scalar_value()))
93 }
94 EncodingName::Sortable => exec_err!("Cannot from sortable term."),
95 EncodingName::ObjectId => match self.encodings.object_id_mapping() {
96 None => exec_err!("Cannot from object id as no encoding is provided."),
97 Some(object_id_encoding) => {
98 let scalar = object_id_encoding.encoding().try_new_scalar(scalar)?;
99 let decoded = object_id_encoding.decode_scalar(&scalar)?;
100 Ok(ColumnarValue::Scalar(decoded.into_scalar_value()))
101 }
102 },
103 }
104 }
105}
106
107impl ScalarUDFImpl for WithPlainTermEncoding {
108 fn as_any(&self) -> &dyn Any {
109 self
110 }
111
112 fn name(&self) -> &str {
113 &self.name
114 }
115
116 fn signature(&self) -> &Signature {
117 &self.signature
118 }
119
120 fn return_type(
121 &self,
122 _arg_types: &[DataType],
123 ) -> datafusion::common::Result<DataType> {
124 exec_err!("return_field_from_args should be called")
125 }
126
127 fn return_field_from_args(&self, args: ReturnFieldArgs<'_>) -> DFResult<FieldRef> {
128 if args.arg_fields.len() != 1 {
129 return plan_err!(
130 "Unexpected number of arg fields in return_field_from_args."
131 );
132 }
133
134 let data_type = PLAIN_TERM_ENCODING.data_type();
135 let incoming_null = args.arg_fields[0].is_nullable();
136 Ok(FieldRef::new(Field::new(
137 "output",
138 data_type,
139 incoming_null,
140 )))
141 }
142
143 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
144 let args = TryInto::<[ColumnarValue; 1]>::try_into(args.args)
145 .map_err(|_| exec_datafusion_err!("Invalid number of arguments."))?;
146 let encoding_name = self
147 .encodings
148 .try_get_encoding_name(&args[0].data_type())
149 .ok_or(exec_datafusion_err!(
150 "Cannot obtain encoding from argument."
151 ))?;
152
153 match args {
154 [ColumnarValue::Array(array)] => self.convert_array(encoding_name, array),
155 [ColumnarValue::Scalar(scalar)] => self.convert_scalar(encoding_name, scalar),
156 }
157 }
158}
159
160impl Hash for WithPlainTermEncoding {
161 fn hash<H: Hasher>(&self, state: &mut H) {
162 self.as_any().type_id().hash(state);
163 }
164}