use vortex_error::VortexExpect;
use crate::dtype::Nullability;
use crate::dtype::StructFields;
use crate::expr::Expression;
use crate::expr::col;
use crate::expr::pack;
use crate::expr::root;
use crate::expr::traversal::NodeExt;
use crate::expr::traversal::Transformed;
use crate::expr::traversal::TraversalOrder;
pub fn replace(expr: Expression, needle: &Expression, replacement: Expression) -> Expression {
expr.transform_down(|node| {
if &node == needle {
Ok(Transformed {
value: replacement.clone(),
order: TraversalOrder::Skip,
changed: true,
})
} else {
Ok(Transformed::no(node))
}
})
.vortex_expect("ReplaceVisitor should not fail")
.into_inner()
}
pub fn replace_root_fields(expr: Expression, fields: &StructFields) -> Expression {
replace(
expr,
&root(),
pack(
fields
.names()
.iter()
.map(|name| (name.clone(), col(name.clone()))),
Nullability::NonNullable,
),
)
}
#[cfg(test)]
mod test {
use super::replace;
use crate::dtype::Nullability::NonNullable;
use crate::expr::get_item;
use crate::expr::lit;
use crate::expr::pack;
#[test]
fn test_replace_full_tree() {
let e = get_item("b", pack([("a", lit(1)), ("b", lit(2))], NonNullable));
let needle = get_item("b", pack([("a", lit(1)), ("b", lit(2))], NonNullable));
let replacement = lit(42);
let replaced_expr = replace(e, &needle, replacement.clone());
assert_eq!(&replaced_expr, &replacement);
}
#[test]
fn test_replace_leaf() {
let e = pack([("a", lit(1)), ("b", lit(2))], NonNullable);
let needle = lit(2);
let replacement = lit(42);
let replaced_expr = replace(e, &needle, replacement);
assert_eq!(replaced_expr.to_string(), "pack(a: 1i32, b: 42i32)");
}
}