use nalgebra::{DVector, Scalar};
use std::collections::HashSet;
use std::hash::Hash;
use crate::basis_function::BasisFunction;
use crate::model::builder::error::ModelBuildError;
pub fn check_parameter_names<StrType>(param_names: &[StrType]) -> Result<(), ModelBuildError>
where
StrType: Hash + Eq + Clone + Into<String>,
{
if param_names.is_empty() {
return Err(ModelBuildError::EmptyParameters);
}
if let Some(param_name) = param_names.iter().find(|&p| p.clone().into().contains(',')) {
return Err(ModelBuildError::CommaInParameterNameNotAllowed {
param_name: param_name.clone().into(),
});
}
if !has_only_unique_elements(param_names.iter()) {
let function_parameters: Vec<String> =
param_names.iter().cloned().map(|n| n.into()).collect();
Err(ModelBuildError::DuplicateParameterNames {
function_parameters,
})
} else {
Ok(())
}
}
fn has_only_unique_elements<T>(iter: T) -> bool
where
T: IntoIterator,
T::Item: Eq + Hash,
{
let mut uniq = HashSet::new();
iter.into_iter().all(move |x| uniq.insert(x))
}
pub fn create_index_mapping<T1, T2>(
full: &[T1],
subset: &[T2],
) -> Result<Vec<usize>, ModelBuildError>
where
T1: Clone + PartialEq + PartialEq<T2>,
T2: Clone + PartialEq + Into<String>,
{
let indices = subset.iter().map(|value_subset| {
full.iter()
.position(|value_full| value_full == value_subset)
.ok_or_else(|| ModelBuildError::FunctionParameterNotInModel {
function_parameter: value_subset.clone().into(),
})
});
indices.collect()
}
#[allow(clippy::type_complexity)]
pub fn create_wrapped_basis_function<ScalarType, ArgList, F, StrType, StrType2>(
model_parameters: &[StrType],
function_parameters: &[StrType2],
function: F,
) -> Result<
Box<dyn Fn(&DVector<ScalarType>, &[ScalarType]) -> DVector<ScalarType> + Send + Sync>,
ModelBuildError,
>
where
ScalarType: Scalar,
F: BasisFunction<ScalarType, ArgList> + 'static,
StrType: Into<String> + Clone + Hash + Eq + PartialEq<StrType2>,
StrType2: Into<String> + Clone + Hash + Eq,
String: PartialEq<StrType> + PartialEq<StrType2>,
{
check_parameter_names(model_parameters)?;
check_parameter_names(function_parameters)?;
check_parameter_count(function_parameters, &function)?;
let index_mapping = create_index_mapping(model_parameters, function_parameters)?;
let wrapped = move |x: &DVector<ScalarType>, params: &[ScalarType]| {
let mut parameters_for_function = Vec::<ScalarType>::with_capacity(index_mapping.len());
for param_idx in index_mapping.iter() {
parameters_for_function.push(params[*param_idx].clone());
}
function.eval(x, ¶meters_for_function)
};
Ok(Box::new(wrapped))
}
pub fn check_parameter_count<StrType, ScalarType, F, ArgList>(
function_parameters: &[StrType],
_function: &F,
) -> Result<(), ModelBuildError>
where
StrType: Into<String> + Clone,
F: BasisFunction<ScalarType, ArgList> + 'static,
ScalarType: Scalar,
{
if function_parameters.len() == F::ARGUMENT_COUNT {
Ok(())
} else {
Err(ModelBuildError::IncorrectParameterCount {
actual: function_parameters.len(),
expected: F::ARGUMENT_COUNT,
})
}
}
#[cfg(any(test, doctest))]
mod test;