nacfahi/models/utility/composition/
mod.rs1use core::ops::{Add, Mul, Sub};
2
3use generic_array::{
4 ArrayLength, GenericArray,
5 functional::FunctionalSequence,
6 sequence::{Concat, Split},
7};
8use generic_array_storage::Conv;
9use typenum::Sum;
10
11use crate::models::{FitModel, FitModelXDeriv};
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
15pub struct Composition<Inner, Outer> {
16 #[allow(missing_docs)]
17 pub inner: Inner,
18 #[allow(missing_docs)]
19 pub outer: Outer,
20}
21
22impl<Inner, Outer> FitModel for Composition<Inner, Outer>
23where
24 Inner: FitModel,
25 Outer: FitModelXDeriv<Scalar = Inner::Scalar>,
26 Inner::Scalar: Clone + Mul<Inner::Scalar, Output = Inner::Scalar>,
27 <Inner::ParamCount as Conv>::TNum: Add<<Outer::ParamCount as Conv>::TNum>,
28 Sum<<Inner::ParamCount as Conv>::TNum, <Outer::ParamCount as Conv>::TNum>: Conv<TNum = Sum<<Inner::ParamCount as Conv>::TNum, <Outer::ParamCount as Conv>::TNum>>
29 + ArrayLength
30 + Sub<<Inner::ParamCount as Conv>::TNum, Output = <Outer::ParamCount as Conv>::TNum>,
31{
32 type Scalar = Inner::Scalar;
33 type ParamCount = Sum<<Inner::ParamCount as Conv>::TNum, <Outer::ParamCount as Conv>::TNum>;
34
35 #[inline]
36 fn evaluate(&self, x: &Self::Scalar) -> Self::Scalar {
37 let y = self.inner.evaluate(x);
38 self.outer.evaluate(&y)
39 }
40
41 #[inline]
42 fn jacobian(
43 &self,
44 x: &Self::Scalar,
45 ) -> impl Into<GenericArray<Self::Scalar, <Self::ParamCount as Conv>::TNum>> {
46 let y = self.inner.evaluate(x);
52 let z_y = self.outer.deriv_x(&y);
53 let z_p_in = self.inner.jacobian(x).into().map(|v| v * z_y.clone());
54 let z_p_out = self.outer.jacobian(&y).into();
55
56 GenericArray::concat(z_p_in, z_p_out)
57 }
58
59 #[inline]
60 fn set_params(&mut self, new_params: GenericArray<Self::Scalar, Self::ParamCount>) {
61 let (inner_params, outer_params) = GenericArray::split(new_params);
62 self.inner.set_params(inner_params);
63 self.outer.set_params(outer_params);
64 }
65
66 #[inline]
67 fn get_params(&self) -> impl Into<GenericArray<Self::Scalar, Self::ParamCount>> {
68 GenericArray::concat(
69 self.inner.get_params().into(),
70 self.outer.get_params().into(),
71 )
72 }
73}
74
75impl<Inner, Outer> FitModelXDeriv for Composition<Inner, Outer>
76where
77 Inner: FitModelXDeriv,
78 Outer: FitModelXDeriv<Scalar = Inner::Scalar>,
79 Inner::Scalar: Mul<Output = Inner::Scalar>,
80 Self: FitModel<Scalar = Inner::Scalar>,
81{
82 #[inline]
83 fn deriv_x(&self, x: &Self::Scalar) -> Self::Scalar {
84 let y = self.inner.evaluate(x);
85 self.inner.deriv_x(x) * self.outer.deriv_x(&y)
86 }
87}
88
89pub trait CompositionExt: FitModel + Sized {
91 fn compose<Outer: FitModelXDeriv<Scalar = Self::Scalar>>(
93 self,
94 outer: Outer,
95 ) -> Composition<Self, Outer>;
96}
97
98impl<Inner: FitModel> CompositionExt for Inner {
99 #[inline]
100 fn compose<Outer: FitModelXDeriv<Scalar = Self::Scalar>>(
101 self,
102 outer: Outer,
103 ) -> Composition<Self, Outer> {
104 Composition { inner: self, outer }
105 }
106}
107
108#[cfg(test)]
109mod tests;