use std::collections::HashSet;
use crate::core::{Function, StabilityMap, Transformation};
use crate::domains::{Context, DslPlanDomain, FrameDomain, SeriesDomain, WildExprDomain};
use crate::error::*;
use crate::metrics::{
Bound, Bounds, FrameDistance, L0PInfDistance, L01InfDistance, SymmetricDistance,
SymmetricIdDistance,
};
use crate::traits::{InfMul, option_min};
use crate::transformations::make_stable_expr;
use matching::Truncation;
use opendp_derive::proven;
use polars::prelude::*;
use polars_plan::prelude::GroupbyOptions;
use super::StableDslPlan;
use super::group_by::{Resize, check_infallible};
#[cfg(test)]
mod test;
mod matching;
pub(crate) use matching::match_truncations;
pub fn make_stable_truncate(
input_domain: DslPlanDomain,
input_metric: FrameDistance<SymmetricIdDistance>,
plan: DslPlan,
) -> Fallible<
Transformation<
DslPlanDomain,
FrameDistance<SymmetricIdDistance>,
DslPlanDomain,
FrameDistance<SymmetricDistance>,
>,
> {
let (input, truncations, truncation_bounds) =
match_truncations(plan, &input_metric.0.identifier)?;
if truncations.is_empty() {
return fallible!(MakeTransformation, "failed to match truncation");
};
let t_prior = input.make_stable(input_domain, input_metric)?;
let (middle_domain, middle_metric): (_, FrameDistance<SymmetricIdDistance>) =
t_prior.output_space();
(truncation_bounds.iter().flat_map(|b| &b.by)).try_for_each(|key| {
make_stable_expr::<_, L01InfDistance<SymmetricIdDistance>>(
WildExprDomain {
columns: middle_domain.series_domains.clone(),
context: Context::RowByRow,
},
L0PInfDistance(middle_metric.0.clone()),
key.clone(),
)
.map(|_| ())
})?;
let output_domain = (truncations.iter())
.try_fold(middle_domain.clone(), |domain, truncation| {
truncate_domain(domain, truncation)
})?;
let t_truncate = Transformation::new(
middle_domain,
middle_metric.clone(),
output_domain,
FrameDistance(SymmetricDistance),
Function::new(move |plan: &DslPlan| {
(truncations.iter()).fold(plan.clone(), |plan, truncation| match truncation {
Truncation::Filter(predicate) => DslPlan::Filter {
input: Arc::new(plan.clone()),
predicate: predicate.clone(),
},
Truncation::GroupBy { keys, aggs } => {
#[cfg(patch_polars)]
let output = DslPlan::GroupBy {
input: Arc::new(plan),
keys: keys.clone(),
aggs: aggs.clone(),
predicates: vec![],
apply: None,
maintain_order: false,
options: Arc::new(GroupbyOptions::default()),
};
#[cfg(not(patch_polars))]
let output = DslPlan::GroupBy {
input: Arc::new(plan),
keys: keys.clone(),
aggs: aggs.clone(),
apply: None,
maintain_order: false,
options: Arc::new(GroupbyOptions::default()),
};
output
}
})
}),
StabilityMap::new_fallible(move |id_bounds: &Bounds| {
let total_num_ids = id_bounds.get_bound(&Default::default()).per_group;
let new_bounds = (truncation_bounds.iter())
.map(|truncation_bound| {
truncate_id_bound(
id_bounds.get_bound(&truncation_bound.by),
truncation_bound.clone(),
total_num_ids,
)
})
.collect::<Fallible<Vec<Bound>>>()?;
Ok(Bounds(new_bounds))
}),
)?;
t_prior >> t_truncate
}
#[proven]
fn truncate_domain(mut domain: DslPlanDomain, truncation: &Truncation) -> Fallible<DslPlanDomain> {
match &truncation {
Truncation::Filter { .. } => {
domain.margins.iter_mut().for_each(|m| {
m.invariant = None;
});
Ok(domain)
}
Truncation::GroupBy { keys, aggs } => {
aggs.iter()
.try_for_each(|e| check_infallible(e, Resize::Allow))?;
FrameDomain::new_with_margins(
domain
.simulate_schema(|lf| lf.group_by(&keys).agg(&aggs))?
.iter_fields()
.map(SeriesDomain::new_from_field)
.collect::<Fallible<_>>()?,
domain
.margins
.into_iter()
.filter(|m| m.by.is_subset(&HashSet::from_iter(keys.clone())))
.map(|mut m| {
m.invariant = None;
m
})
.collect(),
)
}
}
}
#[proven]
fn truncate_id_bound(
id_bound: Bound,
truncation: Bound,
total_ids: Option<u32>,
) -> Fallible<Bound> {
let mut row_bound = Bound::by(&truncation.by.iter().cloned().collect::<Vec<_>>());
if let Some((num_ids, num_rows)) = id_bound.per_group.zip(truncation.per_group) {
row_bound = row_bound.with_per_group(num_ids.inf_mul(&num_rows)?);
}
let num_groups_via_truncation = total_ids
.zip(truncation.num_groups)
.map(|(num_ids, num_groups)| num_ids.inf_mul(&num_groups))
.transpose()?;
if let Some(num_groups) = option_min(num_groups_via_truncation, id_bound.num_groups) {
row_bound = row_bound.with_num_groups(num_groups);
}
Ok(row_bound)
}