use vortex_error::VortexResult;
use vortex_error::vortex_err;
use crate::dtype::DType;
use crate::dtype::Field;
use crate::dtype::FieldPath;
use crate::dtype::FieldPathSet;
use crate::expr::Expression;
use crate::expr::traversal::FoldDownContext;
use crate::expr::traversal::FoldUp;
use crate::expr::traversal::NodeExt;
use crate::expr::traversal::NodeFolderContext;
use crate::scalar_fn::fns::get_item::GetItem;
use crate::scalar_fn::fns::root::Root;
use crate::scalar_fn::fns::select::Select;
pub fn referenced_field_paths(expr: &Expression, scope: &DType) -> VortexResult<FieldPathSet> {
expr.return_dtype(scope)?;
let mut collector = ReferencedFieldPaths {
scope,
field_paths: FieldPathSet::default(),
};
expr.clone()
.fold_context(&vec![FieldPath::root()], &mut collector)?;
let field_paths = collector.field_paths;
#[cfg(debug_assertions)]
if let Some(scope_fields) = scope.as_struct_fields_opt() {
use vortex_utils::aliases::hash_set::HashSet;
use crate::dtype::FieldName;
use crate::expr::analysis::immediate_access::immediate_scope_access;
let referenced_heads: HashSet<FieldName> = if field_paths.iter().any(FieldPath::is_root) {
scope_fields.names().iter().cloned().collect()
} else {
field_paths
.iter()
.filter_map(|path| match path.parts().first() {
Some(Field::Name(name)) => Some(name.clone()),
_ => None,
})
.collect()
};
debug_assert_eq!(
referenced_heads,
immediate_scope_access(expr, scope_fields),
"referenced field path heads must match the immediately accessed scope fields"
);
}
Ok(field_paths)
}
struct ReferencedFieldPaths<'a> {
scope: &'a DType,
field_paths: FieldPathSet,
}
impl NodeFolderContext for ReferencedFieldPaths<'_> {
type NodeTy = Expression;
type Result = ();
type Context = Vec<FieldPath>;
fn visit_down(
&mut self,
requested: &Self::Context,
node: &Expression,
) -> VortexResult<FoldDownContext<Self::Context, ()>> {
if node.is::<Root>() {
self.field_paths.extend(
requested
.iter()
.map(|path| FieldPath::from_iter(path.parts().iter().rev().cloned())),
);
return Ok(FoldDownContext::Skip(()));
}
if let Some(field_name) = node.as_opt::<GetItem>() {
let appended = requested
.iter()
.map(|path| path.clone().push(Field::Name(field_name.clone())))
.collect();
return Ok(FoldDownContext::Continue(appended));
}
if let Some(selection) = node.as_opt::<Select>() {
let child_dtype = node.child(0).return_dtype(self.scope)?;
let child_fields = child_dtype
.as_struct_fields_opt()
.ok_or_else(|| vortex_err!("Select child is not a struct"))?;
let included_fields = selection.normalize_to_included_fields(child_fields.names())?;
let mut narrowed = Vec::with_capacity(requested.len());
for path in requested {
if path.is_root() {
narrowed.extend(included_fields.iter().cloned().map(FieldPath::from_name));
} else if let Some(Field::Name(field_name)) = path.parts().last()
&& included_fields
.iter()
.any(|included| included == field_name)
{
narrowed.push(path.clone());
}
}
if narrowed.is_empty() {
return Ok(FoldDownContext::Skip(()));
}
return Ok(FoldDownContext::Continue(narrowed));
}
Ok(FoldDownContext::Continue(vec![FieldPath::root()]))
}
fn visit_up(
&mut self,
_node: Expression,
_requested: &Self::Context,
_children: Vec<()>,
) -> VortexResult<FoldUp<()>> {
Ok(FoldUp::Continue(()))
}
}
#[cfg(test)]
mod tests {
use vortex_utils::aliases::hash_set::HashSet;
use super::*;
use crate::dtype::Nullability::NonNullable;
use crate::dtype::PType::I32;
use crate::dtype::StructFields;
use crate::expr::get_item;
use crate::expr::pack;
use crate::expr::root;
use crate::expr::select;
use crate::expr::select_exclude;
fn scope() -> DType {
DType::Struct(
StructFields::from_iter([(
"a",
DType::Struct(
StructFields::from_iter([("x", I32), ("y", I32)]),
NonNullable,
),
)]),
NonNullable,
)
}
fn referenced(expr: &Expression) -> VortexResult<HashSet<FieldPath>> {
Ok(referenced_field_paths(expr, &scope())?
.into_iter()
.collect())
}
#[test]
fn nested_select_preserves_field_path() -> VortexResult<()> {
let expr = select(["x"], get_item("a", root()));
assert_eq!(
referenced(&expr)?,
HashSet::from_iter([FieldPath::from_name("a").push("x")])
);
Ok(())
}
#[test]
fn get_item_after_select_only_references_requested_field() -> VortexResult<()> {
let expr = get_item("x", select(["x", "y"], get_item("a", root())));
assert_eq!(
referenced(&expr)?,
HashSet::from_iter([FieldPath::from_name("a").push("x")])
);
Ok(())
}
#[test]
fn select_exclude_references_included_fields() -> VortexResult<()> {
let expr = select_exclude(["y"], get_item("a", root()));
assert_eq!(
referenced(&expr)?,
HashSet::from_iter([FieldPath::from_name("a").push("x")])
);
Ok(())
}
#[test]
fn ancestor_path_subsumes_descendant() -> VortexResult<()> {
let expr = pack(
[
("a", get_item("a", root())),
("x", get_item("x", get_item("a", root()))),
],
NonNullable,
);
assert_eq!(
referenced(&expr)?,
HashSet::from_iter([FieldPath::from_name("a")])
);
Ok(())
}
#[test]
fn get_item_through_opaque_fn_references_all_fields() -> VortexResult<()> {
let expr = get_item("x", pack([("x", root())], NonNullable));
assert_eq!(referenced(&expr)?, HashSet::from_iter([FieldPath::root()]));
Ok(())
}
#[test]
fn root_references_all_fields() -> VortexResult<()> {
assert_eq!(
referenced(&root())?,
HashSet::from_iter([FieldPath::root()])
);
Ok(())
}
#[test]
fn invalid_get_item_path_returns_error() {
assert!(referenced_field_paths(&get_item("missing", root()), &scope()).is_err());
}
}