nacfahi/models/utility/composition/
mod.rs

1use core::ops::{Add, Mul, Sub};
2
3use generic_array::{
4    ArrayLength, GenericArray,
5    functional::FunctionalSequence,
6    sequence::{Concat, Split},
7};
8use generic_array_storage::Conv;
9use typenum::Sum;
10
11use crate::models::{FitModel, FitModelXDeriv};
12
13/// A model equal to consequent application of `inner` and `outer`.
14#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
15pub struct Composition<Inner, Outer> {
16    #[allow(missing_docs)]
17    pub inner: Inner,
18    #[allow(missing_docs)]
19    pub outer: Outer,
20}
21
22impl<Inner, Outer> FitModel for Composition<Inner, Outer>
23where
24    Inner: FitModel,
25    Outer: FitModelXDeriv<Scalar = Inner::Scalar>,
26    Inner::Scalar: Clone + Mul<Inner::Scalar, Output = Inner::Scalar>,
27    <Inner::ParamCount as Conv>::TNum: Add<<Outer::ParamCount as Conv>::TNum>,
28    Sum<<Inner::ParamCount as Conv>::TNum, <Outer::ParamCount as Conv>::TNum>: Conv<TNum = Sum<<Inner::ParamCount as Conv>::TNum, <Outer::ParamCount as Conv>::TNum>>
29        + ArrayLength
30        + Sub<<Inner::ParamCount as Conv>::TNum, Output = <Outer::ParamCount as Conv>::TNum>,
31{
32    type Scalar = Inner::Scalar;
33    type ParamCount = Sum<<Inner::ParamCount as Conv>::TNum, <Outer::ParamCount as Conv>::TNum>;
34
35    #[inline]
36    fn evaluate(&self, x: &Self::Scalar) -> Self::Scalar {
37        let y = self.inner.evaluate(x);
38        self.outer.evaluate(&y)
39    }
40
41    #[inline]
42    fn jacobian(
43        &self,
44        x: &Self::Scalar,
45    ) -> impl Into<GenericArray<Self::Scalar, <Self::ParamCount as Conv>::TNum>> {
46        // y = inner(x, p_in)
47        // z = outer(y, p_out)
48        //
49        // z_p_in = z_y * inner_j(z, p_in)
50        // z_p_out = outer_j(y, p_out)
51        let y = self.inner.evaluate(x);
52        let z_y = self.outer.deriv_x(&y);
53        let z_p_in = self.inner.jacobian(x).into().map(|v| v * z_y.clone());
54        let z_p_out = self.outer.jacobian(&y).into();
55
56        GenericArray::concat(z_p_in, z_p_out)
57    }
58
59    #[inline]
60    fn set_params(&mut self, new_params: GenericArray<Self::Scalar, Self::ParamCount>) {
61        let (inner_params, outer_params) = GenericArray::split(new_params);
62        self.inner.set_params(inner_params);
63        self.outer.set_params(outer_params);
64    }
65
66    #[inline]
67    fn get_params(&self) -> impl Into<GenericArray<Self::Scalar, Self::ParamCount>> {
68        GenericArray::concat(
69            self.inner.get_params().into(),
70            self.outer.get_params().into(),
71        )
72    }
73}
74
75impl<Inner, Outer> FitModelXDeriv for Composition<Inner, Outer>
76where
77    Inner: FitModelXDeriv,
78    Outer: FitModelXDeriv<Scalar = Inner::Scalar>,
79    Inner::Scalar: Mul<Output = Inner::Scalar>,
80    Self: FitModel<Scalar = Inner::Scalar>,
81{
82    #[inline]
83    fn deriv_x(&self, x: &Self::Scalar) -> Self::Scalar {
84        let y = self.inner.evaluate(x);
85        self.inner.deriv_x(x) * self.outer.deriv_x(&y)
86    }
87}
88
89/// Convenience trait to construct [`Composition`]. Alternatively, just construct it manually.
90pub trait CompositionExt: FitModel + Sized {
91    /// Applies second model on top of current one.
92    fn compose<Outer: FitModelXDeriv<Scalar = Self::Scalar>>(
93        self,
94        outer: Outer,
95    ) -> Composition<Self, Outer>;
96}
97
98impl<Inner: FitModel> CompositionExt for Inner {
99    #[inline]
100    fn compose<Outer: FitModelXDeriv<Scalar = Self::Scalar>>(
101        self,
102        outer: Outer,
103    ) -> Composition<Self, Outer> {
104        Composition { inner: self, outer }
105    }
106}
107
108#[cfg(test)]
109mod tests;