rdf_fusion_functions/builtin/query/
is_compatible.rs

1use datafusion::arrow::array::{Array, BooleanArray, BooleanBuilder, make_comparator};
2use datafusion::arrow::compute::SortOptions;
3use datafusion::arrow::compute::kernels::cmp::eq;
4use datafusion::arrow::datatypes::DataType;
5use datafusion::common::{ScalarValue, exec_err};
6use datafusion::logical_expr::{
7    ColumnarValue, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature,
8    TypeSignature, Volatility,
9};
10use rdf_fusion_encoding::{EncodingName, RdfFusionEncodings};
11use rdf_fusion_extensions::functions::BuiltinName;
12use rdf_fusion_model::DFResult;
13use std::any::Any;
14use std::cmp::Ordering;
15use std::hash::{Hash, Hasher};
16use std::sync::Arc;
17
18pub fn is_compatible(encodings: &RdfFusionEncodings) -> ScalarUDF {
19    let udf_impl = IsCompatible::new(encodings);
20    ScalarUDF::new_from_impl(udf_impl)
21}
22
23#[derive(Debug, Eq)]
24struct IsCompatible {
25    name: String,
26    signature: Signature,
27}
28
29impl IsCompatible {
30    pub fn new(encodings: &RdfFusionEncodings) -> Self {
31        Self {
32            name: BuiltinName::IsCompatible.to_string(),
33            signature: Signature::new(
34                TypeSignature::Uniform(
35                    2,
36                    encodings.get_data_types(&[
37                        EncodingName::PlainTerm,
38                        EncodingName::ObjectId,
39                    ]),
40                ),
41                Volatility::Immutable,
42            ),
43        }
44    }
45}
46
47impl ScalarUDFImpl for IsCompatible {
48    fn as_any(&self) -> &dyn Any {
49        self
50    }
51
52    fn name(&self) -> &str {
53        &self.name
54    }
55
56    fn signature(&self) -> &Signature {
57        &self.signature
58    }
59
60    fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
61        Ok(DataType::Boolean)
62    }
63
64    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
65        match TryInto::<[_; 2]>::try_into(args.args) {
66            Ok([ColumnarValue::Array(lhs), ColumnarValue::Array(rhs)]) => {
67                invoke_array_array(args.number_rows, lhs.as_ref(), rhs.as_ref())
68            }
69            Ok([ColumnarValue::Scalar(lhs), ColumnarValue::Array(rhs)]) => {
70                invoke_scalar_array(args.number_rows, &lhs, rhs.as_ref())
71            }
72            Ok([ColumnarValue::Array(lhs), ColumnarValue::Scalar(rhs)]) => {
73                // Commutative operation
74                invoke_scalar_array(args.number_rows, &rhs, lhs.as_ref())
75            }
76            Ok([ColumnarValue::Scalar(lhs), ColumnarValue::Scalar(rhs)]) => {
77                Ok(invoke_scalar_scalar(&lhs, &rhs))
78            }
79            _ => exec_err!("Invalid arguments for IsCompatible"),
80        }
81    }
82}
83
84impl Hash for IsCompatible {
85    fn hash<H: Hasher>(&self, state: &mut H) {
86        self.as_any().type_id().hash(state);
87    }
88}
89
90impl PartialEq for IsCompatible {
91    fn eq(&self, other: &Self) -> bool {
92        self.as_any().type_id() == other.as_any().type_id()
93            && self.signature.eq(&other.signature)
94    }
95}
96
97pub(crate) fn invoke_array_array(
98    number_rows: usize,
99    lhs: &dyn Array,
100    rhs: &dyn Array,
101) -> DFResult<ColumnarValue> {
102    let mut eq_res = invoke_eq_array(number_rows, lhs, rhs)?;
103
104    if eq_res.null_count() > 0 {
105        eq_res = fill_nulls(&eq_res, true);
106    }
107
108    Ok(ColumnarValue::Array(Arc::new(eq_res)))
109}
110
111pub(crate) fn invoke_scalar_array(
112    number_rows: usize,
113    lhs: &ScalarValue,
114    rhs: &dyn Array,
115) -> DFResult<ColumnarValue> {
116    if lhs.is_null() {
117        return Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(true))));
118    }
119
120    let eq_res = invoke_eq_array_scalar(number_rows, rhs, lhs)?;
121    if eq_res.null_count() > 0 {
122        let result = fill_nulls(&eq_res, true);
123        Ok(ColumnarValue::Array(Arc::new(result)))
124    } else {
125        Ok(ColumnarValue::Array(Arc::new(eq_res)))
126    }
127}
128
129pub(crate) fn invoke_scalar_scalar(
130    lhs: &ScalarValue,
131    rhs: &ScalarValue,
132) -> ColumnarValue {
133    ColumnarValue::Scalar(ScalarValue::Boolean(Some(
134        lhs.is_null() || rhs.is_null() || lhs == rhs,
135    )))
136}
137
138fn invoke_eq_array(
139    number_rows: usize,
140    lhs: &dyn Array,
141    rhs: &dyn Array,
142) -> DFResult<BooleanArray> {
143    let data_type = lhs.data_type();
144    if data_type.is_nested() {
145        let comparator = make_comparator(lhs, rhs, SortOptions::default())?;
146        let result = (0..number_rows)
147            .map(|i| {
148                Some(
149                    lhs.is_null(i)
150                        || rhs.is_null(i)
151                        || comparator(i, i) == Ordering::Equal,
152                )
153            })
154            .collect::<BooleanArray>();
155        Ok(result)
156    } else {
157        Ok(eq(&lhs, &rhs)?)
158    }
159}
160
161fn invoke_eq_array_scalar(
162    number_rows: usize,
163    lhs: &dyn Array,
164    rhs: &ScalarValue,
165) -> DFResult<BooleanArray> {
166    let rhs = rhs.to_array_of_size(number_rows)?;
167    invoke_eq_array(number_rows, lhs, &rhs)
168}
169
170fn fill_nulls(bool_array: &BooleanArray, fill_value: bool) -> BooleanArray {
171    let mut builder = BooleanBuilder::with_capacity(bool_array.len());
172
173    for i in 0..bool_array.len() {
174        if bool_array.is_null(i) {
175            builder.append_value(fill_value);
176        } else {
177            builder.append_value(bool_array.value(i));
178        }
179    }
180
181    builder.finish()
182}