use std::collections::HashMap;
use crate::model::errors::ModelError;
use nalgebra::base::Scalar;
use nalgebra::DVector;
type BaseFuncType<ScalarType> =
Box<dyn Fn(&DVector<ScalarType>, &[ScalarType]) -> DVector<ScalarType>>;
pub(crate) struct ModelBasisFunction<ScalarType>
where
ScalarType: Scalar,
{
pub function: BaseFuncType<ScalarType>,
pub derivatives: HashMap<usize, BaseFuncType<ScalarType>>,
}
impl<ScalarType> ModelBasisFunction<ScalarType>
where
ScalarType: Scalar,
{
pub fn parameter_independent<FuncType>(function: FuncType) -> Self
where
FuncType: Fn(&DVector<ScalarType>) -> DVector<ScalarType> + 'static,
{
Self {
function: Box::new(move |x, _params| (function)(x)),
derivatives: HashMap::default(),
}
}
}
#[inline]
pub fn evaluate_and_check<ScalarType: Scalar>(
func: &BaseFuncType<ScalarType>,
location: &DVector<ScalarType>,
parameters: &[ScalarType],
) -> Result<DVector<ScalarType>, ModelError> {
let result = (func)(location, parameters);
if result.len() == location.len() {
Ok(result)
} else {
Err(ModelError::UnexpectedFunctionOutput {
expected_length: location.len(),
actual_length: result.len(),
})
}
}