use smartnoise_validator::errors::*;
use crate::NodeArguments;
use smartnoise_validator::base::{Value, Array, Jagged, ReleaseNode, IndexKey};
use smartnoise_validator::utilities::{standardize_numeric_argument, standardize_categorical_argument, standardize_null_target_argument, take_argument};
use crate::components::Evaluable;
use ndarray::ArrayD;
use crate::utilities::get_num_columns;
use smartnoise_validator::{proto, Float, Integer};
use std::hash::Hash;
impl Evaluable for proto::Clamp {
fn evaluate(&self, _privacy_definition: &Option<proto::PrivacyDefinition>, mut arguments: NodeArguments) -> Result<ReleaseNode> {
if arguments.contains_key::<IndexKey>(&"categories".into()) {
match (take_argument(&mut arguments, "data")?, take_argument(&mut arguments, "categories")?, take_argument(&mut arguments, "null_value")?) {
(Value::Array(data), Value::Jagged(categories), Value::Array(nulls)) => Ok(match (data, categories, nulls) {
(Array::Bool(data), Jagged::Bool(categories), Array::Bool(nulls)) =>
clamp_categorical(data, categories, nulls)?.into(),
(Array::Float(_), Jagged::Float(_), Array::Float(_)) =>
return Err("float clamping is not supported".into()),
(Array::Int(data), Jagged::Int(categories), Array::Int(nulls)) =>
clamp_categorical(data, categories, nulls)?.into(),
(Array::Str(data), Jagged::Str(categories), Array::Str(nulls)) =>
clamp_categorical(data, categories, nulls)?.into(),
_ => return Err("types of data, categories, and null must be consistent".into())
}),
_ => return Err("data must be ArrayND, categories must be Vector2DJagged, and null must be ArrayND".into())
}
}
else {
match (take_argument(&mut arguments, "data")?, take_argument(&mut arguments, "lower")?, take_argument(&mut arguments, "upper")?) {
(Value::Array(data), Value::Array(lower), Value::Array(upper)) => Ok(match (data, lower, upper) {
(Array::Float(data), Array::Float(lower), Array::Float(upper)) =>
clamp_numeric_float(data, lower, upper)?.into(),
(Array::Int(data), Array::Int(lower), Array::Int(upper)) =>
clamp_numeric_integer(data, lower, upper)?.into(),
_ => return Err("data, lower, and upper must all have type f64".into())
}),
_ => return Err("data, lower, and upper must all be ArrayND".into())
}
}.map(ReleaseNode::new)
}
}
pub fn clamp_numeric_float(
mut data: ArrayD<Float>, lower: ArrayD<Float>, upper: ArrayD<Float>
)-> Result<ArrayD<Float>> {
let num_columns = get_num_columns(&data)?;
data.gencolumns_mut().into_iter()
.zip(standardize_numeric_argument(lower, num_columns)?.into_iter())
.zip(standardize_numeric_argument(upper, num_columns)?.into_iter())
.for_each(|((mut column, min), max)| column.iter_mut()
.filter(|v| !v.is_nan())
.for_each(|v| *v = min.max(max.min(*v))));
Ok(data)
}
pub fn clamp_numeric_integer(
mut data: ArrayD<Integer>, lower: ArrayD<Integer>, upper: ArrayD<Integer>
)-> Result<ArrayD<Integer>> {
let num_columns = get_num_columns(&data)?;
data.gencolumns_mut().into_iter()
.zip(standardize_numeric_argument(lower, num_columns)?.into_iter())
.zip(standardize_numeric_argument(upper, num_columns)?.into_iter())
.for_each(|((mut column, min), max)| column.iter_mut()
.for_each(|v| *v = *min.max(max.min(v))));
Ok(data)
}
pub fn clamp_categorical<T: Ord + Hash + Clone>(
mut data: ArrayD<T>,
categories: Vec<Vec<T>>,
null_value: ArrayD<T>
) -> Result<ArrayD<T>> where T:Clone, T:PartialEq, T:Default {
let num_columns = get_num_columns(&data)?;
data.gencolumns_mut().into_iter()
.zip(standardize_categorical_argument(categories.to_vec(), num_columns)?)
.zip(standardize_null_target_argument(null_value, num_columns)?)
.for_each(|((mut column, categories), null)| column.iter_mut()
.filter(|v| !categories.contains(v))
.for_each(|v| *v = null.clone()));
Ok(data)
}