use std::sync::Arc;
use crate::core::PrivacyMap;
use crate::domains::{ArrayDomain, AtomDomain, ExprPlan, VectorDomain, WildExprDomain};
use crate::measurements::{TopKMeasure, make_noisy_max, noisy_top_k};
use crate::metrics::{IntDistance, L0InfDistance, L01InfDistance, LInfDistance};
use crate::polars::{OpenDPPlugin, apply_plugin, literal_value_of, match_plugin};
use crate::traits::{CastInternalRational, InfCast, InfMul, Number};
use crate::transformations::StableExpr;
use crate::transformations::traits::UnboundedMetric;
use crate::{
core::{Function, Measurement},
error::Fallible,
};
use dashu::float::FBig;
use polars::datatypes::{
DataType, Field, Float32Type, Float64Type, Int8Type, Int16Type, Int32Type, Int64Type,
PolarsDataType, UInt32Type, UInt64Type,
};
use polars::error::polars_bail;
use polars::error::{PolarsError, PolarsResult};
use polars::lazy::dsl::Expr;
use polars::prelude::{AnonymousColumnsUdf, Column, IntoColumn};
use polars::series::IntoSeries;
#[cfg(feature = "ffi")]
use polars::series::Series;
use polars_arrow::array::PrimitiveArray;
use polars_arrow::types::NativeType;
use polars_plan::dsl::ColumnsUdf;
use polars_plan::prelude::FunctionOptions;
use serde::{Deserialize, Serialize};
use super::approximate_c_stability;
#[cfg(test)]
mod test;
pub fn make_expr_noisy_max<MI: 'static + UnboundedMetric, MO: 'static + TopKMeasure>(
input_domain: WildExprDomain,
input_metric: L01InfDistance<MI>,
expr: Expr,
global_scale: Option<f64>,
) -> Fallible<Measurement<WildExprDomain, L01InfDistance<MI>, MO, ExprPlan>>
where
Expr: StableExpr<L01InfDistance<MI>, L0InfDistance<LInfDistance<f64>>>,
{
let (input, negate, scale) = match_noisy_max(&expr)?
.ok_or_else(|| err!(MakeMeasurement, "Expected {}", NoisyMaxPlugin::NAME))?;
let t_prior = input.clone().make_stable(input_domain, input_metric)?;
let (middle_domain, middle_metric) = t_prior.output_space();
if scale.is_none() && global_scale.is_none() {
return fallible!(
MakeMeasurement,
"{} requires a scale parameter",
NoisyMaxPlugin::NAME
);
}
let scale = match scale {
Some(scale) => scale,
None => {
let (l_0, l_inf) = approximate_c_stability(&t_prior)?;
f64::inf_cast(l_0)?.inf_mul(&l_inf)?
}
};
let global_scale = global_scale.unwrap_or(1.);
if scale.is_nan() || scale.is_sign_negative() {
return fallible!(
MakeMeasurement,
"{} scale must be a non-negative number",
NoisyMaxPlugin::NAME
);
}
if global_scale.is_nan() || global_scale.is_sign_negative() {
return fallible!(
MakeMeasurement,
"global_scale ({}) must be a non-negative number",
global_scale
);
}
let scale = scale.inf_mul(&global_scale)?;
if middle_domain.column.nullable {
return fallible!(
MakeMeasurement,
"{} requires non-nullable input",
NoisyMaxPlugin::NAME
);
}
let array_domain = middle_domain.column.element_domain::<ArrayDomain>()?;
use DataType::*;
let privacy_map = match array_domain.element_domain.dtype() {
UInt32 => rnm_privacy_map::<u32, _>(array_domain, scale, negate)?,
UInt64 => rnm_privacy_map::<u64, _>(array_domain, scale, negate)?,
Int8 => rnm_privacy_map::<i8, _>(array_domain, scale, negate)?,
Int16 => rnm_privacy_map::<i16, _>(array_domain, scale, negate)?,
Int32 => rnm_privacy_map::<i32, _>(array_domain, scale, negate)?,
Int64 => rnm_privacy_map::<i64, _>(array_domain, scale, negate)?,
Float32 => rnm_privacy_map::<f32, _>(array_domain, scale, negate)?,
Float64 => rnm_privacy_map::<f64, _>(array_domain, scale, negate)?,
_ => {
return fallible!(
MakeMeasurement,
"{} requires numeric array input",
NoisyMaxPlugin::NAME
);
}
};
let m_rnm = Measurement::<_, L0InfDistance<LInfDistance<f64>>, _, _>::new(
middle_domain,
middle_metric.clone(),
MO::default(),
Function::then_expr(move |input_expr| {
apply_plugin(
vec![input_expr],
expr.clone(),
NoisyMaxPlugin {
replacement: MO::REPLACEMENT,
negate,
scale,
},
)
}),
privacy_map,
)?;
t_prior >> m_rnm
}
fn rnm_privacy_map<T: Number, MO: TopKMeasure>(
array_domain: &ArrayDomain,
scale: f64,
negate: bool,
) -> Fallible<PrivacyMap<L0InfDistance<LInfDistance<f64>>, MO>>
where
T: Number + InfCast<f64> + CastInternalRational,
FBig: TryFrom<T> + TryFrom<f64>,
f64: InfCast<T> + InfCast<IntDistance>,
{
let atom_domain = array_domain
.element_domain
.as_any()
.downcast_ref::<AtomDomain<T>>()
.ok_or_else(|| err!(MakeMeasurement, "failed to downcast domain"))?
.clone();
let meas = make_noisy_max(
VectorDomain::new(atom_domain),
LInfDistance::default(),
MO::default(),
scale,
negate,
)?;
Ok(PrivacyMap::new_fallible(
move |(l0, li): &(IntDistance, f64)| {
let epsilon = meas.map(&T::inf_cast(*li)?)?;
f64::inf_cast(*l0)?.inf_mul(&epsilon)
},
))
}
pub(crate) fn match_noisy_max(expr: &Expr) -> Fallible<Option<(&Expr, bool, Option<f64>)>> {
let Some(input) = match_plugin::<NoisyMaxShim>(expr)? else {
return Ok(None);
};
let Ok([data, negate, scale]) = <&[_; 3]>::try_from(input.as_slice()) else {
return fallible!(
MakeMeasurement,
"{:?} expects three inputs",
NoisyMaxShim::NAME
);
};
let negate = literal_value_of::<bool>(negate)?.ok_or_else(|| {
err!(
MakeMeasurement,
"Negate must be true or false, found \"{}\".",
negate
)
})?;
let scale = literal_value_of::<f64>(scale)?;
Ok(Some((data, negate, scale)))
}
#[derive(Serialize, Deserialize, Clone)]
pub(crate) struct NoisyMaxShim;
impl ColumnsUdf for NoisyMaxShim {
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn call_udf(&self, _: &mut [Column]) -> PolarsResult<Column> {
polars_bail!(InvalidOperation: "OpenDP expressions must be passed through make_private_lazyframe to be executed.")
}
}
impl AnonymousColumnsUdf for NoisyMaxShim {
fn as_column_udf(self: Arc<Self>) -> Arc<dyn ColumnsUdf> {
self
}
fn deep_clone(self: Arc<Self>) -> Arc<dyn AnonymousColumnsUdf> {
Arc::new(Arc::unwrap_or_clone(self))
}
fn get_field(
&self,
_: &polars::prelude::Schema,
fields: &[polars::prelude::Field],
) -> PolarsResult<polars::prelude::Field> {
noisy_max_plugin_type_udf(fields)
}
}
impl OpenDPPlugin for NoisyMaxShim {
const NAME: &'static str = "noisy_max";
#[cfg(feature = "ffi")]
const SHIM: bool = true;
fn function_options() -> FunctionOptions {
FunctionOptions::elementwise()
}
}
#[derive(Serialize, Deserialize, Clone)]
pub(crate) struct NoisyMaxPlugin {
pub replacement: bool,
pub scale: f64,
pub negate: bool,
}
impl AnonymousColumnsUdf for NoisyMaxPlugin {
fn as_column_udf(self: Arc<Self>) -> Arc<dyn ColumnsUdf> {
self
}
fn deep_clone(self: Arc<Self>) -> Arc<dyn AnonymousColumnsUdf> {
Arc::new(Arc::unwrap_or_clone(self))
}
fn get_field(
&self,
_: &polars::prelude::Schema,
fields: &[polars::prelude::Field],
) -> PolarsResult<polars::prelude::Field> {
noisy_max_plugin_type_udf(fields)
}
}
impl OpenDPPlugin for NoisyMaxPlugin {
const NAME: &'static str = "noisy_max_plugin";
fn function_options() -> FunctionOptions {
FunctionOptions::elementwise()
}
}
impl ColumnsUdf for NoisyMaxPlugin {
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn call_udf(&self, s: &mut [Column]) -> PolarsResult<Column> {
noisy_max_udf(s, self.clone())
}
}
fn noisy_max_udf(inputs: &[Column], kwargs: NoisyMaxPlugin) -> PolarsResult<Column> {
let Ok([series]) = <&[_; 1]>::try_from(inputs) else {
polars_bail!(InvalidOperation: "{} expects a single input field", NoisyMaxPlugin::NAME);
};
let NoisyMaxPlugin {
replacement,
negate,
scale,
} = kwargs;
if scale.is_sign_negative() {
polars_bail!(InvalidOperation: "{} scale ({}) must be non-negative", NoisyMaxPlugin::NAME, scale);
}
fn rnm_impl<PT: 'static + PolarsDataType>(
column: &Column,
scale: f64,
negate: bool,
replacement: bool,
) -> PolarsResult<Column>
where
PT::Physical<'static>: NativeType + Number + CastInternalRational,
FBig: TryFrom<PT::Physical<'static>> + TryFrom<f64>,
f64: InfCast<PT::Physical<'static>>,
{
Ok(column
.as_materialized_series()
.array()?
.try_apply_nonnull_values_generic::<UInt32Type, _, _, _>(move |v| {
let arr = v
.as_any()
.downcast_ref::<PrimitiveArray<PT::Physical<'static>>>()
.ok_or_else(|| {
PolarsError::InvalidOperation("input dtype does not match".into())
})?;
let scores = arr.values_iter().cloned().collect::<Vec<_>>();
PolarsResult::Ok(noisy_top_k(&scores, scale, 1, negate, replacement)?[0] as u32)
})?
.into_series()
.into_column())
}
use DataType::*;
let Array(dtype, _) = series.dtype() else {
polars_bail!(InvalidOperation: "Expected array data type, found {:?}", series.dtype())
};
match dtype.as_ref() {
UInt32 => rnm_impl::<UInt32Type>(series, scale, negate, replacement),
UInt64 => rnm_impl::<UInt64Type>(series, scale, negate, replacement),
Int8 => rnm_impl::<Int8Type>(series, scale, negate, replacement),
Int16 => rnm_impl::<Int16Type>(series, scale, negate, replacement),
Int32 => rnm_impl::<Int32Type>(series, scale, negate, replacement),
Int64 => rnm_impl::<Int64Type>(series, scale, negate, replacement),
Float32 => rnm_impl::<Float32Type>(series, scale, negate, replacement),
Float64 => rnm_impl::<Float64Type>(series, scale, negate, replacement),
UInt8 | UInt16 => {
polars_bail!(InvalidOperation: "u8 and u16 not supported in the OpenDP Polars plugin. Please use u32 or u64.")
}
dtype => polars_bail!(InvalidOperation: "Expected numeric data type found {}", dtype),
}
}
#[cfg(feature = "ffi")]
#[pyo3_polars::derive::polars_expr(output_type=Null)]
fn noisy_max(_: &[Series]) -> PolarsResult<Series> {
polars_bail!(InvalidOperation: "OpenDP expressions must be passed through make_private_lazyframe to be executed.")
}
pub(crate) fn noisy_max_plugin_type_udf(input_fields: &[Field]) -> PolarsResult<Field> {
let Ok([field]) = <&[Field; 1]>::try_from(input_fields) else {
polars_bail!(InvalidOperation: "{} expects a single input field", NoisyMaxPlugin::NAME)
};
use DataType::*;
let Array(dtype, n) = field.dtype() else {
polars_bail!(InvalidOperation: "Expected array data type, found {:?}", field.dtype())
};
if *n == 0 {
polars_bail!(InvalidOperation: "Array must have a non-zero length");
}
if matches!(dtype.as_ref(), UInt8 | UInt16) {
polars_bail!(
InvalidOperation: "u8 and u16 not supported in the OpenDP Polars plugin. Please use u32 or u64."
);
}
if !matches!(
dtype.as_ref(),
UInt32 | UInt64 | Int8 | Int16 | Int32 | Int64 | Float32 | Float64
) {
polars_bail!(
InvalidOperation: "Expected numeric data type, found {:?}",
field.dtype()
);
}
Ok(Field::new(field.name().clone(), UInt32))
}
#[cfg(feature = "ffi")]
#[pyo3_polars::derive::polars_expr(output_type_func=noisy_max_plugin_type_udf)]
fn noisy_max_plugin(inputs: &[Series], kwargs: NoisyMaxPlugin) -> PolarsResult<Series> {
let inputs: Vec<Column> = inputs.iter().cloned().map(|s| s.into_column()).collect();
let out = noisy_max_udf(inputs.as_slice(), kwargs)?;
Ok(out.take_materialized_series())
}