use std::collections::HashSet;
use std::fmt::Debug;
use std::hash::Hash;
use std::marker::PhantomData;
use opendp_derive::proven;
use polars::lazy::dsl::len;
use polars::prelude::*;
use crate::core::Domain;
use crate::metrics::{FrameDistance, LInfDistance, LpDistance, MicrodataMetric};
use crate::traits::{InfMul, ProductOrd};
use crate::transformations::traits::UnboundedMetric;
use crate::{core::MetricSpace, domains::SeriesDomain, error::Fallible};
use super::NumericDataType;
#[cfg(test)]
mod test;
#[cfg(feature = "ffi")]
pub(crate) mod ffi;
pub trait Frame: Clone + Send + Sync {
fn lazyframe(self) -> LazyFrame;
fn dataframe(self) -> Fallible<DataFrame>;
}
impl Frame for LazyFrame {
fn lazyframe(self) -> LazyFrame {
self
}
fn dataframe(self) -> Fallible<DataFrame> {
self.collect().map_err(Into::into)
}
}
impl Frame for DslPlan {
fn lazyframe(self) -> LazyFrame {
LazyFrame::from(self)
}
fn dataframe(self) -> Fallible<DataFrame> {
self.lazyframe().collect().map_err(Into::into)
}
}
impl Frame for DataFrame {
fn lazyframe(self) -> LazyFrame {
IntoLazy::lazy(self)
}
fn dataframe(self) -> Fallible<DataFrame> {
Ok(self)
}
}
#[derive(Clone)]
pub struct FrameDomain<F: Frame> {
pub series_domains: Vec<SeriesDomain>,
pub margins: Vec<Margin>,
_marker: PhantomData<F>,
}
impl<F: Frame> PartialEq for FrameDomain<F> {
fn eq(&self, other: &Self) -> bool {
self.series_domains == other.series_domains && self.margins == other.margins
}
}
pub type LazyFrameDomain = FrameDomain<LazyFrame>;
pub(crate) type DslPlanDomain = FrameDomain<DslPlan>;
impl<F: Frame, M: MicrodataMetric> MetricSpace for (FrameDomain<F>, M) {
fn check_space(&self) -> Fallible<()> {
if let Some(identifier) = self.1.identifier() {
identifier
.meta()
.root_names()
.into_iter()
.try_for_each(|n| self.0.series_domain(n).map(|_| ()))?;
}
Ok(())
}
}
impl<F: Frame, M: UnboundedMetric> MetricSpace for (FrameDomain<F>, FrameDistance<M>) {
fn check_space(&self) -> Fallible<()> {
(self.0.clone(), self.1.0.clone()).check_space()
}
}
impl<F: Frame, const P: usize, T: ProductOrd + NumericDataType> MetricSpace
for (FrameDomain<F>, LpDistance<P, T>)
{
fn check_space(&self) -> Fallible<()> {
if self
.0
.series_domains
.iter()
.any(|s| s.dtype() != T::dtype())
{
return fallible!(
MetricSpace,
"LpDistance requires columns of type {}",
T::dtype()
);
}
Ok(())
}
}
impl<Q> MetricSpace for (LazyFrameDomain, LInfDistance<Q>) {
fn check_space(&self) -> Fallible<()> {
Ok(())
}
}
impl<F: Frame> FrameDomain<F> {
pub fn new(series_domains: Vec<SeriesDomain>) -> Fallible<Self> {
Self::new_with_margins(series_domains, Vec::new())
}
pub(crate) fn new_with_margins(
series_domains: Vec<SeriesDomain>,
margins: Vec<Margin>,
) -> Fallible<Self> {
let n_unique = series_domains
.iter()
.map(|s| &s.name)
.collect::<HashSet<_>>()
.len();
if n_unique != series_domains.len() {
return fallible!(MakeDomain, "column names must be distinct");
}
Ok(FrameDomain {
series_domains,
margins,
_marker: PhantomData,
})
}
pub fn new_from_schema(schema: Schema) -> Fallible<Self> {
let series_domains = (schema.iter_fields())
.map(SeriesDomain::new_from_field)
.collect::<Fallible<_>>()?;
FrameDomain::new(series_domains)
}
pub fn schema(&self) -> Schema {
Schema::from_iter(
self.series_domains
.iter()
.map(|s| Field::new(s.name.clone(), s.dtype())),
)
}
pub(crate) fn simulate_schema(
&self,
plan: impl Fn(LazyFrame) -> LazyFrame,
) -> Fallible<Schema> {
let output = plan(DataFrame::empty_with_schema(&self.schema()).lazy())
.collect()
.map_err(|e| {
err!(
MakeTransformation,
"Failed to determine output dtypes: {}",
e
)
})?;
Ok((&**output.schema()).clone())
}
pub(crate) fn cast_carrier<FO: Frame>(&self) -> FrameDomain<FO> {
FrameDomain {
series_domains: self.series_domains.clone(),
margins: self.margins.clone(),
_marker: PhantomData,
}
}
#[must_use]
pub fn with_margin(mut self, margin: Margin) -> Fallible<Self> {
(margin.by.iter())
.map(|e| e.clone().meta().root_names())
.flatten()
.try_for_each(|name| self.series_domain(name).map(|_| ()))?;
if self.margins.iter().any(|m| m.by == margin.by) {
return fallible!(
MakeDomain,
"margin ({:?}) is already present in domain",
margin.by
);
}
self.margins.push(margin);
Ok(self)
}
#[proven]
pub fn get_margin(&self, by: &HashSet<Expr>) -> Margin {
let mut margin = (self.margins.iter())
.find(|m| &m.by == by)
.cloned()
.unwrap_or_else(|| Margin::by(by.iter().cloned().collect::<Vec<_>>()));
let coarser_margins = (self.margins.iter())
.filter(|m| m.by.is_subset(by))
.collect::<Vec<&Margin>>();
margin.max_length = coarser_margins.iter().filter_map(|m| m.max_length).min();
let all_max_groups = (self.margins.iter())
.filter_map(|m| Some((&m.by, m.max_groups?)))
.collect();
margin.max_groups = find_min_covering(by.clone(), all_max_groups).and_then(|cover| {
cover
.iter()
.try_fold(1u32, |acc, (_, v)| acc.inf_mul(v).ok())
});
margin.invariant = (self.margins.iter())
.filter(|m| by.is_subset(&m.by))
.map(|m| m.invariant)
.max()
.flatten();
margin
}
pub fn series_domain(&self, name: PlSmallStr) -> Fallible<SeriesDomain> {
self.series_domains
.iter()
.find(|s| s.name == name)
.cloned()
.ok_or_else(|| {
err!(
MakeTransformation,
"unrecognized column in series domain: {}",
name
)
})
}
}
impl<F: Frame> Debug for FrameDomain<F> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let margins_debug = self
.margins
.iter()
.map(|m| format!("{:?}", m.by))
.collect::<Vec<_>>()
.join(", ");
write!(
f,
"FrameDomain({}; margins=[{}])",
self.series_domains
.iter()
.map(|s| format!("{}: {}", s.name, s.dtype()))
.collect::<Vec<_>>()
.join(", "),
margins_debug
)
}
}
impl<F: Frame> Domain for FrameDomain<F> {
type Carrier = F;
fn member(&self, val: &Self::Carrier) -> Fallible<bool> {
let val_df = val.clone().dataframe()?;
if val_df.schema().len() != self.series_domains.len() {
return Ok(false);
}
for (series_domain, series) in self.series_domains.iter().zip(
val_df
.get_columns()
.iter()
.map(|c| c.as_materialized_series()),
) {
if !series_domain.member(series)? {
return Ok(false);
}
}
for margin in self.margins.iter() {
let by = margin.by.iter().cloned().collect::<Vec<_>>();
if !margin.member(val.clone().lazyframe().group_by(by))? {
return Ok(false);
}
}
Ok(true)
}
}
#[derive(Clone, PartialEq, Debug)]
pub struct Margin {
pub by: HashSet<Expr>,
pub max_length: Option<u32>,
pub max_groups: Option<u32>,
pub invariant: Option<Invariant>,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Ord)]
pub enum Invariant {
Keys,
Lengths,
}
impl PartialOrd for Invariant {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
(*self as usize).partial_cmp(&(*other as usize))
}
}
impl Margin {
pub fn select() -> Margin {
Self::by::<&[Expr], Expr>(&[])
}
pub fn by<E: AsRef<[IE]>, IE: Into<Expr> + Clone>(by: E) -> Self {
Self {
by: by.as_ref().iter().cloned().map(Into::into).collect(),
max_length: None,
max_groups: None,
invariant: None,
}
}
pub fn with_max_length(mut self, value: u32) -> Self {
self.max_length = Some(value);
self
}
pub fn with_max_groups(mut self, value: u32) -> Self {
self.max_groups = Some(value);
self
}
pub fn with_invariant_keys(mut self) -> Self {
self.invariant = Some(Invariant::Keys);
self
}
pub fn with_invariant_lengths(mut self) -> Self {
self.invariant = Some(Invariant::Lengths);
self
}
pub(crate) fn member(&self, value: LazyGroupBy) -> Fallible<bool> {
macro_rules! item {
($tgt:expr, $ty:ident) => {
($tgt.collect()?.get_columns()[0].$ty()?.get(0))
.ok_or_else(|| err!(FailedFunction))?
};
}
let max_part_length = value.clone().agg([len()]).select(&[max("len")]);
if item!(max_part_length, u32) > self.max_length.unwrap_or(u32::MAX) {
return Ok(false);
}
let max_num_parts = value.agg([]).select(&[len()]);
if item!(max_num_parts, u32) > self.max_groups.unwrap_or(u32::MAX) {
return Ok(false);
}
Ok(true)
}
pub(crate) fn l_0(&self, l_1: u32) -> u32 {
self.max_groups.unwrap_or(l_1).min(l_1)
}
pub(crate) fn l_inf(&self, l_1: u32) -> u32 {
self.max_length.unwrap_or(l_1).min(l_1)
}
}
#[proven]
pub(crate) fn find_min_covering<T: Hash + Eq>(
mut must_cover: HashSet<T>,
sets: Vec<(&HashSet<T>, u32)>,
) -> Option<Vec<(&HashSet<T>, u32)>> {
let mut covered = Vec::<(&HashSet<T>, u32)>::new();
while !must_cover.is_empty() {
let (best_set, weight) = sets
.iter()
.max_by_key(|(set, weight)| {
(
set.intersection(&must_cover).count(),
-(set.len() as i32),
-(*weight as i32),
)
})
.and_then(|(best_set, weight)| {
(!best_set.is_disjoint(&must_cover)).then(|| (best_set, *weight))
})?;
must_cover.retain(|x| !best_set.contains(x));
covered.push((&best_set, weight));
}
Some(covered)
}