use std::any::Any;
use std::fmt::{self, Display, Formatter};
use std::sync::Arc;
use arrow_schema::{DataType, Field, Schema};
use datafusion::arrow::array::{Array, ArrayRef, LargeListArray, ListArray};
use datafusion::common::ScalarValue;
use datafusion::error::{DataFusionError, Result as DfResult};
use datafusion::logical_expr::ColumnarValue;
use datafusion::physical_plan::PhysicalExpr;
use uni_common::core::schema::raw_bytes_field_metadata;
use uni_cypher::ast::{CypherLiteral, Expr};
#[derive(Debug)]
pub(crate) struct RawBytesMarkerExpr {
inner: Arc<dyn PhysicalExpr>,
on_child: bool,
}
impl RawBytesMarkerExpr {
pub(crate) fn scalar(inner: Arc<dyn PhysicalExpr>) -> Self {
Self {
inner,
on_child: false,
}
}
pub(crate) fn list_child(inner: Arc<dyn PhysicalExpr>) -> Self {
Self {
inner,
on_child: true,
}
}
}
fn mark_list_child_type(dt: DataType) -> DataType {
let marked = |child: &Arc<Field>| {
Arc::new(
child
.as_ref()
.clone()
.with_metadata(raw_bytes_field_metadata()),
)
};
match dt {
DataType::List(child) => DataType::List(marked(&child)),
DataType::LargeList(child) => DataType::LargeList(marked(&child)),
DataType::FixedSizeList(child, n) => DataType::FixedSizeList(marked(&child), n),
other => other,
}
}
fn restamp_list_child(array: ArrayRef) -> ArrayRef {
if let Some(list) = array.as_any().downcast_ref::<ListArray>() {
let (field, offsets, values, nulls) = list.clone().into_parts();
let new_field = Arc::new(
field
.as_ref()
.clone()
.with_metadata(raw_bytes_field_metadata()),
);
return Arc::new(ListArray::new(new_field, offsets, values, nulls));
}
if let Some(list) = array.as_any().downcast_ref::<LargeListArray>() {
let (field, offsets, values, nulls) = list.clone().into_parts();
let new_field = Arc::new(
field
.as_ref()
.clone()
.with_metadata(raw_bytes_field_metadata()),
);
return Arc::new(LargeListArray::new(new_field, offsets, values, nulls));
}
array
}
impl Display for RawBytesMarkerExpr {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.inner)
}
}
impl PartialEq for RawBytesMarkerExpr {
fn eq(&self, other: &Self) -> bool {
self.on_child == other.on_child && Arc::ptr_eq(&self.inner, &other.inner)
}
}
impl Eq for RawBytesMarkerExpr {}
impl std::hash::Hash for RawBytesMarkerExpr {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
std::any::type_name::<Self>().hash(state);
self.on_child.hash(state);
}
}
impl PartialEq<dyn Any> for RawBytesMarkerExpr {
fn eq(&self, other: &dyn Any) -> bool {
other
.downcast_ref::<Self>()
.map(|x| self == x)
.unwrap_or(false)
}
}
impl PhysicalExpr for RawBytesMarkerExpr {
fn as_any(&self) -> &dyn Any {
self
}
fn data_type(&self, input_schema: &Schema) -> DfResult<DataType> {
let dt = self.inner.data_type(input_schema)?;
Ok(if self.on_child {
mark_list_child_type(dt)
} else {
dt
})
}
fn nullable(&self, input_schema: &Schema) -> DfResult<bool> {
self.inner.nullable(input_schema)
}
fn return_field(&self, input_schema: &Schema) -> DfResult<Arc<Field>> {
let field = self.inner.return_field(input_schema)?;
if self.on_child {
let dt = mark_list_child_type(field.data_type().clone());
Ok(Arc::new(
Field::new(field.name(), dt, field.is_nullable())
.with_metadata(field.metadata().clone()),
))
} else {
let mut metadata = field.metadata().clone();
metadata.extend(raw_bytes_field_metadata());
Ok(Arc::new(field.as_ref().clone().with_metadata(metadata)))
}
}
fn evaluate(&self, batch: &datafusion::arrow::array::RecordBatch) -> DfResult<ColumnarValue> {
let value = self.inner.evaluate(batch)?;
if !self.on_child {
return Ok(value);
}
match value {
ColumnarValue::Array(array) => Ok(ColumnarValue::Array(restamp_list_child(array))),
ColumnarValue::Scalar(ScalarValue::List(arr)) => {
let restamped = restamp_list_child(arr as ArrayRef);
match restamped.as_any().downcast_ref::<ListArray>() {
Some(list) => Ok(ColumnarValue::Scalar(ScalarValue::List(Arc::new(
list.clone(),
)))),
None => Err(DataFusionError::Internal(
"RawBytesMarkerExpr: restamped scalar list is not a ListArray".to_string(),
)),
}
}
other => Ok(other),
}
}
fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
vec![&self.inner]
}
fn with_new_children(
self: Arc<Self>,
children: Vec<Arc<dyn PhysicalExpr>>,
) -> DfResult<Arc<dyn PhysicalExpr>> {
if children.len() != 1 {
return Err(DataFusionError::Internal(
"RawBytesMarkerExpr expects exactly 1 child".to_string(),
));
}
Ok(Arc::new(Self {
inner: children[0].clone(),
on_child: self.on_child,
}))
}
fn fmt_sql(&self, f: &mut Formatter<'_>) -> fmt::Result {
self.inner.fmt_sql(f)
}
}
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
pub(crate) enum Shape {
None,
Null,
RawScalar,
RawList,
}
fn shape_of_field(field: &Field) -> Shape {
if field
.metadata()
.get("uni_raw_bytes")
.is_some_and(|v| v == "true")
{
return Shape::RawScalar;
}
match field.data_type() {
DataType::List(child) | DataType::LargeList(child) | DataType::FixedSizeList(child, _)
if child
.metadata()
.get("uni_raw_bytes")
.is_some_and(|v| v == "true") =>
{
Shape::RawList
}
DataType::Null => Shape::Null,
_ => Shape::None,
}
}
fn variable_shape(schema: &Schema, name: &str) -> Shape {
match schema.column_with_name(name) {
Some((_, field)) => shape_of_field(field),
None => Shape::None,
}
}
fn property_shape(schema: &Schema, var: &str, prop: &str) -> Shape {
let mut present_unmarked = false;
if let Ok(idx) = schema.index_of(var)
&& let DataType::Struct(fields) = schema.field(idx).data_type()
&& let Some(field) = fields.iter().find(|f| f.name() == prop)
{
let shape = shape_of_field(field);
if shape != Shape::None {
return shape;
}
present_unmarked = true;
}
let flat = format!("{var}.{prop}");
if let Some((_, field)) = schema.column_with_name(&flat) {
let shape = shape_of_field(field);
if shape != Shape::None {
return shape;
}
present_unmarked = true;
}
if present_unmarked {
Shape::None
} else {
Shape::Null
}
}
fn merge_shapes(shapes: impl IntoIterator<Item = Shape>) -> Shape {
let mut saw_scalar = false;
let mut saw_list = false;
for shape in shapes {
match shape {
Shape::None => return Shape::None,
Shape::Null => {}
Shape::RawScalar => saw_scalar = true,
Shape::RawList => saw_list = true,
}
}
match (saw_scalar, saw_list) {
(false, false) => Shape::Null,
(true, false) => Shape::RawScalar,
(false, true) => Shape::RawList,
(true, true) => Shape::None,
}
}
pub(crate) fn bytes_shape(expr: &Expr, schema: &Schema) -> Shape {
match expr {
Expr::Literal(CypherLiteral::Null) => Shape::Null,
Expr::Variable(name) => variable_shape(schema, name),
Expr::Property(base, prop) => match base.as_ref() {
Expr::Variable(var) => property_shape(schema, var, prop),
_ => Shape::None,
},
Expr::FunctionCall { name, args, .. } if name.eq_ignore_ascii_case("coalesce") => {
merge_shapes(args.iter().map(|a| bytes_shape(a, schema)))
}
Expr::Case {
when_then,
else_expr,
..
} => {
let thens = when_then.iter().map(|(_, then)| bytes_shape(then, schema));
let els = else_expr
.as_deref()
.map(|e| bytes_shape(e, schema))
.into_iter();
merge_shapes(thens.chain(els))
}
Expr::List(items) => {
match merge_shapes(items.iter().map(|e| bytes_shape(e, schema))) {
Shape::RawScalar => Shape::RawList,
_ => Shape::None,
}
}
_ => Shape::None,
}
}
pub(crate) fn is_raw_scalar(expr: &Expr, schema: &Schema) -> bool {
bytes_shape(expr, schema) == Shape::RawScalar
}
pub(crate) fn is_markable_list(expr: &Expr, schema: &Schema) -> bool {
bytes_shape(expr, schema) == Shape::RawList
}
pub(crate) fn coalesce_needs_cv_unify(args: &[Expr], schema: &Schema) -> bool {
let any_raw = args
.iter()
.any(|a| bytes_shape(a, schema) == Shape::RawScalar);
any_raw && merge_shapes(args.iter().map(|a| bytes_shape(a, schema))) != Shape::RawScalar
}