use crate::model::errors::ModelError;
use crate::model::model_basis_function::ModelBasisFunction;
use nalgebra::base::Scalar;
use nalgebra::{DMatrix, DVector, Dyn};
use num_traits::Zero;
mod detail;
pub mod errors;
pub mod builder;
mod model_basis_function;
#[cfg(test)]
mod test;
pub struct SeparableModel<ScalarType>
where
ScalarType: Scalar,
{
parameter_names: Vec<String>,
basefunctions: Vec<ModelBasisFunction<ScalarType>>,
}
impl<ScalarType> SeparableModel<ScalarType>
where
ScalarType: Scalar,
{
pub fn parameters(&self) -> &[String] {
&self.parameter_names
}
pub fn parameter_count(&self) -> usize {
self.parameter_names.len()
}
pub fn basis_function_count(&self) -> usize {
self.basefunctions.len()
}
}
impl<ScalarType> SeparableModel<ScalarType>
where
ScalarType: Scalar + Zero,
{
pub fn eval(
&self,
location: &DVector<ScalarType>,
parameters: &[ScalarType],
) -> Result<DMatrix<ScalarType>, ModelError> {
if parameters.len() != self.parameter_count() {
return Err(ModelError::IncorrectParameterCount {
required: self.parameter_count(),
actual: parameters.len(),
});
}
let nrows = location.len();
let ncols = self.basis_function_count();
let mut function_value_matrix =
unsafe { DMatrix::uninit(Dyn(nrows), Dyn(ncols)).assume_init() };
for (basefunc, mut column) in self
.basefunctions
.iter()
.zip(function_value_matrix.column_iter_mut())
{
let function_value =
model_basis_function::evaluate_and_check(&basefunc.function, location, parameters)?;
column.copy_from(&function_value);
}
Ok(function_value_matrix)
}
pub fn eval_deriv<'a, 'b, 'c, 'd>(
&'a self,
location: &'b DVector<ScalarType>,
parameters: &'c [ScalarType],
) -> DerivativeProxy<'d, ScalarType>
where
'a: 'd,
'b: 'd,
'c: 'd,
{
DerivativeProxy {
basefunctions: &self.basefunctions,
location,
parameters,
model_parameter_names: &self.parameter_names,
}
}
}
#[must_use = "Derivative Proxy should be used immediately to evaluate a derivative matrix"]
pub struct DerivativeProxy<'a, ScalarType: Scalar> {
basefunctions: &'a [ModelBasisFunction<ScalarType>],
location: &'a DVector<ScalarType>,
parameters: &'a [ScalarType],
model_parameter_names: &'a [String],
}
impl<'a, ScalarType: Scalar + Zero> DerivativeProxy<'a, ScalarType> {
#[inline]
pub fn at(&self, param_index: usize) -> Result<DMatrix<ScalarType>, ModelError> {
if self.parameters.len() != self.model_parameter_names.len() {
return Err(ModelError::IncorrectParameterCount {
required: self.model_parameter_names.len(),
actual: self.parameters.len(),
});
}
if param_index >= self.model_parameter_names.len() {
return Err(ModelError::DerivativeIndexOutOfBounds { index: param_index });
}
let nrows = self.location.len();
let ncols = self.basefunctions.len();
let mut derivative_function_value_matrix =
DMatrix::<ScalarType>::from_element(nrows, ncols, Zero::zero());
for (basefunc, mut column) in self
.basefunctions
.iter()
.zip(derivative_function_value_matrix.column_iter_mut())
{
if let Some(derivative) = basefunc.derivatives.get(¶m_index) {
let deriv_value = model_basis_function::evaluate_and_check(
derivative,
self.location,
self.parameters,
)?;
column.copy_from(&deriv_value);
}
}
Ok(derivative_function_value_matrix)
}
#[inline]
pub fn at_param_name<StrType: AsRef<str>>(
&self,
param_name: StrType,
) -> Result<DMatrix<ScalarType>, ModelError> {
let index = self
.model_parameter_names
.iter()
.position(|p| p == param_name.as_ref())
.ok_or(ModelError::ParameterNotInModel {
parameter: param_name.as_ref().into(),
})?;
self.at(index)
}
}