use std::fmt::Formatter;
use prost::Message;
use vortex_error::VortexResult;
use vortex_error::vortex_err;
use vortex_proto::expr as pb;
use vortex_session::VortexSession;
use crate::ArrayRef;
use crate::ExecutionCtx;
use crate::arrays::StructArray;
use crate::arrays::struct_::StructArrayExt;
use crate::builtins::ArrayBuiltins;
use crate::builtins::ExprBuiltins;
use crate::dtype::DType;
use crate::dtype::FieldName;
use crate::dtype::FieldPath;
use crate::dtype::Nullability;
use crate::expr::Expression;
use crate::expr::StatsCatalog;
use crate::expr::lit;
use crate::expr::stats::Stat;
use crate::scalar_fn::Arity;
use crate::scalar_fn::ChildName;
use crate::scalar_fn::EmptyOptions;
use crate::scalar_fn::ExecutionArgs;
use crate::scalar_fn::ReduceCtx;
use crate::scalar_fn::ReduceNode;
use crate::scalar_fn::ReduceNodeRef;
use crate::scalar_fn::ScalarFnId;
use crate::scalar_fn::ScalarFnVTable;
use crate::scalar_fn::ScalarFnVTableExt;
use crate::scalar_fn::fns::literal::Literal;
use crate::scalar_fn::fns::mask::Mask;
use crate::scalar_fn::fns::pack::Pack;
#[derive(Clone)]
pub struct GetItem;
impl ScalarFnVTable for GetItem {
type Options = FieldName;
fn id(&self) -> ScalarFnId {
ScalarFnId::from("vortex.get_item")
}
fn serialize(&self, instance: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
Ok(Some(
pb::GetItemOpts {
path: instance.to_string(),
}
.encode_to_vec(),
))
}
fn deserialize(
&self,
_metadata: &[u8],
_session: &VortexSession,
) -> VortexResult<Self::Options> {
let opts = pb::GetItemOpts::decode(_metadata)?;
Ok(FieldName::from(opts.path))
}
fn arity(&self, _field_name: &FieldName) -> Arity {
Arity::Exact(1)
}
fn child_name(&self, _instance: &Self::Options, child_idx: usize) -> ChildName {
match child_idx {
0 => ChildName::from("input"),
_ => unreachable!("Invalid child index {} for GetItem expression", child_idx),
}
}
fn fmt_sql(
&self,
field_name: &FieldName,
expr: &Expression,
f: &mut Formatter<'_>,
) -> std::fmt::Result {
expr.children()[0].fmt_sql(f)?;
write!(f, ".{}", field_name)
}
fn return_dtype(&self, field_name: &FieldName, arg_dtypes: &[DType]) -> VortexResult<DType> {
let struct_dtype = &arg_dtypes[0];
let field_dtype = struct_dtype
.as_struct_fields_opt()
.and_then(|st| st.field(field_name))
.ok_or_else(|| {
vortex_err!("Couldn't find the {} field in the input scope", field_name)
})?;
if matches!(
(struct_dtype.nullability(), field_dtype.nullability()),
(Nullability::Nullable, Nullability::NonNullable)
) {
return Ok(field_dtype.with_nullability(Nullability::Nullable));
}
Ok(field_dtype)
}
fn execute(
&self,
field_name: &FieldName,
args: &dyn ExecutionArgs,
ctx: &mut ExecutionCtx,
) -> VortexResult<ArrayRef> {
let input = args.get(0)?.execute::<StructArray>(ctx)?;
let field = input.unmasked_field_by_name(field_name).cloned()?;
match input.dtype().nullability() {
Nullability::NonNullable => Ok(field),
Nullability::Nullable => field.mask(input.validity()?.to_array(input.len())),
}
}
fn reduce(
&self,
field_name: &FieldName,
node: &dyn ReduceNode,
ctx: &dyn ReduceCtx,
) -> VortexResult<Option<ReduceNodeRef>> {
let child = node.child(0);
if let Some(child_fn) = child.scalar_fn()
&& let Some(pack) = child_fn.as_opt::<Pack>()
&& let Some(idx) = pack.names.find(field_name)
{
let mut field = child.child(idx);
if pack.nullability.is_nullable() {
field = ctx.new_node(
Mask.bind(EmptyOptions),
&[field, ctx.new_node(Literal.bind(true.into()), &[])?],
)?;
}
return Ok(Some(field));
}
Ok(None)
}
fn simplify_untyped(
&self,
field_name: &FieldName,
expr: &Expression,
) -> VortexResult<Option<Expression>> {
let child = expr.child(0);
if let Some(pack) = child.as_opt::<Pack>() {
let idx = pack
.names
.iter()
.position(|name| name == field_name)
.ok_or_else(|| {
vortex_err!(
"Cannot find field {} in pack fields {:?}",
field_name,
pack.names
)
})?;
let mut field = child.child(idx).clone();
if pack.nullability.is_nullable() {
field = field.mask(lit(true))?;
}
return Ok(Some(field));
}
Ok(None)
}
fn stat_expression(
&self,
field_name: &FieldName,
_expr: &Expression,
stat: Stat,
catalog: &dyn StatsCatalog,
) -> Option<Expression> {
catalog.stats_ref(&FieldPath::from_name(field_name.clone()), stat)
}
fn is_null_sensitive(&self, _field_name: &FieldName) -> bool {
true
}
fn is_fallible(&self, _field_name: &FieldName) -> bool {
false
}
}
#[cfg(test)]
mod tests {
use vortex_buffer::buffer;
use crate::IntoArray;
use crate::dtype::DType;
use crate::dtype::FieldNames;
use crate::dtype::Nullability;
use crate::dtype::Nullability::NonNullable;
use crate::dtype::PType;
use crate::dtype::StructFields;
use crate::expr::checked_add;
use crate::expr::get_item;
use crate::expr::lit;
use crate::expr::pack;
use crate::expr::root;
use crate::scalar_fn::fns::get_item::StructArray;
use crate::validity::Validity;
fn test_array() -> StructArray {
StructArray::from_fields(&[
("a", buffer![0i32, 1, 2].into_array()),
("b", buffer![4i64, 5, 6].into_array()),
])
.unwrap()
}
#[test]
fn get_item_by_name() {
let st = test_array();
let get_item = get_item("a", root());
let item = st.into_array().apply(&get_item).unwrap();
assert_eq!(item.dtype(), &DType::from(PType::I32))
}
#[test]
fn get_item_by_name_none() {
let st = test_array();
let get_item = get_item("c", root());
assert!(st.into_array().apply(&get_item).is_err());
}
#[test]
#[ignore = "apply() has a bug with null propagation from struct validity to non-nullable child fields"]
fn get_nullable_field() {
let st = StructArray::try_new(
FieldNames::from(["a"]),
vec![buffer![1i32].into_array()],
1,
Validity::AllInvalid,
)
.unwrap()
.into_array();
let get_item_expr = get_item("a", root());
let item = st.apply(&get_item_expr).unwrap();
assert_eq!(
item.dtype(),
&DType::Primitive(PType::I32, Nullability::Nullable)
);
}
#[test]
fn test_pack_get_item_rule() {
let pack_expr = pack([("a", lit(1)), ("b", lit(2))], NonNullable);
let get_item_expr = get_item("b", pack_expr);
let result = get_item_expr
.optimize_recursive(&DType::Struct(StructFields::empty(), NonNullable))
.unwrap();
assert_eq!(result, lit(2));
}
#[test]
fn test_multi_level_pack_get_item_simplify() {
let inner_pack = pack([("a", lit(1)), ("b", lit(2))], NonNullable);
let get_a = get_item("a", inner_pack);
let outer_pack = pack([("x", get_a), ("y", lit(3)), ("z", lit(4))], NonNullable);
let get_z = get_item("z", outer_pack);
let dtype = DType::Primitive(PType::I32, NonNullable);
let result = get_z.optimize_recursive(&dtype).unwrap();
assert_eq!(result, lit(4));
}
#[test]
fn test_deeply_nested_pack_get_item() {
let innermost = pack([("a", lit(42))], NonNullable);
let get_a = get_item("a", innermost);
let level2 = pack([("b", get_a)], NonNullable);
let get_b = get_item("b", level2);
let level3 = pack([("c", get_b)], NonNullable);
let get_c = get_item("c", level3);
let outermost = pack([("final", get_c)], NonNullable);
let get_final = get_item("final", outermost);
let dtype = DType::Primitive(PType::I32, NonNullable);
let result = get_final.optimize_recursive(&dtype).unwrap();
assert_eq!(result, lit(42));
}
#[test]
fn test_partial_pack_get_item_simplify() {
let inner_pack = pack([("x", lit(1)), ("y", lit(2))], NonNullable);
let get_x = get_item("x", inner_pack);
let add_expr = checked_add(get_x, lit(10));
let outer_pack = pack([("result", add_expr)], NonNullable);
let get_result = get_item("result", outer_pack);
let dtype = DType::Primitive(PType::I32, NonNullable);
let result = get_result.optimize_recursive(&dtype).unwrap();
let expected = checked_add(lit(1), lit(10));
assert_eq!(&result, &expected);
}
#[test]
fn get_item_filter_list_field() {
use vortex_mask::Mask;
use crate::arrays::BoolArray;
use crate::arrays::FilterArray;
use crate::arrays::ListArray;
let list = ListArray::try_new(
buffer![0f32, 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11.].into_array(),
buffer![2u64, 4, 6, 8, 10, 12].into_array(),
Validity::Array(BoolArray::from_iter([true, true, false, true, true]).into_array()),
)
.unwrap();
let filtered = FilterArray::try_new(
list.into_array(),
Mask::from_iter([true, true, false, false, false]),
)
.unwrap();
let st = StructArray::try_new(
FieldNames::from(["data"]),
vec![filtered.into_array()],
2,
Validity::AllValid,
)
.unwrap();
st.into_array().apply(&get_item("data", root())).unwrap();
}
}