use std::sync::Arc;
use crate::{
core::Function,
domains::ExprPlan,
error::Fallible,
interactive::{Answer, Query, Queryable},
measurements::{
expr_dp_counting_query::{DPCountShim, DPLenShim, DPNUniqueShim, DPNullCountShim},
expr_dp_frame_len::DPFrameLenShim,
expr_dp_mean::DPMeanShim,
expr_dp_median::DPMedianShim,
expr_dp_quantile::DPQuantileShim,
expr_dp_sum::DPSumShim,
expr_noise::NoiseShim,
expr_noisy_max::NoisyMaxShim,
},
};
use polars::prelude::AnonymousColumnsUdf;
use polars::{
frame::DataFrame,
lazy::frame::LazyFrame,
prelude::{AnyValue, DslPlan, LazySerde, NULL, len, lit, repeat},
series::Series,
};
#[cfg(feature = "ffi")]
use polars_plan::dsl::FunctionExpr;
use polars_plan::{
dsl::{Expr, SpecialEq},
plans::{LiteralValue, Null},
prelude::FunctionOptions,
};
#[cfg(feature = "ffi")]
use serde::{Deserialize, Serialize};
#[cfg(test)]
mod test;
#[cfg(not(feature = "ffi"))]
pub(crate) trait OpenDPPlugin: 'static + Clone + AnonymousColumnsUdf {
const NAME: &'static str;
fn function_options() -> FunctionOptions;
}
#[cfg(feature = "ffi")]
pub(crate) trait OpenDPPlugin:
'static + Clone + AnonymousColumnsUdf + for<'de> Deserialize<'de> + Serialize
{
const NAME: &'static str;
const SHIM: bool = false;
fn function_options() -> FunctionOptions;
}
#[cfg(feature = "ffi")]
static OPENDP_LIB_NAME: &str = "opendp";
pub(crate) fn match_plugin<'e, KW>(expr: &'e Expr) -> Fallible<Option<&'e Vec<Expr>>>
where
KW: OpenDPPlugin,
{
Ok(Some(match expr {
#[cfg(feature = "ffi")]
Expr::Function {
input,
function:
FunctionExpr::FfiPlugin {
lib,
symbol,
kwargs, ..
},
..
} => {
if !lib.contains(OPENDP_LIB_NAME) || symbol.as_str() != KW::NAME {
return Ok(None);
}
if kwargs.len() > 3 {
return fallible!(
FailedFunction,
"OpenDP does not allow pickled keyword arguments as they may enable remote code execution."
);
}
input
}
Expr::AnonymousFunction {
input, function, ..
} => {
if function
.clone()
.materialize()?
.as_any()
.downcast_ref::<KW>()
.is_none()
{
return Ok(None);
};
input
}
_ => return Ok(None),
}))
}
pub(crate) fn match_trusted_plugin<'e, KW>(expr: &'e Expr) -> Fallible<Option<(&'e Vec<Expr>, KW)>>
where
KW: OpenDPPlugin,
{
Ok(Some(match expr {
#[cfg(feature = "ffi")]
Expr::Function {
input,
function:
FunctionExpr::FfiPlugin {
lib,
symbol,
kwargs,
..
},
..
} => {
if !lib.contains(OPENDP_LIB_NAME) || symbol.as_str() != KW::NAME {
return Ok(None);
}
let args = serde_pickle::from_slice(kwargs.as_ref(), Default::default())
.map_err(|e| err!(FailedFunction, "{}", e))?;
(input, args)
}
Expr::AnonymousFunction {
input, function, ..
} => {
let function = function.clone().materialize()?;
let Some(args) = function.as_any().downcast_ref::<KW>() else {
return Ok(None);
};
(input, args.clone())
}
_ => return Ok(None),
}))
}
pub(crate) fn match_shim<P: OpenDPPlugin, const V: usize>(
expr: &Expr,
) -> Fallible<Option<[Expr; V]>> {
let Some(input) = match_plugin::<P>(expr)? else {
return Ok(None);
};
if input.len() > V {
return fallible!(
MakeMeasurement,
"{} expects no more than {V} arguments",
P::NAME
);
}
let input = [input.clone(), vec![lit(NULL); V - input.len()]].concat();
let args = <[_; V]>::try_from(input).expect("input always has expected length");
Ok(Some(args))
}
pub(crate) fn apply_plugin<KW: OpenDPPlugin>(
input_exprs: Vec<Expr>,
plugin_expr: Expr,
kwargs_new: KW,
) -> Expr {
match plugin_expr {
#[cfg(feature = "ffi")]
Expr::Function {
input: _, function,
} => {
let lib = if let Ok(path) = std::env::var("OPENDP_POLARS_LIB_PATH") {
path.into()
} else if let FunctionExpr::FfiPlugin { lib, .. } = function {
lib
} else {
unreachable!("plugin expressions are always an FfiPlugin")
};
Expr::Function {
input: input_exprs,
function: FunctionExpr::FfiPlugin {
flags: KW::function_options(),
lib,
symbol: KW::NAME.into(),
kwargs: if KW::SHIM {
Default::default()
} else {
serde_pickle::to_vec(&kwargs_new, Default::default())
.expect("pickling does not fail")
.as_slice()
.into()
},
},
}
}
Expr::AnonymousFunction { .. } => Expr::AnonymousFunction {
input: input_exprs,
fmt_str: Box::new(KW::NAME.into()),
function: LazySerde::Deserialized(SpecialEq::new(Arc::new(kwargs_new))),
options: KW::function_options(),
},
_ => unreachable!("only called after constructor checks"),
}
}
pub(crate) fn apply_anonymous_function<KW: OpenDPPlugin>(input: Vec<Expr>, kwargs: KW) -> Expr {
Expr::AnonymousFunction {
input,
fmt_str: Box::new(KW::NAME.into()),
function: LazySerde::Deserialized(SpecialEq::new(Arc::new(kwargs.clone()))),
options: KW::function_options(),
}
}
pub(crate) fn literal_value_of<T: ExtractValue>(expr: &Expr) -> Fallible<Option<T>> {
let Expr::Literal(literal) = expr else {
return fallible!(FailedFunction, "Expected literal, found: {:?}", expr);
};
T::extract(literal.clone())
}
pub(crate) trait ExtractValue: Sized {
fn extract(literal: LiteralValue) -> Fallible<Option<Self>>;
}
macro_rules! impl_extract_value_number {
($($ty:ty)+) => {$(impl ExtractValue for $ty {
fn extract(literal: LiteralValue) -> Fallible<Option<Self>> {
if literal.is_null() {
return Ok(None);
}
Ok(Some(literal
.to_any_value()
.ok_or_else(|| err!(FailedFunction))?
.try_extract()?))
}
})+}
}
impl_extract_value_number!(u8 u16 u32 u64 i8 i16 i32 i64 f32 f64);
impl ExtractValue for bool {
fn extract(literal: LiteralValue) -> Fallible<Option<Self>> {
let any_value = literal.to_any_value().ok_or_else(|| err!(FailedFunction))?;
if matches!(any_value, AnyValue::Null) {
return Ok(None);
}
let AnyValue::Boolean(value) = any_value else {
return fallible!(FailedFunction, "expected boolean, found {:?}", any_value);
};
Ok(Some(value))
}
}
impl ExtractValue for Series {
fn extract(literal: LiteralValue) -> Fallible<Option<Self>> {
if literal.is_null() {
return Ok(None);
}
Ok(match literal {
LiteralValue::Series(series) => Some((*series).clone()),
_ => return fallible!(FailedFunction, "expected series, found: {:?}", literal),
})
}
}
impl ExtractValue for String {
fn extract(literal: LiteralValue) -> Fallible<Option<Self>> {
if literal.is_null() {
return Ok(None);
}
literal
.extract_str()
.map(|s| Some(s.to_string()))
.ok_or_else(|| err!(FailedFunction, "expected String, found: {:?}", literal))
}
}
impl Function<ExprPlan, ExprPlan> {
pub(crate) fn then_expr(function: impl Fn(Expr) -> Expr + 'static + Send + Sync) -> Self {
Self::new(move |arg: &ExprPlan| arg.then(&function))
}
}
impl Function<DslPlan, ExprPlan> {
pub(crate) fn from_expr(expr: Expr) -> Self {
Self::new_fallible(move |arg: &DslPlan| -> Fallible<ExprPlan> {
Ok(ExprPlan {
plan: arg.clone(),
expr: expr.clone(),
fill: None,
})
})
}
}
impl<TI: 'static> Function<TI, ExprPlan> {
pub(crate) fn fill_with(self, value: Expr) -> Self {
let fill = repeat(value.clone(), len());
Self::new_fallible(move |arg: &TI| {
let mut plan = self.eval(arg)?;
plan.fill = Some(fill.clone());
Ok(plan)
})
}
}
pub trait PrivacyNamespace {
fn dp(self) -> DPExpr;
}
impl PrivacyNamespace for Expr {
fn dp(self) -> DPExpr {
DPExpr(self)
}
}
pub struct DPExpr(Expr);
impl DPExpr {
pub fn noise(self, scale: Option<f64>) -> Expr {
let scale = scale.map(lit).unwrap_or_else(|| lit(Null {}));
apply_anonymous_function(vec![self.0, scale], NoiseShim)
}
pub fn len(self, scale: Option<f64>) -> Expr {
let scale = scale.map(lit).unwrap_or_else(|| lit(Null {}));
apply_anonymous_function(vec![self.0, scale], DPLenShim)
}
pub fn count(self, scale: Option<f64>) -> Expr {
let scale = scale.map(lit).unwrap_or_else(|| lit(Null {}));
apply_anonymous_function(vec![self.0, scale], DPCountShim)
}
pub fn null_count(self, scale: Option<f64>) -> Expr {
let scale = scale.map(lit).unwrap_or_else(|| lit(Null {}));
apply_anonymous_function(vec![self.0, scale], DPNullCountShim)
}
pub fn n_unique(self, scale: Option<f64>) -> Expr {
let scale = scale.map(lit).unwrap_or_else(|| lit(Null {}));
apply_anonymous_function(vec![self.0, scale], DPNUniqueShim)
}
pub fn sum(self, bounds: (Expr, Expr), scale: Option<f64>) -> Expr {
let scale = scale.map(lit).unwrap_or_default();
apply_anonymous_function(vec![self.0, bounds.0, bounds.1, scale], DPSumShim)
}
pub fn mean(self, bounds: (Expr, Expr), scale: Option<f64>) -> Expr {
let scale = scale.map(lit).unwrap_or_default();
apply_anonymous_function(vec![self.0, bounds.0, bounds.1, scale], DPMeanShim)
}
pub fn noisy_max(self, negate: bool, scale: Option<f64>) -> Expr {
let negate = lit(negate);
let scale = scale.map(lit).unwrap_or_else(|| lit(Null {}));
apply_anonymous_function(vec![self.0, negate, scale], NoisyMaxShim)
}
pub fn quantile(self, alpha: f64, candidates: Series, scale: Option<f64>) -> Expr {
let scale = scale.map(lit).unwrap_or_else(|| lit(Null {}));
apply_anonymous_function(
vec![self.0, lit(alpha), lit(candidates), scale],
DPQuantileShim,
)
}
pub fn median(self, candidates: Series, scale: Option<f64>) -> Expr {
let scale = scale.map(lit).unwrap_or_else(|| lit(Null {}));
apply_anonymous_function(vec![self.0, lit(candidates), scale], DPMedianShim)
}
}
pub fn dp_len(scale: Option<f64>) -> Expr {
let scale = scale.map(lit).unwrap_or_else(|| lit(Null {}));
apply_anonymous_function(vec![scale], DPFrameLenShim)
}
pub enum OnceFrameQuery {
Collect,
}
pub enum OnceFrameAnswer {
Collect(DataFrame),
}
pub(crate) struct ExtractLazyFrame;
pub type OnceFrame = Queryable<OnceFrameQuery, OnceFrameAnswer>;
impl From<LazyFrame> for OnceFrame {
fn from(value: LazyFrame) -> Self {
let mut state = Some(value);
Self::new_raw(move |_self: &Self, query: Query<OnceFrameQuery>| {
let Some(lazyframe) = state.clone() else {
return fallible!(FailedFunction, "OnceFrame has been exhausted");
};
Ok(match query {
Query::External(q_external) => Answer::External(match q_external {
OnceFrameQuery::Collect => {
let dataframe = lazyframe.collect()?;
let n = dataframe.height();
let dataframe = dataframe.sample_n_literal(n, false, true, None)?;
state.take();
OnceFrameAnswer::Collect(dataframe)
}
}),
Query::Internal(q_internal) => Answer::Internal({
if q_internal.downcast_ref::<ExtractLazyFrame>().is_some() {
Box::new(lazyframe)
} else {
return fallible!(FailedFunction, "Unrecognized internal query");
}
}),
})
})
}
}
impl OnceFrame {
pub fn collect(mut self) -> Fallible<DataFrame> {
if let Answer::External(OnceFrameAnswer::Collect(dataframe)) =
self.eval_query(Query::External(&OnceFrameQuery::Collect))?
{
Ok(dataframe)
} else {
fallible!(
FailedFunction,
"Collect returned invalid answer: Please report this bug"
)
}
}
#[cfg(feature = "honest-but-curious")]
pub fn lazyframe(&mut self) -> LazyFrame {
let answer = self.eval_query(Query::Internal(&ExtractLazyFrame)).unwrap();
let Answer::Internal(boxed) = answer else {
panic!("failed to extract");
};
let Ok(lazyframe) = boxed.downcast() else {
panic!("failed to extract");
};
*lazyframe
}
}
pub(crate) fn get_disabled_features_message() -> String {
#[allow(unused_mut)]
let mut disabled_features: Vec<&'static str> = vec![];
#[cfg(not(feature = "contrib"))]
disabled_features.push("contrib");
#[cfg(not(feature = "floating-point"))]
disabled_features.push("floating-point");
#[cfg(not(feature = "honest-but-curious"))]
disabled_features.push("honest-but-curious");
if disabled_features.is_empty() {
String::new()
} else {
format!(
"This may be due to disabled features: {}. ",
disabled_features.join(", ")
)
}
}