use std::collections::HashSet;
use std::hash::Hash;
use nalgebra::{DVector, Scalar};
use crate::basis_function::BasisFunction;
use crate::model::builder::error::ModelBuildError;
pub fn check_parameter_names<StrType>(param_names: &[StrType]) -> Result<(), ModelBuildError>
where
StrType: Hash + Eq + Clone + Into<String>,
{
if param_names.is_empty() {
return Err(ModelBuildError::EmptyParameters);
}
if let Some(param_name) = param_names.iter().find(|&p| p.clone().into().contains(',')) {
return Err(ModelBuildError::CommaInParameterNameNotAllowed {
param_name: param_name.clone().into(),
});
}
if !has_only_unique_elements(param_names.iter()) {
let function_parameters: Vec<String> =
param_names.iter().cloned().map(|n| n.into()).collect();
Err(ModelBuildError::DuplicateParameterNames {
function_parameters,
})
} else {
Ok(())
}
}
fn has_only_unique_elements<T>(iter: T) -> bool
where
T: IntoIterator,
T::Item: Eq + Hash,
{
let mut uniq = HashSet::new();
iter.into_iter().all(move |x| uniq.insert(x))
}
pub fn create_index_mapping<T1, T2>(
full: &[T1],
subset: &[T2],
) -> Result<Vec<usize>, ModelBuildError>
where
T1: Clone + PartialEq + PartialEq<T2>,
T2: Clone + PartialEq + Into<String>,
{
let indices = subset.iter().map(|value_subset| {
full.iter()
.position(|value_full| value_full == value_subset)
.ok_or_else(|| ModelBuildError::FunctionParameterNotInModel {
function_parameter: value_subset.clone().into(),
})
});
indices.collect()
}
#[allow(clippy::type_complexity)]
pub fn create_wrapped_basis_function<ScalarType, ArgList, F, StrType, StrType2>(
model_parameters: &[StrType],
function_parameters: &[StrType2],
function: F,
) -> Result<Box<dyn Fn(&DVector<ScalarType>, &[ScalarType]) -> DVector<ScalarType>>, ModelBuildError>
where
ScalarType: Scalar,
F: BasisFunction<ScalarType, ArgList> + 'static,
StrType: Into<String> + Clone + Hash + Eq + PartialEq<StrType2>,
StrType2: Into<String> + Clone + Hash + Eq,
String: PartialEq<StrType> + PartialEq<StrType2>,
{
check_parameter_names(model_parameters)?;
check_parameter_names(function_parameters)?;
check_parameter_count(function_parameters, &function)?;
let index_mapping = create_index_mapping(model_parameters, function_parameters)?;
let wrapped = move |x: &DVector<ScalarType>, params: &[ScalarType]| {
let mut parameters_for_function = Vec::<ScalarType>::with_capacity(index_mapping.len());
for param_idx in index_mapping.iter() {
parameters_for_function.push(params[*param_idx].clone());
}
function.eval(x, ¶meters_for_function)
};
Ok(Box::new(wrapped))
}
pub fn check_parameter_count<StrType, ScalarType, F, ArgList>(
function_parameters: &[StrType],
_function: &F,
) -> Result<(), ModelBuildError>
where
StrType: Into<String> + Clone,
F: BasisFunction<ScalarType, ArgList> + 'static,
ScalarType: Scalar,
{
if function_parameters.len() == F::ARGUMENT_COUNT {
Ok(())
} else {
Err(ModelBuildError::IncorrectParameterCount {
params: function_parameters
.iter()
.cloned()
.map(|p| p.into())
.collect(),
string_params_count: function_parameters.len(),
function_argument_count: F::ARGUMENT_COUNT,
})
}
}
#[cfg(test)]
mod test {
use super::*;
fn dummy_unit_function_for_parameters<ScalarType>(
_x: &DVector<ScalarType>,
param1: ScalarType,
param2: ScalarType,
) -> DVector<ScalarType>
where
ScalarType: Scalar,
{
DVector::from(vec![param1, param2])
}
fn dummy_unit_function_for_x(x: &DVector<f64>, _param1: f64, _param2: f64) -> DVector<f64> {
DVector::<f64>::clone(x)
}
#[test]
fn test_has_only_unique_elements() {
assert!(!has_only_unique_elements(vec![10, 20, 30, 10, 50]));
assert!(has_only_unique_elements(vec![10, 20, 30, 40, 50]));
assert!(has_only_unique_elements(Vec::<u8>::new()));
}
#[test]
fn test_check_parameter_names() {
assert!(matches!(
check_parameter_names(&Vec::<String>::default()),
Err(ModelBuildError::EmptyParameters)
));
assert!(check_parameter_names(&["a"]).is_ok());
assert!(check_parameter_names(&["a", "b", "c"]).is_ok());
assert!(matches!(
check_parameter_names(&["a", "b", "b"]),
Err(ModelBuildError::DuplicateParameterNames { .. })
));
assert!(matches!(
check_parameter_names(&["a,b", "c"]),
Err(ModelBuildError::CommaInParameterNameNotAllowed { .. })
));
}
#[test]
fn test_create_index_mapping() {
let full_set = ['A', 'B', 'C', 'D'];
assert_eq!(
create_index_mapping(&full_set, &Vec::<char>::new()),
Ok(Vec::new()),
"Empty subset produces must produce empty index list"
);
assert!(
create_index_mapping(&Vec::<char>::new(), &['B', 'A']).is_err(),
"Empty full set must produce an error"
);
assert_eq!(
create_index_mapping(&Vec::<char>::new(), &Vec::<char>::new()),
Ok(Vec::new()),
"Empty subset must produce empty index list even if full set is empty"
);
assert_eq!(
create_index_mapping(&full_set, &['B', 'A']),
Ok(vec! {1, 0}),
"Indices must be correctly assigned"
);
assert!(
create_index_mapping(&full_set, &['Z', 'Q']).is_err(),
"Indices that are not in the full set must produce an error"
);
assert_eq!(
create_index_mapping(&['A', 'A', 'B', 'D'], &['B', 'A']),
Ok(vec! {2, 0}),
"For duplicates in the full set, the first index is used"
);
}
#[test]
fn test_create_wrapped_function_gives_error_for_empty_function_parameters_or_duplicate_elements(
) {
let model_parameters = vec!["a".to_string(), "b".to_string(), "c".to_string()];
assert!(
matches!(
create_wrapped_basis_function(
&model_parameters,
&Vec::<String>::new(),
dummy_unit_function_for_x
),
Err(ModelBuildError::EmptyParameters)
),
"creating wrapper function with empty parameter list should report error"
);
assert!(
matches!(
create_wrapped_basis_function(
&model_parameters,
&["a", "a"],
dummy_unit_function_for_x
),
Err(ModelBuildError::DuplicateParameterNames { .. })
),
"creating wrapper function with duplicates in function params should report error"
);
assert!(matches!(
create_wrapped_basis_function(
&model_parameters,
&["a","b","c","d","e"],
dummy_unit_function_for_x
)
,Err(ModelBuildError::IncorrectParameterCount {..})),
"creating wrapper function with a different number of function parameters than the argument list of function takes"
);
}
#[test]
fn creating_wrapped_basis_function_dispatches_elements_correctly_to_underlying_function() {
let model_parameters = vec!["a", "b", "c", "d"];
let function_parameters = vec!["c", "a"];
let x = DVector::<f64>::from(vec![1., 3., 3., 7.]);
let params = vec![1., 2., 3., 4.];
assert_eq!(
dummy_unit_function_for_parameters(&x, params[0], params[1]),
DVector::from(vec! {params[0],params[1]}),
"dummy function must return parameters passed to it"
);
assert_eq!(
dummy_unit_function_for_x(&x, params[0], params[1]),
x,
"dummy function must return the x argument passed to it"
);
let expected_out_params = DVector::<f64>::from(vec![3., 1.]);
let wrapped_function_params = create_wrapped_basis_function(
model_parameters.as_slice(),
function_parameters.as_slice(),
dummy_unit_function_for_parameters,
)
.unwrap();
assert_eq!(
wrapped_function_params(&x, params.as_slice()),
expected_out_params,
"Wrapped function must assign the correct function params from model params"
);
let wrapped_function_x = create_wrapped_basis_function(
&model_parameters,
&function_parameters,
dummy_unit_function_for_x,
)
.unwrap();
assert_eq!(
wrapped_function_x(&x, params.as_slice()),
x,
"Wrapped function must pass the location argument unaltered"
);
}
}