rdf_fusion_functions/builtin/query/
is_compatible.rs1use datafusion::arrow::array::{Array, BooleanArray, BooleanBuilder, make_comparator};
2use datafusion::arrow::compute::SortOptions;
3use datafusion::arrow::compute::kernels::cmp::eq;
4use datafusion::arrow::datatypes::DataType;
5use datafusion::common::{ScalarValue, exec_err};
6use datafusion::logical_expr::{
7 ColumnarValue, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature,
8 TypeSignature, Volatility,
9};
10use rdf_fusion_encoding::{EncodingName, RdfFusionEncodings};
11use rdf_fusion_extensions::functions::BuiltinName;
12use rdf_fusion_model::DFResult;
13use std::any::Any;
14use std::cmp::Ordering;
15use std::hash::{Hash, Hasher};
16use std::sync::Arc;
17
18pub fn is_compatible(encodings: &RdfFusionEncodings) -> ScalarUDF {
19 let udf_impl = IsCompatible::new(encodings);
20 ScalarUDF::new_from_impl(udf_impl)
21}
22
23#[derive(Debug, Eq)]
24struct IsCompatible {
25 name: String,
26 signature: Signature,
27}
28
29impl IsCompatible {
30 pub fn new(encodings: &RdfFusionEncodings) -> Self {
31 Self {
32 name: BuiltinName::IsCompatible.to_string(),
33 signature: Signature::new(
34 TypeSignature::Uniform(
35 2,
36 encodings.get_data_types(&[
37 EncodingName::PlainTerm,
38 EncodingName::ObjectId,
39 ]),
40 ),
41 Volatility::Immutable,
42 ),
43 }
44 }
45}
46
47impl ScalarUDFImpl for IsCompatible {
48 fn as_any(&self) -> &dyn Any {
49 self
50 }
51
52 fn name(&self) -> &str {
53 &self.name
54 }
55
56 fn signature(&self) -> &Signature {
57 &self.signature
58 }
59
60 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
61 Ok(DataType::Boolean)
62 }
63
64 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
65 match TryInto::<[_; 2]>::try_into(args.args) {
66 Ok([ColumnarValue::Array(lhs), ColumnarValue::Array(rhs)]) => {
67 invoke_array_array(args.number_rows, lhs.as_ref(), rhs.as_ref())
68 }
69 Ok([ColumnarValue::Scalar(lhs), ColumnarValue::Array(rhs)]) => {
70 invoke_scalar_array(args.number_rows, &lhs, rhs.as_ref())
71 }
72 Ok([ColumnarValue::Array(lhs), ColumnarValue::Scalar(rhs)]) => {
73 invoke_scalar_array(args.number_rows, &rhs, lhs.as_ref())
75 }
76 Ok([ColumnarValue::Scalar(lhs), ColumnarValue::Scalar(rhs)]) => {
77 Ok(invoke_scalar_scalar(&lhs, &rhs))
78 }
79 _ => exec_err!("Invalid arguments for IsCompatible"),
80 }
81 }
82}
83
84impl Hash for IsCompatible {
85 fn hash<H: Hasher>(&self, state: &mut H) {
86 self.as_any().type_id().hash(state);
87 }
88}
89
90impl PartialEq for IsCompatible {
91 fn eq(&self, other: &Self) -> bool {
92 self.as_any().type_id() == other.as_any().type_id()
93 && self.signature.eq(&other.signature)
94 }
95}
96
97pub(crate) fn invoke_array_array(
98 number_rows: usize,
99 lhs: &dyn Array,
100 rhs: &dyn Array,
101) -> DFResult<ColumnarValue> {
102 let mut eq_res = invoke_eq_array(number_rows, lhs, rhs)?;
103
104 if eq_res.null_count() > 0 {
105 eq_res = fill_nulls(&eq_res, true);
106 }
107
108 Ok(ColumnarValue::Array(Arc::new(eq_res)))
109}
110
111pub(crate) fn invoke_scalar_array(
112 number_rows: usize,
113 lhs: &ScalarValue,
114 rhs: &dyn Array,
115) -> DFResult<ColumnarValue> {
116 if lhs.is_null() {
117 return Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(true))));
118 }
119
120 let eq_res = invoke_eq_array_scalar(number_rows, rhs, lhs)?;
121 if eq_res.null_count() > 0 {
122 let result = fill_nulls(&eq_res, true);
123 Ok(ColumnarValue::Array(Arc::new(result)))
124 } else {
125 Ok(ColumnarValue::Array(Arc::new(eq_res)))
126 }
127}
128
129pub(crate) fn invoke_scalar_scalar(
130 lhs: &ScalarValue,
131 rhs: &ScalarValue,
132) -> ColumnarValue {
133 ColumnarValue::Scalar(ScalarValue::Boolean(Some(
134 lhs.is_null() || rhs.is_null() || lhs == rhs,
135 )))
136}
137
138fn invoke_eq_array(
139 number_rows: usize,
140 lhs: &dyn Array,
141 rhs: &dyn Array,
142) -> DFResult<BooleanArray> {
143 let data_type = lhs.data_type();
144 if data_type.is_nested() {
145 let comparator = make_comparator(lhs, rhs, SortOptions::default())?;
146 let result = (0..number_rows)
147 .map(|i| {
148 Some(
149 lhs.is_null(i)
150 || rhs.is_null(i)
151 || comparator(i, i) == Ordering::Equal,
152 )
153 })
154 .collect::<BooleanArray>();
155 Ok(result)
156 } else {
157 Ok(eq(&lhs, &rhs)?)
158 }
159}
160
161fn invoke_eq_array_scalar(
162 number_rows: usize,
163 lhs: &dyn Array,
164 rhs: &ScalarValue,
165) -> DFResult<BooleanArray> {
166 let rhs = rhs.to_array_of_size(number_rows)?;
167 invoke_eq_array(number_rows, lhs, &rhs)
168}
169
170fn fill_nulls(bool_array: &BooleanArray, fill_value: bool) -> BooleanArray {
171 let mut builder = BooleanBuilder::with_capacity(bool_array.len());
172
173 for i in 0..bool_array.len() {
174 if bool_array.is_null(i) {
175 builder.append_value(fill_value);
176 } else {
177 builder.append_value(bool_array.value(i));
178 }
179 }
180
181 builder.finish()
182}