use whitenoise_validator::errors::*;
use crate::components::Evaluable;
use whitenoise_validator::base::{Value, Array, Jagged, ReleaseNode, IndexKey};
use whitenoise_validator::utilities::{standardize_numeric_argument, standardize_categorical_argument, standardize_weight_argument, take_argument, standardize_null_candidates_argument};
use crate::NodeArguments;
use crate::utilities::{noise};
use crate::utilities;
use ndarray::{ArrayD};
use crate::utilities::get_num_columns;
use whitenoise_validator::{proto, Float};
use std::hash::Hash;
impl Evaluable for proto::Impute {
fn evaluate(&self, privacy_definition: &Option<proto::PrivacyDefinition>, mut arguments: NodeArguments) -> Result<ReleaseNode> {
let enforce_constant_time = privacy_definition.as_ref()
.map(|v| v.protect_elapsed_time).unwrap_or(false);
if arguments.contains_key::<IndexKey>(&"categories".into()) {
let weights = take_argument(&mut arguments, "weights")
.and_then(|v| v.jagged()).and_then(|v| v.float()).ok();
Ok(ReleaseNode::new(match (
take_argument(&mut arguments, "data")?.array()?,
take_argument(&mut arguments, "categories")?.jagged()?,
take_argument(&mut arguments, "null_values")?.jagged()?) {
(Array::Bool(data), Jagged::Bool(categories), Jagged::Bool(nulls)) =>
impute_categorical(data, categories, weights, nulls, enforce_constant_time)?.into(),
(Array::Float(_), Jagged::Float(_), Jagged::Float(_)) =>
return Err("categorical imputation over floats is not currently supported".into()),
(Array::Int(data), Jagged::Int(categories), Jagged::Int(nulls)) =>
impute_categorical(data, categories, weights, nulls, enforce_constant_time)?.into(),
(Array::Str(data), Jagged::Str(categories), Jagged::Str(nulls)) =>
impute_categorical(data, categories, weights, nulls, enforce_constant_time)?.into(),
_ => return Err("types of data, categories, and null must be consistent and probabilities must be f64".into()),
}))
}
else {
let distribution = match take_argument(&mut arguments, "distribution") {
Ok(distribution) => distribution.array()?.first_string()?,
Err(_) => "Uniform".to_string()
};
match distribution.to_lowercase().as_str() {
"uniform" => {
Ok(match (take_argument(&mut arguments, "data")?, take_argument(&mut arguments, "lower")?, take_argument(&mut arguments, "upper")?) {
(Value::Array(data), Value::Array(lower), Value::Array(upper)) => match (data, lower, upper) {
(Array::Float(data), Array::Float(lower), Array::Float(upper)) =>
impute_float_uniform(data, lower, upper, enforce_constant_time)?.into(),
(Array::Int(data), Array::Int(_lower), Array::Int(_upper)) =>
data.into(),
_ => return Err("data, lower, and upper must all be the same type".into())
},
_ => return Err("data, lower, upper, shift, and scale must be ArrayND".into())
})
},
"gaussian" => {
let data = take_argument(&mut arguments, "data")?.array()?.float()?;
let lower = take_argument(&mut arguments, "lower")?.array()?.float()?;
let upper = take_argument(&mut arguments, "upper")?.array()?.float()?;
let scale = take_argument(&mut arguments, "scale")?.array()?.float()?;
let shift = take_argument(&mut arguments, "shift")?.array()?.float()?;
Ok(impute_float_gaussian(data, lower, upper, shift, scale, enforce_constant_time)?.into())
},
_ => return Err("Distribution not supported".into())
}.map(ReleaseNode::new)
}
}
}
pub fn impute_float_uniform(
mut data: ArrayD<Float>,
lower: ArrayD<Float>, upper: ArrayD<Float>,
enforce_constant_time: bool
) -> Result<ArrayD<Float>> {
let num_columns = get_num_columns(&data)?;
data.gencolumns_mut().into_iter()
.zip(standardize_numeric_argument(lower, num_columns)?.into_iter())
.zip(standardize_numeric_argument(upper, num_columns)?.into_iter())
.try_for_each(|((mut column, min), max)| column.iter_mut()
.filter(|v| v.is_nan())
.try_for_each(|v| {
*v = noise::sample_uniform(
*min as f64, *max as f64, enforce_constant_time)? as Float;
Ok::<_, Error>(())
}))?;
Ok(data)
}
pub fn impute_float_gaussian(
mut data: ArrayD<Float>,
lower: ArrayD<Float>, upper: ArrayD<Float>,
shift: ArrayD<Float>, scale: ArrayD<Float>,
enforce_constant_time: bool
) -> Result<ArrayD<Float>> {
let num_columns = get_num_columns(&data)?;
data.gencolumns_mut().into_iter()
.zip(standardize_numeric_argument(lower, num_columns)?.into_iter()
.zip(standardize_numeric_argument(upper, num_columns)?.into_iter()))
.zip(standardize_numeric_argument(shift, num_columns)?.into_iter()
.zip(standardize_numeric_argument(scale, num_columns)?.into_iter()))
.try_for_each(|((mut column, (min, max)), (shift, scale))| column.iter_mut()
.filter(|v| v.is_nan())
.try_for_each(|v| {
*v = noise::sample_gaussian_truncated(
*min as f64, *max as f64, *shift as f64, *scale as f64,
enforce_constant_time
)? as Float;
Ok::<_, Error>(())
}))?;
Ok(data)
}
pub fn impute_categorical<T: Clone>(
mut data: ArrayD<T>, categories: Vec<Vec<T>>,
weights: Option<Vec<Vec<Float>>>, null_value: Vec<Vec<T>>,
enforce_constant_time: bool
) -> Result<ArrayD<T>> where T: Clone, T: PartialEq, T: Default, T: Ord, T: Hash {
let num_columns = get_num_columns(&data)?;
let categories = standardize_categorical_argument(categories.to_vec(), num_columns)?;
let lengths = categories.iter().map(|cats| cats.len() as i64).collect::<Vec<i64>>();
let probabilities = standardize_weight_argument(&weights, &lengths)?;
let null_value = standardize_null_candidates_argument(null_value, num_columns)?;
data.gencolumns_mut().into_iter()
.zip(categories.iter())
.zip(probabilities.iter())
.zip(null_value.iter())
.try_for_each(|(((mut column, cats), probs), null)| column.iter_mut()
.filter(|v| null.contains(v))
.try_for_each(|v| {
*v = utilities::sample_from_set(&cats, &probs, enforce_constant_time)?;
Ok::<_, Error>(())
}))?;
Ok(data)
}