use std::fmt::Debug;
use std::hash::Hash;
use std::ops::Deref;
use vortex_array::operator::OperatorRef;
use vortex_array::{ArrayRef, DeserializeMetadata, SerializeMetadata};
use vortex_dtype::DType;
use vortex_error::VortexResult;
use crate::display::DisplayAs;
use crate::{
AnalysisExpr, ExprEncoding, ExprEncodingRef, ExprId, ExprRef, IntoExpr, Scope, VortexExpr,
};
pub trait VTable: 'static + Sized + Send + Sync + Debug {
type Expr: 'static
+ Send
+ Sync
+ Clone
+ Debug
+ DisplayAs
+ PartialEq
+ Eq
+ Hash
+ Deref<Target = dyn VortexExpr>
+ IntoExpr
+ AnalysisExpr;
type Encoding: 'static + Send + Sync + Deref<Target = dyn ExprEncoding>;
type Metadata: SerializeMetadata + DeserializeMetadata + Debug;
fn id(encoding: &Self::Encoding) -> ExprId;
fn encoding(expr: &Self::Expr) -> ExprEncodingRef;
fn metadata(expr: &Self::Expr) -> Option<Self::Metadata>;
fn children(expr: &Self::Expr) -> Vec<&ExprRef>;
fn with_children(expr: &Self::Expr, children: Vec<ExprRef>) -> VortexResult<Self::Expr>;
fn build(
encoding: &Self::Encoding,
metadata: &<Self::Metadata as DeserializeMetadata>::Output,
children: Vec<ExprRef>,
) -> VortexResult<Self::Expr>;
fn evaluate(expr: &Self::Expr, scope: &Scope) -> VortexResult<ArrayRef>;
fn return_dtype(expr: &Self::Expr, scope: &DType) -> VortexResult<DType>;
fn operator(_expr: &Self::Expr, _scope: &OperatorRef) -> VortexResult<Option<OperatorRef>> {
Ok(None)
}
}
#[macro_export]
macro_rules! vtable {
($V:ident) => {
$crate::aliases::paste::paste! {
#[derive(Debug)]
pub struct [<$V VTable>];
impl AsRef<dyn $crate::VortexExpr> for [<$V Expr>] {
fn as_ref(&self) -> &dyn $crate::VortexExpr {
unsafe { &*(self as *const [<$V Expr>] as *const $crate::ExprAdapter<[<$V VTable>]>) }
}
}
impl std::ops::Deref for [<$V Expr>] {
type Target = dyn $crate::VortexExpr;
fn deref(&self) -> &Self::Target {
unsafe { &*(self as *const [<$V Expr>] as *const $crate::ExprAdapter<[<$V VTable>]>) }
}
}
impl $crate::IntoExpr for [<$V Expr>] {
fn into_expr(self) -> $crate::ExprRef {
std::sync::Arc::new(unsafe { std::mem::transmute::<[<$V Expr>], $crate::ExprAdapter::<[<$V VTable>]>>(self) })
}
}
impl From<[<$V Expr>]> for $crate::ExprRef {
fn from(value: [<$V Expr>]) -> $crate::ExprRef {
use $crate::IntoExpr;
value.into_expr()
}
}
impl AsRef<dyn $crate::ExprEncoding> for [<$V ExprEncoding>] {
fn as_ref(&self) -> &dyn $crate::ExprEncoding {
unsafe { &*(self as *const [<$V ExprEncoding>] as *const $crate::ExprEncodingAdapter<[<$V VTable>]>) }
}
}
impl std::ops::Deref for [<$V ExprEncoding>] {
type Target = dyn $crate::ExprEncoding;
fn deref(&self) -> &Self::Target {
unsafe { &*(self as *const [<$V ExprEncoding>] as *const $crate::ExprEncodingAdapter<[<$V VTable>]>) }
}
}
}
};
}
#[cfg(test)]
mod tests {
use rstest::{fixture, rstest};
use super::*;
use crate::proto::{ExprSerializeProtoExt, deserialize_expr_proto};
use crate::*;
#[fixture]
#[once]
fn registry() -> ExprRegistry {
ExprRegistry::default()
}
#[rstest]
#[case(root())]
#[case(select(["hello", "world"], root()))]
#[case(select_exclude(["world", "hello"], root()))]
#[case(lit(42i32))]
#[case(lit(std::f64::consts::PI))]
#[case(lit(true))]
#[case(lit("hello"))]
#[case(col("column_name"))]
#[case(get_item("field", root()))]
#[case(eq(col("a"), lit(10)))]
#[case(not_eq(col("a"), lit(10)))]
#[case(gt(col("a"), lit(10)))]
#[case(gt_eq(col("a"), lit(10)))]
#[case(lt(col("a"), lit(10)))]
#[case(lt_eq(col("a"), lit(10)))]
#[case(and(col("a"), col("b")))]
#[case(or(col("a"), col("b")))]
#[case(not(col("a")))]
#[case(checked_add(col("a"), lit(5)))]
#[case(is_null(col("nullable_col")))]
#[case(cast(
col("a"),
DType::Primitive(vortex_dtype::PType::I64, vortex_dtype::Nullability::NonNullable)
))]
#[case(between(col("a"), lit(10), lit(20), vortex_array::compute::BetweenOptions { lower_strict: vortex_array::compute::StrictComparison::NonStrict, upper_strict: vortex_array::compute::StrictComparison::NonStrict }))]
#[case(list_contains(col("list_col"), lit("item")))]
#[case(pack([("field1", col("a")), ("field2", col("b"))], vortex_dtype::Nullability::NonNullable))]
#[case(merge([col("struct1"), col("struct2")]))]
#[case(and(gt(col("a"), lit(0)), lt(col("a"), lit(100))))]
#[case(or(is_null(col("a")), eq(col("a"), lit(0))))]
#[case(not(and(eq(col("status"), lit("active")), gt(col("age"), lit(18)))))]
fn text_expr_serde_round_trip(
registry: &ExprRegistry,
#[case] expr: ExprRef,
) -> anyhow::Result<()> {
let serialized_pb = expr.serialize_proto()?;
let deserialized_expr = deserialize_expr_proto(&serialized_pb, registry)?;
assert_eq!(&expr, &deserialized_expr);
Ok(())
}
}