use std::collections::HashSet;
use crate::{
error::Fallible,
metrics::Bound,
polars::literal_value_of,
traits::InfAdd,
transformations::make_stable_lazyframe::group_by::{Resize, check_infallible},
};
use opendp_derive::proven;
#[cfg(not(patch_polars))]
use polars_plan::dsl::WindowType;
use polars_plan::prelude::GroupbyOptions;
use polars::prelude::{
BooleanFunction, DataType, DslPlan, Expr, FunctionExpr, Operator, RankMethod, WindowMapping,
int_range, len, lit,
};
#[cfg(test)]
mod test;
#[derive(PartialEq, Debug)]
pub(crate) enum Truncation {
Filter(Expr),
GroupBy { keys: Vec<Expr>, aggs: Vec<Expr> },
}
#[proven]
pub(crate) fn match_truncations(
mut plan: DslPlan,
identifier: &Expr,
) -> Fallible<(DslPlan, Vec<Truncation>, Vec<Bound>)> {
let mut truncations = vec![];
let mut bounds = vec![];
let allowed_keys =
match_group_by_truncation(&plan, identifier).map(|(input, truncate, new_bound)| {
plan = input;
truncations.push(truncate);
bounds.push(new_bound.clone());
new_bound.by
});
while let DslPlan::Filter { input, predicate } = plan.clone() {
let Some(new_bounds) = match_truncation_predicate(&predicate, identifier)? else {
break;
};
if let Some(allowed_keys) = &allowed_keys {
new_bounds.iter().try_for_each(|bound| if bound.by.is_subset(allowed_keys) {
Ok(())
} else {
fallible!(
MakeTransformation,
"Filter truncation keys ({:?}) must be a subset of groupby truncation keys ({:?}). Otherwise the groupby truncation may invalidate filter truncation.",
bound.by, allowed_keys
)
})?
}
plan = input.as_ref().clone();
truncations.push(Truncation::Filter(predicate.clone()));
bounds.extend(new_bounds);
}
if match_group_by_truncation(&plan, identifier).is_some() {
return fallible!(
MakeTransformation,
"Groupby truncation must be the last truncation in the plan. Otherwise the groupby truncation may invalidate later truncations."
);
}
truncations.reverse();
bounds.reverse();
Ok((plan, truncations, bounds))
}
#[proven]
fn match_group_by_truncation(
plan: &DslPlan,
identifier: &Expr,
) -> Option<(DslPlan, Truncation, Bound)> {
#[cfg(patch_polars)]
let DslPlan::GroupBy {
input,
keys,
aggs,
apply,
options,
..
} = plan.clone()
else {
return None;
};
#[cfg(not(patch_polars))]
let DslPlan::GroupBy {
input,
keys,
aggs,
apply,
options,
..
} = plan.clone()
else {
return None;
};
if apply.is_some() || options.as_ref() != &GroupbyOptions::default() {
return None;
}
let (ids, by) = (keys.iter().cloned()).partition::<HashSet<_>, _>(|expr| expr == identifier);
if ids.is_empty() {
return None;
}
Some((
(*input).clone(),
Truncation::GroupBy { keys, aggs },
Bound {
by,
per_group: Some(1),
num_groups: None,
},
))
}
#[proven]
fn match_truncation_predicate(predicate: &Expr, identifier: &Expr) -> Fallible<Option<Vec<Bound>>> {
Ok(Some(match predicate {
Expr::Function {
input,
function: FunctionExpr::Boolean(BooleanFunction::AllHorizontal),
..
} => {
let bounds = (input.iter())
.map(|expr| match_truncation_predicate(expr, identifier))
.collect::<Fallible<Vec<Option<Vec<Bound>>>>>()?;
let Some(bounds) = bounds.into_iter().collect::<Option<Vec<Vec<Bound>>>>() else {
return Ok(None);
};
bounds.into_iter().flatten().collect::<Vec<_>>()
}
Expr::BinaryExpr {
left,
op: Operator::And,
right,
} => {
let left = match_truncation_predicate(left, identifier)?;
let right = match_truncation_predicate(right, identifier)?;
let Some((left, right)) = left.zip(right) else {
return Ok(None);
};
[left, right].concat()
}
Expr::BinaryExpr { left, op, right } => {
let (over, threshold, offset) = match op {
Operator::Lt => (left, right, 0),
Operator::LtEq => (left, right, 1),
Operator::Gt => (right, left, 0),
Operator::GtEq => (right, left, 1),
_ => return Ok(None),
};
#[cfg(patch_polars)]
let Expr::Over {
function,
partition_by,
mapping: WindowMapping::GroupsToRows,
..
} = over.as_ref()
else {
return Ok(None);
};
#[cfg(not(patch_polars))]
let Expr::Window {
function,
partition_by,
options: WindowType::Over(WindowMapping::GroupsToRows),
..
} = over.as_ref()
else {
return Ok(None);
};
let Some(threshold) = literal_value_of::<u32>(&threshold)? else {
return fallible!(
MakeTransformation,
"literal value for truncation threshold ({:?}) must be representable as a u32",
threshold
);
};
let threshold_value = threshold.inf_add(&offset)?;
let num_groups = match_num_groups_predicate(
function.as_ref(),
partition_by,
identifier,
threshold_value,
)?;
let per_group = match_per_group_predicate(
function.as_ref(),
partition_by,
identifier,
threshold_value,
)?;
let Some(bound) = num_groups.or(per_group) else {
return fallible!(
MakeTransformation,
"expected a predicate that limits per_group contributions (via int_range) or num_groups contributions (via rank). Found {:?}",
function
);
};
vec![bound]
}
_ => return Ok(None),
}))
}
#[proven]
fn match_num_groups_predicate(
ranks: &Expr,
partition_by: &Vec<Expr>,
identifier: &Expr,
threshold: u32,
) -> Fallible<Option<Bound>> {
let Expr::Function {
input,
function: FunctionExpr::Rank { options, .. },
..
} = ranks
else {
return Ok(None);
};
if partition_by != &vec![identifier.clone()] {
return fallible!(
MakeTransformation,
"num_groups truncation must use the identifier in the over clause"
);
}
if !matches!(options.method, RankMethod::Dense) {
return fallible!(
MakeTransformation,
"num_groups truncation's rank must be dense"
);
}
let Ok([input_item]) = <&[_; 1]>::try_from(input.as_slice()) else {
return fallible!(
MakeTransformation,
"rank function must be applied to a single input, found {:?}",
input.len()
);
};
let by = match input_item.clone() {
Expr::Function {
function: FunctionExpr::AsStruct,
mut input,
..
} => {
if let Some(Expr::Function {
input: hash_input,
function: FunctionExpr::Hash(_, _, _, _),
..
}) = input.get(0)
{
if hash_input.get(0) == input.get(1) {
let Some(Expr::Function {
input: true_input,
function: FunctionExpr::AsStruct,
..
}) = hash_input.get(0)
else {
return fallible!(
MakeTransformation,
"expected hash input to be a struct, found {:?}",
hash_input
);
};
input = true_input.clone();
}
}
input.into_iter().collect()
}
input => HashSet::from([input.clone()]),
};
Ok(Some(Bound {
by,
per_group: None,
num_groups: Some(threshold),
}))
}
#[proven]
fn match_per_group_predicate(
mut enumeration: &Expr,
partition_by: &Vec<Expr>,
identifier: &Expr,
threshold: u32,
) -> Fallible<Option<Bound>> {
match enumeration {
Expr::Function {
input, function, ..
} => {
let is_reorder = match function {
FunctionExpr::Reverse => true,
FunctionExpr::Random { method, .. } => {
let method: &'static str = method.into();
method == "shuffle"
}
_ => false,
};
if is_reorder {
enumeration = input
.get(0)
.ok_or_else(|| err!(MakeTransformation, "expected one input"))?;
}
}
Expr::SortBy { expr, by, .. } => {
by.iter()
.try_for_each(|key| check_infallible(key, Resize::Ban))?;
enumeration = expr.as_ref()
}
_ => (),
};
if enumeration.ne(&int_range(lit(0), len(), 1, DataType::Int64)) {
return Ok(None);
}
let (ids, by) = partition_by
.iter()
.cloned()
.partition::<HashSet<_>, _>(|expr| expr == identifier);
if ids.is_empty() {
return fallible!(
MakeTransformation,
"failed to find identifier column in per_group predicate condition"
);
}
Ok(Some(Bound {
by,
per_group: Some(threshold),
num_groups: None,
}))
}