use arrow::array::ArrayRef;
use arrow::datatypes::DataType;
use arrow_array::{
Array, BooleanArray, FixedSizeBinaryArray, Float32Array, Float64Array, Int32Array, Int64Array,
LargeBinaryArray, LargeStringArray, StringArray, UInt64Array,
};
use chrono::Offset;
use datafusion::error::Result as DFResult;
use datafusion::logical_expr::{
ColumnarValue, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, TypeSignature,
Volatility,
};
use datafusion::prelude::SessionContext;
use datafusion::scalar::ScalarValue;
use std::any::Any;
use std::hash::{Hash, Hasher};
use std::sync::Arc;
use uni_common::Value;
use uni_cypher::ast::BinaryOp;
use uni_store::storage::arrow_convert::values_to_array;
use super::expr_eval::cypher_eq;
macro_rules! impl_udf_eq_hash {
($type:ty) => {
impl PartialEq for $type {
fn eq(&self, other: &Self) -> bool {
self.signature == other.signature
}
}
impl Eq for $type {}
impl Hash for $type {
fn hash<H: Hasher>(&self, state: &mut H) {
self.name().hash(state);
}
}
};
}
pub fn register_cypher_udfs(ctx: &SessionContext) -> DFResult<()> {
ctx.register_udf(create_id_udf());
ctx.register_udf(create_type_udf());
ctx.register_udf(create_keys_udf());
ctx.register_udf(create_properties_udf());
ctx.register_udf(create_labels_udf());
ctx.register_udf(create_nodes_udf());
ctx.register_udf(create_relationships_udf());
ctx.register_udf(create_range_udf());
ctx.register_udf(create_index_udf());
ctx.register_udf(create_startnode_udf());
ctx.register_udf(create_endnode_udf());
ctx.register_udf(create_to_integer_udf());
ctx.register_udf(create_to_float_udf());
ctx.register_udf(create_to_boolean_udf());
ctx.register_udf(create_bitwise_or_udf());
ctx.register_udf(create_bitwise_and_udf());
ctx.register_udf(create_bitwise_xor_udf());
ctx.register_udf(create_bitwise_not_udf());
ctx.register_udf(create_shift_left_udf());
ctx.register_udf(create_shift_right_udf());
for name in &[
"date",
"time",
"localtime",
"localdatetime",
"datetime",
"duration",
"btic",
"duration.between",
"duration.inmonths",
"duration.indays",
"duration.inseconds",
"datetime.fromepoch",
"datetime.fromepochmillis",
"date.truncate",
"time.truncate",
"datetime.truncate",
"localdatetime.truncate",
"localtime.truncate",
"datetime.transaction",
"datetime.statement",
"datetime.realtime",
"date.transaction",
"date.statement",
"date.realtime",
"time.transaction",
"time.statement",
"time.realtime",
"localtime.transaction",
"localtime.statement",
"localtime.realtime",
"localdatetime.transaction",
"localdatetime.statement",
"localdatetime.realtime",
] {
ctx.register_udf(create_temporal_udf(name));
}
ctx.register_udf(create_duration_property_udf());
ctx.register_udf(create_temporal_property_udf());
ctx.register_udf(create_tostring_udf());
ctx.register_udf(create_cypher_sort_key_udf());
ctx.register_udf(create_has_null_udf());
ctx.register_udf(create_cypher_size_udf());
ctx.register_udf(create_cypher_starts_with_udf());
ctx.register_udf(create_cypher_ends_with_udf());
ctx.register_udf(create_cypher_contains_udf());
ctx.register_udf(create_cypher_list_compare_udf());
ctx.register_udf(create_cypher_xor_udf());
ctx.register_udf(create_cypher_equal_udf());
ctx.register_udf(create_cypher_not_equal_udf());
ctx.register_udf(create_cypher_gt_udf());
ctx.register_udf(create_cypher_gt_eq_udf());
ctx.register_udf(create_cypher_lt_udf());
ctx.register_udf(create_cypher_lt_eq_udf());
ctx.register_udf(create_cv_to_bool_udf());
ctx.register_udf(create_cypher_add_udf());
ctx.register_udf(create_cypher_sub_udf());
ctx.register_udf(create_cypher_mul_udf());
ctx.register_udf(create_cypher_div_udf());
ctx.register_udf(create_cypher_mod_udf());
ctx.register_udf(create_map_project_udf());
ctx.register_udf(create_make_cypher_list_udf());
ctx.register_udf(create_cypher_in_udf());
ctx.register_udf(create_cypher_list_concat_udf());
ctx.register_udf(create_cypher_list_append_udf());
ctx.register_udf(create_cypher_list_slice_udf());
ctx.register_udf(create_cypher_tail_udf());
ctx.register_udf(create_cypher_head_udf());
ctx.register_udf(create_cypher_last_udf());
ctx.register_udf(create_cypher_reverse_udf());
ctx.register_udf(create_cypher_substring_udf());
ctx.register_udf(create_cypher_split_udf());
ctx.register_udf(create_cypher_list_to_cv_udf());
ctx.register_udf(create_cypher_scalar_to_cv_udf());
for name in &["year", "month", "day", "hour", "minute", "second"] {
ctx.register_udf(create_temporal_udf(name));
}
ctx.register_udf(create_cypher_to_float64_udf());
ctx.register_udf(create_similar_to_udf());
ctx.register_udf(create_vector_similarity_udf());
ctx.register_udaf(create_cypher_min_udaf());
ctx.register_udaf(create_cypher_max_udaf());
ctx.register_udaf(create_cypher_sum_udaf());
ctx.register_udaf(create_cypher_collect_udaf());
ctx.register_udaf(create_cypher_percentile_disc_udaf());
ctx.register_udaf(create_cypher_percentile_cont_udaf());
register_btic_scalar_udfs(ctx)?;
ctx.register_udaf(create_btic_min_udaf());
ctx.register_udaf(create_btic_max_udaf());
ctx.register_udaf(create_btic_span_agg_udaf());
ctx.register_udaf(create_btic_count_at_udaf());
Ok(())
}
pub fn register_custom_udfs(
ctx: &SessionContext,
registry: &super::executor::custom_functions::CustomFunctionRegistry,
) -> DFResult<()> {
for (name, func) in registry.iter() {
let lower = name.to_lowercase();
ctx.register_udf(ScalarUDF::new_from_impl(CustomScalarUdf::new(
lower,
func.clone(),
)));
ctx.register_udf(ScalarUDF::new_from_impl(CustomScalarUdf::new(
name.to_string(),
func.clone(),
)));
}
Ok(())
}
struct CustomScalarUdf {
name: String,
func: super::executor::custom_functions::CustomScalarFn,
signature: Signature,
}
impl CustomScalarUdf {
fn new(name: String, func: super::executor::custom_functions::CustomScalarFn) -> Self {
Self {
signature: Signature::new(TypeSignature::VariadicAny, Volatility::Volatile),
name,
func,
}
}
}
impl std::fmt::Debug for CustomScalarUdf {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CustomScalarUdf")
.field("name", &self.name)
.finish()
}
}
impl_udf_eq_hash!(CustomScalarUdf);
impl ScalarUDFImpl for CustomScalarUdf {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
&self.name
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
Ok(DataType::LargeBinary)
}
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
let func = &self.func;
invoke_cypher_udf(args, &DataType::LargeBinary, |vals| {
func(vals).map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))
})
}
}
pub fn create_id_udf() -> ScalarUDF {
ScalarUDF::new_from_impl(IdUdf::new())
}
#[derive(Debug)]
struct IdUdf {
signature: Signature,
}
impl IdUdf {
fn new() -> Self {
Self {
signature: Signature::new(
TypeSignature::Exact(vec![DataType::UInt64]),
Volatility::Immutable,
),
}
}
}
impl_udf_eq_hash!(IdUdf);
impl ScalarUDFImpl for IdUdf {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"id"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
Ok(DataType::UInt64)
}
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
if args.args.is_empty() {
return Err(datafusion::error::DataFusionError::Execution(
"id(): requires 1 argument".to_string(),
));
}
Ok(args.args[0].clone())
}
}
pub fn create_type_udf() -> ScalarUDF {
ScalarUDF::new_from_impl(TypeUdf::new())
}
#[derive(Debug)]
struct TypeUdf {
signature: Signature,
}
impl TypeUdf {
fn new() -> Self {
Self {
signature: Signature::new(TypeSignature::Any(1), Volatility::Immutable),
}
}
}
impl_udf_eq_hash!(TypeUdf);
impl ScalarUDFImpl for TypeUdf {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"type"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
Ok(DataType::Utf8)
}
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
if args.args.is_empty() {
return Err(datafusion::error::DataFusionError::Execution(
"type(): requires 1 argument".to_string(),
));
}
let output_type = DataType::Utf8;
invoke_cypher_udf(args, &output_type, |val_args| {
if val_args.is_empty() {
return Err(datafusion::error::DataFusionError::Execution(
"type(): requires 1 argument".to_string(),
));
}
let val = &val_args[0];
match val {
Value::Map(map) => {
if let Some(Value::String(t)) = map.get("_type") {
Ok(Value::String(t.clone()))
} else {
Err(datafusion::error::DataFusionError::Execution(
"TypeError: InvalidArgumentValue - type() requires a relationship argument".to_string(),
))
}
}
Value::Null => Ok(Value::Null),
_ => Err(datafusion::error::DataFusionError::Execution(
"TypeError: InvalidArgumentValue - type() requires a relationship argument"
.to_string(),
)),
}
})
}
}
pub fn create_keys_udf() -> ScalarUDF {
ScalarUDF::new_from_impl(KeysUdf::new())
}
#[derive(Debug)]
struct KeysUdf {
signature: Signature,
}
impl KeysUdf {
fn new() -> Self {
Self {
signature: Signature::new(TypeSignature::Any(1), Volatility::Immutable),
}
}
}
impl_udf_eq_hash!(KeysUdf);
impl ScalarUDFImpl for KeysUdf {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"keys"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
Ok(DataType::List(Arc::new(
arrow::datatypes::Field::new_list_field(DataType::Utf8, true),
)))
}
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
let output_type = self.return_type(&[])?;
invoke_cypher_udf(args, &output_type, |val_args| {
if val_args.is_empty() {
return Err(datafusion::error::DataFusionError::Execution(
"keys(): requires 1 argument".to_string(),
));
}
let arg = &val_args[0];
let keys = match arg {
Value::Map(map) => {
let (source, is_entity) = match map.get("_all_props") {
Some(Value::Map(all)) => (all, true),
_ => (map, false),
};
let mut key_strings: Vec<String> = source
.iter()
.filter(|(k, v)| !k.starts_with('_') && (!is_entity || !v.is_null()))
.map(|(k, _)| k.clone())
.collect();
key_strings.sort();
key_strings
.into_iter()
.map(Value::String)
.collect::<Vec<_>>()
}
Value::Null => {
return Ok(Value::Null);
}
_ => {
vec![]
}
};
Ok(Value::List(keys))
})
}
}
pub fn create_properties_udf() -> ScalarUDF {
ScalarUDF::new_from_impl(PropertiesUdf::new())
}
#[derive(Debug)]
struct PropertiesUdf {
signature: Signature,
}
impl PropertiesUdf {
fn new() -> Self {
Self {
signature: Signature::new(TypeSignature::Any(1), Volatility::Immutable),
}
}
}
impl_udf_eq_hash!(PropertiesUdf);
impl ScalarUDFImpl for PropertiesUdf {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"properties"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
Ok(DataType::LargeBinary)
}
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
let output_type = self.return_type(&[])?;
invoke_cypher_udf(args, &output_type, |val_args| {
if val_args.is_empty() {
return Err(datafusion::error::DataFusionError::Execution(
"properties(): requires 1 argument".to_string(),
));
}
let arg = &val_args[0];
match arg {
Value::Map(map) => {
let identity_null = map
.get("_vid")
.map(|v| v.is_null())
.or_else(|| map.get("_eid").map(|v| v.is_null()))
.unwrap_or(false);
if identity_null {
return Ok(Value::Null);
}
let source = match map.get("_all_props") {
Some(Value::Map(all)) => all,
_ => map,
};
let filtered: std::collections::HashMap<String, Value> = source
.iter()
.filter(|(k, _)| !k.starts_with('_'))
.map(|(k, v)| (k.clone(), v.clone()))
.collect();
Ok(Value::Map(filtered))
}
_ => Ok(Value::Null),
}
})
}
}
pub fn create_index_udf() -> ScalarUDF {
ScalarUDF::new_from_impl(IndexUdf::new())
}
#[derive(Debug)]
struct IndexUdf {
signature: Signature,
}
impl IndexUdf {
fn new() -> Self {
Self {
signature: Signature::any(2, Volatility::Immutable),
}
}
}
impl_udf_eq_hash!(IndexUdf);
impl ScalarUDFImpl for IndexUdf {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"index"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
Ok(DataType::LargeBinary)
}
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
let output_type = self.return_type(&[])?;
invoke_cypher_udf(args, &output_type, |val_args| {
if val_args.len() != 2 {
return Err(datafusion::error::DataFusionError::Execution(
"index(): requires 2 arguments".to_string(),
));
}
let container = &val_args[0];
let index = &val_args[1];
let index_as_int = index.as_i64();
let result = match container {
Value::List(arr) => {
if let Some(i) = index_as_int {
let idx = if i < 0 {
let pos = arr.len() as i64 + i;
if pos < 0 { -1 } else { pos }
} else {
i
};
if idx >= 0 && (idx as usize) < arr.len() {
arr[idx as usize].clone()
} else {
Value::Null
}
} else if index.is_null() {
Value::Null
} else {
return Err(datafusion::error::DataFusionError::Execution(format!(
"TypeError: InvalidArgumentType - list index must be an integer, got: {:?}",
index
)));
}
}
Value::Map(map) => {
if let Some(key) = index.as_str() {
if let Some(val) = map.get(key) {
val.clone()
} else if let Some(Value::Map(all_props)) = map.get("_all_props") {
all_props.get(key).cloned().unwrap_or(Value::Null)
} else if let Some(Value::Map(props)) = map.get("properties") {
props.get(key).cloned().unwrap_or(Value::Null)
} else {
Value::Null
}
} else if !index.is_null() {
return Err(datafusion::error::DataFusionError::Execution(
"index(): map index must be a string".to_string(),
));
} else {
Value::Null
}
}
Value::Node(node) => {
if let Some(key) = index.as_str() {
node.properties.get(key).cloned().unwrap_or(Value::Null)
} else if !index.is_null() {
return Err(datafusion::error::DataFusionError::Execution(
"index(): node index must be a string".to_string(),
));
} else {
Value::Null
}
}
Value::Edge(edge) => {
if let Some(key) = index.as_str() {
edge.properties.get(key).cloned().unwrap_or(Value::Null)
} else if !index.is_null() {
return Err(datafusion::error::DataFusionError::Execution(
"index(): edge index must be a string".to_string(),
));
} else {
Value::Null
}
}
Value::Null => Value::Null,
_ => {
return Err(datafusion::error::DataFusionError::Execution(format!(
"TypeError: InvalidArgumentType - cannot index into {:?}",
container
)));
}
};
Ok(result)
})
}
}
pub fn create_labels_udf() -> ScalarUDF {
ScalarUDF::new_from_impl(LabelsUdf::new())
}
#[derive(Debug)]
struct LabelsUdf {
signature: Signature,
}
impl LabelsUdf {
fn new() -> Self {
Self {
signature: Signature::any(1, Volatility::Immutable),
}
}
}
impl_udf_eq_hash!(LabelsUdf);
impl ScalarUDFImpl for LabelsUdf {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"labels"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
Ok(DataType::List(Arc::new(
arrow::datatypes::Field::new_list_field(DataType::Utf8, true),
)))
}
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
let output_type = self.return_type(&[])?;
invoke_cypher_udf(args, &output_type, |val_args| {
if val_args.is_empty() {
return Err(datafusion::error::DataFusionError::Execution(
"labels(): requires 1 argument".to_string(),
));
}
let node = &val_args[0];
match node {
Value::Map(map) => {
if let Some(Value::List(arr)) = map.get("_labels") {
Ok(Value::List(arr.clone()))
} else {
Err(datafusion::error::DataFusionError::Execution(
"TypeError: InvalidArgumentValue - labels() requires a node argument"
.to_string(),
))
}
}
Value::Null => Ok(Value::Null),
_ => Err(datafusion::error::DataFusionError::Execution(
"TypeError: InvalidArgumentValue - labels() requires a node argument"
.to_string(),
)),
}
})
}
}
pub fn create_nodes_udf() -> ScalarUDF {
ScalarUDF::new_from_impl(NodesUdf::new())
}
#[derive(Debug)]
struct NodesUdf {
signature: Signature,
}
impl NodesUdf {
fn new() -> Self {
Self {
signature: Signature::any(1, Volatility::Immutable),
}
}
}
impl_udf_eq_hash!(NodesUdf);
impl ScalarUDFImpl for NodesUdf {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"nodes"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
Ok(DataType::LargeBinary)
}
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
let output_type = self.return_type(&[])?;
invoke_cypher_udf(args, &output_type, |val_args| {
if val_args.is_empty() {
return Err(datafusion::error::DataFusionError::Execution(
"nodes(): requires 1 argument".to_string(),
));
}
let path = &val_args[0];
let nodes = match path {
Value::Map(map) => map.get("nodes").cloned().unwrap_or(Value::Null),
_ => Value::Null,
};
Ok(nodes)
})
}
}
pub fn create_relationships_udf() -> ScalarUDF {
ScalarUDF::new_from_impl(RelationshipsUdf::new())
}
#[derive(Debug)]
struct RelationshipsUdf {
signature: Signature,
}
impl RelationshipsUdf {
fn new() -> Self {
Self {
signature: Signature::any(1, Volatility::Immutable),
}
}
}
impl_udf_eq_hash!(RelationshipsUdf);
impl ScalarUDFImpl for RelationshipsUdf {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"relationships"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
Ok(DataType::LargeBinary)
}
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
let output_type = self.return_type(&[])?;
invoke_cypher_udf(args, &output_type, |val_args| {
if val_args.is_empty() {
return Err(datafusion::error::DataFusionError::Execution(
"relationships(): requires 1 argument".to_string(),
));
}
let path = &val_args[0];
let rels = match path {
Value::Map(map) => map.get("relationships").cloned().unwrap_or(Value::Null),
_ => Value::Null,
};
Ok(rels)
})
}
}
pub fn create_startnode_udf() -> ScalarUDF {
ScalarUDF::new_from_impl(StartNodeUdf::new())
}
#[derive(Debug)]
struct StartNodeUdf {
signature: Signature,
}
impl StartNodeUdf {
fn new() -> Self {
Self {
signature: Signature::new(TypeSignature::VariadicAny, Volatility::Immutable),
}
}
}
impl_udf_eq_hash!(StartNodeUdf);
impl ScalarUDFImpl for StartNodeUdf {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"startnode"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
Ok(DataType::LargeBinary)
}
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
let output_type = DataType::LargeBinary;
invoke_cypher_udf(args, &output_type, |val_args| {
startnode_endnode_impl(val_args, true)
})
}
}
pub fn create_endnode_udf() -> ScalarUDF {
ScalarUDF::new_from_impl(EndNodeUdf::new())
}
#[derive(Debug)]
struct EndNodeUdf {
signature: Signature,
}
impl EndNodeUdf {
fn new() -> Self {
Self {
signature: Signature::new(TypeSignature::VariadicAny, Volatility::Immutable),
}
}
}
impl_udf_eq_hash!(EndNodeUdf);
impl ScalarUDFImpl for EndNodeUdf {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"endnode"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
Ok(DataType::LargeBinary)
}
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
let output_type = DataType::LargeBinary;
invoke_cypher_udf(args, &output_type, |val_args| {
startnode_endnode_impl(val_args, false)
})
}
}
fn startnode_endnode_impl(val_args: &[Value], is_start: bool) -> DFResult<Value> {
if val_args.is_empty() {
let fn_name = if is_start { "startNode" } else { "endNode" };
return Err(datafusion::error::DataFusionError::Execution(format!(
"{fn_name}(): requires at least 1 argument"
)));
}
let edge_val = &val_args[0];
let target_vid = extract_endpoint_vid(edge_val, is_start);
let target_vid = match target_vid {
Some(vid) => vid,
None => return Ok(Value::Null),
};
for node_val in val_args.iter().skip(1) {
if let Some(vid) = extract_vid(node_val)
&& vid == target_vid
{
return Ok(node_val.clone());
}
}
let mut map = std::collections::HashMap::new();
map.insert("_vid".to_string(), Value::Int(target_vid as i64));
Ok(Value::Map(map))
}
fn extract_endpoint_vid(val: &Value, is_start: bool) -> Option<u64> {
match val {
Value::Edge(edge) => {
let vid = if is_start { edge.src } else { edge.dst };
Some(vid.as_u64())
}
Value::Map(map) => {
let key = if is_start { "_src_vid" } else { "_dst_vid" };
if let Some(v) = map.get(key) {
return v.as_u64();
}
let key2 = if is_start { "_src" } else { "_dst" };
if let Some(v) = map.get(key2) {
return v.as_u64();
}
let node_key = if is_start { "_startNode" } else { "_endNode" };
if let Some(node_val) = map.get(node_key) {
return extract_vid(node_val);
}
None
}
_ => None,
}
}
fn extract_vid(val: &Value) -> Option<u64> {
match val {
Value::Map(map) => map.get("_vid").and_then(|v| v.as_u64()),
_ => None,
}
}
fn extract_i64_range_arg(arg: &ColumnarValue, row_idx: usize, name: &str) -> DFResult<i64> {
match arg {
ColumnarValue::Scalar(sv) => match sv {
ScalarValue::Int8(Some(v)) => Ok(*v as i64),
ScalarValue::Int16(Some(v)) => Ok(*v as i64),
ScalarValue::Int32(Some(v)) => Ok(*v as i64),
ScalarValue::Int64(Some(v)) => Ok(*v),
ScalarValue::UInt8(Some(v)) => Ok(*v as i64),
ScalarValue::UInt16(Some(v)) => Ok(*v as i64),
ScalarValue::UInt32(Some(v)) => Ok(*v as i64),
ScalarValue::UInt64(Some(v)) => Ok(*v as i64),
ScalarValue::LargeBinary(Some(bytes)) => {
scalar_binary_to_value(bytes).as_i64().ok_or_else(|| {
datafusion::error::DataFusionError::Execution(format!(
"ArgumentError: InvalidArgumentType - range() {} must be an integer",
name
))
})
}
_ => Err(datafusion::error::DataFusionError::Execution(format!(
"ArgumentError: InvalidArgumentType - range() {} must be an integer",
name
))),
},
ColumnarValue::Array(arr) => {
if row_idx >= arr.len() || arr.is_null(row_idx) {
return Err(datafusion::error::DataFusionError::Execution(format!(
"ArgumentError: InvalidArgumentType - range() {} must be an integer",
name
)));
}
if !arr.is_empty() {
use datafusion::arrow::array::{
Int8Array, Int16Array, Int32Array, Int64Array, UInt8Array, UInt16Array,
UInt32Array, UInt64Array,
};
match arr.data_type() {
DataType::Int8 => Ok(arr
.as_any()
.downcast_ref::<Int8Array>()
.unwrap()
.value(row_idx) as i64),
DataType::Int16 => Ok(arr
.as_any()
.downcast_ref::<Int16Array>()
.unwrap()
.value(row_idx) as i64),
DataType::Int32 => Ok(arr
.as_any()
.downcast_ref::<Int32Array>()
.unwrap()
.value(row_idx) as i64),
DataType::Int64 => Ok(arr
.as_any()
.downcast_ref::<Int64Array>()
.unwrap()
.value(row_idx)),
DataType::UInt8 => Ok(arr
.as_any()
.downcast_ref::<UInt8Array>()
.unwrap()
.value(row_idx) as i64),
DataType::UInt16 => Ok(arr
.as_any()
.downcast_ref::<UInt16Array>()
.unwrap()
.value(row_idx) as i64),
DataType::UInt32 => Ok(arr
.as_any()
.downcast_ref::<UInt32Array>()
.unwrap()
.value(row_idx) as i64),
DataType::UInt64 => Ok(arr
.as_any()
.downcast_ref::<UInt64Array>()
.unwrap()
.value(row_idx) as i64),
DataType::LargeBinary => {
let bytes = arr
.as_any()
.downcast_ref::<LargeBinaryArray>()
.unwrap()
.value(row_idx);
scalar_binary_to_value(bytes).as_i64().ok_or_else(|| {
datafusion::error::DataFusionError::Execution(format!(
"ArgumentError: InvalidArgumentType - range() {} must be an integer",
name
))
})
}
_ => Err(datafusion::error::DataFusionError::Execution(format!(
"ArgumentError: InvalidArgumentType - range() {} must be an integer",
name
))),
}
} else {
Err(datafusion::error::DataFusionError::Execution(format!(
"ArgumentError: InvalidArgumentType - range() {} must be an integer",
name
)))
}
}
}
}
pub fn create_range_udf() -> ScalarUDF {
ScalarUDF::new_from_impl(RangeUdf::new())
}
#[derive(Debug)]
struct RangeUdf {
signature: Signature,
}
impl RangeUdf {
fn new() -> Self {
Self {
signature: Signature::one_of(
vec![TypeSignature::Any(2), TypeSignature::Any(3)],
Volatility::Immutable,
),
}
}
}
impl_udf_eq_hash!(RangeUdf);
impl ScalarUDFImpl for RangeUdf {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"range"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
Ok(DataType::List(Arc::new(
arrow::datatypes::Field::new_list_field(DataType::Int64, true),
)))
}
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
if args.args.len() < 2 || args.args.len() > 3 {
return Err(datafusion::error::DataFusionError::Execution(
"range(): requires 2 or 3 arguments".to_string(),
));
}
let len = args
.args
.iter()
.find_map(|arg| match arg {
ColumnarValue::Array(arr) => Some(arr.len()),
_ => None,
})
.unwrap_or(1);
let mut list_builder =
arrow_array::builder::ListBuilder::new(arrow_array::builder::Int64Builder::new());
for row_idx in 0..len {
let start = extract_i64_range_arg(&args.args[0], row_idx, "start")?;
let end = extract_i64_range_arg(&args.args[1], row_idx, "end")?;
let step = if args.args.len() == 3 {
extract_i64_range_arg(&args.args[2], row_idx, "step")?
} else {
1
};
if step == 0 {
return Err(datafusion::error::DataFusionError::Execution(
"range(): step cannot be zero".to_string(),
));
}
if step > 0 && start <= end {
let mut current = start;
while current <= end {
list_builder.values().append_value(current);
current += step;
}
} else if step < 0 && start >= end {
let mut current = start;
while current >= end {
list_builder.values().append_value(current);
current += step;
}
}
list_builder.append(true);
}
let list_arr = Arc::new(list_builder.finish()) as ArrayRef;
if len == 1
&& args
.args
.iter()
.all(|arg| matches!(arg, ColumnarValue::Scalar(_)))
{
Ok(ColumnarValue::Scalar(ScalarValue::try_from_array(
&list_arr, 0,
)?))
} else {
Ok(ColumnarValue::Array(list_arr))
}
}
}
fn invoke_binary_bitwise_op<F>(
args: &ScalarFunctionArgs,
name: &str,
op: F,
) -> DFResult<ColumnarValue>
where
F: Fn(i64, i64) -> i64,
{
use arrow_array::Int64Array;
use datafusion::common::ScalarValue;
use datafusion::error::DataFusionError;
if args.args.len() != 2 {
return Err(DataFusionError::Execution(format!(
"{}(): requires exactly 2 arguments",
name
)));
}
let left = &args.args[0];
let right = &args.args[1];
match (left, right) {
(
ColumnarValue::Scalar(ScalarValue::Int64(Some(l))),
ColumnarValue::Scalar(ScalarValue::Int64(Some(r))),
) => Ok(ColumnarValue::Scalar(ScalarValue::Int64(Some(op(*l, *r))))),
(ColumnarValue::Array(l_arr), ColumnarValue::Array(r_arr)) => {
let l_arr = l_arr.as_any().downcast_ref::<Int64Array>().ok_or_else(|| {
DataFusionError::Execution(format!("{}(): left array must be Int64", name))
})?;
let r_arr = r_arr.as_any().downcast_ref::<Int64Array>().ok_or_else(|| {
DataFusionError::Execution(format!("{}(): right array must be Int64", name))
})?;
let result: Int64Array = l_arr
.iter()
.zip(r_arr.iter())
.map(|(l, r)| match (l, r) {
(Some(l), Some(r)) => Some(op(l, r)),
_ => None,
})
.collect();
Ok(ColumnarValue::Array(Arc::new(result)))
}
_ => Err(DataFusionError::Execution(format!(
"{}(): mixed scalar/array not supported",
name
))),
}
}
fn invoke_unary_bitwise_op<F>(
args: &ScalarFunctionArgs,
name: &str,
op: F,
) -> DFResult<ColumnarValue>
where
F: Fn(i64) -> i64,
{
use arrow_array::Int64Array;
use datafusion::common::ScalarValue;
use datafusion::error::DataFusionError;
if args.args.len() != 1 {
return Err(DataFusionError::Execution(format!(
"{}(): requires exactly 1 argument",
name
)));
}
let operand = &args.args[0];
match operand {
ColumnarValue::Scalar(ScalarValue::Int64(Some(v))) => {
Ok(ColumnarValue::Scalar(ScalarValue::Int64(Some(op(*v)))))
}
ColumnarValue::Array(arr) => {
let arr = arr.as_any().downcast_ref::<Int64Array>().ok_or_else(|| {
DataFusionError::Execution(format!("{}(): array must be Int64", name))
})?;
let result: Int64Array = arr.iter().map(|v| v.map(&op)).collect();
Ok(ColumnarValue::Array(Arc::new(result)))
}
_ => Err(DataFusionError::Execution(format!(
"{}(): invalid argument type",
name
))),
}
}
macro_rules! define_binary_bitwise_udf {
($struct_name:ident, $udf_name:literal, $op:expr) => {
#[derive(Debug)]
struct $struct_name {
signature: Signature,
}
impl $struct_name {
fn new() -> Self {
Self {
signature: Signature::exact(
vec![DataType::Int64, DataType::Int64],
Volatility::Immutable,
),
}
}
}
impl_udf_eq_hash!($struct_name);
impl ScalarUDFImpl for $struct_name {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
$udf_name
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
Ok(DataType::Int64)
}
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
invoke_binary_bitwise_op(&args, $udf_name, $op)
}
}
};
}
macro_rules! define_unary_bitwise_udf {
($struct_name:ident, $udf_name:literal, $op:expr) => {
#[derive(Debug)]
struct $struct_name {
signature: Signature,
}
impl $struct_name {
fn new() -> Self {
Self {
signature: Signature::exact(vec![DataType::Int64], Volatility::Immutable),
}
}
}
impl_udf_eq_hash!($struct_name);
impl ScalarUDFImpl for $struct_name {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
$udf_name
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
Ok(DataType::Int64)
}
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
invoke_unary_bitwise_op(&args, $udf_name, $op)
}
}
};
}
define_binary_bitwise_udf!(BitwiseOrUdf, "uni.bitwise.or", |l, r| l | r);
define_binary_bitwise_udf!(BitwiseAndUdf, "uni.bitwise.and", |l, r| l & r);
define_binary_bitwise_udf!(BitwiseXorUdf, "uni.bitwise.xor", |l, r| l ^ r);
define_binary_bitwise_udf!(ShiftLeftUdf, "uni.bitwise.shiftLeft", |l, r| l << r);
define_binary_bitwise_udf!(ShiftRightUdf, "uni.bitwise.shiftRight", |l, r| l >> r);
define_unary_bitwise_udf!(BitwiseNotUdf, "uni.bitwise.not", |v| !v);
pub fn create_bitwise_or_udf() -> ScalarUDF {
ScalarUDF::new_from_impl(BitwiseOrUdf::new())
}
pub fn create_bitwise_and_udf() -> ScalarUDF {
ScalarUDF::new_from_impl(BitwiseAndUdf::new())
}
pub fn create_bitwise_xor_udf() -> ScalarUDF {
ScalarUDF::new_from_impl(BitwiseXorUdf::new())
}
pub fn create_bitwise_not_udf() -> ScalarUDF {
ScalarUDF::new_from_impl(BitwiseNotUdf::new())
}
pub fn create_shift_left_udf() -> ScalarUDF {
ScalarUDF::new_from_impl(ShiftLeftUdf::new())
}
pub fn create_shift_right_udf() -> ScalarUDF {
ScalarUDF::new_from_impl(ShiftRightUdf::new())
}
fn create_temporal_udf(name: &str) -> ScalarUDF {
ScalarUDF::new_from_impl(TemporalUdf::new(name.to_string()))
}
#[derive(Debug)]
struct TemporalUdf {
name: String,
signature: Signature,
}
impl TemporalUdf {
fn new(name: String) -> Self {
Self {
name,
signature: Signature::new(
TypeSignature::OneOf(vec![
TypeSignature::Exact(vec![]),
TypeSignature::VariadicAny,
]),
Volatility::Immutable,
),
}
}
}
impl_udf_eq_hash!(TemporalUdf);
impl ScalarUDFImpl for TemporalUdf {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
&self.name
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
let name = self.name.to_lowercase();
match name.as_str() {
"year" | "month" | "day" | "hour" | "minute" | "second" => Ok(DataType::Int64),
"datetime"
| "localdatetime"
| "date"
| "time"
| "localtime"
| "duration"
| "date.truncate"
| "time.truncate"
| "datetime.truncate"
| "localdatetime.truncate"
| "localtime.truncate"
| "duration.between"
| "duration.inmonths"
| "duration.indays"
| "duration.inseconds"
| "datetime.fromepoch"
| "datetime.fromepochmillis"
| "datetime.transaction"
| "datetime.statement"
| "datetime.realtime"
| "date.transaction"
| "date.statement"
| "date.realtime"
| "time.transaction"
| "time.statement"
| "time.realtime"
| "localtime.transaction"
| "localtime.statement"
| "localtime.realtime"
| "localdatetime.transaction"
| "localdatetime.statement"
| "localdatetime.realtime"
| "btic" => Ok(DataType::LargeBinary),
_ => Ok(DataType::Utf8),
}
}
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
let func_name = self.name.to_uppercase();
let output_type = self.return_type(&[])?;
invoke_cypher_udf(args, &output_type, |val_args| {
crate::query::datetime::eval_datetime_function(&func_name, val_args).map_err(|e| {
datafusion::error::DataFusionError::Execution(format!("{}(): {}", self.name, e))
})
})
}
}
fn create_duration_property_udf() -> ScalarUDF {
ScalarUDF::new_from_impl(DurationPropertyUdf::new())
}
#[derive(Debug)]
struct DurationPropertyUdf {
signature: Signature,
}
impl DurationPropertyUdf {
fn new() -> Self {
Self {
signature: Signature::new(
TypeSignature::Exact(vec![DataType::Utf8, DataType::Utf8]),
Volatility::Immutable,
),
}
}
}
impl_udf_eq_hash!(DurationPropertyUdf);
impl ScalarUDFImpl for DurationPropertyUdf {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"_duration_property"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
Ok(DataType::Int64)
}
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
let output_type = self.return_type(&[])?;
invoke_cypher_udf(args, &output_type, |val_args| {
if val_args.len() != 2 {
return Err(datafusion::error::DataFusionError::Execution(
"_duration_property(): requires 2 arguments (duration_string, component)"
.to_string(),
));
}
let dur_string_owned;
let dur_str = match &val_args[0] {
Value::String(s) => s.as_str(),
Value::Temporal(uni_common::TemporalValue::Duration { .. }) => {
dur_string_owned = val_args[0].to_string();
&dur_string_owned
}
Value::Null => return Ok(Value::Null),
_ => {
return Err(datafusion::error::DataFusionError::Execution(
"_duration_property(): duration must be a string or temporal duration"
.to_string(),
));
}
};
let component = match &val_args[1] {
Value::String(s) => s,
_ => {
return Err(datafusion::error::DataFusionError::Execution(
"_duration_property(): component must be a string".to_string(),
));
}
};
crate::query::datetime::eval_duration_accessor(dur_str, component).map_err(|e| {
datafusion::error::DataFusionError::Execution(format!(
"_duration_property(): {}",
e
))
})
})
}
}
fn create_tostring_udf() -> ScalarUDF {
ScalarUDF::new_from_impl(ToStringUdf::new())
}
#[derive(Debug)]
struct ToStringUdf {
signature: Signature,
}
impl ToStringUdf {
fn new() -> Self {
Self {
signature: Signature::variadic_any(Volatility::Immutable),
}
}
}
impl_udf_eq_hash!(ToStringUdf);
impl ScalarUDFImpl for ToStringUdf {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"tostring"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
Ok(DataType::Utf8)
}
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
let output_type = self.return_type(&[])?;
invoke_cypher_udf(args, &output_type, |val_args| {
if val_args.is_empty() {
return Err(datafusion::error::DataFusionError::Execution(
"toString(): requires 1 argument".to_string(),
));
}
match &val_args[0] {
Value::Null => Ok(Value::Null),
Value::String(s) => Ok(Value::String(s.clone())),
Value::Int(i) => Ok(Value::String(i.to_string())),
Value::Float(f) => Ok(Value::String(f.to_string())),
Value::Bool(b) => Ok(Value::String(b.to_string())),
Value::Temporal(t) => Ok(Value::String(t.to_string())),
other => {
let type_name = match other {
Value::List(_) => "List",
Value::Map(_) => "Map",
Value::Node { .. } => "Node",
Value::Edge { .. } => "Relationship",
Value::Path { .. } => "Path",
_ => "Unknown",
};
Err(datafusion::error::DataFusionError::Execution(format!(
"TypeError: InvalidArgumentValue - toString() does not accept {} values",
type_name
)))
}
}
})
}
}
fn create_temporal_property_udf() -> ScalarUDF {
ScalarUDF::new_from_impl(TemporalPropertyUdf::new())
}
#[derive(Debug)]
struct TemporalPropertyUdf {
signature: Signature,
}
impl TemporalPropertyUdf {
fn new() -> Self {
Self {
signature: Signature::variadic_any(Volatility::Immutable),
}
}
}
impl_udf_eq_hash!(TemporalPropertyUdf);
impl ScalarUDFImpl for TemporalPropertyUdf {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"_temporal_property"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
Ok(DataType::LargeBinary)
}
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
let output_type = self.return_type(&[])?;
invoke_cypher_udf(args, &output_type, |val_args| {
if val_args.len() != 2 {
return Err(datafusion::error::DataFusionError::Execution(
"_temporal_property(): requires 2 arguments (temporal_value, component)"
.to_string(),
));
}
let component = match &val_args[1] {
Value::String(s) => s.clone(),
_ => {
return Err(datafusion::error::DataFusionError::Execution(
"_temporal_property(): component must be a string".to_string(),
));
}
};
crate::query::datetime::eval_temporal_accessor_value(&val_args[0], &component).map_err(
|e| {
datafusion::error::DataFusionError::Execution(format!(
"_temporal_property(): {}",
e
))
},
)
})
}
}
macro_rules! downcast_arr {
($arr:expr, $array_type:ty) => {
$arr.as_any().downcast_ref::<$array_type>().ok_or_else(|| {
datafusion::error::DataFusionError::Execution(format!(
"Failed to downcast to {}",
stringify!($array_type)
))
})?
};
}
fn cypher_type_name(val: &Value) -> &'static str {
match val {
Value::Null => "Null",
Value::Bool(_) => "Boolean",
Value::Int(_) => "Integer",
Value::Float(_) => "Float",
Value::String(_) => "String",
Value::Bytes(_) => "Bytes",
Value::List(_) => "List",
Value::Map(_) => "Map",
Value::Node(_) => "Node",
Value::Edge(_) => "Relationship",
Value::Path(_) => "Path",
Value::Vector(_) => "Vector",
Value::Temporal(_) => "Temporal",
_ => "Unknown",
}
}
fn string_to_value(s: &str) -> Value {
if (s.starts_with('{') || s.starts_with('[') || s.starts_with('"'))
&& let Ok(obj) = serde_json::from_str::<serde_json::Value>(s)
{
return Value::from(obj);
}
Value::String(s.to_string())
}
fn get_value_from_array(arr: &ArrayRef, row: usize) -> DFResult<Value> {
if arr.is_null(row) {
return Ok(Value::Null);
}
match arr.data_type() {
DataType::LargeBinary => {
let typed = downcast_arr!(arr, LargeBinaryArray);
let bytes = typed.value(row);
if let Ok(val) = uni_common::cypher_value_codec::decode(bytes) {
return Ok(val);
}
Ok(serde_json::from_slice::<serde_json::Value>(bytes)
.map(Value::from)
.unwrap_or(Value::Null))
}
DataType::Int64 => Ok(Value::Int(downcast_arr!(arr, Int64Array).value(row))),
DataType::Float64 => Ok(Value::Float(downcast_arr!(arr, Float64Array).value(row))),
DataType::Utf8 => Ok(string_to_value(downcast_arr!(arr, StringArray).value(row))),
DataType::LargeUtf8 => Ok(string_to_value(
downcast_arr!(arr, LargeStringArray).value(row),
)),
DataType::Boolean => Ok(Value::Bool(downcast_arr!(arr, BooleanArray).value(row))),
DataType::UInt64 => Ok(Value::Int(downcast_arr!(arr, UInt64Array).value(row) as i64)),
DataType::Int32 => Ok(Value::Int(downcast_arr!(arr, Int32Array).value(row) as i64)),
DataType::Float32 => Ok(Value::Float(
downcast_arr!(arr, Float32Array).value(row) as f64
)),
_ => {
let scalar = ScalarValue::try_from_array(arr, row).map_err(|e| {
datafusion::error::DataFusionError::Execution(format!(
"Cannot extract scalar from array at row {}: {}",
row, e
))
})?;
scalar_to_value(&scalar)
}
}
}
fn get_value_args_for_row(args: &[ColumnarValue], row: usize) -> DFResult<Vec<Value>> {
args.iter()
.map(|arg| match arg {
ColumnarValue::Scalar(scalar) => scalar_to_value(scalar),
ColumnarValue::Array(arr) => get_value_from_array(arr, row),
})
.collect()
}
fn invoke_cypher_udf<F>(
args: ScalarFunctionArgs,
output_type: &DataType,
f: F,
) -> DFResult<ColumnarValue>
where
F: Fn(&[Value]) -> DFResult<Value>,
{
let len = args
.args
.iter()
.find_map(|arg| match arg {
ColumnarValue::Array(arr) => Some(arr.len()),
_ => None,
})
.unwrap_or(1);
if len == 1
&& args
.args
.iter()
.all(|a| matches!(a, ColumnarValue::Scalar(_)))
{
let row_args = get_value_args_for_row(&args.args, 0)?;
let res = f(&row_args)?;
if matches!(output_type, DataType::LargeBinary | DataType::List(_)) {
let arr = values_to_array(&[res], output_type)
.map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))?;
return Ok(ColumnarValue::Scalar(ScalarValue::try_from_array(&arr, 0)?));
}
if res.is_null() {
let typed_null = ScalarValue::try_from(output_type).unwrap_or(ScalarValue::Utf8(None));
return Ok(ColumnarValue::Scalar(typed_null));
}
return value_to_columnar(&res);
}
let mut results = Vec::with_capacity(len);
for i in 0..len {
let row_args = get_value_args_for_row(&args.args, i)?;
results.push(f(&row_args)?);
}
let arr = values_to_array(&results, output_type)
.map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))?;
Ok(ColumnarValue::Array(arr))
}
fn scalar_arr_to_value(arr: &dyn arrow::array::Array) -> DFResult<Value> {
if arr.is_empty() || arr.is_null(0) {
Ok(Value::Null)
} else {
Ok(uni_store::storage::arrow_convert::arrow_to_value(
arr, 0, None,
))
}
}
fn resolve_timezone_offset(tz_name: &str, nanos_utc: i64) -> i32 {
if tz_name == "UTC" || tz_name == "Z" {
return 0;
}
if let Ok(tz) = tz_name.parse::<chrono_tz::Tz>() {
let dt = chrono::DateTime::from_timestamp_nanos(nanos_utc).with_timezone(&tz);
dt.offset().fix().local_minus_utc()
} else {
0
}
}
fn duration_micros_to_value(micros: i64) -> Value {
let dur = crate::query::datetime::CypherDuration::from_micros(micros);
Value::Temporal(uni_common::TemporalValue::Duration {
months: dur.months,
days: dur.days,
nanos: dur.nanos,
})
}
fn timestamp_nanos_to_value(nanos: i64, tz: Option<&Arc<str>>) -> DFResult<Value> {
if let Some(tz_str) = tz {
let offset = resolve_timezone_offset(tz_str.as_ref(), nanos);
let tz_name = if tz_str.as_ref() == "UTC" {
None
} else {
Some(tz_str.to_string())
};
Ok(Value::Temporal(uni_common::TemporalValue::DateTime {
nanos_since_epoch: nanos,
offset_seconds: offset,
timezone_name: tz_name,
}))
} else {
Ok(Value::Temporal(uni_common::TemporalValue::LocalDateTime {
nanos_since_epoch: nanos,
}))
}
}
pub(crate) fn scalar_to_value(scalar: &ScalarValue) -> DFResult<Value> {
match scalar {
ScalarValue::Utf8(Some(s)) | ScalarValue::LargeUtf8(Some(s)) => {
if (s.starts_with('{') || s.starts_with('[') || s.starts_with('"'))
&& let Ok(obj) = serde_json::from_str::<serde_json::Value>(s)
{
return Ok(Value::from(obj));
}
Ok(Value::String(s.clone()))
}
ScalarValue::LargeBinary(Some(b)) => {
if let Ok(val) = uni_common::cypher_value_codec::decode(b) {
return Ok(val);
}
if let Ok(obj) = serde_json::from_slice::<serde_json::Value>(b) {
Ok(Value::from(obj))
} else {
Ok(Value::Null)
}
}
ScalarValue::Int64(Some(i)) => Ok(Value::Int(*i)),
ScalarValue::Int32(Some(i)) => Ok(Value::Int(*i as i64)),
ScalarValue::Float64(Some(f)) => {
Ok(Value::Float(*f))
}
ScalarValue::Boolean(Some(b)) => Ok(Value::Bool(*b)),
ScalarValue::Struct(arr) => scalar_arr_to_value(arr.as_ref()),
ScalarValue::List(arr) => scalar_arr_to_value(arr.as_ref()),
ScalarValue::LargeList(arr) => scalar_arr_to_value(arr.as_ref()),
ScalarValue::FixedSizeList(arr) => scalar_arr_to_value(arr.as_ref()),
ScalarValue::UInt64(Some(u)) => Ok(Value::Int(*u as i64)),
ScalarValue::UInt32(Some(u)) => Ok(Value::Int(*u as i64)),
ScalarValue::UInt16(Some(u)) => Ok(Value::Int(*u as i64)),
ScalarValue::UInt8(Some(u)) => Ok(Value::Int(*u as i64)),
ScalarValue::Int16(Some(i)) => Ok(Value::Int(*i as i64)),
ScalarValue::Int8(Some(i)) => Ok(Value::Int(*i as i64)),
ScalarValue::Date32(Some(days)) => Ok(Value::Temporal(uni_common::TemporalValue::Date {
days_since_epoch: *days,
})),
ScalarValue::Date64(Some(millis)) => {
let days = (*millis / 86_400_000) as i32;
Ok(Value::Temporal(uni_common::TemporalValue::Date {
days_since_epoch: days,
}))
}
ScalarValue::TimestampNanosecond(Some(nanos), tz) => {
timestamp_nanos_to_value(*nanos, tz.as_ref())
}
ScalarValue::TimestampMicrosecond(Some(micros), tz) => {
timestamp_nanos_to_value(*micros * 1_000, tz.as_ref())
}
ScalarValue::TimestampMillisecond(Some(millis), tz) => {
timestamp_nanos_to_value(*millis * 1_000_000, tz.as_ref())
}
ScalarValue::TimestampSecond(Some(secs), tz) => {
timestamp_nanos_to_value(*secs * 1_000_000_000, tz.as_ref())
}
ScalarValue::Time64Nanosecond(Some(nanos)) => {
Ok(Value::Temporal(uni_common::TemporalValue::LocalTime {
nanos_since_midnight: *nanos,
}))
}
ScalarValue::Time64Microsecond(Some(micros)) => {
Ok(Value::Temporal(uni_common::TemporalValue::LocalTime {
nanos_since_midnight: *micros * 1_000,
}))
}
ScalarValue::IntervalMonthDayNano(Some(v)) => {
Ok(Value::Temporal(uni_common::TemporalValue::Duration {
months: v.months as i64,
days: v.days as i64,
nanos: v.nanoseconds,
}))
}
ScalarValue::DurationMicrosecond(Some(micros)) => Ok(duration_micros_to_value(*micros)),
ScalarValue::DurationMillisecond(Some(millis)) => {
Ok(duration_micros_to_value(*millis * 1_000))
}
ScalarValue::DurationSecond(Some(secs)) => Ok(duration_micros_to_value(*secs * 1_000_000)),
ScalarValue::DurationNanosecond(Some(nanos)) => {
Ok(Value::Temporal(uni_common::TemporalValue::Duration {
months: 0,
days: 0,
nanos: *nanos,
}))
}
ScalarValue::Float32(Some(f)) => Ok(Value::Float(*f as f64)),
ScalarValue::FixedSizeBinary(24, Some(bytes)) => {
match uni_btic::encode::decode_slice(bytes) {
Ok(btic) => Ok(Value::Temporal(uni_common::TemporalValue::Btic {
lo: btic.lo(),
hi: btic.hi(),
meta: btic.meta(),
})),
Err(e) => Err(datafusion::error::DataFusionError::Execution(format!(
"BTIC decode error: {e}"
))),
}
}
ScalarValue::Null
| ScalarValue::Utf8(None)
| ScalarValue::LargeUtf8(None)
| ScalarValue::LargeBinary(None)
| ScalarValue::Int64(None)
| ScalarValue::Int32(None)
| ScalarValue::Int16(None)
| ScalarValue::Int8(None)
| ScalarValue::UInt64(None)
| ScalarValue::UInt32(None)
| ScalarValue::UInt16(None)
| ScalarValue::UInt8(None)
| ScalarValue::Float64(None)
| ScalarValue::Float32(None)
| ScalarValue::Boolean(None)
| ScalarValue::Date32(None)
| ScalarValue::Date64(None)
| ScalarValue::TimestampMicrosecond(None, _)
| ScalarValue::TimestampMillisecond(None, _)
| ScalarValue::TimestampSecond(None, _)
| ScalarValue::TimestampNanosecond(None, _)
| ScalarValue::Time64Microsecond(None)
| ScalarValue::Time64Nanosecond(None)
| ScalarValue::DurationMicrosecond(None)
| ScalarValue::DurationMillisecond(None)
| ScalarValue::DurationSecond(None)
| ScalarValue::DurationNanosecond(None)
| ScalarValue::IntervalMonthDayNano(None)
| ScalarValue::FixedSizeBinary(_, None) => Ok(Value::Null),
other => Err(datafusion::error::DataFusionError::Execution(format!(
"scalar_to_value(): unsupported scalar type {other:?}"
))),
}
}
fn value_to_columnar(val: &Value) -> DFResult<ColumnarValue> {
let scalar = match val {
Value::String(s) => ScalarValue::Utf8(Some(s.clone())),
Value::Int(i) => ScalarValue::Int64(Some(*i)),
Value::Float(f) => ScalarValue::Float64(Some(*f)),
Value::Bool(b) => ScalarValue::Boolean(Some(*b)),
Value::Null => ScalarValue::Utf8(None),
Value::Temporal(tv) => {
use uni_common::TemporalValue;
match tv {
TemporalValue::Date { days_since_epoch } => {
ScalarValue::Date32(Some(*days_since_epoch))
}
TemporalValue::LocalTime {
nanos_since_midnight,
} => ScalarValue::Time64Nanosecond(Some(*nanos_since_midnight)),
TemporalValue::Time {
nanos_since_midnight,
..
} => ScalarValue::Time64Nanosecond(Some(*nanos_since_midnight)),
TemporalValue::LocalDateTime { nanos_since_epoch } => {
ScalarValue::TimestampNanosecond(Some(*nanos_since_epoch), None)
}
TemporalValue::DateTime {
nanos_since_epoch,
timezone_name,
..
} => {
let tz = timezone_name.as_deref().unwrap_or("UTC");
ScalarValue::TimestampNanosecond(Some(*nanos_since_epoch), Some(tz.into()))
}
TemporalValue::Duration {
months,
days,
nanos,
} => ScalarValue::IntervalMonthDayNano(Some(
arrow::datatypes::IntervalMonthDayNano {
months: *months as i32,
days: *days as i32,
nanoseconds: *nanos,
},
)),
TemporalValue::Btic { lo, hi, meta } => {
let btic = uni_btic::Btic::new(*lo, *hi, *meta).map_err(|e| {
datafusion::error::DataFusionError::Execution(format!("invalid BTIC: {e}"))
})?;
let packed = uni_btic::encode::encode(&btic);
ScalarValue::FixedSizeBinary(24, Some(packed.to_vec()))
}
}
}
other => {
return Err(datafusion::error::DataFusionError::Execution(format!(
"value_to_columnar(): unsupported type {other:?}"
)));
}
};
Ok(ColumnarValue::Scalar(scalar))
}
pub fn create_has_null_udf() -> ScalarUDF {
ScalarUDF::new_from_impl(HasNullUdf::new())
}
#[derive(Debug)]
struct HasNullUdf {
signature: Signature,
}
impl HasNullUdf {
fn new() -> Self {
Self {
signature: Signature::any(1, Volatility::Immutable),
}
}
}
impl_udf_eq_hash!(HasNullUdf);
impl ScalarUDFImpl for HasNullUdf {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"_has_null"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
Ok(DataType::Boolean)
}
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
if args.args.len() != 1 {
return Err(datafusion::error::DataFusionError::Execution(
"_has_null(): requires 1 argument".to_string(),
));
}
fn check_list_nulls<T: arrow_array::OffsetSizeTrait>(
arr: &arrow_array::GenericListArray<T>,
idx: usize,
) -> bool {
if arr.is_null(idx) || arr.is_empty() {
false
} else {
arr.value(idx).null_count() > 0
}
}
match &args.args[0] {
ColumnarValue::Scalar(scalar) => {
let has_null = match scalar {
ScalarValue::List(arr) => arr
.as_any()
.downcast_ref::<arrow::array::ListArray>()
.map(|a| !a.is_empty() && a.value(0).null_count() > 0)
.unwrap_or(arr.null_count() > 0),
ScalarValue::LargeList(arr) => arr.len() > 0 && arr.value(0).null_count() > 0,
ScalarValue::FixedSizeList(arr) => {
arr.len() > 0 && arr.value(0).null_count() > 0
}
_ => false,
};
Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(has_null))))
}
ColumnarValue::Array(arr) => {
use arrow_array::{LargeListArray, ListArray};
let results: arrow::array::BooleanArray =
if let Some(list_arr) = arr.as_any().downcast_ref::<ListArray>() {
(0..list_arr.len())
.map(|i| {
if list_arr.is_null(i) {
None
} else {
Some(check_list_nulls(list_arr, i))
}
})
.collect()
} else if let Some(large) = arr.as_any().downcast_ref::<LargeListArray>() {
(0..large.len())
.map(|i| {
if large.is_null(i) {
None
} else {
Some(check_list_nulls(large, i))
}
})
.collect()
} else {
return Err(datafusion::error::DataFusionError::Execution(
"_has_null(): requires list array".to_string(),
));
};
Ok(ColumnarValue::Array(Arc::new(results)))
}
}
}
}
pub fn create_to_integer_udf() -> ScalarUDF {
ScalarUDF::new_from_impl(ToIntegerUdf::new())
}
#[derive(Debug)]
struct ToIntegerUdf {
signature: Signature,
}
impl ToIntegerUdf {
fn new() -> Self {
Self {
signature: Signature::any(1, Volatility::Immutable),
}
}
}
impl_udf_eq_hash!(ToIntegerUdf);
impl ScalarUDFImpl for ToIntegerUdf {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"tointeger"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
Ok(DataType::Int64)
}
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
let output_type = self.return_type(&[])?;
invoke_cypher_udf(args, &output_type, |val_args| {
if val_args.is_empty() {
return Err(datafusion::error::DataFusionError::Execution(
"tointeger(): requires 1 argument".to_string(),
));
}
let val = &val_args[0];
match val {
Value::Int(i) => Ok(Value::Int(*i)),
Value::Float(f) => Ok(Value::Int(*f as i64)),
Value::String(s) => {
if let Ok(i) = s.parse::<i64>() {
Ok(Value::Int(i))
} else if let Ok(f) = s.parse::<f64>() {
Ok(Value::Int(f as i64))
} else {
Ok(Value::Null)
}
}
Value::Null => Ok(Value::Null),
other => Err(datafusion::error::DataFusionError::Execution(format!(
"InvalidArgumentValue: tointeger(): cannot convert {} to integer",
cypher_type_name(other)
))),
}
})
}
}
pub fn create_to_float_udf() -> ScalarUDF {
ScalarUDF::new_from_impl(ToFloatUdf::new())
}
#[derive(Debug)]
struct ToFloatUdf {
signature: Signature,
}
impl ToFloatUdf {
fn new() -> Self {
Self {
signature: Signature::any(1, Volatility::Immutable),
}
}
}
impl_udf_eq_hash!(ToFloatUdf);
impl ScalarUDFImpl for ToFloatUdf {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"tofloat"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
Ok(DataType::Float64)
}
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
let output_type = self.return_type(&[])?;
invoke_cypher_udf(args, &output_type, |val_args| {
if val_args.is_empty() {
return Err(datafusion::error::DataFusionError::Execution(
"tofloat(): requires 1 argument".to_string(),
));
}
let val = &val_args[0];
match val {
Value::Int(i) => Ok(Value::Float(*i as f64)),
Value::Float(f) => Ok(Value::Float(*f)),
Value::String(s) => {
if let Ok(f) = s.parse::<f64>() {
Ok(Value::Float(f))
} else {
Ok(Value::Null)
}
}
Value::Null => Ok(Value::Null),
other => Err(datafusion::error::DataFusionError::Execution(format!(
"InvalidArgumentValue: tofloat(): cannot convert {} to float",
cypher_type_name(other)
))),
}
})
}
}
pub fn create_to_boolean_udf() -> ScalarUDF {
ScalarUDF::new_from_impl(ToBooleanUdf::new())
}
#[derive(Debug)]
struct ToBooleanUdf {
signature: Signature,
}
impl ToBooleanUdf {
fn new() -> Self {
Self {
signature: Signature::any(1, Volatility::Immutable),
}
}
}
impl_udf_eq_hash!(ToBooleanUdf);
impl ScalarUDFImpl for ToBooleanUdf {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"toboolean"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
Ok(DataType::Boolean)
}
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
let output_type = self.return_type(&[])?;
invoke_cypher_udf(args, &output_type, |val_args| {
if val_args.is_empty() {
return Err(datafusion::error::DataFusionError::Execution(
"toboolean(): requires 1 argument".to_string(),
));
}
let val = &val_args[0];
match val {
Value::Bool(b) => Ok(Value::Bool(*b)),
Value::String(s) => {
let s_lower = s.to_lowercase();
if s_lower == "true" {
Ok(Value::Bool(true))
} else if s_lower == "false" {
Ok(Value::Bool(false))
} else {
Ok(Value::Null)
}
}
Value::Null => Ok(Value::Null),
Value::Int(i) => Ok(Value::Bool(*i != 0)),
other => Err(datafusion::error::DataFusionError::Execution(format!(
"InvalidArgumentValue: toboolean(): cannot convert {} to boolean",
cypher_type_name(other)
))),
}
})
}
}
pub fn create_cypher_sort_key_udf() -> ScalarUDF {
ScalarUDF::new_from_impl(CypherSortKeyUdf::new())
}
#[derive(Debug)]
struct CypherSortKeyUdf {
signature: Signature,
}
impl CypherSortKeyUdf {
fn new() -> Self {
Self {
signature: Signature::any(1, Volatility::Immutable),
}
}
}
impl_udf_eq_hash!(CypherSortKeyUdf);
impl ScalarUDFImpl for CypherSortKeyUdf {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"_cypher_sort_key"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
Ok(DataType::LargeBinary)
}
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
if args.args.len() != 1 {
return Err(datafusion::error::DataFusionError::Execution(
"_cypher_sort_key(): requires 1 argument".to_string(),
));
}
let arg = &args.args[0];
match arg {
ColumnarValue::Scalar(s) => {
let val = if s.is_null() {
Value::Null
} else {
scalar_to_value(s)?
};
let key = encode_cypher_sort_key(&val);
Ok(ColumnarValue::Scalar(ScalarValue::LargeBinary(Some(key))))
}
ColumnarValue::Array(arr) => {
let mut keys: Vec<Option<Vec<u8>>> = Vec::with_capacity(arr.len());
for i in 0..arr.len() {
let val = if arr.is_null(i) {
Value::Null
} else {
get_value_from_array(arr, i)?
};
keys.push(Some(encode_cypher_sort_key(&val)));
}
let array = LargeBinaryArray::from(
keys.iter()
.map(|k| k.as_deref())
.collect::<Vec<Option<&[u8]>>>(),
);
Ok(ColumnarValue::Array(Arc::new(array)))
}
}
}
}
pub fn encode_cypher_sort_key(value: &Value) -> Vec<u8> {
let mut buf = Vec::with_capacity(32);
encode_sort_key_to_buf(value, &mut buf);
buf
}
fn encode_sort_key_to_buf(value: &Value, buf: &mut Vec<u8>) {
if let Value::Map(map) = value {
if let Some(tv) = sort_key_map_as_temporal(map) {
buf.push(0x07); encode_temporal_payload(&tv, buf);
return;
}
let rank = sort_key_map_rank(map);
if rank != 0 {
buf.push(rank);
match rank {
0x01 => encode_map_as_node_payload(map, buf),
0x02 => encode_map_as_edge_payload(map, buf),
0x04 => encode_map_as_path_payload(map, buf),
_ => {} }
return;
}
}
if let Value::String(s) = value {
if let Some(tv) = sort_key_string_as_temporal(s) {
buf.push(0x07); encode_temporal_payload(&tv, buf);
return;
}
if let Some(temporal_type) = crate::query::datetime::classify_temporal(s) {
buf.push(0x07); if encode_wide_temporal_sort_key(s, temporal_type, buf) {
return;
}
buf.pop();
}
}
let rank = sort_key_type_rank(value);
buf.push(rank);
match value {
Value::Null => {} Value::Float(f) if f.is_nan() => {} Value::Bool(b) => buf.push(if *b { 0x01 } else { 0x00 }),
Value::Int(i) => {
let f = *i as f64;
buf.extend_from_slice(&encode_order_preserving_f64(f));
}
Value::Float(f) => {
buf.extend_from_slice(&encode_order_preserving_f64(*f));
}
Value::String(s) => {
byte_stuff_terminate(s.as_bytes(), buf);
}
Value::Temporal(tv) => {
encode_temporal_payload(tv, buf);
}
Value::List(items) => {
encode_list_payload(items, buf);
}
Value::Map(map) => {
encode_map_payload(map, buf);
}
Value::Node(node) => {
encode_node_payload(node, buf);
}
Value::Edge(edge) => {
encode_edge_payload(edge, buf);
}
Value::Path(path) => {
encode_path_payload(path, buf);
}
Value::Bytes(b) => {
byte_stuff_terminate(b, buf);
}
Value::Vector(v) => {
for f in v {
buf.extend_from_slice(&encode_order_preserving_f64(*f as f64));
}
}
_ => {} }
}
fn sort_key_type_rank(v: &Value) -> u8 {
match v {
Value::Map(map) => sort_key_map_rank(map),
Value::Node(_) => 0x01,
Value::Edge(_) => 0x02,
Value::List(_) => 0x03,
Value::Path(_) => 0x04,
Value::String(_) => 0x05,
Value::Bool(_) => 0x06,
Value::Temporal(_) => 0x07,
Value::Int(_) => 0x08,
Value::Float(f) if f.is_nan() => 0x09,
Value::Float(_) => 0x08,
Value::Null => 0x0A,
Value::Bytes(_) | Value::Vector(_) => 0x0B,
_ => 0x0B, }
}
fn sort_key_map_rank(map: &std::collections::HashMap<String, Value>) -> u8 {
if sort_key_map_as_temporal(map).is_some() {
0x07
} else if map.contains_key("nodes")
&& (map.contains_key("relationships") || map.contains_key("edges"))
{
0x04 } else if map.contains_key("_eid")
|| map.contains_key("_src")
|| map.contains_key("_dst")
|| map.contains_key("_type")
|| map.contains_key("_type_name")
{
0x02 } else if map.contains_key("_vid") || map.contains_key("_labels") || map.contains_key("_label")
{
0x01 } else {
0x00 }
}
fn sort_key_map_as_temporal(
map: &std::collections::HashMap<String, Value>,
) -> Option<uni_common::TemporalValue> {
super::expr_eval::temporal_from_map_wrapper(map)
}
fn sort_key_string_as_temporal(s: &str) -> Option<uni_common::TemporalValue> {
super::expr_eval::temporal_from_value(&Value::String(s.to_string()))
}
fn encode_wide_temporal_sort_key(
s: &str,
temporal_type: uni_common::TemporalType,
buf: &mut Vec<u8>,
) -> bool {
match temporal_type {
uni_common::TemporalType::LocalDateTime => {
if let Some(ndt) = parse_naive_datetime(s) {
buf.push(0x03); let wide_nanos = naive_datetime_to_wide_nanos(&ndt);
buf.extend_from_slice(&encode_order_preserving_i128(wide_nanos));
return true;
}
false
}
uni_common::TemporalType::DateTime => {
let base = if let Some(bracket_pos) = s.find('[') {
&s[..bracket_pos]
} else {
s
};
if let Ok(dt) = chrono::DateTime::parse_from_str(base, "%Y-%m-%dT%H:%M:%S%.f%:z") {
buf.push(0x04); let utc = dt.naive_utc();
let wide_nanos = naive_datetime_to_wide_nanos(&utc);
buf.extend_from_slice(&encode_order_preserving_i128(wide_nanos));
return true;
}
if let Ok(dt) = chrono::DateTime::parse_from_str(base, "%Y-%m-%dT%H:%M:%S%:z") {
buf.push(0x04); let utc = dt.naive_utc();
let wide_nanos = naive_datetime_to_wide_nanos(&utc);
buf.extend_from_slice(&encode_order_preserving_i128(wide_nanos));
return true;
}
false
}
uni_common::TemporalType::Date => {
if let Ok(nd) = chrono::NaiveDate::parse_from_str(s, "%Y-%m-%d")
&& let Some(epoch) = chrono::NaiveDate::from_ymd_opt(1970, 1, 1)
{
buf.push(0x00); let days = nd.signed_duration_since(epoch).num_days() as i32;
buf.extend_from_slice(&encode_order_preserving_i32(days));
return true;
}
false
}
_ => false,
}
}
fn parse_naive_datetime(s: &str) -> Option<chrono::NaiveDateTime> {
chrono::NaiveDateTime::parse_from_str(s, "%Y-%m-%dT%H:%M:%S%.f")
.ok()
.or_else(|| chrono::NaiveDateTime::parse_from_str(s, "%Y-%m-%dT%H:%M:%S").ok())
}
fn naive_datetime_to_wide_nanos(ndt: &chrono::NaiveDateTime) -> i128 {
let secs = ndt.and_utc().timestamp() as i128;
let subsec_nanos = ndt.and_utc().timestamp_subsec_nanos() as i128;
secs * 1_000_000_000 + subsec_nanos
}
fn encode_map_as_node_payload(map: &std::collections::HashMap<String, Value>, buf: &mut Vec<u8>) {
let mut labels: Vec<String> = Vec::new();
if let Some(Value::List(lbls)) = map.get("_labels") {
for l in lbls {
if let Value::String(s) = l {
labels.push(s.clone());
}
}
} else if let Some(Value::String(lbl)) = map.get("_label") {
labels.push(lbl.clone());
}
labels.sort();
let vid = map.get("_vid").and_then(|v| v.as_i64()).unwrap_or(0) as u64;
let labels_joined = labels.join("\x01");
byte_stuff_terminate(labels_joined.as_bytes(), buf);
buf.extend_from_slice(&vid.to_be_bytes());
let mut props: std::collections::HashMap<String, Value> = std::collections::HashMap::new();
for (k, v) in map {
if !k.starts_with('_') {
props.insert(k.clone(), v.clone());
}
}
encode_map_payload(&props, buf);
}
fn encode_map_as_edge_payload(map: &std::collections::HashMap<String, Value>, buf: &mut Vec<u8>) {
let edge_type = map
.get("_type")
.or_else(|| map.get("_type_name"))
.and_then(|v| {
if let Value::String(s) = v {
Some(s.as_str())
} else {
None
}
})
.unwrap_or("");
byte_stuff_terminate(edge_type.as_bytes(), buf);
let src = map.get("_src").and_then(|v| v.as_i64()).unwrap_or(0) as u64;
let dst = map.get("_dst").and_then(|v| v.as_i64()).unwrap_or(0) as u64;
let eid = map.get("_eid").and_then(|v| v.as_i64()).unwrap_or(0) as u64;
buf.extend_from_slice(&src.to_be_bytes());
buf.extend_from_slice(&dst.to_be_bytes());
buf.extend_from_slice(&eid.to_be_bytes());
let mut props: std::collections::HashMap<String, Value> = std::collections::HashMap::new();
for (k, v) in map {
if !k.starts_with('_') {
props.insert(k.clone(), v.clone());
}
}
encode_map_payload(&props, buf);
}
fn encode_map_as_path_payload(map: &std::collections::HashMap<String, Value>, buf: &mut Vec<u8>) {
if let Some(Value::List(nodes)) = map.get("nodes") {
encode_list_payload(nodes, buf);
} else {
buf.push(0x00); }
let edges = map.get("relationships").or_else(|| map.get("edges"));
if let Some(Value::List(edges)) = edges {
encode_list_payload(edges, buf);
} else {
buf.push(0x00); }
}
fn encode_order_preserving_f64(f: f64) -> [u8; 8] {
let bits = f.to_bits();
let encoded = if bits >> 63 == 1 {
!bits
} else {
bits ^ (1u64 << 63)
};
encoded.to_be_bytes()
}
fn encode_order_preserving_i64(i: i64) -> [u8; 8] {
((i as u64) ^ (1u64 << 63)).to_be_bytes()
}
fn encode_order_preserving_i32(i: i32) -> [u8; 4] {
((i as u32) ^ (1u32 << 31)).to_be_bytes()
}
fn encode_order_preserving_i128(i: i128) -> [u8; 16] {
((i as u128) ^ (1u128 << 127)).to_be_bytes()
}
fn byte_stuff_terminate(data: &[u8], buf: &mut Vec<u8>) {
byte_stuff(data, buf);
buf.push(0x00);
buf.push(0x00);
}
fn byte_stuff(data: &[u8], buf: &mut Vec<u8>) {
for &b in data {
buf.push(b);
if b == 0x00 {
buf.push(0xFF);
}
}
}
fn encode_list_payload(items: &[Value], buf: &mut Vec<u8>) {
for item in items {
buf.push(0x01); let elem_key = encode_cypher_sort_key(item);
byte_stuff_terminate(&elem_key, buf);
}
buf.push(0x00); }
fn encode_map_payload(map: &std::collections::HashMap<String, Value>, buf: &mut Vec<u8>) {
let mut pairs: Vec<(&String, &Value)> = map.iter().collect();
pairs.sort_by_key(|(k, _)| *k);
for (key, value) in pairs {
buf.push(0x01); byte_stuff_terminate(key.as_bytes(), buf);
let val_key = encode_cypher_sort_key(value);
byte_stuff_terminate(&val_key, buf);
}
buf.push(0x00); }
fn encode_node_payload(node: &uni_common::Node, buf: &mut Vec<u8>) {
let mut labels = node.labels.clone();
labels.sort();
let labels_joined = labels.join("\x01");
byte_stuff_terminate(labels_joined.as_bytes(), buf);
buf.extend_from_slice(&node.vid.as_u64().to_be_bytes());
encode_map_payload(&node.properties, buf);
}
fn encode_edge_payload(edge: &uni_common::Edge, buf: &mut Vec<u8>) {
byte_stuff_terminate(edge.edge_type.as_bytes(), buf);
buf.extend_from_slice(&edge.src.as_u64().to_be_bytes());
buf.extend_from_slice(&edge.dst.as_u64().to_be_bytes());
buf.extend_from_slice(&edge.eid.as_u64().to_be_bytes());
encode_map_payload(&edge.properties, buf);
}
fn encode_path_payload(path: &uni_common::Path, buf: &mut Vec<u8>) {
for node in &path.nodes {
buf.push(0x01); let mut node_key = Vec::new();
node_key.push(0x01); encode_node_payload(node, &mut node_key);
byte_stuff_terminate(&node_key, buf);
}
buf.push(0x00);
for edge in &path.edges {
buf.push(0x01); let mut edge_key = Vec::new();
edge_key.push(0x02); encode_edge_payload(edge, &mut edge_key);
byte_stuff_terminate(&edge_key, buf);
}
buf.push(0x00); }
fn encode_temporal_payload(tv: &uni_common::TemporalValue, buf: &mut Vec<u8>) {
match tv {
uni_common::TemporalValue::Date { days_since_epoch } => {
buf.push(0x00); buf.extend_from_slice(&encode_order_preserving_i32(*days_since_epoch));
}
uni_common::TemporalValue::LocalTime {
nanos_since_midnight,
} => {
buf.push(0x01); buf.extend_from_slice(&encode_order_preserving_i64(*nanos_since_midnight));
}
uni_common::TemporalValue::Time {
nanos_since_midnight,
offset_seconds,
} => {
buf.push(0x02); let utc_nanos =
*nanos_since_midnight as i128 - (*offset_seconds as i128) * 1_000_000_000;
buf.extend_from_slice(&encode_order_preserving_i128(utc_nanos));
}
uni_common::TemporalValue::LocalDateTime { nanos_since_epoch } => {
buf.push(0x03); buf.extend_from_slice(&encode_order_preserving_i128(*nanos_since_epoch as i128));
}
uni_common::TemporalValue::DateTime {
nanos_since_epoch, ..
} => {
buf.push(0x04); buf.extend_from_slice(&encode_order_preserving_i128(*nanos_since_epoch as i128));
}
uni_common::TemporalValue::Duration {
months,
days,
nanos,
} => {
buf.push(0x05); buf.extend_from_slice(&encode_order_preserving_i64(*months));
buf.extend_from_slice(&encode_order_preserving_i64(*days));
buf.extend_from_slice(&encode_order_preserving_i64(*nanos));
}
uni_common::TemporalValue::Btic { lo, hi, meta } => {
buf.push(0x06); if let Ok(btic) = uni_btic::Btic::new(*lo, *hi, *meta) {
buf.extend_from_slice(&uni_btic::encode::encode(&btic));
} else {
buf.extend_from_slice(&encode_order_preserving_i64(*lo));
buf.extend_from_slice(&encode_order_preserving_i64(*hi));
}
}
}
}
#[derive(Debug)]
struct BticScalarUdf {
name: String,
signature: Signature,
return_type: DataType,
}
impl BticScalarUdf {
fn new(name: &str, num_args: usize, return_type: DataType) -> Self {
Self {
name: name.to_string(),
signature: Signature::new(TypeSignature::Any(num_args), Volatility::Immutable),
return_type,
}
}
}
impl_udf_eq_hash!(BticScalarUdf);
impl ScalarUDFImpl for BticScalarUdf {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
&self.name
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
Ok(self.return_type.clone())
}
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
let fname = self.name.to_uppercase();
let rt = self.return_type.clone();
invoke_cypher_udf(args, &rt, |val_args| {
crate::query::expr_eval::eval_btic_function(&fname, val_args)
.map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))
})
}
}
fn register_btic_scalar_udfs(ctx: &SessionContext) -> DFResult<()> {
for name in &["btic_lo", "btic_hi"] {
ctx.register_udf(ScalarUDF::new_from_impl(BticScalarUdf::new(
name,
1,
DataType::LargeBinary,
)));
}
ctx.register_udf(ScalarUDF::new_from_impl(BticScalarUdf::new(
"btic_duration",
1,
DataType::Int64,
)));
for name in &["btic_is_instant", "btic_is_unbounded", "btic_is_finite"] {
ctx.register_udf(ScalarUDF::new_from_impl(BticScalarUdf::new(
name,
1,
DataType::Boolean,
)));
}
for name in &[
"btic_granularity",
"btic_lo_granularity",
"btic_hi_granularity",
"btic_certainty",
"btic_lo_certainty",
"btic_hi_certainty",
] {
ctx.register_udf(ScalarUDF::new_from_impl(BticScalarUdf::new(
name,
1,
DataType::Utf8,
)));
}
for name in &[
"btic_contains_point",
"btic_overlaps",
"btic_contains",
"btic_before",
"btic_after",
"btic_meets",
"btic_adjacent",
"btic_disjoint",
"btic_equals",
"btic_starts",
"btic_during",
"btic_finishes",
] {
ctx.register_udf(ScalarUDF::new_from_impl(BticScalarUdf::new(
name,
2,
DataType::Boolean,
)));
}
for name in &["btic_intersection", "btic_span", "btic_gap"] {
ctx.register_udf(ScalarUDF::new_from_impl(BticScalarUdf::new(
name,
2,
DataType::LargeBinary,
)));
}
Ok(())
}
#[derive(Debug, Clone)]
struct BticMinMaxUdaf {
name: String,
signature: Signature,
is_max: bool,
}
impl BticMinMaxUdaf {
fn new(is_max: bool) -> Self {
Self {
name: (if is_max { "btic_max" } else { "btic_min" }).to_string(),
signature: Signature::new(TypeSignature::Any(1), Volatility::Immutable),
is_max,
}
}
}
impl_udf_eq_hash!(BticMinMaxUdaf);
impl AggregateUDFImpl for BticMinMaxUdaf {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
&self.name
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _args: &[DataType]) -> DFResult<DataType> {
Ok(DataType::LargeBinary)
}
fn accumulator(
&self,
_acc_args: datafusion::logical_expr::function::AccumulatorArgs,
) -> DFResult<Box<dyn DfAccumulator>> {
Ok(Box::new(BticMinMaxAccumulator {
current: None,
is_max: self.is_max,
}))
}
fn state_fields(
&self,
args: datafusion::logical_expr::function::StateFieldsArgs,
) -> DFResult<Vec<Arc<arrow::datatypes::Field>>> {
Ok(vec![Arc::new(arrow::datatypes::Field::new(
args.name,
DataType::LargeBinary,
true,
))])
}
}
#[derive(Debug)]
struct BticMinMaxAccumulator {
current: Option<uni_btic::Btic>,
is_max: bool,
}
impl DfAccumulator for BticMinMaxAccumulator {
fn update_batch(&mut self, values: &[ArrayRef]) -> DFResult<()> {
let arr = &values[0];
for i in 0..arr.len() {
if arr.is_null(i) {
continue;
}
let Some(btic) = decode_btic_from_array(arr, i)? else {
continue;
};
self.current = Some(match self.current.take() {
None => btic,
Some(cur) => {
if (self.is_max && btic > cur) || (!self.is_max && btic < cur) {
btic
} else {
cur
}
}
});
}
Ok(())
}
fn evaluate(&mut self) -> DFResult<ScalarValue> {
Ok(btic_to_scalar_value(self.current.as_ref()))
}
fn size(&self) -> usize {
std::mem::size_of::<Self>()
}
fn state(&mut self) -> DFResult<Vec<ScalarValue>> {
Ok(vec![self.evaluate()?])
}
fn merge_batch(&mut self, states: &[ArrayRef]) -> DFResult<()> {
self.update_batch(states)
}
}
#[derive(Debug, Clone)]
struct BticSpanAggUdaf {
signature: Signature,
}
impl BticSpanAggUdaf {
fn new() -> Self {
Self {
signature: Signature::new(TypeSignature::Any(1), Volatility::Immutable),
}
}
}
impl_udf_eq_hash!(BticSpanAggUdaf);
impl AggregateUDFImpl for BticSpanAggUdaf {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"btic_span_agg"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _args: &[DataType]) -> DFResult<DataType> {
Ok(DataType::LargeBinary)
}
fn accumulator(
&self,
_acc_args: datafusion::logical_expr::function::AccumulatorArgs,
) -> DFResult<Box<dyn DfAccumulator>> {
Ok(Box::new(BticSpanAggAccumulator { current: None }))
}
fn state_fields(
&self,
args: datafusion::logical_expr::function::StateFieldsArgs,
) -> DFResult<Vec<Arc<arrow::datatypes::Field>>> {
Ok(vec![Arc::new(arrow::datatypes::Field::new(
args.name,
DataType::LargeBinary,
true,
))])
}
}
#[derive(Debug)]
struct BticSpanAggAccumulator {
current: Option<uni_btic::Btic>,
}
impl DfAccumulator for BticSpanAggAccumulator {
fn update_batch(&mut self, values: &[ArrayRef]) -> DFResult<()> {
let arr = &values[0];
for i in 0..arr.len() {
if arr.is_null(i) {
continue;
}
let Some(btic) = decode_btic_from_array(arr, i)? else {
continue;
};
self.current = Some(match self.current.take() {
None => btic,
Some(cur) => uni_btic::set_ops::span(&cur, &btic),
});
}
Ok(())
}
fn evaluate(&mut self) -> DFResult<ScalarValue> {
Ok(btic_to_scalar_value(self.current.as_ref()))
}
fn size(&self) -> usize {
std::mem::size_of::<Self>()
}
fn state(&mut self) -> DFResult<Vec<ScalarValue>> {
Ok(vec![self.evaluate()?])
}
fn merge_batch(&mut self, states: &[ArrayRef]) -> DFResult<()> {
self.update_batch(states)
}
}
#[derive(Debug, Clone)]
struct BticCountAtUdaf {
signature: Signature,
}
impl BticCountAtUdaf {
fn new() -> Self {
Self {
signature: Signature::new(TypeSignature::Any(2), Volatility::Immutable),
}
}
}
impl_udf_eq_hash!(BticCountAtUdaf);
impl AggregateUDFImpl for BticCountAtUdaf {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"btic_count_at"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _args: &[DataType]) -> DFResult<DataType> {
Ok(DataType::Int64)
}
fn accumulator(
&self,
_acc_args: datafusion::logical_expr::function::AccumulatorArgs,
) -> DFResult<Box<dyn DfAccumulator>> {
Ok(Box::new(BticCountAtAccumulator { count: 0 }))
}
fn state_fields(
&self,
args: datafusion::logical_expr::function::StateFieldsArgs,
) -> DFResult<Vec<Arc<arrow::datatypes::Field>>> {
Ok(vec![Arc::new(arrow::datatypes::Field::new(
args.name,
DataType::Int64,
true,
))])
}
}
#[derive(Debug)]
struct BticCountAtAccumulator {
count: i64,
}
impl DfAccumulator for BticCountAtAccumulator {
fn update_batch(&mut self, values: &[ArrayRef]) -> DFResult<()> {
if values.len() < 2 {
return Ok(());
}
let btic_arr = &values[0];
let point_arr = &values[1];
for i in 0..btic_arr.len() {
if btic_arr.is_null(i) || point_arr.is_null(i) {
continue;
}
let Some(btic) = decode_btic_from_array(btic_arr, i)? else {
continue;
};
let point_ms = if let Some(int_arr) = point_arr.as_any().downcast_ref::<Int64Array>() {
int_arr.value(i)
} else if let Some(lb) = point_arr.as_any().downcast_ref::<LargeBinaryArray>() {
let val = scalar_binary_to_value(lb.value(i));
match &val {
Value::Int(ms) => *ms,
Value::Temporal(uni_common::TemporalValue::DateTime {
nanos_since_epoch,
..
}) => nanos_since_epoch / 1_000_000,
_ => continue,
}
} else {
continue;
};
if uni_btic::predicates::contains_point(&btic, point_ms) {
self.count += 1;
}
}
Ok(())
}
fn evaluate(&mut self) -> DFResult<ScalarValue> {
Ok(ScalarValue::Int64(Some(self.count)))
}
fn size(&self) -> usize {
std::mem::size_of::<Self>()
}
fn state(&mut self) -> DFResult<Vec<ScalarValue>> {
Ok(vec![ScalarValue::Int64(Some(self.count))])
}
fn merge_batch(&mut self, states: &[ArrayRef]) -> DFResult<()> {
let arr = &states[0];
if let Some(int_arr) = arr.as_any().downcast_ref::<Int64Array>() {
for i in 0..int_arr.len() {
if !int_arr.is_null(i) {
self.count += int_arr.value(i);
}
}
}
Ok(())
}
}
fn btic_to_scalar_value(btic: Option<&uni_btic::Btic>) -> ScalarValue {
match btic {
None => ScalarValue::LargeBinary(None),
Some(b) => {
let val = Value::Temporal(uni_common::TemporalValue::Btic {
lo: b.lo(),
hi: b.hi(),
meta: b.meta(),
});
ScalarValue::LargeBinary(Some(uni_common::cypher_value_codec::encode(&val)))
}
}
}
fn decode_btic_from_array(arr: &ArrayRef, row: usize) -> DFResult<Option<uni_btic::Btic>> {
if let Some(fsb) = arr.as_any().downcast_ref::<FixedSizeBinaryArray>() {
let bytes = fsb.value(row);
return uni_btic::encode::decode_slice(bytes)
.map(Some)
.map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()));
}
if let Some(lb) = arr.as_any().downcast_ref::<LargeBinaryArray>() {
let val = scalar_binary_to_value(lb.value(row));
if let Value::Temporal(uni_common::TemporalValue::Btic { lo, hi, meta }) = val {
return uni_btic::Btic::new(lo, hi, meta)
.map(Some)
.map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()));
}
return Ok(None);
}
Ok(None)
}
pub(crate) fn create_btic_min_udaf() -> AggregateUDF {
AggregateUDF::from(BticMinMaxUdaf::new(false))
}
pub(crate) fn create_btic_max_udaf() -> AggregateUDF {
AggregateUDF::from(BticMinMaxUdaf::new(true))
}
pub(crate) fn create_btic_span_agg_udaf() -> AggregateUDF {
AggregateUDF::from(BticSpanAggUdaf::new())
}
pub(crate) fn create_btic_count_at_udaf() -> AggregateUDF {
AggregateUDF::from(BticCountAtUdaf::new())
}
pub fn invoke_cypher_string_op<F>(
args: &ScalarFunctionArgs,
name: &str,
op: F,
) -> DFResult<ColumnarValue>
where
F: Fn(&str, &str) -> bool,
{
use arrow_array::{BooleanArray, LargeBinaryArray, LargeStringArray, StringArray};
use datafusion::common::ScalarValue;
use datafusion::error::DataFusionError;
if args.args.len() != 2 {
return Err(DataFusionError::Execution(format!(
"{}(): requires exactly 2 arguments",
name
)));
}
let left = &args.args[0];
let right = &args.args[1];
let extract_string = |scalar: &ScalarValue| -> Option<String> {
match scalar {
ScalarValue::Utf8(Some(s)) | ScalarValue::LargeUtf8(Some(s)) => Some(s.clone()),
ScalarValue::LargeBinary(Some(bytes)) => {
match uni_common::cypher_value_codec::decode(bytes) {
Ok(uni_common::Value::String(s)) => Some(s),
_ => None,
}
}
ScalarValue::Utf8(None)
| ScalarValue::LargeUtf8(None)
| ScalarValue::LargeBinary(None)
| ScalarValue::Null => None,
_ => None,
}
};
match (left, right) {
(ColumnarValue::Scalar(l_scalar), ColumnarValue::Scalar(r_scalar)) => {
let l_str = extract_string(l_scalar);
let r_str = extract_string(r_scalar);
match (l_str, r_str) {
(Some(l), Some(r)) => Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(op(
&l, &r,
))))),
_ => Ok(ColumnarValue::Scalar(ScalarValue::Boolean(None))),
}
}
(ColumnarValue::Array(l_arr), ColumnarValue::Scalar(r_scalar)) => {
let r_val = extract_string(r_scalar);
if r_val.is_none() {
let nulls = arrow_array::new_null_array(&DataType::Boolean, l_arr.len());
return Ok(ColumnarValue::Array(nulls));
}
let pattern = r_val.unwrap();
let result_array = if let Some(arr) = l_arr.as_any().downcast_ref::<StringArray>() {
arr.iter()
.map(|opt_s| opt_s.map(|s| op(s, &pattern)))
.collect::<BooleanArray>()
} else if let Some(arr) = l_arr.as_any().downcast_ref::<LargeStringArray>() {
arr.iter()
.map(|opt_s| opt_s.map(|s| op(s, &pattern)))
.collect::<BooleanArray>()
} else if let Some(arr) = l_arr.as_any().downcast_ref::<LargeBinaryArray>() {
arr.iter()
.map(|opt_bytes| {
opt_bytes.and_then(|bytes| {
match uni_common::cypher_value_codec::decode(bytes) {
Ok(uni_common::Value::String(s)) => Some(op(&s, &pattern)),
_ => None,
}
})
})
.collect::<BooleanArray>()
} else {
arrow_array::new_null_array(&DataType::Boolean, l_arr.len())
.as_any()
.downcast_ref::<BooleanArray>()
.unwrap()
.clone()
};
Ok(ColumnarValue::Array(Arc::new(result_array)))
}
(ColumnarValue::Scalar(l_scalar), ColumnarValue::Array(r_arr)) => {
let l_val = extract_string(l_scalar);
if l_val.is_none() {
let nulls = arrow_array::new_null_array(&DataType::Boolean, r_arr.len());
return Ok(ColumnarValue::Array(nulls));
}
let target = l_val.unwrap();
let result_array = if let Some(arr) = r_arr.as_any().downcast_ref::<StringArray>() {
arr.iter()
.map(|opt_s| opt_s.map(|s| op(&target, s)))
.collect::<BooleanArray>()
} else if let Some(arr) = r_arr.as_any().downcast_ref::<LargeStringArray>() {
arr.iter()
.map(|opt_s| opt_s.map(|s| op(&target, s)))
.collect::<BooleanArray>()
} else if let Some(arr) = r_arr.as_any().downcast_ref::<LargeBinaryArray>() {
arr.iter()
.map(|opt_bytes| {
opt_bytes.and_then(|bytes| {
match uni_common::cypher_value_codec::decode(bytes) {
Ok(uni_common::Value::String(s)) => Some(op(&target, &s)),
_ => None,
}
})
})
.collect::<BooleanArray>()
} else {
arrow_array::new_null_array(&DataType::Boolean, r_arr.len())
.as_any()
.downcast_ref::<BooleanArray>()
.unwrap()
.clone()
};
Ok(ColumnarValue::Array(Arc::new(result_array)))
}
(ColumnarValue::Array(l_arr), ColumnarValue::Array(r_arr)) => {
if l_arr.len() != r_arr.len() {
return Err(DataFusionError::Execution(format!(
"{}(): array lengths must match",
name
)));
}
let extract_string_at = |arr: &dyn Array, idx: usize| -> Option<String> {
if let Some(str_arr) = arr.as_any().downcast_ref::<StringArray>() {
str_arr.value(idx).to_string().into()
} else if let Some(str_arr) = arr.as_any().downcast_ref::<LargeStringArray>() {
str_arr.value(idx).to_string().into()
} else if let Some(bin_arr) = arr.as_any().downcast_ref::<LargeBinaryArray>() {
if bin_arr.is_null(idx) {
return None;
}
let bytes = bin_arr.value(idx);
match uni_common::cypher_value_codec::decode(bytes) {
Ok(uni_common::Value::String(s)) => Some(s),
_ => None,
}
} else {
None
}
};
let result: BooleanArray = (0..l_arr.len())
.map(|idx| {
match (
extract_string_at(l_arr.as_ref(), idx),
extract_string_at(r_arr.as_ref(), idx),
) {
(Some(l_str), Some(r_str)) => Some(op(&l_str, &r_str)),
_ => None,
}
})
.collect();
Ok(ColumnarValue::Array(Arc::new(result)))
}
}
}
macro_rules! define_string_op_udf {
($struct_name:ident, $udf_name:literal, $op:expr) => {
#[derive(Debug)]
struct $struct_name {
signature: Signature,
}
impl $struct_name {
fn new() -> Self {
Self {
signature: Signature::any(2, Volatility::Immutable),
}
}
}
impl_udf_eq_hash!($struct_name);
impl ScalarUDFImpl for $struct_name {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
$udf_name
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
Ok(DataType::Boolean)
}
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
invoke_cypher_string_op(&args, $udf_name, $op)
}
}
};
}
define_string_op_udf!(CypherStartsWithUdf, "_cypher_starts_with", |s, p| s
.starts_with(p));
define_string_op_udf!(CypherEndsWithUdf, "_cypher_ends_with", |s, p| s
.ends_with(p));
define_string_op_udf!(CypherContainsUdf, "_cypher_contains", |s, p| s.contains(p));
pub fn create_cypher_starts_with_udf() -> ScalarUDF {
ScalarUDF::new_from_impl(CypherStartsWithUdf::new())
}
pub fn create_cypher_ends_with_udf() -> ScalarUDF {
ScalarUDF::new_from_impl(CypherEndsWithUdf::new())
}
pub fn create_cypher_contains_udf() -> ScalarUDF {
ScalarUDF::new_from_impl(CypherContainsUdf::new())
}
pub fn create_cypher_equal_udf() -> ScalarUDF {
ScalarUDF::new_from_impl(CypherCompareUdf::new("_cypher_equal", BinaryOp::Eq))
}
pub fn create_cypher_not_equal_udf() -> ScalarUDF {
ScalarUDF::new_from_impl(CypherCompareUdf::new("_cypher_not_equal", BinaryOp::NotEq))
}
pub fn create_cypher_lt_udf() -> ScalarUDF {
ScalarUDF::new_from_impl(CypherCompareUdf::new("_cypher_lt", BinaryOp::Lt))
}
pub fn create_cypher_lt_eq_udf() -> ScalarUDF {
ScalarUDF::new_from_impl(CypherCompareUdf::new("_cypher_lt_eq", BinaryOp::LtEq))
}
pub fn create_cypher_gt_udf() -> ScalarUDF {
ScalarUDF::new_from_impl(CypherCompareUdf::new("_cypher_gt", BinaryOp::Gt))
}
pub fn create_cypher_gt_eq_udf() -> ScalarUDF {
ScalarUDF::new_from_impl(CypherCompareUdf::new("_cypher_gt_eq", BinaryOp::GtEq))
}
#[expect(clippy::match_like_matches_macro)]
fn apply_comparison_op(ord: std::cmp::Ordering, op: &BinaryOp) -> bool {
use std::cmp::Ordering;
match (ord, op) {
(Ordering::Less, BinaryOp::Lt | BinaryOp::LtEq | BinaryOp::NotEq) => true,
(Ordering::Equal, BinaryOp::Eq | BinaryOp::LtEq | BinaryOp::GtEq) => true,
(Ordering::Greater, BinaryOp::Gt | BinaryOp::GtEq | BinaryOp::NotEq) => true,
_ => false,
}
}
fn compare_f64(lhs: f64, rhs: f64, op: &BinaryOp) -> Option<bool> {
if lhs.is_nan() || rhs.is_nan() {
Some(matches!(op, BinaryOp::NotEq))
} else {
Some(apply_comparison_op(lhs.partial_cmp(&rhs)?, op))
}
}
fn cv_bytes_as_f64(bytes: &[u8]) -> Option<f64> {
use uni_common::cypher_value_codec::{TAG_FLOAT, TAG_INT, decode_float, decode_int, peek_tag};
match peek_tag(bytes)? {
TAG_INT => decode_int(bytes).map(|i| i as f64),
TAG_FLOAT => decode_float(bytes),
_ => None,
}
}
fn compare_cv_numeric(bytes: &[u8], rhs: f64, op: &BinaryOp) -> Option<bool> {
use uni_common::cypher_value_codec::{TAG_INT, TAG_NULL, decode_int, peek_tag};
if peek_tag(bytes) == Some(TAG_INT)
&& let Some(lhs_int) = decode_int(bytes)
&& rhs.fract() == 0.0
&& rhs >= i64::MIN as f64
&& rhs <= i64::MAX as f64
{
return Some(apply_comparison_op(lhs_int.cmp(&(rhs as i64)), op));
}
if peek_tag(bytes) == Some(TAG_NULL) {
return None;
}
let lhs = cv_bytes_as_f64(bytes)?;
compare_f64(lhs, rhs, op)
}
fn try_fast_compare(
lhs: &ColumnarValue,
rhs: &ColumnarValue,
op: &BinaryOp,
) -> Option<ColumnarValue> {
use arrow_array::builder::BooleanBuilder;
use uni_common::cypher_value_codec::{
TAG_INT, TAG_NULL, TAG_STRING, decode_int, decode_string, peek_tag,
};
let (lhs_arr, rhs_arr) = match (lhs, rhs) {
(ColumnarValue::Array(l), ColumnarValue::Array(r)) => (l, r),
_ => return None,
};
if !matches!(lhs_arr.data_type(), DataType::LargeBinary) {
return None;
}
let lb_arr = lhs_arr.as_any().downcast_ref::<LargeBinaryArray>()?;
match rhs_arr.data_type() {
DataType::Int64 => {
let int_arr = rhs_arr.as_any().downcast_ref::<Int64Array>()?;
let mut builder = BooleanBuilder::with_capacity(lb_arr.len());
for i in 0..lb_arr.len() {
if lb_arr.is_null(i) || int_arr.is_null(i) {
builder.append_null();
} else {
match compare_cv_numeric(lb_arr.value(i), int_arr.value(i) as f64, op) {
Some(result) => builder.append_value(result),
None => builder.append_null(),
}
}
}
Some(ColumnarValue::Array(Arc::new(builder.finish())))
}
DataType::Float64 => {
let float_arr = rhs_arr.as_any().downcast_ref::<Float64Array>()?;
let mut builder = BooleanBuilder::with_capacity(lb_arr.len());
for i in 0..lb_arr.len() {
if lb_arr.is_null(i) || float_arr.is_null(i) {
builder.append_null();
} else {
match compare_cv_numeric(lb_arr.value(i), float_arr.value(i), op) {
Some(result) => builder.append_value(result),
None => builder.append_null(),
}
}
}
Some(ColumnarValue::Array(Arc::new(builder.finish())))
}
DataType::Utf8 | DataType::LargeUtf8 => {
let mut builder = BooleanBuilder::with_capacity(lb_arr.len());
for i in 0..lb_arr.len() {
if lb_arr.is_null(i) || rhs_arr.is_null(i) {
builder.append_null();
} else {
let bytes = lb_arr.value(i);
let rhs_str = if matches!(rhs_arr.data_type(), DataType::Utf8) {
rhs_arr.as_any().downcast_ref::<StringArray>()?.value(i)
} else {
rhs_arr
.as_any()
.downcast_ref::<LargeStringArray>()?
.value(i)
};
match peek_tag(bytes) {
Some(TAG_STRING) => {
if let Some(lhs_str) = decode_string(bytes) {
builder.append_value(apply_comparison_op(
lhs_str.as_str().cmp(rhs_str),
op,
));
} else {
builder.append_null();
}
}
_ => builder.append_null(),
}
}
}
Some(ColumnarValue::Array(Arc::new(builder.finish())))
}
DataType::LargeBinary => {
let rhs_lb = rhs_arr.as_any().downcast_ref::<LargeBinaryArray>()?;
let mut builder = BooleanBuilder::with_capacity(lb_arr.len());
for i in 0..lb_arr.len() {
if lb_arr.is_null(i) || rhs_lb.is_null(i) {
builder.append_null();
} else {
let lhs_bytes = lb_arr.value(i);
let rhs_bytes = rhs_lb.value(i);
let lhs_tag = peek_tag(lhs_bytes);
let rhs_tag = peek_tag(rhs_bytes);
if lhs_tag == Some(TAG_NULL) || rhs_tag == Some(TAG_NULL) {
builder.append_null();
continue;
}
if lhs_tag == Some(TAG_INT) && rhs_tag == Some(TAG_INT) {
if let (Some(l), Some(r)) = (decode_int(lhs_bytes), decode_int(rhs_bytes)) {
builder.append_value(apply_comparison_op(l.cmp(&r), op));
} else {
builder.append_null();
}
continue;
}
if lhs_tag == Some(TAG_STRING) && rhs_tag == Some(TAG_STRING) {
if let (Some(l), Some(r)) =
(decode_string(lhs_bytes), decode_string(rhs_bytes))
{
builder.append_value(apply_comparison_op(l.cmp(&r), op));
} else {
builder.append_null();
}
continue;
}
if let (Some(l), Some(r)) =
(cv_bytes_as_f64(lhs_bytes), cv_bytes_as_f64(rhs_bytes))
{
match compare_f64(l, r, op) {
Some(result) => builder.append_value(result),
None => builder.append_null(),
}
} else {
return None;
}
}
}
Some(ColumnarValue::Array(Arc::new(builder.finish())))
}
_ => None, }
}
#[derive(Debug)]
struct CypherCompareUdf {
name: String,
op: BinaryOp,
signature: Signature,
}
impl CypherCompareUdf {
fn new(name: &str, op: BinaryOp) -> Self {
Self {
name: name.to_string(),
op,
signature: Signature::any(2, Volatility::Immutable),
}
}
}
impl PartialEq for CypherCompareUdf {
fn eq(&self, other: &Self) -> bool {
self.name == other.name
}
}
impl Eq for CypherCompareUdf {}
impl std::hash::Hash for CypherCompareUdf {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.name.hash(state);
}
}
impl ScalarUDFImpl for CypherCompareUdf {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
&self.name
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
Ok(DataType::Boolean)
}
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
if args.args.len() != 2 {
return Err(datafusion::error::DataFusionError::Execution(format!(
"{}(): requires 2 arguments",
self.name
)));
}
if let Some(result) = try_fast_compare(&args.args[0], &args.args[1], &self.op) {
return Ok(result);
}
let output_type = DataType::Boolean;
invoke_cypher_udf(args, &output_type, |val_args| {
crate::query::expr_eval::eval_binary_op(&val_args[0], &self.op, &val_args[1])
.map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))
})
}
}
pub fn create_cypher_add_udf() -> ScalarUDF {
ScalarUDF::new_from_impl(CypherArithmeticUdf::new("_cypher_add", BinaryOp::Add))
}
pub fn create_cypher_sub_udf() -> ScalarUDF {
ScalarUDF::new_from_impl(CypherArithmeticUdf::new("_cypher_sub", BinaryOp::Sub))
}
pub fn create_cypher_mul_udf() -> ScalarUDF {
ScalarUDF::new_from_impl(CypherArithmeticUdf::new("_cypher_mul", BinaryOp::Mul))
}
pub fn create_cypher_div_udf() -> ScalarUDF {
ScalarUDF::new_from_impl(CypherArithmeticUdf::new("_cypher_div", BinaryOp::Div))
}
pub fn create_cypher_mod_udf() -> ScalarUDF {
ScalarUDF::new_from_impl(CypherArithmeticUdf::new("_cypher_mod", BinaryOp::Mod))
}
pub fn create_cypher_abs_udf() -> ScalarUDF {
ScalarUDF::new_from_impl(CypherAbsUdf::new())
}
pub(crate) fn cypher_abs_expr(
arg: datafusion::logical_expr::Expr,
) -> datafusion::logical_expr::Expr {
datafusion::logical_expr::Expr::ScalarFunction(
datafusion::logical_expr::expr::ScalarFunction::new_udf(
Arc::new(create_cypher_abs_udf()),
vec![arg],
),
)
}
#[derive(Debug)]
struct CypherAbsUdf {
signature: Signature,
}
impl CypherAbsUdf {
fn new() -> Self {
Self {
signature: Signature::any(1, Volatility::Immutable),
}
}
}
impl_udf_eq_hash!(CypherAbsUdf);
impl ScalarUDFImpl for CypherAbsUdf {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"_cypher_abs"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _args: &[DataType]) -> DFResult<DataType> {
Ok(DataType::LargeBinary)
}
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
if args.args.len() != 1 {
return Err(datafusion::error::DataFusionError::Execution(
"_cypher_abs requires exactly 1 argument".into(),
));
}
invoke_cypher_udf(args, &DataType::LargeBinary, |val_args| {
match &val_args[0] {
Value::Int(i) => i.checked_abs().map(Value::Int).ok_or_else(|| {
datafusion::error::DataFusionError::Execution(
"integer overflow in abs()".into(),
)
}),
Value::Float(f) => Ok(Value::Float(f.abs())),
Value::Null => Ok(Value::Null),
other => Err(datafusion::error::DataFusionError::Execution(format!(
"abs() requires a numeric argument, got {other:?}"
))),
}
})
}
}
fn apply_int_arithmetic(lhs: i64, rhs: i64, op: &BinaryOp) -> Option<Vec<u8>> {
use uni_common::cypher_value_codec::encode_int;
match op {
BinaryOp::Add => lhs.checked_add(rhs).map(encode_int),
BinaryOp::Sub => lhs.checked_sub(rhs).map(encode_int),
BinaryOp::Mul => lhs.checked_mul(rhs).map(encode_int),
BinaryOp::Div => {
if rhs == 0 {
None
} else {
lhs.checked_div(rhs).map(encode_int)
}
}
BinaryOp::Mod => {
if rhs == 0 {
None
} else {
lhs.checked_rem(rhs).map(encode_int)
}
}
_ => None,
}
}
fn apply_float_arithmetic(lhs: f64, rhs: f64, op: &BinaryOp) -> Option<Vec<u8>> {
use uni_common::cypher_value_codec::encode_float;
let result = match op {
BinaryOp::Add => lhs + rhs,
BinaryOp::Sub => lhs - rhs,
BinaryOp::Mul => lhs * rhs,
BinaryOp::Div => lhs / rhs, BinaryOp::Mod => lhs % rhs,
_ => return None,
};
Some(encode_float(result))
}
fn cv_arithmetic_int(bytes: &[u8], rhs: i64, op: &BinaryOp) -> Option<Vec<u8>> {
use uni_common::cypher_value_codec::{TAG_FLOAT, TAG_INT, decode_float, decode_int, peek_tag};
match peek_tag(bytes)? {
TAG_INT => apply_int_arithmetic(decode_int(bytes)?, rhs, op),
TAG_FLOAT => apply_float_arithmetic(decode_float(bytes)?, rhs as f64, op),
_ => None,
}
}
fn cv_arithmetic_float(bytes: &[u8], rhs: f64, op: &BinaryOp) -> Option<Vec<u8>> {
let lhs = cv_bytes_as_f64(bytes)?;
apply_float_arithmetic(lhs, rhs, op)
}
fn try_fast_arithmetic(
lhs: &ColumnarValue,
rhs: &ColumnarValue,
op: &BinaryOp,
) -> Option<ColumnarValue> {
use arrow_array::builder::LargeBinaryBuilder;
let (lhs_arr, rhs_arr) = match (lhs, rhs) {
(ColumnarValue::Array(l), ColumnarValue::Array(r)) => (l, r),
_ => return None,
};
match (lhs_arr.data_type(), rhs_arr.data_type()) {
(DataType::LargeBinary, DataType::Int64) => {
let lb_arr = lhs_arr.as_any().downcast_ref::<LargeBinaryArray>()?;
let int_arr = rhs_arr.as_any().downcast_ref::<Int64Array>()?;
let mut builder = LargeBinaryBuilder::new();
for i in 0..lb_arr.len() {
if lb_arr.is_null(i) || int_arr.is_null(i) {
builder.append_null();
} else if let Some(bytes) = cv_arithmetic_int(lb_arr.value(i), int_arr.value(i), op)
{
builder.append_value(&bytes);
} else {
builder.append_null();
}
}
Some(ColumnarValue::Array(Arc::new(builder.finish())))
}
(DataType::LargeBinary, DataType::Float64) => {
let lb_arr = lhs_arr.as_any().downcast_ref::<LargeBinaryArray>()?;
let float_arr = rhs_arr.as_any().downcast_ref::<Float64Array>()?;
let mut builder = LargeBinaryBuilder::new();
for i in 0..lb_arr.len() {
if lb_arr.is_null(i) || float_arr.is_null(i) {
builder.append_null();
} else if let Some(bytes) =
cv_arithmetic_float(lb_arr.value(i), float_arr.value(i), op)
{
builder.append_value(&bytes);
} else {
builder.append_null();
}
}
Some(ColumnarValue::Array(Arc::new(builder.finish())))
}
(DataType::Int64, DataType::Int64) => {
let lhs_int = lhs_arr.as_any().downcast_ref::<Int64Array>()?;
let rhs_int = rhs_arr.as_any().downcast_ref::<Int64Array>()?;
let mut builder = LargeBinaryBuilder::new();
for i in 0..lhs_int.len() {
if lhs_int.is_null(i) || rhs_int.is_null(i) {
builder.append_null();
} else if let Some(bytes) =
apply_int_arithmetic(lhs_int.value(i), rhs_int.value(i), op)
{
builder.append_value(&bytes);
} else {
builder.append_null();
}
}
Some(ColumnarValue::Array(Arc::new(builder.finish())))
}
_ => None, }
}
#[derive(Debug)]
struct CypherArithmeticUdf {
name: String,
op: BinaryOp,
signature: Signature,
}
impl CypherArithmeticUdf {
fn new(name: &str, op: BinaryOp) -> Self {
Self {
name: name.to_string(),
op,
signature: Signature::any(2, Volatility::Immutable),
}
}
}
impl PartialEq for CypherArithmeticUdf {
fn eq(&self, other: &Self) -> bool {
self.name == other.name
}
}
impl Eq for CypherArithmeticUdf {}
impl std::hash::Hash for CypherArithmeticUdf {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.name.hash(state);
}
}
impl ScalarUDFImpl for CypherArithmeticUdf {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
&self.name
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
Ok(DataType::LargeBinary) }
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
if args.args.len() != 2 {
return Err(datafusion::error::DataFusionError::Execution(format!(
"{}(): requires 2 arguments",
self.name
)));
}
if let Some(result) = try_fast_arithmetic(&args.args[0], &args.args[1], &self.op) {
return Ok(result);
}
let output_type = DataType::LargeBinary;
invoke_cypher_udf(args, &output_type, |val_args| {
crate::query::expr_eval::eval_binary_op(&val_args[0], &self.op, &val_args[1])
.map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))
})
}
}
pub fn create_cypher_xor_udf() -> ScalarUDF {
ScalarUDF::new_from_impl(CypherXorUdf::new())
}
#[derive(Debug)]
struct CypherXorUdf {
signature: Signature,
}
impl CypherXorUdf {
fn new() -> Self {
Self {
signature: Signature::any(2, Volatility::Immutable),
}
}
}
impl_udf_eq_hash!(CypherXorUdf);
impl ScalarUDFImpl for CypherXorUdf {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"_cypher_xor"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
Ok(DataType::Boolean)
}
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
let output_type = DataType::Boolean;
invoke_cypher_udf(args, &output_type, |val_args| {
if val_args.len() != 2 {
return Err(datafusion::error::DataFusionError::Execution(
"_cypher_xor(): requires 2 arguments".to_string(),
));
}
let coerce_bool = |v: &Value| -> Value {
match v {
Value::String(s) if s == "true" => Value::Bool(true),
Value::String(s) if s == "false" => Value::Bool(false),
other => other.clone(),
}
};
let left = coerce_bool(&val_args[0]);
let right = coerce_bool(&val_args[1]);
crate::query::expr_eval::eval_binary_op(&left, &BinaryOp::Xor, &right)
.map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))
})
}
}
pub fn create_cv_to_bool_udf() -> ScalarUDF {
ScalarUDF::new_from_impl(CvToBoolUdf::new())
}
#[derive(Debug)]
struct CvToBoolUdf {
signature: Signature,
}
impl CvToBoolUdf {
fn new() -> Self {
Self {
signature: Signature::exact(vec![DataType::LargeBinary], Volatility::Immutable),
}
}
}
impl_udf_eq_hash!(CvToBoolUdf);
impl ScalarUDFImpl for CvToBoolUdf {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"_cv_to_bool"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
Ok(DataType::Boolean)
}
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
if args.args.len() != 1 {
return Err(datafusion::error::DataFusionError::Execution(
"_cv_to_bool() requires exactly 1 argument".to_string(),
));
}
match &args.args[0] {
ColumnarValue::Scalar(ScalarValue::LargeBinary(Some(bytes))) => {
use uni_common::cypher_value_codec::{TAG_BOOL, TAG_NULL, decode_bool, peek_tag};
let b = match peek_tag(bytes) {
Some(TAG_BOOL) => decode_bool(bytes).unwrap_or(false),
Some(TAG_NULL) => false,
_ => false, };
Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(b))))
}
ColumnarValue::Scalar(_) => Ok(ColumnarValue::Scalar(ScalarValue::Boolean(None))),
ColumnarValue::Array(arr) => {
let lb_arr = arr
.as_any()
.downcast_ref::<arrow_array::LargeBinaryArray>()
.ok_or_else(|| {
datafusion::error::DataFusionError::Execution(format!(
"_cv_to_bool(): expected LargeBinary array, got {:?}",
arr.data_type()
))
})?;
let mut builder = arrow_array::builder::BooleanBuilder::with_capacity(lb_arr.len());
use uni_common::cypher_value_codec::{TAG_BOOL, TAG_NULL, decode_bool, peek_tag};
for i in 0..lb_arr.len() {
if lb_arr.is_null(i) {
builder.append_null();
} else {
let bytes = lb_arr.value(i);
let b = match peek_tag(bytes) {
Some(TAG_BOOL) => decode_bool(bytes).unwrap_or(false),
Some(TAG_NULL) => false,
_ => false, };
builder.append_value(b);
}
}
Ok(ColumnarValue::Array(Arc::new(builder.finish())))
}
}
}
}
pub fn create_cypher_size_udf() -> ScalarUDF {
ScalarUDF::new_from_impl(CypherSizeUdf::new())
}
#[derive(Debug)]
struct CypherSizeUdf {
signature: Signature,
}
impl CypherSizeUdf {
fn new() -> Self {
Self {
signature: Signature::any(1, Volatility::Immutable),
}
}
}
impl_udf_eq_hash!(CypherSizeUdf);
impl ScalarUDFImpl for CypherSizeUdf {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"_cypher_size"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
Ok(DataType::Int64)
}
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
if args.args.len() != 1 {
return Err(datafusion::error::DataFusionError::Execution(
"_cypher_size() requires exactly 1 argument".to_string(),
));
}
match &args.args[0] {
ColumnarValue::Scalar(scalar) => {
let result = cypher_size_scalar(scalar)?;
Ok(ColumnarValue::Scalar(result))
}
ColumnarValue::Array(arr) => {
let mut results: Vec<Option<i64>> = Vec::with_capacity(arr.len());
for i in 0..arr.len() {
if arr.is_null(i) {
results.push(None);
} else {
let scalar = ScalarValue::try_from_array(arr, i)?;
match cypher_size_scalar(&scalar)? {
ScalarValue::Int64(v) => results.push(v),
_ => results.push(None),
}
}
}
let arr: ArrayRef = Arc::new(arrow_array::Int64Array::from(results));
Ok(ColumnarValue::Array(arr))
}
}
}
}
fn cypher_size_scalar(scalar: &ScalarValue) -> DFResult<ScalarValue> {
match scalar {
ScalarValue::Utf8(Some(s)) | ScalarValue::LargeUtf8(Some(s)) => {
Ok(ScalarValue::Int64(Some(s.chars().count() as i64)))
}
ScalarValue::List(arr) => {
if arr.is_empty() || arr.is_null(0) {
Ok(ScalarValue::Int64(None))
} else {
Ok(ScalarValue::Int64(Some(arr.value(0).len() as i64)))
}
}
ScalarValue::LargeList(arr) => {
if arr.is_empty() || arr.is_null(0) {
Ok(ScalarValue::Int64(None))
} else {
Ok(ScalarValue::Int64(Some(arr.value(0).len() as i64)))
}
}
ScalarValue::LargeBinary(Some(b)) => {
if let Ok(uni_val) = uni_common::cypher_value_codec::decode(b) {
match &uni_val {
uni_common::Value::Node(_) => {
Err(datafusion::error::DataFusionError::Execution(
"TypeError: InvalidArgumentValue - length() is not supported for Node values".to_string(),
))
}
uni_common::Value::Edge(_) => {
Err(datafusion::error::DataFusionError::Execution(
"TypeError: InvalidArgumentValue - length() is not supported for Relationship values".to_string(),
))
}
_ => {
let json_val: serde_json::Value = uni_val.into();
match json_val {
serde_json::Value::Array(arr) => Ok(ScalarValue::Int64(Some(arr.len() as i64))),
serde_json::Value::String(s) => {
Ok(ScalarValue::Int64(Some(s.chars().count() as i64)))
}
serde_json::Value::Object(m) => Ok(ScalarValue::Int64(Some(m.len() as i64))),
_ => Ok(ScalarValue::Int64(None)),
}
}
}
} else {
Ok(ScalarValue::Int64(None))
}
}
ScalarValue::Map(arr) => {
if arr.is_empty() || arr.is_null(0) {
Ok(ScalarValue::Int64(None))
} else {
Ok(ScalarValue::Int64(Some(arr.value(0).len() as i64)))
}
}
ScalarValue::Struct(arr) => {
if arr.is_null(0) {
Ok(ScalarValue::Int64(None))
} else {
let schema = arr.fields();
let field_names: Vec<&str> = schema.iter().map(|f| f.name().as_str()).collect();
if field_names.contains(&"_vid") && !field_names.contains(&"relationships") {
return Err(datafusion::error::DataFusionError::Execution(
"TypeError: InvalidArgumentValue - length() is not supported for Node values".to_string(),
));
}
if field_names.contains(&"_eid")
|| (field_names.contains(&"_src") && field_names.contains(&"_dst"))
{
return Err(datafusion::error::DataFusionError::Execution(
"TypeError: InvalidArgumentValue - length() is not supported for Relationship values".to_string(),
));
}
if let Some((rels_idx, _)) = schema
.iter()
.enumerate()
.find(|(_, f)| f.name() == "relationships")
{
let rels_col = arr.column(rels_idx);
if let Some(list_arr) =
rels_col.as_any().downcast_ref::<arrow_array::ListArray>()
{
if list_arr.is_null(0) {
Ok(ScalarValue::Int64(Some(0)))
} else {
Ok(ScalarValue::Int64(Some(list_arr.value(0).len() as i64)))
}
} else {
Ok(ScalarValue::Int64(Some(arr.num_columns() as i64)))
}
} else {
Ok(ScalarValue::Int64(Some(arr.num_columns() as i64)))
}
}
}
ScalarValue::Null
| ScalarValue::Utf8(None)
| ScalarValue::LargeUtf8(None)
| ScalarValue::LargeBinary(None) => Ok(ScalarValue::Int64(None)),
other => Err(datafusion::error::DataFusionError::Execution(format!(
"_cypher_size(): unsupported type {other:?}"
))),
}
}
pub fn create_cypher_list_compare_udf() -> ScalarUDF {
ScalarUDF::new_from_impl(CypherListCompareUdf::new())
}
#[derive(Debug)]
struct CypherListCompareUdf {
signature: Signature,
}
impl CypherListCompareUdf {
fn new() -> Self {
Self {
signature: Signature::any(3, Volatility::Immutable),
}
}
}
impl_udf_eq_hash!(CypherListCompareUdf);
impl ScalarUDFImpl for CypherListCompareUdf {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"_cypher_list_compare"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
Ok(DataType::Boolean)
}
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
let output_type = DataType::Boolean;
invoke_cypher_udf(args, &output_type, |val_args| {
if val_args.len() != 3 {
return Err(datafusion::error::DataFusionError::Execution(
"_cypher_list_compare(): requires 3 arguments (left, right, op)".to_string(),
));
}
let left = &val_args[0];
let right = &val_args[1];
let op_str = match &val_args[2] {
Value::String(s) => s.as_str(),
_ => {
return Err(datafusion::error::DataFusionError::Execution(
"_cypher_list_compare(): op must be a string".to_string(),
));
}
};
let (left_items, right_items) = match (left, right) {
(Value::List(l), Value::List(r)) => (l, r),
(Value::Null, _) | (_, Value::Null) => return Ok(Value::Null),
_ => {
return Err(datafusion::error::DataFusionError::Execution(
"_cypher_list_compare(): both arguments must be lists".to_string(),
));
}
};
let cmp = cypher_list_cmp(left_items, right_items);
let result = match (op_str, cmp) {
(_, None) => Value::Null,
("lt", Some(ord)) => Value::Bool(ord == std::cmp::Ordering::Less),
("lteq", Some(ord)) => Value::Bool(ord != std::cmp::Ordering::Greater),
("gt", Some(ord)) => Value::Bool(ord == std::cmp::Ordering::Greater),
("gteq", Some(ord)) => Value::Bool(ord != std::cmp::Ordering::Less),
_ => {
return Err(datafusion::error::DataFusionError::Execution(format!(
"_cypher_list_compare(): unknown op '{}'",
op_str
)));
}
};
Ok(result)
})
}
}
pub fn create_map_project_udf() -> ScalarUDF {
ScalarUDF::new_from_impl(MapProjectUdf::new())
}
#[derive(Debug)]
struct MapProjectUdf {
signature: Signature,
}
impl MapProjectUdf {
fn new() -> Self {
Self {
signature: Signature::new(TypeSignature::VariadicAny, Volatility::Immutable),
}
}
}
impl_udf_eq_hash!(MapProjectUdf);
impl ScalarUDFImpl for MapProjectUdf {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"_map_project"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
Ok(DataType::LargeBinary)
}
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
let output_type = self.return_type(&[])?;
invoke_cypher_udf(args, &output_type, |val_args| {
let mut result_map = std::collections::HashMap::new();
let mut i = 0;
while i + 1 < val_args.len() {
let key = &val_args[i];
let value = &val_args[i + 1];
if let Some(k) = key.as_str() {
if k == "__all__" {
match value {
Value::Map(map) => {
for (mk, mv) in map {
if !mk.starts_with('_') {
result_map.insert(mk.clone(), mv.clone());
}
}
}
Value::Node(node) => {
for (pk, pv) in &node.properties {
result_map.insert(pk.clone(), pv.clone());
}
}
Value::Edge(edge) => {
for (pk, pv) in &edge.properties {
result_map.insert(pk.clone(), pv.clone());
}
}
_ => {}
}
} else {
result_map.insert(k.to_string(), value.clone());
}
}
i += 2;
}
Ok(Value::Map(result_map))
})
}
}
pub fn create_make_cypher_list_udf() -> ScalarUDF {
ScalarUDF::new_from_impl(MakeCypherListUdf::new())
}
#[derive(Debug)]
struct MakeCypherListUdf {
signature: Signature,
}
impl MakeCypherListUdf {
fn new() -> Self {
Self {
signature: Signature::new(TypeSignature::VariadicAny, Volatility::Immutable),
}
}
}
impl_udf_eq_hash!(MakeCypherListUdf);
impl ScalarUDFImpl for MakeCypherListUdf {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"_make_cypher_list"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
Ok(DataType::LargeBinary)
}
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
let output_type = self.return_type(&[])?;
invoke_cypher_udf(args, &output_type, |val_args| {
Ok(Value::List(val_args.to_vec()))
})
}
}
pub fn create_cypher_in_udf() -> ScalarUDF {
ScalarUDF::new_from_impl(CypherInUdf::new())
}
#[derive(Debug)]
struct CypherInUdf {
signature: Signature,
}
impl CypherInUdf {
fn new() -> Self {
Self {
signature: Signature::any(2, Volatility::Immutable),
}
}
}
impl_udf_eq_hash!(CypherInUdf);
impl ScalarUDFImpl for CypherInUdf {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"_cypher_in"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
Ok(DataType::Boolean)
}
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
invoke_cypher_udf(args, &DataType::Boolean, |vals| {
if vals.len() != 2 {
return Err(datafusion::error::DataFusionError::Execution(
"_cypher_in(): requires 2 arguments".to_string(),
));
}
let element = &vals[0];
let list_val = &vals[1];
if list_val.is_null() {
return Ok(Value::Null);
}
let items = match list_val {
Value::List(items) => items.as_slice(),
_ => {
return Err(datafusion::error::DataFusionError::Execution(format!(
"_cypher_in(): second argument must be a list, got {:?}",
list_val
)));
}
};
if element.is_null() {
return if items.is_empty() {
Ok(Value::Bool(false))
} else {
Ok(Value::Null) };
}
let mut has_null = false;
for item in items {
match cypher_eq(element, item) {
Some(true) => return Ok(Value::Bool(true)),
None => has_null = true,
Some(false) => {}
}
}
if has_null {
Ok(Value::Null) } else {
Ok(Value::Bool(false))
}
})
}
}
pub fn create_cypher_list_concat_udf() -> ScalarUDF {
ScalarUDF::new_from_impl(CypherListConcatUdf::new())
}
#[derive(Debug)]
struct CypherListConcatUdf {
signature: Signature,
}
impl CypherListConcatUdf {
fn new() -> Self {
Self {
signature: Signature::any(2, Volatility::Immutable),
}
}
}
impl_udf_eq_hash!(CypherListConcatUdf);
impl ScalarUDFImpl for CypherListConcatUdf {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"_cypher_list_concat"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
Ok(DataType::LargeBinary)
}
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
invoke_cypher_udf(args, &DataType::LargeBinary, |vals| {
if vals.len() != 2 {
return Err(datafusion::error::DataFusionError::Execution(
"_cypher_list_concat(): requires 2 arguments".to_string(),
));
}
if vals[0].is_null() || vals[1].is_null() {
return Ok(Value::Null);
}
match (&vals[0], &vals[1]) {
(Value::List(left), Value::List(right)) => {
let mut result = left.clone();
result.extend(right.iter().cloned());
Ok(Value::List(result))
}
(Value::List(list), elem) => {
let mut result = list.clone();
result.push(elem.clone());
Ok(Value::List(result))
}
(elem, Value::List(list)) => {
let mut result = vec![elem.clone()];
result.extend(list.iter().cloned());
Ok(Value::List(result))
}
_ => {
crate::query::expr_eval::eval_binary_op(
&vals[0],
&uni_cypher::ast::BinaryOp::Add,
&vals[1],
)
.map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))
}
}
})
}
}
pub fn create_cypher_list_append_udf() -> ScalarUDF {
ScalarUDF::new_from_impl(CypherListAppendUdf::new())
}
#[derive(Debug)]
struct CypherListAppendUdf {
signature: Signature,
}
impl CypherListAppendUdf {
fn new() -> Self {
Self {
signature: Signature::any(2, Volatility::Immutable),
}
}
}
impl_udf_eq_hash!(CypherListAppendUdf);
impl ScalarUDFImpl for CypherListAppendUdf {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"_cypher_list_append"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
Ok(DataType::LargeBinary)
}
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
invoke_cypher_udf(args, &DataType::LargeBinary, |vals| {
if vals.len() != 2 {
return Err(datafusion::error::DataFusionError::Execution(
"_cypher_list_append(): requires 2 arguments".to_string(),
));
}
let left = &vals[0];
let right = &vals[1];
if left.is_null() || right.is_null() {
return Ok(Value::Null);
}
match (left, right) {
(Value::List(list), elem) => {
let mut result = list.clone();
result.push(elem.clone());
Ok(Value::List(result))
}
(elem, Value::List(list)) => {
let mut result = vec![elem.clone()];
result.extend(list.iter().cloned());
Ok(Value::List(result))
}
_ => Err(datafusion::error::DataFusionError::Execution(format!(
"_cypher_list_append(): at least one argument must be a list, got {:?} and {:?}",
left, right
))),
}
})
}
}
pub fn create_cypher_list_slice_udf() -> ScalarUDF {
ScalarUDF::new_from_impl(CypherListSliceUdf::new())
}
#[derive(Debug)]
struct CypherListSliceUdf {
signature: Signature,
}
impl CypherListSliceUdf {
fn new() -> Self {
Self {
signature: Signature::any(3, Volatility::Immutable),
}
}
}
impl_udf_eq_hash!(CypherListSliceUdf);
impl ScalarUDFImpl for CypherListSliceUdf {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"_cypher_list_slice"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
Ok(DataType::LargeBinary)
}
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
invoke_cypher_udf(args, &DataType::LargeBinary, |vals| {
if vals.len() != 3 {
return Err(datafusion::error::DataFusionError::Execution(
"_cypher_list_slice(): requires 3 arguments (list, start, end)".to_string(),
));
}
if vals[0].is_null() {
return Ok(Value::Null);
}
let list = match &vals[0] {
Value::List(l) => l,
_ => {
return Err(datafusion::error::DataFusionError::Execution(format!(
"_cypher_list_slice(): first argument must be a list, got {:?}",
vals[0]
)));
}
};
if vals[1].is_null() || vals[2].is_null() {
return Ok(Value::Null);
}
let len = list.len() as i64;
let raw_start = match &vals[1] {
Value::Int(i) => *i,
_ => 0,
};
let raw_end = match &vals[2] {
Value::Int(i) => *i,
_ => len,
};
let start = if raw_start < 0 {
(len + raw_start).max(0) as usize
} else {
(raw_start).min(len) as usize
};
let end = if raw_end == i64::MAX {
len as usize
} else if raw_end < 0 {
(len + raw_end).max(0) as usize
} else {
(raw_end).min(len) as usize
};
if start >= end {
return Ok(Value::List(vec![]));
}
Ok(Value::List(list[start..end.min(list.len())].to_vec()))
})
}
}
pub fn create_cypher_reverse_udf() -> ScalarUDF {
ScalarUDF::new_from_impl(CypherReverseUdf::new())
}
#[derive(Debug)]
struct CypherReverseUdf {
signature: Signature,
}
impl CypherReverseUdf {
fn new() -> Self {
Self {
signature: Signature::any(1, Volatility::Immutable),
}
}
}
impl_udf_eq_hash!(CypherReverseUdf);
impl ScalarUDFImpl for CypherReverseUdf {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"_cypher_reverse"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
Ok(DataType::LargeBinary)
}
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
invoke_cypher_udf(args, &DataType::LargeBinary, |vals| {
if vals.len() != 1 {
return Err(datafusion::error::DataFusionError::Execution(
"_cypher_reverse(): requires exactly 1 argument".to_string(),
));
}
match &vals[0] {
Value::Null => Ok(Value::Null),
Value::String(s) => Ok(Value::String(s.chars().rev().collect())),
Value::List(l) => {
let mut reversed = l.clone();
reversed.reverse();
Ok(Value::List(reversed))
}
other => Err(datafusion::error::DataFusionError::Execution(format!(
"_cypher_reverse(): expected string or list, got {:?}",
other
))),
}
})
}
}
pub fn create_cypher_substring_udf() -> ScalarUDF {
ScalarUDF::new_from_impl(CypherSubstringUdf::new())
}
#[derive(Debug)]
struct CypherSubstringUdf {
signature: Signature,
}
impl CypherSubstringUdf {
fn new() -> Self {
Self {
signature: Signature::variadic_any(Volatility::Immutable),
}
}
}
impl_udf_eq_hash!(CypherSubstringUdf);
impl ScalarUDFImpl for CypherSubstringUdf {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"_cypher_substring"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
Ok(DataType::Utf8)
}
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
invoke_cypher_udf(args, &DataType::Utf8, |vals| {
if vals.len() < 2 || vals.len() > 3 {
return Err(datafusion::error::DataFusionError::Execution(
"_cypher_substring(): requires 2 or 3 arguments".to_string(),
));
}
if vals.iter().any(|v| v.is_null()) {
return Ok(Value::Null);
}
let s = match &vals[0] {
Value::String(s) => s.as_str(),
other => {
return Err(datafusion::error::DataFusionError::Execution(format!(
"_cypher_substring(): first argument must be a string, got {:?}",
other
)));
}
};
let start = match &vals[1] {
Value::Int(i) => *i,
other => {
return Err(datafusion::error::DataFusionError::Execution(format!(
"_cypher_substring(): second argument must be an integer, got {:?}",
other
)));
}
};
let chars: Vec<char> = s.chars().collect();
let len = chars.len() as i64;
let start_idx = start.max(0).min(len) as usize;
let end_idx = if vals.len() == 3 {
let length = match &vals[2] {
Value::Int(i) => *i,
other => {
return Err(datafusion::error::DataFusionError::Execution(format!(
"_cypher_substring(): third argument must be an integer, got {:?}",
other
)));
}
};
if length < 0 {
return Err(datafusion::error::DataFusionError::Execution(
"ArgumentError: NegativeIntegerArgument - substring length must be non-negative".to_string(),
));
}
(start_idx as i64 + length).min(len) as usize
} else {
len as usize
};
Ok(Value::String(chars[start_idx..end_idx].iter().collect()))
})
}
}
pub fn create_cypher_split_udf() -> ScalarUDF {
ScalarUDF::new_from_impl(CypherSplitUdf::new())
}
#[derive(Debug)]
struct CypherSplitUdf {
signature: Signature,
}
impl CypherSplitUdf {
fn new() -> Self {
Self {
signature: Signature::any(2, Volatility::Immutable),
}
}
}
impl_udf_eq_hash!(CypherSplitUdf);
impl ScalarUDFImpl for CypherSplitUdf {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"_cypher_split"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
Ok(DataType::LargeBinary)
}
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
invoke_cypher_udf(args, &DataType::LargeBinary, |vals| {
if vals.len() != 2 {
return Err(datafusion::error::DataFusionError::Execution(
"_cypher_split(): requires exactly 2 arguments".to_string(),
));
}
if vals.iter().any(|v| v.is_null()) {
return Ok(Value::Null);
}
let s = match &vals[0] {
Value::String(s) => s.clone(),
other => {
return Err(datafusion::error::DataFusionError::Execution(format!(
"_cypher_split(): first argument must be a string, got {:?}",
other
)));
}
};
let delimiter = match &vals[1] {
Value::String(d) => d.clone(),
other => {
return Err(datafusion::error::DataFusionError::Execution(format!(
"_cypher_split(): second argument must be a string, got {:?}",
other
)));
}
};
let parts: Vec<Value> = s
.split(&delimiter)
.map(|p| Value::String(p.to_string()))
.collect();
Ok(Value::List(parts))
})
}
}
pub fn create_cypher_list_to_cv_udf() -> ScalarUDF {
ScalarUDF::new_from_impl(CypherListToCvUdf::new())
}
#[derive(Debug)]
struct CypherListToCvUdf {
signature: Signature,
}
impl CypherListToCvUdf {
fn new() -> Self {
Self {
signature: Signature::any(1, Volatility::Immutable),
}
}
}
impl_udf_eq_hash!(CypherListToCvUdf);
impl ScalarUDFImpl for CypherListToCvUdf {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"_cypher_list_to_cv"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
Ok(DataType::LargeBinary)
}
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
invoke_cypher_udf(args, &DataType::LargeBinary, |vals| {
if vals.len() != 1 {
return Err(datafusion::error::DataFusionError::Execution(
"_cypher_list_to_cv(): requires exactly 1 argument".to_string(),
));
}
Ok(vals[0].clone())
})
}
}
pub fn create_cypher_scalar_to_cv_udf() -> ScalarUDF {
ScalarUDF::new_from_impl(CypherScalarToCvUdf::new())
}
#[derive(Debug)]
struct CypherScalarToCvUdf {
signature: Signature,
}
impl CypherScalarToCvUdf {
fn new() -> Self {
Self {
signature: Signature::any(1, Volatility::Immutable),
}
}
}
impl_udf_eq_hash!(CypherScalarToCvUdf);
impl ScalarUDFImpl for CypherScalarToCvUdf {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"_cypher_scalar_to_cv"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
Ok(DataType::LargeBinary)
}
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
invoke_cypher_udf(args, &DataType::LargeBinary, |vals| {
if vals.len() != 1 {
return Err(datafusion::error::DataFusionError::Execution(
"_cypher_scalar_to_cv(): requires exactly 1 argument".to_string(),
));
}
Ok(vals[0].clone())
})
}
}
pub fn create_cypher_tail_udf() -> ScalarUDF {
ScalarUDF::new_from_impl(CypherTailUdf::new())
}
#[derive(Debug)]
struct CypherTailUdf {
signature: Signature,
}
impl CypherTailUdf {
fn new() -> Self {
Self {
signature: Signature::any(1, Volatility::Immutable),
}
}
}
impl_udf_eq_hash!(CypherTailUdf);
impl ScalarUDFImpl for CypherTailUdf {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"_cypher_tail"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
Ok(DataType::LargeBinary)
}
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
invoke_cypher_udf(args, &DataType::LargeBinary, |vals| {
if vals.len() != 1 {
return Err(datafusion::error::DataFusionError::Execution(
"_cypher_tail(): requires exactly 1 argument".to_string(),
));
}
match &vals[0] {
Value::Null => Ok(Value::Null),
Value::List(l) => {
if l.is_empty() {
Ok(Value::List(vec![]))
} else {
Ok(Value::List(l[1..].to_vec()))
}
}
other => Err(datafusion::error::DataFusionError::Execution(format!(
"_cypher_tail(): expected list, got {:?}",
other
))),
}
})
}
}
pub fn create_cypher_head_udf() -> ScalarUDF {
ScalarUDF::new_from_impl(CypherHeadUdf::new())
}
#[derive(Debug)]
struct CypherHeadUdf {
signature: Signature,
}
impl CypherHeadUdf {
fn new() -> Self {
Self {
signature: Signature::any(1, Volatility::Immutable),
}
}
}
impl_udf_eq_hash!(CypherHeadUdf);
impl ScalarUDFImpl for CypherHeadUdf {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"head"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
Ok(DataType::LargeBinary)
}
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
invoke_cypher_udf(args, &DataType::LargeBinary, |vals| {
if vals.len() != 1 {
return Err(datafusion::error::DataFusionError::Execution(
"head(): requires exactly 1 argument".to_string(),
));
}
match &vals[0] {
Value::Null => Ok(Value::Null),
Value::List(l) => Ok(l.first().cloned().unwrap_or(Value::Null)),
other => Err(datafusion::error::DataFusionError::Execution(format!(
"head(): expected list, got {:?}",
other
))),
}
})
}
}
pub fn create_cypher_last_udf() -> ScalarUDF {
ScalarUDF::new_from_impl(CypherLastUdf::new())
}
#[derive(Debug)]
struct CypherLastUdf {
signature: Signature,
}
impl CypherLastUdf {
fn new() -> Self {
Self {
signature: Signature::any(1, Volatility::Immutable),
}
}
}
impl_udf_eq_hash!(CypherLastUdf);
impl ScalarUDFImpl for CypherLastUdf {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"last"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
Ok(DataType::LargeBinary)
}
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
invoke_cypher_udf(args, &DataType::LargeBinary, |vals| {
if vals.len() != 1 {
return Err(datafusion::error::DataFusionError::Execution(
"last(): requires exactly 1 argument".to_string(),
));
}
match &vals[0] {
Value::Null => Ok(Value::Null),
Value::List(l) => Ok(l.last().cloned().unwrap_or(Value::Null)),
other => Err(datafusion::error::DataFusionError::Execution(format!(
"last(): expected list, got {:?}",
other
))),
}
})
}
}
fn cypher_list_cmp(left: &[Value], right: &[Value]) -> Option<std::cmp::Ordering> {
let min_len = left.len().min(right.len());
for i in 0..min_len {
let cmp = cypher_value_cmp(&left[i], &right[i])?;
if cmp != std::cmp::Ordering::Equal {
return Some(cmp);
}
}
Some(left.len().cmp(&right.len()))
}
fn cypher_value_cmp(a: &Value, b: &Value) -> Option<std::cmp::Ordering> {
match (a, b) {
(Value::Null, Value::Null) => Some(std::cmp::Ordering::Equal),
(Value::Null, _) | (_, Value::Null) => None,
(Value::Int(l), Value::Int(r)) => Some(l.cmp(r)),
(Value::Float(l), Value::Float(r)) => l.partial_cmp(r),
(Value::Int(l), Value::Float(r)) => (*l as f64).partial_cmp(r),
(Value::Float(l), Value::Int(r)) => l.partial_cmp(&(*r as f64)),
(Value::String(l), Value::String(r)) => Some(l.cmp(r)),
(Value::Bool(l), Value::Bool(r)) => Some(l.cmp(r)),
(Value::List(l), Value::List(r)) => cypher_list_cmp(l, r),
_ => None, }
}
struct CypherToFloat64Udf {
signature: Signature,
}
impl CypherToFloat64Udf {
fn new() -> Self {
Self {
signature: Signature::new(TypeSignature::Any(1), Volatility::Immutable),
}
}
}
impl_udf_eq_hash!(CypherToFloat64Udf);
impl std::fmt::Debug for CypherToFloat64Udf {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CypherToFloat64Udf").finish()
}
}
impl ScalarUDFImpl for CypherToFloat64Udf {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"_cypher_to_float64"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _args: &[DataType]) -> DFResult<DataType> {
Ok(DataType::Float64)
}
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
if args.args.len() != 1 {
return Err(datafusion::error::DataFusionError::Execution(
"_cypher_to_float64 requires exactly 1 argument".into(),
));
}
match &args.args[0] {
ColumnarValue::Scalar(scalar) => {
let f = match scalar {
ScalarValue::LargeBinary(Some(bytes)) => cv_bytes_as_f64(bytes),
ScalarValue::Int64(Some(i)) => Some(*i as f64),
ScalarValue::Int32(Some(i)) => Some(*i as f64),
ScalarValue::Float64(Some(f)) => Some(*f),
ScalarValue::Float32(Some(f)) => Some(*f as f64),
_ => None,
};
Ok(ColumnarValue::Scalar(ScalarValue::Float64(f)))
}
ColumnarValue::Array(arr) => {
let len = arr.len();
let mut builder = arrow::array::Float64Builder::with_capacity(len);
match arr.data_type() {
DataType::LargeBinary => {
let lb = arr.as_any().downcast_ref::<LargeBinaryArray>().unwrap();
for i in 0..len {
if lb.is_null(i) {
builder.append_null();
} else {
match cv_bytes_as_f64(lb.value(i)) {
Some(f) => builder.append_value(f),
None => builder.append_null(),
}
}
}
}
DataType::Int64 => {
let int_arr = arr.as_any().downcast_ref::<Int64Array>().unwrap();
for i in 0..len {
if int_arr.is_null(i) {
builder.append_null();
} else {
builder.append_value(int_arr.value(i) as f64);
}
}
}
DataType::Float64 => {
let f_arr = arr.as_any().downcast_ref::<Float64Array>().unwrap();
for i in 0..len {
if f_arr.is_null(i) {
builder.append_null();
} else {
builder.append_value(f_arr.value(i));
}
}
}
_ => {
for _ in 0..len {
builder.append_null();
}
}
}
Ok(ColumnarValue::Array(Arc::new(builder.finish())))
}
}
}
}
fn create_cypher_to_float64_udf() -> ScalarUDF {
ScalarUDF::from(CypherToFloat64Udf::new())
}
pub(crate) fn cypher_to_float64_expr(
arg: datafusion::logical_expr::Expr,
) -> datafusion::logical_expr::Expr {
datafusion::logical_expr::Expr::ScalarFunction(
datafusion::logical_expr::expr::ScalarFunction::new_udf(
Arc::new(create_cypher_to_float64_udf()),
vec![arg],
),
)
}
pub(crate) fn cypher_to_float64_udf() -> datafusion::logical_expr::ScalarUDF {
create_cypher_to_float64_udf()
}
fn cypher_type_rank(val: &Value) -> u8 {
match val {
Value::Null => 0,
Value::List(_) => 1,
Value::String(_) => 2,
Value::Bool(_) => 3,
Value::Int(_) | Value::Float(_) => 4,
_ => 5, }
}
fn cypher_cross_type_cmp(a: &Value, b: &Value) -> std::cmp::Ordering {
use std::cmp::Ordering;
let ra = cypher_type_rank(a);
let rb = cypher_type_rank(b);
if ra != rb {
return ra.cmp(&rb);
}
match (a, b) {
(Value::Int(l), Value::Int(r)) => l.cmp(r),
(Value::Float(l), Value::Float(r)) => l.partial_cmp(r).unwrap_or(Ordering::Equal),
(Value::Int(l), Value::Float(r)) => (*l as f64).partial_cmp(r).unwrap_or(Ordering::Equal),
(Value::Float(l), Value::Int(r)) => l.partial_cmp(&(*r as f64)).unwrap_or(Ordering::Equal),
(Value::String(l), Value::String(r)) => l.cmp(r),
(Value::Bool(l), Value::Bool(r)) => l.cmp(r),
(Value::List(l), Value::List(r)) => cypher_list_cmp(l, r).unwrap_or(Ordering::Equal),
_ => Ordering::Equal,
}
}
fn scalar_binary_to_value(bytes: &[u8]) -> Value {
uni_common::cypher_value_codec::decode(bytes).unwrap_or(Value::Null)
}
use datafusion::logical_expr::{Accumulator as DfAccumulator, AggregateUDF, AggregateUDFImpl};
#[derive(Debug, Clone)]
struct CypherMinMaxUdaf {
name: String,
signature: Signature,
is_max: bool,
}
impl CypherMinMaxUdaf {
fn new(is_max: bool) -> Self {
let name = if is_max { "_cypher_max" } else { "_cypher_min" };
Self {
name: name.to_string(),
signature: Signature::new(TypeSignature::Any(1), Volatility::Immutable),
is_max,
}
}
}
impl PartialEq for CypherMinMaxUdaf {
fn eq(&self, other: &Self) -> bool {
self.name == other.name
}
}
impl Eq for CypherMinMaxUdaf {}
impl Hash for CypherMinMaxUdaf {
fn hash<H: Hasher>(&self, state: &mut H) {
self.name.hash(state);
}
}
impl AggregateUDFImpl for CypherMinMaxUdaf {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
&self.name
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, args: &[DataType]) -> DFResult<DataType> {
Ok(args.first().cloned().unwrap_or(DataType::LargeBinary))
}
fn accumulator(
&self,
acc_args: datafusion::logical_expr::function::AccumulatorArgs,
) -> DFResult<Box<dyn DfAccumulator>> {
Ok(Box::new(CypherMinMaxAccumulator {
current: None,
is_max: self.is_max,
return_type: acc_args.return_field.data_type().clone(),
}))
}
fn state_fields(
&self,
args: datafusion::logical_expr::function::StateFieldsArgs,
) -> DFResult<Vec<Arc<arrow::datatypes::Field>>> {
Ok(vec![Arc::new(arrow::datatypes::Field::new(
args.name,
DataType::LargeBinary,
true,
))])
}
}
#[derive(Debug)]
struct CypherMinMaxAccumulator {
current: Option<Value>,
is_max: bool,
return_type: DataType,
}
impl DfAccumulator for CypherMinMaxAccumulator {
fn update_batch(&mut self, values: &[ArrayRef]) -> DFResult<()> {
let arr = &values[0];
match arr.data_type() {
DataType::LargeBinary => {
let lb = arr.as_any().downcast_ref::<LargeBinaryArray>().unwrap();
for i in 0..lb.len() {
if lb.is_null(i) {
continue;
}
let val = scalar_binary_to_value(lb.value(i));
if val.is_null() {
continue;
}
self.current = Some(match self.current.take() {
None => val,
Some(cur) => {
let ord = cypher_cross_type_cmp(&val, &cur);
if (self.is_max && ord == std::cmp::Ordering::Greater)
|| (!self.is_max && ord == std::cmp::Ordering::Less)
{
val
} else {
cur
}
}
});
}
}
_ => {
for i in 0..arr.len() {
if arr.is_null(i) {
continue;
}
let sv = ScalarValue::try_from_array(arr, i).map_err(|e| {
datafusion::error::DataFusionError::Execution(e.to_string())
})?;
let val = scalar_to_value(&sv)?;
if val.is_null() {
continue;
}
self.current = Some(match self.current.take() {
None => val,
Some(cur) => {
let ord = cypher_cross_type_cmp(&val, &cur);
if (self.is_max && ord == std::cmp::Ordering::Greater)
|| (!self.is_max && ord == std::cmp::Ordering::Less)
{
val
} else {
cur
}
}
});
}
}
}
Ok(())
}
fn evaluate(&mut self) -> DFResult<ScalarValue> {
match &self.current {
None => {
ScalarValue::try_from(&self.return_type)
.map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))
}
Some(val) => {
if matches!(self.return_type, DataType::LargeBinary) {
let bytes = uni_common::cypher_value_codec::encode(val);
return Ok(ScalarValue::LargeBinary(Some(bytes)));
}
match val {
Value::Int(i) => match &self.return_type {
DataType::Int64 => Ok(ScalarValue::Int64(Some(*i))),
DataType::UInt64 => Ok(ScalarValue::UInt64(Some(*i as u64))),
_ => {
let bytes = uni_common::cypher_value_codec::encode(val);
Ok(ScalarValue::LargeBinary(Some(bytes)))
}
},
Value::Float(f) => match &self.return_type {
DataType::Float64 => Ok(ScalarValue::Float64(Some(*f))),
_ => {
let bytes = uni_common::cypher_value_codec::encode(val);
Ok(ScalarValue::LargeBinary(Some(bytes)))
}
},
Value::String(s) => match &self.return_type {
DataType::Utf8 => Ok(ScalarValue::Utf8(Some(s.clone()))),
DataType::LargeUtf8 => Ok(ScalarValue::LargeUtf8(Some(s.clone()))),
_ => {
let bytes = uni_common::cypher_value_codec::encode(val);
Ok(ScalarValue::LargeBinary(Some(bytes)))
}
},
Value::Bool(b) => match &self.return_type {
DataType::Boolean => Ok(ScalarValue::Boolean(Some(*b))),
_ => {
let bytes = uni_common::cypher_value_codec::encode(val);
Ok(ScalarValue::LargeBinary(Some(bytes)))
}
},
_ => {
let bytes = uni_common::cypher_value_codec::encode(val);
Ok(ScalarValue::LargeBinary(Some(bytes)))
}
}
}
}
}
fn size(&self) -> usize {
std::mem::size_of_val(self) + self.current.as_ref().map_or(0, |_| 64)
}
fn state(&mut self) -> DFResult<Vec<ScalarValue>> {
Ok(vec![self.evaluate()?])
}
fn merge_batch(&mut self, states: &[ArrayRef]) -> DFResult<()> {
self.update_batch(states)
}
}
pub(crate) fn create_cypher_min_udaf() -> AggregateUDF {
AggregateUDF::from(CypherMinMaxUdaf::new(false))
}
pub(crate) fn create_cypher_max_udaf() -> AggregateUDF {
AggregateUDF::from(CypherMinMaxUdaf::new(true))
}
#[derive(Debug, Clone)]
struct CypherSumUdaf {
signature: Signature,
}
impl CypherSumUdaf {
fn new() -> Self {
Self {
signature: Signature::new(TypeSignature::Any(1), Volatility::Immutable),
}
}
}
impl PartialEq for CypherSumUdaf {
fn eq(&self, other: &Self) -> bool {
self.signature == other.signature
}
}
impl Eq for CypherSumUdaf {}
impl Hash for CypherSumUdaf {
fn hash<H: Hasher>(&self, state: &mut H) {
self.name().hash(state);
}
}
impl AggregateUDFImpl for CypherSumUdaf {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"_cypher_sum"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _args: &[DataType]) -> DFResult<DataType> {
Ok(DataType::LargeBinary)
}
fn accumulator(
&self,
_acc_args: datafusion::logical_expr::function::AccumulatorArgs,
) -> DFResult<Box<dyn DfAccumulator>> {
Ok(Box::new(CypherSumAccumulator {
sum: 0.0,
all_ints: true,
int_sum: 0i64,
has_value: false,
}))
}
fn state_fields(
&self,
args: datafusion::logical_expr::function::StateFieldsArgs,
) -> DFResult<Vec<Arc<arrow::datatypes::Field>>> {
Ok(vec![
Arc::new(arrow::datatypes::Field::new(
format!("{}_sum", args.name),
DataType::Float64,
true,
)),
Arc::new(arrow::datatypes::Field::new(
format!("{}_int_sum", args.name),
DataType::Int64,
true,
)),
Arc::new(arrow::datatypes::Field::new(
format!("{}_all_ints", args.name),
DataType::Boolean,
true,
)),
Arc::new(arrow::datatypes::Field::new(
format!("{}_has_value", args.name),
DataType::Boolean,
true,
)),
])
}
}
#[derive(Debug)]
struct CypherSumAccumulator {
sum: f64,
all_ints: bool,
int_sum: i64,
has_value: bool,
}
impl DfAccumulator for CypherSumAccumulator {
fn update_batch(&mut self, values: &[ArrayRef]) -> DFResult<()> {
let arr = &values[0];
for i in 0..arr.len() {
if arr.is_null(i) {
continue;
}
match arr.data_type() {
DataType::LargeBinary => {
let lb = arr.as_any().downcast_ref::<LargeBinaryArray>().unwrap();
let bytes = lb.value(i);
use uni_common::cypher_value_codec::{
TAG_FLOAT, TAG_INT, decode_float, decode_int, peek_tag,
};
match peek_tag(bytes) {
Some(TAG_INT) => {
if let Some(v) = decode_int(bytes) {
self.sum += v as f64;
self.int_sum = self.int_sum.wrapping_add(v);
self.has_value = true;
}
}
Some(TAG_FLOAT) => {
if let Some(v) = decode_float(bytes) {
self.sum += v;
self.all_ints = false;
self.has_value = true;
}
}
_ => {} }
}
DataType::Int64 => {
let a = arr.as_any().downcast_ref::<Int64Array>().unwrap();
let v = a.value(i);
self.sum += v as f64;
self.int_sum = self.int_sum.wrapping_add(v);
self.has_value = true;
}
DataType::Float64 => {
let a = arr.as_any().downcast_ref::<Float64Array>().unwrap();
self.sum += a.value(i);
self.all_ints = false;
self.has_value = true;
}
_ => {}
}
}
Ok(())
}
fn evaluate(&mut self) -> DFResult<ScalarValue> {
if !self.has_value {
return Ok(ScalarValue::LargeBinary(None));
}
let val = if self.all_ints {
Value::Int(self.int_sum)
} else {
Value::Float(self.sum)
};
let bytes = uni_common::cypher_value_codec::encode(&val);
Ok(ScalarValue::LargeBinary(Some(bytes)))
}
fn size(&self) -> usize {
std::mem::size_of_val(self)
}
fn state(&mut self) -> DFResult<Vec<ScalarValue>> {
Ok(vec![
ScalarValue::Float64(Some(self.sum)),
ScalarValue::Int64(Some(self.int_sum)),
ScalarValue::Boolean(Some(self.all_ints)),
ScalarValue::Boolean(Some(self.has_value)),
])
}
fn merge_batch(&mut self, states: &[ArrayRef]) -> DFResult<()> {
let sum_arr = states[0].as_any().downcast_ref::<Float64Array>().unwrap();
let int_sum_arr = states[1].as_any().downcast_ref::<Int64Array>().unwrap();
let all_ints_arr = states[2].as_any().downcast_ref::<BooleanArray>().unwrap();
let has_value_arr = states[3].as_any().downcast_ref::<BooleanArray>().unwrap();
for i in 0..sum_arr.len() {
if !has_value_arr.is_null(i) && has_value_arr.value(i) {
self.sum += sum_arr.value(i);
self.int_sum = self.int_sum.wrapping_add(int_sum_arr.value(i));
if !all_ints_arr.value(i) {
self.all_ints = false;
}
self.has_value = true;
}
}
Ok(())
}
}
pub(crate) fn create_cypher_sum_udaf() -> AggregateUDF {
AggregateUDF::from(CypherSumUdaf::new())
}
#[derive(Debug, Clone)]
struct CypherCollectUdaf {
signature: Signature,
}
impl CypherCollectUdaf {
fn new() -> Self {
Self {
signature: Signature::new(TypeSignature::Any(1), Volatility::Immutable),
}
}
}
impl PartialEq for CypherCollectUdaf {
fn eq(&self, other: &Self) -> bool {
self.signature == other.signature
}
}
impl Eq for CypherCollectUdaf {}
impl Hash for CypherCollectUdaf {
fn hash<H: Hasher>(&self, state: &mut H) {
self.name().hash(state);
}
}
impl AggregateUDFImpl for CypherCollectUdaf {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"_cypher_collect"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _args: &[DataType]) -> DFResult<DataType> {
Ok(DataType::LargeBinary)
}
fn accumulator(
&self,
acc_args: datafusion::logical_expr::function::AccumulatorArgs,
) -> DFResult<Box<dyn DfAccumulator>> {
Ok(Box::new(CypherCollectAccumulator {
values: Vec::new(),
distinct: acc_args.is_distinct,
}))
}
fn state_fields(
&self,
args: datafusion::logical_expr::function::StateFieldsArgs,
) -> DFResult<Vec<Arc<arrow::datatypes::Field>>> {
Ok(vec![Arc::new(arrow::datatypes::Field::new(
args.name,
DataType::LargeBinary,
true,
))])
}
}
#[derive(Debug)]
struct CypherCollectAccumulator {
values: Vec<Value>,
distinct: bool,
}
impl DfAccumulator for CypherCollectAccumulator {
fn update_batch(&mut self, values: &[ArrayRef]) -> DFResult<()> {
let arr = &values[0];
for i in 0..arr.len() {
if arr.is_null(i) {
continue;
}
if let Some(struct_arr) = arr.as_any().downcast_ref::<arrow::array::StructArray>()
&& struct_arr.num_columns() > 0
&& struct_arr.column(0).is_null(i)
{
continue;
}
let sv = ScalarValue::try_from_array(arr, i)
.map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))?;
let val = scalar_to_value(&sv)?;
if val.is_null() {
continue;
}
if self.distinct {
let repr = val.to_string();
if self.values.iter().any(|v| v.to_string() == repr) {
continue;
}
}
self.values.push(val);
}
Ok(())
}
fn evaluate(&mut self) -> DFResult<ScalarValue> {
let val = Value::List(self.values.clone());
let bytes = uni_common::cypher_value_codec::encode(&val);
Ok(ScalarValue::LargeBinary(Some(bytes)))
}
fn size(&self) -> usize {
std::mem::size_of_val(self) + self.values.len() * 64
}
fn state(&mut self) -> DFResult<Vec<ScalarValue>> {
Ok(vec![self.evaluate()?])
}
fn merge_batch(&mut self, states: &[ArrayRef]) -> DFResult<()> {
let arr = &states[0];
if let Some(lb) = arr.as_any().downcast_ref::<LargeBinaryArray>() {
for i in 0..lb.len() {
if lb.is_null(i) {
continue;
}
let val = scalar_binary_to_value(lb.value(i));
if let Value::List(items) = val {
for item in items {
if !item.is_null() {
if self.distinct {
let repr = item.to_string();
if self.values.iter().any(|v| v.to_string() == repr) {
continue;
}
}
self.values.push(item);
}
}
}
}
}
Ok(())
}
}
pub(crate) fn create_cypher_collect_udaf() -> AggregateUDF {
AggregateUDF::from(CypherCollectUdaf::new())
}
pub(crate) fn create_cypher_collect_expr(
arg: datafusion::logical_expr::Expr,
distinct: bool,
) -> datafusion::logical_expr::Expr {
let udaf = Arc::new(create_cypher_collect_udaf());
if distinct {
datafusion::logical_expr::Expr::AggregateFunction(
datafusion::logical_expr::expr::AggregateFunction::new_udf(
udaf,
vec![arg],
true, None,
vec![],
None,
),
)
} else {
udaf.call(vec![arg])
}
}
#[derive(Debug, Clone)]
struct CypherPercentileDiscUdaf {
signature: Signature,
}
impl CypherPercentileDiscUdaf {
fn new() -> Self {
Self {
signature: Signature::new(TypeSignature::Any(2), Volatility::Immutable),
}
}
}
impl PartialEq for CypherPercentileDiscUdaf {
fn eq(&self, other: &Self) -> bool {
self.signature == other.signature
}
}
impl Eq for CypherPercentileDiscUdaf {}
impl Hash for CypherPercentileDiscUdaf {
fn hash<H: Hasher>(&self, state: &mut H) {
self.name().hash(state);
}
}
impl AggregateUDFImpl for CypherPercentileDiscUdaf {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"percentiledisc"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _args: &[DataType]) -> DFResult<DataType> {
Ok(DataType::Float64)
}
fn accumulator(
&self,
_acc_args: datafusion::logical_expr::function::AccumulatorArgs,
) -> DFResult<Box<dyn DfAccumulator>> {
Ok(Box::new(CypherPercentileDiscAccumulator {
values: Vec::new(),
percentile: None,
}))
}
fn state_fields(
&self,
args: datafusion::logical_expr::function::StateFieldsArgs,
) -> DFResult<Vec<Arc<arrow::datatypes::Field>>> {
Ok(vec![
Arc::new(arrow::datatypes::Field::new(
format!("{}_values", args.name),
DataType::List(Arc::new(arrow::datatypes::Field::new(
"item",
DataType::Float64,
true,
))),
true,
)),
Arc::new(arrow::datatypes::Field::new(
format!("{}_percentile", args.name),
DataType::Float64,
true,
)),
])
}
}
#[derive(Debug)]
struct CypherPercentileDiscAccumulator {
values: Vec<f64>,
percentile: Option<f64>,
}
impl CypherPercentileDiscAccumulator {
fn extract_f64(arr: &ArrayRef, i: usize) -> Option<f64> {
if arr.is_null(i) {
return None;
}
match arr.data_type() {
DataType::LargeBinary => {
let lb = arr.as_any().downcast_ref::<LargeBinaryArray>()?;
cv_bytes_as_f64(lb.value(i))
}
DataType::Int64 => {
let a = arr.as_any().downcast_ref::<Int64Array>()?;
Some(a.value(i) as f64)
}
DataType::Float64 => {
let a = arr.as_any().downcast_ref::<Float64Array>()?;
Some(a.value(i))
}
DataType::Int32 => {
let a = arr.as_any().downcast_ref::<Int32Array>()?;
Some(a.value(i) as f64)
}
DataType::Float32 => {
let a = arr.as_any().downcast_ref::<Float32Array>()?;
Some(a.value(i) as f64)
}
_ => None,
}
}
fn extract_percentile(arr: &ArrayRef, i: usize) -> Option<f64> {
if arr.is_null(i) {
return None;
}
match arr.data_type() {
DataType::Float64 => {
let a = arr.as_any().downcast_ref::<Float64Array>()?;
Some(a.value(i))
}
DataType::Int64 => {
let a = arr.as_any().downcast_ref::<Int64Array>()?;
Some(a.value(i) as f64)
}
DataType::LargeBinary => {
let lb = arr.as_any().downcast_ref::<LargeBinaryArray>()?;
cv_bytes_as_f64(lb.value(i))
}
_ => None,
}
}
}
impl DfAccumulator for CypherPercentileDiscAccumulator {
fn update_batch(&mut self, values: &[ArrayRef]) -> DFResult<()> {
let expr_arr = &values[0];
let pct_arr = &values[1];
for i in 0..expr_arr.len() {
if self.percentile.is_none()
&& let Some(p) = Self::extract_percentile(pct_arr, i)
{
if !(0.0..=1.0).contains(&p) {
return Err(datafusion::error::DataFusionError::Execution(
"ArgumentError: NumberOutOfRange - percentileDisc(): percentile value must be between 0.0 and 1.0".to_string(),
));
}
self.percentile = Some(p);
}
if let Some(f) = Self::extract_f64(expr_arr, i) {
self.values.push(f);
}
}
Ok(())
}
fn evaluate(&mut self) -> DFResult<ScalarValue> {
let pct = match self.percentile {
Some(p) if !(0.0..=1.0).contains(&p) => {
return Err(datafusion::error::DataFusionError::Execution(
"ArgumentError: NumberOutOfRange - percentileDisc(): percentile value must be between 0.0 and 1.0".to_string(),
));
}
Some(p) => p,
None => 0.0,
};
if self.values.is_empty() {
return Ok(ScalarValue::Float64(None));
}
self.values
.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let n = self.values.len();
let idx = (pct * (n as f64 - 1.0)).round() as usize;
let idx = idx.min(n - 1);
let result = self.values[idx];
Ok(ScalarValue::Float64(Some(result)))
}
fn size(&self) -> usize {
std::mem::size_of_val(self) + self.values.capacity() * 8
}
fn state(&mut self) -> DFResult<Vec<ScalarValue>> {
let list_values: Vec<ScalarValue> = self
.values
.iter()
.map(|f| ScalarValue::Float64(Some(*f)))
.collect();
let list_scalar = ScalarValue::List(ScalarValue::new_list(
&list_values,
&DataType::Float64,
true,
));
Ok(vec![list_scalar, ScalarValue::Float64(self.percentile)])
}
fn merge_batch(&mut self, states: &[ArrayRef]) -> DFResult<()> {
let list_arr = &states[0];
let pct_arr = &states[1];
if self.percentile.is_none()
&& let Some(f64_arr) = pct_arr.as_any().downcast_ref::<Float64Array>()
{
for i in 0..f64_arr.len() {
if !f64_arr.is_null(i) {
self.percentile = Some(f64_arr.value(i));
break;
}
}
}
if let Some(list_array) = list_arr.as_any().downcast_ref::<arrow_array::ListArray>() {
for i in 0..list_array.len() {
if list_array.is_null(i) {
continue;
}
let inner = list_array.value(i);
if let Some(f64_arr) = inner.as_any().downcast_ref::<Float64Array>() {
for j in 0..f64_arr.len() {
if !f64_arr.is_null(j) {
self.values.push(f64_arr.value(j));
}
}
}
}
}
Ok(())
}
}
#[derive(Debug, Clone)]
struct CypherPercentileContUdaf {
signature: Signature,
}
impl CypherPercentileContUdaf {
fn new() -> Self {
Self {
signature: Signature::new(TypeSignature::Any(2), Volatility::Immutable),
}
}
}
impl PartialEq for CypherPercentileContUdaf {
fn eq(&self, other: &Self) -> bool {
self.signature == other.signature
}
}
impl Eq for CypherPercentileContUdaf {}
impl Hash for CypherPercentileContUdaf {
fn hash<H: Hasher>(&self, state: &mut H) {
self.name().hash(state);
}
}
impl AggregateUDFImpl for CypherPercentileContUdaf {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"percentilecont"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _args: &[DataType]) -> DFResult<DataType> {
Ok(DataType::Float64)
}
fn accumulator(
&self,
_acc_args: datafusion::logical_expr::function::AccumulatorArgs,
) -> DFResult<Box<dyn DfAccumulator>> {
Ok(Box::new(CypherPercentileContAccumulator {
values: Vec::new(),
percentile: None,
}))
}
fn state_fields(
&self,
args: datafusion::logical_expr::function::StateFieldsArgs,
) -> DFResult<Vec<Arc<arrow::datatypes::Field>>> {
Ok(vec![
Arc::new(arrow::datatypes::Field::new(
format!("{}_values", args.name),
DataType::List(Arc::new(arrow::datatypes::Field::new(
"item",
DataType::Float64,
true,
))),
true,
)),
Arc::new(arrow::datatypes::Field::new(
format!("{}_percentile", args.name),
DataType::Float64,
true,
)),
])
}
}
#[derive(Debug)]
struct CypherPercentileContAccumulator {
values: Vec<f64>,
percentile: Option<f64>,
}
impl DfAccumulator for CypherPercentileContAccumulator {
fn update_batch(&mut self, values: &[ArrayRef]) -> DFResult<()> {
let expr_arr = &values[0];
let pct_arr = &values[1];
for i in 0..expr_arr.len() {
if self.percentile.is_none()
&& let Some(p) = CypherPercentileDiscAccumulator::extract_percentile(pct_arr, i)
{
if !(0.0..=1.0).contains(&p) {
return Err(datafusion::error::DataFusionError::Execution(
"ArgumentError: NumberOutOfRange - percentileCont(): percentile value must be between 0.0 and 1.0".to_string(),
));
}
self.percentile = Some(p);
}
if let Some(f) = CypherPercentileDiscAccumulator::extract_f64(expr_arr, i) {
self.values.push(f);
}
}
Ok(())
}
fn evaluate(&mut self) -> DFResult<ScalarValue> {
let pct = match self.percentile {
Some(p) if !(0.0..=1.0).contains(&p) => {
return Err(datafusion::error::DataFusionError::Execution(
"ArgumentError: NumberOutOfRange - percentileCont(): percentile value must be between 0.0 and 1.0".to_string(),
));
}
Some(p) => p,
None => 0.0,
};
if self.values.is_empty() {
return Ok(ScalarValue::Float64(None));
}
self.values
.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let n = self.values.len();
if n == 1 {
return Ok(ScalarValue::Float64(Some(self.values[0])));
}
let pos = pct * (n as f64 - 1.0);
let lower = pos.floor() as usize;
let upper = pos.ceil() as usize;
let lower = lower.min(n - 1);
let upper = upper.min(n - 1);
if lower == upper {
Ok(ScalarValue::Float64(Some(self.values[lower])))
} else {
let frac = pos - lower as f64;
let result = self.values[lower] + frac * (self.values[upper] - self.values[lower]);
Ok(ScalarValue::Float64(Some(result)))
}
}
fn size(&self) -> usize {
std::mem::size_of_val(self) + self.values.capacity() * 8
}
fn state(&mut self) -> DFResult<Vec<ScalarValue>> {
let list_values: Vec<ScalarValue> = self
.values
.iter()
.map(|f| ScalarValue::Float64(Some(*f)))
.collect();
let list_scalar = ScalarValue::List(ScalarValue::new_list(
&list_values,
&DataType::Float64,
true,
));
Ok(vec![list_scalar, ScalarValue::Float64(self.percentile)])
}
fn merge_batch(&mut self, states: &[ArrayRef]) -> DFResult<()> {
let list_arr = &states[0];
let pct_arr = &states[1];
if self.percentile.is_none()
&& let Some(f64_arr) = pct_arr.as_any().downcast_ref::<Float64Array>()
{
for i in 0..f64_arr.len() {
if !f64_arr.is_null(i) {
self.percentile = Some(f64_arr.value(i));
break;
}
}
}
if let Some(list_array) = list_arr.as_any().downcast_ref::<arrow_array::ListArray>() {
for i in 0..list_array.len() {
if list_array.is_null(i) {
continue;
}
let inner = list_array.value(i);
if let Some(f64_arr) = inner.as_any().downcast_ref::<Float64Array>() {
for j in 0..f64_arr.len() {
if !f64_arr.is_null(j) {
self.values.push(f64_arr.value(j));
}
}
}
}
}
Ok(())
}
}
pub(crate) fn create_cypher_percentile_disc_udaf() -> AggregateUDF {
AggregateUDF::from(CypherPercentileDiscUdaf::new())
}
pub(crate) fn create_cypher_percentile_cont_udaf() -> AggregateUDF {
AggregateUDF::from(CypherPercentileContUdaf::new())
}
fn invoke_similarity_udf(
func_name: &str,
min_args: usize,
args: ScalarFunctionArgs,
) -> DFResult<ColumnarValue> {
let output_type = DataType::Float64;
invoke_cypher_udf(args, &output_type, |val_args| {
if val_args.len() < min_args {
return Err(datafusion::error::DataFusionError::Execution(format!(
"{} requires at least {} arguments",
func_name, min_args
)));
}
crate::query::similar_to::eval_similar_to_pure(&val_args[0], &val_args[1])
.map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))
})
}
pub fn create_similar_to_udf() -> ScalarUDF {
ScalarUDF::new_from_impl(SimilarToUdf::new())
}
#[derive(Debug)]
struct SimilarToUdf {
signature: Signature,
}
impl SimilarToUdf {
fn new() -> Self {
Self {
signature: Signature::new(TypeSignature::VariadicAny, Volatility::Immutable),
}
}
}
impl_udf_eq_hash!(SimilarToUdf);
impl ScalarUDFImpl for SimilarToUdf {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"similar_to"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
Ok(DataType::Float64)
}
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
invoke_similarity_udf("similar_to", 2, args)
}
}
pub fn create_vector_similarity_udf() -> ScalarUDF {
ScalarUDF::new_from_impl(VectorSimilarityUdf::new())
}
#[derive(Debug)]
struct VectorSimilarityUdf {
signature: Signature,
}
impl VectorSimilarityUdf {
fn new() -> Self {
Self {
signature: Signature::new(TypeSignature::Any(2), Volatility::Immutable),
}
}
}
impl_udf_eq_hash!(VectorSimilarityUdf);
impl ScalarUDFImpl for VectorSimilarityUdf {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"vector_similarity"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
Ok(DataType::Float64)
}
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
invoke_similarity_udf("vector_similarity", 2, args)
}
}
#[cfg(test)]
mod tests {
use super::*;
use datafusion::execution::FunctionRegistry;
#[test]
fn test_register_udfs() {
let ctx = SessionContext::new();
register_cypher_udfs(&ctx).unwrap();
assert!(ctx.udf("id").is_ok());
assert!(ctx.udf("type").is_ok());
assert!(ctx.udf("keys").is_ok());
assert!(ctx.udf("range").is_ok());
assert!(
ctx.udf("_make_cypher_list").is_ok(),
"_make_cypher_list UDF should be registered"
);
assert!(
ctx.udf("_cv_to_bool").is_ok(),
"_cv_to_bool UDF should be registered"
);
}
#[test]
fn test_id_udf_signature() {
let udf = create_id_udf();
assert_eq!(udf.name(), "id");
}
#[test]
fn test_has_null_udf() {
use datafusion::arrow::datatypes::{DataType, Field};
use datafusion::config::ConfigOptions;
use datafusion::scalar::ScalarValue;
use std::sync::Arc;
let udf = create_has_null_udf();
let values = vec![
ScalarValue::Int64(Some(1)),
ScalarValue::Int64(Some(2)),
ScalarValue::Int64(None),
];
let list_scalar = ScalarValue::List(ScalarValue::new_list(&values, &DataType::Int64, true));
let list_field = Arc::new(Field::new(
"item",
DataType::List(Arc::new(Field::new("item", DataType::Int64, true))),
true,
));
let args = ScalarFunctionArgs {
args: vec![ColumnarValue::Scalar(list_scalar)],
arg_fields: vec![list_field],
number_rows: 1,
return_field: Arc::new(Field::new("result", DataType::Boolean, true)),
config_options: Arc::new(ConfigOptions::default()),
};
let result = udf.invoke_with_args(args).unwrap();
if let ColumnarValue::Scalar(ScalarValue::Boolean(Some(b))) = result {
assert!(b, "has_null should return true for list with null");
} else {
panic!("Unexpected result: {:?}", result);
}
}
fn json_to_cv_bytes(val: &serde_json::Value) -> Vec<u8> {
let uni_val: uni_common::Value = val.clone().into();
uni_common::cypher_value_codec::encode(&uni_val)
}
fn make_multi_scalar_args(scalars: Vec<ScalarValue>) -> ScalarFunctionArgs {
make_multi_scalar_args_with_return(scalars, DataType::LargeBinary)
}
fn make_multi_scalar_args_with_return(
scalars: Vec<ScalarValue>,
return_type: DataType,
) -> ScalarFunctionArgs {
use datafusion::arrow::datatypes::Field;
use datafusion::config::ConfigOptions;
let arg_fields: Vec<_> = scalars
.iter()
.enumerate()
.map(|(i, s)| Arc::new(Field::new(format!("arg{i}"), s.data_type(), true)))
.collect();
let args: Vec<_> = scalars.into_iter().map(ColumnarValue::Scalar).collect();
ScalarFunctionArgs {
args,
arg_fields,
number_rows: 1,
return_field: Arc::new(Field::new("result", return_type, true)),
config_options: Arc::new(ConfigOptions::default()),
}
}
fn decode_cv_scalar(cv: &ColumnarValue) -> serde_json::Value {
match cv {
ColumnarValue::Scalar(ScalarValue::LargeBinary(Some(bytes))) => {
let val = uni_common::cypher_value_codec::decode(bytes)
.expect("failed to decode CypherValue output");
val.into()
}
other => panic!("expected LargeBinary scalar, got {other:?}"),
}
}
#[test]
fn test_make_cypher_list_scalars() {
let udf = create_make_cypher_list_udf();
let args = make_multi_scalar_args(vec![
ScalarValue::Int64(Some(1)),
ScalarValue::Float64(Some(3.21)),
ScalarValue::Utf8(Some("hello".to_string())),
ScalarValue::Boolean(Some(true)),
ScalarValue::Null,
]);
let result = udf.invoke_with_args(args).unwrap();
let json = decode_cv_scalar(&result);
let arr = json.as_array().expect("should be array");
assert_eq!(arr.len(), 5);
assert_eq!(arr[0], serde_json::json!(1));
assert_eq!(arr[1], serde_json::json!(3.21));
assert_eq!(arr[2], serde_json::json!("hello"));
assert_eq!(arr[3], serde_json::json!(true));
assert!(arr[4].is_null());
}
#[test]
fn test_make_cypher_list_empty() {
let udf = create_make_cypher_list_udf();
let args = make_multi_scalar_args(vec![]);
let result = udf.invoke_with_args(args).unwrap();
let json = decode_cv_scalar(&result);
let arr = json.as_array().expect("should be array");
assert!(arr.is_empty());
}
#[test]
fn test_make_cypher_list_single() {
let udf = create_make_cypher_list_udf();
let args = make_multi_scalar_args(vec![ScalarValue::Int64(Some(42))]);
let result = udf.invoke_with_args(args).unwrap();
let json = decode_cv_scalar(&result);
let arr = json.as_array().expect("should be array");
assert_eq!(arr.len(), 1);
assert_eq!(arr[0], serde_json::json!(42));
}
#[test]
fn test_make_cypher_list_nested_cypher_value() {
let udf = create_make_cypher_list_udf();
let nested_bytes = json_to_cv_bytes(&serde_json::json!([1, 2]));
let args = make_multi_scalar_args(vec![
ScalarValue::LargeBinary(Some(nested_bytes)),
ScalarValue::Int64(Some(3)),
]);
let result = udf.invoke_with_args(args).unwrap();
let json = decode_cv_scalar(&result);
let arr = json.as_array().expect("should be array");
assert_eq!(arr.len(), 2);
assert_eq!(arr[0], serde_json::json!([1, 2]));
assert_eq!(arr[1], serde_json::json!(3));
}
fn make_cypher_in_args(
element: &serde_json::Value,
list: &serde_json::Value,
) -> ScalarFunctionArgs {
make_multi_scalar_args_with_return(
vec![
ScalarValue::LargeBinary(Some(json_to_cv_bytes(element))),
ScalarValue::LargeBinary(Some(json_to_cv_bytes(list))),
],
DataType::Boolean,
)
}
#[test]
fn test_cypher_in_found() {
let udf = create_cypher_in_udf();
let args = make_cypher_in_args(&serde_json::json!(3), &serde_json::json!([1, 2, 3]));
let result = udf.invoke_with_args(args).unwrap();
match result {
ColumnarValue::Scalar(ScalarValue::Boolean(Some(b))) => assert!(b),
other => panic!("expected Boolean(true), got {other:?}"),
}
}
#[test]
fn test_cypher_in_not_found() {
let udf = create_cypher_in_udf();
let args = make_cypher_in_args(&serde_json::json!(4), &serde_json::json!([1, 2, 3]));
let result = udf.invoke_with_args(args).unwrap();
match result {
ColumnarValue::Scalar(ScalarValue::Boolean(Some(b))) => assert!(!b),
other => panic!("expected Boolean(false), got {other:?}"),
}
}
#[test]
fn test_cypher_in_null_list() {
let udf = create_cypher_in_udf();
let args = make_multi_scalar_args_with_return(
vec![
ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!(1)))),
ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!(null)))),
],
DataType::Boolean,
);
let result = udf.invoke_with_args(args).unwrap();
match result {
ColumnarValue::Scalar(ScalarValue::Boolean(None)) => {} other => panic!("expected Boolean(None) for null list, got {other:?}"),
}
}
#[test]
fn test_cypher_in_null_element_nonempty() {
let udf = create_cypher_in_udf();
let args = make_multi_scalar_args_with_return(
vec![
ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!(null)))),
ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!([1, 2])))),
],
DataType::Boolean,
);
let result = udf.invoke_with_args(args).unwrap();
match result {
ColumnarValue::Scalar(ScalarValue::Boolean(None)) => {} other => panic!("expected Boolean(None) for null IN non-empty list, got {other:?}"),
}
}
#[test]
fn test_cypher_in_null_element_empty() {
let udf = create_cypher_in_udf();
let args = make_multi_scalar_args_with_return(
vec![
ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!(null)))),
ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!([])))),
],
DataType::Boolean,
);
let result = udf.invoke_with_args(args).unwrap();
match result {
ColumnarValue::Scalar(ScalarValue::Boolean(Some(b))) => assert!(!b),
other => panic!("expected Boolean(false) for null IN [], got {other:?}"),
}
}
#[test]
fn test_cypher_in_not_found_with_null() {
let udf = create_cypher_in_udf();
let args = make_cypher_in_args(&serde_json::json!(4), &serde_json::json!([1, null, 3]));
let result = udf.invoke_with_args(args).unwrap();
match result {
ColumnarValue::Scalar(ScalarValue::Boolean(None)) => {} other => panic!("expected Boolean(None) for 4 IN [1,null,3], got {other:?}"),
}
}
#[test]
fn test_cypher_in_cross_type_int_float() {
let udf = create_cypher_in_udf();
let args = make_cypher_in_args(&serde_json::json!(1), &serde_json::json!([1.0, 2.0]));
let result = udf.invoke_with_args(args).unwrap();
match result {
ColumnarValue::Scalar(ScalarValue::Boolean(Some(b))) => assert!(b),
other => panic!("expected Boolean(true) for 1 IN [1.0, 2.0], got {other:?}"),
}
}
#[test]
fn test_list_concat_basic() {
let udf = create_cypher_list_concat_udf();
let args = make_multi_scalar_args(vec![
ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!([1, 2])))),
ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!([3, 4])))),
]);
let result = udf.invoke_with_args(args).unwrap();
let json = decode_cv_scalar(&result);
assert_eq!(json, serde_json::json!([1, 2, 3, 4]));
}
#[test]
fn test_list_concat_empty() {
let udf = create_cypher_list_concat_udf();
let args = make_multi_scalar_args(vec![
ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!([])))),
ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!([1])))),
]);
let result = udf.invoke_with_args(args).unwrap();
let json = decode_cv_scalar(&result);
assert_eq!(json, serde_json::json!([1]));
}
#[test]
fn test_list_concat_null_left() {
let udf = create_cypher_list_concat_udf();
let args = make_multi_scalar_args(vec![
ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!(null)))),
ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!([1])))),
]);
let result = udf.invoke_with_args(args).unwrap();
match result {
ColumnarValue::Scalar(ScalarValue::LargeBinary(Some(bytes))) => {
let uni_val = uni_common::cypher_value_codec::decode(&bytes).expect("decode");
let json: serde_json::Value = uni_val.into();
assert!(json.is_null(), "expected null, got {json}");
}
ColumnarValue::Scalar(ScalarValue::LargeBinary(None)) => {} other => panic!("expected null result, got {other:?}"),
}
}
#[test]
fn test_list_concat_null_right() {
let udf = create_cypher_list_concat_udf();
let args = make_multi_scalar_args(vec![
ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!([1])))),
ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!(null)))),
]);
let result = udf.invoke_with_args(args).unwrap();
match result {
ColumnarValue::Scalar(ScalarValue::LargeBinary(Some(bytes))) => {
let uni_val = uni_common::cypher_value_codec::decode(&bytes).expect("decode");
let json: serde_json::Value = uni_val.into();
assert!(json.is_null(), "expected null, got {json}");
}
ColumnarValue::Scalar(ScalarValue::LargeBinary(None)) => {}
other => panic!("expected null result, got {other:?}"),
}
}
#[test]
fn test_list_append_scalar() {
let udf = create_cypher_list_append_udf();
let args = make_multi_scalar_args(vec![
ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!([1, 2])))),
ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!(3)))),
]);
let result = udf.invoke_with_args(args).unwrap();
let json = decode_cv_scalar(&result);
assert_eq!(json, serde_json::json!([1, 2, 3]));
}
#[test]
fn test_list_prepend_scalar() {
let udf = create_cypher_list_append_udf();
let args = make_multi_scalar_args(vec![
ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!(3)))),
ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!([1, 2])))),
]);
let result = udf.invoke_with_args(args).unwrap();
let json = decode_cv_scalar(&result);
assert_eq!(json, serde_json::json!([3, 1, 2]));
}
#[test]
fn test_list_append_null_list() {
let udf = create_cypher_list_append_udf();
let args = make_multi_scalar_args(vec![
ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!(null)))),
ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!(3)))),
]);
let result = udf.invoke_with_args(args).unwrap();
match result {
ColumnarValue::Scalar(ScalarValue::LargeBinary(Some(bytes))) => {
let uni_val = uni_common::cypher_value_codec::decode(&bytes).expect("decode");
let json: serde_json::Value = uni_val.into();
assert!(json.is_null(), "expected null, got {json}");
}
ColumnarValue::Scalar(ScalarValue::LargeBinary(None)) => {}
other => panic!("expected null result, got {other:?}"),
}
}
#[test]
fn test_list_append_null_scalar() {
let udf = create_cypher_list_append_udf();
let args = make_multi_scalar_args(vec![
ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!([1, 2])))),
ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!(null)))),
]);
let result = udf.invoke_with_args(args).unwrap();
match result {
ColumnarValue::Scalar(ScalarValue::LargeBinary(Some(bytes))) => {
let uni_val = uni_common::cypher_value_codec::decode(&bytes).expect("decode");
let json: serde_json::Value = uni_val.into();
assert!(json.is_null(), "expected null, got {json}");
}
ColumnarValue::Scalar(ScalarValue::LargeBinary(None)) => {}
other => panic!("expected null result, got {other:?}"),
}
}
#[test]
fn test_sort_key_cross_type_ordering() {
use uni_common::core::id::{Eid, Vid};
use uni_common::{Edge, Node, Path, TemporalValue, Value};
let map_val = Value::Map([("a".to_string(), Value::String("map".to_string()))].into());
let node_val = Value::Node(Node {
vid: Vid::new(1),
labels: vec!["L".to_string()],
properties: Default::default(),
});
let edge_val = Value::Edge(Edge {
eid: Eid::new(1),
edge_type: "T".to_string(),
src: Vid::new(1),
dst: Vid::new(2),
properties: Default::default(),
});
let list_val = Value::List(vec![Value::Int(1)]);
let path_val = Value::Path(Path {
nodes: vec![Node {
vid: Vid::new(1),
labels: vec!["L".to_string()],
properties: Default::default(),
}],
edges: vec![],
});
let string_val = Value::String("hello".to_string());
let bool_val = Value::Bool(false);
let temporal_val = Value::Temporal(TemporalValue::Date {
days_since_epoch: 1000,
});
let number_val = Value::Int(42);
let nan_val = Value::Float(f64::NAN);
let null_val = Value::Null;
let values = vec![
&map_val,
&node_val,
&edge_val,
&list_val,
&path_val,
&string_val,
&bool_val,
&temporal_val,
&number_val,
&nan_val,
&null_val,
];
let keys: Vec<Vec<u8>> = values.iter().map(|v| encode_cypher_sort_key(v)).collect();
for i in 0..keys.len() - 1 {
assert!(
keys[i] < keys[i + 1],
"Expected sort_key({:?}) < sort_key({:?}), but {:?} >= {:?}",
values[i],
values[i + 1],
keys[i],
keys[i + 1]
);
}
}
#[test]
fn test_sort_key_numbers() {
let neg_inf = encode_cypher_sort_key(&Value::Float(f64::NEG_INFINITY));
let neg_100 = encode_cypher_sort_key(&Value::Float(-100.0));
let neg_1 = encode_cypher_sort_key(&Value::Int(-1));
let zero_int = encode_cypher_sort_key(&Value::Int(0));
let zero_float = encode_cypher_sort_key(&Value::Float(0.0));
let one_int = encode_cypher_sort_key(&Value::Int(1));
let one_float = encode_cypher_sort_key(&Value::Float(1.0));
let hundred = encode_cypher_sort_key(&Value::Int(100));
let pos_inf = encode_cypher_sort_key(&Value::Float(f64::INFINITY));
let nan = encode_cypher_sort_key(&Value::Float(f64::NAN));
assert!(neg_inf < neg_100, "-inf < -100");
assert!(neg_100 < neg_1, "-100 < -1");
assert!(neg_1 < zero_int, "-1 < 0");
assert_eq!(zero_int, zero_float, "0 int == 0.0 float");
assert!(zero_int < one_int, "0 < 1");
assert_eq!(one_int, one_float, "1 int == 1.0 float");
assert!(one_int < hundred, "1 < 100");
assert!(hundred < pos_inf, "100 < +inf");
assert!(pos_inf < nan, "+inf < NaN");
}
#[test]
fn test_sort_key_booleans() {
let f = encode_cypher_sort_key(&Value::Bool(false));
let t = encode_cypher_sort_key(&Value::Bool(true));
assert!(f < t, "false < true");
}
#[test]
fn test_sort_key_strings() {
let empty = encode_cypher_sort_key(&Value::String(String::new()));
let a = encode_cypher_sort_key(&Value::String("a".to_string()));
let ab = encode_cypher_sort_key(&Value::String("ab".to_string()));
let b = encode_cypher_sort_key(&Value::String("b".to_string()));
assert!(empty < a, "'' < 'a'");
assert!(a < ab, "'a' < 'ab'");
assert!(ab < b, "'ab' < 'b'");
}
#[test]
fn test_sort_key_lists() {
let empty = encode_cypher_sort_key(&Value::List(vec![]));
let one = encode_cypher_sort_key(&Value::List(vec![Value::Int(1)]));
let one_two = encode_cypher_sort_key(&Value::List(vec![Value::Int(1), Value::Int(2)]));
let two = encode_cypher_sort_key(&Value::List(vec![Value::Int(2)]));
assert!(empty < one, "[] < [1]");
assert!(one < one_two, "[1] < [1,2]");
assert!(one_two < two, "[1,2] < [2]");
}
#[test]
fn test_sort_key_temporal() {
use uni_common::TemporalValue;
let date1 = encode_cypher_sort_key(&Value::Temporal(TemporalValue::Date {
days_since_epoch: 100,
}));
let date2 = encode_cypher_sort_key(&Value::Temporal(TemporalValue::Date {
days_since_epoch: 200,
}));
assert!(date1 < date2, "earlier date < later date");
let date = encode_cypher_sort_key(&Value::Temporal(TemporalValue::Date {
days_since_epoch: i32::MAX,
}));
let local_time = encode_cypher_sort_key(&Value::Temporal(TemporalValue::LocalTime {
nanos_since_midnight: 0,
}));
assert!(date < local_time, "Date < LocalTime (by variant rank)");
}
#[test]
fn test_sort_key_nested_lists() {
let inner_a = Value::List(vec![Value::Int(1)]);
let inner_b = Value::List(vec![Value::Int(2)]);
let list_a = encode_cypher_sort_key(&Value::List(vec![inner_a.clone()]));
let list_b = encode_cypher_sort_key(&Value::List(vec![inner_b.clone()]));
assert!(list_a < list_b, "[[1]] < [[2]]");
}
#[test]
fn test_sort_key_null_handling() {
let null_key = encode_cypher_sort_key(&Value::Null);
assert_eq!(null_key, vec![0x0A], "Null produces [0x0A]");
let number_key = encode_cypher_sort_key(&Value::Int(42));
assert!(number_key < null_key, "number < null");
}
#[test]
fn test_byte_stuff_roundtrip() {
let s1 = Value::String("a\x00b".to_string());
let s2 = Value::String("a\x00c".to_string());
let s3 = Value::String("a\x01".to_string());
let k1 = encode_cypher_sort_key(&s1);
let k2 = encode_cypher_sort_key(&s2);
let k3 = encode_cypher_sort_key(&s3);
assert!(k1 < k2, "a\\x00b < a\\x00c");
assert!(k1 < k3, "a\\x00b < a\\x01");
}
#[test]
fn test_sort_key_order_preserving_f64() {
let vals = [f64::NEG_INFINITY, -1.0, -0.0, 0.0, 1.0, f64::INFINITY];
let encoded: Vec<[u8; 8]> = vals
.iter()
.map(|f| encode_order_preserving_f64(*f))
.collect();
for i in 0..encoded.len() - 1 {
assert!(
encoded[i] <= encoded[i + 1],
"encode({}) should <= encode({}), got {:?} vs {:?}",
vals[i],
vals[i + 1],
encoded[i],
encoded[i + 1]
);
}
}
#[test]
fn test_sort_key_string_as_temporal_time_with_offset() {
let tv = sort_key_string_as_temporal("12:35:15+05:00")
.expect("should parse Time with positive offset");
match tv {
uni_common::TemporalValue::Time {
nanos_since_midnight,
offset_seconds,
} => {
assert_eq!(offset_seconds, 5 * 3600, "offset should be +05:00 = 18000s");
let expected_nanos = (12 * 3600 + 35 * 60 + 15) * 1_000_000_000i64;
assert_eq!(nanos_since_midnight, expected_nanos);
}
other => panic!("expected TemporalValue::Time, got {other:?}"),
}
}
#[test]
fn test_sort_key_string_as_temporal_time_negative_offset() {
let tv = sort_key_string_as_temporal("10:35:00-08:00")
.expect("should parse Time with negative offset");
match tv {
uni_common::TemporalValue::Time {
nanos_since_midnight,
offset_seconds,
} => {
assert_eq!(
offset_seconds,
-8 * 3600,
"offset should be -08:00 = -28800s"
);
let expected_nanos = (10 * 3600 + 35 * 60) * 1_000_000_000i64;
assert_eq!(nanos_since_midnight, expected_nanos);
}
other => panic!("expected TemporalValue::Time, got {other:?}"),
}
}
#[test]
fn test_sort_key_string_as_temporal_date() {
use super::super::expr_eval::temporal_from_value;
let tv = temporal_from_value(&Value::String("2024-01-15".into()))
.expect("should parse Date string");
match tv {
uni_common::TemporalValue::Date { days_since_epoch } => {
assert!(days_since_epoch > 0, "2024-01-15 should be after epoch");
}
other => panic!("expected TemporalValue::Date, got {other:?}"),
}
}
}