use std::any::Any;
use std::sync::Arc;
use datafusion::arrow::array::{BinaryArray, BooleanArray, StringArray};
use datafusion::arrow::datatypes::DataType;
use datafusion::common::Result as DfResult;
use datafusion::logical_expr::{
ColumnarValue, ScalarUDFImpl, Signature, TypeSignature, Volatility,
};
use super::nav::{expand_to_array, navigate_json, navigate_rmpv, rmpv_to_string};
#[derive(Debug, PartialEq, Eq, Hash)]
pub struct DocArrayContains {
signature: Signature,
}
impl DocArrayContains {
pub fn new() -> Self {
Self {
signature: Signature::one_of(
vec![
TypeSignature::Exact(vec![DataType::Binary, DataType::Utf8, DataType::Utf8]),
TypeSignature::Exact(vec![DataType::Utf8, DataType::Utf8, DataType::Utf8]),
],
Volatility::Immutable,
),
}
}
}
impl Default for DocArrayContains {
fn default() -> Self {
Self::new()
}
}
impl ScalarUDFImpl for DocArrayContains {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"doc_array_contains"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _arg_types: &[DataType]) -> DfResult<DataType> {
Ok(DataType::Boolean)
}
fn invoke_with_args(
&self,
args: datafusion::logical_expr::ScalarFunctionArgs,
) -> DfResult<ColumnarValue> {
let num_rows = args.number_rows;
let docs = expand_to_array(&args.args[0], num_rows)?;
let paths = expand_to_array(&args.args[1], num_rows)?;
let values = expand_to_array(&args.args[2], num_rows)?;
let paths = paths
.as_any()
.downcast_ref::<StringArray>()
.ok_or_else(|| {
datafusion::error::DataFusionError::Internal(
"doc_array_contains: expected Utf8 for path".into(),
)
})?;
let values = values
.as_any()
.downcast_ref::<StringArray>()
.ok_or_else(|| {
datafusion::error::DataFusionError::Internal(
"doc_array_contains: expected Utf8 for value".into(),
)
})?;
if let Some(bin_docs) = docs.as_any().downcast_ref::<BinaryArray>() {
let result: BooleanArray = bin_docs
.iter()
.zip(paths.iter())
.zip(values.iter())
.map(|((doc, path), val)| Some(contains_msgpack(doc?, path?, val?)))
.collect();
Ok(ColumnarValue::Array(Arc::new(result)))
} else if let Some(str_docs) = docs.as_any().downcast_ref::<StringArray>() {
let result: BooleanArray = str_docs
.iter()
.zip(paths.iter())
.zip(values.iter())
.map(|((doc, path), val)| Some(contains_json(doc?, path?, val?)))
.collect();
Ok(ColumnarValue::Array(Arc::new(result)))
} else {
Err(datafusion::error::DataFusionError::Internal(
"doc_array_contains: expected Binary or Utf8 for document".into(),
))
}
}
}
fn contains_msgpack(data: &[u8], path: &str, needle: &str) -> bool {
let value: rmpv::Value = match rmpv::decode::read_value(&mut &data[..]) {
Ok(v) => v,
Err(_) => return false,
};
let target = match navigate_rmpv(&value, path) {
Some(v) => v,
None => return false,
};
let arr = match target.as_array() {
Some(a) => a,
None => return false,
};
if let Ok(needle_num) = needle.parse::<f64>() {
for item in arr {
if let Some(n) = item.as_f64()
&& (n - needle_num).abs() < f64::EPSILON
{
return true;
}
if let Some(i) = item.as_i64()
&& (i as f64 - needle_num).abs() < f64::EPSILON
{
return true;
}
if let Some(u) = item.as_u64()
&& (u as f64 - needle_num).abs() < f64::EPSILON
{
return true;
}
}
}
for item in arr {
let item_str = rmpv_to_string(item);
if item_str == needle {
return true;
}
}
false
}
fn contains_json(json_str: &str, path: &str, needle: &str) -> bool {
let parsed: serde_json::Value = match serde_json::from_str(json_str) {
Ok(v) => v,
Err(_) => return false,
};
let target = match navigate_json(&parsed, path) {
Some(v) => v,
None => return false,
};
let arr = match target.as_array() {
Some(a) => a,
None => return false,
};
if let Ok(needle_num) = needle.parse::<f64>() {
for item in arr {
if let Some(n) = item.as_f64()
&& (n - needle_num).abs() < f64::EPSILON
{
return true;
}
}
}
for item in arr {
match item {
serde_json::Value::String(s) if s == needle => return true,
serde_json::Value::Bool(b) => {
if (needle == "true" && *b) || (needle == "false" && !*b) {
return true;
}
}
other => {
let s = other.to_string();
if s == needle {
return true;
}
}
}
}
false
}
#[cfg(test)]
mod tests {
use super::super::nav::test_util::to_msgpack;
use super::*;
#[test]
fn msgpack_contains_string() {
let json = serde_json::json!({"tags": ["admin", "user", "editor"]});
let data = to_msgpack(&json);
assert!(contains_msgpack(&data, "$.tags", "admin"));
assert!(!contains_msgpack(&data, "$.tags", "guest"));
}
#[test]
fn msgpack_contains_number() {
let json = serde_json::json!({"scores": [10, 20, 30]});
let data = to_msgpack(&json);
assert!(contains_msgpack(&data, "$.scores", "20"));
assert!(!contains_msgpack(&data, "$.scores", "25"));
}
#[test]
fn msgpack_contains_nested() {
let json = serde_json::json!({"user": {"roles": ["admin"]}});
let data = to_msgpack(&json);
assert!(contains_msgpack(&data, "$.user.roles", "admin"));
}
#[test]
fn msgpack_not_array_returns_false() {
let json = serde_json::json!({"name": "alice"});
let data = to_msgpack(&json);
assert!(!contains_msgpack(&data, "$.name", "alice"));
}
#[test]
fn msgpack_missing_path_returns_false() {
let json = serde_json::json!({"a": 1});
let data = to_msgpack(&json);
assert!(!contains_msgpack(&data, "$.missing", "1"));
}
super::super::nav::test_util::assert_invalid_msgpack!(contains_msgpack, "$.tags", "x");
#[test]
fn json_compat_contains() {
assert!(contains_json(
r#"{"tags": ["admin", "user"]}"#,
"$.tags",
"admin"
));
assert!(!contains_json(r#"{"tags": ["admin"]}"#, "$.tags", "guest"));
}
#[test]
fn json_contains_boolean() {
assert!(contains_json(
r#"{"flags": [true, false]}"#,
"$.flags",
"true"
));
}
super::super::nav::test_util::generate_udf_batch_test_3arg!(
DocArrayContains,
DataType::Boolean,
BooleanArray,
serde_json::json!({"tags": ["a", "b"]}),
serde_json::json!({"tags": ["c"]}),
"$.tags",
"a",
|arr| {
assert!(arr.value(0));
assert!(!arr.value(1));
}
);
}