use std::sync::Arc;
use datafusion::arrow::datatypes::DataType;
use datafusion::error::DataFusionError;
use datafusion::logical_expr::{AggregateUDF, ScalarUDF};
use serde::{Deserialize, Serialize};
use crate::udf;
#[derive(Serialize, Deserialize)]
struct ArrayCastPayload {
target_type: DataType,
cast_descriptor: udf::CastDescriptor,
}
pub fn try_decode_udf(name: &str, buf: &[u8]) -> Result<Option<Arc<ScalarUDF>>, DataFusionError> {
let udf = match name {
"hamelin_array_avg" => udf::array_avg_udf(),
"hamelin_array_length" => udf::hamelin_array_length_udf(),
"hamelin_array_sum" => udf::array_sum_udf(),
"hamelin_array_variant_get" => udf::array_variant_get_udf(),
"hamelin_is_ipv4" => udf::is_ipv4_udf(),
"hamelin_is_ipv6" => udf::is_ipv6_udf(),
"hamelin_cidr_contains" => udf::cidr_contains_udf(),
"hamelin_json_to_variant" => udf::json_to_variant_udf(),
"hamelin_map_from_entries" => udf::map_from_entries_udf(),
"hamelin_parse_timestamp" => udf::parse_timestamp_udf(),
"hamelin_regexp_extract_all" => udf::regexp_extract_all_udf(),
"hamelin_regexp_split" => udf::regexp_split_udf(),
"hamelin_cast_to_variant" => udf::cast_to_variant_udf(),
"hamelin_uuid5" => udf::uuid5_udf(),
"hamelin_variant_get" => udf::variant_get_udf(),
"hamelin_to_json_string" => udf::variant_to_json_udf(),
"hamelin_width_bucket" => udf::width_bucket_array_udf(),
"hamelin_to_millis" => udf::to_millis_udf(),
"hamelin_to_nanos" => udf::to_nanos_udf(),
"hamelin_from_millis" => udf::from_millis_udf(),
"hamelin_from_nanos" => udf::from_nanos_udf(),
"hamelin_to_months" => udf::to_months_udf(),
"hamelin_from_months" => udf::from_months_udf(),
"hamelin_from_variant" => {
let target_type: DataType = serde_json::from_slice(buf).map_err(|e| {
DataFusionError::Internal(format!("Failed to decode hamelin_from_variant: {e}"))
})?;
udf::from_variant_udf(target_type)
}
"hamelin_array_cast" => {
let payload: ArrayCastPayload = serde_json::from_slice(buf).map_err(|e| {
DataFusionError::Internal(format!("Failed to decode hamelin_array_cast: {e}"))
})?;
udf::array_cast_udf(payload.target_type, payload.cast_descriptor)
}
_ => return Ok(None),
};
Ok(Some(Arc::new(udf)))
}
pub fn try_encode_udf(node: &ScalarUDF, buf: &mut Vec<u8>) -> Result<bool, DataFusionError> {
match node.name() {
"hamelin_from_variant" => {
let inner = node
.inner()
.as_any()
.downcast_ref::<udf::FromVariantUdf>()
.ok_or_else(|| {
DataFusionError::Internal("Failed to downcast hamelin_from_variant".to_string())
})?;
serde_json::to_writer(buf, inner.target_type()).map_err(|e| {
DataFusionError::Internal(format!("Failed to encode hamelin_from_variant: {e}"))
})?;
Ok(true)
}
"hamelin_array_cast" => {
let inner = node
.inner()
.as_any()
.downcast_ref::<udf::ArrayCastUdf>()
.ok_or_else(|| {
DataFusionError::Internal("Failed to downcast hamelin_array_cast".to_string())
})?;
let payload = ArrayCastPayload {
target_type: inner.target_type().clone(),
cast_descriptor: inner.cast_descriptor().clone(),
};
serde_json::to_writer(buf, &payload).map_err(|e| {
DataFusionError::Internal(format!("Failed to encode hamelin_array_cast: {e}"))
})?;
Ok(true)
}
"hamelin_array_avg"
| "hamelin_array_length"
| "hamelin_array_sum"
| "hamelin_array_variant_get"
| "hamelin_is_ipv4"
| "hamelin_is_ipv6"
| "hamelin_cidr_contains"
| "hamelin_json_to_variant"
| "hamelin_map_from_entries"
| "hamelin_parse_timestamp"
| "hamelin_regexp_extract_all"
| "hamelin_regexp_split"
| "hamelin_cast_to_variant"
| "hamelin_uuid5"
| "hamelin_variant_get"
| "hamelin_to_json_string"
| "hamelin_width_bucket"
| "hamelin_to_millis"
| "hamelin_to_nanos"
| "hamelin_from_millis"
| "hamelin_from_nanos"
| "hamelin_to_months"
| "hamelin_from_months" => Ok(true),
_ => Ok(false),
}
}
pub fn try_decode_udaf(
name: &str,
_buf: &[u8],
) -> Result<Option<Arc<AggregateUDF>>, DataFusionError> {
let udaf = match name {
"hamelin_any_value" => udf::any_value_udaf(),
"hamelin_map_agg" => udf::map_agg_udaf(),
"hamelin_multimap_agg" => udf::multimap_agg_udaf(),
"hamelin_sliding_array_agg" => udf::sliding_array_agg_udaf(),
_ => return Ok(None),
};
Ok(Some(Arc::new(udaf)))
}
pub fn try_encode_udaf(node: &AggregateUDF, _buf: &mut Vec<u8>) -> Result<bool, DataFusionError> {
if matches!(
node.name(),
"hamelin_any_value"
| "hamelin_map_agg"
| "hamelin_multimap_agg"
| "hamelin_sliding_array_agg"
) {
Ok(true)
} else {
Ok(false)
}
}