#[cfg(test)]
mod test;
use nalgebra::base::Scalar;
use crate::basis_function::BasisFunction;
use crate::model::builder::error::ModelBuildError;
use crate::model::detail::{
check_parameter_names, create_index_mapping, create_wrapped_basis_function,
};
use crate::model::model_basis_function::ModelBasisFunction;
#[doc(hidden)]
pub struct ModelBasisFunctionBuilder<ScalarType>
where
ScalarType: Scalar,
{
model_parameters: Vec<String>,
function_parameters: Vec<String>,
model_function_result: Result<ModelBasisFunction<ScalarType>, ModelBuildError>,
}
impl<ScalarType> ModelBasisFunctionBuilder<ScalarType>
where
ScalarType: Scalar,
{
pub fn new<FuncType, StrCollection, ArgList>(
model_parameters: Vec<String>,
function_parameters: StrCollection,
function: FuncType,
) -> Self
where
FuncType: BasisFunction<ScalarType, ArgList> + 'static,
StrCollection: IntoIterator,
StrCollection::Item: AsRef<str>,
{
let function_parameters: Vec<String> = function_parameters
.into_iter()
.map(|s| s.as_ref().to_string())
.collect();
if let Err(err) = check_parameter_names(&function_parameters) {
return Self {
model_function_result: Err(err),
model_parameters,
function_parameters,
};
}
let model_function_result =
create_wrapped_basis_function(&model_parameters, &function_parameters, function).map(
|function| ModelBasisFunction {
function,
derivatives: Default::default(),
},
);
Self {
model_function_result,
model_parameters,
function_parameters: function_parameters.to_vec(),
}
}
pub fn partial_deriv<FuncType, ArgList>(mut self, parameter: &str, derivative: FuncType) -> Self
where
FuncType: BasisFunction<ScalarType, ArgList> + 'static,
{
if let Some((deriv_index_in_model, _)) = self
.model_parameters
.iter()
.enumerate()
.filter(|(_idx, model_param)| self.function_parameters.contains(model_param))
.find(|(_idx, model_param)| model_param == ¶meter)
{
if let Ok(model_function) = self.model_function_result.as_mut() {
match create_wrapped_basis_function(
&self.model_parameters,
&self.function_parameters,
derivative,
) {
Ok(deriv) => {
if model_function
.derivatives
.insert(deriv_index_in_model, deriv)
.is_some()
{
self.model_function_result =
Err(ModelBuildError::DuplicateDerivative {
parameter: parameter.into(),
});
}
}
Err(error) => {
self.model_function_result = Err(error);
}
}
}
self
} else {
Self {
model_function_result: Err(ModelBuildError::InvalidDerivative {
parameter: parameter.into(),
function_parameters: self.function_parameters.clone(),
}),
..self
}
}
}
pub fn build(self) -> Result<ModelBasisFunction<ScalarType>, ModelBuildError> {
self.check_completion()?;
self.model_function_result
}
fn check_completion(&self) -> Result<(), ModelBuildError> {
match self.model_function_result.as_ref() {
Ok(modelfunction) => {
check_parameter_names(self.model_parameters.as_slice())?;
check_parameter_names(self.function_parameters.as_slice())?;
let index_mapping = create_index_mapping(
self.model_parameters.as_slice(),
self.function_parameters.as_slice(),
)?;
for (index, parameter) in index_mapping.iter().zip(self.function_parameters.iter())
{
if !modelfunction.derivatives.contains_key(index) {
return Err(ModelBuildError::MissingDerivative {
missing_parameter: parameter.clone(),
function_parameters: self.function_parameters.clone(),
});
}
}
if index_mapping.len() != modelfunction.derivatives.len() {
panic!(
"Incorrect number of derivatives in set. This indicates a logic error in this library."
);
}
Ok(())
}
Err(_) => Ok(()),
}
}
}