use std::fmt::Display;
use itertools::Itertools;
use vortex_array::{ArrayRef, DeserializeMetadata, IntoArray, ProstMetadata, ToCanonical};
use vortex_dtype::{DType, FieldNames};
use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_err};
use vortex_proto::expr::select_opts::Opts;
use vortex_proto::expr::{FieldNames as ProtoFieldNames, SelectOpts};
use crate::display::{DisplayAs, DisplayFormat};
use crate::field::DisplayFieldNames;
use crate::{AnalysisExpr, ExprEncodingRef, ExprId, ExprRef, IntoExpr, Scope, VTable, vtable};
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum FieldSelection {
Include(FieldNames),
Exclude(FieldNames),
}
vtable!(Select);
#[derive(Debug, Clone, Hash, Eq)]
#[allow(clippy::derived_hash_with_manual_eq)]
pub struct SelectExpr {
selection: FieldSelection,
child: ExprRef,
}
impl PartialEq for SelectExpr {
fn eq(&self, other: &Self) -> bool {
self.selection == other.selection && self.child.eq(&other.child)
}
}
pub struct SelectExprEncoding;
impl VTable for SelectVTable {
type Expr = SelectExpr;
type Encoding = SelectExprEncoding;
type Metadata = ProstMetadata<SelectOpts>;
fn id(_encoding: &Self::Encoding) -> ExprId {
ExprId::new_ref("select")
}
fn encoding(_expr: &Self::Expr) -> ExprEncodingRef {
ExprEncodingRef::new_ref(SelectExprEncoding.as_ref())
}
fn metadata(expr: &Self::Expr) -> Option<Self::Metadata> {
let names = expr
.selection()
.field_names()
.iter()
.map(|f| f.to_string())
.collect_vec();
let opts = if expr.selection().is_include() {
Opts::Include(ProtoFieldNames { names })
} else {
Opts::Exclude(ProtoFieldNames { names })
};
Some(ProstMetadata(SelectOpts { opts: Some(opts) }))
}
fn children(expr: &Self::Expr) -> Vec<&ExprRef> {
vec![&expr.child]
}
fn with_children(expr: &Self::Expr, children: Vec<ExprRef>) -> VortexResult<Self::Expr> {
Ok(SelectExpr {
selection: expr.selection.clone(),
child: children[0].clone(),
})
}
fn build(
_encoding: &Self::Encoding,
metadata: &<Self::Metadata as DeserializeMetadata>::Output,
mut children: Vec<ExprRef>,
) -> VortexResult<Self::Expr> {
if children.len() != 1 {
vortex_bail!("Select expression must have exactly one child");
}
let fields = match metadata.opts.as_ref() {
Some(opts) => match opts {
Opts::Include(field_names) => FieldSelection::Include(FieldNames::from_iter(
field_names.names.iter().map(|s| s.as_str()),
)),
Opts::Exclude(field_names) => FieldSelection::Exclude(FieldNames::from_iter(
field_names.names.iter().map(|s| s.as_str()),
)),
},
None => {
vortex_bail!("Select expressions must be provided with fields to select or exclude")
}
};
let child = children
.drain(..)
.next()
.vortex_expect("number of children validated to be one");
Ok(SelectExpr {
selection: fields,
child,
})
}
fn evaluate(expr: &Self::Expr, scope: &Scope) -> VortexResult<ArrayRef> {
let batch = expr.child.unchecked_evaluate(scope)?.to_struct();
Ok(match &expr.selection {
FieldSelection::Include(f) => batch.project(f.as_ref()),
FieldSelection::Exclude(names) => {
let included_names = batch
.names()
.iter()
.filter(|&f| !names.as_ref().contains(f))
.cloned()
.collect::<Vec<_>>();
batch.project(included_names.as_slice())
}
}?
.into_array())
}
fn return_dtype(expr: &Self::Expr, scope: &DType) -> VortexResult<DType> {
let child_dtype = expr.child.return_dtype(scope)?;
let child_struct_dtype = child_dtype
.as_struct_fields_opt()
.ok_or_else(|| vortex_err!("Select child not a struct dtype"))?;
let projected = match &expr.selection {
FieldSelection::Include(fields) => child_struct_dtype.project(fields.as_ref())?,
FieldSelection::Exclude(fields) => child_struct_dtype
.names()
.iter()
.cloned()
.zip_eq(child_struct_dtype.fields())
.filter(|(name, _)| !fields.as_ref().contains(name))
.collect(),
};
Ok(DType::Struct(projected, child_dtype.nullability()))
}
}
pub fn select(field_names: impl Into<FieldNames>, child: ExprRef) -> ExprRef {
SelectExpr::include_expr(field_names.into(), child)
}
pub fn select_exclude(fields: impl Into<FieldNames>, child: ExprRef) -> ExprRef {
SelectExpr::exclude_expr(fields.into(), child)
}
impl SelectExpr {
pub fn new(fields: FieldSelection, child: ExprRef) -> Self {
Self {
selection: fields,
child,
}
}
pub fn new_expr(fields: FieldSelection, child: ExprRef) -> ExprRef {
Self::new(fields, child).into_expr()
}
pub fn include_expr(columns: FieldNames, child: ExprRef) -> ExprRef {
Self::new(FieldSelection::Include(columns), child).into_expr()
}
pub fn exclude_expr(columns: FieldNames, child: ExprRef) -> ExprRef {
Self::new(FieldSelection::Exclude(columns), child).into_expr()
}
pub fn selection(&self) -> &FieldSelection {
&self.selection
}
pub fn child(&self) -> &ExprRef {
&self.child
}
pub fn as_include(&self, field_names: &FieldNames) -> VortexResult<ExprRef> {
Ok(Self::new(
FieldSelection::Include(self.selection.as_include_names(field_names)?),
self.child.clone(),
)
.into_expr())
}
}
impl FieldSelection {
pub fn include(columns: FieldNames) -> Self {
assert_eq!(columns.iter().unique().collect_vec().len(), columns.len());
Self::Include(columns)
}
pub fn exclude(columns: FieldNames) -> Self {
assert_eq!(columns.iter().unique().collect_vec().len(), columns.len());
Self::Exclude(columns)
}
pub fn is_include(&self) -> bool {
matches!(self, Self::Include(_))
}
pub fn is_exclude(&self) -> bool {
matches!(self, Self::Exclude(_))
}
pub fn field_names(&self) -> &FieldNames {
let (FieldSelection::Include(fields) | FieldSelection::Exclude(fields)) = self;
fields
}
pub fn as_include_names(&self, field_names: &FieldNames) -> VortexResult<FieldNames> {
if self
.field_names()
.iter()
.any(|f| !field_names.iter().contains(f))
{
vortex_bail!(
"Field {:?} in select not in field names {:?}",
self,
field_names
);
}
match self {
FieldSelection::Include(fields) => Ok(fields.clone()),
FieldSelection::Exclude(exc_fields) => Ok(field_names
.iter()
.filter(|f| !exc_fields.iter().contains(f))
.cloned()
.collect()),
}
}
}
impl Display for FieldSelection {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
FieldSelection::Include(fields) => write!(f, "{{{}}}", DisplayFieldNames(fields)),
FieldSelection::Exclude(fields) => write!(f, "~{{{}}}", DisplayFieldNames(fields)),
}
}
}
impl DisplayAs for SelectExpr {
fn fmt_as(&self, df: DisplayFormat, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match df {
DisplayFormat::Compact => {
write!(f, "{}{}", self.child, self.selection)
}
DisplayFormat::Tree => {
let field_type = if self.selection.is_include() {
"include"
} else {
"exclude"
};
write!(
f,
"Select({}): {}",
field_type,
self.selection().field_names()
)
}
}
}
fn child_names(&self) -> Option<Vec<String>> {
None
}
}
impl AnalysisExpr for SelectExpr {}
#[cfg(test)]
mod tests {
use vortex_array::arrays::StructArray;
use vortex_array::{IntoArray, ToCanonical};
use vortex_buffer::buffer;
use vortex_dtype::{DType, FieldName, FieldNames, Nullability};
use crate::{FieldSelection, Scope, SelectExpr, root, select, select_exclude, test_harness};
fn test_array() -> StructArray {
StructArray::from_fields(&[
("a", buffer![0, 1, 2].into_array()),
("b", buffer![4, 5, 6].into_array()),
])
.unwrap()
}
#[test]
pub fn include_columns() {
let st = test_array();
let select = select(vec![FieldName::from("a")], root());
let selected = select
.evaluate(&Scope::new(st.to_array()))
.unwrap()
.to_struct();
let selected_names = selected.names().clone();
assert_eq!(selected_names.as_ref(), &["a"]);
}
#[test]
pub fn exclude_columns() {
let st = test_array();
let select = select_exclude(vec![FieldName::from("a")], root());
let selected = select
.evaluate(&Scope::new(st.to_array()))
.unwrap()
.to_struct();
let selected_names = selected.names().clone();
assert_eq!(selected_names.as_ref(), &["b"]);
}
#[test]
fn dtype() {
let dtype = test_harness::struct_dtype();
let select_expr = select(vec![FieldName::from("a")], root());
let expected_dtype = DType::Struct(
dtype
.as_struct_fields_opt()
.unwrap()
.project(&["a".into()])
.unwrap(),
Nullability::NonNullable,
);
assert_eq!(select_expr.return_dtype(&dtype).unwrap(), expected_dtype);
let select_expr_exclude = select_exclude(
vec![
FieldName::from("col1"),
FieldName::from("col2"),
FieldName::from("bool1"),
FieldName::from("bool2"),
],
root(),
);
assert_eq!(
select_expr_exclude.return_dtype(&dtype).unwrap(),
expected_dtype
);
let select_expr_exclude = select_exclude(
vec![FieldName::from("col1"), FieldName::from("col2")],
root(),
);
assert_eq!(
select_expr_exclude.return_dtype(&dtype).unwrap(),
DType::Struct(
dtype
.as_struct_fields_opt()
.unwrap()
.project(&["a".into(), "bool1".into(), "bool2".into()])
.unwrap(),
Nullability::NonNullable
)
);
}
#[test]
fn test_as_include_names() {
let field_names = FieldNames::from(["a", "b", "c"]);
let include = SelectExpr::new(FieldSelection::Include(["a"].into()), root());
let exclude = SelectExpr::new(FieldSelection::Exclude(["b", "c"].into()), root());
assert_eq!(
&include.as_include(&field_names).unwrap(),
&exclude.as_include(&field_names).unwrap()
);
}
}