use toasty_core::stmt;
pub(super) fn index_filter_to_any_map(expr: stmt::Expr) -> stmt::Expr {
if let stmt::Expr::InList(ref in_list) = expr
&& matches!(*in_list.list, stmt::Expr::Value(stmt::Value::List(_)))
{
let shape = stmt::Expr::from(stmt::ExprBinaryOp {
lhs: in_list.expr.clone(),
op: stmt::BinaryOp::Eq,
rhs: Box::new(stmt::Expr::arg(0)),
});
return stmt::Expr::any(stmt::Expr::map(*in_list.list.clone(), shape));
}
let branches = flatten_to_dnf(expr);
unify_dnf_branches(branches)
}
fn flatten_to_dnf(expr: stmt::Expr) -> Vec<stmt::Expr> {
let mut branches: Vec<stmt::Expr> = Vec::new();
let mut queue: Vec<stmt::Expr> = vec![expr];
while let Some(expr) = queue.pop() {
match expr {
stmt::Expr::Or(or) => queue.extend(or.operands.into_iter().rev()),
stmt::Expr::And(and) => process_and(and, &mut queue, &mut branches),
leaf => branches.push(leaf),
}
}
for branch in &branches {
assert_no_or_in_any(branch);
}
branches
}
fn process_and(and: stmt::ExprAnd, queue: &mut Vec<stmt::Expr>, branches: &mut Vec<stmt::Expr>) {
if let Some(pos) = and
.operands
.iter()
.position(|op| matches!(op, stmt::Expr::Or(_)))
{
return distribute_over_or(and, pos, queue);
}
if let Some(pos) = and
.operands
.iter()
.position(|op| matches!(op, stmt::Expr::Any(_)))
{
return distribute_into_any(and, pos, queue);
}
branches.push(stmt::Expr::And(and));
}
fn distribute_over_or(and: stmt::ExprAnd, pos: usize, queue: &mut Vec<stmt::Expr>) {
let mut operands = and.operands;
let stmt::Expr::Or(or) = operands.remove(pos) else {
unreachable!()
};
for branch in or.operands.into_iter().rev() {
let mut new_operands = operands.clone();
new_operands.insert(pos, branch);
queue.push(
stmt::ExprAnd {
operands: new_operands,
}
.into(),
);
}
}
fn distribute_into_any(and: stmt::ExprAnd, pos: usize, queue: &mut Vec<stmt::Expr>) {
let mut operands = and.operands;
let stmt::Expr::Any(any) = operands.remove(pos) else {
unreachable!()
};
let stmt::Expr::Map(map) = *any.expr else {
todo!("Any with non-Map expr in AND distribution");
};
let mut inner_operands = vec![*map.map];
inner_operands.extend(operands);
let inner: stmt::Expr = if inner_operands.len() == 1 {
inner_operands.into_iter().next().unwrap()
} else {
stmt::ExprAnd {
operands: inner_operands,
}
.into()
};
queue.push(
stmt::ExprAny {
expr: Box::new(stmt::Expr::Map(stmt::ExprMap {
base: map.base,
map: Box::new(inner),
})),
}
.into(),
);
}
fn unify_dnf_branches(branches: Vec<stmt::Expr>) -> stmt::Expr {
if branches.len() == 1 {
return branches.into_iter().next().unwrap();
}
let mut groups: Vec<(stmt::Expr, Vec<stmt::Value>)> = vec![];
for branch in branches {
let (shape, value) = extract_shape(branch);
if let Some((_, values)) = groups.iter_mut().find(|(s, _)| *s == shape) {
values.push(value);
} else {
groups.push((shape, vec![value]));
}
}
if groups.len() > 1 {
todo!(
"OR index filter with multiple distinct branch shapes is not yet implemented; \
shapes: {:#?}",
groups.iter().map(|(s, _)| s).collect::<Vec<_>>()
);
}
let (shape, values) = groups.into_iter().next().unwrap();
stmt::Expr::any(stmt::Expr::map(
stmt::Expr::Value(stmt::Value::List(values)),
shape,
))
}
fn extract_shape(branch: stmt::Expr) -> (stmt::Expr, stmt::Value) {
match branch {
stmt::Expr::BinaryOp(b) => {
let stmt::Expr::Value(v) = *b.rhs else {
todo!("non-literal value in OR branch rhs: {:#?}", b.rhs);
};
let shape: stmt::Expr = stmt::ExprBinaryOp {
lhs: b.lhs,
op: b.op,
rhs: Box::new(stmt::Expr::arg(0)),
}
.into();
(shape, v)
}
stmt::Expr::And(and) => {
let mut values = vec![];
let mut shape_operands = vec![];
for (i, operand) in and.operands.into_iter().enumerate() {
let stmt::Expr::BinaryOp(b) = operand else {
todo!(
"non-BinaryOp operand in composite AND branch: {:#?}",
operand
);
};
let stmt::Expr::Value(v) = *b.rhs else {
todo!(
"non-literal value in composite AND branch rhs: {:#?}",
b.rhs
);
};
values.push(v);
shape_operands.push(stmt::Expr::from(stmt::ExprBinaryOp {
lhs: b.lhs,
op: b.op,
rhs: Box::new(stmt::Expr::arg(i)),
}));
}
let shape = stmt::Expr::from(stmt::ExprAnd {
operands: shape_operands,
});
let record = stmt::Value::Record(stmt::ValueRecord::from_vec(values));
(shape, record)
}
_ => todo!("unsupported branch type in OR index filter: {branch:#?}"),
}
}
fn assert_no_or_in_any(expr: &stmt::Expr) {
match expr {
stmt::Expr::Any(any) => {
assert!(
!contains_or(&any.expr),
"Any(Map(...)) contains an Or expression after DNF distribution; \
this is a bug in flatten_to_dnf: {:#?}",
any.expr
);
}
stmt::Expr::And(and) => and.operands.iter().for_each(assert_no_or_in_any),
_ => {}
}
}
fn contains_or(expr: &stmt::Expr) -> bool {
match expr {
stmt::Expr::Or(_) => true,
stmt::Expr::And(and) => and.operands.iter().any(contains_or),
stmt::Expr::Any(a) => contains_or(&a.expr),
stmt::Expr::Map(m) => contains_or(&m.base) || contains_or(&m.map),
stmt::Expr::BinaryOp(b) => contains_or(&b.lhs) || contains_or(&b.rhs),
stmt::Expr::Not(n) => contains_or(&n.expr),
stmt::Expr::IsNull(n) => contains_or(&n.expr),
_ => false,
}
}