#[cfg(feature = "ffi")]
mod ffi;
use std::collections::HashMap;
use std::iter::FromIterator;
use opendp_derive::bootstrap;
use crate::core::{MetricSpace, Transformation};
use crate::domains::{AtomDomain, OptionDomain, VectorDomain};
use crate::error::Fallible;
use crate::metrics::EventLevelMetric;
use crate::traits::{Hashable, Number, Primitive};
use crate::transformations::make_row_by_row;
#[bootstrap(
features("contrib"),
generics(TIA(suppress), M(suppress)),
derived_types(TIA = "$get_atom(get_type(input_domain))")
)]
pub fn make_find<M, TIA>(
input_domain: VectorDomain<AtomDomain<TIA>>,
input_metric: M,
categories: Vec<TIA>,
) -> Fallible<
Transformation<
VectorDomain<AtomDomain<TIA>>,
M,
VectorDomain<OptionDomain<AtomDomain<usize>>>,
M,
>,
>
where
TIA: Hashable,
M: EventLevelMetric,
(VectorDomain<AtomDomain<TIA>>, M): MetricSpace,
(VectorDomain<OptionDomain<AtomDomain<usize>>>, M): MetricSpace,
{
let categories_len = categories.len();
let indexes =
HashMap::<TIA, usize>::from_iter(categories.into_iter().enumerate().map(|(i, v)| (v, i)));
if indexes.len() != categories_len {
return fallible!(MakeTransformation, "categories must be unique");
}
make_row_by_row(
input_domain,
input_metric,
OptionDomain::new(AtomDomain::default()),
move |v| indexes.get(v).cloned(),
)
}
#[bootstrap(
features("contrib"),
generics(TIA(suppress), M(suppress)),
derived_types(TIA = "$get_atom(get_type(input_domain))")
)]
pub fn make_find_bin<M, TIA>(
input_domain: VectorDomain<AtomDomain<TIA>>,
input_metric: M,
edges: Vec<TIA>,
) -> Fallible<Transformation<VectorDomain<AtomDomain<TIA>>, M, VectorDomain<AtomDomain<usize>>, M>>
where
TIA: Number,
M: EventLevelMetric,
(VectorDomain<AtomDomain<TIA>>, M): MetricSpace,
(VectorDomain<AtomDomain<usize>>, M): MetricSpace,
{
if !edges.windows(2).all(|pair| pair[0] < pair[1]) {
return fallible!(MakeTransformation, "edges must be unique and ordered");
}
make_row_by_row(
input_domain,
input_metric,
AtomDomain::default(),
move |v| {
edges
.iter()
.enumerate()
.find(|(_, edge)| v < edge)
.map(|(i, _)| i)
.unwrap_or(edges.len())
},
)
}
#[bootstrap(features("contrib"), generics(M(suppress)))]
pub fn make_index<M, TOA>(
input_domain: VectorDomain<AtomDomain<usize>>,
input_metric: M,
categories: Vec<TOA>,
null: TOA,
) -> Fallible<Transformation<VectorDomain<AtomDomain<usize>>, M, VectorDomain<AtomDomain<TOA>>, M>>
where
TOA: Primitive,
M: EventLevelMetric,
(VectorDomain<AtomDomain<usize>>, M): MetricSpace,
(VectorDomain<AtomDomain<TOA>>, M): MetricSpace,
{
make_row_by_row(
input_domain,
input_metric,
AtomDomain::default(),
move |v| categories.get(*v).unwrap_or(&null).clone(),
)
}
#[cfg(test)]
mod test;