use core::ops::{Mul, Sub};
use generic_array::{ArrayLength, GenericArray, IntoArrayLength, functional::FunctionalSequence};
use generic_array_storage::Conv;
use num_traits::{Float, One, Pow};
use typenum::{U0, U1, U2};
use crate::models::{FitModel, FitModelXDeriv};
pub trait DifferentiableFunction<Scalar>: Sized {
type ValueParams: ArrayLength;
type DerivativeParams: ArrayLength;
fn into_params(
self,
) -> (
impl Into<GenericArray<Scalar, Self::ValueParams>>,
impl Into<GenericArray<Scalar, Self::DerivativeParams>>,
);
fn value(params: &GenericArray<Scalar, Self::ValueParams>, x: &Scalar) -> Scalar;
fn derivative(params: &GenericArray<Scalar, Self::DerivativeParams>, x: &Scalar) -> Scalar;
}
#[inline]
pub fn func_pars<T, L: ArrayLength, const N: usize>(array: &GenericArray<T, L>) -> [&T; N]
where
typenum::Const<N>: IntoArrayLength<ArrayLength = L>,
{
array.map(core::convert::identity).into_array()
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct Addition<Scalar>(pub Scalar);
impl<Scalar: Clone + core::ops::Add<Scalar, Output = Scalar> + num_traits::One>
DifferentiableFunction<Scalar> for Addition<Scalar>
{
type ValueParams = U1;
type DerivativeParams = U0;
#[inline]
fn into_params(
self,
) -> (
impl Into<GenericArray<Scalar, Self::ValueParams>>,
impl Into<GenericArray<Scalar, Self::DerivativeParams>>,
) {
([self.0], [])
}
#[inline]
fn value(params: &GenericArray<Scalar, Self::ValueParams>, x: &Scalar) -> Scalar {
let [add] = func_pars(params);
x.clone() + add.clone()
}
#[inline]
fn derivative(params: &GenericArray<Scalar, Self::DerivativeParams>, _: &Scalar) -> Scalar {
let [] = func_pars(params);
Scalar::one()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct Multiplier<Scalar>(pub Scalar);
impl<Scalar: Clone + core::ops::Mul<Scalar, Output = Scalar>> DifferentiableFunction<Scalar>
for Multiplier<Scalar>
{
type ValueParams = U1;
type DerivativeParams = U1;
fn into_params(
self,
) -> (
impl Into<GenericArray<Scalar, Self::ValueParams>>,
impl Into<GenericArray<Scalar, Self::DerivativeParams>>,
) {
([self.0.clone()], [self.0])
}
fn value(params: &GenericArray<Scalar, Self::ValueParams>, x: &Scalar) -> Scalar {
let [mul] = func_pars(params);
mul.clone() * x.clone()
}
fn derivative(params: &GenericArray<Scalar, Self::DerivativeParams>, _x: &Scalar) -> Scalar {
let [mul] = func_pars(params);
mul.clone()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct Power<Scalar>(pub Scalar);
impl<Scalar: Clone + Sub<Scalar, Output = Scalar> + One + Pow<Scalar, Output = Scalar>>
DifferentiableFunction<Scalar> for Power<Scalar>
{
type ValueParams = U1;
type DerivativeParams = U2;
fn into_params(
self,
) -> (
impl Into<GenericArray<Scalar, Self::ValueParams>>,
impl Into<GenericArray<Scalar, Self::DerivativeParams>>,
) {
(
[self.0.clone()],
[self.0.clone(), self.0.clone().sub(Scalar::one())],
)
}
fn value(params: &GenericArray<Scalar, Self::ValueParams>, x: &Scalar) -> Scalar {
let [pow] = func_pars(params);
x.clone().pow(pow.clone())
}
fn derivative(params: &GenericArray<Scalar, Self::DerivativeParams>, x: &Scalar) -> Scalar {
let [pow, pow1] = func_pars(params);
x.clone().pow(pow1.clone()) * pow.clone()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct LnMap;
impl<Scalar: Float> DifferentiableFunction<Scalar> for LnMap {
type ValueParams = U0;
type DerivativeParams = U0;
fn into_params(
self,
) -> (
impl Into<GenericArray<Scalar, Self::ValueParams>>,
impl Into<GenericArray<Scalar, Self::DerivativeParams>>,
) {
([], [])
}
fn value(params: &GenericArray<Scalar, Self::ValueParams>, x: &Scalar) -> Scalar {
let [] = func_pars(params);
x.ln()
}
fn derivative(params: &GenericArray<Scalar, Self::DerivativeParams>, &x: &Scalar) -> Scalar {
let [] = func_pars(params);
Scalar::one() / x
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct ExpMap;
impl<Scalar: Float> DifferentiableFunction<Scalar> for ExpMap {
type ValueParams = U0;
type DerivativeParams = U0;
fn into_params(
self,
) -> (
impl Into<GenericArray<Scalar, Self::ValueParams>>,
impl Into<GenericArray<Scalar, Self::DerivativeParams>>,
) {
([], [])
}
fn value(params: &GenericArray<Scalar, Self::ValueParams>, x: &Scalar) -> Scalar {
let [] = func_pars(params);
x.exp()
}
fn derivative(params: &GenericArray<Scalar, Self::DerivativeParams>, x: &Scalar) -> Scalar {
let [] = func_pars(params);
x.exp()
}
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct ModelMap<Inner: FitModel, Map: DifferentiableFunction<Inner::Scalar>> {
pub inner: Inner,
value_params: GenericArray<Inner::Scalar, Map::ValueParams>,
derivative_params: GenericArray<Inner::Scalar, Map::DerivativeParams>,
}
pub fn model_map<Inner: FitModel, Map: DifferentiableFunction<Inner::Scalar>>(
inner: Inner,
map: Map,
) -> ModelMap<Inner, Map> {
let (value_params, derivative_params) = map.into_params();
ModelMap {
inner,
value_params: value_params.into(),
derivative_params: derivative_params.into(),
}
}
impl<Inner, Map: DifferentiableFunction<Inner::Scalar>> FitModel for ModelMap<Inner, Map>
where
Inner: FitModel,
Inner::Scalar: Clone + core::ops::Mul<Inner::Scalar, Output = Inner::Scalar>,
{
type Scalar = Inner::Scalar;
type ParamCount = Inner::ParamCount;
#[inline]
fn evaluate(&self, x: &Self::Scalar) -> Self::Scalar {
Map::value(&self.value_params, &self.inner.evaluate(x))
}
#[inline]
fn jacobian(
&self,
x: &Self::Scalar,
) -> impl Into<GenericArray<Self::Scalar, <Self::ParamCount as Conv>::TNum>> {
let inner_eval = self.inner.evaluate(x);
let inner_jacobian = self.inner.jacobian(x).into();
let map_derivative = Map::derivative(&self.derivative_params, &inner_eval);
inner_jacobian.map(|d| map_derivative.clone() * d)
}
#[inline]
fn set_params(
&mut self,
new_params: GenericArray<Self::Scalar, <Self::ParamCount as Conv>::TNum>,
) {
self.inner.set_params(new_params);
}
#[inline]
fn get_params(
&self,
) -> impl Into<GenericArray<Self::Scalar, <Self::ParamCount as Conv>::TNum>> {
self.inner.get_params()
}
}
impl<Inner, Map> FitModelXDeriv for ModelMap<Inner, Map>
where
Inner: FitModelXDeriv,
Inner::Scalar: Mul<Output = Inner::Scalar>,
Map: DifferentiableFunction<Inner::Scalar>,
Self: FitModel<Scalar = Inner::Scalar, ParamCount = Inner::ParamCount>,
{
#[inline]
fn deriv_x(&self, x: &Self::Scalar) -> Self::Scalar {
let y = self.inner.evaluate(x);
let y_x = self.inner.deriv_x(x);
let z_y = Map::derivative(&self.derivative_params, &y);
z_y * y_x
}
}
#[cfg(test)]
mod tests;