use opendp_derive::bootstrap;
use polars::{
datatypes::{AnyValue, DataType, Field},
frame::{DataFrame, row::Row},
prelude::{AggExpr, DslPlan, Expr, FunctionExpr, IntoLazy, LazyFrame, Schema},
};
#[cfg(test)]
mod test;
use crate::{
accuracy::{
discrete_gaussian_scale_to_accuracy, discrete_laplacian_scale_to_accuracy,
gaussian_scale_to_accuracy, laplacian_scale_to_accuracy,
},
core::{Measure, Measurement, Metric, MetricSpace},
domains::LazyFrameDomain,
error::Fallible,
measurements::{
KeySanitizer, MatchGroupBy,
expr_index_candidates::IndexCandidatesPlugin,
expr_noise::{NoiseDistribution, NoisePlugin, Support},
expr_noisy_max::NoisyMaxPlugin,
is_threshold_predicate, match_group_by,
},
polars::{ExtractLazyFrame, OnceFrame, match_trusted_plugin},
transformations::expr_discrete_quantile_score::DiscreteQuantileScorePlugin,
};
#[cfg(feature = "ffi")]
mod ffi;
#[bootstrap(
name = "summarize_polars_measurement",
features("contrib"),
arguments(
measurement(rust_type = "AnyMeasurement"),
alpha(c_type = "AnyObject *", default = b"null")
),
generics(MI(suppress), MO(suppress)),
returns(c_type = "FfiResult<AnyObject *>")
)]
pub fn summarize_polars_measurement<MI: Metric, MO: 'static + Measure>(
measurement: Measurement<LazyFrameDomain, MI, MO, OnceFrame>,
alpha: Option<f64>,
) -> Fallible<DataFrame>
where
(LazyFrameDomain, MI): MetricSpace,
{
let schema = measurement.input_domain.schema();
let lf = DataFrame::from_rows_and_schema(&[], &schema)?.lazy();
let mut of = measurement.invoke(&lf)?;
let lf: LazyFrame = of.eval_internal(&ExtractLazyFrame)?;
summarize_lazyframe(&lf, alpha)
}
#[derive(Clone)]
struct UtilitySummary {
pub name: String,
pub aggregate: String,
pub distribution: Option<String>,
pub scale: Option<f64>,
pub accuracy: Option<f64>,
pub threshold: Option<u32>,
}
pub fn summarize_lazyframe(lazyframe: &LazyFrame, alpha: Option<f64>) -> Fallible<DataFrame> {
let mut utility = summarize_logical_plan(&lazyframe.logical_plan, alpha)?;
if alpha.is_none() {
utility = utility.drop("accuracy")?;
}
if utility.column("threshold")?.is_null().all() {
utility = utility.drop("threshold")?;
}
Ok(utility)
}
fn summarize_logical_plan(logical_plan: &DslPlan, alpha: Option<f64>) -> Fallible<DataFrame> {
if let Some(MatchGroupBy {
aggs: exprs,
key_sanitizer,
..
}) = match_group_by(logical_plan.clone())?
{
let threshold = if let Some(KeySanitizer::Filter(predicate)) = key_sanitizer {
Some(is_threshold_predicate(predicate.clone()).ok_or_else(|| {
err!(
FailedFunction,
"predicate is not a valid filter: {}",
predicate
)
})?)
} else {
None
};
return agg_dataframe(&exprs, threshold, alpha);
}
if let DslPlan::Select { expr: exprs, .. } = logical_plan {
return agg_dataframe(exprs, None, alpha);
}
if let DslPlan::Slice { input, .. }
| DslPlan::Sink { input, .. }
| DslPlan::HStack { input, .. } = logical_plan
{
return summarize_logical_plan(input.as_ref(), alpha);
}
fallible!(
FailedFunction,
"unrecognized dsl: {}",
logical_plan.describe()?
)
}
fn agg_dataframe(
exprs: &Vec<Expr>,
threshold: Option<(String, u32)>,
alpha: Option<f64>,
) -> Fallible<DataFrame> {
let rows = exprs
.iter()
.map(|e| {
let name = e.clone().meta().output_name()?.to_string();
Ok(summarize_expr(&e, alpha, threshold.clone())?
.into_iter()
.map(|mut summary| {
summary.name = name.clone();
summary
})
.collect())
})
.collect::<Fallible<Vec<Vec<UtilitySummary>>>>()?;
Ok(DataFrame::from_rows_and_schema(
&(rows.iter().flatten())
.map(|summary| {
Row(vec![
AnyValue::String(summary.name.as_ref()),
AnyValue::String(summary.aggregate.as_ref()),
match &summary.distribution {
Some(distribution) => AnyValue::String(distribution.as_ref()),
None => AnyValue::Null,
},
AnyValue::from(summary.scale),
AnyValue::from(summary.accuracy),
AnyValue::from(summary.threshold),
])
})
.collect::<Vec<_>>(),
&Schema::from_iter(vec![
Field::new("column".into(), DataType::String),
Field::new("aggregate".into(), DataType::String),
Field::new("distribution".into(), DataType::String),
Field::new("scale".into(), DataType::Float64),
Field::new("accuracy".into(), DataType::Float64),
Field::new("threshold".into(), DataType::UInt32),
]),
)?)
}
fn summarize_expr<'a>(
expr: &Expr,
alpha: Option<f64>,
threshold: Option<(String, u32)>,
) -> Fallible<Vec<UtilitySummary>> {
let name = expr.clone().meta().output_name()?.to_string();
let expr = expr.clone().meta().undo_aliases();
let t_value = threshold
.clone()
.and_then(|(t_name, t_value)| (name == t_name).then_some(t_value));
if let Some((input, plugin)) = match_trusted_plugin::<NoisePlugin>(&expr)? {
let accuracy = if let Some(alpha) = alpha {
use {NoiseDistribution::*, Support::*};
Some(match (plugin.distribution, plugin.support) {
(Laplace, Float) => laplacian_scale_to_accuracy(plugin.scale, alpha),
(Gaussian, Float) => gaussian_scale_to_accuracy(plugin.scale, alpha),
(Laplace, Integer) => discrete_laplacian_scale_to_accuracy(plugin.scale, alpha),
(Gaussian, Integer) => discrete_gaussian_scale_to_accuracy(plugin.scale, alpha),
}?)
} else {
None
};
return Ok(vec![UtilitySummary {
name,
aggregate: expr_aggregate(&input[0])?.to_string(),
distribution: Some(format!("{:?} {:?}", plugin.support, plugin.distribution)),
scale: Some(plugin.scale),
accuracy,
threshold: t_value,
}]);
}
if let Some((inputs, _)) = match_trusted_plugin::<IndexCandidatesPlugin>(&expr)? {
return summarize_expr(&inputs[0], alpha, threshold);
}
if let Some((inputs, plugin)) = match_trusted_plugin::<NoisyMaxPlugin>(&expr)? {
return Ok(vec![UtilitySummary {
name,
aggregate: expr_aggregate(&inputs[0])?.to_string(),
distribution: Some(format!(
"{}{}",
match plugin.replacement {
false => "Exponential",
true => "Gumbel",
},
if plugin.negate { "Min" } else { "Max" }
)),
scale: Some(plugin.scale),
accuracy: None,
threshold: t_value,
}]);
}
Ok(match expr {
Expr::Len => vec![UtilitySummary {
name: name.clone(),
aggregate: "Frame Length".to_string(),
distribution: None,
scale: None,
accuracy: alpha.is_some().then_some(0.0),
threshold: t_value,
}],
Expr::Function { input, .. } => input
.iter()
.map(|e| summarize_expr(e, alpha, threshold.clone()))
.collect::<Fallible<Vec<_>>>()?
.into_iter()
.flatten()
.collect(),
Expr::BinaryExpr { left, op: _, right } => [
summarize_expr(&left, alpha, threshold.clone())?,
summarize_expr(&right, alpha, threshold)?,
]
.concat(),
e => return fallible!(FailedFunction, "unrecognized primitive: {:?}", e),
})
}
fn expr_aggregate(expr: &Expr) -> Fallible<String> {
if let Some((_, plugin)) = match_trusted_plugin::<DiscreteQuantileScorePlugin>(&expr)? {
let (num, den) = plugin.alpha;
return Ok(format!("{}-Quantile", num as f64 / den as f64));
}
Ok(match expr {
Expr::Agg(AggExpr::Sum(_)) => "Sum",
Expr::Len => "Frame Length",
Expr::Agg(AggExpr::Count {
input: _,
include_nulls,
}) => {
if *include_nulls {
"Length"
} else {
"Count"
}
}
Expr::Function {
function: FunctionExpr::NullCount,
..
} => "Null Count",
Expr::Agg(AggExpr::NUnique(_)) => "N Unique",
expr => return fallible!(FailedFunction, "unrecognized aggregation: {:?}", expr),
}
.to_string())
}